advent-of-code/2024/20/two_correct.py
2024-12-25 12:59:49 +01:00

213 lines
5.2 KiB
Python

#!/usr/bin/env python3
import collections
import colorsys
import sys
import rich.console
import rich.text
import rich.progress
console = rich.console.Console()
input_file = sys.argv[1]
with open(input_file) as fd:
lines = [line.rstrip() for line in fd.readlines()]
height = len(lines)
width = len(lines[0])
# Adjust parameters for part/file
part = int(sys.argv[2])
assert part in (1, 2)
if input_file.startswith("input"):
minic = 100
elif input_file.startswith("reddit_part3"):
minic = 30
else:
if part == 1:
minic = 1
else:
minic = 50
skips = 2 if part == 1 else 20
canon: collections.Counter[int] = collections.Counter()
demo = {}
if input_file == "demo":
if part == 1:
demo = {2: 14, 4: 14, 6: 2, 8: 4, 10: 2} | (
{12: 3, 20: 1, 36: 1, 38: 1, 40: 1, 64: 1}
)
elif part == 2:
demo = {50: 32, 52: 31, 54: 29, 56: 39, 58: 25, 60: 23} | (
{62: 20, 64: 19, 66: 12, 68: 14, 70: 12, 72: 22, 74: 4, 76: 3}
)
for k, v in demo.items():
canon[k] = v
vec = tuple[int, int]
directions = [
(-1, 0), # ^ North
(0, 1), # > East
(1, 0), # v South
(0, -1), # < West
]
# Find start position
for i, line in enumerate(lines):
if "S" in line:
j = line.index("S")
start = i, j
if "E" in line:
j = line.index("E")
stop = i, j
# Visit forward
normal = None
forward: list[list[int | None]] = list()
for _ in range(height):
forward.append([None] * width)
forward[start[0]][start[1]] = 0
stack: set[vec] = {start}
s = 0
while stack:
s += 1
nstack: set[vec] = set()
for pos in stack:
i, j = pos
for d, direction in enumerate(directions):
ii, jj = i + direction[0], j + direction[1]
cchar = lines[ii][jj]
if cchar == "#":
continue
previs = forward[ii][jj]
if previs is not None and previs < s:
continue
forward[ii][jj] = s
if cchar == "E":
if normal is None:
normal = s
nstack.add((ii, jj))
stack = nstack
assert normal
# Visit backwards
backward: list[list[int | None]] = list()
for _ in range(height):
backward.append([None] * width)
backward[stop[0]][stop[1]] = 0
stack = {stop}
s = 0
while stack:
s += 1
nstack = set()
for pos in stack:
i, j = pos
for d, direction in enumerate(directions):
ii, jj = i + direction[0], j + direction[1]
cchar = lines[ii][jj]
if cchar == "#":
continue
previs = backward[ii][jj]
if previs is not None and previs < s:
continue
backward[ii][jj] = s
if cchar == "E":
assert s == normal
nstack.add((ii, jj))
stack = nstack
# Print
def perc2color(perc: float) -> str:
rgb = colorsys.hsv_to_rgb(perc, 1.0, 1.0)
r, g, b = tuple(round(c * 255) for c in rgb)
return f"rgb({r},{g},{b})"
text = rich.text.Text()
for i in range(height):
for j in range(width):
fg = "white"
bg = "black"
char = lines[i][j]
forw = forward[i][j]
if char == ".":
if forw is not None:
fg = perc2color(forw / normal)
char = str(forw % 10)
bckw = backward[i][j]
if bckw is not None:
bg = perc2color(bckw / normal)
if char == "#":
char = ""
text.append(char, style=f"{fg} on {bg}")
text.append("\n")
console.print(text)
# Find cheats
saves: collections.Counter[int] = collections.Counter()
for i in rich.progress.track(range(1, height - 1), description="Finding cheats"):
for j in range(1, width - 1):
char = lines[i][j]
if char == "#":
continue
ovis = forward[i][j]
if ovis is None:
continue
if ovis >= normal:
continue
min_i = max(1, i - skips)
max_i = min(height - 1, i + skips)
for ii in range(min_i, max_i + 1):
rem = skips - abs(ii - i)
min_j = max(1, j - rem)
max_j = min(width - 1, j + rem)
for jj in range(min_j, max_j + 1):
manh = abs(i - ii) + abs(j - jj)
if manh > skips:
continue
cchar = lines[ii][jj]
if cchar == "#":
continue
nvis = backward[ii][jj]
if nvis is None:
continue
orem = normal - ovis
nrem = nvis + manh
save = orem - nrem
if save < minic:
continue
saves[save] += 1
log = console.log
log(f"{normal=}")
log(f"{dict(sorted(saves.items()))=}")
if demo:
log(f"{dict(sorted(canon.items()))=}")
diff = canon.copy()
diff.subtract(saves)
log(f"{dict(sorted(diff.items()))=}")
log(f"{(saves == canon)=}")
log(f"{saves.total()=}")
log(f"{canon.total()=}")
difft = 0
for v in diff.values():
difft += abs(v)
log(f"{difft=}")
print(saves.total())
# 1119834 too high
# 982425 correct!