173 lines
4.4 KiB
Python
173 lines
4.4 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import collections
|
|
import colorsys
|
|
import sys
|
|
import rich.console
|
|
|
|
console = rich.console.Console()
|
|
|
|
input_file = sys.argv[1]
|
|
|
|
with open(input_file) as fd:
|
|
lines = [line.rstrip() for line in fd.readlines()]
|
|
|
|
height = len(lines)
|
|
width = len(lines[0])
|
|
|
|
# Adjust parameters for part/file
|
|
part = int(sys.argv[2])
|
|
assert part in (1, 2)
|
|
|
|
if input_file.startswith("input"):
|
|
minic = 100
|
|
elif input_file.startswith("reddit_part3"):
|
|
minic = 30
|
|
else:
|
|
if part == 1:
|
|
minic = 1
|
|
else:
|
|
minic = 50
|
|
skips = 2 if part == 1 else 20
|
|
canon: collections.Counter[int] = collections.Counter()
|
|
demo = {}
|
|
if input_file == "demo":
|
|
if part == 1:
|
|
demo = {2: 14, 4: 14, 6: 2, 8: 4, 10: 2} | (
|
|
{12: 3, 20: 1, 36: 1, 38: 1, 40: 1, 64: 1}
|
|
)
|
|
elif part == 2:
|
|
demo = {50: 32, 52: 31, 54: 29, 56: 39, 58: 25, 60: 23} | (
|
|
{62: 20, 64: 19, 66: 12, 68: 14, 70: 12, 72: 22, 74: 4, 76: 3}
|
|
)
|
|
for k, v in demo.items():
|
|
canon[k] = v
|
|
|
|
vec = tuple[int, int]
|
|
directions = [
|
|
(-1, 0), # ^ North
|
|
(0, 1), # > East
|
|
(1, 0), # v South
|
|
(0, -1), # < West
|
|
]
|
|
|
|
# Find start position
|
|
for i, line in enumerate(lines):
|
|
if "S" in line:
|
|
j = line.index("S")
|
|
start = i, j
|
|
|
|
|
|
# Visit normally
|
|
normal = None
|
|
visited: list[list[int | None]] = list()
|
|
for _ in range(height):
|
|
visited.append([None] * width)
|
|
visited[start[0]][start[1]] = 0
|
|
stack: set[vec] = {start}
|
|
s = 0
|
|
while stack:
|
|
s += 1
|
|
nstack: set[vec] = set()
|
|
for pos in stack:
|
|
i, j = pos
|
|
for d, direction in enumerate(directions):
|
|
ii, jj = i + direction[0], j + direction[1]
|
|
|
|
cchar = lines[ii][jj]
|
|
if cchar == "#":
|
|
continue
|
|
|
|
previs = visited[ii][jj]
|
|
if previs is not None and previs < s:
|
|
continue
|
|
visited[ii][jj] = s
|
|
|
|
if cchar == "E":
|
|
if normal is None:
|
|
normal = s
|
|
nstack.add((ii, jj))
|
|
stack = nstack
|
|
assert normal
|
|
|
|
# Print
|
|
for i in range(height):
|
|
line = ""
|
|
for j in range(width):
|
|
char = lines[i][j]
|
|
vis = visited[i][j]
|
|
if (i, j) == (19, 1):
|
|
char = "[bold red on black]@"
|
|
elif (i, j) == (15, 1):
|
|
char = "[bold red on black]a"
|
|
elif vis is not None and char == ".":
|
|
hue = vis / normal
|
|
rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
|
|
r, g, b = tuple(round(c * 255) for c in rgb)
|
|
char = f"[on rgb({r},{g},{b})]{vis % 10}"
|
|
elif char == "#":
|
|
char = "[white]█"
|
|
else:
|
|
char = f"[bold green on black]{char}"
|
|
line += char
|
|
console.print(line)
|
|
print()
|
|
|
|
|
|
# Find cheats
|
|
saves: collections.Counter[int] = collections.Counter()
|
|
for i in range(1, height - 1):
|
|
if height > 100:
|
|
print(103, i, "/", height-2)
|
|
for j in range(1, width - 1):
|
|
char = lines[i][j]
|
|
if char == "#":
|
|
continue
|
|
ovis = visited[i][j]
|
|
if ovis is None:
|
|
continue
|
|
if ovis >= normal:
|
|
continue
|
|
min_i = max(1, i-skips)
|
|
max_i = min(height-1, i+skips)
|
|
for ii in range(min_i, max_i+1):
|
|
rem = skips - abs(ii - i)
|
|
min_j = max(1, j-rem)
|
|
max_j = min(width-1, j+rem)
|
|
for jj in range(min_j, max_j+1):
|
|
manh = abs(i - ii) + abs(j - jj)
|
|
if manh > skips:
|
|
continue
|
|
cchar = lines[ii][jj]
|
|
if cchar == "#":
|
|
continue
|
|
nvis = visited[ii][jj]
|
|
if nvis is None:
|
|
continue
|
|
orem = normal - ovis
|
|
# Works if there's space after the E, but catches unrelated paths
|
|
nrem = abs(normal - nvis) + manh
|
|
save = orem - nrem
|
|
if save < minic:
|
|
continue
|
|
saves[save] += 1
|
|
|
|
|
|
print(f"{normal=}")
|
|
print(f"{dict(sorted(saves.items()))=}")
|
|
if demo:
|
|
print(f"{dict(sorted(canon.items()))=}")
|
|
diff = canon.copy()
|
|
diff.subtract(saves)
|
|
print(f"{dict(sorted(diff.items()))=}")
|
|
print(f"{(saves == canon)=}")
|
|
print(f"{saves.total()=}")
|
|
print(f"{canon.total()=}")
|
|
difft = 0
|
|
for v in diff.values():
|
|
difft += abs(v)
|
|
print(f"{difft=}")
|
|
print(saves.total())
|
|
# 1119834 too high
|
|
# 982425 correct!
|