advent-of-code/2024/13/two_simpy.py

83 lines
2.2 KiB
Python
Raw Permalink Normal View History

2024-12-25 12:58:02 +01:00
#!/usr/bin/env python3
"""
Someone mentionned sympy on reddit, wanted to see what I could do with it.
"""
import re
import sys
import sympy
input_file = sys.argv[1]
with open(input_file) as fd:
lines = [line.rstrip() for line in fd.readlines()]
coords = tuple[int, int]
prizes: list[coords] = list()
buttons: list[tuple[coords, coords]] = list()
for li, line in enumerate(lines):
machine = li // 4
offset = li % 4
if offset == 0:
match = re.match(r"^Button A: X\+([0-9]+), Y\+([0-9]+)$", line)
assert match
button_a = int(match[1]), int(match[2])
elif offset == 1:
match = re.match(r"^Button B: X\+([0-9]+), Y\+([0-9]+)$", line)
assert match
button_b = int(match[1]), int(match[2])
buttons.append((button_a, button_b))
elif offset == 2:
match = re.match("^Prize: X=([0-9]+), Y=([0-9]+)$", line)
assert match
prize = int(match[1]), int(match[2])
# prize = prize[0] + 10000000000000, prize[1] + 10000000000000
prizes.append(prize)
assert len(prizes) == len(buttons)
sympy.init_printing()
a, b, Ax, Ay, Bx, By, Px, Py = sympy.symbols(
"a b Ax Ay Bx By Px Py", positive=True, integer=True
)
x_eq = sympy.Eq(a * Ax + b * Bx, Px)
y_eq = sympy.Eq(a * Ay + b * By, Py)
tokens = 3 * a + 1 * b
sols = sympy.solve([x_eq, y_eq], a, b, dict=True)
# In that case, should use linsolve directly (solve ain't great)
# Would allow to .subs the whole solution set at once.
ttoks = sympy.Integer(0)
for arcade, prize in enumerate(prizes):
button_a, button_b = buttons[arcade]
print(43, prize, button_a, button_b)
vars = {
Ax: button_a[0],
Ay: button_a[1],
Bx: button_b[0],
By: button_b[1],
Px: prize[0],
Py: prize[1],
}
toks = None
for sol in sols:
a_presses, b_presses = sol[a].subs(vars), sol[b].subs(vars)
if not a_presses.is_integer or not b_presses.is_integer:
continue
ntoks = tokens.subs({a: a_presses, b: b_presses})
if toks is None or ntoks < toks:
toks = ntoks
print(76, toks)
if toks is not None:
ttoks += toks
assert ttoks.is_integer
print(int(ttoks.evalf()))