advent-of-code/2024/16/networkx_test.py
2024-12-25 12:59:49 +01:00

87 lines
2.4 KiB
Python

#!/usr/bin/env python3
import sys
import typing
import matplotlib.pyplot as plt
import networkx as nx
input_file = sys.argv[1]
with open(input_file) as fd:
lines = [line.rstrip() for line in fd.readlines()]
height = len(lines)
width = len(lines[0])
directions = [
(-1, 0), # ^ North
(0, 1), # > East
(1, 0), # v South
(0, -1), # < West
]
# Parse input
g = nx.DiGraph()
for i in range(height):
for j in range(width):
char = lines[i][j]
if char == "#":
continue
for d, direction in enumerate(directions):
cur = (i, j, d)
# Start/end
if char == "E":
g.add_edge(cur, "end", weight=0)
elif char == "S" and d == 1:
g.add_edge("start", cur, weight=0)
# Rotate
g.add_edge(cur, (i, j, (d + 1) % len(directions)), weight=1000)
g.add_edge(cur, (i, j, (d - 1) % len(directions)), weight=1000)
# Advance
ii, jj = i + direction[0], j + direction[1]
if lines[ii][jj] == "#":
continue
g.add_edge(cur, (ii, jj, d), weight=1)
# Part 1
score = nx.shortest_path_length(g, "start", "end", weight="weight")
print(f"{score=}")
# Part 2
paths = nx.all_shortest_paths(g, "start", "end", weight="weight")
best_orientations = set()
for path in paths:
best_orientations |= set(path)
path_edges = list(zip(path, path[1:])) # Will be one random best path
best_places = set(bo[:2] for bo in best_orientations - {"start", "end"})
print(f"{len(best_places)=}")
# Draw graph
if len(g.nodes) > 1000:
sys.exit(0)
node_colors = ["blue" if node in best_orientations else "cyan" for node in g.nodes()]
edge_colors = ["red" if edge in path_edges else "black" for edge in g.edges()]
node_pos: dict[typing.Any, tuple[float, float]] = dict()
for node in g.nodes():
pos: tuple[float, float]
if node == "start":
pos = height - 1, 0
elif node == "end":
pos = 0, width - 1
else:
i, j, d = node
direction = directions[d]
pos = i + direction[0] / 3, j + direction[1] / 3
node_pos[node] = pos[1], pos[0] * -1
nx.draw_networkx_nodes(g, node_pos, node_color=node_colors)
nx.draw_networkx_edges(g, node_pos, edge_color=edge_colors)
# nx.draw_networkx_labels(g, node_pos)
# nx.draw_networkx_edge_labels(
# g, node_pos, edge_labels={(u, v): d["weight"] for u, v, d in g.edges(data=True)}
# )
plt.show()