advent-of-code/2024/24/two_test.py

164 lines
3.5 KiB
Python
Raw Permalink Normal View History

2024-12-25 12:58:02 +01:00
#!/usr/bin/env python3
import functools
import sys
import typing
import matplotlib.pyplot as plt
import networkx as nx
input_file = sys.argv[1]
with open(input_file) as fd:
lines = [line.rstrip() for line in fd.readlines()]
gates: dict[str, tuple[str, typing.Callable, str]] = dict()
varis: dict[str, int] = dict()
funs = {
"AND": int.__and__, # orange
"OR": int.__or__, # green
"XOR": int.__xor__, # purple
}
G = nx.DiGraph()
swaps = [
# ("ncd", "nfj"),
# ("z37", "vkg"),
# ("z20", "cqr"),
# ("z15", "qnw"),
]
swapdict: dict[str, str] = dict()
for a, b in swaps:
swapdict[a] = b
swapdict[b] = a
step = False
for line in lines:
if not line:
step = True
elif step:
a, op, b, _, dest = line.split()
dest = swapdict.get(dest, dest)
fun = funs[op]
gates[dest] = (a, fun, b)
G.add_node(dest, op=op)
G.add_edge(a, dest)
G.add_edge(b, dest)
else:
dest, val = line.split(":")
varis[dest] = int(val)
def swap(a: str, b: str) -> None:
temp = gates[a]
gates[a] = gates[b]
gates[b] = temp
@functools.cache
def get_var(var: str) -> int:
if var in varis:
return varis[var]
a, fun, b = gates[var]
avar = get_var(a)
bvar = get_var(b)
return fun(avar, bvar)
all_keys = list(gates.keys()) + list(varis.keys())
all_keys.sort(reverse=True)
def get_number(prefix: str) -> int:
tot = 0
keys = [key for key in all_keys if key.startswith(prefix)]
for key in keys:
tot <<= 1
tot |= get_var(key)
return tot
X = get_number("x")
Y = get_number("y")
Z = get_number("z")
print(f"{X+Y=} = {X=} + {Y=}")
print(f" {Z=} {Z == X + Y=}")
print(",".join(sorted(swapdict.keys())))
# Viz
@functools.cache
def get_node_pos(node: str) -> tuple[float, float]:
x: float
y: float
if node.startswith("x"):
x = -int(node[1:]) * 2
y = 0
elif node.startswith("y"):
x = -int(node[1:]) * 2 - 1
y = 0
elif node.startswith("z"):
x = -int(node[1:]) * 2
y = 100
else:
a, _, b = gates[node]
ax, ay = get_node_pos(a)
bx, by = get_node_pos(b)
x = (ax + bx) / 2
y = max(ay, by) + 1
return x, y
colors = {
"AND": "orange",
"OR": "green",
"XOR": "purple",
}
node_colors = []
node_pos: dict[str, tuple[float, float]] = dict()
node_fixed: set[str] = set()
for node in G.nodes():
op = G.nodes[node].get("op")
node_colors.append(colors.get(op, "cyan" if node.startswith("x") else "blue"))
x: float
y: float
fixed = True
if node.startswith("x"):
x = -int(node[1:]) * 2
y = 0
elif node.startswith("y"):
x = -int(node[1:]) * 2 - 1
y = 0
elif node.startswith("z"):
x = -int(node[1:]) * 2
y = 50
else:
fixed = False
x = -23
y = 25
node_pos[node] = x, y
if fixed:
node_fixed.add(node)
# # My own layout
# for i in range(50):
# for node in G.nodes():
# if node in node_fixed:
# continue
# neighs = list(G.succ[node]) + list(G.pred[node])
# x = sum(node_pos[neigh][0] for neigh in neighs) / len(neighs)
# y = sum(node_pos[neigh][1] for neigh in neighs) / len(neighs)
# node_pos[node] = x, y
# node_fixed = set(G.nodes())
node_layout = nx.layout.spring_layout(
G.to_undirected(), k=1, iterations=1000, pos=node_pos, fixed=node_fixed
)
nx.draw(G, pos=node_layout, node_color=node_colors, with_labels=True)
plt.show()