├── .gitignore ├── 01.py ├── 02.py ├── 03.py ├── 04.py ├── 05.py ├── 05_post.py ├── 06.py ├── 06_np.py ├── 07.py ├── 08.py ├── 08_post.py ├── 08_z3.py ├── 09.py ├── 10.py ├── 11.py ├── 12.py ├── 12_post.py ├── 13.py ├── 14.py ├── 14_post.py ├── 15.py ├── 15_post.py ├── 16.py ├── 17.py ├── 17_post.py ├── 18.py ├── 19.py ├── 19_post.py ├── 20.py ├── 21.py ├── 22.py ├── 22_post.py ├── 23.py ├── 23_post.py ├── 23e.in ├── 24.py ├── 24_post.py ├── 24_post_partial.py ├── 24_z3_optimized.py ├── 25.py ├── get_input.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .session 3 | *.in 4 | -------------------------------------------------------------------------------- /01.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | sys.stdin = open(__file__.replace('py', 'in')) 5 | 6 | L = list(ints()) 7 | 8 | 9 | prev = 0 10 | r = 0 11 | 12 | for j, i in enumerate(L): 13 | i = sum(L[j:j+3]) 14 | if i > prev: 15 | r += 1 16 | 17 | prev = i 18 | 19 | print(r-1) 20 | -------------------------------------------------------------------------------- /02.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | sys.stdin = open(__file__.replace('py', 'in')) 5 | 6 | 7 | depth = pos = 0 8 | d2 = 0 9 | 10 | for l in lines(): 11 | a, x = l.split() 12 | x = int(x) 13 | if a == 'forward': 14 | pos += x 15 | d2 += x * depth 16 | elif a == 'down': 17 | depth += x 18 | else: 19 | depth -= x 20 | 21 | 22 | print(pos * depth) 23 | print(pos * d2) 24 | 25 | -------------------------------------------------------------------------------- /03.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | sys.stdin = open(__file__.replace('py', 'in')) 5 | 6 | g = e = 0 7 | 8 | L = lines() 9 | N = len(L[0]) 10 | 11 | def f(t): 12 | P = list(L) 13 | 14 | for x in range(N): 15 | C = Counter() 16 | for l in P: 17 | C[l[x]] += 1 18 | 19 | o, z = C['1'], C['0'] 20 | 21 | if t: 22 | keep = '1' if o >= z else '0' 23 | else: 24 | keep = '0' if z <= o else '1' 25 | 26 | P = [l for l in P if l[x] == keep] 27 | if len(P) == 1: 28 | print('found', P[0]) 29 | return int(P[0], 2) 30 | 31 | assert P 32 | 33 | print(f(True) * f(False)) 34 | exit() 35 | 36 | for x in range(N): 37 | C = Counter() 38 | for l in L: 39 | C[l[x]] += 1 40 | 41 | g <<= 1 42 | e <<= 1 43 | o, z = C['1'], C['0'] 44 | if o > z: 45 | g += 1 46 | else: 47 | e += 1 48 | 49 | 50 | print(g * e) 51 | -------------------------------------------------------------------------------- /04.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | sys.stdin = open(__file__.replace('py', 'in')) 5 | input = sys.stdin.readline 6 | 7 | I = list(map(int, input().split(','))) 8 | 9 | boards = [] 10 | M = defaultdict(list) 11 | 12 | bs = 0 13 | while True: 14 | _ = input() 15 | if _ == '': break 16 | 17 | B = [list(ints(input())) for _ in range(5)] 18 | boards.append(B) 19 | 20 | for c in range(5): 21 | s = {B[y][c] for y in range(5)} 22 | for i in s: M[i].append((s, bs)) 23 | 24 | for y in range(5): 25 | s = {B[y][c] for c in range(5)} 26 | for i in s: M[i].append((s, bs)) 27 | 28 | bs += 1 29 | 30 | 31 | rem = set(range(bs)) 32 | D = set() 33 | for i in I: 34 | D.add(i) 35 | 36 | for s, bi in M[i]: 37 | if i in s: 38 | s.remove(i) 39 | 40 | if not s: 41 | rem.discard(bi) 42 | 43 | if not rem: 44 | r = sum(v for l in boards[bi] for v in l if v not in D) 45 | print(r * i) 46 | exit() 47 | 48 | """ 49 | for s, bi in M[i]: 50 | assert i in s 51 | s.remove(i) 52 | 53 | if not s: 54 | r = sum(v for l in boards[bi] for v in l if v not in D) 55 | print(r * i) 56 | exit() 57 | """ 58 | -------------------------------------------------------------------------------- /05.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | sys.stdin = open(__file__.replace('py', 'in')) 5 | 6 | l = lines() 7 | 8 | O = Counter() 9 | 10 | def f(s): 11 | return map(int, s.split(',')) 12 | 13 | for s in l: 14 | a, b = s.split(' -> ') 15 | 16 | x1, y1 = f(a) 17 | x2, y2 = f(b) 18 | 19 | #if not (x1 == x2 or y1 == y2): continue 20 | 21 | if x1 == x2: 22 | dx = 0 23 | else: 24 | dx = 1 if x2 > x1 else -1 25 | 26 | if y1 == y2: 27 | dy = 0 28 | else: 29 | dy = 1 if y2 > y1 else -1 30 | 31 | x, y = x1, y1 32 | while True: 33 | O[(x, y)] += 1 34 | 35 | if (x, y) == (x2, y2): break 36 | 37 | x += dx; y += dy 38 | 39 | print(sum(v > 1 for v in O.values())) 40 | -------------------------------------------------------------------------------- /05_post.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | sys.stdin = open(__file__.replace('_post.py', '.in')) 5 | 6 | O = Counter() 7 | 8 | def f(s) -> Point[int]: 9 | return Point.of(*ints(s)) 10 | 11 | for p1, p2 in map(lambda s: map(f, s.split(' -> ')), lines()): 12 | diff = p2 - p1 13 | d = Point.of(*map(sign, diff)) 14 | #if not (p1.x == p2.x or p1.y == p2.y): continue 15 | O.update(p1 + d * i for i in range(max(abs(diff))+1)) 16 | 17 | print(sum(v > 1 for v in O.values())) 18 | -------------------------------------------------------------------------------- /06.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # 34/8 3 | from util import * 4 | 5 | sys.stdin = open(__file__.replace('py', 'in')) 6 | 7 | C = Counter(ints()) 8 | 9 | for _ in range(256): 10 | NC = Counter() 11 | for k, v in C.items(): 12 | if k == 0: 13 | NC[6] += v 14 | NC[8] += v 15 | else: 16 | NC[k-1] += v 17 | 18 | C = NC 19 | 20 | 21 | prints(sum(C.values())) 22 | -------------------------------------------------------------------------------- /06_np.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | from util import * 3 | 4 | sys.stdin = open(__file__.replace("_np.py", ".in")) 5 | 6 | import numpy as np 7 | 8 | C = Counter(ints()) 9 | V = [C[i] for i in range(9)] 10 | 11 | A = np.eye(9, k=1, dtype=int) 12 | 13 | A[6][0] = 1 14 | A[8][0] = 1 15 | 16 | def f(n: int) -> int: 17 | return sum(np.linalg.matrix_power(A, n) @ V) 18 | 19 | print(f(80)) 20 | print(f(256)) 21 | -------------------------------------------------------------------------------- /07.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | sys.stdin = open(__file__.replace('py', 'in')) 5 | 6 | V = list(ints()) 7 | V.sort() 8 | n = len(V)//2 9 | 10 | x = V[n//2] 11 | 12 | # Using the "median" didn't work because I divided len(V) by 4... 13 | print(min(sum(abs(y - x) for y in V) for x in V)) 14 | 15 | prints(min(sum(abs(y - x) * (abs(y - x) + 1) // 2 for y in V) for x in range(min(V), max(V)+1))) 16 | -------------------------------------------------------------------------------- /08.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | if len(sys.argv) == 1: 5 | sys.stdin = open(__file__.replace('py', 'in')) 6 | 7 | nums = [ 8 | (0, 1, 2, 4, 5, 6), 9 | (2, 5), 10 | (0, 2, 3, 4, 6), 11 | (0, 2, 3, 5, 6), 12 | (1, 2, 3, 5), 13 | (0, 1, 3, 5, 6), 14 | (0, 1, 3, 4, 5, 6), 15 | (0, 2, 5), 16 | tuple(range(7)), 17 | (0, 1, 2, 3, 5, 6), 18 | ] 19 | 20 | r = 0 21 | for l in lines(): 22 | pat, e = l.split(' | ') 23 | 24 | ins = pat.split() 25 | out = e.split() 26 | 27 | M = [-1] * 7 28 | 29 | def f(c): 30 | return 'abcdefg'.index(c) 31 | 32 | ins = [set(map(f, s)) for s in ins] 33 | 34 | def get(l): 35 | for s in ins: 36 | if len(s) == l: 37 | return s 38 | 39 | ones = get(2) 40 | sev = get(3) 41 | 42 | poss = [set(range(7)) for _ in range(7)] 43 | poss[0] = {(sev - ones).pop()} 44 | 45 | for i in range(1, 7): poss[i] -= poss[0] 46 | 47 | poss[2] = set(ones) 48 | poss[5] = set(ones) 49 | 50 | four = get(4) 51 | 52 | for p in ins: 53 | for v, req in enumerate(nums): 54 | if len(p) == len(req) and len(p) in (2, 3, 4, 7): 55 | for i in req: 56 | poss[i] &= p 57 | 58 | for k, req2 in enumerate(nums): 59 | if v != k and set(req) < set(req2): 60 | for i in (set(req2) - set(req)): 61 | poss[i] -= p 62 | 63 | # look for 9 64 | for p in ins: 65 | if len(p) != 6: continue 66 | if sev < p and four < p: 67 | rem = p - sev - four 68 | poss[6] = rem 69 | 70 | for i in range(1, 6): 71 | poss[i] -= rem 72 | 73 | break 74 | 75 | # look for 0 76 | for p in ins: 77 | if len(p) != 6: continue 78 | if (sev | poss[4] | poss[6]) < p: 79 | rem = p - sev - poss[4] - poss[6] 80 | assert len(rem) == 1, rem 81 | 82 | poss[1] = rem 83 | for i in range(7): 84 | if i != 1: 85 | poss[i] -= rem 86 | #for v, req in enumerate(nums): 87 | 88 | rem = four - ones - poss[1] 89 | assert len(rem) == 1, rem 90 | poss[3] = rem 91 | for i in range(7): 92 | if i != 3: 93 | poss[i] -= rem 94 | 95 | # look for 5 96 | for p in ins: 97 | if len(p) != 5: continue 98 | 99 | if (poss[0] | poss[1] | poss[3] | poss[6]) < p: 100 | rem = p - (poss[0] | poss[1] | poss[3] | poss[6]) 101 | poss[5] = rem 102 | poss[2] = ones - rem 103 | #rem -= ones 104 | 105 | M = [v.pop() for v in poss] 106 | rev = [-1] * 7 107 | for i, v in enumerate(M): 108 | rev[v] = i 109 | 110 | res = 0 111 | for l in out: 112 | res *= 10 113 | 114 | l = list(map(f, l)) 115 | 116 | on = set(rev[v] for v in l) 117 | 118 | for i, v in enumerate(nums): 119 | if set(v) == on: 120 | res += i 121 | break 122 | else: 123 | assert False 124 | 125 | r += res 126 | 127 | prints(r) 128 | -------------------------------------------------------------------------------- /08_post.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | if len(sys.argv) == 1: 5 | sys.stdin = open(__file__.replace('_post.py', '.in')) 6 | 7 | 8 | nums = [ 9 | (0, 1, 2, 4, 5, 6), 10 | (2, 5), 11 | (0, 2, 3, 4, 6), 12 | (0, 2, 3, 5, 6), 13 | (1, 2, 3, 5), 14 | (0, 1, 3, 5, 6), 15 | (0, 1, 3, 4, 5, 6), 16 | (0, 2, 5), 17 | tuple(range(7)), 18 | (0, 1, 2, 3, 5, 6), 19 | ] 20 | 21 | r = 0 22 | for l in lines(): 23 | pat, e = l.split(' | ') 24 | 25 | ins = pat.split() 26 | out = e.split() 27 | 28 | def f(c): 29 | return 'abcdefg'.index(c) 30 | 31 | ins = [list(map(f, s)) for s in ins] 32 | 33 | res = 0 34 | for perm in permutations(range(7)): 35 | if all(tuple(sorted(perm[j] for j in i)) in nums for i in ins): 36 | for o in out: 37 | res *= 10 38 | 39 | x = tuple(sorted(perm[f(c)] for c in o)) 40 | res += nums.index(x) 41 | 42 | break 43 | else: assert False 44 | 45 | r += res 46 | 47 | prints(r) 48 | -------------------------------------------------------------------------------- /08_z3.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | from util import * 3 | 4 | from functools import reduce 5 | from z3 import And, BitVec, Distinct, Or, Solver 6 | 7 | 8 | if len(sys.argv) == 1: 9 | sys.stdin = open(__file__.replace("_z3.py", ".in")) 10 | 11 | nums = [ 12 | (0, 1, 2, 4, 5, 6), 13 | (2, 5), 14 | (0, 2, 3, 4, 6), 15 | (0, 2, 3, 5, 6), 16 | (1, 2, 3, 5), 17 | (0, 1, 3, 5, 6), 18 | (0, 1, 3, 4, 5, 6), 19 | (0, 2, 5), 20 | tuple(range(7)), 21 | (0, 1, 2, 3, 5, 6), 22 | ] 23 | 24 | # Represent the segments as bitmasks 25 | numb = [sum(1 << i for i in t) for t in nums] 26 | 27 | # X represents the permutation, the values should therefore be distinct and 28 | # range from 0 to 6 29 | X = [BitVec(f"X__{i}", 7) for i in range(7)] 30 | S = Solver() 31 | S.add(Distinct(X), *[And(0 <= x, x < 7) for x in X]) 32 | 33 | 34 | def f(c): 35 | return "abcdefg".index(c) 36 | 37 | 38 | r = 0 39 | for l in lines(): 40 | S.push() 41 | 42 | pat, e = l.split(" | ") 43 | 44 | ins = pat.split() 45 | out = e.split() 46 | 47 | for i, signal_ in enumerate(ins): 48 | signal = list(map(f, signal_)) 49 | 50 | # Make a helper variable that is equal to the bitmask of the signal 51 | # after the permutation is applied. 52 | bv = BitVec(f"bv_{i}", 7) 53 | signal_shifts = [1 << X[v] for v in signal] 54 | S.add(bv == reduce(lambda expr, v: expr | v, signal_shifts)) 55 | 56 | # The bitmask should be equal to one of the valid bitmasks. 57 | # We can save some time by only considering bitmasks with the same number of bits set. 58 | S.add(Or([bv == exp for t, exp in zip(nums, numb) if len(t) == len(signal)])) 59 | 60 | print(S.check()) 61 | m = S.model() 62 | 63 | M = [m.evaluate(x).as_long() for x in X] 64 | 65 | res = 0 66 | for l in out: 67 | res *= 10 68 | 69 | v = sum(1 << M[i] for i in map(f, l)) 70 | res += numb.index(v) 71 | 72 | S.pop() 73 | r += res 74 | 75 | prints(r) 76 | -------------------------------------------------------------------------------- /09.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # 61/8 3 | from util import * 4 | 5 | if len(sys.argv) == 1: 6 | sys.stdin = open(__file__.replace('py', 'in')) 7 | 8 | M = [list(map(int, l)) for l in lines()] 9 | H = len(M) 10 | W = len(M[0]) 11 | 12 | Q = [] 13 | r = 0 14 | for y, l in enumerate(M): 15 | for x, d in enumerate(l): 16 | m = 10 17 | for dx, dy in ((1, 0), (0, 1), (-1, 0), (0, -1)): 18 | nx, ny = x +dx, y + dy 19 | if not 0 <= nx < W or not 0 <= ny < H: continue 20 | m = min(m, M[ny][nx]) 21 | 22 | if d < m: 23 | Q.append((x, y)) 24 | r += d + 1 25 | 26 | print(r) 27 | 28 | B = [] 29 | V = [[False] * W for _ in range(H)] 30 | for sx, sy in Q: 31 | 32 | NQ = [(sx, sy)] 33 | V[sy][sx] = True 34 | 35 | for x, y in NQ: 36 | for dx, dy in ((1, 0), (0, 1), (-1, 0), (0, -1)): 37 | nx, ny = x +dx, y + dy 38 | if not 0 <= nx < W or not 0 <= ny < H: continue 39 | m = M[ny][nx] 40 | if m == 9 or V[ny][nx]: continue 41 | V[ny][nx] = True 42 | NQ.append((nx, ny)) 43 | 44 | B.append(len(NQ)) 45 | 46 | B.sort(reverse=True) 47 | a, b, c = B[:3] 48 | prints(a * b *c) 49 | -------------------------------------------------------------------------------- /10.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # 13/6 3 | from util import * 4 | 5 | if len(sys.argv) == 1: 6 | sys.stdin = open(__file__.replace('py', 'in')) 7 | 8 | r = 0 9 | 10 | pt = {')': 3, ']': 57, '}': 1197, '>': 25137} 11 | op = {')': '(', ']': '[', '}': '{', '>': '<'} 12 | 13 | score = ' )]}>' 14 | 15 | sc = [] 16 | 17 | for l in lines(): 18 | S = [] 19 | for c in l: 20 | if c in op: 21 | if not S or S[-1] != op[c]: 22 | r += pt[c] 23 | break 24 | 25 | S.pop() 26 | else: 27 | S.append(c) 28 | 29 | else: 30 | x = 0 31 | while S: 32 | x *= 5 33 | c = S.pop() 34 | 35 | for k, v in op.items(): 36 | if v == c: 37 | x += score.index(k) 38 | break 39 | else: 40 | assert False 41 | 42 | sc.append(x) 43 | 44 | sc.sort() 45 | 46 | print(r) 47 | prints(sc[len(sc)//2]) 48 | -------------------------------------------------------------------------------- /11.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | if len(sys.argv) == 1: 5 | sys.stdin = open(__file__.replace('py', 'in')) 6 | 7 | r = 0 8 | 9 | G = [list(map(int, input())) for _ in range(10)] 10 | 11 | for i in range(10 ** 10): 12 | V = [[False] * 10 for _ in range(10)] 13 | 14 | Q = [] 15 | for y, l in enumerate(G): 16 | for x, c in enumerate(l): 17 | G[y][x] += 1 18 | if G[y][x] > 9: 19 | V[y][x] = True 20 | Q.append((x, y)) 21 | 22 | for x, y in Q: 23 | for dx in range(-1, 2): 24 | for dy in range(-1, 2): 25 | nx = x + dx 26 | ny = y + dy 27 | if nx < 0 or ny < 0 or nx >= 10 or ny >= 10: continue 28 | G[ny][nx] += 1 29 | 30 | if not V[ny][nx] and G[ny][nx] > 9: 31 | V[ny][nx] = True 32 | Q.append((nx, ny)) 33 | 34 | r += len(Q) 35 | 36 | if len(Q) == 10 * 10: 37 | prints(i+1) 38 | exit() 39 | 40 | for x, y in Q: 41 | G[y][x] = 0 42 | 43 | prints(r) 44 | -------------------------------------------------------------------------------- /12.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # 1/37 3 | from util import * 4 | 5 | if len(sys.argv) == 1: 6 | sys.stdin = open(__file__.replace('py', 'in')) 7 | 8 | r = 0 9 | 10 | adj = defaultdict(list) 11 | for l in lines(): 12 | a, b = l.split('-') 13 | adj[a].append(b) 14 | adj[b].append(a) 15 | 16 | 17 | def f(i, v, b): 18 | global r 19 | 20 | if i.islower() and v[i]: 21 | if i == 'start' or b: return 22 | b = True 23 | 24 | if i == 'end': 25 | r += 1 26 | return 27 | 28 | v[i] += 1 29 | 30 | for j in adj[i]: 31 | f(j, v, b) 32 | 33 | v[i] -= 1 34 | 35 | 36 | f('start', Counter(), True) 37 | print(r) 38 | 39 | r = 0 40 | f('start', Counter(), False) 41 | prints(r) 42 | -------------------------------------------------------------------------------- /12_post.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | if len(sys.argv) == 1: 5 | sys.stdin = open(__file__.replace('_post.py', '.in')) 6 | 7 | adj = make_adj((l.split('-') for l in lines()), both=True) 8 | idx = {name: i for i, name in enumerate(a for a in adj if a.islower())} 9 | 10 | @lru_cache(maxsize=None) 11 | def f(i: str, v: int, d: bool) -> int: 12 | if i == 'end': 13 | return 1 14 | 15 | r = 0 16 | for j in adj[i]: 17 | if j.isupper(): r += f(j, v, d) 18 | elif j != 'start': 19 | seen = v & (1 << idx[j]) != 0 20 | if not seen or not d: 21 | r += f(j, v | (1 << idx[j]), seen or d) 22 | 23 | return r 24 | 25 | 26 | s = 'start' 27 | print(f(s, 1 << idx[s], True)) 28 | prints(f(s, 1 << idx[s], False)) 29 | 30 | 31 | -------------------------------------------------------------------------------- /13.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | if len(sys.argv) == 1: 5 | sys.stdin = open(__file__.replace('py', 'in')) 6 | 7 | 8 | A, B = sys.stdin.read().split('\n\n') 9 | 10 | dots = tile(list(ints(A)), 2) 11 | 12 | for s in lines(B): 13 | c = next(ints(s)) 14 | is_x = 'x' in s 15 | 16 | ndots = set() 17 | 18 | for a, b in dots: 19 | if is_x: a = min(a, 2 * c - a) 20 | else: b = min(b, 2 * c - b) 21 | ndots.add((a, b)) 22 | 23 | dots = ndots 24 | 25 | print_coords(dots) 26 | -------------------------------------------------------------------------------- /14.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # 42/82 3 | from util import * 4 | 5 | if len(sys.argv) == 1: 6 | sys.stdin = open(__file__.replace("py", "in")) 7 | 8 | 9 | s = input() 10 | input() 11 | 12 | adj = {} 13 | 14 | for l in sys.stdin: 15 | l = l.rstrip() 16 | a, b = l.split(" -> ") 17 | adj[a] = b 18 | 19 | after = defaultdict(Counter) 20 | for c1, c2 in zip(s, s[1:]): 21 | after[c1][c2] += 1 22 | 23 | for _ in range(40): 24 | nafter = defaultdict(Counter) 25 | for c1, v in after.items(): 26 | for c2, cnt in v.items(): 27 | nafter[c1][adj[c1 + c2]] += cnt 28 | nafter[adj[c1 + c2]][c2] += cnt 29 | 30 | after = nafter 31 | 32 | C = sum(after.values(), start=Counter()) 33 | C[s[0]] += 1 34 | 35 | S = sorted(C.values()) 36 | prints(S[-1] - S[0]) 37 | -------------------------------------------------------------------------------- /14_post.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | if len(sys.argv) == 1: 5 | sys.stdin = open(__file__.replace("_post.py", ".in")) 6 | 7 | 8 | start, _, *rest = lines() 9 | adj = dict(l.split(" -> ") for l in rest) 10 | 11 | pairs = Counter(a + b for a, b in zip(start, start[1:])) 12 | 13 | for steps in (10, 30): 14 | for i in range(steps): 15 | npairs = Counter() 16 | for p, cnt in pairs.items(): 17 | npairs[p[0] + adj[p]] += cnt 18 | npairs[adj[p] + p[1]] += cnt 19 | 20 | pairs = npairs 21 | 22 | C = Counter(start[0]) 23 | for p, cnt in pairs.items(): C[p[1]] += cnt 24 | print(max(C.values()) - min(C.values())) 25 | -------------------------------------------------------------------------------- /15.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # 16/44 3 | from util import * 4 | 5 | if len(sys.argv) == 1: 6 | sys.stdin = open(__file__.replace("py", "in")) 7 | 8 | 9 | G = [list(map(int, l)) for l in lines()] 10 | Q = [(0, 0, 0)] 11 | INF = 10 ** 10 12 | H = len(G) 13 | W = len(G[0]) 14 | 15 | for _ in range(4): 16 | G.extend([] for _ in range(H)) 17 | 18 | for dy in range(5): 19 | sy = dy * H 20 | 21 | for dx in range(5): 22 | if dx == dy == 0: continue 23 | 24 | dist = dy + dx 25 | 26 | for y in range(H): 27 | G[sy + y].extend((G[y][x] + dist - 1) % 9 + 1 for x in range(W)) 28 | 29 | 30 | W *= 5 31 | H *= 5 32 | 33 | 34 | D = [[INF] * W for _ in range(H)] 35 | D[0][0] = 0 36 | 37 | while Q: 38 | d, x, y = heappop(Q) 39 | if d > D[y][x]: continue 40 | 41 | if x == W-1 and y == H-1: 42 | print(d) 43 | 44 | for dx, dy in DIR: 45 | nx, ny = x + dx, y + dy 46 | if not 0 <= nx < W or not 0 <= ny < H: continue 47 | nd = d + G[ny][nx] 48 | if nd < D[ny][nx]: 49 | D[ny][nx] = nd 50 | heappush(Q, (nd, nx, ny)) 51 | 52 | -------------------------------------------------------------------------------- /15_post.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | if len(sys.argv) == 1: 5 | sys.stdin = open(__file__.replace("_post.py", ".in")) 6 | 7 | 8 | G = [list(map(int, l)) for l in lines()] 9 | Q = [(0, 0, 0)] 10 | INF = 10 ** 10 11 | H = len(G) 12 | W = len(G[0]) 13 | 14 | RW, RH = W * 5, H * 5 15 | 16 | D = [[INF] * RW for _ in range(RH)] 17 | D[0][0] = 0 18 | 19 | V = set(product(range(RW), range(RH))) # valid tiles 20 | 21 | while Q: 22 | d, x, y = heappop(Q) 23 | 24 | if (x+1, y+1) in ((RW, RH), (W, H)): 25 | prints(d) 26 | 27 | for nx, ny in neighbours(x, y, V=V): 28 | # It's simple to implicitly extend the graph in this way 29 | dist = nx // W + ny // H 30 | nd = d + (G[ny % H][nx % W] + dist - 1) % 9 + 1 31 | if nd < D[ny][nx]: 32 | D[ny][nx] = nd 33 | heappush(Q, (nd, nx, ny)) 34 | 35 | -------------------------------------------------------------------------------- /16.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # 27/22 3 | from util import * 4 | 5 | if len(sys.argv) == 1: 6 | sys.stdin = open(__file__.replace("py", "in")) 7 | 8 | 9 | DEC = '''0 = 0000 10 | 1 = 0001 11 | 2 = 0010 12 | 3 = 0011 13 | 4 = 0100 14 | 5 = 0101 15 | 6 = 0110 16 | 7 = 0111 17 | 8 = 1000 18 | 9 = 1001 19 | A = 1010 20 | B = 1011 21 | C = 1100 22 | D = 1101 23 | E = 1110 24 | F = 1111'''.split('\n') 25 | 26 | D = {} 27 | for l in DEC: 28 | a, b = l.split(' = ') 29 | D[a] = b 30 | 31 | I = ''.join(D[v] for v in input()) 32 | 33 | packets = [] 34 | S = [] 35 | 36 | def f(i): 37 | V = int(I[i:i+3], 2) 38 | i += 3 39 | ID = int(I[i:i+3], 2) 40 | i += 3 41 | 42 | if ID == 4: 43 | s = [] 44 | while True: 45 | s.append(I[i+1:i+5]) 46 | if I[i] == '0': break 47 | i += 5 48 | 49 | i += 5 50 | 51 | packets.append((V, ID, int(''.join(s), 2))) 52 | S.append(packets[-1][-1]) 53 | else: 54 | lid = int(I[i]) 55 | i += 1 56 | 57 | 58 | if lid == 0: 59 | le = int(I[i:i+15], 2) 60 | i += 15 61 | 62 | start = i 63 | cnt = 0 64 | while i - start < le: 65 | i = f(i) 66 | cnt += 1 67 | else: 68 | cnt = int(I[i:i+11], 2) 69 | i += 11 70 | 71 | for _ in range(cnt): 72 | i = f(i) 73 | 74 | packets.append((V, ID, cnt)) 75 | 76 | sub = [S.pop() for _ in range(cnt)] 77 | sub.reverse() 78 | 79 | if ID == 0: 80 | r = sum(sub) 81 | elif ID == 1: 82 | r = 1 83 | for v in sub: r *= v 84 | elif ID == 2: 85 | r = min(sub) 86 | elif ID == 3: 87 | r = max(sub) 88 | elif ID == 5: 89 | r = int(sub[0] > sub[1]) 90 | elif ID == 6: 91 | r = int(sub[0] < sub[1]) 92 | elif ID == 7: 93 | r = int(sub[0] == sub[1]) 94 | 95 | S.append(r) 96 | 97 | return i 98 | 99 | f(0) 100 | 101 | print(sum(v for v, _, _ in packets)) 102 | prints(S.pop()) 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /17.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # 17/7 3 | from util import * 4 | 5 | if len(sys.argv) == 1: 6 | sys.stdin = open(__file__.replace("py", "in")) 7 | 8 | x1, x2, y1, y2 = ints(input()) 9 | 10 | best = -1000 11 | r = 0 12 | for yvel in range(y1, 300): 13 | for xvel in range(x2+1): 14 | vel = Point.of(xvel, yvel) 15 | p = Point.of(0, 0) 16 | 17 | hy = 0 18 | 19 | while True: 20 | p += vel 21 | vel.x -= sign(vel.x) 22 | vel.y -= 1 23 | 24 | hy = max(hy, p.y) 25 | 26 | if x1 <= p.x <= x2 and y1 <= p.y <= y2: 27 | best = max(best, hy) 28 | r += 1 29 | break 30 | 31 | if p.y < y1 or p.x > x2: break 32 | 33 | print(best) 34 | prints(r) 35 | -------------------------------------------------------------------------------- /17_post.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | if len(sys.argv) == 1: 5 | sys.stdin = open(__file__.replace("_post.py", ".in")) 6 | 7 | x1, x2, y1, y2 = ints(input()) 8 | 9 | def tri(x: int) -> int: 10 | return x * (x+1) // 2 11 | 12 | best = -1000 13 | r = 0 14 | for yvel in range(y1, abs(y1)): 15 | def ypos(t: int) -> int: 16 | return yvel * t - tri(t-1) 17 | 18 | start = binary_search(lambda t: ypos(t) <= y2, 1) 19 | for xvel in range(int(x1 ** .5), x2+1): 20 | def xpos(t: int) -> int: 21 | if t <= xvel: 22 | return xvel * t - tri(t-1) 23 | return tri(xvel) 24 | 25 | t = start 26 | while y1 <= ypos(t) <= y2: 27 | if x1 <= xpos(t) <= x2: 28 | r += 1 29 | if yvel >= 0: 30 | best = max(best, tri(yvel)) 31 | break 32 | 33 | t += 1 34 | 35 | print(best) 36 | prints(r) 37 | -------------------------------------------------------------------------------- /18.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # 23/19 3 | from util import * 4 | 5 | if len(sys.argv) == 1: 6 | sys.stdin = open(__file__.replace("py", "in")) 7 | 8 | L = lines() 9 | res = eval(L[0]) 10 | 11 | def addl(a, v): 12 | assert isinstance(a, list) 13 | if isinstance(a[1], int): 14 | a[1] += v 15 | else: 16 | addl(a[1], v) 17 | 18 | def addr(a, v): 19 | assert isinstance(a, list) 20 | if isinstance(a[0], int): 21 | a[0] += v 22 | else: 23 | addr(a[0], v) 24 | 25 | def explode(a, depth=0): 26 | if isinstance(a, list): 27 | if depth == 4: 28 | return True, a[0], a[1] 29 | 30 | r = explode(a[0], depth+1) 31 | if r is not None: 32 | imm, b, c = r 33 | if imm: 34 | a[0] = 0 35 | 36 | if isinstance(a[1], int): 37 | a[1] += c 38 | else: 39 | addr(a[1], c) 40 | 41 | c = 0 42 | return False, b, c 43 | 44 | r = explode(a[1], depth+1) 45 | if r is not None: 46 | imm, b, c = r 47 | if imm: 48 | a[1] = 0 49 | 50 | if isinstance(a[0], int): 51 | a[0] += b 52 | else: 53 | addl(a[0], b) 54 | 55 | b = 0 56 | return False, b, c 57 | 58 | def split(a): 59 | if isinstance(a, list): 60 | for i in range(2): 61 | if isinstance(a[i], int): 62 | if a[i] >= 10: 63 | a[i] = [a[i] // 2, (a[i] + 1) // 2] 64 | return True 65 | elif split(a[i]): 66 | return True 67 | 68 | 69 | def reduce(): 70 | while True: 71 | if explode(res): 72 | pass 73 | elif split(res): 74 | pass 75 | else: 76 | break 77 | 78 | def magn(a): 79 | if isinstance(a, int): 80 | return a 81 | 82 | return 3 * magn(a[0]) + 2 * magn(a[1]) 83 | 84 | print(res) 85 | reduce() 86 | print(res) 87 | 88 | for l in L[1:]: 89 | b = eval(l) 90 | res = [res, b] 91 | reduce() 92 | print(res) 93 | 94 | 95 | print(magn(res)) 96 | 97 | 98 | best = 0 99 | for i in range(len(L)): 100 | for j in range(len(L)): 101 | if i != j: 102 | res = [eval(L[i]), eval(L[j])] 103 | reduce() 104 | best = max(best, magn(res)) 105 | 106 | prints(best) 107 | -------------------------------------------------------------------------------- /19.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # 48/48 3 | from util import * 4 | 5 | if len(sys.argv) == 1: 6 | sys.stdin = open(__file__.replace("py", "in")) 7 | 8 | S = sys.stdin.read().split('\n\n') 9 | 10 | scan = [] 11 | for i, s in enumerate(S): 12 | L = lines(s) 13 | seen = [] 14 | for l in L[1:]: 15 | seen.append(list(ints(l))) 16 | 17 | #scan.append((i, seen)) 18 | scan.append(seen) 19 | print(seen) 20 | 21 | """ 22 | from collections import defaultdict 23 | 24 | pts = defaultdict(set) 25 | for p in scan[0][1]: 26 | pts[tuple(p)].add(0) 27 | print(pts) 28 | """ 29 | 30 | here = [set(map(tuple, scan.pop(0)))] 31 | #scan.pop(0) 32 | P = [Point.of(0, 0, 0)] 33 | 34 | def mods(scanner): 35 | for xyz in permutations(range(3)): 36 | for flip in range(2**3): 37 | nscan = [] 38 | 39 | for p in scanner: 40 | np = [p[i] * (((flip >> j)&1) * 2 - 1) for j, i in enumerate(xyz)] 41 | 42 | nscan.append(np) 43 | 44 | yield nscan 45 | 46 | def f(next): 47 | for me in mods(next): 48 | """ 49 | for ref1 in pts: 50 | for ref2 in me: 51 | rref1 = Point.of(*ref1) 52 | rref2 = Point.of(*ref2) 53 | 54 | s2 = rref1 - rref2 55 | C = Counter() 56 | 57 | for i, ps in enumerate(me): 58 | pp = s2 + Point.of(*ps) 59 | C.update(pts.get(tuple(pp), [])) 60 | 61 | if any(v >= 12 for v in C.values()): 62 | return me, s2 63 | 64 | """ 65 | for prev in here: 66 | for ref1 in prev: 67 | for ref2 in me: 68 | 69 | rref1 = Point.of(*ref1) 70 | rref2 = Point.of(*ref2) 71 | 72 | s2 = rref1 - rref2 73 | 74 | overlap = 0 75 | for i, ps in enumerate(me): 76 | if overlap + len(me) - i < 12: break 77 | 78 | pp = s2 + Point.of(*ps) 79 | if tuple(pp) in prev: 80 | overlap += 1 81 | 82 | if overlap >= 12: 83 | return me, s2 84 | 85 | import random 86 | 87 | while scan: 88 | random.shuffle(scan) 89 | for i in range(len(scan)): 90 | print(i, len(scan)) 91 | sc = scan[i] 92 | r = f(sc) 93 | if r is not None: 94 | scan.pop(i) 95 | me, s2 = r 96 | here.append({tuple(s2 + p) for p in me}) 97 | P.append(s2) 98 | #for p in me: 99 | # pts[tuple(s2 + p)].add(j) 100 | break 101 | else: 102 | assert False 103 | 104 | S = set() 105 | for ps in here: S |= ps 106 | 107 | #print(len(pts)) 108 | print(len(S)) 109 | 110 | 111 | best = 0 112 | for i in range(len(P)): 113 | for j in range(i): 114 | best = max(best, (P[i] - P[j]).manh_dist()) 115 | print(best) 116 | -------------------------------------------------------------------------------- /19_post.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | if len(sys.argv) == 1: 5 | sys.stdin = open(__file__.replace("_post.py", ".in")) 6 | 7 | S = sys.stdin.read().split("\n\n") 8 | 9 | scan = [] 10 | for i, s in enumerate(S): 11 | beacons = {tuple(ints(l)) for l in lines(s)[1:]} 12 | fp = sorted((Point.of(*a) - b).dist2() for a, b in combinations(beacons, r=2)) 13 | scan.append((beacons, fp)) 14 | 15 | here = [scan.pop(0)] 16 | P = [Point.of(0, 0, 0)] 17 | 18 | F = [Point(l) for x in (-1, 1) for l in ([x, 0, 0], [0, x, 0], [0, 0, x])] 19 | rotation_matrices = [] 20 | for a, b in permutations(F, r=2): 21 | if (a + b).manh_dist() != 2: continue 22 | c = a.cross_3d(b) 23 | 24 | # Make matrix with basis as columns (transposed basis rows) 25 | rotation_matrices.append(tuple(zip(a, b, c))) 26 | 27 | 28 | def matvec(A: Sequence[Sequence[int]], v: Point[int]): 29 | return [sum(A[i][j] * v[j] for j in range(3)) for i in range(3)] 30 | 31 | 32 | def mods(scanner: Sequence[Point[int]]) -> Iterable[List[List[int]]]: 33 | for m in rotation_matrices: 34 | yield [matvec(m, v) for v in scanner] 35 | 36 | 37 | def fp_matches(a, b): 38 | cnt = 0 39 | i = 0 40 | for x in a: 41 | while i < len(b) and b[i] < x: i += 1 42 | if i < len(b) and b[i] == x: 43 | cnt += 1 44 | i += 1 45 | 46 | return cnt >= 12 * (12-1) // 2 47 | 48 | 49 | def matches(next, prev): 50 | prev_points, prev_fp = prev 51 | if not fp_matches(next[1], prev_fp): return 52 | 53 | for me in mods(next[0]): 54 | for ref1 in prev_points: 55 | for ref2 in me: 56 | s2 = Point.of(*ref1) - ref2 57 | 58 | overlap = 0 59 | for i, ps in enumerate(me): 60 | if overlap + len(me) - i < 12: 61 | break 62 | 63 | if tuple(s2 + ps) in prev_points: 64 | overlap += 1 65 | 66 | if overlap >= 12: 67 | return me, s2 68 | 69 | 70 | for prev in here: 71 | i = 0 72 | while i < len(scan): 73 | sc = scan[i] 74 | r = matches(sc, prev) 75 | if r is not None: 76 | scan.pop(i) 77 | me, s2 = r 78 | here.append(({tuple(s2 + p) for p in me}, sc[1])) 79 | P.append(s2) 80 | else: 81 | i += 1 82 | 83 | assert not scan 84 | 85 | print(len(set().union(*(points for points, _ in here)))) 86 | 87 | print(max((a - b).manh_dist() for a, b in combinations(P, r=2))) 88 | -------------------------------------------------------------------------------- /20.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from functools import reduce 3 | 4 | from util import * 5 | 6 | if len(sys.argv) == 1: 7 | sys.stdin = open(__file__.replace("py", "in")) 8 | 9 | alg, _, *I = lines() 10 | 11 | D = {(x, y): c == "#" for y, l in enumerate(I) for x, c in enumerate(l)} 12 | 13 | H, W = len(I), len(I[0]) 14 | 15 | for time in range(50): 16 | D = { 17 | (x, y): alg[ 18 | reduce( 19 | lambda acc, v: acc * 2 + v, 20 | ( 21 | D.get((x + dx, y + dy), time % 2) 22 | for dy in range(-1, 2) 23 | for dx in range(-1, 2) 24 | ), 25 | ) 26 | ] 27 | == "#" 28 | for y in range(-time - 1, H + time + 1) 29 | for x in range(-time - 1, W + time + 1) 30 | } 31 | 32 | if time in (1, 49): 33 | prints(sum(D.values())) 34 | -------------------------------------------------------------------------------- /21.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # 3/19 3 | from util import * 4 | 5 | if len(sys.argv) == 1: 6 | sys.stdin = open(__file__.replace("py", "in")) 7 | 8 | 9 | _, a, _, b = ints() 10 | 11 | @lru_cache(maxsize=None) 12 | def f(a, b, s1, s2, my_turn): 13 | 14 | res = 0 15 | tot = 0 16 | for roll in product(range(1, 4), repeat=3): 17 | na = (a + sum(roll) - 1) % 10 + 1 18 | 19 | if s1 + na >= 21: 20 | res += my_turn 21 | tot += 1 22 | else: 23 | x, y = f(b, na, s2, s1 + na, not my_turn) 24 | res += x 25 | tot += y 26 | 27 | return res, tot 28 | 29 | 30 | w1, total = f(a, b, 0, 0, True) 31 | print(max(w1, total - w1)) 32 | 33 | rolls = 0 34 | s1 = s2 = 0 35 | 36 | die = cycle(range(1, 101)) 37 | 38 | while True: 39 | rolls += 3 40 | r = sum(next(die) for _ in range(3)) 41 | a = (a + r - 1) % 10 + 1 42 | s1 += a 43 | 44 | if s1 >= 1000: 45 | break 46 | 47 | a, b, s1, s2 = b, a, s2, s1 48 | 49 | print(s2 * rolls) 50 | -------------------------------------------------------------------------------- /22.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # 3/>100 3 | from util import * 4 | 5 | if len(sys.argv) == 1: 6 | sys.stdin = open(__file__.replace("py", "in")) 7 | 8 | 9 | L = 50 10 | cub = [] 11 | D = {} 12 | for l in lines(): 13 | x1, x2, y1, y2, z1, z2 = ints(l) 14 | on = int(l.startswith("on")) 15 | 16 | me = (x1, x2, y1, y2, z1, z2) 17 | 18 | ncub = [] 19 | for pcub in cub: 20 | avoid = any(me[i + 1] < pcub[i] or me[i] > pcub[i + 1] for i in range(0, 6, 2)) 21 | 22 | if avoid: 23 | ncub.append(pcub) 24 | else: 25 | for split_at in range(0, 6, 2): 26 | def f(cur, i): 27 | if i == 6: 28 | yield cur 29 | else: 30 | (x, y), (a, b) = me[i : i + 2], pcub[i : i + 2] 31 | if i < split_at: 32 | yield from f(cur + (max(x, a), min(y, b)), i + 2) 33 | elif i > split_at: 34 | yield from f(cur + (a, b), i + 2) 35 | else: 36 | if a < x: 37 | yield from f(cur + (a, x - 1), i + 2) 38 | if y < b: 39 | yield from f(cur + (y + 1, b), i + 2) 40 | 41 | ncub.extend(f((), 0)) 42 | 43 | cub = ncub 44 | if on: 45 | cub.append(me) 46 | 47 | for x in range(max(x1, -L), min(x2, L) + 1): 48 | for y in range(max(y1, -L), min(y2, L) + 1): 49 | for z in range(max(z1, -L), min(z2, L) + 1): 50 | D[(x, y, z)] = on 51 | 52 | print(sum(D.values())) 53 | 54 | 55 | res = 0 56 | for x1, x2, y1, y2, z1, z2 in cub: 57 | res += (x2 + 1 - x1) * (y2 + 1 - y1) * (z2 + 1 - z1) 58 | 59 | prints(res) 60 | -------------------------------------------------------------------------------- /22_post.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | if len(sys.argv) == 1: 5 | sys.stdin = open(__file__.replace("_post.py", ".in")) 6 | 7 | 8 | L = 50 9 | cub = [] 10 | D = {} 11 | xs, ys, zs = ([10 ** 10] for _ in range(3)) 12 | for l in lines(): 13 | x1, x2, y1, y2, z1, z2 = ints(l) 14 | on = int(l.startswith("on")) 15 | 16 | xs += x1, x2 + 1 17 | ys += y1, y2 + 1 18 | zs += z1, z2 + 1 19 | 20 | cub.append((x1, x2, y1, y2, z1, z2, on)) 21 | 22 | for x in range(max(x1, -L), min(x2, L) + 1): 23 | for y in range(max(y1, -L), min(y2, L) + 1): 24 | for z in range(max(z1, -L), min(z2, L) + 1): 25 | D[(x, y, z)] = on 26 | 27 | L = [] 28 | for l in (xs, ys, zs): 29 | l[:] = sorted(set(l)) 30 | L.append(len(l)) 31 | 32 | X, Y, Z = L 33 | 34 | print(sum(D.values())) 35 | 36 | # From PyRival 37 | class BitArray: 38 | """implements bitarray using bytearray""" 39 | 40 | def __init__(self, size): 41 | self.bytes = bytearray((size >> 3) + 1) 42 | 43 | def __len__(self): 44 | return len(self.bytes) * 8 45 | 46 | def __getitem__(self, index): 47 | return (self.bytes[index >> 3] >> (index & 7)) & 1 48 | 49 | def __setitem__(self, index, value): 50 | if value: 51 | self.bytes[index >> 3] |= 1 << (index & 7) 52 | else: 53 | self.bytes[index >> 3] &= ~(1 << (index & 7)) 54 | 55 | 56 | C = [[BitArray(Z) for _ in range(Y)] for _ in range(X)] 57 | 58 | from bisect import * 59 | 60 | from tqdm import tqdm 61 | 62 | lmask = 2 ** 8 - 1 63 | 64 | for *me, on in tqdm(cub): 65 | xb, yb, zb = ( 66 | (bisect_left(l, a), bisect_right(l, b)) 67 | for a, b, l in zip( 68 | me[::2], 69 | me[1::2], 70 | (xs, ys, zs), 71 | ) 72 | ) 73 | z1, z2 = zb 74 | 75 | for x in range(*xb): 76 | for y in range(*yb): 77 | b1 = (z1 & ~7) + 8 78 | for z in range(z1, min(b1, z2)): 79 | C[x][y][z] = on 80 | 81 | while b1 + 8 <= z2: 82 | C[x][y].bytes[b1 // 8] = lmask * on 83 | b1 += 8 84 | 85 | for z in range(b1, z2): 86 | C[x][y][z] = on 87 | 88 | 89 | idx = [[]] 90 | for bs in range(1, 2 ** 8): 91 | nl = [i + 1 for i in idx[bs >> 1]] 92 | if bs & 1: 93 | nl.append(0) 94 | idx.append(nl) 95 | 96 | 97 | res = 0 98 | for x in range(X-1): 99 | xdiff = (xs[x + 1] - xs[x]) 100 | for y in range(Y-1): 101 | area = xdiff * (ys[y + 1] - ys[y]) 102 | for zb, b in enumerate(C[x][y].bytes): 103 | zb *= 8 104 | for i in idx[b]: 105 | res += area * (zs[zb + i + 1] - zs[zb + i]) 106 | 107 | prints(res) 108 | -------------------------------------------------------------------------------- /23.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # >100/41 3 | from util import * 4 | 5 | if len(sys.argv) == 1: 6 | sys.stdin = open(__file__.replace("py", "in")) 7 | 8 | S = 'ABCD' 9 | D = dict(zip(S, (1, 10, 100, 1000))) 10 | dest = dict(zip(S, (3, 5, 7, 9))) 11 | 12 | P = tuple([] for _ in range(4)) 13 | 14 | for y, l in enumerate(lines()): 15 | for x, c in enumerate(l): 16 | if c in S: 17 | P[S.index(c)].append((x, y)) 18 | 19 | def fix(state): 20 | return tuple(tuple(sorted(l)) for l in state) 21 | 22 | def mkrev(state): 23 | D = {} 24 | for i, l in enumerate(state): 25 | for x, y in l: 26 | D[(x, y)] = S[i] 27 | return D 28 | 29 | P = fix(P) 30 | dist = {P: 0} 31 | Q = [(0, P)] 32 | 33 | def push(nd, nstate): 34 | #nstate = fix(nstate) 35 | pd = dist.get(nstate, -1) 36 | if pd == -1 or nd < pd: 37 | dist[nstate] = nd 38 | heappush(Q, (nd, nstate)) 39 | 40 | 41 | def is_correct(rev, x, y): 42 | c = rev.get((x, y)) 43 | if c is None: return False 44 | return x == dest[c] 45 | 46 | W = len('#...B.......#') 47 | H = max(y for l in P for _, y in l) + 1 48 | print(H) 49 | 50 | while Q: 51 | d, state = heappop(Q) 52 | if d > dist[state]: continue 53 | #print(d) 54 | 55 | ok = True 56 | 57 | rev = mkrev(state) 58 | 59 | for idx, l in enumerate(state): 60 | 61 | for i, (x, y) in enumerate(l): 62 | if y > 1: 63 | corr = is_correct(rev, x, y) 64 | ok &= corr 65 | 66 | #if not corr or y == 2 and not is_correct(rev, x, y+1): 67 | if not corr or any(not is_correct(rev, x, ny) for ny in range(y+1, H)): 68 | #if y == 3 and (x, y-1) in rev: continue 69 | if y > 2 and any((x, ny) in rev for ny in range(2, y)): 70 | continue 71 | 72 | for sig in (-1, 1): 73 | nx = x + sig 74 | while (nx, 1) not in rev and 1 <= nx < W-1: 75 | if nx % 2 and 3 <= nx <= 9: 76 | nx += sig 77 | continue 78 | 79 | nstate = list(state) 80 | new_l = list(nstate[idx]) 81 | new_l[i] = (nx, 1) 82 | nstate[idx] = tuple(sorted(new_l)) 83 | 84 | move = (abs(x - nx) + y - 1) * D[S[idx]] 85 | push(d + move, tuple(nstate)) 86 | 87 | nx += sig 88 | else: 89 | ok = False 90 | assert y == 1 91 | 92 | nx = dest[S[idx]] 93 | if (nx, 2) in rev: continue 94 | 95 | rny = 2 96 | will_move = True 97 | for ny in range(3, H): 98 | if (nx, ny) in rev: 99 | if not is_correct(rev, nx, ny): 100 | will_move = False 101 | break 102 | else: 103 | rny = ny 104 | 105 | if not will_move: continue 106 | 107 | #if (nx, 2) in rev or ((nx, 3) in rev and not is_correct(rev, nx, 3)): 108 | # continue 109 | 110 | ny = rny 111 | xmove = abs(nx - x) 112 | 113 | way = sign(nx - x) 114 | while x != nx: 115 | x += way 116 | if (x, y) in rev: 117 | will_move = False 118 | break 119 | 120 | if not will_move: 121 | continue 122 | 123 | nstate = list(state) 124 | new_l = list(nstate[idx]) 125 | new_l[i] = (nx, ny) 126 | nstate[idx] = tuple(sorted(new_l)) 127 | 128 | """ 129 | print_coords(rev, '.') 130 | print() 131 | print_coords(mkrev(nstate), '.') 132 | print('-' * W) 133 | """ 134 | 135 | move = (xmove + abs(ny - y)) * D[S[idx]] 136 | push(d + move, tuple(nstate)) 137 | 138 | if ok: 139 | print_coords(rev, '.') 140 | prints(d) 141 | break 142 | -------------------------------------------------------------------------------- /23_post.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | if len(sys.argv) == 1: 5 | sys.stdin = open(__file__.replace("_post.py", ".in")) 6 | 7 | S = 'ABCD' 8 | D = (1, 10, 100, 1000) 9 | dest = dict(zip(S, (3, 5, 7, 9))) 10 | 11 | Pnt = Tuple[int, int] 12 | P: Tuple[List[Pnt], ...] = tuple([] for _ in range(4)) 13 | 14 | for y, line in enumerate(lines()): 15 | for x, c in enumerate(line): 16 | if c in S: 17 | P[S.index(c)].append((x, y)) 18 | 19 | 20 | def mkrev(state): 21 | "Build a reverse mapping from positions to amphipods" 22 | return {p: S[i] for i, l in enumerate(state) for p in l} 23 | 24 | def new_state(state, idx, i, new_p): 25 | """ 26 | Make a new normalised state from 'state' where the i'th amphipod of type 27 | 'idx' is at 'new_p' 28 | """ 29 | nstate = list(state) 30 | new_l = list(nstate[idx]) 31 | new_l[i] = new_p 32 | nstate[idx] = tuple(sorted(new_l)) 33 | return tuple(nstate) 34 | 35 | P = tuple(tuple(sorted(l)) for l in P) 36 | dist = {P: 0} 37 | Q = [(0, P)] 38 | 39 | def push(nd, nstate): 40 | pd = dist.get(nstate, -1) 41 | if pd == -1 or nd < pd: 42 | dist[nstate] = nd 43 | heappush(Q, (nd, nstate)) 44 | 45 | 46 | def is_correct(rev, x, y): 47 | c = rev.get((x, y)) 48 | return c is not None and x == dest[c] 49 | 50 | W = len('#...B.......#') 51 | H = max(y for l in P for _, y in l) + 1 52 | 53 | while Q: 54 | d, state = heappop(Q) 55 | if d > dist[state]: continue 56 | 57 | ok = True 58 | rev = mkrev(state) 59 | 60 | up_moves = [] 61 | 62 | for idx, l in enumerate(state): 63 | for i, (x, y) in enumerate(l): 64 | if y > 1: 65 | corr = is_correct(rev, x, y) 66 | ok &= corr 67 | 68 | # We have to move up if any amphipod below us (including ourself) 69 | # is in the wrong position. 70 | if any(not is_correct(rev, x, ny) for ny in range(y, H)): 71 | # We cannot move up if there is an amphipod above us 72 | if any((x, ny) in rev for ny in range(2, y)): 73 | continue 74 | 75 | for sig in (-1, 1): 76 | nx = x + sig 77 | while (nx, 1) not in rev and 1 <= nx < W-1: 78 | if nx % 2 and 3 <= nx <= 9: 79 | nx += sig 80 | continue 81 | 82 | move = (abs(x - nx) + y - 1) * D[idx] 83 | up_moves.append((d + move, new_state(state, idx, i, (nx, 1)))) 84 | 85 | nx += sig 86 | else: 87 | ok = False 88 | nx = dest[S[idx]] 89 | 90 | rny = 2 91 | will_move = True 92 | for ny in range(2, H): 93 | if (nx, ny) in rev: 94 | if not is_correct(rev, nx, ny): 95 | will_move = False 96 | break 97 | else: 98 | rny = ny 99 | 100 | if not will_move: continue 101 | 102 | xmove = abs(nx - x) 103 | way = sign(nx - x) 104 | while x != nx: 105 | x += way 106 | if (x, y) in rev: 107 | will_move = False 108 | break 109 | 110 | if not will_move: continue 111 | 112 | move = (xmove + abs(rny - y)) * D[idx] 113 | push(d + move, new_state(state, idx, i, (nx, rny))) 114 | # It is always optimal to move an amphipod into its final position 115 | break 116 | else: 117 | for nd, nstate in up_moves: 118 | push(nd, nstate) 119 | 120 | 121 | if ok: 122 | print_coords(rev, '.') 123 | prints(d) 124 | break 125 | -------------------------------------------------------------------------------- /23e.in: -------------------------------------------------------------------------------- 1 | ############# 2 | #...........# 3 | ###D#B#C#C### 4 | #D#C#B#A# 5 | #D#B#A#C# 6 | #D#A#B#A# 7 | ######### 8 | -------------------------------------------------------------------------------- /24.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | if len(sys.argv) == 1: 5 | sys.stdin = open(__file__.replace("py", "in")) 6 | 7 | FS = sys.stdin.read() 8 | secs = FS.split('inp w\n')[1:] 9 | secs = [[s.split() for s in lines(sec)] for sec in secs] 10 | N = len(secs) 11 | assert N == 14 12 | 13 | cache = [set() for _ in range(14)] 14 | VARS = 'wxyz' 15 | 16 | def run(sec_no: int, pz: int) -> int: 17 | if sec_no == N or pz > 10 ** 7: 18 | return 0 if pz == 0 else -1 19 | 20 | if pz in cache[sec_no]: return -1 21 | 22 | #for w in range(9, 0, -1): 23 | for w in range(1, 10): 24 | vs = [0] * 4 25 | vs[0] = vs[1] = w 26 | vs[3] = pz 27 | 28 | def get(x): 29 | if x[0] == '-' or x.isdigit(): 30 | return int(x) 31 | return vs[ord(x) - ord('w')] 32 | 33 | for instr, *args in secs[sec_no]: 34 | if instr == 'inp': 35 | assert False 36 | else: 37 | x, y = args 38 | x = ord(x) - ord('w') 39 | y = get(y) 40 | 41 | if instr == 'add': 42 | vs[x] += y 43 | elif instr == 'mul': 44 | vs[x] *= y 45 | elif instr == 'div': 46 | assert y != 0 47 | vs[x] //= y 48 | elif instr == 'mod': 49 | assert vs[x] >= 0 50 | assert y > 0 51 | vs[x] %= y 52 | elif instr == 'eql': 53 | vs[x] = int(vs[x] == y) 54 | 55 | r = run(sec_no + 1, vs[-1]) 56 | if r != -1: 57 | return r + w * 10 ** (N - sec_no - 1) 58 | 59 | cache[sec_no].add(pz) 60 | return -1 61 | 62 | print(run(0, 0)) 63 | -------------------------------------------------------------------------------- /24_post.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | if len(sys.argv) == 1: 5 | sys.stdin = open(__file__.replace("_post.py", ".in")) 6 | 7 | FS = sys.stdin.read() 8 | secs = FS.split('inp w\n')[1:] 9 | secs = [[s.split() for s in lines(sec)] for sec in secs] 10 | N = len(secs) 11 | assert N == 14 12 | 13 | op_map = dict(zip('add mul div mod'.split(), '+ * // %'.split())) 14 | 15 | generated_code = '' 16 | for i, sec in enumerate(secs): 17 | opt = [] 18 | for instr, x, y in sec: 19 | if instr == 'eql': 20 | opt.append(f'{x} = int({x} == {y})') 21 | else: 22 | opt.append(f'{x} {op_map[instr]}= {y}') 23 | 24 | generated_code += f'def f_{i}(w, z, x=0, y=0):\n' 25 | opt.append('return z') 26 | generated_code += '\n'.join('\t' + s for s in opt) + '\n\n' 27 | 28 | generated_code += 'funs = [' + ', '.join(f'f_{i}' for i in range(14)) + ']' 29 | exec(generated_code) 30 | 31 | cache = [set() for _ in range(14)] 32 | 33 | def run(sec_no: int, pz: int) -> int: 34 | if sec_no == N: #or pz > 10 ** 7: 35 | return 0 if pz == 0 else -1 36 | 37 | if pz in cache[sec_no]: return -1 38 | 39 | for w in range(9, 0, -1): 40 | #for w in range(1, 10): 41 | nz = funs[sec_no](w, pz) 42 | r = run(sec_no + 1, nz) 43 | if r != -1: 44 | return r + w * 10 ** (N - sec_no - 1) 45 | 46 | cache[sec_no].add(pz) 47 | return -1 48 | 49 | print(run(0, 0)) 50 | -------------------------------------------------------------------------------- /24_post_partial.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | from util import * 3 | 4 | FS = sys.stdin.read() 5 | secs = FS.split("inp w\n")[1:] 6 | secs = [[s.split() for s in lines(sec)] for sec in secs] 7 | N = len(secs) 8 | assert N == 14 9 | 10 | 11 | def add(x, y, vs): 12 | vs[x] += y(vs) 13 | 14 | def mul(x, y, vs): 15 | vs[x] *= y(vs) 16 | 17 | def div(x, y, vs): 18 | vs[x] //= y(vs) 19 | 20 | def mod(x, y, vs): 21 | vs[x] %= y(vs) 22 | 23 | def eql(x, y, vs): 24 | vs[x] = int(vs[x] == y(vs)) 25 | 26 | 27 | from functools import partial 28 | 29 | fun_map = {fun.__name__: fun for fun in (add, mul, div, mod, eql)} 30 | 31 | opt_secs = [] 32 | for sec in secs: 33 | funs = [] 34 | for instr, x, y in sec: 35 | x = ord(x) - ord("w") 36 | yf = ( 37 | (lambda vs, y=int(y): y) 38 | if y[0] == "-" or y.isdigit() 39 | else (lambda vs, i=(ord(y) - ord("w")): vs[i]) 40 | ) 41 | funs.append(partial(fun_map[instr], x, yf)) 42 | 43 | opt_secs.append(funs) 44 | 45 | 46 | cache = [set() for _ in range(14)] 47 | VARS = "wxyz" 48 | 49 | 50 | def run(sec_no: int, pz: int) -> int: 51 | if sec_no == N or pz > 10 ** 7: 52 | return 0 if pz == 0 else -1 53 | 54 | if pz in cache[sec_no]: return -1 55 | 56 | for w in range(9, 0, -1): 57 | # for w in range(1, 10): 58 | vs = [0] * 4 59 | vs[0] = w 60 | vs[3] = pz 61 | 62 | for fun in opt_secs[sec_no]: fun(vs) 63 | 64 | r = run(sec_no + 1, vs[-1]) 65 | if r != -1: 66 | return r + w * 10 ** (N - sec_no - 1) 67 | 68 | cache[sec_no].add(pz) 69 | return -1 70 | 71 | 72 | print(run(0, 0)) 73 | -------------------------------------------------------------------------------- /24_z3_optimized.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | from util import * 3 | 4 | program = sys.stdin.read() 5 | 6 | parts = re.findall( 7 | r"""inp w 8 | mul x 0 9 | add x z 10 | mod x 26 11 | div z (\d+) 12 | add x (-?\d+) 13 | eql x w 14 | eql x 0 15 | mul y 0 16 | add y 25 17 | mul y x 18 | add y 1 19 | mul z y 20 | mul y 0 21 | add y w 22 | add y (\d+) 23 | mul y x 24 | add z y""", 25 | program, 26 | ) 27 | 28 | from z3 import * 29 | 30 | o = Optimize() 31 | 32 | cnt = Counter() 33 | VARS = "zw" 34 | vs = {(c, 0): Int(f"{c}_0") for c in VARS} 35 | o.add(vs[("z", 0)] == 0) 36 | 37 | 38 | def pv(x): 39 | a = vs[(x, cnt[x])] 40 | cnt[x] += 1 41 | b = Int(f"{x}_{cnt[x]}") 42 | vs[(x, cnt[x])] = b 43 | return a, b 44 | 45 | 46 | """ 47 | x = ((z % 26) + add_x != w) 48 | z //= div_z 49 | 50 | if x == 1: 51 | z = z * 26 + (w + add_y) 52 | """ 53 | 54 | objective = 0 55 | 56 | for args in parts: 57 | div, addx, addy = map(int, args) 58 | 59 | _, w = pv("w") 60 | objective = objective * 10 + w 61 | o.add(And(1 <= w, w <= 9)) 62 | 63 | pz, nz = pv("z") 64 | o.add(nz == If((pz % 26) + addx == w, pz / div, pz / div * 26 + w + addy)) 65 | 66 | o.add(nz == 0) 67 | 68 | print(o) 69 | 70 | o.set(priority="box") 71 | max_id = o.maximize(objective) 72 | min_id = o.minimize(objective) 73 | 74 | assert o.check() == sat 75 | 76 | print(max_id.value()) 77 | print(min_id.value()) 78 | -------------------------------------------------------------------------------- /25.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env pypy3 2 | # 69/54 3 | from util import * 4 | 5 | if len(sys.argv) == 1: 6 | sys.stdin = open(__file__.replace("py", "in")) 7 | 8 | G = [l.rstrip() for l in sys.stdin] 9 | H = len(G) 10 | W = len(G[0]) 11 | 12 | D = {(x, y): c for y, l in enumerate(G) for x, c in enumerate(l) if c != '.'} 13 | 14 | for step in range(10 ** 10): 15 | moved = False 16 | 17 | for C, (dx, dy) in zip('>v', ((1, 0), (0, 1))): 18 | ND = {} 19 | for (x, y), c in D.items(): 20 | nx, ny = (x + dx) % W, (y + dy) % H 21 | if c == C and (nx, ny) not in D: 22 | moved = True 23 | ND[(nx, ny)] = c 24 | else: 25 | ND[(x, y)] = c 26 | 27 | D = ND 28 | #print_coords(D, '.') 29 | 30 | if not moved: 31 | prints(step+1) 32 | break 33 | 34 | 35 | -------------------------------------------------------------------------------- /get_input.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Fetch AoC input as files. 5 | By Asger Hautop Drewsen: https://github.com/Tyilo 6 | """ 7 | 8 | import sys 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | import requests 13 | 14 | YEAR = 2021 15 | URL_PREFIX = f"https://adventofcode.com/{YEAR}" 16 | 17 | 18 | def validate_session(session): 19 | test_url = f"{URL_PREFIX}/settings" 20 | r = session.get(test_url) 21 | return r.status_code == 200 and r.url == test_url 22 | 23 | 24 | session_cookie: Optional[str] 25 | try: 26 | with open(".session", "r") as f: 27 | session_cookie = f.read().strip() 28 | except FileNotFoundError: 29 | session_cookie = None 30 | 31 | while True: 32 | if not session_cookie: 33 | session_cookie = input("Session cookie value: ").strip() 34 | with open(".session", "w") as f: 35 | f.write(session_cookie) 36 | 37 | session = requests.Session() 38 | session.cookies.set("session", session_cookie, domain=".adventofcode.com", path="/") 39 | if validate_session(session): 40 | break 41 | 42 | print("That session cookie doesn't seem to work. Try again.") 43 | session_cookie = None 44 | 45 | 46 | for i in range(1, 26): 47 | path = Path(f"{i:02}.in") 48 | 49 | if path.exists(): 50 | continue 51 | 52 | r = session.get(f"{URL_PREFIX}/day/{i}/input") 53 | if r.ok: 54 | with path.open("wb") as fb: 55 | fb.write(r.content) 56 | print(f"Downloaded {path.name}") 57 | else: 58 | if r.status_code == 404: 59 | print(f"Day {i} not released yet") 60 | break 61 | else: 62 | sys.exit(f"Got unknown status code: {r.status_code}\n{r.text.strip()}") 63 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | import sys 5 | from collections import Counter, defaultdict, deque 6 | from functools import lru_cache, total_ordering 7 | from heapq import * 8 | from itertools import combinations 9 | from itertools import combinations_with_replacement as combr 10 | from itertools import cycle, permutations, product, repeat 11 | from typing import (Any, Callable, Collection, DefaultDict, Dict, Generic, 12 | Iterable, Iterator, List, Mapping, MutableSet, Sequence, 13 | Tuple, TypeVar) 14 | 15 | sys.setrecursionlimit(1 << 30) 16 | 17 | DIR = ((1, 0), (0, 1), (-1, 0), (0, -1)) 18 | 19 | 20 | def ints(inp: str = None) -> Iterator[int]: 21 | return map(int, re.findall(r"-?\d+", inp or sys.stdin.read())) 22 | 23 | 24 | def floats(inp: str = None) -> Iterator[float]: 25 | return map(float, re.findall(r"-?\d+(?:\.\d*)?", inp or sys.stdin.read())) 26 | 27 | 28 | def lines(inp: str = None) -> List[str]: 29 | return (inp or sys.stdin.read()).splitlines() 30 | 31 | 32 | def prints(*args): 33 | """ 34 | Function for printing the solution to a puzzle. 35 | Also copies the solution to the clipboard. 36 | """ 37 | from subprocess import run 38 | 39 | ans = " ".join(map(str, args)) 40 | print(ans) 41 | run(["xsel", "-bi"], input=ans, check=True, text=True) 42 | print("(Copied to clipboard)") 43 | 44 | 45 | T = TypeVar("T", int, float) 46 | 47 | 48 | def sign(x: T) -> int: 49 | return (x > 0) - (x < 0) 50 | 51 | 52 | @total_ordering 53 | class Point(Generic[T]): 54 | c: List[T] 55 | __slots__ = ("c",) 56 | 57 | def __init__(self, c: List[T]): 58 | self.c = c 59 | 60 | @classmethod 61 | def of(cls, *c: T) -> Point[T]: 62 | return cls(list(c)) 63 | 64 | # Points are generally immutable except that you can set coordinates 65 | 66 | @property 67 | def x(s) -> T: 68 | return s.c[0] 69 | 70 | @x.setter 71 | def x(s, v: T): 72 | s.c[0] = v 73 | 74 | @property 75 | def y(s) -> T: 76 | return s.c[1] 77 | 78 | @y.setter 79 | def y(s, v: T): 80 | s.c[1] = v 81 | 82 | @property 83 | def z(s) -> T: 84 | return s.c[2] 85 | 86 | @z.setter 87 | def z(s, v: T): 88 | s.c[2] = v 89 | 90 | # Standard object methods 91 | 92 | def __lt__(s, o: Point[T]) -> bool: 93 | return s.c < o.c 94 | 95 | def __eq__(s, o) -> bool: 96 | return isinstance(o, Point) and s.c == o.c 97 | 98 | def __hash__(s) -> int: 99 | return hash(tuple(s.c)) 100 | 101 | def __str__(s) -> str: 102 | return f'({", ".join(map(str, s))})' 103 | 104 | def __repr__(s) -> str: 105 | return f"Point({s.c})" 106 | 107 | def __len__(s) -> int: 108 | return len(s.c) 109 | 110 | def __iter__(s) -> Iterator[T]: 111 | return iter(s.c) 112 | 113 | def __getitem__(s, key): 114 | return s.c[key] 115 | 116 | # Geometry stuff 117 | 118 | def __add__(s, o: Iterable[T]) -> Point[T]: 119 | return Point([a + b for a, b in zip(s, o)]) 120 | 121 | def __sub__(s, o: Iterable[T]) -> Point[T]: 122 | return Point([a - b for a, b in zip(s, o)]) 123 | 124 | def __neg__(s) -> Point[T]: 125 | return Point([-x for x in s]) 126 | 127 | def __abs__(s) -> Point[T]: 128 | return Point.of(*map(lambda x: abs(x), s)) 129 | 130 | def __mul__(s, d: T) -> Point[T]: 131 | return Point([a * d for a in s]) 132 | 133 | __rmul__ = __mul__ 134 | 135 | def __floordiv__(s, d: T) -> Point[T]: 136 | return Point([a // d for a in s]) 137 | 138 | def __truediv__(s, d: T) -> Point[float]: 139 | return Point([a / d for a in s]) 140 | 141 | def dot(s, o: Iterable[T]) -> T: 142 | return sum(a * b for a, b in zip(s, o)) 143 | 144 | __matmul__ = dot 145 | 146 | def cross(a, b: Point[T]) -> T: 147 | assert len(a) == 2 148 | return a.x * b.y - a.y * b.x 149 | 150 | def cross2(s, a: Point[T], b: Point[T]) -> T: 151 | return (a - s).cross(b - s) 152 | 153 | def cross_3d(a, b: Point[T]) -> Point[T]: 154 | assert len(a) == 3 155 | return Point.of( 156 | a.y * b.z - a.z * b.y, -a.x * b.z + a.z * b.x, a.x * b.y - a.y * b.x 157 | ) 158 | 159 | def cross2_3d(s, a: Point[T], b: Point[T]) -> Point[T]: 160 | return (a - s).cross_3d(b - s) 161 | 162 | def manh_dist(s) -> T: 163 | return sum(map(lambda x: abs(x), s)) 164 | 165 | def dist2(s) -> T: 166 | return sum(x * x for x in s) 167 | 168 | def dist(s) -> float: 169 | return s.dist2() ** 0.5 170 | 171 | 172 | def make_adj(edges, both=False) -> DefaultDict[Any, List]: 173 | adj = defaultdict(list) 174 | for a, b in edges: 175 | adj[a].append(b) 176 | if both: 177 | adj[b].append(a) 178 | return adj 179 | 180 | 181 | def make_wadj(edges, both=False) -> DefaultDict[Any, List[Tuple[Any, Any]]]: 182 | adj = defaultdict(list) 183 | for a, b, w in edges: 184 | adj[a].append((b, w)) 185 | if both: 186 | adj[b].append((a, w)) 187 | return adj 188 | 189 | 190 | def bfs(s, adj): 191 | D: DefaultDict[Any, float] = defaultdict(lambda: float("inf")) 192 | D[s] = 0 193 | Q = [s] 194 | for i in Q: 195 | d = D[i] 196 | for j in adj[i]: 197 | if j in D: 198 | continue 199 | D[j] = d + 1 200 | Q.append(j) 201 | return D, Q 202 | 203 | 204 | def dijkstra(s, adj): 205 | D: DefaultDict[Any, float] = defaultdict(lambda: float("inf")) 206 | V = set() 207 | D[s] = 0 208 | Q = [(0, s)] 209 | while Q: 210 | d, i = heappop(Q) 211 | if i in V: 212 | continue 213 | V.add(i) 214 | for j, w in adj[i]: 215 | if j in V: 216 | continue 217 | nd = d + w 218 | if nd >= D[j]: 219 | continue 220 | D[j] = nd 221 | heappush(Q, (nd, j)) 222 | return D 223 | 224 | 225 | def topsort(adj): 226 | indeg: DefaultDict[Any, int] = defaultdict(int) 227 | for i, l in adj.items(): 228 | for j in l: 229 | indeg[j] += 1 230 | Q = [i for i in adj if indeg[i] == 0] 231 | for i in Q: 232 | for j in adj[i]: 233 | indeg[j] -= 1 234 | if indeg[j] == 0: 235 | Q.append(j) 236 | return Q 237 | 238 | 239 | _U = TypeVar("_U") 240 | 241 | 242 | def tile(L: Sequence[_U], S: int) -> List[Sequence[_U]]: 243 | assert len(L) % S == 0 244 | return [L[i : i + S] for i in range(0, len(L), S)] 245 | 246 | 247 | def print_coords(L: Collection[Tuple[int, int]], empty = ' '): 248 | import collections.abc 249 | 250 | xs, ys = zip(*L) 251 | min_x, max_x = min(xs), max(xs) 252 | min_y, max_y = min(ys), max(ys) 253 | print("X", min_x, max_x) 254 | print("Y", min_y, max_y) 255 | 256 | R = [[empty] * (max_x - min_x + 1) for _ in range(max_y - min_y + 1)] 257 | 258 | if isinstance(L, collections.abc.Mapping): 259 | for (x, y), c in L.items(): 260 | assert len(c) == 1, ((x, y), c) 261 | R[y - min_y][x - min_x] = c 262 | else: 263 | for x, y in L: 264 | R[y - min_y][x - min_x] = "#" 265 | 266 | print(*map("".join, R), sep="\n") 267 | 268 | 269 | def binary_search(f: Callable[[int], bool], lo: int, hi: int = None) -> int: 270 | " Returns the first i >= lo such that f(i) == True " 271 | if hi is None: 272 | assert lo >= 0 273 | hi = lo + 1 274 | while not f(hi): 275 | lo, hi = hi, hi * 2 276 | 277 | assert lo <= hi 278 | while lo < hi: 279 | m = (lo + hi) // 2 280 | if f(m): 281 | hi = m 282 | else: 283 | lo = m + 1 284 | 285 | return lo 286 | 287 | 288 | def neighbours( 289 | x: int, y: int, dirs: Iterable[Tuple[int, int]] = DIR, V=None 290 | ) -> Iterator[Tuple[int, int]]: 291 | for dx, dy in dirs: 292 | nx, ny = x + dx, y + dy 293 | if V is None or (nx, ny) in V: 294 | yield nx, ny 295 | --------------------------------------------------------------------------------