213 lines
5.2 KiB
Python
213 lines
5.2 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import collections
|
|
import colorsys
|
|
import sys
|
|
|
|
import rich.console
|
|
|
|
import rich.text
|
|
import rich.progress
|
|
|
|
console = rich.console.Console()
|
|
|
|
input_file = sys.argv[1]
|
|
|
|
with open(input_file) as fd:
|
|
lines = [line.rstrip() for line in fd.readlines()]
|
|
|
|
height = len(lines)
|
|
width = len(lines[0])
|
|
|
|
# Adjust parameters for part/file
|
|
part = int(sys.argv[2])
|
|
assert part in (1, 2)
|
|
|
|
if input_file.startswith("input"):
|
|
minic = 100
|
|
elif input_file.startswith("reddit_part3"):
|
|
minic = 30
|
|
else:
|
|
if part == 1:
|
|
minic = 1
|
|
else:
|
|
minic = 50
|
|
skips = 2 if part == 1 else 20
|
|
canon: collections.Counter[int] = collections.Counter()
|
|
demo = {}
|
|
if input_file == "demo":
|
|
if part == 1:
|
|
demo = {2: 14, 4: 14, 6: 2, 8: 4, 10: 2} | (
|
|
{12: 3, 20: 1, 36: 1, 38: 1, 40: 1, 64: 1}
|
|
)
|
|
elif part == 2:
|
|
demo = {50: 32, 52: 31, 54: 29, 56: 39, 58: 25, 60: 23} | (
|
|
{62: 20, 64: 19, 66: 12, 68: 14, 70: 12, 72: 22, 74: 4, 76: 3}
|
|
)
|
|
for k, v in demo.items():
|
|
canon[k] = v
|
|
|
|
vec = tuple[int, int]
|
|
directions = [
|
|
(-1, 0), # ^ North
|
|
(0, 1), # > East
|
|
(1, 0), # v South
|
|
(0, -1), # < West
|
|
]
|
|
|
|
# Find start position
|
|
for i, line in enumerate(lines):
|
|
if "S" in line:
|
|
j = line.index("S")
|
|
start = i, j
|
|
if "E" in line:
|
|
j = line.index("E")
|
|
stop = i, j
|
|
|
|
|
|
# Visit forward
|
|
normal = None
|
|
forward: list[list[int | None]] = list()
|
|
for _ in range(height):
|
|
forward.append([None] * width)
|
|
forward[start[0]][start[1]] = 0
|
|
stack: set[vec] = {start}
|
|
s = 0
|
|
while stack:
|
|
s += 1
|
|
nstack: set[vec] = set()
|
|
for pos in stack:
|
|
i, j = pos
|
|
for d, direction in enumerate(directions):
|
|
ii, jj = i + direction[0], j + direction[1]
|
|
|
|
cchar = lines[ii][jj]
|
|
if cchar == "#":
|
|
continue
|
|
|
|
previs = forward[ii][jj]
|
|
if previs is not None and previs < s:
|
|
continue
|
|
forward[ii][jj] = s
|
|
|
|
if cchar == "E":
|
|
if normal is None:
|
|
normal = s
|
|
nstack.add((ii, jj))
|
|
stack = nstack
|
|
assert normal
|
|
|
|
# Visit backwards
|
|
backward: list[list[int | None]] = list()
|
|
for _ in range(height):
|
|
backward.append([None] * width)
|
|
backward[stop[0]][stop[1]] = 0
|
|
stack = {stop}
|
|
s = 0
|
|
while stack:
|
|
s += 1
|
|
nstack = set()
|
|
for pos in stack:
|
|
i, j = pos
|
|
for d, direction in enumerate(directions):
|
|
ii, jj = i + direction[0], j + direction[1]
|
|
|
|
cchar = lines[ii][jj]
|
|
if cchar == "#":
|
|
continue
|
|
|
|
previs = backward[ii][jj]
|
|
if previs is not None and previs < s:
|
|
continue
|
|
backward[ii][jj] = s
|
|
|
|
if cchar == "E":
|
|
assert s == normal
|
|
nstack.add((ii, jj))
|
|
stack = nstack
|
|
|
|
# Print
|
|
|
|
|
|
def perc2color(perc: float) -> str:
|
|
rgb = colorsys.hsv_to_rgb(perc, 1.0, 1.0)
|
|
r, g, b = tuple(round(c * 255) for c in rgb)
|
|
return f"rgb({r},{g},{b})"
|
|
|
|
|
|
text = rich.text.Text()
|
|
for i in range(height):
|
|
for j in range(width):
|
|
fg = "white"
|
|
bg = "black"
|
|
char = lines[i][j]
|
|
forw = forward[i][j]
|
|
if char == ".":
|
|
if forw is not None:
|
|
fg = perc2color(forw / normal)
|
|
char = str(forw % 10)
|
|
bckw = backward[i][j]
|
|
if bckw is not None:
|
|
bg = perc2color(bckw / normal)
|
|
if char == "#":
|
|
char = "█"
|
|
text.append(char, style=f"{fg} on {bg}")
|
|
text.append("\n")
|
|
console.print(text)
|
|
|
|
|
|
# Find cheats
|
|
saves: collections.Counter[int] = collections.Counter()
|
|
for i in rich.progress.track(range(1, height - 1), description="Finding cheats"):
|
|
for j in range(1, width - 1):
|
|
char = lines[i][j]
|
|
if char == "#":
|
|
continue
|
|
ovis = forward[i][j]
|
|
if ovis is None:
|
|
continue
|
|
if ovis >= normal:
|
|
continue
|
|
min_i = max(1, i - skips)
|
|
max_i = min(height - 1, i + skips)
|
|
for ii in range(min_i, max_i + 1):
|
|
rem = skips - abs(ii - i)
|
|
min_j = max(1, j - rem)
|
|
max_j = min(width - 1, j + rem)
|
|
for jj in range(min_j, max_j + 1):
|
|
manh = abs(i - ii) + abs(j - jj)
|
|
if manh > skips:
|
|
continue
|
|
cchar = lines[ii][jj]
|
|
if cchar == "#":
|
|
continue
|
|
nvis = backward[ii][jj]
|
|
if nvis is None:
|
|
continue
|
|
orem = normal - ovis
|
|
nrem = nvis + manh
|
|
save = orem - nrem
|
|
if save < minic:
|
|
continue
|
|
saves[save] += 1
|
|
|
|
log = console.log
|
|
|
|
log(f"{normal=}")
|
|
log(f"{dict(sorted(saves.items()))=}")
|
|
if demo:
|
|
log(f"{dict(sorted(canon.items()))=}")
|
|
diff = canon.copy()
|
|
diff.subtract(saves)
|
|
log(f"{dict(sorted(diff.items()))=}")
|
|
log(f"{(saves == canon)=}")
|
|
log(f"{saves.total()=}")
|
|
log(f"{canon.total()=}")
|
|
difft = 0
|
|
for v in diff.values():
|
|
difft += abs(v)
|
|
log(f"{difft=}")
|
|
print(saves.total())
|
|
# 1119834 too high
|
|
# 982425 correct!
|