├── .gitignore ├── python ├── tak │ ├── model │ │ ├── __init__.py │ │ ├── grpc.py │ │ ├── heads.py │ │ ├── batches.py │ │ ├── losses.py │ │ ├── wrapper.py │ │ └── server.py │ ├── alphazero │ │ ├── __init__.py │ │ ├── hooks │ │ │ ├── __init__.py │ │ │ ├── timing.py │ │ │ ├── wandb.py │ │ │ ├── test_loss.py │ │ │ └── saving.py │ │ ├── stats.py │ │ ├── losses.py │ │ ├── config.py │ │ └── data.py │ ├── symmetry │ │ ├── __init__.py │ │ └── symmetry.py │ ├── ptn │ │ ├── __init__.py │ │ └── tps.py │ ├── __init__.py │ ├── proto │ │ ├── corpus_entry_pb2_grpc.py │ │ ├── corpus_entry_pb2.py │ │ ├── analysis_pb2.py │ │ └── analysis_pb2_grpc.py │ ├── pieces.py │ └── moves.py ├── pytest.ini ├── setup.cfg ├── xformer │ ├── __init__.py │ ├── train │ │ ├── __init__.py │ │ ├── hooks │ │ │ ├── __init__.py │ │ │ ├── wandb.py │ │ │ ├── saving.py │ │ │ ├── test_loss.py │ │ │ └── profile.py │ │ ├── lr_schedules.py │ │ ├── wandb.py │ │ ├── run.py │ │ └── trainer.py │ ├── yaml_ext.py │ ├── loading.py │ └── data │ │ └── __init__.py ├── .gitignore ├── bench │ ├── load_data.py │ ├── symmetry.py │ ├── probs.py │ └── mcts.py ├── setup.py ├── requirements.txt ├── test │ ├── test_yaml.py │ ├── test_xformer.py │ ├── test_encoding.py │ ├── corpus.csv │ ├── test_mcts.py │ ├── test_move.py │ └── test_data.py ├── .gitattributes ├── pyproject.toml ├── scripts │ ├── resume_alphazero │ ├── analysis_server │ └── self_play.py └── ext │ └── tak.cpp ├── testdata ├── openings │ ├── 5 │ │ ├── 1.ptn │ │ ├── 2.ptn │ │ ├── 4.ptn │ │ ├── 5.ptn │ │ ├── 6.ptn │ │ ├── 13.ptn │ │ ├── 15.ptn │ │ ├── 16.ptn │ │ ├── 17.ptn │ │ ├── 18.ptn │ │ ├── 19.ptn │ │ ├── 20.ptn │ │ ├── 11.ptn │ │ ├── 12.ptn │ │ ├── 3.ptn │ │ ├── 7.ptn │ │ ├── 10.ptn │ │ ├── 8.ptn │ │ ├── 9.ptn │ │ ├── 14.ptn │ │ └── openings.txt │ ├── 6 │ │ ├── 6.ptn │ │ ├── 7.ptn │ │ ├── 9.ptn │ │ ├── 1.ptn │ │ ├── 2.ptn │ │ ├── 3.ptn │ │ ├── 4.ptn │ │ ├── 5.ptn │ │ ├── 8.ptn │ │ └── openings.txt │ └── generate.sh ├── zoo │ ├── 5x5-open.ptn │ ├── 6x6-open.ptn │ ├── puzzle-2016-06-22.ptn │ ├── puzzle-2016-06-10.ptn │ ├── _pending │ │ ├── 79555.ptn │ │ ├── 82415.ptn │ │ ├── 78016.ptn │ │ ├── 74639.ptn │ │ ├── 82143.ptn │ │ └── 76201.ptn │ ├── 90504.ptn │ ├── 53752.ptn │ ├── puzzle-2016-09-15.ptn │ ├── puzzle-2016-05-26-6x6.ptn │ ├── 70709.ptn │ ├── 2016-round7-game2.ptn │ ├── 100675.ptn │ ├── match1-game1-PonchoPal.ptn │ ├── 115508.ptn │ ├── 86085.ptn │ ├── 74359.ptn │ └── 117569.ptn ├── ai │ ├── regression.ptn │ ├── puzzle1.ptn │ ├── puzzle2.ptn │ ├── ally-table.ptn │ ├── benwo-v1.ptn │ ├── kakaburra-v3-short.ptn │ ├── kakaburra-v1.ptn │ ├── applemonkeyman_viper.ptn │ ├── applemonkeyman-v4.ptn │ ├── kakaburra-v3.ptn │ ├── asgardiator-v4.ptn │ ├── treffnon-v2.ptn │ └── fwwwwibib-v3.ptn └── puzzles │ ├── puzzle-2016-11-14.ptn │ ├── puzzle-2016-06-22.ptn │ ├── puzzle-2016-06-10.ptn │ ├── puzzle1.ptn │ ├── puzzle-2017-06-21.ptn │ ├── puzzle2.ptn │ ├── puzzle3.ptn │ ├── puzzle-2016-09-15.ptn │ └── puzzle-2016-05-26-6x6.ptn ├── .dockerignore ├── ai ├── types.go ├── random.go ├── json.go ├── json_test.go ├── mcts │ ├── policy_test.go │ └── debug.go ├── opening_test.go ├── moves_test.go ├── feature_string.go ├── moves.go ├── opening.go └── minimax_test.go ├── Dockerfile ├── bitboard ├── bits_19.go ├── bits_18.go ├── bits.go └── bits_test.go ├── bin ├── build-proto ├── browse-ptn ├── sync └── import-rankings ├── tei ├── time.go └── server_test.go ├── proto └── tak │ └── proto │ ├── analysis.proto │ ├── corpus_entry.proto │ └── taktician.proto ├── cmd ├── internal │ ├── playtak │ │ ├── parse.go │ │ ├── book.go │ │ └── taktician.go │ ├── selfplay │ │ └── stats.go │ ├── tei │ │ └── tei.go │ ├── importptn │ │ └── sql.go │ ├── canonicalize │ │ └── main.go │ ├── opt │ │ └── opt.go │ └── genopenings │ │ └── genopenings.go └── taktician │ └── main.go ├── .travis.yml ├── tests ├── helpers.go ├── bench_test.go ├── play_game_test.go └── hash_test.go ├── Gopkg.toml ├── tak ├── movetype_string.go ├── slide_test.go ├── hash_test.go ├── slide.go ├── pieces.go ├── hash.go └── alloc.go ├── .circleci └── config.yml ├── cli └── player.go ├── logs ├── repository.go └── sql.go ├── .github └── workflows │ └── main.yml ├── go.mod ├── COPYING ├── Makefile ├── playtak ├── commands.go ├── client_test.go ├── move_test.go └── bot │ └── mock_test.go ├── ptn ├── iterator.go ├── iterator_test.go └── move_test.go ├── README.md ├── doc ├── friendly.md └── bitboards.md └── taktest └── utils.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.test 2 | -------------------------------------------------------------------------------- /python/tak/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = test 3 | -------------------------------------------------------------------------------- /testdata/openings/5/1.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e5 d3 4 | -------------------------------------------------------------------------------- /testdata/openings/5/2.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e5 c3 4 | -------------------------------------------------------------------------------- /testdata/openings/5/4.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e5 d4 4 | -------------------------------------------------------------------------------- /testdata/openings/5/5.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e5 d2 4 | -------------------------------------------------------------------------------- /testdata/openings/5/6.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e5 e3 4 | -------------------------------------------------------------------------------- /testdata/openings/6/6.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | 3 | a1 f1 f2 4 | -------------------------------------------------------------------------------- /testdata/openings/6/7.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | 3 | a1 f1 d4 4 | -------------------------------------------------------------------------------- /testdata/openings/6/9.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | 3 | a1 f1 e2 4 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | **/*.test 2 | games 3 | _scratch 4 | _train 5 | -------------------------------------------------------------------------------- /python/tak/alphazero/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import Config 2 | -------------------------------------------------------------------------------- /python/tak/symmetry/__init__.py: -------------------------------------------------------------------------------- 1 | from .symmetry import * 2 | -------------------------------------------------------------------------------- /testdata/openings/5/13.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 e2 d3 4 | -------------------------------------------------------------------------------- /testdata/openings/5/15.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 d2 a2 4 | -------------------------------------------------------------------------------- /testdata/openings/5/16.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 d2 b4 4 | -------------------------------------------------------------------------------- /testdata/openings/5/17.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 c3 4 | -------------------------------------------------------------------------------- /testdata/openings/5/18.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 d3 4 | -------------------------------------------------------------------------------- /testdata/openings/5/19.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 e3 e2 4 | -------------------------------------------------------------------------------- /testdata/openings/5/20.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 b4 c3 4 | -------------------------------------------------------------------------------- /testdata/openings/6/1.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | 3 | a1 f6 d4 d3 4 | -------------------------------------------------------------------------------- /testdata/openings/6/2.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | 3 | a1 f1 e3 d4 4 | -------------------------------------------------------------------------------- /testdata/openings/6/3.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | 3 | a1 f1 e3 e2 4 | -------------------------------------------------------------------------------- /testdata/openings/6/4.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | 3 | a1 f1 e3 e4 4 | -------------------------------------------------------------------------------- /testdata/openings/6/5.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | 3 | a1 f1 e3 Cd4 4 | -------------------------------------------------------------------------------- /testdata/openings/5/11.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 e2 e3 d3 4 | -------------------------------------------------------------------------------- /testdata/openings/5/12.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 e2 d4 d2 4 | -------------------------------------------------------------------------------- /testdata/openings/5/3.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e5 e4 e3 d4 4 | -------------------------------------------------------------------------------- /testdata/openings/5/7.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 e2 e3 d2 c2 4 | -------------------------------------------------------------------------------- /testdata/openings/6/8.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | 3 | a1 f1 d3 c3 d4 4 | -------------------------------------------------------------------------------- /python/tak/ptn/__init__.py: -------------------------------------------------------------------------------- 1 | from .ptn import * 2 | from .tps import * 3 | -------------------------------------------------------------------------------- /testdata/openings/5/10.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 e2 e3 d2 b1 d3 4 | -------------------------------------------------------------------------------- /testdata/openings/5/8.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 e2 e3 d2 a2 d3 d1 4 | -------------------------------------------------------------------------------- /testdata/openings/5/9.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 e2 e3 d2 a2 d3 c2 4 | -------------------------------------------------------------------------------- /testdata/openings/5/14.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | 3 | a1 e1 e2 b4 e3 d2 e4 e5 4 | -------------------------------------------------------------------------------- /python/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = W503,E203 3 | 4 | [isort] 5 | skip = tak/proto/ 6 | -------------------------------------------------------------------------------- /python/xformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Config, Transformer 2 | 3 | from . import yaml_ext 4 | -------------------------------------------------------------------------------- /python/xformer/train/__init__.py: -------------------------------------------------------------------------------- 1 | from . import hooks 2 | from .run import * 3 | from .trainer import * 4 | -------------------------------------------------------------------------------- /testdata/zoo/5x5-open.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Depth "8"] 3 | [Move "1 white"] 4 | [Move "1 black"] 5 | 6 | 1. a1 7 | -------------------------------------------------------------------------------- /testdata/zoo/6x6-open.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | [Depth "7"] 3 | [Move "1 white"] 4 | [Move "1 black"] 5 | 6 | 1. a1 7 | -------------------------------------------------------------------------------- /testdata/ai/regression.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [TPS "2,x4/x2,2,x2/x,2,2,x2/x2,12,2,1/1,1,21,2,1 1 9"] 3 | [Depth "3"] 4 | -------------------------------------------------------------------------------- /python/tak/__init__.py: -------------------------------------------------------------------------------- 1 | from .game import * 2 | from .moves import * 3 | from .pieces import * 4 | 5 | VERSION = "0.1" 6 | -------------------------------------------------------------------------------- /testdata/puzzles/puzzle-2016-11-14.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | [TPS "2,x4,1/2,x2,1,2C,x/2,x,2S,21C,1,x/1,x,12,112S,1,1/x2,21,1221,2,x/x3,1,2,x 1 8"] 3 | -------------------------------------------------------------------------------- /testdata/zoo/puzzle-2016-06-22.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [TPS "1,1,x3/1,1,x,2,1/x,1,x3/x,2,x2,2/2,1C,2S,2,2 1 3"] 3 | [Depth "9"] 4 | [Move "0"] 5 | [GoodMove "b3+"] 6 | -------------------------------------------------------------------------------- /python/xformer/train/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .profile import Profile 2 | from .saving import Save 3 | from .test_loss import TestLoss 4 | from .wandb import Wandb 5 | -------------------------------------------------------------------------------- /testdata/puzzles/puzzle-2016-06-22.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [TPS "1,1,x3/1,1,x,2,1/x,1,x3/x,2,x2,2/2,1C,2S,2,2 1 3"] 3 | [Depth "9"] 4 | [Move "0"] 5 | [GoodMove "b3+"] 6 | -------------------------------------------------------------------------------- /python/.gitignore: -------------------------------------------------------------------------------- 1 | .cache/ 2 | __pycache__ 3 | *.egg-info/ 4 | .coverage 5 | htmlcov/ 6 | .pytest_cache 7 | *.pyc 8 | *.so 9 | _feat 10 | /build 11 | /dist/ 12 | -------------------------------------------------------------------------------- /python/tak/alphazero/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .timing import TimingHook 2 | from .wandb import WandB 3 | from .saving import SavingHook 4 | from .test_loss import TestLoss 5 | -------------------------------------------------------------------------------- /testdata/zoo/puzzle-2016-06-10.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [TPS "1,1,x,1,x/2,x,1,x,1/2,1,x,1,x/1,2,2,1,212S/2,2,2,2C,1C 2 2"] 3 | [Depth "9"] 4 | [Move "0"] 5 | [GoodMove "2e2<"] 6 | -------------------------------------------------------------------------------- /testdata/puzzles/puzzle-2016-06-10.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [TPS "1,1,x,1,x/2,x,1,x,1/2,1,x,1,x/1,2,2,1,212S/2,2,2,2C,1C 2 2"] 3 | [Depth "9"] 4 | [Move "0"] 5 | [GoodMove "2e2<"] 6 | -------------------------------------------------------------------------------- /testdata/openings/6/openings.txt: -------------------------------------------------------------------------------- 1 | a1 f6 d4 d3 2 | a1 f1 e3 d4 3 | a1 f1 e3 e2 4 | a1 f1 e3 e4 5 | a1 f1 e3 Cd4 6 | a1 f1 f2 7 | a1 f1 d4 8 | a1 f1 d3 c3 d4 9 | a1 f1 e2 10 | -------------------------------------------------------------------------------- /testdata/ai/puzzle1.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [TPS "2,x2,121C,1/x2,2,12,1/x2,2,12S,2/x3,1,1/x4,1 1 2"] 3 | [Depth "7"] 4 | [Seed "1462667932"] 5 | [Move "1 white"] 6 | [GoodMove "2d5-11"] 7 | 8 | 1. 9 | -------------------------------------------------------------------------------- /testdata/puzzles/puzzle1.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [TPS "2,x2,121C,1/x2,2,12,1/x2,2,12S,2/x3,1,1/x4,1 1 2"] 3 | [Depth "7"] 4 | [Seed "1462667932"] 5 | [Move "1 white"] 6 | [GoodMove "2d5-11"] 7 | 8 | 1. 9 | -------------------------------------------------------------------------------- /python/bench/load_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import tak.train 4 | 5 | 6 | def main(args): 7 | tak.train.load_corpus(args[1], True) 8 | 9 | 10 | if __name__ == "__main__": 11 | main(sys.argv) 12 | -------------------------------------------------------------------------------- /python/tak/proto/corpus_entry_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | -------------------------------------------------------------------------------- /ai/types.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/nelhage/taktician/tak" 7 | ) 8 | 9 | type TakPlayer interface { 10 | GetMove(ctx context.Context, p *tak.Position) tak.Move 11 | } 12 | -------------------------------------------------------------------------------- /python/tak/alphazero/stats.py: -------------------------------------------------------------------------------- 1 | from attrs import define, field 2 | from typing import Any 3 | 4 | 5 | @define(slots=False) 6 | class Elapsed: 7 | step: int = 0 8 | positions: int = 0 9 | epoch: int = 0 10 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils import cpp_extension 3 | 4 | setup( 5 | ext_modules=[cpp_extension.CppExtension('tak_ext', ['ext/tak.cpp'])], 6 | cmdclass={'build_ext': cpp_extension.BuildExtension} 7 | ) 8 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:alpine 2 | RUN apk update && apk add gcc libc-dev sqlite-dev 3 | 4 | ADD . /go/src/github.com/nelhage/taktician/ 5 | WORKDIR /go/src/github.com/nelhage/taktician/ 6 | 7 | RUN go install -v github.com/nelhage/taktician/... 8 | -------------------------------------------------------------------------------- /bitboard/bits_19.go: -------------------------------------------------------------------------------- 1 | // +build go1.9 2 | // +build !test_18 3 | 4 | package bitboard 5 | 6 | import "math/bits" 7 | 8 | func Popcount(x uint64) int { 9 | return bits.OnesCount64(x) 10 | } 11 | 12 | func TrailingZeros(x uint64) uint { 13 | return uint(bits.TrailingZeros64(x)) 14 | } 15 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | --index-url https://pypi.python.org/simple/ 2 | attrs 3 | pytest 4 | pytest-cov 5 | numpy 6 | protobuf==3.20.* 7 | grpcio 8 | grpcio-tools 9 | black>=22.3.0 10 | python-lsp-server>=1.4.1 11 | zstandard 12 | py-spy 13 | torch==1.12.* 14 | wandb 15 | tqdm 16 | shed 17 | -------------------------------------------------------------------------------- /python/test/test_yaml.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | import io 4 | import xformer.yaml_ext # noqa 5 | 6 | 7 | def test_yaml_dtype(): 8 | out = yaml.dump({"dtype": torch.float32}) 9 | got = yaml.unsafe_load(io.StringIO(out)) 10 | assert got["dtype"] == torch.float32 11 | -------------------------------------------------------------------------------- /python/bench/symmetry.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import tak.ptn 4 | import tak.symmetry 5 | 6 | 7 | def main(args): 8 | p = tak.ptn.parse_tps(args[1]) 9 | for i in range(1000): 10 | tak.symmetry.symmetries(p) 11 | 12 | 13 | if __name__ == "__main__": 14 | main(sys.argv) 15 | -------------------------------------------------------------------------------- /python/.gitattributes: -------------------------------------------------------------------------------- 1 | data/graphs/*.data-* filter=lfs diff=lfs merge=lfs -text 2 | data/graphs/*.meta filter=lfs diff=lfs merge=lfs -text 3 | data/graphs/*.index filter=lfs diff=lfs merge=lfs -text 4 | data/corpus/*/*.csv filter=lfs diff=lfs merge=lfs -text 5 | data/corpus/*/*.dat filter=lfs diff=lfs merge=lfs -text 6 | -------------------------------------------------------------------------------- /testdata/puzzles/puzzle-2017-06-21.ptn: -------------------------------------------------------------------------------- 1 | [Site "PlayTak.com"] 2 | [Event "Online Play"] 3 | [Date "2017.06.21"] 4 | [Time "19:58:33"] 5 | [Player1 "White"] 6 | [Player2 "Black"] 7 | [Clock "1:0 +5"] 8 | [Result "R-0"] 9 | [Size "5"] 10 | [TPS "1,1,1,x2/2,112,1,2,x/x2,121C,1,x/x3,1,2/2,2,2C,121S,x 2 14"] 11 | 12 | 14. d4- -------------------------------------------------------------------------------- /testdata/openings/generate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eu 3 | dir=$1 4 | size=$(basename "$dir") 5 | rm -f "$dir/*.ptn" 6 | 7 | i=1 8 | while read line; do 9 | { 10 | echo "[Size \"$size\"]" 11 | echo 12 | echo "$line" 13 | } > "$dir/$i.ptn" 14 | let i=i+1 15 | done < "$dir/openings.txt" 16 | -------------------------------------------------------------------------------- /testdata/ai/puzzle2.ptn: -------------------------------------------------------------------------------- 1 | [Player1 "White"] 2 | [Player2 "Black"] 3 | [Result "R-0"] 4 | [Size "5"] 5 | [Depth "7"] 6 | [Move "12 white"] 7 | [GoodMove "b5-"] 8 | 9 | 1. a1 e5 10 | 2. d3 b2 11 | 3. d2 a3 12 | 4. d1 d4 13 | 5. e4 e3 14 | 6. Cc4 1d4-1 15 | 7. c3 c2 16 | 8. d4 2d3-2 17 | 9. b4 a4 18 | 10. a5 1a4>1 19 | 11. b5 c5 20 | 12. -------------------------------------------------------------------------------- /testdata/puzzles/puzzle2.ptn: -------------------------------------------------------------------------------- 1 | [Player1 "White"] 2 | [Player2 "Black"] 3 | [Result "R-0"] 4 | [Size "5"] 5 | [Depth "7"] 6 | [Move "12 white"] 7 | [GoodMove "b5-"] 8 | 9 | 1. a1 e5 10 | 2. d3 b2 11 | 3. d2 a3 12 | 4. d1 d4 13 | 5. e4 e3 14 | 6. Cc4 1d4-1 15 | 7. c3 c2 16 | 8. d4 2d3-2 17 | 9. b4 a4 18 | 10. a5 1a4>1 19 | 11. b5 c5 20 | 12. -------------------------------------------------------------------------------- /bin/build-proto: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eu 3 | cd "$(dirname "$0")/.." 4 | protoc \ 5 | -I proto/ \ 6 | --go_out=plugins=grpc:pb \ 7 | $(find proto/ -name '*.proto') 8 | python \ 9 | -m grpc_tools.protoc \ 10 | -I proto \ 11 | --python_out=python/ \ 12 | --grpc_python_out=python/ \ 13 | $(find proto/ -name '*.proto') 14 | -------------------------------------------------------------------------------- /bin/browse-ptn: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import urllib 4 | import subprocess 5 | 6 | PTNVIEWER = 'https://jsfiddle.net/bwochinski/043hpzwu/embedded/result/' 7 | 8 | if len(sys.argv) != 2: 9 | print "Usage: %s PTN" 10 | sys.exit(1) 11 | 12 | ptn = sys.argv[1] 13 | url = PTNVIEWER + '?ptn=' + urllib.quote(open(ptn).read()) 14 | 15 | subprocess.check_call(['xdg-open', url]) 16 | -------------------------------------------------------------------------------- /testdata/openings/5/openings.txt: -------------------------------------------------------------------------------- 1 | a1 e5 d3 2 | a1 e5 c3 3 | a1 e5 e4 e3 d4 4 | a1 e5 d4 5 | a1 e5 d2 6 | a1 e5 e3 7 | a1 e1 e2 e3 d2 c2 8 | a1 e1 e2 e3 d2 a2 d3 d1 9 | a1 e1 e2 e3 d2 a2 d3 c2 10 | a1 e1 e2 e3 d2 b1 d3 11 | a1 e1 e2 e3 d3 12 | a1 e1 e2 d4 d2 13 | a1 e1 e2 d3 14 | a1 e1 e2 b4 e3 d2 e4 e5 15 | a1 e1 d2 a2 16 | a1 e1 d2 b4 17 | a1 e1 c3 18 | a1 e1 d3 19 | a1 e1 e3 e2 20 | a1 e1 b4 c3 21 | -------------------------------------------------------------------------------- /tei/time.go: -------------------------------------------------------------------------------- 1 | package tei 2 | 3 | import ( 4 | "strconv" 5 | "time" 6 | ) 7 | 8 | type TimeControl struct { 9 | White time.Duration 10 | Black time.Duration 11 | WInc time.Duration 12 | BInc time.Duration 13 | } 14 | 15 | func formatTime(d time.Duration) string { 16 | ms := d / time.Millisecond 17 | if ms < 0 { 18 | ms = 0 19 | } 20 | return strconv.FormatUint(uint64(ms), 10) 21 | } 22 | -------------------------------------------------------------------------------- /testdata/ai/ally-table.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-05-22"] 3 | [Time "2016-05-22T19:48:11Z"] 4 | [Player1 "RTJR"] 5 | [Player2 "TakticianBot"] 6 | [Result "R-0"] 7 | [Id "34759"] 8 | [Depth "5"] 9 | [Seed "1463955533"] 10 | [Move "3 black"] 11 | [Move "4 black"] 12 | [Move "5 black"] 13 | [BadMove "d1"] 14 | 15 | 16 | 1. a1 e1 17 | 2. d4 c1 18 | 3. d3 c4 19 | 4. d5 e3 20 | 5. d2 d1 21 | 6. e1< 22 | R-0 23 | -------------------------------------------------------------------------------- /testdata/ai/benwo-v1.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-05-01"] 3 | [Player1 "BenWo"] 4 | [Player2 "TakticianBot"] 5 | [Result "R-0"] 6 | [Id "6454"] 7 | [Depth "5"] 8 | [Move "3 black"] 9 | [BadMove "a3"] 10 | 11 | 12 | 1. a1 e1 13 | 2. e2 a2 14 | 3. e3 a3 15 | 4. Ce4 Se5 16 | 5. d5 d4 17 | 6. c5 a4 18 | 7. Sa5 b5 19 | 8. a5- b3 20 | 9. 2a4- b2 21 | 10. 3a3> c4 22 | 11. e4+ e4 23 | 12. e5- d4+ 24 | 13. c5> e5< 25 | 14. e5 26 | -------------------------------------------------------------------------------- /testdata/ai/kakaburra-v3-short.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-05-08"] 3 | [Time "2016-05-08T07:31:09Z"] 4 | [Player1 "kakaburra"] 5 | [Player2 "TakticianBot"] 6 | [Result "R-0"] 7 | [Id "1787"] 8 | [Move "4 black"] 9 | 10 | 1. a5 e4 11 | 2. c3 b3 12 | 3. c4 c2 13 | 4. d3 e3 14 | 5. d2 c5 15 | 6. d4 d1 16 | 7. Ce2 e1 17 | 8. e2< c1 18 | 9. e5 e3+ 19 | 10. d5 2e4< 20 | 11. c4> c5> 21 | 12. e5< c2+ 22 | 13. d2- 23 | R-0 24 | -------------------------------------------------------------------------------- /testdata/zoo/_pending/79555.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-08-22"] 3 | [Time "2016-08-22T22:43:30Z"] 4 | [Player1 "dove_queen"] 5 | [Player2 "Abyss"] 6 | [Result "R-0"] 7 | [Id "79555"] 8 | 9 | 10 | 1. a5 a1 11 | 2. b2 c2 12 | 3. b3 Cc3 13 | 4. c4 b4 14 | 5. c5 b5 15 | 6. a2 b4- 16 | 7. Cb4 c3< 17 | 8. a4 a3 18 | 9. d4 a5- 19 | 10. e4 c3 20 | 11. a5 b5< 21 | 12. b4< b3+ 22 | 13. 3a4+ Sb5 23 | 14. 4a5-22 24 | R-0 25 | 26 | -------------------------------------------------------------------------------- /testdata/ai/kakaburra-v1.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-05-02"] 3 | [Player1 "TakticianBot"] 4 | [Player2 "kakaburra"] 5 | [Result "0-R"] 6 | [Id "7249"] 7 | [Depth "6"] 8 | [Move "3 white"] 9 | [BadMove "a3"] 10 | 11 | 12 | 1. a1 e5 13 | 2. a2 b2 14 | 3. a3 c3 15 | 4. a4 b1 16 | 5. a5 b3 17 | 6. b4 c4 18 | 7. c5 Cd4 19 | 8. a2> b1+ 20 | 9. d5 Sb5 21 | 10. e4 c2 22 | 11. c1 b1 23 | 12. d3 d4< 24 | 13. d3< b3> 25 | 14. c1< a1> 26 | 15. a1 c4+ 27 | -------------------------------------------------------------------------------- /proto/tak/proto/analysis.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tak.proto; 4 | option go_package="github.com/nelhage/taktician/pb/"; 5 | 6 | service Analysis { 7 | rpc Evaluate(EvaluateRequest) returns (EvaluateResponse) {} 8 | } 9 | 10 | message EvaluateRequest { 11 | repeated int32 position = 1; 12 | } 13 | 14 | message EvaluateResponse { 15 | repeated float move_probs = 1; 16 | float value = 2; 17 | bytes move_probs_bytes = 3; 18 | } 19 | -------------------------------------------------------------------------------- /python/test/test_xformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import xformer 4 | 5 | 6 | def test_xformer(): 7 | cfg = xformer.Config( 8 | n_vocab=256, 9 | d_model=2 * 128, 10 | n_layer=3, 11 | d_head=32, 12 | ) 13 | model = xformer.Transformer(cfg) 14 | 15 | tokens = torch.randint(0, cfg.n_vocab, (4, 1024)) 16 | logits = model(tokens) 17 | 18 | assert logits.shape == (4, 1024, cfg.n_vocab) 19 | assert logits.isnan().sum() == 0, "No NaNs" 20 | -------------------------------------------------------------------------------- /cmd/internal/playtak/parse.go: -------------------------------------------------------------------------------- 1 | package playtak 2 | 3 | import ( 4 | "regexp" 5 | "strings" 6 | ) 7 | 8 | var commandRE = regexp.MustCompile(`^([^ :]+):?\s*([^ :]+):?\s*(.*)$`) 9 | 10 | func parseCommand(whoami string, msg string) (string, string) { 11 | gs := commandRE.FindStringSubmatch(msg) 12 | if gs == nil { 13 | return "", "" 14 | } 15 | if !strings.EqualFold(gs[1], whoami) && 16 | !strings.EqualFold(gs[1]+"bot", whoami) { 17 | return "", "" 18 | } 19 | return gs[2], gs[3] 20 | 21 | } 22 | -------------------------------------------------------------------------------- /testdata/ai/applemonkeyman_viper.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-08-02"] 3 | [Time "2016-08-02T16:10:53Z"] 4 | [Player1 "Guest79"] 5 | [Player2 "TakticianBot"] 6 | [Result "R-0"] 7 | [Id "72924"] 8 | [Depth "6"] 9 | [Move "4 black"] 10 | [BadMove "d5"] 11 | 12 | 13 | 1. a5 a1 14 | 2. b1 c1 15 | 3. b2 b3 16 | 4. Cc2 d5 17 | 5. d2 e2 18 | 6. d1 b3- 19 | 7. e1 e2- 20 | 8. e2 2e1+ 21 | 9. e1 c1< 22 | 10. c1 c5 23 | 11. c2< 2b1> 24 | 12. c2 3c1>12 25 | 13. d1> 3e2- 26 | 14. d1> Cb1 27 | 15. 4e1+112 Sa2 28 | 16. e5 29 | R-0 30 | -------------------------------------------------------------------------------- /testdata/puzzles/puzzle3.ptn: -------------------------------------------------------------------------------- 1 | [Site "PlayTak.com"] 2 | [Event "Online Play"] 3 | [Date "2016.09.05"] 4 | [Time "18:18:12"] 5 | [Player1 "fwwwwibib"] 6 | [Player2 "TakkerusBot"] 7 | [Result "0-R"] 8 | [Size "5"] 9 | [Depth "5"] 10 | [Move "14 black"] 11 | [GoodMove "4a4-22"] 12 | 13 | 1. a1 e1 14 | 2. e2 a2 15 | 3. e3 d3 16 | 4. e4 1d3>1 17 | 5. Cd3 2e3-2 18 | 6. d2 a3 19 | 7. a4 b4 20 | 8. b3 1b4<1 21 | 9. Sa5 Cb4 22 | 10. 1a5-1 1b4<1 23 | 11. Sa5 b4 24 | 12. 1b3<1 b5 25 | 13. 2a3-11 a3 26 | 14. 2a1+2 {Black to win} 27 | -------------------------------------------------------------------------------- /python/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "tak" 3 | description = "An implementation of Tak" 4 | license = {text = "MIT"} 5 | dynamic = ["version"] 6 | 7 | [pycodestyle] 8 | ignore = "D100,D111,D102" 9 | 10 | [tool.black] 11 | extend-exclude=''' 12 | ^/tak/proto/ 13 | ''' 14 | 15 | [build-system] 16 | requires = ["setuptools", "torch ~= 1.12.0"] 17 | build-backend = "setuptools.build_meta" 18 | 19 | [tool.setuptools] 20 | packages = ["tak", "xformer"] 21 | 22 | [tool.setuptools.dynamic] 23 | version = {attr = "tak.VERSION"} 24 | -------------------------------------------------------------------------------- /proto/tak/proto/corpus_entry.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tak.proto; 4 | option go_package="github.com/nelhage/taktician/pb"; 5 | 6 | message CorpusEntry { 7 | string day = 1; 8 | int32 id = 2; 9 | int32 ply = 3; 10 | 11 | string tps = 4; 12 | 13 | string move = 5; 14 | float value = 6; 15 | int32 plies = 7; 16 | repeated int64 features = 8; 17 | 18 | enum InTak { 19 | UNSET = 0; 20 | NOT_IN_TAK = 1; 21 | IN_TAK = 2; 22 | }; 23 | InTak in_tak = 9; 24 | } 25 | -------------------------------------------------------------------------------- /ai/random.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "math/rand" 5 | 6 | "context" 7 | 8 | "github.com/nelhage/taktician/tak" 9 | ) 10 | 11 | type RandomAI struct { 12 | r *rand.Rand 13 | } 14 | 15 | func (r *RandomAI) GetMove(ctx context.Context, p *tak.Position) tak.Move { 16 | var buffer [100]tak.Move 17 | moves := p.AllMoves(buffer[:0]) 18 | i := r.r.Int31n(int32(len(moves))) 19 | return moves[i] 20 | } 21 | 22 | func NewRandom(seed int64) TakPlayer { 23 | return &RandomAI{ 24 | r: rand.New(rand.NewSource(seed)), 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /python/xformer/yaml_ext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | 4 | 5 | def dtype_representer(dumper, data): 6 | return dumper.represent_scalar("!dtype", data.__reduce__()) 7 | 8 | 9 | def dtype_constructor(loader, node): 10 | value = loader.construct_scalar(node) 11 | dtype = getattr(torch, value, None) 12 | if not isinstance(dtype, torch.dtype): 13 | raise ValueError(f"Invalid type: {value}") 14 | return dtype 15 | 16 | 17 | yaml.add_representer(torch.dtype, dtype_representer) 18 | yaml.add_constructor("!dtype", dtype_constructor) 19 | -------------------------------------------------------------------------------- /testdata/zoo/_pending/82415.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | [Date "2016-08-31"] 3 | [Time "2016-08-31T22:40:48Z"] 4 | [Player1 "Ally"] 5 | [Player2 "tjrhodes"] 6 | [Result "R-0"] 7 | [Id "82415"] 8 | 9 | 10 | 1. a1 f1 11 | 2. e3 e4 12 | 3. Cd3 c4 13 | 4. d4 d5 14 | 5. c3 Cb3 15 | 6. e5 e6 16 | 7. f6 f5 17 | 8. d6 c5 18 | 9. e2 e6< 19 | 10. e6 e4+ 20 | 11. Se4 2e5+ 21 | 12. e5 b4 22 | 13. e5+ Se5 23 | 14. 3e6<12 b6 24 | 15. f2 e5< 25 | 16. d3+ Sd3 26 | 17. e5 f5< 27 | 18. f5 b2 28 | 19. f4 2d5>11 29 | 20. d4> d3> 30 | 21. f3 2e3> 31 | 22. e3 e5+ 32 | 23. 2d6- 33 | R-0 34 | 35 | -------------------------------------------------------------------------------- /testdata/ai/applemonkeyman-v4.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-05-14"] 3 | [Time "2016-05-14T23:53:10Z"] 4 | [Player1 "TakticianBot"] 5 | [Player2 "applemonkeyman"] 6 | [Result "0-R"] 7 | [Id "26520"] 8 | [Move "13 white"] 9 | 10 | 1. a1 b2 11 | 2. c3 a2 12 | 3. c2 a3 13 | 4. c4 a4 14 | 5. a5 b5 15 | 6. b2< b2 16 | 7. c1 Cc5 17 | 8. d4 b2> 18 | 9. c1+ b4 19 | 10. d5 c5- 20 | 11. Cd3 2c4-11 21 | 12. a5> a5 22 | 13. 2c3<11 4c2<22 23 | 14. 2a3+11 a4+ 24 | 15. Sa3 3a5> 25 | 16. b3- 5b5-122 26 | 17. a3> a3 27 | 18. 2b3< b4- 28 | 19. 3a3> 3a2+ 29 | 20. 5b3-14 a5 30 | 0-R 31 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | dist: trusty 3 | sudo: false 4 | 5 | addons: 6 | apt: 7 | sources: 8 | - deadsnakes 9 | packages: 10 | - python3.5 11 | - python-virtualenv 12 | 13 | cache: 14 | directories: 15 | - venv 16 | 17 | install: 18 | - virtualenv --python=python3.5 venv 19 | - venv/bin/pip install -r python/requirements.txt 20 | - venv/bin/pip install -e python 21 | - go get -t ./... 22 | 23 | script: 24 | - venv/bin/py.test python/test 25 | - go test -v ./... 26 | 27 | go: 28 | - 1.9 29 | - "1.10" 30 | - tip 31 | -------------------------------------------------------------------------------- /cmd/internal/selfplay/stats.go: -------------------------------------------------------------------------------- 1 | package selfplay 2 | 3 | import ( 4 | "math" 5 | "math/big" 6 | ) 7 | 8 | func binomprob(k, n int64, p float64) float64 { 9 | nk := big.NewFloat(0).SetInt(big.NewInt(0).Binomial(n, k)) 10 | nk.Mul(nk, big.NewFloat(math.Pow(p, float64(k)))) 11 | nk.Mul(nk, big.NewFloat(math.Pow(1-p, float64(n-k)))) 12 | f, _ := nk.Float64() 13 | return f 14 | } 15 | 16 | func binomTest(succ, fail int64, p float64) float64 { 17 | var r float64 18 | for t := succ; t < (fail + succ); t++ { 19 | r += binomprob(t, succ+fail, p) 20 | } 21 | return r 22 | } 23 | -------------------------------------------------------------------------------- /testdata/zoo/_pending/78016.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | [Date "2016-08-18"] 3 | [Time "2016-08-18T15:55:26Z"] 4 | [Player1 "Gerrek"] 5 | [Player2 "applemonkeyman"] 6 | [Result "R-0"] 7 | [Id "78016"] 8 | 9 | 10 | 1. a1 a6 11 | 2. c5 b5 12 | 3. Cc4 Sc3 13 | 4. b4 a4 14 | 5. b6 b3 15 | 6. d4 Cd5 16 | 7. c6 d6 17 | 8. e4 d5< 18 | 9. a3 a5 19 | 10. Sa2 d5 20 | 11. f4 d5- 21 | 12. d5 Sd3 22 | 13. d5- d3+ 23 | 14. c4> Sd5 24 | 15. c4 2c5- 25 | 16. b2 c5 26 | 17. e5 e6 27 | 18. f6 Sf5 28 | 19. e3 c2 29 | 20. f6< f5< 30 | 21. f5 e5+ 31 | 22. f6 3e6-12 32 | 23. f3 e6 33 | 24. 5d4-122 34 | R-0 35 | 36 | -------------------------------------------------------------------------------- /python/scripts/resume_alphazero: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import yaml 4 | import os.path 5 | from tak.alphazero import trainer 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser(description="Resume an alphazero run") 10 | parser.add_argument("run_dir", type=str) 11 | 12 | args = parser.parse_args() 13 | 14 | with open(os.path.join(args.run_dir, "run.yaml"), "r") as fh: 15 | config = yaml.unsafe_load(fh) 16 | 17 | train = trainer.TrainingRun(config=config) 18 | train.run() 19 | 20 | 21 | if __name__ == "__main__": 22 | main() 23 | -------------------------------------------------------------------------------- /testdata/ai/kakaburra-v3.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-05-07"] 3 | [Time "2016-05-07T20:46:45Z"] 4 | [Player1 "TakticianBot"] 5 | [Player2 "kakaburra"] 6 | [Result "0-R"] 7 | [Id "1573"] 8 | [Move "8 white"] 9 | 10 | 1. a1 a5 11 | 2. c3 c2 12 | 3. c4 b3 13 | 4. d3 d2 14 | 5. a3 b2 15 | 6. e2 d1 16 | 7. e1 e3 17 | 8. c5 a2 18 | 9. d3> Cd3 19 | 10. e1< d3> 20 | 11. c3- 3e3<12 21 | 12. e3 d4 22 | 13. a3- a3 23 | 14. 2a2+ a2 24 | 15. c4> 2c3<11 25 | 16. Sa4 4a3>13 26 | 17. b1 d3+ 27 | 18. a3 d3 28 | 19. a3- a1+ 29 | 20. 2c2<11 b2< 30 | 21. e3< 2d4- 31 | 22. 2d1+11 d2+ 32 | 23. b1+ 4a2> 33 | 24. e3 e4 34 | 0-R 35 | -------------------------------------------------------------------------------- /testdata/ai/asgardiator-v4.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-05-21"] 3 | [Time "2016-05-21T20:19:33Z"] 4 | [Player1 "asgardiator"] 5 | [Player2 "TakticianBot"] 6 | [Result "R-0"] 7 | [Id "33741"] 8 | [Move "2 black"] 9 | [BadMove "d4"] 10 | 11 | 1. a5 a1 12 | 2. b1 d4 13 | 3. c1 d1 14 | 4. d2 c2 15 | 5. e1 d1< 16 | 6. Cd1 b2 17 | 7. d1< Cd1 18 | 8. e2 c2> 19 | 9. c2 c5 20 | 10. e2< b2- 21 | 11. 2c1< d1< 22 | 12. d1 Sb2 23 | 13. b1+ 2c1> 24 | 14. 2d2> 3d1+ 25 | 15. 2b2- 4d2> 26 | 16. d2 5e2< 27 | 17. 5b1>14 b1 28 | 18. b2 e2- 29 | 19. e2 b3 30 | 20. a2 b3- 31 | 21. e2- d3 32 | 22. a1> 4d2<211 33 | 23. a1 34 | R-0 35 | -------------------------------------------------------------------------------- /python/xformer/train/lr_schedules.py: -------------------------------------------------------------------------------- 1 | from attrs import define 2 | 3 | from . import run 4 | 5 | 6 | @define 7 | class LinearWarmupCooldown: 8 | warmup_steps: int 9 | cooldown_steps: int 10 | cooldown_start: int 11 | 12 | def __call__(self, stats: run.Stats): 13 | if stats.step < self.warmup_steps: 14 | return stats.step / self.warmup_steps 15 | if stats.step > self.cooldown_start: 16 | end = self.cooldown_start + self.cooldown_steps 17 | remaining = end - stats.step 18 | return (remaining + 1) / self.cooldown_steps 19 | return 1.0 20 | -------------------------------------------------------------------------------- /bitboard/bits_18.go: -------------------------------------------------------------------------------- 1 | // +build !go1.9 test_18 2 | 3 | package bitboard 4 | 5 | func Popcount(x uint64) int { 6 | // bit population count, see 7 | // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel 8 | if x == 0 { 9 | return 0 10 | } 11 | x -= (x >> 1) & 0x5555555555555555 12 | x = (x>>2)&0x3333333333333333 + x&0x3333333333333333 13 | x += x >> 4 14 | x &= 0x0f0f0f0f0f0f0f0f 15 | x *= 0x0101010101010101 16 | return int(x >> 56) 17 | } 18 | 19 | func TrailingZeros(x uint64) uint { 20 | for i := uint(0); i < 64; i++ { 21 | if x&(1< b4 28 | 15. 2d2+ e3+ 29 | 16. d4 2e4+ 30 | 17. d5 3e5< 31 | 18. e5 4d5> 32 | 19. d5 5e5< 33 | 20. e5 e4 34 | 21. e5< e4< 35 | 22. 4d5- c4> 36 | 23. 2d5- b3 37 | 24. 3d3- 38 | R-0 39 | -------------------------------------------------------------------------------- /tests/helpers.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "path" 7 | "strings" 8 | 9 | "github.com/nelhage/taktician/ptn" 10 | ) 11 | 12 | func readPTNs(d string) (map[string]*ptn.PTN, error) { 13 | ents, e := ioutil.ReadDir(d) 14 | if e != nil { 15 | return nil, e 16 | } 17 | out := make(map[string]*ptn.PTN) 18 | for _, de := range ents { 19 | if !strings.HasSuffix(de.Name(), ".ptn") { 20 | continue 21 | } 22 | g, e := ptn.ParseFile(path.Join(d, de.Name())) 23 | if e != nil { 24 | log.Printf("parse(%s): %v", de.Name(), e) 25 | continue 26 | } 27 | out[de.Name()] = g 28 | } 29 | return out, nil 30 | } 31 | -------------------------------------------------------------------------------- /testdata/zoo/53752.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-06-15"] 3 | [Time "2016-06-15T20:07:02Z"] 4 | [Player1 "TakticianBot"] 5 | [Player2 "NohatCoder"] 6 | [Result "0-R"] 7 | [Id "53752"] 8 | [Depth "7"] 9 | [Move "4 black"] 10 | 11 | 1. e3 e5 12 | 2. d5 c4 13 | 3. c5 b5 14 | 4. d4 d3 15 | 5. b4 a4 16 | 6. c3 Cb3 17 | 7. Ca5 b3+ 18 | 8. a5> a5 19 | 9. a3 b3 20 | 10. a3+ 2b4< 21 | 11. 2b5< b5 22 | 12. c3> d2 23 | 13. 3a5> a5 24 | 14. c3 b3> 25 | 15. 2d3> Se4 26 | 16. 4b5< b5 27 | 17. b4 3a4> 28 | 18. 5a5> a5 29 | 19. c2 e4- 30 | 20. Se4 d3 31 | 21. c5- b3 32 | 22. c2+ 3e3<12 33 | 23. d4- 3c3> 34 | 24. 2c4- 5d3< 35 | 25. e4- b2 36 | 26. Sb1 d1 37 | 27. 2e3- c2 38 | 0-R 39 | -------------------------------------------------------------------------------- /Gopkg.toml: -------------------------------------------------------------------------------- 1 | # Gopkg.toml example 2 | # 3 | # Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md 4 | # for detailed Gopkg.toml documentation. 5 | # 6 | 7 | [[constraint]] 8 | name = "github.com/golang/protobuf" 9 | version = "1.1.0" 10 | 11 | [[constraint]] 12 | name = "github.com/mattn/go-sqlite3" 13 | version = "1.6.0" 14 | 15 | [[constraint]] 16 | branch = "master" 17 | name = "golang.org/x/net" 18 | 19 | [[constraint]] 20 | name = "google.golang.org/grpc" 21 | version = "1.12.0" 22 | 23 | [prune] 24 | go-tests = true 25 | unused-packages = true 26 | 27 | [[constraint]] 28 | branch = "master" 29 | name = "github.com/google/subcommands" 30 | -------------------------------------------------------------------------------- /tak/movetype_string.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type=MoveType"; DO NOT EDIT. 2 | 3 | package tak 4 | 5 | import "strconv" 6 | 7 | const ( 8 | _MoveType_name_0 = "PassPlaceFlatPlaceStandingPlaceCapstoneSlideLeftSlideRightSlideUpSlideDown" 9 | _MoveType_name_1 = "TypeMask" 10 | ) 11 | 12 | var ( 13 | _MoveType_index_0 = [...]uint8{0, 4, 13, 26, 39, 48, 58, 65, 74} 14 | ) 15 | 16 | func (i MoveType) String() string { 17 | switch { 18 | case 1 <= i && i <= 8: 19 | i -= 1 20 | return _MoveType_name_0[_MoveType_index_0[i]:_MoveType_index_0[i+1]] 21 | case i == 15: 22 | return _MoveType_name_1 23 | default: 24 | return "MoveType(" + strconv.FormatInt(int64(i), 10) + ")" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /testdata/puzzles/puzzle-2016-09-15.ptn: -------------------------------------------------------------------------------- 1 | [Site "PlayTak.com"] 2 | [Event "Online Play"] 3 | [Date "2016.09.14"] 4 | [Time "23:46:59"] 5 | [Player1 "Abyss"] 6 | [Player2 "SultanPepper"] 7 | [Result "0-1"] 8 | [Size "6"] 9 | 10 | 1. a6 a1 11 | 2. c2 b2 12 | 3. c3 c4 13 | 4. Cd4 b3 14 | 5. c5 Cd3 15 | 6. 1d4<1 a3 16 | 7. b4 1d3<1 17 | 8. d3 d4 18 | 9. e4 e3 19 | 10. a4 1d4-1 20 | 11. f3 e2 21 | 12. f2 f1 22 | 13. 1f3<1 d2 23 | 14. 1f2<1 e1 24 | 15. Sd1 f2 25 | 16. 1d1+1 Sd4 26 | 17. c6 f3 27 | 18. 2d2+2 c1 28 | 19. d1 d2 29 | 20. 4d3-13 2c3-2 30 | 21. 4d1>13 3c2>3 31 | 22. e5 4d2>4 32 | 23. 4f1+4 c2 33 | 24. 2e3>2 f1 34 | 25. 1b4-1 1e2>1 35 | 26. 2e1>2 5f2<5 36 | 27. Sa2 e3 37 | 28. b4 e1 38 | 29. d5 {Black to force a road} 39 | -------------------------------------------------------------------------------- /testdata/zoo/puzzle-2016-09-15.ptn: -------------------------------------------------------------------------------- 1 | [Site "PlayTak.com"] 2 | [Event "Online Play"] 3 | [Date "2016.09.14"] 4 | [Time "23:46:59"] 5 | [Player1 "Abyss"] 6 | [Player2 "SultanPepper"] 7 | [Result "0-1"] 8 | [Size "6"] 9 | 10 | 1. a6 a1 11 | 2. c2 b2 12 | 3. c3 c4 13 | 4. Cd4 b3 14 | 5. c5 Cd3 15 | 6. 1d4<1 a3 16 | 7. b4 1d3<1 17 | 8. d3 d4 18 | 9. e4 e3 19 | 10. a4 1d4-1 20 | 11. f3 e2 21 | 12. f2 f1 22 | 13. 1f3<1 d2 23 | 14. 1f2<1 e1 24 | 15. Sd1 f2 25 | 16. 1d1+1 Sd4 26 | 17. c6 f3 27 | 18. 2d2+2 c1 28 | 19. d1 d2 29 | 20. 4d3-13 2c3-2 30 | 21. 4d1>13 3c2>3 31 | 22. e5 4d2>4 32 | 23. 4f1+4 c2 33 | 24. 2e3>2 f1 34 | 25. 1b4-1 1e2>1 35 | 26. 2e1>2 5f2<5 36 | 27. Sa2 e3 37 | 28. b4 e1 38 | 29. d5 {Black to force a road} 39 | -------------------------------------------------------------------------------- /testdata/zoo/puzzle-2016-05-26-6x6.ptn: -------------------------------------------------------------------------------- 1 | [Site "PlayTak.com"] 2 | [Player1 "Turing"] 3 | [Player2 "TakticianBot"] 4 | [Result "R-0"] 5 | [Size "6"] 6 | [Depth "7"] 7 | [Move "0"] 8 | [GoodMove "3c3<12"] 9 | [GoodMove "3c3<"] 10 | [GoodMove "2c3<"] 11 | 12 | 1. a6 f1 13 | 2. d2 d4 14 | 3. b3 d6 15 | 4. a3 c6 16 | 5. Sb6 b2 17 | 6. c3 e6 18 | 7. e2 d3 19 | 8. c2 1b2+1 20 | 9. b2 d1 21 | 10. c4 f6 22 | 11. f2 1d1+1 23 | 12. d1 2d2<2 24 | 13. d2 1d3-1 25 | 14. Cd3 Cb5 26 | 15. 1b6>1 3c2<3 27 | 16. Sc2 c5 28 | 17. Sd5 2b3>2 29 | 18. b3 a5 30 | 19. 1c4-1 4b2+4 31 | {White can force a win here, the best solution I know of costs three pieces and takes five moves (9 ply) to complete a road, but I wouldn't be surprised if there's a better solution.} 32 | -------------------------------------------------------------------------------- /testdata/puzzles/puzzle-2016-05-26-6x6.ptn: -------------------------------------------------------------------------------- 1 | [Site "PlayTak.com"] 2 | [Player1 "Turing"] 3 | [Player2 "TakticianBot"] 4 | [Result "R-0"] 5 | [Size "6"] 6 | [Depth "7"] 7 | [Move "0"] 8 | [GoodMove "3c3<12"] 9 | [GoodMove "3c3<"] 10 | [GoodMove "2c3<"] 11 | 12 | 1. a6 f1 13 | 2. d2 d4 14 | 3. b3 d6 15 | 4. a3 c6 16 | 5. Sb6 b2 17 | 6. c3 e6 18 | 7. e2 d3 19 | 8. c2 1b2+1 20 | 9. b2 d1 21 | 10. c4 f6 22 | 11. f2 1d1+1 23 | 12. d1 2d2<2 24 | 13. d2 1d3-1 25 | 14. Cd3 Cb5 26 | 15. 1b6>1 3c2<3 27 | 16. Sc2 c5 28 | 17. Sd5 2b3>2 29 | 18. b3 a5 30 | 19. 1c4-1 4b2+4 31 | {White can force a win here, the best solution I know of costs three pieces and takes five moves (9 ply) to complete a road, but I wouldn't be surprised if there's a better solution.} 32 | -------------------------------------------------------------------------------- /testdata/zoo/70709.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | [Date "2016-07-25"] 3 | [Time "2016-07-25T20:43:10Z"] 4 | [Player1 "dove_queen"] 5 | [Player2 "NohatCoder"] 6 | [Result "0-R"] 7 | [Id "70709"] 8 | [Depth "7"] 9 | [Move "1 black"] 10 | [Move "12 white"] 11 | 12 | 13 | 1. a6 f6 14 | 2. e5 e4 15 | 3. Cd4 d5 16 | 4. f5 c4 17 | 5. e3 d3 18 | 6. f4 e2 19 | 7. f3 f2 20 | 8. d2 d3- 21 | 9. d3 c2 22 | 10. d4- f2+ 23 | 11. e5- Ce5 24 | 12. f4- e5- 25 | 13. f4 2e4- 26 | 14. 3f3-12 f3 27 | 15. 2d3- c5 28 | 16. 4d2>13 3e3> 29 | 17. e1 c3 30 | 18. c1 c6 31 | 19. b1 d1 32 | 20. 4f2<121 Sa1 33 | 21. a2 Sb2 34 | 22. 2d2- 4f3-13 35 | 23. e3 5f1<14 36 | 24. Sd2 f2< 37 | 25. 2c2+ d4 38 | 26. e3+ 4e2+13 39 | 27. 3c3>12 e2 40 | 28. d2+ 6d1+51 41 | 29. 3d2>12 2d3- 42 | 0-R 43 | -------------------------------------------------------------------------------- /testdata/zoo/2016-round7-game2.ptn: -------------------------------------------------------------------------------- 1 | [Site "PlayTak.com"] 2 | [Event "Online Play"] 3 | [Date "2016.10.30"] 4 | [Time "18:31:54"] 5 | [Player1 "nelhage"] 6 | [Player2 "pythoner6"] 7 | [Clock "10:0 +15"] 8 | [Result "F-0"] 9 | [Size "5"] 10 | [Depth "9"] 11 | [Move "12 white"] 12 | 13 | 1. a1 e1 14 | 2. d2 a2 15 | 3. a3 Cb3 16 | 4. Cb2 1b3<1 17 | 5. c2 a4 18 | 6. 1b2<1 Sb2 19 | 7. e2 b1 20 | 8. d3 d4 21 | 9. c4 b4 22 | 10. e4 Sd1 23 | 11. e3 Se5 24 | { There's what looks like a Tinue here after 12. e4<, but 25 | understanding it requires following a wall bouncing back and forth 26 | eating pieces. } 27 | 12. c5 1d4<1 28 | 13. d4 Sd5 29 | 14. 1d4<1 1b4>1 30 | 15. 1c5-1 Sc3 31 | 16. 4c4<22 2a3+2 32 | 17. b5 4a4>13 33 | 18. a3 b3 34 | 19. a5 c1 35 | 20. c5 d4 -------------------------------------------------------------------------------- /python/tak/alphazero/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from functools import cached_property 5 | 6 | 7 | class ReferenceAccuracy: 8 | def loss_and_metrics(self, batch, logits): 9 | v_logits = logits["values"] 10 | m_logits = logits["moves"] 11 | 12 | v_error = F.mse_loss(v_logits, batch.values) 13 | 14 | moves = batch.moves 15 | moves = moves.to(m_logits.dtype) 16 | probs = (torch.softmax(m_logits, -1) * moves).sum(-1) 17 | accuracy = probs.mean() 18 | 19 | metrics = { 20 | "v_error": v_error.item(), 21 | "accuracy": accuracy.item(), 22 | } 23 | 24 | return (v_error - accuracy), metrics 25 | -------------------------------------------------------------------------------- /testdata/ai/treffnon-v2.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-05-06"] 3 | [Player1 "TreffnonX"] 4 | [Player2 "TakticianBot"] 5 | [Result "F-0"] 6 | [Id "980"] 7 | [Move "6 black"] 8 | 9 | 1. a1 e1 10 | 2. a2 c3 11 | 3. b2 a3 12 | 4. c2 c1 13 | 5. d2 a1+ 14 | 6. a1 2a2> 15 | 7. c2< Sb3 16 | 8. 3b2>12 Sd3 17 | 9. Cd1 b3- 18 | 10. a2 b3 19 | 11. e2 e3 20 | 12. d1+ a3- 21 | 13. a1+ b2< 22 | 14. 3d2<111 2c2>11 23 | 15. d2> Cc2 24 | 16. 4a2> c4 25 | 17. Sc5 d4 26 | 18. c5- Sb4 27 | 19. e4 e3- 28 | 20. e1+ Se1 29 | 21. 5e2+122 Sd5 30 | 22. a1 d5> 31 | 23. a5 c5 32 | 24. 5b2+ b5 33 | 25. a4 b5< 34 | 26. a4+ b1 35 | 27. a4 b1< 36 | 28. a2- Sa3 37 | 29. a2 a3+ 38 | 30. d1 3e5-12 39 | 31. 3a5>12 d5 40 | 32. a3 Sb1 41 | 33. e2 b1< 42 | 34. b1 2a1>11 43 | 35. a5 c1< 44 | 36. e5 45 | F-0 46 | -------------------------------------------------------------------------------- /testdata/zoo/_pending/74639.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-08-08"] 3 | [Time "2016-08-08T15:12:57Z"] 4 | [Player1 "TakticianBot"] 5 | [Player2 "applemonkeyman"] 6 | [Result "0-F"] 7 | [Id "74639"] 8 | 9 | 10 | 1. e1 a1 11 | 2. b1 e2 12 | 3. e3 d2 13 | 4. c1 Cd3 14 | 5. c2 d4 15 | 6. d5 e4 16 | 7. c2> d1 17 | 8. e5 c5 18 | 9. d5- d5 19 | 10. e5- c4 20 | 11. 2d2- d2 21 | 12. e3- e1< 22 | 13. c1> d2- 23 | 14. Cc1 5d1+ 24 | 15. c1> d3- 25 | 16. c3 e1 26 | 17. c2 e1+ 27 | 18. Sb2 e1 28 | 19. 2d1> d1 29 | 20. 3e1< e1 30 | 21. Sd3 e5 31 | 22. 4d1> b5 32 | 23. 2e4+ 3e2+111 33 | 24. e4+ d5> 34 | 25. 5e1+1211 2e3-11 35 | 26. 4e5-22 5d2< 36 | 27. d3- c1 37 | 28. b1> d1 38 | 29. 2d2- d2 39 | 30. 3d1> 3c2- 40 | 31. b2> c1+ 41 | 32. c3+ 4c1<13 42 | 33. 4e1+ c1 43 | 34. b3 d1 44 | 0-F 45 | 46 | -------------------------------------------------------------------------------- /tak/slide_test.go: -------------------------------------------------------------------------------- 1 | package tak 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestMkSlides(t *testing.T) { 9 | cases := []struct { 10 | out uint32 11 | in []int 12 | }{ 13 | { 14 | 0, 15 | nil, 16 | }, 17 | { 18 | 0x1, 19 | []int{1}, 20 | }, 21 | { 22 | 0x321, 23 | []int{1, 2, 3}, 24 | }, 25 | } 26 | 27 | for _, tc := range cases { 28 | s := MkSlides(tc.in...) 29 | if uint32(s) != tc.out { 30 | t.Errorf("%v: got %x != %x", tc.in, s, tc.out) 31 | } 32 | 33 | var out []int 34 | if !s.Empty() { 35 | for it := s.Iterator(); it.Ok(); it = it.Next() { 36 | out = append(out, it.Elem()) 37 | } 38 | } 39 | if !reflect.DeepEqual(out, tc.in) { 40 | t.Errorf("rt(%v) = %v", tc.in, out) 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | docker: 5 | # specify the version 6 | - image: debian:buster 7 | steps: 8 | - checkout 9 | - run: | 10 | apt-get update 11 | apt-get -y install python3.7 python-virtualenv curl git gcc 12 | virtualenv --python=python3.7 venv 13 | - run: | 14 | curl -Lo /tmp/go.tar.tgz https://dl.google.com/go/go1.14.1.linux-amd64.tar.gz 15 | tar -xzf /tmp/go.tar.tgz -C /usr/local/ 16 | - run: | 17 | venv/bin/pip install -r python/requirements.txt 18 | venv/bin/pip install -e python 19 | 20 | - run: venv/bin/py.test python/test 21 | - run: /usr/local/go/bin/go get -v -t -d ./... 22 | - run: /usr/local/go/bin/go test -v ./... 23 | -------------------------------------------------------------------------------- /cli/player.go: -------------------------------------------------------------------------------- 1 | package cli 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "strings" 8 | 9 | "github.com/nelhage/taktician/ptn" 10 | "github.com/nelhage/taktician/tak" 11 | ) 12 | 13 | func NewCLIPlayer(out io.Writer, in *bufio.Reader) Player { 14 | return &cliPlayer{out, in} 15 | } 16 | 17 | type cliPlayer struct { 18 | out io.Writer 19 | in *bufio.Reader 20 | } 21 | 22 | func (c *cliPlayer) GetMove(p *tak.Position) tak.Move { 23 | for { 24 | fmt.Fprintf(c.out, "%s> ", p.ToMove()) 25 | line, err := c.in.ReadString('\n') 26 | if err != nil { 27 | panic(err) 28 | } 29 | line = strings.TrimRight(line, "\r\n") 30 | m, err := ptn.ParseMove(line) 31 | if err != nil { 32 | fmt.Fprintln(c.out, "parse error: ", err) 33 | continue 34 | } 35 | return m 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /testdata/zoo/100675.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-10-14"] 3 | [Time "2016-10-14T18:06:30Z"] 4 | [Player1 "Guest195"] 5 | [Player2 "TakticianBot"] 6 | [Result "F-0"] 7 | [Id "100675"] 8 | [Depth "7"] 9 | [Seed "1476468550"] 10 | [Move "5 black"] 11 | [Move "6 black"] 12 | 13 | 1. a5 e1 14 | 2. d2 a4 15 | 3. a2 b2 16 | 4. Cb3 c2 17 | 5. c3 Ca3 18 | 6. a1 a3- 19 | 7. a3 d3 20 | 8. e2 2a2+ 21 | 9. a2 b1 22 | 10. b3- 3a3- 23 | 11. a3 b3 24 | 12. c1 d1 25 | 13. b4 c2> 26 | 14. Sc2 b3< 27 | 15. 2b2- 4a2- 28 | 16. a2 b2 29 | 17. 3b1+ 5a1+ 30 | 18. a1 b5 31 | 19. c2> d1< 32 | 20. Sb1 c5 33 | 21. e5 e3 34 | 22. e4 e3+ 35 | 23. 3d2+12 d1 36 | 24. d2 Se3 37 | 25. b1> d5 38 | 26. 2d4> e3< 39 | 27. 2c1> d5> 40 | 28. 3e4+ b1 41 | 29. 4b2- 2d3- 42 | 30. 5e5<14 e3 43 | 31. c2 b3 44 | 32. c4 d3< 45 | 33. d3 46 | F-0 47 | -------------------------------------------------------------------------------- /testdata/zoo/match1-game1-PonchoPal.ptn: -------------------------------------------------------------------------------- 1 | [Site "PlayTak.com"] 2 | [Event "Online Play"] 3 | [Date "2016.09.15"] 4 | [Time "04:42:58"] 5 | [Player1 "nelhage"] 6 | [Player2 "PonchoPal"] 7 | [Clock "10:0 +15"] 8 | [Result "0-F"] 9 | [Size "5"] 10 | [Depth "7"] 11 | [Move "6 white"] 12 | [Move "10 black"] 13 | 14 | 1. a1 e1 15 | 2. d3 d1 16 | 3. d2 Ce2 17 | 4. c3 c2 18 | 5. c1 b1 19 | 6. 1d2<1 d2 20 | 7. c4 1d1<1 21 | 8. Sd1 1e2<1 22 | 9. 1d1<1 2d2<2 23 | 10. Cd2 c5 24 | 11. b4 b2 25 | 12. 2c1<2 a2 26 | 13. 1c4+1 b3 27 | 14. 1b4-1 3c2+3 28 | 15. d4 e3 29 | 16. 1d2+1 4c3<4 30 | 17. Sb4 b5 31 | 18. 2c5<2 d2 32 | 19. Se2 a5 33 | 20. Sa3 a4 34 | 21. 1a3-1 a3 35 | 22. 1b4<1 1a5>1 36 | 23. d1 b4 37 | 24. 2a4>2 3b5>12 38 | 25. e5 Se4 39 | 26. 1e5<1 1e4<1 40 | 27. 3d5<12 a5 41 | 28. c4 a4 42 | 29. 3b1+3 b1 43 | 30. 2a2+2 a2 -------------------------------------------------------------------------------- /logs/repository.go: -------------------------------------------------------------------------------- 1 | package logs 2 | 3 | import ( 4 | "database/sql" 5 | "time" 6 | 7 | _ "github.com/mattn/go-sqlite3" // repository assumes sqlite 8 | ) 9 | 10 | type Repository struct { 11 | db *sql.DB 12 | } 13 | 14 | type Game struct { 15 | Day string 16 | ID int 17 | Timestamp time.Time 18 | Size int 19 | Player1, Player2 string 20 | Result string 21 | Winner string 22 | Moves int 23 | } 24 | 25 | func Open(db string) (*Repository, error) { 26 | sql, err := sql.Open("sqlite3", db) 27 | if err != nil { 28 | return nil, err 29 | } 30 | return &Repository{db: sql}, nil 31 | } 32 | 33 | func (r *Repository) Close() { 34 | r.db.Close() 35 | } 36 | 37 | func (r *Repository) DB() *sql.DB { 38 | return r.db 39 | } 40 | -------------------------------------------------------------------------------- /python/test/test_encoding.py: -------------------------------------------------------------------------------- 1 | import tak.ptn 2 | from tak.model import encoding 3 | 4 | 5 | def test_encoding_invariants(): 6 | assert len(encoding.Token.CAPSTONES) == encoding.MAX_CAPSTONES 7 | assert len(encoding.Token.RESERVES) == encoding.MAX_RESERVES 8 | 9 | assert set(encoding.Token.CAPSTONES) & set(encoding.Token.RESERVES) == set() 10 | 11 | 12 | def test_encode(): 13 | p1 = tak.ptn.parse_tps("12,x,22S/x2,1/x,21,1 2 8") 14 | p2 = tak.ptn.parse_tps("x,2,2/x,12S,1/x,1,2S 1 6") 15 | 16 | e1 = encoding.encode(p1) 17 | e2 = encoding.encode(p2) 18 | 19 | batch, mask = encoding.encode_batch([p1, p2]) 20 | 21 | assert [b for (b, m) in zip(batch[0].tolist(), mask[0].tolist()) if m] == e1 22 | assert [b for (b, m) in zip(batch[1].tolist(), mask[1].tolist()) if m] == e2 23 | -------------------------------------------------------------------------------- /tei/server_test.go: -------------------------------------------------------------------------------- 1 | package tei 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestCalcBudget(t *testing.T) { 11 | cases := []struct { 12 | Move time.Duration 13 | Game time.Duration 14 | Inc time.Duration 15 | Expect time.Duration 16 | }{ 17 | {0, 3 * time.Second, 3 * time.Second, 0}, 18 | {time.Second, 3 * time.Second, 3 * time.Second, time.Second}, 19 | {5 * time.Second, 3 * time.Second, 3 * time.Second, 0}, 20 | } 21 | for _, tc := range cases { 22 | got := calcBudget(tc.Move, tc.Game, tc.Inc) 23 | if tc.Expect != 0 { 24 | assert.Equal(t, tc.Expect, got) 25 | } 26 | if tc.Game != 0 { 27 | assert.Less(t, int64(got), int64(tc.Game)) 28 | } 29 | if tc.Move != 0 { 30 | assert.LessOrEqual(t, int64(got), int64(tc.Move)) 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /testdata/zoo/115508.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | [Date "2016-11-27"] 3 | [Time "2016-11-27T16:02:18Z"] 4 | [Player1 "SultanPepper"] 5 | [Player2 "NohatCoder"] 6 | [Result "R-0"] 7 | [Id "115508"] 8 | [Depth "6"] 9 | [Move "38 white"] 10 | [Move "34 black"] 11 | 12 | 1. a6 a1 13 | 2. c3 c4 14 | 3. d3 d4 15 | 4. e3 b3 16 | 5. c2 b2 17 | 6. Cb4 Cd2 18 | 7. b4> c1 19 | 8. f3 b4 20 | 9. a3 b1 21 | 10. b5 a4 22 | 11. a5 d2+ 23 | 12. c5 d2 24 | 13. d5 e5 25 | 14. e4 e5< 26 | 15. e5 b6 27 | 16. e5< a6- 28 | 17. c2< Se5 29 | 18. c2 e5< 30 | 19. a2 2d3< 31 | 20. d3 a6 32 | 21. 2c4< c4 33 | 22. c6 4d5<22 34 | 23. e2 d2< 35 | 24. c6- 3b5> 36 | 25. Sb5 d5 37 | 26. 2b2> c1+ 38 | 27. Sd2 3c2<12 39 | 28. a1> a1 40 | 29. a3- b2< 41 | 30. b5< d6 42 | 31. 3b4> b4 43 | 32. d2< e6 44 | 33. 3c2<12 c6 45 | 34. 2a5+ b5 46 | 35. 6a2+15 c2 47 | 36. 3a4> 3c3> 48 | 37. d2 c2> 49 | 38. 3a6>111 e5 50 | 39. a6 51 | R-0 52 | -------------------------------------------------------------------------------- /python/xformer/train/wandb.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | 3 | import attrs 4 | import wandb 5 | from attrs import define 6 | 7 | from .run import Hook, Run, Stats 8 | 9 | 10 | @define 11 | class WandbHook(Hook): 12 | job_name: T.Optional[str] = None 13 | project: T.Optional[str] = None 14 | group: T.Optional[str] = None 15 | config: T.Any = None 16 | 17 | def before_run(self, run: Run): 18 | job_name = self.job_name 19 | if job_name is not None and "{rand}" in job_name: 20 | job_name = job_name.format(rand=wandb.util.generate_id()) 21 | run.wandb = wandb.init( 22 | project=self.project, 23 | name=job_name, 24 | group=self.group, 25 | ) 26 | wandb.config.update(self.config) 27 | 28 | def after_step(self, run: Run, stats: Stats): 29 | wandb.log(attrs.asdict(stats), step=stats.step) 30 | -------------------------------------------------------------------------------- /testdata/ai/fwwwwibib-v3.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-05-14"] 3 | [Time "2016-05-14T18:20:58Z"] 4 | [Player1 "fwwwwibib"] 5 | [Player2 "TakticianBot"] 6 | [Result "R-0"] 7 | [Id "26271"] 8 | [Move "29 black"] 9 | 10 | 1. a1 b4 11 | 2. c4 c3 12 | 3. d4 c5 13 | 4. d3 a4 14 | 5. b3 a3 15 | 6. a5 e3 16 | 7. b5 e4 17 | 8. d5 c5< 18 | 9. b4+ Sc5 19 | 10. Cb4 e4< 20 | 11. d3+ c5< 21 | 12. e4 2b5< 22 | 13. b3< e3+ 23 | 14. b3 a4- 24 | 15. b3< 2e4< 25 | 16. d5- Sd3 26 | 17. 5d4> d3+ 27 | 18. 5e4-212 Se4 28 | 19. 4a3>22 e4- 29 | 20. c2 3e3<12 30 | 21. d2 b2 31 | 22. c5 4c3+13 32 | 23. a3 e3 33 | 24. b4> c3< 34 | 25. 3c4- b2> 35 | 26. d2< d2 36 | 27. 2c2> d3- 37 | 28. c2> 2d4-11 38 | 29. e2+ 5d2+122 39 | 30. 2c3> 2b3< 40 | 31. Sa2 Sd1 41 | 32. a2+ d1> 42 | 33. Se2 Se4 43 | 34. b2 e4- 44 | 35. c2 e5 45 | 36. a4 e4 46 | 37. 3a3-12 c4 47 | 38. b1 c1 48 | 39. 3d3+ 3e3<12 49 | 40. 5d4-14 e3 50 | 41. 2d3> 51 | R-0 52 | -------------------------------------------------------------------------------- /testdata/zoo/86085.ptn: -------------------------------------------------------------------------------- 1 | [Size "5"] 2 | [Date "2016-09-08"] 3 | [Time "2016-09-08T01:24:23Z"] 4 | [Player1 "applemonkeyman"] 5 | [Player2 "TakticianBot"] 6 | [Result "R-0"] 7 | [Id "86085"] 8 | [Depth "7"] 9 | [Move "35 black"] 10 | 11 | 12 | 1. a5 a1 13 | 2. b1 d2 14 | 3. b2 b3 15 | 4. Cc2 d3 16 | 5. c3 c4 17 | 6. d4 b5 18 | 7. e3 b3- 19 | 8. c1 c4> 20 | 9. c4 c5 21 | 10. d1 2b2- 22 | 11. Sd5 2d4< 23 | 12. d5< Cd4 24 | 13. e1 e2 25 | 14. Sb4 Sb2 26 | 15. c2< 3b1> 27 | 16. c2 d2- 28 | 17. d2 e2< 29 | 18. 2b2>11 d5 30 | 19. 2d2< e2 31 | 20. e3- d3- 32 | 21. 3c2> d4< 33 | 22. 4d2< e5 34 | 23. 2c5< 4c4+ 35 | 24. 3b5< 4c5<31 36 | 25. Sc5 5a5-1112 37 | 26. b4< b3 38 | 27. c5< c4 39 | 28. 4b5-112 b5 40 | 29. a5 c4< 41 | 30. c4 c5 42 | 31. 2b3+ b5- 43 | 32. 2a4> d3 44 | 33. b5 d3< 45 | 34. 4c2+ c5- 46 | 35. 3b4> 2d1+ 47 | 36. d3 3a1> 48 | 37. e3 3d2+ 49 | 38. 5c3> 4c1+31 50 | 39. 3c2+ a3> 51 | 40. 4c3<22 52 | R-0 53 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | on: 2 | pull_request: {} 3 | push: {} 4 | 5 | name: Continuous integration 6 | 7 | jobs: 8 | ci-go: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - uses: actions/setup-go@v2 13 | with: 14 | go-version: '^1.18.3' 15 | - name: go test 16 | run: | 17 | go test ./... 18 | ci-python: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v3 22 | - uses: actions/setup-python@v3 23 | with: 24 | python-version: '3.10' 25 | architecture: x64 26 | cache: pip 27 | - name: pip install 28 | run: | 29 | pip install -r python/requirements.txt 30 | pip install -e python/ 31 | - name: pytest 32 | run: | 33 | cd python && pytest 34 | env: 35 | TEST_WANDB: true 36 | -------------------------------------------------------------------------------- /python/test/corpus.csv: -------------------------------------------------------------------------------- 1 | "x,22,2S/1,x,1S/1,21S,x 1 6",a3,+1.000000 2 | "2S,1,2/x,2,1/1S,112,21 1 10",b3>,+1.000000 3 | "1S,1,2S/2S,x,21S/2S,1,2 1 10",b2,+1.000000 4 | "1,x,1/2S,1,1S/2S,1,2 2 5",b3,-1.000000 5 | "1S,1,1S/22,x,1S/2S,2,1 1 7",b2,+1.000000 6 | "x3/1,x,2/1,x,2 1 3",a3,+1.000000 7 | "x3/1,x,2S/11,22S,x 1 6",a3,+1.000000 8 | "1S,1,2S/x,2S,1S/1S,2S,2 2 6",a2,+1.000000 9 | "21,x,2/x,212S,1/x2,1 1 7",2a3>11,+1.000000 10 | "x,1,x/1,x2/2,x,2 2 3",b1,+1.000000 11 | "2,1,x/2S,1,2S/22S,x,11S 1 7",b1,+1.000000 12 | "x,11S,1/22S,x,1S/2S,2S,1S 1 7",b3-,+1.000000 13 | "x,2,x/1,1S,2/2S,1,x 2 4",c3,+1.000000 14 | "2,x2/x,2S,1/x,1,x 1 3",c1,+1.000000 15 | "2,x,1/1S,2,x/2,x,1 1 4",c2,+1.000000 16 | "2,x,212/x,2,1S/x2,2 1 7",Sb3,-1.000000 17 | "1S,2S,1/21S,21,2S/x,11S,1 2 14",b3-,-1.000000 18 | "1,x,21/1S,2,2S/1S,2,1 2 11",b3,+1.000000 19 | "2S,1,12S/x,2,2/2121S,21S,x 1 15",2b1+,-1.000000 20 | "2S,2S,1S/x,211S,x/1S,2,2 2 10",a2,+1.000000 21 | -------------------------------------------------------------------------------- /tak/hash_test.go: -------------------------------------------------------------------------------- 1 | package tak 2 | 3 | import "testing" 4 | 5 | func TestPositionEqual(t *testing.T) { 6 | p := New(Config{Size: 5}) 7 | if !p.Equal(p) { 8 | t.Fatal("New() != self!") 9 | } 10 | p2 := New(Config{Size: 5}) 11 | if !p.Equal(p2) { 12 | t.Fatal("New() != New()!") 13 | } 14 | l := moves([]Move{ 15 | Move{X: 0, Y: 0, Type: PlaceFlat}, 16 | Move{X: 4, Y: 4, Type: PlaceFlat}, 17 | Move{X: 0, Y: 4, Type: PlaceFlat}, 18 | Move{X: 4, Y: 0, Type: PlaceFlat}, 19 | }) 20 | r := moves([]Move{ 21 | Move{X: 4, Y: 0, Type: PlaceFlat}, 22 | Move{X: 0, Y: 4, Type: PlaceFlat}, 23 | Move{X: 4, Y: 4, Type: PlaceFlat}, 24 | Move{X: 0, Y: 0, Type: PlaceFlat}, 25 | }) 26 | if !l.Equal(r) { 27 | t.Fatalf("l != r") 28 | } 29 | if !r.Equal(l) { 30 | t.Fatalf("r != l") 31 | } 32 | if p.Equal(r) { 33 | t.Fatalf("New() == r") 34 | } 35 | if p.Equal(l) { 36 | t.Fatalf("New() == l") 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /python/tak/model/grpc.py: -------------------------------------------------------------------------------- 1 | from . import encoding 2 | from attrs import define, field 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from tak.proto import analysis_pb2, analysis_pb2_grpc 8 | import grpc 9 | 10 | 11 | @define 12 | class GRPCNetwork: 13 | host: str 14 | port: int 15 | stub: analysis_pb2_grpc.AnalysisStub = field(init=False) 16 | 17 | def __attrs_post_init__(self): 18 | channel = grpc.insecure_channel(f"{self.host}:{self.port}") 19 | self.stub = analysis_pb2_grpc.AnalysisStub(channel) 20 | 21 | def evaluate(self, pos): 22 | with torch.no_grad(): 23 | encoded = encoding.encode(pos) 24 | out = self.stub.Evaluate(analysis_pb2.EvaluateRequest(position=encoded)) 25 | move_probs = torch.from_numpy( 26 | np.frombuffer(out.move_probs_bytes, dtype=np.float32).copy() 27 | ) 28 | return move_probs, out.value 29 | -------------------------------------------------------------------------------- /python/xformer/train/hooks/wandb.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | 3 | import attrs 4 | import torch 5 | from attrs import define 6 | 7 | from ..run import Hook, Run, Stats 8 | 9 | 10 | @define 11 | class Wandb(Hook): 12 | job_name: T.Optional[str] = None 13 | project: T.Optional[str] = None 14 | group: T.Optional[str] = None 15 | config: T.Any = None 16 | 17 | def before_run(self, run: Run): 18 | import wandb 19 | 20 | job_name = self.job_name 21 | if job_name is not None and "{rand}" in job_name: 22 | job_name = job_name.format(rand=wandb.util.generate_id()) 23 | run.wandb = wandb.init( 24 | project=self.project, 25 | name=job_name, 26 | group=self.group, 27 | ) 28 | wandb.config.update(self.config) 29 | 30 | def after_step(self, run: Run, stats: Stats): 31 | run.wandb.log(attrs.asdict(stats), step=stats.step) 32 | -------------------------------------------------------------------------------- /ai/json.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | ) 7 | 8 | var _ interface { 9 | json.Marshaler 10 | json.Unmarshaler 11 | } = &Weights{} 12 | 13 | var featureNames map[string]Feature 14 | 15 | func init() { 16 | featureNames = make(map[string]Feature) 17 | for i := Feature(0); i < MaxFeature; i++ { 18 | featureNames[i.String()] = i 19 | } 20 | } 21 | 22 | func (ws *Weights) MarshalJSON() ([]byte, error) { 23 | h := make(map[string]int64) 24 | for i, v := range ws { 25 | if v != 0 { 26 | h[Feature(i).String()] = v 27 | } 28 | } 29 | return json.Marshal(h) 30 | } 31 | 32 | func (ws *Weights) UnmarshalJSON(bs []byte) error { 33 | h := make(map[string]int64) 34 | e := json.Unmarshal(bs, &h) 35 | if e != nil { 36 | return e 37 | } 38 | for k, v := range h { 39 | f, ok := featureNames[k] 40 | if !ok { 41 | return fmt.Errorf("Unknown feature: %q", k) 42 | } 43 | ws[f] = v 44 | } 45 | return nil 46 | } 47 | -------------------------------------------------------------------------------- /python/tak/model/heads.py: -------------------------------------------------------------------------------- 1 | import typing as T # noqa 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from tak.model import encoding 7 | 8 | 9 | class PolicyValue(nn.Module): 10 | def __init__(self, cfg, dtype=None, device=None): 11 | super().__init__() 12 | self.final_ln = nn.LayerNorm( 13 | normalized_shape=(cfg.d_model,), dtype=dtype, device=device 14 | ) 15 | self.v_proj = nn.Linear(cfg.d_model, 1, dtype=dtype, device=device) 16 | self.move_proj = nn.Linear( 17 | cfg.d_model, encoding.MAX_MOVE_ID, dtype=dtype, device=device 18 | ) 19 | 20 | def init_weights(self, cfg): 21 | pass 22 | 23 | def forward(self, acts): 24 | acts = self.final_ln(acts)[:, 0] 25 | 26 | v = torch.tanh(self.v_proj(acts)) 27 | 28 | moves = self.move_proj(acts) 29 | 30 | return { 31 | "values": v.squeeze(-1), 32 | "moves": moves, 33 | } 34 | -------------------------------------------------------------------------------- /ai/json_test.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "testing" 7 | ) 8 | 9 | func TestMarshalUnmarshal(t *testing.T) { 10 | cases := []struct { 11 | in Weights 12 | out string 13 | }{ 14 | {Weights{}, "{}"}, 15 | {Weights{TopFlat: 100}, `{"TopFlat":100}`}, 16 | {Weights{TopFlat: 100, Capstone: 150}, `{"Capstone":150,"TopFlat":100}`}, 17 | } 18 | for i, tc := range cases { 19 | t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { 20 | out, e := json.Marshal(&tc.in) 21 | if e != nil { 22 | t.Fatalf("Marshal(): %v", e) 23 | } 24 | if string(out) != tc.out { 25 | t.Fatalf("Marshal() = %q != %q", out, tc.out) 26 | } 27 | 28 | var back Weights 29 | e = json.Unmarshal(out, &back) 30 | if e != nil { 31 | t.Fatalf("Unmarshal(%q): %v", out, e) 32 | } 33 | for i, v := range back { 34 | if tc.in[i] != v { 35 | t.Errorf("roundtrip[%d] = %v != %v", i, v, tc.in[i]) 36 | } 37 | } 38 | }) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /logs/sql.go: -------------------------------------------------------------------------------- 1 | package logs 2 | 3 | const createGameTable = ` 4 | CREATE TABLE IF NOT EXISTS games ( 5 | day string not null, 6 | id integer not null, 7 | time datetime, 8 | size int, 9 | player1 varchar, 10 | player2 varchar, 11 | result string, 12 | winner string, 13 | moves int 14 | )` 15 | 16 | const createPlayerTable = ` 17 | CREATE VIEW IF NOT EXISTS player_games ( 18 | day, id, player, opponent, color, win, result, size, moves 19 | ) AS 20 | SELECT day, id, player2, player1, 'black', 21 | CASE winner WHEN 'white' THEN 'lose' WHEN 'black' THEN 'win' ELSE 'tie' END, 22 | result, size, moves 23 | FROM games 24 | UNION 25 | SELECT day, id, player1, player2, 'white', 26 | CASE winner WHEN 'white' THEN 'win' WHEN 'black' THEN 'lose' ELSE 'tie' END, 27 | result, size, moves 28 | FROM games 29 | ` 30 | 31 | const insertStmt = ` 32 | INSERT INTO games (day, id, time, size, player1, player2, result, winner, moves) 33 | VALUES (?,?,?,?,?,?,?,?,?) 34 | ` 35 | -------------------------------------------------------------------------------- /testdata/zoo/74359.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | [Date "2016-08-07"] 3 | [Time "2016-08-07T08:26:18Z"] 4 | [Player1 "NohatCoder"] 5 | [Player2 "TakticianBot"] 6 | [Result "F-0"] 7 | [Id "74359"] 8 | [Depth "7"] 9 | [Move "12 black"] 10 | [Move "24 black"] 11 | 12 | 1. a6 f1 13 | 2. e2 c6 14 | 3. d2 b6 15 | 4. d6 c5 16 | 5. c2 b2 17 | 6. b3 d5 18 | 7. c3 d3 19 | 8. a3 d3- 20 | 9. d3 b2+ 21 | 10. b2 d4 22 | 11. a2 Ce5 23 | 12. Sf5 b1 24 | 13. d3- b1+ 25 | 14. Sb5 f6 26 | 15. b5> d1 27 | 16. Cd3 e3 28 | 17. f2 e6 29 | 18. d6< e1 30 | 19. a3> e1+ 31 | 20. 2d2> e3- 32 | 21. d2> 2b2+ 33 | 22. c3< Sb2 34 | 23. 5b3+212 Se3 35 | 24. 5e2< a6> 36 | 25. b5+ c4 37 | 26. a3 e3- 38 | 27. a1 Sb5 39 | 28. c3 b5+ 40 | 29. Sb5 6b6< 41 | 30. Sa5 b2+ 42 | 31. Sb6 e4 43 | 32. e3 e4- 44 | 33. e1 b2 45 | 34. Se4 c1 46 | 35. b1 a4 47 | 36. 2c5- d1> 48 | 37. f3 2b3+ 49 | 38. e4- 4b4< 50 | 39. a2> 5a4-113 51 | 40. a4 a2> 52 | 41. b1+ a2 53 | 42. f4 c5 54 | 43. d1 a2> 55 | 44. c2< 2e2< 56 | 45. d3- 2e1< 57 | 46. e2 a1+ 58 | 47. c2 59 | F-0 60 | -------------------------------------------------------------------------------- /testdata/zoo/_pending/82143.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | [Date "2016-08-31"] 3 | [Time "2016-08-31T09:27:23Z"] 4 | [Player1 "Ally"] 5 | [Player2 "TakticianBot"] 6 | [Result "F-0"] 7 | [Id "82143"] 8 | 9 | 10 | 1. a6 a1 11 | 2. c3 b2 12 | 3. b3 d3 13 | 4. Cd4 Cc4 14 | 5. c2 d2 15 | 6. e3 e2 16 | 7. f3 c4- 17 | 8. c4 a2 18 | 9. f2 e4 19 | 10. f1 e5 20 | 11. f4 2c3- 21 | 12. f5 e4> 22 | 13. f5- f5 23 | 14. e4 f5- 24 | 15. e4> Sf5 25 | 16. 4f4< e2> 26 | 17. e2 b4 27 | 18. d5 c3 28 | 19. e2> e5< 29 | 20. c5 b4> 30 | 21. c5> Sd6 31 | 22. e5 f5< 32 | 23. f5 f6 33 | 24. e6 f6- 34 | 25. d4> 3c2>111 35 | 26. e1 4f2<22 36 | 27. f2 3e2> 37 | 28. Se2 a5 38 | 29. b4 Sd4 39 | 30. a4 c2 40 | 31. c5 b5 41 | 32. c5- c3+ 42 | 33. b4> Sc3 43 | 34. e4< 2e5- 44 | 35. d4> c3+ 45 | 36. 2e4<11 2d4>11 46 | 37. 6c4>222 4d2+13 47 | 38. 3f4< 2d3>11 48 | 39. f4- 5d4+ 49 | 40. e2> c5 50 | 41. 4f2+112 d6> 51 | 42. 3f3<21 6d5-24 52 | 43. f6 Se5 53 | 44. b4 a5- 54 | 45. a3 a2+ 55 | 46. c4> 5d3+ 56 | 47. b1 c1 57 | 48. d1 a2 58 | 49. d2 59 | F-0 60 | 61 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/nelhage/taktician 2 | 3 | require ( 4 | github.com/golang/protobuf v1.5.2 5 | github.com/google/subcommands v0.0.0-20181012225330-46f0354f6315 6 | github.com/jmoiron/sqlx v1.2.0 7 | github.com/mattn/go-sqlite3 v1.10.0 8 | github.com/stretchr/testify v1.7.0 9 | golang.org/x/net v0.0.0-20220615171555-694bf12d69de 10 | golang.org/x/sync v0.0.0-20190423024810-112230192c58 11 | google.golang.org/grpc v1.47.0 12 | ) 13 | 14 | require ( 15 | github.com/davecgh/go-spew v1.1.0 // indirect 16 | github.com/go-sql-driver/mysql v1.4.1 // indirect 17 | github.com/pmezard/go-difflib v1.0.0 // indirect 18 | golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c // indirect 19 | golang.org/x/text v0.3.7 // indirect 20 | google.golang.org/appengine v1.4.0 // indirect 21 | google.golang.org/genproto v0.0.0-20220615141314-f1464d18c36b // indirect 22 | google.golang.org/protobuf v1.28.0 // indirect 23 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect 24 | ) 25 | 26 | go 1.18 27 | -------------------------------------------------------------------------------- /python/xformer/train/hooks/saving.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path 3 | 4 | import attrs 5 | import torch 6 | import yaml 7 | from attrs import define 8 | 9 | from xformer import loading 10 | from .. import run 11 | 12 | 13 | @define 14 | class Save(run.Hook): 15 | save_dir: str 16 | step_freq: int 17 | 18 | def after_step(self, run, stats): 19 | if stats.step % self.step_freq != 0: 20 | return 21 | self.save_run(run, stats) 22 | 23 | def after_run(self, run, stats): 24 | self.save_run(run, stats) 25 | 26 | def save_run(self, run: run.Run, stats: run.Stats): 27 | run_dir = os.path.join(self.save_dir, f"step_{stats.step:06d}") 28 | print(f"Saving to {run_dir}...") 29 | loading.save_model(run.model, run_dir) 30 | with open(os.path.join(run_dir, "stats.json"), "w") as fh: 31 | json.dump(attrs.asdict(stats), fh, indent=2) 32 | # torch.save(os.path.join(run_dir, "model.opt.pt"), run.optimizer.state_dict()) 33 | -------------------------------------------------------------------------------- /python/xformer/loading.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import torch 4 | from torch import nn 5 | import yaml 6 | 7 | from .model import Transformer 8 | 9 | 10 | def load_config(save_dir): 11 | with open(os.path.join(save_dir, "config.yaml")) as fh: 12 | return yaml.unsafe_load(fh) 13 | 14 | 15 | def load_snapshot(model: nn.Module, save_dir: str): 16 | state = torch.load(os.path.join(save_dir, "model.pt"), map_location="cpu") 17 | model.load_state_dict(state) 18 | 19 | 20 | def load_model(save_dir, device="cpu"): 21 | config = load_config(save_dir) 22 | model = Transformer(config, device=device) 23 | load_snapshot(model, save_dir) 24 | return model 25 | 26 | 27 | def save_model(model: Transformer, save_dir: str): 28 | os.makedirs(save_dir, exist_ok=True) 29 | 30 | torch.save( 31 | model.state_dict(), 32 | os.path.join(save_dir, "model.pt"), 33 | ) 34 | with open(os.path.join(save_dir, "config.yaml"), "w") as fh: 35 | yaml.dump(model.cfg, fh) 36 | -------------------------------------------------------------------------------- /proto/tak/proto/taktician.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tak.proto; 4 | option go_package="github.com/nelhage/taktician/pb"; 5 | 6 | service Taktician { 7 | rpc Analyze(AnalyzeRequest) returns (AnalyzeResponse) {} 8 | rpc Canonicalize(CanonicalizeRequest) returns (CanonicalizeResponse) {} 9 | rpc IsPositionInTak(IsPositionInTakRequest) returns (IsPositionInTakResponse) {} 10 | } 11 | 12 | message AnalyzeRequest { 13 | string position = 1; 14 | int32 depth = 2; 15 | bool precise = 3; 16 | } 17 | 18 | message AnalyzeResponse { 19 | repeated string pv = 1; 20 | int64 value = 2; 21 | int32 depth = 3; 22 | } 23 | 24 | message CanonicalizeRequest { 25 | int32 size = 1; 26 | repeated string moves = 2; 27 | } 28 | 29 | message CanonicalizeResponse { 30 | repeated string moves = 1; 31 | } 32 | 33 | message IsPositionInTakRequest { 34 | string position = 1; 35 | } 36 | 37 | message IsPositionInTakResponse { 38 | bool inTak = 1; 39 | string takMove = 2; 40 | } 41 | -------------------------------------------------------------------------------- /python/tak/pieces.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | from attrs import define 4 | 5 | 6 | class Color(enum.Enum): 7 | WHITE = 0 8 | BLACK = 1 9 | 10 | def flip(self): 11 | return Color(1 - self.value) 12 | 13 | 14 | class Kind(enum.Enum): 15 | FLAT = 0 16 | STANDING = 1 17 | CAPSTONE = 2 18 | 19 | def is_road(self): 20 | return self == Kind.FLAT or self == Kind.CAPSTONE 21 | 22 | 23 | _piece_cache = [[None for k in Kind] for c in Color] 24 | 25 | 26 | @define(frozen=True, slots=True) 27 | class Piece(object): 28 | color: Color 29 | kind: Kind 30 | 31 | def is_road(self): 32 | return self.kind.is_road() 33 | 34 | @classmethod 35 | def _init_cache(cls): 36 | for c in Color: 37 | for k in Kind: 38 | _piece_cache[c.value][k.value] = cls(c, k) 39 | 40 | @classmethod 41 | def cached(self, color, kind): 42 | return _piece_cache[color.value][kind.value] 43 | 44 | 45 | Piece._init_cache() 46 | 47 | __all__ = ["Color", "Kind", "Piece"] 48 | -------------------------------------------------------------------------------- /python/tak/model/batches.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from attrs import define 3 | 4 | 5 | @define 6 | class Position: 7 | data: dict[str, torch.Tensor] 8 | 9 | @property 10 | def inputs(self): 11 | return self.data["positions"][:, :-1] 12 | 13 | @property 14 | def extra_inputs(self): 15 | return (~self.mask,) 16 | 17 | @property 18 | def targets(self): 19 | return self.data["positions"][:, 1:] 20 | 21 | @property 22 | def mask(self): 23 | return self.data["mask"][:, :-1] 24 | 25 | 26 | @define 27 | class PositionValuePolicy: 28 | data: dict[str, torch.Tensor] 29 | 30 | @property 31 | def inputs(self): 32 | return self.data["positions"] 33 | 34 | @property 35 | def mask(self): 36 | return self.data["mask"] 37 | 38 | @property 39 | def extra_inputs(self): 40 | return (~self.mask,) 41 | 42 | @property 43 | def moves(self): 44 | return self.data["moves"] 45 | 46 | @property 47 | def values(self): 48 | return self.data["values"] 49 | -------------------------------------------------------------------------------- /bin/import-rankings: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eu 3 | ROOT=$(readlink -f "$(dirname $0)/..") 4 | rankings=$1 5 | db=${2-$ROOT/games/games.db} 6 | 7 | sqlite3 "$db" 3<"$rankings" <1 20 | 8. d3 c2 21 | 9. b4 Sa4 22 | 10. b5 a5 23 | 11. b6 a6 24 | 12. d1 1e1+1 25 | 13. 1f3-1 1a4>1 26 | 14. c5 2e2>2 27 | 15. 1f1+1 Se2 28 | 16. 5f2+212 a4 29 | 17. c3 a3 30 | 18. Sa2 2e3>2 31 | 19. b2 4f3<13 32 | 20. 1b6<1 2b4+2 33 | 21. b4 1a5+1 34 | 22. c1 a5 35 | 23. b3 3b5-3 36 | 24. b1 4d3<4 37 | 25. d3 Se5 38 | 26. 1d4<1 4b4-4 39 | 27. d4 1e5>1 40 | 28. 2c4<11 2f5-2 41 | 29. 2a4-2 5b3-5 42 | 30. a4 b3 43 | 31. 3a3>3 6b2>15 44 | 32. b2 c6 45 | 33. b6 Sb5 46 | 34. 1b6>1 3a6>12 47 | 35. 1c5+1 1b6>1 48 | 36. Sb6 5c6-32 49 | 37. 1b4>1 5c3+5 50 | 38. c3 4f4<13 51 | 39. a3 6c4-6 52 | 40. 1b6>1 1b5>1 53 | 41. b4 b5 54 | 42. f1 6d2<15 55 | 43. 1b3-1 1c5<1 56 | 44. 4b2+4 e1 57 | 45. 2b2>11 d5 58 | 46. 1f1<1 1e2-1 59 | 47. 2c6-2 6c3+51 60 | 48. 4c4<22 Sa6 61 | 49. e2 -------------------------------------------------------------------------------- /cmd/internal/tei/tei.go: -------------------------------------------------------------------------------- 1 | package tei 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "log" 7 | "os" 8 | 9 | "github.com/google/subcommands" 10 | "github.com/nelhage/taktician/cmd/internal/opt" 11 | "github.com/nelhage/taktician/tei" 12 | ) 13 | 14 | type Command struct { 15 | opt opt.Minimax 16 | } 17 | 18 | func (*Command) Name() string { return "tei" } 19 | func (*Command) Synopsis() string { return "Launch Taktician in TEI mode" } 20 | func (*Command) Usage() string { 21 | return `tei 22 | 23 | Launch the engine in TEI mode, a a UCI-like protocol suitable for being 24 | driven by an external GUI or controller. 25 | 26 | ` 27 | } 28 | 29 | func (c *Command) SetFlags(fs *flag.FlagSet) { 30 | c.opt.AddFlags(fs) 31 | } 32 | 33 | func (c *Command) Execute(ctx context.Context, flag *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { 34 | engine := tei.NewEngine(os.Stdin, os.Stdout) 35 | engine.ConfigFactory = c.opt.BuildConfig 36 | if err := engine.Run(ctx); err != nil { 37 | log.Println("tei: ", err.Error()) 38 | return subcommands.ExitFailure 39 | } 40 | 41 | return subcommands.ExitSuccess 42 | } 43 | -------------------------------------------------------------------------------- /COPYING: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright (c) 2016 Nelson Elhage 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /ai/mcts/policy_test.go: -------------------------------------------------------------------------------- 1 | package mcts 2 | 3 | import ( 4 | "log" 5 | "testing" 6 | 7 | "github.com/nelhage/taktician/bitboard" 8 | "github.com/nelhage/taktician/tak" 9 | "github.com/nelhage/taktician/taktest" 10 | ) 11 | 12 | func TestFindPlaceWins(t *testing.T) { 13 | cases := []struct { 14 | board string 15 | x, y int8 16 | }{ 17 | {` 18 | . . . . 19 | . . W B 20 | . B W W 21 | B . . W 22 | `, 2, 0}, 23 | {` 24 | W B B . 25 | . W W B 26 | . B W W 27 | B . . W 28 | `, 0, 1}, 29 | {` 30 | W B B . 31 | W B B B 32 | W . W . 33 | B . W . 34 | `, 1, 2}, 35 | } 36 | for n, tc := range cases { 37 | board, err := taktest.Board(tc.board, tak.White) 38 | if err != nil { 39 | t.Errorf("%d: %v", n, err) 40 | continue 41 | } 42 | log.Printf("to move=%s", board.ToMove()) 43 | c := bitboard.Precompute(uint(board.Size())) 44 | mv := placeWinMove(&c, board) 45 | if mv.Type != tak.PlaceFlat { 46 | t.Errorf("%d: bad move: type=%s", n, mv.Type) 47 | continue 48 | } 49 | if mv.X != tc.x || mv.Y != tc.y { 50 | t.Errorf("%d: bad move: (%d, %d) != (%d, %d)", 51 | n, mv.X, mv.Y, tc.x, tc.y) 52 | } 53 | 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PREFIX := github.com/nelhage/taktician 2 | 3 | PROTOS := $(wildcard proto/tak/proto/*.proto) 4 | PROTONAMES := $(foreach proto,$(PROTOS), $(basename $(notdir $(proto)))) 5 | GOPROTOSRC := $(foreach proto,$(PROTONAMES),pb/$(proto).pb.go) 6 | PYPROTOSRC := $(foreach proto,$(PROTONAMES),python/tak/proto/$(proto)_pb2.py) 7 | GENFILES := ai/feature_string.go $(GOPROTOSRC) $(PYPROTOSRC) 8 | 9 | 10 | ai/feature_string.go: ai/evaluate.go 11 | go generate $(PREFIX)/ai 12 | 13 | protoc: $(GOPROTOSRC) $(PYPROTOSRC) 14 | 15 | $(GOPROTOSRC) $(PYPROTOSRC): $(PROTOS) 16 | python -m grpc_tools.protoc\ 17 | -I proto/ \ 18 | --python_out=python/ \ 19 | --grpc_python_out=python/ \ 20 | --go_out=. \ 21 | --go_opt="module=$(PREFIX)" \ 22 | --go-grpc_out=. \ 23 | --go-grpc_opt="module=$(PREFIX)" \ 24 | proto/tak/proto/*.proto 25 | 26 | build: $(GENFILES) 27 | go build $(PREFIX)/... 28 | 29 | install: $(GENFILES) 30 | go install $(PREFIX)/... 31 | 32 | test: $(GENFILES) 33 | go test $(PREFIX)/... 34 | 35 | test-%: $(GENFILES) 36 | go test $(PREFIX)/$*... 37 | 38 | .PHONY: test install build protoc 39 | -------------------------------------------------------------------------------- /playtak/commands.go: -------------------------------------------------------------------------------- 1 | package playtak 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | ) 7 | 8 | type Commands struct { 9 | User string 10 | Client 11 | } 12 | 13 | func (c *Commands) SendClient(name string) { 14 | c.SendCommand("Client", name) 15 | } 16 | 17 | func (c *Commands) Login(user, pass string) error { 18 | for line := range c.Recv() { 19 | if strings.HasPrefix(line, "Login ") { 20 | break 21 | } 22 | } 23 | if pass == "" { 24 | c.SendCommand("Login", user) 25 | } else { 26 | c.SendCommand("Login", user, pass) 27 | } 28 | for line := range c.Recv() { 29 | if line == "Authentication failure" { 30 | return errors.New("bad username or password") 31 | } 32 | if strings.HasPrefix(line, "Welcome ") { 33 | c.User = user 34 | return nil 35 | } 36 | } 37 | return errors.New("login failed") 38 | } 39 | 40 | func (c *Commands) LoginGuest() error { 41 | return c.Login("Guest", "") 42 | } 43 | 44 | func (c *Commands) Shout(room, msg string) { 45 | if room == "" { 46 | c.SendCommand("Shout", msg) 47 | } else { 48 | c.SendCommand("ShoutRoom", room, msg) 49 | } 50 | } 51 | 52 | func (c *Commands) Tell(who, msg string) { 53 | c.SendCommand("Tell", who, msg) 54 | } 55 | -------------------------------------------------------------------------------- /python/xformer/train/hooks/test_loss.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | from attrs import define 5 | 6 | from ..run import Dataset, Hook, Run, Stats 7 | 8 | 9 | @define 10 | class TestLoss(Hook): 11 | __test__ = False 12 | 13 | dataset: Dataset 14 | frequency: int 15 | 16 | def after_step(self, run: Run, stats: Stats): 17 | if stats.step > 1 and stats.step % self.frequency != 0: 18 | return 19 | 20 | with torch.no_grad(): 21 | losses = [] 22 | metrics = defaultdict(float) 23 | for batch in self.dataset: 24 | out = run.model(batch.inputs, *batch.extra_inputs) 25 | loss, batch_metrics = run.loss.loss_and_metrics(batch, out) 26 | losses.append(loss) 27 | for (k, v) in batch_metrics.items(): 28 | metrics[k] += v 29 | for (k, v) in metrics.items(): 30 | metrics[k] = v / len(losses) 31 | test_loss = torch.stack(losses).mean().item() 32 | 33 | for (k, v) in metrics.items(): 34 | stats.metrics[f"test.{k}"] = v 35 | stats.metrics["test_loss"] = test_loss 36 | -------------------------------------------------------------------------------- /cmd/internal/playtak/book.go: -------------------------------------------------------------------------------- 1 | package playtak 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/nelhage/taktician/ai" 8 | ) 9 | 10 | var books []*ai.OpeningBook 11 | 12 | const book5 = ` 13 | a1 e5 e4 14 | a1 e5 d4 15 | a1 e5 e3 16 | a1 e1 c3 17 | a1 e1 e2 d4 18 | a1 e1 e2 e3 d2 19 | a1 e1 e2 a2 e3 20 | a1 e1 e2 Ce3 d2 d3 21 | a1 e1 e2 a3 e3 22 | a1 e1 e3 e2 23 | a1 e1 d2 a2 24 | a1 e1 e4 25 | ` 26 | 27 | const book6 = ` 28 | a1 f6 e4 29 | a1 f6 d4 d3 c4 30 | a1 f1 e3 d4 31 | a1 f1 e3 e2 32 | a1 f1 e3 d3 33 | a1 f1 e3 e4 34 | a1 f1 e3 Cd4 35 | a1 f1 f2 36 | a1 f1 d4 37 | a1 f1 d3 c3 d4 38 | a1 f1 d3 d4 39 | ` 40 | 41 | func init() { 42 | books = make([]*ai.OpeningBook, 9) 43 | var e error 44 | books[5], e = ai.BuildOpeningBook(5, 45 | strings.Split(strings.Trim(book5, " \n"), "\n")) 46 | if e == nil { 47 | books[6], e = ai.BuildOpeningBook(6, 48 | strings.Split(strings.Trim(book6, " \n"), "\n")) 49 | } 50 | if e != nil { 51 | panic(fmt.Sprintf("build: %v", e)) 52 | } 53 | } 54 | 55 | func (c *Command) wrapWithBook(size int, p ai.TakPlayer) ai.TakPlayer { 56 | if !c.book { 57 | return p 58 | } 59 | if size != 5 && size != 6 { 60 | return p 61 | } 62 | return ai.WithOpeningBook(p, books[size]) 63 | } 64 | -------------------------------------------------------------------------------- /python/tak/model/losses.py: -------------------------------------------------------------------------------- 1 | from attrs import define, field 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | class MaskedAR: 7 | def __init__(self): 8 | self.xent = nn.CrossEntropyLoss(reduction="none") 9 | 10 | def train_and_metrics(self, batch, logits): 11 | return ( 12 | (self.xent(logits.permute(0, 2, 1), batch.targets) * batch.mask).mean(), 13 | {}, 14 | ) 15 | 16 | 17 | @define(slots=False) 18 | class PolicyValue: 19 | v_weight: float = 1.0 20 | policy_weight: float = 1.0 21 | 22 | def loss_and_metrics(self, batch, logits): 23 | v_logits = logits["values"] 24 | m_logits = logits["moves"] 25 | 26 | v_error = F.mse_loss(v_logits, batch.values) 27 | 28 | metrics = { 29 | "v_error": v_error.item(), 30 | } 31 | 32 | moves = batch.moves 33 | if moves.ndim == 1: 34 | with torch.no_grad(): 35 | argmax = torch.argmax(m_logits, dim=-1) 36 | match = argmax == moves 37 | metrics["acc@01"] = match.float().mean().item() 38 | 39 | return ( 40 | self.v_weight * v_error 41 | + self.policy_weight * F.cross_entropy(m_logits, moves) 42 | ), metrics 43 | -------------------------------------------------------------------------------- /testdata/zoo/_pending/76201.ptn: -------------------------------------------------------------------------------- 1 | [Size "6"] 2 | [Date "2016-08-12"] 3 | [Time "2016-08-12T16:01:10Z"] 4 | [Player1 "fwwwwibib"] 5 | [Player2 "Gerrek"] 6 | [Result "R-0"] 7 | [Id "76201"] 8 | 9 | 10 | 1. a1 f1 11 | 2. d3 c4 12 | 3. d4 c5 13 | 4. d5 Ce4 14 | 5. Cc3 d2 15 | 6. e3 b3 16 | 7. b2 a3 17 | 8. a2 c2 18 | 9. e2 d1 19 | 10. b4 e4- 20 | 11. c1 b1 21 | 12. f2 f3 22 | 13. c1+ 2e3< 23 | 14. c1 e3 24 | 15. e4 d1< 25 | 16. e4- d1 26 | 17. c3- a3- 27 | 18. c3 e4 28 | 19. 3c2- c6 29 | 20. c3+ a3 30 | 21. d6 Sc2 31 | 22. c3 3d3+ 32 | 23. d3 a4 33 | 24. 5c1<14 e4- 34 | 25. e2+ f3< 35 | 26. d3> Se4 36 | 27. 6e3-24 d1> 37 | 28. b5 4e1+22 38 | 29. b6 c2+ 39 | 30. 5a1+41 b3+ 40 | 31. b3 4d4< 41 | 32. 2a3+ 5c4< 42 | 33. 5a2+ 2c3<11 43 | 34. a4- d3 44 | 35. d6< d4 45 | 36. d5- c3 46 | 37. 2b3>11 e4< 47 | 38. Sf3 e4 48 | 39. f3< f4 49 | 40. 2a3+ 6b4+15 50 | 41. 2e3- 2d4- 51 | 42. e3+ 4d3<121 52 | 43. 4e2<13 6a3>42 53 | 44. 4a4> 6b3< 54 | 45. c1 6a3-42 55 | 46. Sa3 2a1> 56 | 47. a3- 3b1>21 57 | 48. 3c2- 4c3- 58 | 49. 5b4>14 b4 59 | 50. 5a2+113 b3 60 | 51. 5c1< 2c2> 61 | 52. c1+ 3d2< 62 | 53. 6b1>24 b1 63 | 54. 3a5> b6- 64 | 55. 4b6>1111 4b5+ 65 | 56. 2c1< c1 66 | 57. 5d1< d1 67 | 58. 6c1> c1 68 | 59. a4> 5b6>212 69 | 60. a4 b3+ 70 | 61. a4> 4c2< 71 | 62. a4 5b2+14 72 | 63. b3> 5b4-221 73 | 64. e3 74 | R-0 75 | 76 | -------------------------------------------------------------------------------- /python/test/test_mcts.py: -------------------------------------------------------------------------------- 1 | from tak import game, mcts 2 | import xformer 3 | from tak.model import heads, wrapper 4 | 5 | import torch 6 | import tak_ext 7 | 8 | 9 | def test_mcts(): 10 | cfg = xformer.Config( 11 | n_layer=1, 12 | d_model=64, 13 | d_head=32, 14 | n_ctx=128, 15 | n_vocab=256, 16 | autoregressive_mask=False, 17 | output_head=heads.PolicyValue, 18 | ) 19 | model = xformer.Transformer(cfg) 20 | 21 | engine = mcts.MCTS( 22 | config=mcts.Config( 23 | time_limit=0, 24 | simulation_limit=5, 25 | ), 26 | network=wrapper.ModelWrapper(model), 27 | ) 28 | 29 | print(engine.get_move(game.Position.from_config(game.Config(size=3)))) 30 | 31 | 32 | def test_solve_policy(): 33 | pi_theta = torch.tensor( 34 | [0.1818, 0.1651, 0.1377, 0.1367, 0.1307, 0.1033, 0.0655, 0.0558, 0.0235] 35 | ) 36 | q = torch.tensor( 37 | [-0.6232, 0.6529, 0.6529, 0.6529, 0.6529, 0.6529, 0.6529, 0.6529, 0.6529] 38 | ) 39 | lambda_n = 0.0899954085146515 40 | 41 | policy = tak_ext.solve_policy(pi_theta, q, lambda_n) 42 | 43 | py_policy = mcts.solve_policy_python(pi_theta, q, lambda_n) 44 | 45 | assert (policy >= 0).all() 46 | assert ((policy - py_policy).abs() <= 1e-2).all() 47 | -------------------------------------------------------------------------------- /tak/slide.go: -------------------------------------------------------------------------------- 1 | package tak 2 | 3 | import "fmt" 4 | 5 | // Slides is essentially a packed [8]uint4, used to represent the 6 | // slide counts in a Tak move in a space-efficient way. We store the 7 | // first drop count in (s&0xf), the next in (s&0xf0), and so on. 8 | type Slides uint32 9 | 10 | func MkSlides(drops ...int) Slides { 11 | var out Slides 12 | for i := len(drops) - 1; i >= 0; i-- { 13 | if drops[i] > 8 { 14 | panic(fmt.Sprintf("bad drop: %#v", drops)) 15 | } 16 | out = out.Prepend(drops[i]) 17 | } 18 | return out 19 | } 20 | 21 | func (s Slides) Len() int { 22 | l := 0 23 | for s != 0 { 24 | l++ 25 | s >>= 4 26 | } 27 | return l 28 | } 29 | 30 | func (s Slides) Empty() bool { 31 | return s == 0 32 | } 33 | 34 | func (s Slides) Singleton() bool { 35 | return s > 0xf 36 | } 37 | 38 | func (s Slides) First() int { 39 | return int(s & 0xf) 40 | } 41 | 42 | func (s Slides) Prepend(next int) Slides { 43 | return (s << 4) | Slides(next) 44 | } 45 | 46 | type SlideIterator uint32 47 | 48 | func (s Slides) Iterator() SlideIterator { 49 | return SlideIterator(s) 50 | } 51 | 52 | func (s SlideIterator) Next() SlideIterator { 53 | return s >> 4 54 | } 55 | 56 | func (s SlideIterator) Ok() bool { 57 | return s != 0 58 | } 59 | 60 | func (s SlideIterator) Elem() int { 61 | return int(s & 0xf) 62 | } 63 | -------------------------------------------------------------------------------- /ai/mcts/debug.go: -------------------------------------------------------------------------------- 1 | package mcts 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | "math" 8 | "os" 9 | 10 | "github.com/nelhage/taktician/ptn" 11 | ) 12 | 13 | func (mc *MonteCarloAI) dumpTree(t *tree) { 14 | f, e := os.Create(mc.cfg.DumpTree) 15 | if e != nil { 16 | log.Printf("DumpTree(%s): %v", mc.cfg.DumpTree, e) 17 | return 18 | } 19 | defer f.Close() 20 | 21 | fmt.Fprintf(f, "digraph G {\n") 22 | mc.dumpTreeNode(f, t) 23 | fmt.Fprintf(f, "}\n") 24 | } 25 | 26 | func (mc *MonteCarloAI) dumpTreeNode(f io.Writer, t *tree) { 27 | parent := 1 28 | if t.parent != nil { 29 | parent = t.parent.simulations 30 | if t.parent.proven != 0 && t.proven == 0 { 31 | return 32 | } 33 | 34 | } 35 | label := fmt.Sprintf("n=%d p=%d v=%.0f+%.0f", 36 | t.simulations, 37 | t.proven, 38 | float64(t.value)/float64(t.simulations), 39 | mc.cfg.C*math.Sqrt(math.Log(float64(t.simulations))/float64(parent))) 40 | 41 | fmt.Fprintf(f, ` n%p [label="%s"]`, t, label) 42 | fmt.Fprintln(f) 43 | if t.children == nil { 44 | return 45 | } 46 | 47 | for _, c := range t.children { 48 | if t.proven > 0 && c.proven >= 0 { 49 | continue 50 | } 51 | fmt.Fprintf(f, ` n%p -> n%p [label="%s"]`, 52 | t, c, ptn.FormatMove(c.move)) 53 | fmt.Fprintln(f) 54 | mc.dumpTreeNode(f, c) 55 | if c.proven < 0 { 56 | break 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /playtak/client_test.go: -------------------------------------------------------------------------------- 1 | package playtak 2 | 3 | import "testing" 4 | 5 | func TestParseShout(t *testing.T) { 6 | cases := []struct { 7 | in string 8 | who string 9 | msg string 10 | }{ 11 | {"zzzz", "", ""}, 12 | {"Shout zzzz", "", ""}, 13 | {"Shout hi there", "nelhage", "hi there"}, 14 | {"Shout hi there", "IRC", " hi there"}, 15 | } 16 | for i, tc := range cases { 17 | who, msg := ParseShout(tc.in) 18 | if who != tc.who { 19 | t.Errorf("[%d] got who=%q!=%q", 20 | i, who, tc.who) 21 | } 22 | if msg != tc.msg { 23 | t.Errorf("[%d] got msg=%q!=%q", 24 | i, msg, tc.msg) 25 | } 26 | } 27 | } 28 | 29 | func TestParseShoutRoom(t *testing.T) { 30 | cases := []struct { 31 | in string 32 | room string 33 | who string 34 | msg string 35 | }{ 36 | {"ShoutRoom zzzz", "", "", ""}, 37 | {"ShoutRoom Game1 hi there", "Game1", "nelhage", "hi there"}, 38 | {"ShoutRoom hi there", "", "", ""}, 39 | } 40 | for i, tc := range cases { 41 | room, who, msg := ParseShoutRoom(tc.in) 42 | if room != tc.room { 43 | t.Errorf("[%d] got room=%q!=%q", 44 | i, room, tc.room) 45 | } 46 | if who != tc.who { 47 | t.Errorf("[%d] got who=%q!=%q", 48 | i, who, tc.who) 49 | } 50 | if msg != tc.msg { 51 | t.Errorf("[%d] got msg=%q!=%q", 52 | i, msg, tc.msg) 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /ai/opening_test.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "math/rand" 5 | "testing" 6 | 7 | "github.com/nelhage/taktician/ptn" 8 | "github.com/nelhage/taktician/taktest" 9 | ) 10 | 11 | func TestOpeningBook(t *testing.T) { 12 | moves := []string{ 13 | `a1 f1`, 14 | `a1 f6`, 15 | } 16 | ob, err := BuildOpeningBook(6, moves) 17 | if err != nil { 18 | t.Fatal("build: ", err) 19 | } 20 | 21 | r := rand.New(rand.NewSource(1)) 22 | 23 | p := taktest.Position(6, "") 24 | m, ok := ob.GetMove(p, r) 25 | if !ok { 26 | t.Fatal("no move") 27 | } 28 | f := ptn.FormatMove(m) 29 | if f != "a1" { 30 | t.Fatal("wrong move: ", f) 31 | } 32 | 33 | p = taktest.Position(6, "f1") 34 | m, ok = ob.GetMove(p, r) 35 | if !ok { 36 | t.Fatal("no move f1") 37 | } 38 | 39 | pos := ob.book[p.Hash()] 40 | if len(pos.moves) != 2 { 41 | t.Fatal("wrong children n=", len(pos.moves)) 42 | } 43 | } 44 | 45 | func TestCollisions(t *testing.T) { 46 | ob, err := BuildOpeningBook(6, []string{`a1 f6 d4 d3 c4`}) 47 | if err != nil { 48 | t.Fatal("build ", err) 49 | } 50 | p := taktest.Position(6, "a1 f6 d4 d3") 51 | pos := ob.book[p.Hash()] 52 | if pos == nil { 53 | t.Fatal("did not store") 54 | } 55 | for _, c := range pos.moves { 56 | _, e := p.Move(c.move) 57 | if e != nil { 58 | t.Logf("children=%#v", pos.moves) 59 | t.Errorf("illegal move=%s w=%d", ptn.FormatMove(c.move), c.weight) 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /python/tak/alphazero/hooks/test_loss.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | from attrs import define 5 | 6 | from xformer.data import Dataset 7 | from tak.alphazero import losses 8 | from xformer.train import LossFunction 9 | from ..trainer import Hook, TrainState 10 | from functools import cached_property 11 | 12 | 13 | @define(slots=False) 14 | class TestLoss(Hook): 15 | __test__ = False 16 | 17 | dataset: Dataset 18 | frequency: int 19 | loss: LossFunction = losses.ReferenceAccuracy() 20 | 21 | name: str = "test" 22 | 23 | def after_step(self, state: TrainState): 24 | if state.elapsed.step > 1 and state.elapsed.step % self.frequency != 0: 25 | return 26 | 27 | with torch.no_grad(): 28 | losses = [] 29 | metrics = defaultdict(float) 30 | for batch in self.dataset: 31 | out = state.model(batch.inputs, *batch.extra_inputs) 32 | loss, batch_metrics = self.loss.loss_and_metrics(batch, out) 33 | losses.append(loss) 34 | for (k, v) in batch_metrics.items(): 35 | metrics[k] += v 36 | for (k, v) in metrics.items(): 37 | metrics[k] = v / len(losses) 38 | test_loss = torch.stack(losses).mean().item() 39 | 40 | for (k, v) in metrics.items(): 41 | state.step_stats[f"{self.name}.{k}"] = v 42 | state.step_stats[f"{self.name}.loss"] = test_loss 43 | -------------------------------------------------------------------------------- /tak/pieces.go: -------------------------------------------------------------------------------- 1 | package tak 2 | 3 | import "fmt" 4 | 5 | type Color byte 6 | type Kind byte 7 | type Piece byte 8 | 9 | const ( 10 | White Color = 1 << 7 11 | Black Color = 1 << 6 12 | NoColor Color = 0 13 | 14 | colorMask byte = 3 << 6 15 | 16 | Flat Kind = 1 17 | Standing Kind = 2 18 | Capstone Kind = 3 19 | 20 | typeMask byte = 1<<2 - 1 21 | ) 22 | 23 | func MakePiece(color Color, kind Kind) Piece { 24 | return Piece(byte(color) | byte(kind)) 25 | } 26 | 27 | func (p Piece) Color() Color { 28 | return Color(byte(p) & colorMask) 29 | } 30 | 31 | func (p Piece) Kind() Kind { 32 | return Kind(byte(p) & typeMask) 33 | } 34 | 35 | func (p Piece) IsRoad() bool { 36 | return p.Kind() == Flat || p.Kind() == Capstone 37 | } 38 | 39 | func (p Piece) String() string { 40 | c := "" 41 | if p.Color() == White { 42 | c = "W" 43 | } else { 44 | c = "B" 45 | } 46 | switch p.Kind() { 47 | case Capstone: 48 | c += "C" 49 | case Standing: 50 | c += "S" 51 | } 52 | return c 53 | } 54 | 55 | func (c Color) String() string { 56 | switch c { 57 | case White: 58 | return "white" 59 | case Black: 60 | return "black" 61 | case NoColor: 62 | return "no color" 63 | default: 64 | panic(fmt.Sprintf("bad color: %x", int(c))) 65 | } 66 | } 67 | 68 | func (c Color) Flip() Color { 69 | switch c { 70 | case White: 71 | return Black 72 | case Black: 73 | return White 74 | case NoColor: 75 | return NoColor 76 | default: 77 | panic(fmt.Sprintf("bad color: %x", int(c))) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /tak/hash.go: -------------------------------------------------------------------------------- 1 | package tak 2 | 3 | import "math/rand" 4 | 5 | const ( 6 | fnvBasis = 14695981039346656037 7 | fnvPrime = 1099511628211 8 | ) 9 | 10 | var basis [64]uint64 11 | 12 | func init() { 13 | r := rand.New(rand.NewSource(0x7a3)) 14 | for i := 0; i < 64; i++ { 15 | basis[i] = uint64(r.Int63()) 16 | } 17 | } 18 | 19 | func hash8(basis uint64, b byte) uint64 { 20 | return (basis ^ uint64(b)) * fnvPrime 21 | } 22 | 23 | func hash64(basis uint64, w uint64) uint64 { 24 | h := basis 25 | h = (h ^ (w & 0xff)) * fnvPrime 26 | h = (h ^ ((w >> 8) & 0xff)) * fnvPrime 27 | h = (h ^ ((w >> 16) & 0xff)) * fnvPrime 28 | h = (h ^ (w >> 24)) * fnvPrime 29 | return h 30 | } 31 | 32 | func (p *Position) hashAt(i uint) uint64 { 33 | if p.Height[i] <= 1 { 34 | return 0 35 | } 36 | return hash64(hash8(basis[i], p.Height[i]), p.Stacks[i]) 37 | } 38 | 39 | func (p *Position) Hash() uint64 { 40 | h := p.hash 41 | h = hash64(h, p.White) 42 | h = hash64(h, p.Black) 43 | h = hash64(h, p.Standing) 44 | h = hash64(h, p.Caps) 45 | h = hash8(h, byte(p.ToMove())) 46 | return h 47 | } 48 | 49 | func (p *Position) Equal(rhs *Position) bool { 50 | if p.cfg.Size != rhs.cfg.Size || p.hash != rhs.hash || 51 | p.White != rhs.White || 52 | p.Black != rhs.Black || 53 | p.Standing != rhs.Standing || 54 | p.Caps != rhs.Caps || 55 | p.ToMove() != rhs.ToMove() { 56 | return false 57 | } 58 | for i := range p.Height { 59 | if p.Height[i] != rhs.Height[i] || 60 | p.Stacks[i] != rhs.Stacks[i] { 61 | return false 62 | } 63 | } 64 | return true 65 | } 66 | -------------------------------------------------------------------------------- /playtak/move_test.go: -------------------------------------------------------------------------------- 1 | package playtak 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/nelhage/taktician/tak" 8 | ) 9 | 10 | func TestParseServer(t *testing.T) { 11 | cases := []struct { 12 | in string 13 | out tak.Move 14 | }{ 15 | { 16 | "P A1", 17 | tak.Move{ 18 | X: 0, Y: 0, Type: tak.PlaceFlat, 19 | }, 20 | }, 21 | { 22 | "P H8 C", 23 | tak.Move{ 24 | X: 7, Y: 7, Type: tak.PlaceCapstone, 25 | }, 26 | }, 27 | { 28 | "P C1 W", 29 | tak.Move{ 30 | X: 2, Y: 0, Type: tak.PlaceStanding, 31 | }, 32 | }, 33 | { 34 | "M C1 C3 4 1", 35 | tak.Move{ 36 | X: 2, Y: 0, Type: tak.SlideUp, 37 | Slides: tak.MkSlides(4, 1), 38 | }, 39 | }, 40 | { 41 | "M D2 E2 1", 42 | tak.Move{ 43 | X: 3, Y: 1, Type: tak.SlideRight, 44 | Slides: tak.MkSlides(1), 45 | }, 46 | }, 47 | { 48 | "M D4 D1 1 1 1", 49 | tak.Move{ 50 | X: 3, Y: 3, Type: tak.SlideDown, 51 | Slides: tak.MkSlides(1, 1, 1), 52 | }, 53 | }, 54 | { 55 | "M D4 A4 3 1 1", 56 | tak.Move{ 57 | X: 3, Y: 3, Type: tak.SlideLeft, 58 | Slides: tak.MkSlides(3, 1, 1), 59 | }, 60 | }, 61 | } 62 | for _, tc := range cases { 63 | m, e := ParseServer(tc.in) 64 | if e != nil { 65 | t.Errorf("parse(%s): %v", tc.in, e) 66 | continue 67 | } 68 | if !reflect.DeepEqual(m, tc.out) { 69 | t.Errorf("parse(%s) = %#v not %#v", tc.in, m, tc.out) 70 | } 71 | back := FormatServer(m) 72 | if back != tc.in { 73 | t.Errorf("round-trip(%s) = %s", tc.in, back) 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /python/tak/alphazero/config.py: -------------------------------------------------------------------------------- 1 | from attrs import define, field 2 | from tak import mcts 3 | import torch 4 | import secrets 5 | from functools import partial 6 | from typing import Optional, TYPE_CHECKING 7 | import xformer 8 | 9 | if TYPE_CHECKING: 10 | from .trainer import Hook 11 | 12 | 13 | def default_hooks() -> list["Hook"]: 14 | from . import hooks 15 | 16 | return [ 17 | hooks.TimingHook(), 18 | ] 19 | 20 | 21 | @define(slots=False) 22 | class Config: 23 | model: xformer.Config 24 | 25 | device: str = "cuda" 26 | server_port: int = 5432 27 | 28 | run_dir: Optional[str] = None 29 | load_model: Optional[str] = None 30 | 31 | lr: float = 1e-3 32 | 33 | size: int = 3 34 | 35 | rollout_config: mcts.Config = field( 36 | factory=lambda: mcts.Config( 37 | simulation_limit=25, 38 | root_noise_alpha=1.0, 39 | root_noise_mix=0.25, 40 | ) 41 | ) 42 | 43 | rollout_resignation_threshold: float = 0.95 44 | rollout_ply_limit: int = 100 45 | 46 | rollout_workers: int = 50 47 | rollouts_per_step: int = 100 48 | replay_buffer_steps: int = 4 49 | 50 | train_batch: int = 64 51 | train_positions: int = 1024 52 | 53 | train_dtype: torch.dtype = torch.float32 54 | serve_dtype: torch.dtype = torch.float16 55 | 56 | train_steps: int = 10 57 | 58 | hooks: list["Hook"] = field(factory=default_hooks) 59 | 60 | def __attrs_post_init__(self): 61 | if self.device == "cpu": 62 | self.serve_dtype = torch.float32 63 | -------------------------------------------------------------------------------- /tests/bench_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "testing" 7 | 8 | "github.com/nelhage/taktician/ai" 9 | "github.com/nelhage/taktician/ptn" 10 | "github.com/nelhage/taktician/tak" 11 | ) 12 | 13 | var seed = flag.Int64("seed", 4, "random seed") 14 | 15 | func BenchmarkMoveEmpty(b *testing.B) { 16 | p := tak.New(tak.Config{Size: 5}) 17 | n := tak.New(tak.Config{Size: 5}) 18 | ms := p.AllMoves(nil) 19 | b.ReportAllocs() 20 | b.ResetTimer() 21 | for i := 0; i < b.N; i++ { 22 | for { 23 | _, e := p.MovePreallocated(ms[i%len(ms)], n) 24 | if e == nil { 25 | break 26 | } 27 | } 28 | } 29 | } 30 | 31 | func BenchmarkMoveComplex(b *testing.B) { 32 | p, e := ptn.ParseTPS("112S,12,1112S,x2/x2,121C,12S,x/1,21,2,2,2/x,2,1,1,1/2,x3,21 2 24") 33 | n := tak.New(tak.Config{Size: 5}) 34 | if e != nil { 35 | panic("bad tps") 36 | } 37 | ms := p.AllMoves(nil) 38 | b.ReportAllocs() 39 | b.ResetTimer() 40 | for i := 0; i < b.N; i++ { 41 | j := i 42 | for { 43 | _, e := p.MovePreallocated(ms[j%len(ms)], n) 44 | if e == nil { 45 | break 46 | } 47 | j++ 48 | } 49 | } 50 | } 51 | 52 | func BenchmarkPuzzle1(b *testing.B) { 53 | p, e := ptn.ParseTPS("2,x2,121C,1/x2,2,12,1/x2,2,12S,2/x3,1,1/x4,1 1 2") 54 | if e != nil { 55 | panic("bad tps") 56 | } 57 | 58 | b.ReportAllocs() 59 | b.ResetTimer() 60 | 61 | for i := 0; i < b.N; i++ { 62 | mm := ai.NewMinimax(ai.MinimaxConfig{ 63 | Depth: 7, 64 | Seed: *seed, 65 | Size: p.Size(), 66 | }) 67 | mm.GetMove(context.Background(), p) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /cmd/internal/importptn/sql.go: -------------------------------------------------------------------------------- 1 | package importptn 2 | 3 | const createPTNTable = ` 4 | CREATE TABLE IF NOT EXISTS ptns ( 5 | id integer primary key, 6 | ptn string 7 | ) 8 | ` 9 | 10 | /* 11 | CREATE TABLE games ( 12 | id INTEGER PRIMARY KEY, 13 | date INT, 14 | size INT, 15 | player_white VARCHAR(20), 16 | player_black VARCHAR(20), 17 | notation TEXT, 18 | result VARCAR(10), 19 | timertime INT DEFAULT 0, 20 | timerinc INT DEFAULT 0, 21 | rating_white int default 1000, 22 | rating_black int default 1000, 23 | unrated int default 0, 24 | tournament int default 0, 25 | komi int default 0, 26 | pieces int default -1, 27 | capstones int default -1, 28 | rating_change_white int default 0, 29 | rating_change_black int default 0); 30 | */ 31 | 32 | type gameRow struct { 33 | Id int `db:"id"` 34 | Date int `db:"date"` 35 | Size int `db:"size"` 36 | 37 | PlayerWhite string `db:"player_white"` 38 | PlayerBlack string `db:"player_black"` 39 | 40 | Notation string `db:"notation"` 41 | Result string `db:"result"` 42 | 43 | TimerTime int `db:"timertime"` 44 | TimerInc int `db:"timerinc"` 45 | } 46 | 47 | type ptnRow struct { 48 | Id int `db:"id"` 49 | PTN string `db:"ptn"` 50 | } 51 | 52 | const selectTODO = ` 53 | SELECT g.id, g.date, g.size, g.player_white, g.player_black, g.notation, g.result, g.timertime, g.timerinc 54 | FROM games g LEFT OUTER JOIN ptns p 55 | ON (g.id = p.id) 56 | WHERE p.id is NULL 57 | AND g.notation IS NOT NULL 58 | AND g.notation != "" 59 | ` 60 | 61 | const insertPTN = ` 62 | INSERT INTO ptns (id, ptn) 63 | VALUES (:id, :ptn) 64 | ` 65 | -------------------------------------------------------------------------------- /cmd/internal/canonicalize/main.go: -------------------------------------------------------------------------------- 1 | package canonicalize 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "strconv" 9 | 10 | "github.com/google/subcommands" 11 | "github.com/nelhage/taktician/ptn" 12 | "github.com/nelhage/taktician/symmetry" 13 | "github.com/nelhage/taktician/tak" 14 | ) 15 | 16 | type Command struct{} 17 | 18 | func (*Command) Name() string { return "canonicalize" } 19 | func (*Command) Synopsis() string { return "Canonicalize the symmetry of a PTN" } 20 | func (*Command) Usage() string { 21 | return `canonicalize FILE.ptn 22 | 23 | Rewrite a PTN into a symmetric PTN in a canonical orientation. 24 | ` 25 | } 26 | 27 | func (c *Command) SetFlags(flags *flag.FlagSet) { 28 | } 29 | 30 | func (c *Command) Execute(ctx context.Context, flag *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { 31 | if len(flag.Args()) == 0 { 32 | flag.Usage() 33 | return subcommands.ExitUsageError 34 | } 35 | 36 | g, e := ptn.ParseFile(flag.Arg(0)) 37 | if e != nil { 38 | log.Fatalf("read %s: %v", flag.Arg(0), e) 39 | } 40 | 41 | var ms []tak.Move 42 | for _, o := range g.Ops { 43 | if m, ok := o.(*ptn.Move); ok { 44 | ms = append(ms, m.Move) 45 | } 46 | } 47 | 48 | sz, e := strconv.ParseUint(g.FindTag("Size"), 10, 32) 49 | if e != nil { 50 | log.Fatalf("bad size: %v", e) 51 | } 52 | out, e := symmetry.Canonical(int(sz), ms) 53 | if e != nil { 54 | log.Fatalf("canonicalize: %v", e) 55 | } 56 | 57 | i := 0 58 | for _, o := range g.Ops { 59 | if m, ok := o.(*ptn.Move); ok { 60 | m.Move = out[i] 61 | i++ 62 | } 63 | } 64 | 65 | fmt.Printf(g.Render()) 66 | return subcommands.ExitSuccess 67 | } 68 | -------------------------------------------------------------------------------- /python/xformer/train/hooks/profile.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import typing as T 3 | from functools import partial 4 | 5 | from attrs import define, field 6 | from torch.profiler import ProfilerAction, profile 7 | 8 | from ..run import Hook, Run, Stats 9 | 10 | 11 | @define 12 | class Profile(Hook): 13 | extra_steps: set[int] = field(factory=set) 14 | every: T.Optional[int] = None 15 | output_root: str = "profile/" 16 | 17 | profiler: profile = field(init=False) 18 | 19 | def before_run(self, run: Run): 20 | run.profiler = profile( 21 | schedule=self.schedule, 22 | with_stack=True, 23 | on_trace_ready=partial(self.save_profile, run=run), 24 | ) 25 | run.profiler.start() 26 | 27 | def should_profile(self, step): 28 | if step in self.extra_steps: 29 | return True 30 | if self.every is not None: 31 | return step % self.every == 0 32 | return False 33 | 34 | def schedule(self, step): 35 | if self.should_profile(step): 36 | print(f"Profiling step {step}...") 37 | return ProfilerAction.RECORD_AND_SAVE 38 | if self.should_profile(step + 1): 39 | return ProfilerAction.WARMUP 40 | return ProfilerAction.NONE 41 | 42 | def save_profile(self, prof, run): 43 | os.makedirs(self.output_root, 0o755, True) 44 | prof.export_chrome_trace( 45 | os.path.join( 46 | self.output_root, f"step_{run.profiler.step_num-1}.pt.trace.json" 47 | ) 48 | ) 49 | 50 | def after_step(self, run: Run, stats: Stats): 51 | run.profiler.step() 52 | -------------------------------------------------------------------------------- /python/bench/probs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from functools import partial 4 | 5 | import tak 6 | from tak import mcts 7 | from tak.model import grpc 8 | from attrs import define, field 9 | 10 | import concurrent.futures 11 | 12 | import torch 13 | 14 | import time 15 | 16 | 17 | def main(argv): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | "--iterations", 21 | dest="iterations", 22 | type=int, 23 | default=100, 24 | ) 25 | parser.add_argument( 26 | "--size", 27 | dest="size", 28 | type=int, 29 | default=3, 30 | metavar="SIZE", 31 | ) 32 | parser.add_argument( 33 | "--host", 34 | type=str, 35 | default="localhost", 36 | ) 37 | parser.add_argument( 38 | "--port", 39 | type=int, 40 | default=5001, 41 | ) 42 | 43 | args = parser.parse_args(argv) 44 | 45 | network = grpc.GRPCNetwork(host=args.host, port=args.port) 46 | 47 | engine = mcts.MCTS( 48 | mcts.Config( 49 | network=network, 50 | simulation_limit=1, 51 | time_limit=0, 52 | ) 53 | ) 54 | 55 | tree = engine.analyze(tak.Position.from_config(tak.Config(size=args.size))) 56 | 57 | start = time.perf_counter() 58 | for _ in range(args.iterations): 59 | engine.tree_probs(tree) 60 | 61 | end = time.perf_counter() 62 | 63 | print( 64 | f"done loops={args.iterations}" 65 | + f" duration={end-start:.2f}" 66 | + f" us/loop={1_000_000*(end-start)/args.iterations:.2f}" 67 | ) 68 | 69 | 70 | if __name__ == "__main__": 71 | main(sys.argv[1:]) 72 | -------------------------------------------------------------------------------- /python/test/test_move.py: -------------------------------------------------------------------------------- 1 | import tak 2 | 3 | 4 | class TestMove(object): 5 | def test_is_slide(self): 6 | assert not tak.MoveType.PLACE_FLAT.is_slide() 7 | assert not tak.MoveType.PLACE_STANDING.is_slide() 8 | assert not tak.MoveType.PLACE_CAPSTONE.is_slide() 9 | assert tak.MoveType.SLIDE_LEFT.is_slide() 10 | assert tak.MoveType.SLIDE_RIGHT.is_slide() 11 | assert tak.MoveType.SLIDE_UP.is_slide() 12 | assert tak.MoveType.SLIDE_DOWN.is_slide() 13 | 14 | def test_direction(self): 15 | assert tak.MoveType.SLIDE_LEFT.direction() == (-1, 0) 16 | assert tak.MoveType.SLIDE_RIGHT.direction() == (1, 0) 17 | assert tak.MoveType.SLIDE_UP.direction() == (0, 1) 18 | assert tak.MoveType.SLIDE_DOWN.direction() == (0, -1) 19 | 20 | 21 | def test_all_moves(): 22 | moves = set(tak.all_moves_for_size(5)) 23 | 24 | assert tak.Move(0, 0) in moves 25 | assert tak.Move(0, 4) in moves 26 | assert tak.Move(4, 4) in moves 27 | 28 | assert tak.Move(0, 0, tak.MoveType.PLACE_CAPSTONE) in moves 29 | assert tak.Move(0, 0, tak.MoveType.PLACE_STANDING) in moves 30 | 31 | assert tak.Move(5, 5) not in moves 32 | 33 | assert tak.Move(0, 0, tak.MoveType.SLIDE_RIGHT, (1, 1, 1, 1)) in moves 34 | assert tak.Move(1, 0, tak.MoveType.SLIDE_RIGHT, (1, 1, 1)) in moves 35 | assert tak.Move(2, 0, tak.MoveType.SLIDE_RIGHT, (1, 1, 1)) not in moves 36 | 37 | assert all(m.type != tak.MoveType.SLIDE_LEFT for m in moves if m.x == 0) 38 | assert all(m.type != tak.MoveType.SLIDE_RIGHT for m in moves if m.x == 4) 39 | assert all(m.type != tak.MoveType.SLIDE_UP for m in moves if m.y == 4) 40 | assert all(m.type != tak.MoveType.SLIDE_DOWN for m in moves if m.y == 0) 41 | -------------------------------------------------------------------------------- /ptn/iterator.go: -------------------------------------------------------------------------------- 1 | package ptn 2 | 3 | import "github.com/nelhage/taktician/tak" 4 | 5 | type Iterator struct { 6 | ptn *PTN 7 | i int 8 | 9 | err error 10 | over bool 11 | 12 | initial bool 13 | position *tak.Position 14 | ptnMove int 15 | lastMove, move tak.Move 16 | } 17 | 18 | func (p *PTN) Iterator() *Iterator { 19 | pos, err := p.InitialPosition() 20 | return &Iterator{ 21 | ptn: p, 22 | position: pos, 23 | err: err, 24 | initial: true, 25 | } 26 | } 27 | 28 | func (i *Iterator) Err() error { 29 | return i.err 30 | } 31 | 32 | func (i *Iterator) apply() bool { 33 | next, e := i.position.Move(i.move) 34 | if e != nil { 35 | i.err = e 36 | return false 37 | } 38 | i.position = next 39 | i.lastMove = i.move 40 | i.move = tak.Move{} 41 | return true 42 | } 43 | 44 | func (i *Iterator) Next() bool { 45 | if i.err != nil || i.over { 46 | return false 47 | } 48 | 49 | if i.move.Type != 0 { 50 | if !i.apply() { 51 | return false 52 | } 53 | if ok, _ := i.position.GameOver(); ok { 54 | i.over = true 55 | return true 56 | } 57 | } 58 | 59 | for i.i < len(i.ptn.Ops) { 60 | op := i.ptn.Ops[i.i] 61 | i.i++ 62 | switch o := op.(type) { 63 | case *MoveNumber: 64 | i.ptnMove = o.Number 65 | case *Move: 66 | i.move = o.Move 67 | return true 68 | } 69 | } 70 | i.over = true 71 | if i.move.Type != 0 { 72 | return i.apply() 73 | } 74 | return true 75 | } 76 | 77 | func (i *Iterator) Position() *tak.Position { 78 | return i.position 79 | } 80 | 81 | func (i *Iterator) PTNMove() int { 82 | return i.ptnMove 83 | } 84 | 85 | func (i *Iterator) Move() tak.Move { 86 | return i.lastMove 87 | } 88 | 89 | func (i *Iterator) PeekMove() tak.Move { 90 | return i.move 91 | } 92 | -------------------------------------------------------------------------------- /ptn/iterator_test.go: -------------------------------------------------------------------------------- 1 | package ptn 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/nelhage/taktician/tak" 8 | ) 9 | 10 | func TestIterator(t *testing.T) { 11 | type step struct { 12 | ply int 13 | ptnMove int 14 | color tak.Color 15 | } 16 | cases := []struct { 17 | ptn string 18 | iters []step 19 | }{ 20 | { 21 | ` 22 | [Size "5"] 23 | [TPS "2,x2,121C,1/x2,2,12,1/x2,2,12S,2/x3,1,1/x4,1 1 2"] 24 | 25 | 1. 26 | `, 27 | []step{ 28 | {2, 1, tak.White}, 29 | }, 30 | }, 31 | {` 32 | [Size "5"] 33 | 34 | 1. a1 e5 35 | 2. e1 a2 36 | `, 37 | []step{ 38 | {0, 1, tak.White}, 39 | {1, 1, tak.Black}, 40 | {2, 2, tak.White}, 41 | {3, 2, tak.Black}, 42 | {4, 2, tak.White}, 43 | }, 44 | }, 45 | {` 46 | [Size "5"] 47 | 48 | `, 49 | []step{ 50 | {0, 0, tak.White}, 51 | }, 52 | }, 53 | } 54 | for i, tc := range cases { 55 | ptn, e := ParsePTN(bytes.NewBufferString(tc.ptn)) 56 | if e != nil { 57 | t.Errorf("[%d] %v", i, e) 58 | continue 59 | } 60 | it := ptn.Iterator() 61 | ct := 0 62 | for it.Next() { 63 | if ct >= len(tc.iters) { 64 | t.Errorf("[%d] too many results ply=%d", 65 | i, it.Position().MoveNumber()) 66 | break 67 | } 68 | expect := tc.iters[ct] 69 | ct++ 70 | if c := it.Position().ToMove(); c != expect.color { 71 | t.Errorf("[%d] .%d: wrong color %s != %s", 72 | i, ct, c, expect.color, 73 | ) 74 | } 75 | if m := it.PTNMove(); m != expect.ptnMove { 76 | t.Errorf("[%d] .%d: wrong PTN %d != %d", 77 | i, ct, m, expect.ptnMove, 78 | ) 79 | } 80 | if ply := it.Position().MoveNumber(); ply != expect.ply { 81 | t.Errorf("[%d] .%d: wrong ply %d != %d", 82 | i, ct, ply, expect.ply, 83 | ) 84 | } 85 | } 86 | if ct < len(tc.iters) { 87 | t.Errorf("[%d] too few results %d < %d", i, ct, len(tc.iters)) 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /ai/moves_test.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "math/rand" 5 | "testing" 6 | 7 | "github.com/nelhage/taktician/ptn" 8 | "github.com/nelhage/taktician/tak" 9 | "github.com/nelhage/taktician/taktest" 10 | ) 11 | 12 | func TestMoveGenerator(t *testing.T) { 13 | p, _ := ptn.ParseTPS("1,1,x3/x,1,x,2,x/x,2,1C,x2/x,2,1,x2/2,2,1,x2 2 6") 14 | pvm := taktest.Move("Cc4") 15 | tem := taktest.Move("c4") 16 | cm := taktest.Move("b1>") 17 | te := tableEntry{ 18 | m: tem, 19 | } 20 | 21 | ai := NewMinimax(MinimaxConfig{Size: 5}) 22 | ai.rand = rand.New(rand.NewSource(7)) 23 | ai.history[cm] = 100 24 | 25 | mg := &ai.stack[1].mg 26 | *mg = moveGenerator{ 27 | p: p, 28 | f: &ai.stack[1], 29 | ai: ai, 30 | ply: 1, 31 | depth: 5, 32 | 33 | te: &te, 34 | pv: []tak.Move{pvm}, 35 | } 36 | 37 | allS := make(map[string]struct{}) 38 | all := p.AllMoves(nil) 39 | for _, a := range all { 40 | if _, e := p.Move(a); e == nil { 41 | allS[ptn.FormatMove(a)] = struct{}{} 42 | } 43 | } 44 | 45 | var generated []tak.Move 46 | genS := make(map[string]struct{}) 47 | for { 48 | m, c := mg.Next() 49 | if c == nil { 50 | break 51 | } 52 | generated = append(generated, m) 53 | genS[ptn.FormatMove(m)] = struct{}{} 54 | } 55 | 56 | if g := generated[0]; !g.Equal(tem) { 57 | t.Errorf("move[0]=%s != %s", 58 | ptn.FormatMove(g), ptn.FormatMove(tem)) 59 | } 60 | if g := generated[1]; !g.Equal(pvm) { 61 | t.Errorf("move[1]=%s != %s", 62 | ptn.FormatMove(g), ptn.FormatMove(pvm)) 63 | } 64 | if g := generated[2]; !g.Equal(cm) { 65 | t.Errorf("move[2]=%s != %s", 66 | ptn.FormatMove(g), ptn.FormatMove(cm)) 67 | } 68 | 69 | for g := range genS { 70 | if _, ok := allS[g]; !ok { 71 | t.Errorf("generated additional move %s", g) 72 | } 73 | } 74 | for a := range allS { 75 | if _, ok := genS[a]; !ok { 76 | t.Errorf("generate missed move %s", a) 77 | } 78 | } 79 | 80 | } 81 | -------------------------------------------------------------------------------- /bitboard/bits.go: -------------------------------------------------------------------------------- 1 | package bitboard 2 | 3 | type Constants struct { 4 | Size uint 5 | L, R, T, B uint64 6 | Edge uint64 7 | Mask uint64 8 | } 9 | 10 | func Precompute(size uint) Constants { 11 | var c Constants 12 | for i := uint(0); i < size; i++ { 13 | c.R |= 1 << (i * size) 14 | } 15 | c.Size = size 16 | c.L = c.R << (size - 1) 17 | c.T = ((1 << size) - 1) << (size * (size - 1)) 18 | c.B = (1 << size) - 1 19 | c.Mask = 1<<(size*size) - 1 20 | c.Edge = c.L | c.R | c.B | c.T 21 | return c 22 | } 23 | 24 | func Flood(c *Constants, within uint64, seed uint64) uint64 { 25 | for { 26 | next := Grow(c, within, seed) 27 | if next == seed { 28 | return next 29 | } 30 | seed = next 31 | } 32 | } 33 | 34 | func Grow(c *Constants, within uint64, seed uint64) uint64 { 35 | next := seed 36 | next |= (seed << 1) &^ c.R 37 | next |= (seed >> 1) &^ c.L 38 | next |= (seed >> c.Size) 39 | next |= (seed << c.Size) 40 | return next & within 41 | } 42 | 43 | func FloodGroups(c *Constants, bits uint64, out []uint64) []uint64 { 44 | var seen uint64 45 | for bits != 0 { 46 | next := bits & (bits - 1) 47 | bit := bits &^ next 48 | 49 | if seen&bit == 0 { 50 | g := Flood(c, bits, bit) 51 | if g != bit { 52 | out = append(out, g) 53 | } 54 | seen |= g 55 | } 56 | 57 | bits = next 58 | } 59 | return out 60 | } 61 | 62 | func Dimensions(c *Constants, bits uint64) (w, h int) { 63 | if bits == 0 { 64 | return 0, 0 65 | } 66 | b := c.L 67 | for bits&b == 0 { 68 | b >>= 1 69 | } 70 | for b != 0 && bits&b != 0 { 71 | b >>= 1 72 | w++ 73 | } 74 | b = c.T 75 | for bits&b == 0 { 76 | b >>= c.Size 77 | } 78 | for b != 0 && bits&b != 0 { 79 | b >>= c.Size 80 | h++ 81 | } 82 | return w, h 83 | } 84 | 85 | func BitCoords(c *Constants, bits uint64) (x, y uint) { 86 | if bits == 0 || bits&(bits-1) != 0 { 87 | panic("BitCoords: non-singular") 88 | } 89 | n := TrailingZeros(bits) 90 | y = n / c.Size 91 | x = n % c.Size 92 | return x, y 93 | } 94 | -------------------------------------------------------------------------------- /python/ext/tak.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | using std::min; 9 | using std::max; 10 | using std::abs; 11 | 12 | constexpr float SIGMA_EPSILON = 1e-3; 13 | 14 | torch::Tensor solve_policy(torch::Tensor pi_theta, torch::Tensor q, float lambda_n) { 15 | auto pi_theta_a = pi_theta.accessor(); 16 | auto q_a = q.accessor(); 17 | 18 | auto len = pi_theta.sizes()[0]; 19 | 20 | float alpha_min = -std::numeric_limits::infinity(); 21 | float alpha_max = -std::numeric_limits::infinity(); 22 | for (int i = 0; i < len; i++) { 23 | alpha_min = max(alpha_min, q_a[i] + lambda_n * pi_theta_a[i]); 24 | alpha_max = max(alpha_max, q_a[i] + lambda_n); 25 | } 26 | 27 | float alpha = (alpha_min + alpha_max)/2; 28 | float last_sum = std::numeric_limits::infinity(); 29 | for (int loops = 0; loops < 32; loops++) { 30 | float sum = 0.0; 31 | for (int i = 0; i < len; i++) { 32 | sum += lambda_n * pi_theta_a[i] / (alpha - q_a[i]); 33 | } 34 | /* 35 | printf("c++ i=%d alpha_bounds=%.2f,%.2f alpha=%.2f sigma=%.2f\n", 36 | loops, 37 | alpha_min, 38 | alpha_max, 39 | alpha, 40 | sum); 41 | */ 42 | float error = sum - 1.0; 43 | if (abs(error) <= SIGMA_EPSILON or sum == last_sum) { 44 | return lambda_n * pi_theta / (alpha - q); 45 | } 46 | last_sum = sum; 47 | if (sum > 1) { 48 | alpha_min = alpha; 49 | alpha = (alpha + alpha_max) / 2; 50 | } else { 51 | alpha_max = alpha; 52 | alpha = (alpha + alpha_min) / 2; 53 | } 54 | } 55 | 56 | throw std::runtime_error("alpha search did not converge"); 57 | } 58 | 59 | 60 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 61 | m.def("solve_policy", &solve_policy, "Solve for the regularized MCST policy."); 62 | } 63 | -------------------------------------------------------------------------------- /python/tak/proto/corpus_entry_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: tak/proto/corpus_entry.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1ctak/proto/corpus_entry.proto\x12\ttak.proto\"\xdc\x01\n\x0b\x43orpusEntry\x12\x0b\n\x03\x64\x61y\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x0b\n\x03ply\x18\x03 \x01(\x05\x12\x0b\n\x03tps\x18\x04 \x01(\t\x12\x0c\n\x04move\x18\x05 \x01(\t\x12\r\n\x05value\x18\x06 \x01(\x02\x12\r\n\x05plies\x18\x07 \x01(\x05\x12\x10\n\x08\x66\x65\x61tures\x18\x08 \x03(\x03\x12,\n\x06in_tak\x18\t \x01(\x0e\x32\x1c.tak.proto.CorpusEntry.InTak\".\n\x05InTak\x12\t\n\x05UNSET\x10\x00\x12\x0e\n\nNOT_IN_TAK\x10\x01\x12\n\n\x06IN_TAK\x10\x02\x42!Z\x1fgithub.com/nelhage/taktician/pbb\x06proto3') 18 | 19 | 20 | 21 | _CORPUSENTRY = DESCRIPTOR.message_types_by_name['CorpusEntry'] 22 | _CORPUSENTRY_INTAK = _CORPUSENTRY.enum_types_by_name['InTak'] 23 | CorpusEntry = _reflection.GeneratedProtocolMessageType('CorpusEntry', (_message.Message,), { 24 | 'DESCRIPTOR' : _CORPUSENTRY, 25 | '__module__' : 'tak.proto.corpus_entry_pb2' 26 | # @@protoc_insertion_point(class_scope:tak.proto.CorpusEntry) 27 | }) 28 | _sym_db.RegisterMessage(CorpusEntry) 29 | 30 | if _descriptor._USE_C_DESCRIPTORS == False: 31 | 32 | DESCRIPTOR._options = None 33 | DESCRIPTOR._serialized_options = b'Z\037github.com/nelhage/taktician/pb' 34 | _CORPUSENTRY._serialized_start=44 35 | _CORPUSENTRY._serialized_end=264 36 | _CORPUSENTRY_INTAK._serialized_start=218 37 | _CORPUSENTRY_INTAK._serialized_end=264 38 | # @@protoc_insertion_point(module_scope) 39 | -------------------------------------------------------------------------------- /python/tak/symmetry/symmetry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import tak 4 | 5 | rot = np.array( 6 | [[0, 1, 0], [-1, 0, 1], [0, 0, 1]], 7 | dtype=int, 8 | ) 9 | flip = np.array( 10 | [ 11 | [-1, 0, 1], 12 | [0, 1, 0], 13 | [0, 0, 1], 14 | ], 15 | dtype=int, 16 | ) 17 | 18 | SYMMETRIES = [ 19 | np.matmul(l, r) 20 | for l in [ 21 | np.identity(3, dtype=int), 22 | rot, 23 | np.matmul(rot, rot), 24 | np.matmul(np.matmul(rot, rot), rot), 25 | ] 26 | for r in [np.identity(3, dtype=int), flip] 27 | ] 28 | 29 | assert all(abs(np.linalg.det(m)) == 1 for m in SYMMETRIES) 30 | 31 | 32 | def transform_position(sym, pos): 33 | ix = np.stack( 34 | [ 35 | np.repeat(np.arange(pos.size), pos.size), 36 | np.tile(np.arange(pos.size), pos.size), 37 | (pos.size - 1) * np.ones(pos.size * pos.size), 38 | ], 39 | axis=-1, 40 | ) 41 | ix = np.transpose(np.matmul(sym, np.transpose(ix))).astype(int) 42 | ix = ix.reshape((pos.size, pos.size, 3)) 43 | 44 | sqs = list(pos.board) 45 | for i in range(pos.size): 46 | for j in range(pos.size): 47 | oi, oj, _ = ix[i, j] 48 | sqs[oi + oj * pos.size] = pos[i, j] 49 | 50 | return tak.Position.from_squares( 51 | tak.Config(size=pos.size), 52 | sqs, 53 | pos.ply, 54 | ) 55 | 56 | 57 | def transform_move(sym, move, size): 58 | ox, oy, _ = np.matmul(sym, [move.x, move.y, size - 1]) 59 | type = move.type 60 | if type.is_slide(): 61 | dx, dy, _ = np.matmul(sym, move.type.direction() + (0,)) 62 | type = tak.MoveType.from_direction(dx, dy) 63 | return tak.Move(int(ox), int(oy), type, move.slides) 64 | 65 | 66 | def symmetries(pos): 67 | out = [] 68 | for s in SYMMETRIES: 69 | t = transform_position(s, pos) 70 | if all(t != p for _, p in out): 71 | out.append((s, t)) 72 | 73 | return out 74 | 75 | 76 | __all__ = ["SYMMETRIES", "transform_position", "transform_move", "symmetries"] 77 | -------------------------------------------------------------------------------- /python/xformer/train/run.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | 3 | import torch 4 | from attrs import define, field 5 | from torch import nn 6 | 7 | 8 | class Batch(T.Protocol): 9 | inputs: torch.Tensor 10 | extra_inputs: tuple[torch.Tensor] = () 11 | 12 | 13 | Dataset = T.Iterable[Batch] 14 | 15 | 16 | class LossFunction(T.Protocol): 17 | def loss_and_metrics(self, batch, output) -> tuple[torch.Tensor, dict[str, float]]: 18 | ... 19 | 20 | 21 | @define 22 | class Stats: 23 | step: int = 0 24 | epoch: int = 0 25 | sequences: int = 0 26 | tokens: int = 0 27 | train_loss: float = 0 28 | step_time: float = 0 29 | elapsed_time: float = 0 30 | 31 | metrics: dict[str, object] = field(factory=dict) 32 | 33 | 34 | @define 35 | class Optimizer: 36 | lr: float = 5e-4 37 | lr_schedule: T.Optional[T.Callable[[Stats], float]] = None 38 | 39 | 40 | Trigger = T.Callable[[Stats], bool] 41 | 42 | 43 | @define 44 | class StopTrigger: 45 | steps: T.Optional[int] 46 | sequences: T.Optional[int] 47 | 48 | def __call__(self, stats: Stats): 49 | if self.steps is not None and stats.step >= self.steps: 50 | return True 51 | if self.sequences is not None and stats.sequences >= self.sequences: 52 | return True 53 | return False 54 | 55 | 56 | class Hook: 57 | def before_run(self, run: "Run"): 58 | pass 59 | 60 | def after_run(self, run: "Run", stats: Stats): 61 | pass 62 | 63 | def before_step(self, run: "Run", stats: Stats): 64 | pass 65 | 66 | def after_step(self, run: "Run", stats: Stats): 67 | pass 68 | 69 | 70 | @define(slots=False) 71 | class Run: 72 | model: nn.Module 73 | dataset: Dataset 74 | loss: LossFunction 75 | 76 | stop: Trigger 77 | 78 | optimizer: Optimizer = field(factory=Optimizer) 79 | 80 | hooks: list[Hook] = field(factory=list) 81 | 82 | 83 | __all__ = [ 84 | "Batch", 85 | "Dataset", 86 | "Hook", 87 | "LossFunction", 88 | "Optimizer", 89 | "Run", 90 | "Stats", 91 | "StopTrigger", 92 | ] 93 | -------------------------------------------------------------------------------- /tests/play_game_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "os" 7 | "strconv" 8 | "testing" 9 | 10 | "github.com/nelhage/taktician/cli" 11 | "github.com/nelhage/taktician/ptn" 12 | "github.com/nelhage/taktician/tak" 13 | ) 14 | 15 | var games = flag.String("games", "", "Directory of .ptn files to self-check on") 16 | 17 | func TestPlayPTNs(t *testing.T) { 18 | if *games == "" { 19 | t.SkipNow() 20 | } 21 | ptns, err := readPTNs(*games) 22 | if err != nil { 23 | t.Fatalf("read ptns: %v", err) 24 | } 25 | for _, p := range ptns { 26 | playPTN(t, p) 27 | } 28 | } 29 | 30 | func playPTN(t *testing.T, p *ptn.PTN) { 31 | id := p.FindTag("Id") 32 | if id == "" { 33 | return 34 | } 35 | t.Log("playing", id) 36 | size, _ := strconv.Atoi(p.FindTag("Size")) 37 | g := tak.New(tak.Config{Size: size}) 38 | for _, op := range p.Ops { 39 | if m, ok := op.(*ptn.Move); ok { 40 | next, e := g.Move(m.Move) 41 | if e != nil { 42 | fmt.Printf("illegal move: %s\n", ptn.FormatMove(m.Move)) 43 | fmt.Printf("move=%d\n", g.MoveNumber()) 44 | cli.RenderBoard(nil, os.Stdout, g) 45 | t.Fatal("illegal move") 46 | } 47 | g = next 48 | } 49 | } 50 | over, winner := g.GameOver() 51 | var d tak.WinDetails 52 | if over { 53 | d = g.WinDetails() 54 | } 55 | switch p.FindTag("Result") { 56 | case "R-0": 57 | if !over || winner != tak.White || d.Reason != tak.RoadWin { 58 | t.Error("road win for white:", d) 59 | } 60 | case "0-R": 61 | if !over || winner != tak.Black || d.Reason != tak.RoadWin { 62 | t.Error("road win for white:", d) 63 | } 64 | case "F-0": 65 | if !over || winner != tak.White || d.Reason != tak.FlatsWin { 66 | t.Error("flats win for white:", d) 67 | } 68 | case "0-F": 69 | if !over || winner != tak.Black || d.Reason != tak.FlatsWin { 70 | t.Error("flats win for black:", d) 71 | } 72 | case "1/2-1/2": 73 | /* 74 | if over && winner != tak.NoColor { 75 | t.Error("tie", over, d) 76 | } 77 | */ 78 | 79 | // playtak mishandles double-road wins as ties, so we 80 | // can't usefully check here. 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /python/tak/alphazero/hooks/saving.py: -------------------------------------------------------------------------------- 1 | from ..trainer import Hook, TrainState 2 | 3 | from attrs import define, field 4 | import os.path 5 | 6 | from xformer import loading 7 | import torch 8 | import yaml 9 | 10 | 11 | def save_snapshot(state: TrainState, snapshot_path): 12 | os.makedirs(snapshot_path, exist_ok=True) 13 | loading.save_model(state.model, snapshot_path) 14 | torch.save( 15 | state.opt.state_dict(), 16 | os.path.join(snapshot_path, "opt.pt"), 17 | ) 18 | torch.save( 19 | state.replay_buffer, 20 | os.path.join(snapshot_path, "replay_buffer.pt"), 21 | ) 22 | with open(os.path.join(snapshot_path, "elapsed.yaml"), "w") as fh: 23 | yaml.dump(state.elapsed, fh) 24 | 25 | 26 | @define(slots=False) 27 | class SavingHook(Hook): 28 | freq: int 29 | 30 | def before_run(self, state, config): 31 | self.run_dir = config.run_dir 32 | 33 | def check_and_clear_save_request(self) -> bool: 34 | run_dir = self.run_dir 35 | if not run_dir: 36 | return False 37 | flagpath = os.path.join(run_dir, "SAVE_NOW") 38 | if os.path.exists(flagpath): 39 | os.unlink(flagpath) 40 | return True 41 | return False 42 | 43 | def after_step(self, state: TrainState): 44 | if state.elapsed.step % self.freq == 0 or self.check_and_clear_save_request(): 45 | self.save_snapshot(state) 46 | 47 | def after_run(self, state: TrainState): 48 | self.save_snapshot(state) 49 | 50 | def save_snapshot(self, state: TrainState): 51 | if self.run_dir is None: 52 | return 53 | 54 | save_dir = os.path.join(self.run_dir, f"step_{state.elapsed.step:06d}") 55 | print(f"Saving snapshot to {save_dir}...") 56 | save_snapshot(state, save_dir) 57 | latest_link = os.path.join(self.run_dir, "latest") 58 | try: 59 | os.unlink(latest_link) 60 | except FileNotFoundError: 61 | pass 62 | os.symlink( 63 | os.path.basename(save_dir), 64 | latest_link, 65 | ) 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Taktician - A Tak Bot 2 | 3 | This repository implements the game of [Tak][tak], including a fairly 4 | strong AI, and support for the playtak.com server. 5 | 6 | # Installation 7 | 8 | Taktician requires `go1.7` or newer. On OS X, try `brew update && brew 9 | install go`. 10 | 11 | Once you have a working `go` installation, you can fetch+install the 12 | below commands using: 13 | 14 | 15 | ``` 16 | go get -u github.com/nelhage/taktician/cmd/... 17 | ``` 18 | 19 | Alternately, if you have a checkout of this repository, build+install 20 | it using 21 | 22 | ``` 23 | go install ./cmd/... 24 | ``` 25 | 26 | to install a `taktician` binary into your `$GOPATH/bin` (`~/go/bin` by 27 | default). 28 | 29 | # Subcommands 30 | 31 | Taktician consists of a single binary, `taktician`, which accepts a 32 | number of subcommands. You can run `taktician -help` to list all 33 | available commands, and `taktician [command] -help` for details on the 34 | options available to an individual command. 35 | 36 | Perhaps the most generally useful subcommand is `taktician analyze`, 37 | which allows you evaluate a position offline using Taktician's AI: 38 | 39 | ## taktician analyze 40 | 41 | A command that reads PTN files and performs AI analysis on the 42 | terminal position. 43 | 44 | By default 45 | 46 | ``` 47 | taktician analyze FILE.ptn 48 | ``` 49 | 50 | will analyze the final position and report Taktician's evaluation and 51 | suggested move. 52 | 53 | You can also analzye e.g. white's 10th move using: 54 | 55 | ``` 56 | taktician analyze -white -move 10 FILE.ptn 57 | ``` 58 | 59 | With `-all`, `taktician analyze` will analyze each position in the PTN 60 | file. 61 | 62 | By default, `taktician analyze` will search for up to 1m before 63 | returning a final assessment. Use `-limit 2m` to give it more time, or 64 | `-depth 5` to search to a fixed depth. 65 | 66 | 67 | ## `taktician play` 68 | 69 | A simple interface to play tak on the command line. Try e.g. 70 | 71 | ``` 72 | taktician play -white=human -black=minimax:5 73 | ``` 74 | 75 | ## `taktician playtak` 76 | 77 | The AI driver for playtak.com. Can be used via 78 | 79 | ``` 80 | taktician playtak -user USERNAME -pass PASSWORD 81 | ``` 82 | 83 | [tak]: https://cheapass.com/tak/ 84 | -------------------------------------------------------------------------------- /ai/feature_string.go: -------------------------------------------------------------------------------- 1 | // Code generated by "stringer -type=Feature"; DO NOT EDIT. 2 | 3 | package ai 4 | 5 | import "strconv" 6 | 7 | func _() { 8 | // An "invalid array index" compiler error signifies that the constant values have changed. 9 | // Re-run the stringer command to generate them again. 10 | var x [1]struct{} 11 | _ = x[Tempo-0] 12 | _ = x[TopFlat-1] 13 | _ = x[Standing-2] 14 | _ = x[Capstone-3] 15 | _ = x[HardTopCap-4] 16 | _ = x[CapMobility-5] 17 | _ = x[FlatCaptives_Soft-6] 18 | _ = x[FlatCaptives_Hard-7] 19 | _ = x[StandingCaptives_Soft-8] 20 | _ = x[StandingCaptives_Hard-9] 21 | _ = x[CapstoneCaptives_Soft-10] 22 | _ = x[CapstoneCaptives_Hard-11] 23 | _ = x[Liberties-12] 24 | _ = x[GroupLiberties-13] 25 | _ = x[Groups-14] 26 | _ = x[Groups_1-15] 27 | _ = x[Groups_2-16] 28 | _ = x[Groups_3-17] 29 | _ = x[Groups_4-18] 30 | _ = x[Groups_5-19] 31 | _ = x[Groups_6-20] 32 | _ = x[Groups_7-21] 33 | _ = x[Groups_8-22] 34 | _ = x[Potential-23] 35 | _ = x[Threat-24] 36 | _ = x[EmptyControl-25] 37 | _ = x[FlatControl-26] 38 | _ = x[Center-27] 39 | _ = x[CenterControl-28] 40 | _ = x[ThrowMine-29] 41 | _ = x[ThrowTheirs-30] 42 | _ = x[ThrowEmpty-31] 43 | _ = x[Terminal_Plies-32] 44 | _ = x[Terminal_Flats-33] 45 | _ = x[Terminal_Reserves-34] 46 | _ = x[Terminal_OpponentReserves-35] 47 | _ = x[MaxFeature-36] 48 | } 49 | 50 | const _Feature_name = "TempoTopFlatStandingCapstoneHardTopCapCapMobilityFlatCaptives_SoftFlatCaptives_HardStandingCaptives_SoftStandingCaptives_HardCapstoneCaptives_SoftCapstoneCaptives_HardLibertiesGroupLibertiesGroupsGroups_1Groups_2Groups_3Groups_4Groups_5Groups_6Groups_7Groups_8PotentialThreatEmptyControlFlatControlCenterCenterControlThrowMineThrowTheirsThrowEmptyTerminal_PliesTerminal_FlatsTerminal_ReservesTerminal_OpponentReservesMaxFeature" 51 | 52 | var _Feature_index = [...]uint16{0, 5, 12, 20, 28, 38, 49, 66, 83, 104, 125, 146, 167, 176, 190, 196, 204, 212, 220, 228, 236, 244, 252, 260, 269, 275, 287, 298, 304, 317, 326, 337, 347, 361, 375, 392, 417, 427} 53 | 54 | func (i Feature) String() string { 55 | if i < 0 || i >= Feature(len(_Feature_index)-1) { 56 | return "Feature(" + strconv.FormatInt(int64(i), 10) + ")" 57 | } 58 | return _Feature_name[_Feature_index[i]:_Feature_index[i+1]] 59 | } 60 | -------------------------------------------------------------------------------- /doc/friendly.md: -------------------------------------------------------------------------------- 1 | # FriendlyBot 2 | 3 | Hello, friend! You've reached the home of FriendlyBot, a Tak bot 4 | designed to be a tool to help you learn the game of Tak. 5 | 6 | # Playing FriendlyBot 7 | 8 | You can play games against FriendlyBot on playtak.com. Log in and find 9 | the bot's name under "Join game". If it's not there, it might be 10 | occupied right now -- it can only play one person at a time right now. 11 | 12 | # Changing the difficulty level 13 | 14 | FriendlyBot supports playing at a number of different difficulty 15 | levels. 16 | 17 | To change the difficulty level, just say 18 | 19 | FriendlyBot: level LEVEL 20 | 21 | e.g. 22 | 23 | FriendlyBot: level 6 24 | 25 | in chat. It supports a number of levels, from 1 up through 13, at the 26 | moment. 27 | 28 | The levels have somewhat different styles of play, in addition to 29 | being harder or easier, so try a few different levels, even if you're 30 | struggling with one. 31 | 32 | If you're playing the bot, you can even change the difficulty level 33 | mid-game. 34 | 35 | Levels 10 and up are roughly equivalent to Taktician at various stages 36 | of its development, and so should be considered pretty challenging. 37 | 38 | At any given time, the highest level will track approximately the 39 | current play of Taktician. 40 | 41 | FriendlyBot can also play on different sizes of board. Try, for example: 42 | 43 | FriendlyBot: size 5 44 | 45 | # FPABot 46 | 47 | `FPABot` is a variant of `FriendlyBot` that exists to test variatons 48 | on the rules or play style to reduce the first-player advantage that 49 | exists in vanilla Tak. 50 | 51 | At present, it follows the rule that the first player must place the 52 | second player's initial stone somewhere in the center of the board. It 53 | may be updated with additional variations in the future. 54 | 55 | It obeys the same commands a `FriendlyBot`, but under its own name. 56 | 57 | # Resources 58 | 59 | New to Tak? Here are some links that might help you out. 60 | 61 | - [/r/tak](https://www.reddit.com/r/Tak/) is full of wonderful 62 | friendly people who are happy to give you advice or just chat. 63 | - [Turing's blog](https://taktraveler.wordpress.com/) has a bunch of 64 | good introductory strategy posts. 65 | - NohatCoder also wrote 66 | [a good introductory guide](http://ebusiness.hopto.org/2016-04-27-1.html) 67 | -------------------------------------------------------------------------------- /python/tak/moves.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import typing as T 3 | 4 | from attrs import define, field 5 | 6 | 7 | @enum.unique 8 | class MoveType(enum.Enum): 9 | PLACE_FLAT = 1 10 | PLACE_STANDING = 2 11 | PLACE_CAPSTONE = 3 12 | SLIDE_LEFT = 4 13 | SLIDE_RIGHT = 5 14 | SLIDE_UP = 6 15 | SLIDE_DOWN = 7 16 | 17 | def is_slide(self): 18 | return self.value >= MoveType.SLIDE_LEFT.value 19 | 20 | def direction(self): 21 | assert self.is_slide() 22 | return DIRECTIONS[self] 23 | 24 | @staticmethod 25 | def from_direction(dx, dy): 26 | return RDIRECTIONS[(dx, dy)] 27 | 28 | def __lt__(self, rhs): 29 | return self.value < rhs.value 30 | 31 | 32 | DIRECTIONS = { 33 | MoveType.SLIDE_LEFT: (-1, 0), 34 | MoveType.SLIDE_RIGHT: (1, 0), 35 | MoveType.SLIDE_UP: (0, 1), 36 | MoveType.SLIDE_DOWN: (0, -1), 37 | } 38 | RDIRECTIONS = dict((v, k) for (k, v) in DIRECTIONS.items()) 39 | 40 | 41 | @define(frozen=True) 42 | class Move(object): 43 | x: int 44 | y: int 45 | type: MoveType = field(default=MoveType.PLACE_FLAT) 46 | slides: T.Optional[tuple[int]] = None 47 | 48 | 49 | ALL_SLIDES = [() for i in range(9)] 50 | 51 | 52 | def _compute_slides(size): 53 | slides = [] 54 | for i in range(1, size + 1): 55 | slides.append((i,)) 56 | for inner in ALL_SLIDES[size - i]: 57 | slides.append((i,) + inner) 58 | return slides 59 | 60 | 61 | for s in range(1, 9): 62 | ALL_SLIDES[s] = _compute_slides(s) 63 | 64 | 65 | def all_moves_for_size(size): 66 | out = [] 67 | for x in range(size): 68 | for y in range(size): 69 | out.append(Move(x, y, MoveType.PLACE_FLAT)) 70 | out.append(Move(x, y, MoveType.PLACE_STANDING)) 71 | out.append(Move(x, y, MoveType.PLACE_CAPSTONE)) 72 | 73 | dirs = [ 74 | (MoveType.SLIDE_LEFT, x), 75 | (MoveType.SLIDE_RIGHT, size - x - 1), 76 | (MoveType.SLIDE_DOWN, y), 77 | (MoveType.SLIDE_UP, size - y - 1), 78 | ] 79 | for slide in ALL_SLIDES[size]: 80 | for d, l in dirs: 81 | if len(slide) <= l: 82 | out.append(Move(x, y, d, slide)) 83 | return out 84 | 85 | 86 | __all__ = ["MoveType", "Move", "ALL_SLIDES", "all_moves_for_size"] 87 | -------------------------------------------------------------------------------- /taktest/utils.go: -------------------------------------------------------------------------------- 1 | package taktest 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strings" 7 | 8 | "github.com/nelhage/taktician/ptn" 9 | "github.com/nelhage/taktician/tak" 10 | ) 11 | 12 | func Move(s string) tak.Move { 13 | m, e := ptn.ParseMove(s) 14 | if e != nil { 15 | panic(e) 16 | } 17 | return m 18 | } 19 | 20 | func Moves(s string) []tak.Move { 21 | if s == "" { 22 | return nil 23 | } 24 | bits := strings.Split(s, " ") 25 | var ms []tak.Move 26 | for _, b := range bits { 27 | m, e := ptn.ParseMove(b) 28 | if e != nil { 29 | panic(e) 30 | } 31 | ms = append(ms, m) 32 | } 33 | return ms 34 | } 35 | 36 | func FormatMoves(ms []tak.Move) string { 37 | var bits []string 38 | for _, o := range ms { 39 | bits = append(bits, ptn.FormatMove(o)) 40 | } 41 | return strings.Join(bits, " ") 42 | } 43 | 44 | func Position(size int, ms string) *tak.Position { 45 | p := tak.New(tak.Config{Size: size}) 46 | moves := Moves(ms) 47 | var e error 48 | for _, m := range moves { 49 | p, e = p.Move(m) 50 | if e != nil { 51 | panic(e) 52 | } 53 | } 54 | return p 55 | } 56 | 57 | func Board(tpl string, who tak.Color) (*tak.Position, error) { 58 | lines := strings.Split(strings.Trim(tpl, " \n"), "\n") 59 | var pieces [][]tak.Square 60 | for _, l := range lines { 61 | bits := strings.Split(l, " ") 62 | var row []tak.Square 63 | for _, p := range bits { 64 | switch p { 65 | case "W": 66 | row = append(row, tak.Square{tak.MakePiece(tak.White, tak.Flat)}) 67 | case "B": 68 | row = append(row, tak.Square{tak.MakePiece(tak.Black, tak.Flat)}) 69 | case "WC": 70 | row = append(row, tak.Square{tak.MakePiece(tak.White, tak.Capstone)}) 71 | case "BC": 72 | row = append(row, tak.Square{tak.MakePiece(tak.Black, tak.Capstone)}) 73 | case "WS": 74 | row = append(row, tak.Square{tak.MakePiece(tak.White, tak.Standing)}) 75 | case "BS": 76 | row = append(row, tak.Square{tak.MakePiece(tak.Black, tak.Standing)}) 77 | case ".": 78 | row = append(row, tak.Square{}) 79 | case "": 80 | default: 81 | return nil, fmt.Errorf("bad piece: %v", p) 82 | } 83 | } 84 | if len(row) != len(lines) { 85 | return nil, errors.New("size mismatch") 86 | } 87 | pieces = append(pieces, row) 88 | } 89 | ply := 2 90 | if who == tak.Black { 91 | ply = 3 92 | } 93 | return tak.FromSquares(tak.Config{Size: len(pieces)}, pieces, ply) 94 | } 95 | -------------------------------------------------------------------------------- /python/tak/proto/analysis_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: tak/proto/analysis.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18tak/proto/analysis.proto\x12\ttak.proto\"#\n\x0f\x45valuateRequest\x12\x10\n\x08position\x18\x01 \x03(\x05\"O\n\x10\x45valuateResponse\x12\x12\n\nmove_probs\x18\x01 \x03(\x02\x12\r\n\x05value\x18\x02 \x01(\x02\x12\x18\n\x10move_probs_bytes\x18\x03 \x01(\x0c\x32Q\n\x08\x41nalysis\x12\x45\n\x08\x45valuate\x12\x1a.tak.proto.EvaluateRequest\x1a\x1b.tak.proto.EvaluateResponse\"\x00\x42\"Z github.com/nelhage/taktician/pb/b\x06proto3') 18 | 19 | 20 | 21 | _EVALUATEREQUEST = DESCRIPTOR.message_types_by_name['EvaluateRequest'] 22 | _EVALUATERESPONSE = DESCRIPTOR.message_types_by_name['EvaluateResponse'] 23 | EvaluateRequest = _reflection.GeneratedProtocolMessageType('EvaluateRequest', (_message.Message,), { 24 | 'DESCRIPTOR' : _EVALUATEREQUEST, 25 | '__module__' : 'tak.proto.analysis_pb2' 26 | # @@protoc_insertion_point(class_scope:tak.proto.EvaluateRequest) 27 | }) 28 | _sym_db.RegisterMessage(EvaluateRequest) 29 | 30 | EvaluateResponse = _reflection.GeneratedProtocolMessageType('EvaluateResponse', (_message.Message,), { 31 | 'DESCRIPTOR' : _EVALUATERESPONSE, 32 | '__module__' : 'tak.proto.analysis_pb2' 33 | # @@protoc_insertion_point(class_scope:tak.proto.EvaluateResponse) 34 | }) 35 | _sym_db.RegisterMessage(EvaluateResponse) 36 | 37 | _ANALYSIS = DESCRIPTOR.services_by_name['Analysis'] 38 | if _descriptor._USE_C_DESCRIPTORS == False: 39 | 40 | DESCRIPTOR._options = None 41 | DESCRIPTOR._serialized_options = b'Z github.com/nelhage/taktician/pb/' 42 | _EVALUATEREQUEST._serialized_start=39 43 | _EVALUATEREQUEST._serialized_end=74 44 | _EVALUATERESPONSE._serialized_start=76 45 | _EVALUATERESPONSE._serialized_end=155 46 | _ANALYSIS._serialized_start=157 47 | _ANALYSIS._serialized_end=238 48 | # @@protoc_insertion_point(module_scope) 49 | -------------------------------------------------------------------------------- /python/tak/model/wrapper.py: -------------------------------------------------------------------------------- 1 | from . import encoding 2 | from attrs import define, field 3 | 4 | import torch 5 | from torch import nn 6 | import typing as T 7 | 8 | 9 | @define 10 | class ModelWrapper: 11 | model: nn.Module 12 | device: T.Optional[str] = None 13 | 14 | def evaluate(self, pos): 15 | with torch.no_grad(): 16 | encoded = torch.tensor( 17 | [encoding.encode(pos)], dtype=torch.long, device=self.device 18 | ) 19 | out = self.model(encoded) 20 | return torch.softmax(out["moves"][0], dim=0).cpu(), out["values"][0].item() 21 | 22 | 23 | @define 24 | class GraphedWrapper: 25 | model: nn.Module 26 | max_length: int = 30 27 | 28 | graph: torch.cuda.CUDAGraph = field(factory=torch.cuda.CUDAGraph) 29 | static_pos: torch.Tensor = field(init=False) 30 | static_mask: torch.Tensor = field(init=False) 31 | static_output: dict[str, torch.Tensor] = field(init=False) 32 | 33 | def __attrs_post_init__(self): 34 | self.static_pos = torch.ones( 35 | ( 36 | 1, 37 | self.max_length, 38 | ), 39 | dtype=torch.long, 40 | device="cuda", 41 | ) 42 | self.static_mask = torch.zeros( 43 | ( 44 | 1, 45 | self.max_length, 46 | ), 47 | dtype=torch.bool, 48 | device="cuda", 49 | ) 50 | 51 | s = torch.cuda.Stream() 52 | s.wait_stream(torch.cuda.current_stream()) 53 | with torch.cuda.stream(s), torch.no_grad(): 54 | for _ in range(3): 55 | self.model(self.static_pos, self.static_mask) 56 | torch.cuda.current_stream().wait_stream(s) 57 | 58 | with torch.cuda.graph(self.graph), torch.no_grad(): 59 | self.static_output = self.model(self.static_pos, self.static_mask) 60 | 61 | def evaluate(self, pos): 62 | with torch.no_grad(): 63 | encoded = encoding.encode(pos) 64 | self.static_pos[:, : len(encoded)].copy_( 65 | torch.tensor(encoded, dtype=torch.long) 66 | ) 67 | self.static_mask[:, : len(encoded)].fill_(0) 68 | self.static_mask[:, len(encoded) :].fill_(1) 69 | self.graph.replay() 70 | out = self.static_output 71 | return torch.softmax(out["moves"][0], dim=0).cpu(), out["values"][0].item() 72 | -------------------------------------------------------------------------------- /python/bench/mcts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | import tak 5 | from tak import mcts 6 | from xformer import loading 7 | from tak.model import wrapper, grpc 8 | import torch 9 | 10 | import time 11 | 12 | 13 | def main(argv): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--simulations", 17 | dest="simulations", 18 | type=int, 19 | default=100, 20 | metavar="POSITIONS", 21 | ) 22 | parser.add_argument( 23 | "--size", 24 | dest="size", 25 | type=int, 26 | default=3, 27 | metavar="SIZE", 28 | ) 29 | parser.add_argument( 30 | "--graph", 31 | action="store_true", 32 | default=False, 33 | help="Use CUDA graphs to run the network", 34 | ) 35 | parser.add_argument( 36 | "--fp16", 37 | action="store_true", 38 | default=False, 39 | help="Run model in float16", 40 | ) 41 | parser.add_argument( 42 | "--device", 43 | type=str, 44 | default="cpu", 45 | ) 46 | parser.add_argument( 47 | "--model", 48 | type=str, 49 | ) 50 | parser.add_argument( 51 | "--host", 52 | type=str, 53 | ) 54 | parser.add_argument( 55 | "--port", 56 | type=int, 57 | default=5001, 58 | ) 59 | 60 | args = parser.parse_args(argv) 61 | 62 | if (args.model and args.host) or not (args.model or args.host): 63 | raise ValueError("Must specify either --host or --model, not both") 64 | if args.model: 65 | model = loading.load_model(args.model, args.device) 66 | if args.fp16: 67 | model = model.to(torch.float16) 68 | 69 | if args.graph: 70 | network = wrapper.GraphedWrapper(model) 71 | else: 72 | network = wrapper.ModelWrapper(model, device=args.device) 73 | else: 74 | network = grpc.GRPCNetwork(host=args.host, port=args.port) 75 | 76 | p = tak.Position.from_config(tak.Config(size=args.size)) 77 | 78 | engine = mcts.MCTS( 79 | mcts.Config( 80 | network=network, 81 | simulation_limit=args.simulations, 82 | time_limit=0, 83 | ) 84 | ) 85 | 86 | start = time.time() 87 | tree = engine.analyze(p) 88 | end = time.time() 89 | 90 | print(f"done simulations={tree.simulations} duration={end-start:.2f}") 91 | 92 | 93 | if __name__ == "__main__": 94 | main(sys.argv[1:]) 95 | -------------------------------------------------------------------------------- /cmd/taktician/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "log" 7 | "os" 8 | "runtime/pprof" 9 | 10 | "github.com/google/subcommands" 11 | "github.com/nelhage/taktician/cmd/internal/analyze" 12 | "github.com/nelhage/taktician/cmd/internal/canonicalize" 13 | "github.com/nelhage/taktician/cmd/internal/gencorpus" 14 | "github.com/nelhage/taktician/cmd/internal/genopenings" 15 | "github.com/nelhage/taktician/cmd/internal/importptn" 16 | "github.com/nelhage/taktician/cmd/internal/openings" 17 | "github.com/nelhage/taktician/cmd/internal/play" 18 | "github.com/nelhage/taktician/cmd/internal/playtak" 19 | "github.com/nelhage/taktician/cmd/internal/selfplay" 20 | "github.com/nelhage/taktician/cmd/internal/serve" 21 | "github.com/nelhage/taktician/cmd/internal/tei" 22 | ) 23 | 24 | func innerMain() int { 25 | subcommands.Register(subcommands.HelpCommand(), "") 26 | subcommands.Register(subcommands.FlagsCommand(), "") 27 | // subcommands.Register(subcommands.CommandsCommand(), "") 28 | 29 | subcommands.Register(&analyze.Command{}, "") 30 | subcommands.Register(&selfplay.Command{}, "") 31 | subcommands.Register(&playtak.Command{}, "") 32 | subcommands.Register(&serve.Command{}, "") 33 | subcommands.Register(&play.Command{}, "") 34 | subcommands.Register(&tei.Command{}, "") 35 | 36 | subcommands.Register(&genopenings.Command{}, "") 37 | subcommands.Register(&openings.Command{}, "") 38 | subcommands.Register(&canonicalize.Command{}, "") 39 | subcommands.Register(&gencorpus.Command{}, "") 40 | 41 | subcommands.Register(&importptn.Command{}, "") 42 | 43 | var cpuProfile, memProfile string 44 | 45 | flag.StringVar(&cpuProfile, "cpuprofile", "", "write CPU profile") 46 | flag.StringVar(&memProfile, "memprofile", "", "write memory profile") 47 | 48 | flag.Parse() 49 | 50 | if cpuProfile != "" { 51 | f, e := os.OpenFile(cpuProfile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) 52 | if e != nil { 53 | log.Fatalf("open cpu-profile: %s: %v", cpuProfile, e) 54 | } 55 | pprof.StartCPUProfile(f) 56 | defer f.Close() 57 | defer pprof.StopCPUProfile() 58 | } 59 | if memProfile != "" { 60 | f, e := os.OpenFile(memProfile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) 61 | if e != nil { 62 | log.Fatalf("open memory profile: %s: %v", cpuProfile, e) 63 | } 64 | defer func() { 65 | pprof.Lookup("allocs").WriteTo(f, 0) 66 | f.Close() 67 | }() 68 | } 69 | 70 | ctx := context.Background() 71 | return int(subcommands.Execute(ctx)) 72 | } 73 | 74 | func main() { 75 | os.Exit(innerMain()) 76 | } 77 | -------------------------------------------------------------------------------- /python/tak/alphazero/data.py: -------------------------------------------------------------------------------- 1 | from attrs import define, field 2 | import torch 3 | 4 | 5 | @define 6 | class ReplayBufferBatch: 7 | data: dict[str, torch.Tensor] 8 | 9 | @property 10 | def inputs(self): 11 | return self.data["positions"] 12 | 13 | @property 14 | def mask(self): 15 | return self.data["mask"] 16 | 17 | @property 18 | def extra_inputs(self): 19 | return (~self.mask,) 20 | 21 | @property 22 | def moves(self): 23 | return self.data["moves"] 24 | 25 | @property 26 | def values(self): 27 | # Experiment with rollout result here instead 28 | return self.data["values"] 29 | 30 | 31 | @define 32 | class ReplayBufferDataset: 33 | replay_buffer: list[dict[str, torch.Tensor]] 34 | batch_size: int 35 | device: str 36 | flat_replay_buffer: dict[str, torch.Tensor] = field(init=False) 37 | 38 | def __attrs_post_init__(self): 39 | self.flat_replay_buffer = self.cat_replay_buffer() 40 | 41 | def cat_replay_buffer(self): 42 | full_replay_buffer = { 43 | k: torch.cat([d[k] for d in self.replay_buffer]) 44 | for k in self.replay_buffer[0] 45 | if k not in ["positions", "mask"] 46 | } 47 | npos = sum(b["positions"].size(0) for b in self.replay_buffer) 48 | maxwidth = max(b["positions"].size(1) for b in self.replay_buffer) 49 | positions = torch.zeros((npos, maxwidth), dtype=torch.long) 50 | mask = torch.zeros((npos, maxwidth), dtype=torch.bool) 51 | 52 | n = 0 53 | for b in self.replay_buffer: 54 | shape = b["positions"].shape 55 | positions[n : n + shape[0], : shape[1]] = b["positions"] 56 | mask[n : n + shape[0], : shape[1]] = b["mask"] 57 | n += shape[0] 58 | 59 | full_replay_buffer["positions"] = positions 60 | full_replay_buffer["mask"] = mask 61 | return full_replay_buffer 62 | 63 | def pin(self, tensor): 64 | if self.device.startswith("cuda"): 65 | return tensor.pin_memory() 66 | return tensor 67 | 68 | def __iter__(self): 69 | npos = len(self.flat_replay_buffer["positions"]) 70 | 71 | perm = torch.randperm(npos) 72 | shuffled = {k: self.pin(v[perm]) for (k, v) in self.flat_replay_buffer.items()} 73 | 74 | for i in range(0, npos, self.batch_size): 75 | yield ReplayBufferBatch( 76 | { 77 | k: v[i : i + self.batch_size].to(self.device) 78 | for (k, v) in shuffled.items() 79 | } 80 | ) 81 | -------------------------------------------------------------------------------- /python/xformer/data/__init__.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | 3 | import attrs 4 | import torch 5 | from attrs import define, field 6 | 7 | from xformer import train 8 | 9 | 10 | class BatchProtocol(train.Batch, T.Protocol): 11 | def __init__(self, data: dict[str, torch.Tensor]): 12 | ... 13 | 14 | 15 | @define 16 | class Batch(BatchProtocol): 17 | data: dict[str, torch.Tensor] 18 | 19 | @property 20 | def inputs(self): 21 | return self.data["inputs"] 22 | 23 | 24 | def transient(**kwargs): 25 | kwargs.setdefault("metadata", {})["transient"] = True 26 | kwargs["init"] = False 27 | return field(**kwargs) 28 | 29 | 30 | @define(getstate_setstate=False) 31 | class Dataset: 32 | path: str 33 | batch_size: int 34 | batches: T.Optional[int] = None 35 | device: str = "cpu" 36 | seed: int = 0x12345678 37 | batch_class: type = Batch 38 | 39 | data: dict[str, torch.Tensor] = transient() 40 | generator: torch.Generator = transient() 41 | 42 | def __getstate__(self): 43 | return { 44 | f.name: getattr(self, f.name) 45 | for f in attrs.fields(type(self)) 46 | if not f.metadata.get("transient") 47 | } 48 | 49 | def __setstate__(self, state): 50 | for (k, v) in state.items(): 51 | setattr(self, k, v) 52 | self.__attrs_post_init__() 53 | 54 | def __attrs_post_init__(self): 55 | self.data = torch.load(self.path) 56 | for (k, v) in self.data.items(): 57 | if v.dtype == torch.uint8: 58 | v = v.long() 59 | if self.batches is not None: 60 | v = v[: self.batches * self.batch_size] 61 | self.data[k] = v 62 | 63 | self.generator = torch.Generator().manual_seed(self.seed) 64 | 65 | def __len__(self): 66 | return len(next(iter(self.data.values()))) 67 | 68 | def pin(self, tensor): 69 | if self.device.startswith("cuda"): 70 | return tensor.pin_memory() 71 | return tensor 72 | 73 | def _next_epoch(self): 74 | perm = torch.randperm(len(self), generator=self.generator) 75 | return {k: self.pin(v[perm]) for (k, v) in self.data.items()} 76 | 77 | def fastforward_epochs(self, n: int): 78 | for _ in range(n): 79 | self._next_epoch() 80 | 81 | def __iter__(self): 82 | shuffled = self._next_epoch() 83 | for i in range(0, len(self), self.batch_size): 84 | yield self.batch_class( 85 | { 86 | k: v[i : i + self.batch_size].to(self.device) 87 | for (k, v) in shuffled.items() 88 | } 89 | ) 90 | -------------------------------------------------------------------------------- /tests/hash_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "testing" 7 | 8 | "github.com/nelhage/taktician/ai" 9 | "github.com/nelhage/taktician/bitboard" 10 | "github.com/nelhage/taktician/ptn" 11 | "github.com/nelhage/taktician/tak" 12 | ) 13 | 14 | var hashTests = flag.Bool("test-hash", false, "run hash collision tests") 15 | var tps = flag.String("tps", "112S,12,1112S,x2/x2,121C,12S,x/1,21,2,2,2/x,2,1,1,1/2,x3,21 2 24", "run hash collision tests on tps") 16 | var depth = flag.Int("depth", 5, "run hash collision tests to depth") 17 | 18 | func wrapHash(tbl map[uint64][]*tak.Position, eval ai.EvaluationFunc) ai.EvaluationFunc { 19 | return func(c *bitboard.Constants, p *tak.Position) int64 { 20 | tbl[p.Hash()] = append(tbl[p.Hash()], p.Clone()) 21 | return eval(c, p) 22 | } 23 | } 24 | 25 | func equal(a, b *tak.Position) bool { 26 | if a.ToMove() != b.ToMove() { 27 | return false 28 | } 29 | if a.White != b.White { 30 | return false 31 | } 32 | if a.Black != b.Black { 33 | return false 34 | } 35 | if a.Standing != b.Standing { 36 | return false 37 | } 38 | if a.Caps != b.Caps { 39 | return false 40 | } 41 | for i := range a.Height { 42 | if a.Height[i] != b.Height[i] { 43 | return false 44 | } 45 | if a.Stacks[i] != b.Stacks[i] { 46 | return false 47 | } 48 | } 49 | return true 50 | } 51 | 52 | func reportCollisions(t *testing.T, tbl map[uint64][]*tak.Position) { 53 | var n, collisions int 54 | for h, l := range tbl { 55 | n += len(l) 56 | p := l[0] 57 | for _, pp := range l[1:] { 58 | if !equal(p, pp) { 59 | t.Logf(" collision h=%x l=%q r=%q", 60 | h, ptn.FormatTPS(p), ptn.FormatTPS(pp), 61 | ) 62 | collisions++ 63 | break 64 | } 65 | } 66 | } 67 | 68 | t.Logf("evaluated %d positions and %d hashes, with %d collisions", 69 | n, len(tbl), collisions) 70 | } 71 | 72 | func TestHash(t *testing.T) { 73 | if !*hashTests { 74 | t.SkipNow() 75 | } 76 | testCollisions(t, tak.New(tak.Config{Size: 5})) 77 | p, e := ptn.ParseTPS(*tps) 78 | if e != nil { 79 | panic("bad tps") 80 | } 81 | testCollisions(t, p) 82 | } 83 | 84 | func testCollisions(t *testing.T, p *tak.Position) { 85 | tbl := make(map[uint64][]*tak.Position) 86 | ai := ai.NewMinimax(ai.MinimaxConfig{ 87 | Size: 5, 88 | Depth: *depth, 89 | Evaluate: wrapHash(tbl, ai.MakeEvaluator(5, nil)), 90 | TableMem: -1, 91 | }) 92 | for i := 0; i < 4; i++ { 93 | m := ai.GetMove(context.Background(), p) 94 | p, _ = p.Move(m) 95 | if ok, _ := p.GameOver(); ok { 96 | break 97 | } 98 | } 99 | reportCollisions(t, tbl) 100 | } 101 | -------------------------------------------------------------------------------- /ai/moves.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "sort" 5 | 6 | "github.com/nelhage/taktician/tak" 7 | ) 8 | 9 | type moveGenerator struct { 10 | ai *MinimaxAI 11 | f *frame 12 | ply int 13 | depth int 14 | p *tak.Position 15 | 16 | te *tableEntry 17 | pv []tak.Move 18 | r tak.Move 19 | 20 | ms []tak.Move 21 | i int 22 | } 23 | 24 | type sortMoves struct { 25 | ms []tak.Move 26 | vs []int 27 | } 28 | 29 | func (s sortMoves) Len() int { return len(s.ms) } 30 | func (s sortMoves) Less(i, j int) bool { 31 | return s.vs[i] > s.vs[j] 32 | } 33 | func (s sortMoves) Swap(i, j int) { 34 | s.ms[i], s.ms[j] = s.ms[j], s.ms[i] 35 | s.vs[i], s.vs[j] = s.vs[j], s.vs[i] 36 | } 37 | 38 | func (mg *moveGenerator) sortMoves() { 39 | vs := mg.f.vals.slice 40 | if vs == nil { 41 | vs = mg.f.vals.alloc[:] 42 | } 43 | if len(vs) < len(mg.ms) { 44 | vs = make([]int, len(mg.ms)) 45 | } 46 | s := sortMoves{ 47 | mg.ms, 48 | vs, 49 | } 50 | for i, m := range s.ms { 51 | s.vs[i] = mg.ai.history[m] 52 | } 53 | sort.Sort(s) 54 | } 55 | 56 | func (mg *moveGenerator) Reset() { 57 | mg.i = 0 58 | } 59 | 60 | func (mg *moveGenerator) Next() (m tak.Move, p *tak.Position) { 61 | for { 62 | var m tak.Move 63 | switch mg.i { 64 | case 0: 65 | mg.i++ 66 | if mg.te != nil { 67 | m = mg.te.m 68 | break 69 | } 70 | fallthrough 71 | case 1: 72 | mg.i++ 73 | if len(mg.pv) > 0 { 74 | m = mg.pv[0] 75 | if mg.te != nil && m.Equal(mg.te.m) { 76 | continue 77 | } 78 | break 79 | } 80 | fallthrough 81 | case 2: 82 | mg.i++ 83 | if mg.ply == 0 { 84 | continue 85 | } 86 | var ok bool 87 | if mg.r, ok = mg.ai.response[mg.ai.stack[mg.ply-1].m]; ok { 88 | m = mg.r 89 | break 90 | } 91 | fallthrough 92 | case 3: 93 | mg.i++ 94 | if mg.ms == nil { 95 | ms := mg.f.moves.slice 96 | if ms == nil { 97 | ms = mg.f.moves.alloc[:] 98 | } 99 | mg.ms = mg.p.AllMoves(ms[:0]) 100 | mg.f.moves.slice = ms[:] 101 | } 102 | if mg.depth > 1 && !mg.ai.Cfg.NoSort { 103 | mg.sortMoves() 104 | } 105 | fallthrough 106 | default: 107 | j := mg.i - 4 108 | mg.i++ 109 | if j >= len(mg.ms) { 110 | return tak.Move{}, nil 111 | } 112 | m = mg.ms[j] 113 | if mg.te != nil && mg.te.m.Equal(m) { 114 | continue 115 | } 116 | if len(mg.pv) != 0 && mg.pv[0].Equal(m) { 117 | continue 118 | } 119 | if mg.r.Equal(m) { 120 | continue 121 | } 122 | } 123 | child, e := mg.p.MovePreallocated(m, mg.ai.stack[mg.ply].p) 124 | if e == nil { 125 | return m, child 126 | } 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /cmd/internal/playtak/taktician.go: -------------------------------------------------------------------------------- 1 | package playtak 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "strconv" 7 | "strings" 8 | "time" 9 | 10 | "github.com/nelhage/taktician/ai" 11 | "github.com/nelhage/taktician/playtak" 12 | "github.com/nelhage/taktician/playtak/bot" 13 | "github.com/nelhage/taktician/tak" 14 | ) 15 | 16 | type Taktician struct { 17 | cmd *Command 18 | 19 | g *bot.Game 20 | client *playtak.Commands 21 | ai ai.TakPlayer 22 | } 23 | 24 | func (t *Taktician) NewGame(g *bot.Game) { 25 | t.g = g 26 | t.ai = t.cmd.wrapWithBook( 27 | g.Size, 28 | ai.NewMinimax(ai.MinimaxConfig{ 29 | Size: g.Size, 30 | Depth: t.cmd.depth, 31 | Debug: t.cmd.debug, 32 | 33 | NoSort: !t.cmd.sort, 34 | TableMem: t.cmd.tableMem, 35 | MultiCut: t.cmd.multicut, 36 | })) 37 | } 38 | 39 | func (t *Taktician) GetMove( 40 | ctx context.Context, 41 | p *tak.Position, 42 | mine, theirs time.Duration) tak.Move { 43 | if p.ToMove() == t.g.Color { 44 | var cancel context.CancelFunc 45 | timeout := t.timeBound(mine) 46 | if p.MoveNumber() < 2 { 47 | timeout = 20 * time.Second 48 | } 49 | ctx, cancel = context.WithTimeout(ctx, timeout) 50 | defer cancel() 51 | } else if !t.cmd.useOpponentTime { 52 | return tak.Move{} 53 | } 54 | return t.ai.GetMove(ctx, p) 55 | } 56 | 57 | func (t *Taktician) timeBound(remaining time.Duration) time.Duration { 58 | if t.g.Size == 4 { 59 | return t.cmd.limit 60 | } 61 | return t.cmd.limit 62 | } 63 | 64 | func (t *Taktician) GameOver() { 65 | t.ai = nil 66 | t.g = nil 67 | } 68 | 69 | func (t *Taktician) handleCommand(cmd, arg string) { 70 | switch strings.ToLower(cmd) { 71 | case "size": 72 | sz, err := strconv.Atoi(arg) 73 | if err != nil { 74 | log.Printf("bad size size=%q", arg) 75 | return 76 | } 77 | if sz >= 4 && sz <= 6 { 78 | t.cmd.size = sz 79 | if t.g == nil { 80 | t.client.SendCommand("Seek", 81 | strconv.Itoa(t.cmd.size), 82 | strconv.Itoa(int(t.cmd.gameTime.Seconds())), 83 | strconv.Itoa(int(t.cmd.increment.Seconds()))) 84 | } 85 | } 86 | } 87 | } 88 | 89 | func (t *Taktician) HandleTell(who string, msg string) { 90 | bits := strings.SplitN(msg, " ", 2) 91 | cmd := bits[0] 92 | var arg string 93 | if len(bits) == 2 { 94 | arg = bits[1] 95 | } 96 | t.handleCommand(cmd, arg) 97 | } 98 | 99 | func (t *Taktician) HandleChat(room string, who string, msg string) { 100 | cmd, arg := parseCommand(t.client.User, msg) 101 | if cmd == "" { 102 | return 103 | } 104 | log.Printf("chat room=%q from=%q msg=%q", room, who, msg) 105 | t.handleCommand(cmd, arg) 106 | } 107 | 108 | func (t *Taktician) AcceptUndo() bool { 109 | return false 110 | } 111 | -------------------------------------------------------------------------------- /python/tak/proto/analysis_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | from tak.proto import analysis_pb2 as tak_dot_proto_dot_analysis__pb2 6 | 7 | 8 | class AnalysisStub(object): 9 | """Missing associated documentation comment in .proto file.""" 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.Evaluate = channel.unary_unary( 18 | '/tak.proto.Analysis/Evaluate', 19 | request_serializer=tak_dot_proto_dot_analysis__pb2.EvaluateRequest.SerializeToString, 20 | response_deserializer=tak_dot_proto_dot_analysis__pb2.EvaluateResponse.FromString, 21 | ) 22 | 23 | 24 | class AnalysisServicer(object): 25 | """Missing associated documentation comment in .proto file.""" 26 | 27 | def Evaluate(self, request, context): 28 | """Missing associated documentation comment in .proto file.""" 29 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 30 | context.set_details('Method not implemented!') 31 | raise NotImplementedError('Method not implemented!') 32 | 33 | 34 | def add_AnalysisServicer_to_server(servicer, server): 35 | rpc_method_handlers = { 36 | 'Evaluate': grpc.unary_unary_rpc_method_handler( 37 | servicer.Evaluate, 38 | request_deserializer=tak_dot_proto_dot_analysis__pb2.EvaluateRequest.FromString, 39 | response_serializer=tak_dot_proto_dot_analysis__pb2.EvaluateResponse.SerializeToString, 40 | ), 41 | } 42 | generic_handler = grpc.method_handlers_generic_handler( 43 | 'tak.proto.Analysis', rpc_method_handlers) 44 | server.add_generic_rpc_handlers((generic_handler,)) 45 | 46 | 47 | # This class is part of an EXPERIMENTAL API. 48 | class Analysis(object): 49 | """Missing associated documentation comment in .proto file.""" 50 | 51 | @staticmethod 52 | def Evaluate(request, 53 | target, 54 | options=(), 55 | channel_credentials=None, 56 | call_credentials=None, 57 | insecure=False, 58 | compression=None, 59 | wait_for_ready=None, 60 | timeout=None, 61 | metadata=None): 62 | return grpc.experimental.unary_unary(request, target, '/tak.proto.Analysis/Evaluate', 63 | tak_dot_proto_dot_analysis__pb2.EvaluateRequest.SerializeToString, 64 | tak_dot_proto_dot_analysis__pb2.EvaluateResponse.FromString, 65 | options, channel_credentials, 66 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 67 | -------------------------------------------------------------------------------- /cmd/internal/opt/opt.go: -------------------------------------------------------------------------------- 1 | package opt 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "log" 7 | 8 | "github.com/nelhage/taktician/ai" 9 | ) 10 | 11 | type Minimax struct { 12 | Seed int64 13 | Debug int 14 | Depth int 15 | MaxEvals uint64 16 | Sort bool 17 | TableMem int64 18 | NullMove bool 19 | ExtendForces bool 20 | ReduceSlides bool 21 | MultiCut bool 22 | Precise bool 23 | Weights string 24 | ModWeights string 25 | LogCuts string 26 | Symmetry bool 27 | } 28 | 29 | func (o *Minimax) AddFlags(flags *flag.FlagSet) { 30 | flags.IntVar(&o.Debug, "debug", 0, "debug level") 31 | flags.Int64Var(&o.Seed, "seed", 0, "specify a seed") 32 | flags.IntVar(&o.Depth, "depth", 0, "minimax depth") 33 | flags.Uint64Var(&o.MaxEvals, "max-evals", 0, "Limit the search by number of nodes evaluated") 34 | flags.BoolVar(&o.Sort, "sort", true, "sort moves via history heuristic") 35 | flags.Int64Var(&o.TableMem, "table-mem", 0, "set table size") 36 | flags.BoolVar(&o.NullMove, "null-move", true, "use null-move pruning") 37 | flags.BoolVar(&o.ExtendForces, "extend-forces", true, "extend forced moves") 38 | flags.BoolVar(&o.ReduceSlides, "reduce-slides", true, "reduce trivial slides") 39 | flags.BoolVar(&o.MultiCut, "multi-cut", false, "use multi-cut pruning") 40 | flags.BoolVar(&o.Precise, "precise", false, "Limit to optimizations that provably preserve the game-theoretic value") 41 | flags.StringVar(&o.Weights, "weights", "", "JSON-encoded evaluation weights") 42 | flags.StringVar(&o.ModWeights, "mod-weights", "", "JSON-encoded evaluation weights applied on top of defaults") 43 | flags.StringVar(&o.LogCuts, "log-cuts", "", "log all cuts") 44 | flags.BoolVar(&o.Symmetry, "symmetry", false, "ignore symmetries") 45 | } 46 | 47 | func (o *Minimax) BuildConfig(size int) ai.MinimaxConfig { 48 | var w ai.Weights 49 | var err error 50 | if o.Weights == "" && o.ModWeights == "" { 51 | w = ai.DefaultWeights[size] 52 | } else if o.Weights != "" && o.ModWeights != "" { 53 | log.Fatalf("Can't combine -mod-weights and -weights") 54 | } else if o.Weights != "" { 55 | err = json.Unmarshal([]byte(o.Weights), &w) 56 | 57 | } else if o.ModWeights != "" { 58 | w = ai.DefaultWeights[size] 59 | err = json.Unmarshal([]byte(o.ModWeights), &w) 60 | } 61 | if err != nil { 62 | log.Fatalf("parse weights: %s", err.Error()) 63 | } 64 | cfg := ai.MinimaxConfig{ 65 | Size: size, 66 | Depth: o.Depth, 67 | MaxEvals: o.MaxEvals, 68 | Seed: o.Seed, 69 | Debug: o.Debug, 70 | 71 | NoSort: !o.Sort, 72 | TableMem: o.TableMem, 73 | NoNullMove: !o.NullMove, 74 | NoExtendForces: !o.ExtendForces, 75 | NoReduceSlides: !o.ReduceSlides, 76 | MultiCut: o.MultiCut, 77 | 78 | CutLog: o.LogCuts, 79 | DedupSymmetry: o.Symmetry, 80 | 81 | Evaluate: ai.MakeEvaluator(size, &w), 82 | } 83 | if o.Precise { 84 | cfg.MakePrecise() 85 | } 86 | return cfg 87 | } 88 | -------------------------------------------------------------------------------- /ai/opening.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "math/rand" 7 | "strings" 8 | "time" 9 | 10 | "github.com/nelhage/taktician/ptn" 11 | "github.com/nelhage/taktician/symmetry" 12 | "github.com/nelhage/taktician/tak" 13 | ) 14 | 15 | type OpeningBook struct { 16 | size int 17 | book map[uint64]*openingPosition 18 | } 19 | 20 | type openingPosition struct { 21 | p *tak.Position 22 | moves []child 23 | } 24 | 25 | type child struct { 26 | move tak.Move 27 | weight int 28 | } 29 | 30 | func BuildOpeningBook(size int, lines []string) (*OpeningBook, error) { 31 | ob := &OpeningBook{ 32 | size: size, 33 | book: make(map[uint64]*openingPosition), 34 | } 35 | for lno, line := range lines { 36 | p := tak.New(tak.Config{Size: size}) 37 | bits := strings.Split(line, " ") 38 | 39 | for _, b := range bits { 40 | m, e := ptn.ParseMove(b) 41 | if e != nil { 42 | return nil, fmt.Errorf("line %d: move `%s`: %v", 43 | lno, b, e) 44 | } 45 | 46 | rs, e := symmetry.Symmetries(p) 47 | if e != nil { 48 | return nil, fmt.Errorf("compute symmetries: %v", e) 49 | } 50 | for _, sym := range rs { 51 | pos, ok := ob.book[sym.P.Hash()] 52 | if !ok { 53 | pos = &openingPosition{ 54 | p: sym.P, 55 | } 56 | ob.book[sym.P.Hash()] = pos 57 | } 58 | sm := symmetry.TransformMove(sym.S, m) 59 | var ch *child 60 | for i := range pos.moves { 61 | if pos.moves[i].move.Equal(sm) { 62 | ch = &pos.moves[i] 63 | break 64 | } 65 | } 66 | if ch == nil { 67 | pos.moves = append(pos.moves, 68 | child{ 69 | move: sm, 70 | weight: 0, 71 | }) 72 | ch = &pos.moves[len(pos.moves)-1] 73 | } 74 | ch.weight++ 75 | } 76 | 77 | p, e = p.Move(m) 78 | if e != nil { 79 | return nil, fmt.Errorf("line %d: move `%s`: %v", 80 | lno, b, e) 81 | } 82 | } 83 | } 84 | 85 | return ob, nil 86 | } 87 | 88 | func (ob *OpeningBook) GetMove(p *tak.Position, r *rand.Rand) (tak.Move, bool) { 89 | pos, ok := ob.book[p.Hash()] 90 | if !ok { 91 | return tak.Move{}, false 92 | } 93 | sum := 0 94 | var out tak.Move 95 | for _, ch := range pos.moves { 96 | sum += ch.weight 97 | if r.Int31n(int32(sum)) < int32(ch.weight) { 98 | out = ch.move 99 | } 100 | } 101 | return out, true 102 | } 103 | 104 | type OpeningPlayer struct { 105 | inner TakPlayer 106 | book *OpeningBook 107 | r *rand.Rand 108 | } 109 | 110 | func (op *OpeningPlayer) GetMove(ctx context.Context, p *tak.Position) tak.Move { 111 | if m, ok := op.book.GetMove(p, op.r); ok { 112 | return m 113 | } 114 | return op.inner.GetMove(ctx, p) 115 | } 116 | 117 | func WithOpeningBook(ai TakPlayer, ob *OpeningBook) TakPlayer { 118 | return &OpeningPlayer{ 119 | inner: ai, 120 | book: ob, 121 | r: rand.New(rand.NewSource(time.Now().UnixNano())), 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /cmd/internal/genopenings/genopenings.go: -------------------------------------------------------------------------------- 1 | package genopenings 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "math/rand" 9 | 10 | "github.com/google/subcommands" 11 | "github.com/nelhage/taktician/ptn" 12 | "github.com/nelhage/taktician/symmetry" 13 | "github.com/nelhage/taktician/tak" 14 | ) 15 | 16 | type Command struct { 17 | seed int64 18 | size int 19 | depth int 20 | n int 21 | 22 | rand *rand.Rand 23 | 24 | placeOnly bool 25 | allowSymmetries bool 26 | } 27 | 28 | func (*Command) Name() string { return "genopenings" } 29 | func (*Command) Synopsis() string { return "Generate a set of opening positions" } 30 | func (*Command) Usage() string { 31 | return `genopenings [flags] 32 | ` 33 | } 34 | 35 | func (c *Command) SetFlags(flags *flag.FlagSet) { 36 | flags.IntVar(&c.size, "size", 5, "what size to analyze") 37 | flags.IntVar(&c.depth, "depth", 2, "generate openings to what depth") 38 | flags.IntVar(&c.n, "n", 100, "generate how many openings") 39 | flags.Int64Var(&c.seed, "seed", 0, "Random seed") 40 | flags.BoolVar(&c.placeOnly, "only-place", true, "Only generate moves that place flats") 41 | flags.BoolVar(&c.allowSymmetries, "allow-symmetries", false, "Allow positions that are symmetries of each other") 42 | } 43 | 44 | func (c *Command) Execute(ctx context.Context, flag *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { 45 | c.rand = rand.New(rand.NewSource(c.seed)) 46 | init := tak.New(tak.Config{Size: c.size}) 47 | var positions []*tak.Position 48 | seen := make(map[uint64]*tak.Position) 49 | 50 | generate: 51 | for len(positions) < c.n { 52 | pos := c.generate(init, c.depth) 53 | if c.allowSymmetries { 54 | if got, ok := seen[pos.Hash()]; ok { 55 | if !got.Equal(pos) { 56 | log.Fatalf("hash collision seen=%q new=%q", ptn.FormatTPS(got), ptn.FormatTPS(pos)) 57 | } 58 | continue generate 59 | } 60 | } else { 61 | syms, _ := symmetry.Symmetries(pos) 62 | for _, sym := range syms { 63 | if got, ok := seen[sym.P.Hash()]; ok { 64 | if !got.Equal(sym.P) { 65 | log.Fatalf("hash collision seen=%q new=%q", ptn.FormatTPS(got), ptn.FormatTPS(sym.P)) 66 | } 67 | continue generate 68 | } 69 | } 70 | } 71 | seen[pos.Hash()] = pos 72 | positions = append(positions, pos) 73 | } 74 | 75 | for _, pos := range positions { 76 | fmt.Println(ptn.FormatTPS(pos)) 77 | } 78 | return subcommands.ExitSuccess 79 | } 80 | 81 | func (c *Command) generate(pos *tak.Position, depth int) *tak.Position { 82 | var buf [100]tak.Move 83 | for d := 0; d < depth; d++ { 84 | moves := pos.AllMoves(buf[:0]) 85 | for { 86 | i := c.rand.Intn(len(moves)) 87 | m := moves[i] 88 | if c.placeOnly && m.Type != tak.PlaceFlat { 89 | continue 90 | } 91 | n, e := pos.Move(m) 92 | if e != nil { 93 | continue 94 | } 95 | pos = n 96 | break 97 | } 98 | } 99 | return pos 100 | } 101 | -------------------------------------------------------------------------------- /python/xformer/train/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | 5 | from . import run 6 | 7 | 8 | class Trainer: 9 | start_time: float 10 | run: run.Run 11 | stats: run.Stats 12 | opt: torch.optim.Optimizer 13 | 14 | def __init__(self, training_run: run.Run): 15 | self.run = training_run 16 | self.stats = run.Stats() 17 | 18 | def train(self): 19 | self.start_time = time.time() 20 | self.epoch = iter(self.run.dataset) 21 | 22 | for hook in self.run.hooks: 23 | hook.before_run(self.run) 24 | 25 | self.opt = torch.optim.AdamW( 26 | self.run.model.parameters(), lr=self.run.optimizer.lr 27 | ) 28 | while True: 29 | self.one_step() 30 | if self.run.stop(self.stats): 31 | break 32 | 33 | for hook in self.run.hooks: 34 | hook.after_run(self.run, self.stats) 35 | 36 | def one_step(self): 37 | step_start = time.time() 38 | self.stats.step += 1 39 | self.stats.metrics.clear() 40 | 41 | for hook in self.run.hooks: 42 | hook.before_step(self.run, self.stats) 43 | 44 | self.opt.zero_grad(set_to_none=True) 45 | try: 46 | batch = next(self.epoch) 47 | except StopIteration: 48 | self.stats.epoch += 1 49 | self.epoch = iter(self.run.dataset) 50 | batch = next(self.epoch) 51 | 52 | inputs = batch.inputs 53 | self.stats.sequences += inputs.size(0) 54 | self.stats.tokens += inputs.numel() 55 | 56 | logits = self.run.model(inputs, *batch.extra_inputs) 57 | loss, metrics = self.run.loss.loss_and_metrics(batch, logits) 58 | self.stats.train_loss = loss.item() 59 | for (k, v) in metrics.items(): 60 | self.stats.metrics[f"train.{k}"] = v 61 | loss.backward() 62 | 63 | if self.run.optimizer.lr_schedule: 64 | new_lr = self.run.optimizer.lr * self.run.optimizer.lr_schedule(self.stats) 65 | for g in self.opt.param_groups: 66 | g["lr"] = new_lr 67 | self.stats.metrics["lr"] = new_lr 68 | 69 | self.opt.step() 70 | 71 | # self.profiler.step() 72 | step_done = time.time() 73 | self.stats.step_time = step_done - step_start 74 | self.stats.elapsed_time = step_done - self.start_time 75 | 76 | for hook in self.run.hooks: 77 | hook.after_step(self.run, self.stats) 78 | 79 | self.log_step() 80 | 81 | def log_step(self): 82 | stats = self.stats 83 | 84 | print( 85 | f"[step={stats.step:06d}" 86 | f" t={stats.elapsed_time:.1f}s" 87 | f" sequences={stats.sequences:08d}]" 88 | f" loss={stats.train_loss:2.2f}" 89 | f" ms_per_step={1000*(stats.step_time):.0f}" 90 | ) 91 | if stats.metrics: 92 | for (k, v) in stats.metrics.items(): 93 | print(f" {k}={v}") 94 | 95 | 96 | __all__ = ["Trainer"] 97 | -------------------------------------------------------------------------------- /python/test/test_data.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import pickle 3 | import tempfile 4 | 5 | import pytest 6 | import torch 7 | 8 | from xformer import data 9 | 10 | N_TEST = 16 11 | 12 | 13 | @pytest.fixture(scope="module") 14 | def encoded_dataset(): 15 | with tempfile.TemporaryDirectory() as tmp: 16 | dataset_path = os.path.join(tmp, "dataset.pt") 17 | torch.save( 18 | { 19 | "ints": torch.arange(N_TEST), 20 | "squares": torch.arange(N_TEST) ** 2, 21 | }, 22 | dataset_path, 23 | ) 24 | yield dataset_path 25 | 26 | 27 | def test_basic(encoded_dataset): 28 | for batch_size in (2, 4): 29 | ds = data.Dataset(encoded_dataset, batch_size=batch_size, seed=0x12345678) 30 | batches = list(ds) 31 | for b in batches: 32 | assert isinstance(b, data.Batch) 33 | assert set(b.data.keys()) == {"ints", "squares"} 34 | assert b.data["ints"].shape == (batch_size,) 35 | assert b.data["squares"].shape == (batch_size,) 36 | assert torch.equal(b.data["ints"] ** 2, b.data["squares"]) 37 | all_ints = torch.cat([b.data["ints"] for b in batches]) 38 | assert torch.equal(torch.sort(all_ints).values, torch.arange(N_TEST)) 39 | 40 | 41 | def test_odd_batch(encoded_dataset): 42 | ds = data.Dataset(encoded_dataset, batch_size=6, seed=0x12345678) 43 | batches = list(ds) 44 | assert [len(b.data["ints"]) for b in batches] == [6, 6, 4] 45 | 46 | 47 | def test_reshuffle(encoded_dataset): 48 | ds = data.Dataset(encoded_dataset, batch_size=6, seed=0x12345678) 49 | b1 = list(ds) 50 | b2 = list(ds) 51 | 52 | i1 = torch.cat([b.data["ints"] for b in b1]) 53 | i2 = torch.cat([b.data["ints"] for b in b2]) 54 | assert not torch.equal(i1, i2) 55 | assert torch.equal(torch.sort(i1).values, torch.sort(i2).values) 56 | 57 | 58 | def assert_same_data(ds1, ds2): 59 | for (l, r) in zip(iter(ds1), iter(ds2)): 60 | assert torch.equal(l.data["ints"], r.data["ints"]) 61 | assert torch.equal(l.data["squares"], r.data["squares"]) 62 | 63 | 64 | def test_determinism(encoded_dataset): 65 | ds1 = data.Dataset(encoded_dataset, batch_size=2, seed=0x12345678) 66 | ds2 = data.Dataset(encoded_dataset, batch_size=2, seed=0x12345678) 67 | 68 | assert_same_data(ds1, ds2) 69 | 70 | 71 | def test_fast_forward(encoded_dataset): 72 | ds1 = data.Dataset(encoded_dataset, batch_size=2, seed=0x12345678) 73 | ds2 = data.Dataset(encoded_dataset, batch_size=2, seed=0x12345678) 74 | 75 | for _ in range(4): 76 | list(ds1) 77 | ds2.fastforward_epochs(4) 78 | 79 | assert_same_data(ds1, ds2) 80 | 81 | 82 | def test_serde(encoded_dataset): 83 | ds1 = data.Dataset(encoded_dataset, batch_size=2, seed=0x12345678) 84 | ds2 = data.Dataset(encoded_dataset, batch_size=2, seed=0x12345678) 85 | 86 | list(ds1) 87 | list(ds2) 88 | 89 | ds2 = pickle.loads(pickle.dumps(ds2)) 90 | ds2.fastforward_epochs(1) 91 | 92 | assert_same_data(ds1, ds2) 93 | -------------------------------------------------------------------------------- /ptn/move_test.go: -------------------------------------------------------------------------------- 1 | package ptn 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/nelhage/taktician/tak" 8 | ) 9 | 10 | func TestParseMove(t *testing.T) { 11 | cases := []struct { 12 | in string 13 | out tak.Move 14 | str string 15 | long string 16 | }{ 17 | { 18 | "a1", 19 | tak.Move{X: 0, Y: 0, Type: tak.PlaceFlat}, 20 | "a1", 21 | "Fa1", 22 | }, 23 | { 24 | "Sa4", 25 | tak.Move{X: 0, Y: 3, Type: tak.PlaceStanding}, 26 | "Sa4", 27 | "Sa4", 28 | }, 29 | { 30 | "Ch7", 31 | tak.Move{X: 7, Y: 6, Type: tak.PlaceCapstone}, 32 | "Ch7", 33 | "Ch7", 34 | }, 35 | { 36 | "Fh7", 37 | tak.Move{X: 7, Y: 6, Type: tak.PlaceFlat}, 38 | "h7", 39 | "Fh7", 40 | }, 41 | { 42 | "a1>", 43 | tak.Move{X: 0, Y: 0, Type: tak.SlideRight, Slides: tak.MkSlides(1)}, 44 | "a1>", 45 | "1a1>1", 46 | }, 47 | { 48 | "2a2<", 49 | tak.Move{X: 0, Y: 1, Type: tak.SlideLeft, Slides: tak.MkSlides(2)}, 50 | "2a2<", 51 | "2a2<2", 52 | }, 53 | { 54 | "3a1+111", 55 | tak.Move{X: 0, Y: 0, Type: tak.SlideUp, Slides: tak.MkSlides(1, 1, 1)}, 56 | "3a1+111", 57 | "3a1+111", 58 | }, 59 | { 60 | "5d4-22", 61 | tak.Move{X: 3, Y: 3, Type: tak.SlideDown, Slides: tak.MkSlides(2, 2, 1)}, 62 | "5d4-221", 63 | "5d4-221", 64 | }, 65 | { 66 | "a1?", 67 | tak.Move{X: 0, Y: 0, Type: tak.PlaceFlat}, 68 | "a1", 69 | "Fa1", 70 | }, 71 | { 72 | "Ch7!", 73 | tak.Move{X: 7, Y: 6, Type: tak.PlaceCapstone}, 74 | "Ch7", 75 | "Ch7", 76 | }, 77 | { 78 | "b1>*'", 79 | tak.Move{X: 1, Y: 0, Type: tak.SlideRight, Slides: tak.MkSlides(1)}, 80 | "b1>", 81 | "1b1>1", 82 | }, 83 | { 84 | "2a2<*", 85 | tak.Move{X: 0, Y: 1, Type: tak.SlideLeft, Slides: tak.MkSlides(2)}, 86 | "2a2<", 87 | "2a2<2", 88 | }, 89 | { 90 | "3a1+111''!", 91 | tak.Move{X: 0, Y: 0, Type: tak.SlideUp, Slides: tak.MkSlides(1, 1, 1)}, 92 | "3a1+111", 93 | "3a1+111", 94 | }, 95 | } 96 | for _, tc := range cases { 97 | get, err := ParseMove(tc.in) 98 | if err != nil { 99 | t.Errorf("ParseMove(%s): err=%v", tc.in, err) 100 | continue 101 | } 102 | if !reflect.DeepEqual(get, tc.out) { 103 | t.Errorf("ParseMove(%s)=%#v not %#v", tc.in, get, tc.out) 104 | } 105 | rt := FormatMove(tc.out) 106 | if rt != tc.str { 107 | t.Errorf("FormatMove(%s)=%s not %s", tc.in, rt, tc.str) 108 | } 109 | long := FormatMoveLong(tc.out) 110 | if long != tc.long { 111 | t.Errorf("FormatMoveLong(%s)=%s not %s", tc.in, long, tc.long) 112 | } 113 | } 114 | } 115 | 116 | func TestParseMoveErrors(t *testing.T) { 117 | bad := []string{ 118 | "", 119 | "a11", 120 | "z3", 121 | "14c4>", 122 | "6a1", 123 | "6a1>2222", 124 | "a", 125 | "3a", 126 | } 127 | for _, b := range bad { 128 | _, e := ParseMove(b) 129 | if e == nil { 130 | t.Errorf("parse(%q): no error", b) 131 | } 132 | } 133 | } 134 | 135 | func BenchmarkParseMove(b *testing.B) { 136 | for i := 0; i < b.N; i++ { 137 | ParseMove("3a1+111") 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /python/scripts/analysis_server: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import logging 4 | 5 | from xformer import loading 6 | 7 | from tak.proto import analysis_pb2_grpc 8 | from tak.proto import analysis_pb2 9 | import tak.model.server 10 | import argparse 11 | 12 | import numpy as np 13 | import torch 14 | from torch import nn 15 | 16 | from attrs import define, field 17 | 18 | import grpc 19 | import asyncio 20 | import time 21 | 22 | import typing as T 23 | 24 | 25 | _cleanup_coroutines = [] 26 | 27 | 28 | async def main(argv): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument( 31 | "--fp16", 32 | action="store_true", 33 | default=None, 34 | help="Run model in float16", 35 | ) 36 | parser.add_argument( 37 | "--no-fp16", 38 | action="store_false", 39 | dest="fp16", 40 | ) 41 | parser.add_argument( 42 | "--no-script", 43 | action="store_false", 44 | dest="script", 45 | default=True, 46 | ) 47 | parser.add_argument( 48 | "--device", 49 | type=str, 50 | default="cuda" if torch.cuda.is_available() else "cpu", 51 | ) 52 | parser.add_argument( 53 | "--host", 54 | type=str, 55 | default="localhost", 56 | ) 57 | parser.add_argument( 58 | "--port", 59 | type=int, 60 | default=5001, 61 | ) 62 | parser.add_argument( 63 | "model", 64 | type=str, 65 | ) 66 | 67 | args = parser.parse_args(argv) 68 | 69 | model = loading.load_model(args.model, args.device) 70 | fp16 = args.fp16 71 | if fp16 is None: 72 | fp16 == args.device == "cuda" 73 | if fp16: 74 | model = model.to(torch.float16) 75 | if args.script: 76 | model = torch.jit.script(model) 77 | 78 | server = grpc.aio.server() 79 | server.add_insecure_port(f"{args.host}:{args.port}") 80 | 81 | analysis = tak.model.server.Server(model=model, device=args.device) 82 | worker = asyncio.create_task(analysis.worker_loop()) 83 | 84 | analysis_pb2_grpc.add_AnalysisServicer_to_server( 85 | analysis, 86 | server, 87 | ) 88 | await server.start() 89 | 90 | async def server_graceful_shutdown(): 91 | logging.info("Starting graceful shutdown...") 92 | # Shuts down the server with 5 seconds of grace period. During the 93 | # grace period, the server won't accept new connections and allow 94 | # existing RPCs to continue within the grace period. 95 | await server.stop(2) 96 | worker.cancel() 97 | 98 | _cleanup_coroutines.append(server_graceful_shutdown()) 99 | 100 | await server.wait_for_termination() 101 | 102 | 103 | if __name__ == "__main__": 104 | logging.basicConfig(level=logging.INFO) 105 | loop = asyncio.get_event_loop() 106 | try: 107 | loop.run_until_complete(main(sys.argv[1:])) 108 | finally: 109 | for co in _cleanup_coroutines: 110 | loop.run_until_complete(co) 111 | loop.close() 112 | -------------------------------------------------------------------------------- /python/tak/ptn/tps.py: -------------------------------------------------------------------------------- 1 | import tak 2 | 3 | 4 | def parse_tps(tps): 5 | bits = tps.split(" ") 6 | if len(bits) != 3: 7 | raise IllegalTPS("need three components") 8 | board, who, move = bits 9 | 10 | if not who in "12": 11 | raise IllegalTPS("Current player must be either 1 or 2") 12 | try: 13 | ply = 2 * (int(move) - 1) + int(who) - 1 14 | except ValueError: 15 | raise IllegalTPS("Bad move number: " + move) 16 | 17 | squares = [] 18 | rows = board.split("/") 19 | for row in reversed(rows): 20 | rsq = parse_row(row) 21 | if len(rsq) != len(rows): 22 | raise IllegalTPS("inconsistent size") 23 | squares += rsq 24 | 25 | return tak.Position.from_squares(tak.Config(size=len(rows)), squares, ply) 26 | 27 | 28 | def parse_row(rtext): 29 | squares = [] 30 | bits = rtext.split(",") 31 | for b in bits: 32 | if b[0] == "x": 33 | n = 1 34 | if len(b) > 1: 35 | n = int(b[1:]) 36 | squares += [[]] * n 37 | continue 38 | 39 | stack = [] 40 | for c in b: 41 | if c == "1": 42 | stack.append(tak.Piece.cached(tak.Color.WHITE, tak.Kind.FLAT)) 43 | elif c == "2": 44 | stack.append(tak.Piece.cached(tak.Color.BLACK, tak.Kind.FLAT)) 45 | elif c in ("C", "S"): 46 | if not stack: 47 | raise IllegalTPS("bare capstone or standing") 48 | typ = tak.Kind.CAPSTONE if c == "C" else tak.Kind.STANDING 49 | stack[-1] = tak.Piece.cached(stack[-1].color, typ) 50 | else: 51 | raise IllegalTPS("bad character: " + c) 52 | 53 | squares.append(list(reversed(stack))) 54 | return squares 55 | 56 | 57 | def format_tps(pos): 58 | rows = [] 59 | for row in range(pos.size): 60 | i = row * pos.size 61 | rows.append(_format_row(pos.board[i : i + pos.size])) 62 | 63 | return " ".join( 64 | ["/".join(reversed(rows)), str((pos.ply % 2) + 1), str(pos.ply // 2 + 1)] 65 | ) 66 | 67 | 68 | def _format_row(row): 69 | out = [] 70 | i = 0 71 | while i < len(row): 72 | x = 0 73 | while i + x < len(row) and row[i + x] == []: 74 | x += 1 75 | if x > 0: 76 | out.append("x{0}".format(x if x > 1 else "")) 77 | i += x 78 | else: 79 | out.append(_format_square(row[i])) 80 | i += 1 81 | return ",".join(out) 82 | 83 | 84 | def _format_square(sq): 85 | out = [] 86 | for p in reversed(sq): 87 | if p.color == tak.Color.WHITE: 88 | out.append("1") 89 | else: 90 | out.append("2") 91 | 92 | if sq[0].kind == tak.Kind.STANDING: 93 | out.append("S") 94 | elif sq[0].kind == tak.Kind.CAPSTONE: 95 | out.append("C") 96 | 97 | return "".join(out) 98 | 99 | 100 | class IllegalTPS(Exception): 101 | pass 102 | 103 | 104 | __all__ = ["parse_tps", "format_tps", "IllegalTPS"] 105 | -------------------------------------------------------------------------------- /python/scripts/self_play.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import sys 4 | import traceback 5 | 6 | import tak 7 | from tak import mcts, self_play 8 | from tak.model import grpc 9 | import attrs 10 | import tqdm 11 | 12 | import queue 13 | from torch import multiprocessing 14 | 15 | import torch 16 | import numpy as np 17 | 18 | import time 19 | 20 | 21 | def main(argv): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument( 24 | "--simulations", 25 | dest="simulations", 26 | type=int, 27 | default=100, 28 | metavar="POSITIONS", 29 | ) 30 | parser.add_argument( 31 | "--size", 32 | dest="size", 33 | type=int, 34 | default=3, 35 | metavar="SIZE", 36 | ) 37 | parser.add_argument( 38 | "--host", 39 | type=str, 40 | default="localhost", 41 | ) 42 | parser.add_argument( 43 | "--port", 44 | type=int, 45 | default=5001, 46 | ) 47 | parser.add_argument( 48 | "--games", 49 | type=int, 50 | default=1, 51 | ) 52 | parser.add_argument( 53 | "--noise-alpha", 54 | type=float, 55 | default=None, 56 | ) 57 | parser.add_argument( 58 | "--noise-weight", 59 | type=float, 60 | default=0.25, 61 | ) 62 | parser.add_argument( 63 | "--resign-threshold", 64 | type=float, 65 | default=0.99, 66 | ) 67 | parser.add_argument( 68 | "--threads", 69 | type=int, 70 | default=1, 71 | ) 72 | parser.add_argument( 73 | "-C", 74 | type=float, 75 | default=4, 76 | ) 77 | parser.add_argument("--write-games", type=str, metavar="FILE") 78 | 79 | args = parser.parse_args(argv) 80 | 81 | config = self_play.SelfPlayConfig( 82 | size=args.size, 83 | workers=args.threads, 84 | resignation_threshold=args.resign_threshold, 85 | engine_factory=self_play.BuildRemoteMCTS( 86 | host=args.host, 87 | port=args.port, 88 | config=mcts.Config( 89 | simulation_limit=args.simulations, 90 | root_noise_alpha=args.noise_alpha, 91 | root_noise_mix=args.noise_weight, 92 | C=args.C, 93 | ), 94 | ), 95 | ) 96 | 97 | start = time.time() 98 | 99 | logs = self_play.play_many_games(config, args.games, progress=True) 100 | 101 | end = time.time() 102 | 103 | stats = mcts.Stats() 104 | for l in logs: 105 | stats = stats.merge(l.stats) 106 | 107 | print( 108 | f"done games={len(logs)}" 109 | f" plies={sum(len(l.positions) for l in logs)}" 110 | f" threads={args.threads} duration={end-start:.2f}" 111 | f" games/s={args.games/(end-start):.1f}" 112 | " " + " ".join(f"{k}={v}" for (k, v) in attrs.asdict(stats).items()) 113 | ) 114 | 115 | if args.write_games: 116 | torch.save( 117 | self_play.encode_games(logs), 118 | args.write_games, 119 | ) 120 | 121 | pass 122 | 123 | 124 | if __name__ == "__main__": 125 | main(sys.argv[1:]) 126 | -------------------------------------------------------------------------------- /ai/minimax_test.go: -------------------------------------------------------------------------------- 1 | package ai 2 | 3 | import ( 4 | "flag" 5 | "testing" 6 | "time" 7 | 8 | "context" 9 | 10 | "github.com/nelhage/taktician/ptn" 11 | "github.com/nelhage/taktician/tak" 12 | ) 13 | 14 | var size = flag.Int("size", 5, "board size to benchmark") 15 | var depth = flag.Int("depth", 4, "minimax search depth") 16 | var seed = flag.Int64("seed", 4, "random seed") 17 | 18 | func BenchmarkMinimax(b *testing.B) { 19 | var cfg = tak.Config{Size: *size} 20 | p := tak.New(cfg) 21 | p, _ = p.Move(tak.Move{X: 0, Y: 0, Type: tak.PlaceFlat}) 22 | p, _ = p.Move(tak.Move{X: int8(*size - 1), Y: int8(*size - 1), Type: tak.PlaceFlat}) 23 | ai := NewMinimax(MinimaxConfig{ 24 | Size: *size, 25 | Depth: *depth, 26 | Seed: *seed, 27 | }) 28 | 29 | base := p.Clone() 30 | 31 | next := tak.Alloc(*size) 32 | 33 | b.ReportAllocs() 34 | b.ResetTimer() 35 | 36 | for i := 0; i < b.N; i++ { 37 | var e error 38 | m := ai.GetMove(context.Background(), p) 39 | next, e = p.MovePreallocated(m, next) 40 | if e != nil { 41 | b.Fatal("bad move", e) 42 | } 43 | p, next = next, p 44 | 45 | if over, _ := p.GameOver(); over { 46 | p = base.Clone() 47 | } 48 | } 49 | } 50 | 51 | func TestRegression(t *testing.T) { 52 | game, err := ptn.ParseTPS( 53 | `2,x4/x2,2,x2/x,2,2,x2/x2,12,2,1/1,1,21,2,1 1 9`, 54 | ) 55 | if err != nil { 56 | panic(err) 57 | } 58 | ai := NewMinimax(MinimaxConfig{Size: game.Size(), Depth: 3}) 59 | m := ai.GetMove(context.Background(), game) 60 | _, e := game.Move(m) 61 | if e != nil { 62 | t.Fatalf("ai returned illegal move: %s: %s", ptn.FormatMove(m), e) 63 | } 64 | } 65 | 66 | func TestCancel(t *testing.T) { 67 | ctx, cancel := context.WithCancel(context.Background()) 68 | ai := NewMinimax(MinimaxConfig{Size: 5, Depth: maxDepth}) 69 | p := tak.New(tak.Config{Size: 5}) 70 | done := make(chan Stats) 71 | go func() { 72 | _, _, st := ai.Analyze(ctx, p) 73 | done <- st 74 | }() 75 | cancel() 76 | st := <-done 77 | if st.Depth == maxDepth { 78 | t.Fatal("wtf too deep") 79 | } 80 | if !st.Canceled { 81 | t.Fatal("didn't cancel") 82 | } 83 | } 84 | 85 | func TestRepeatedCancel(t *testing.T) { 86 | type result struct { 87 | ms []tak.Move 88 | st Stats 89 | } 90 | ctx := context.Background() 91 | ai := NewMinimax(MinimaxConfig{Size: 5, Depth: 6, NoNullMove: true, TableMem: -1}) 92 | p := tak.New(tak.Config{Size: 5}) 93 | for i := 0; i < 5; i++ { 94 | done := make(chan result) 95 | start := make(chan struct{}) 96 | ctx, cancel := context.WithCancel(ctx) 97 | go func() { 98 | start <- struct{}{} 99 | ms, _, st := ai.Analyze(ctx, p) 100 | done <- result{ms, st} 101 | }() 102 | <-start 103 | time.Sleep(time.Millisecond) 104 | cancel() 105 | res := <-done 106 | if res.st.Depth == 6 { 107 | t.Fatalf("[%d] cancel() didn't work", i) 108 | } 109 | if !res.st.Canceled { 110 | t.Fatalf("[%d] not canceled", i) 111 | } 112 | if len(res.ms) == 0 { 113 | t.Fatalf("[%d] canceled search did not return a move", i) 114 | } 115 | } 116 | ms, _, st := ai.Analyze(ctx, p) 117 | if len(ms) == 0 { 118 | t.Fatal("did not return a move") 119 | } 120 | if st.Depth != 6 { 121 | t.Fatal("did not do full search") 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /bitboard/bits_test.go: -------------------------------------------------------------------------------- 1 | package bitboard 2 | 3 | import ( 4 | "strconv" 5 | "testing" 6 | ) 7 | 8 | func TestPrecompute(t *testing.T) { 9 | c := Precompute(5) 10 | if c.B != (1<<5)-1 { 11 | t.Error("c.b(5):", strconv.FormatUint(c.B, 2)) 12 | } 13 | if c.T != ((1<<5)-1)<<(4*5) { 14 | t.Error("c.t(5):", strconv.FormatUint(c.T, 2)) 15 | } 16 | if c.R != 0x0108421 { 17 | t.Error("c.r(5):", strconv.FormatUint(c.R, 2)) 18 | } 19 | if c.L != 0x1084210 { 20 | t.Error("c.l(5):", strconv.FormatUint(c.L, 2)) 21 | } 22 | if c.Mask != 0x1ffffff { 23 | t.Error("c.mask(5):", strconv.FormatUint(c.Mask, 2)) 24 | } 25 | 26 | c = Precompute(8) 27 | if c.B != (1<<8)-1 { 28 | t.Error("c.b(8):", strconv.FormatUint(c.B, 2)) 29 | } 30 | if c.T != ((1<<8)-1)<<(7*8) { 31 | t.Error("c.t(8):", strconv.FormatUint(c.T, 2)) 32 | } 33 | if c.R != 0x101010101010101 { 34 | t.Error("c.r(8):", strconv.FormatUint(c.R, 2)) 35 | } 36 | if c.L != 0x8080808080808080 { 37 | t.Error("c.l(8):", strconv.FormatUint(c.L, 2)) 38 | } 39 | if c.Mask != ^uint64(0) { 40 | t.Error("c.mask(8):", strconv.FormatUint(c.Mask, 2)) 41 | } 42 | } 43 | 44 | func TestFlood(t *testing.T) { 45 | cases := []struct { 46 | size uint 47 | bound uint64 48 | seed uint64 49 | out uint64 50 | }{ 51 | { 52 | 5, 53 | 0x108423c, 54 | 0x4, 55 | 0x108421c, 56 | }, 57 | } 58 | for _, tc := range cases { 59 | c := Precompute(tc.size) 60 | got := Flood(&c, tc.bound, tc.seed) 61 | if got != tc.out { 62 | t.Errorf("Flood[%d](%s, %s)=%s !=%s", 63 | tc.size, 64 | strconv.FormatUint(tc.bound, 2), 65 | strconv.FormatUint(tc.seed, 2), 66 | strconv.FormatUint(got, 2), 67 | strconv.FormatUint(tc.out, 2)) 68 | } 69 | } 70 | } 71 | 72 | func TestDimensions(t *testing.T) { 73 | cases := []struct { 74 | size uint 75 | bits uint64 76 | w int 77 | h int 78 | }{ 79 | {5, 0x108421c, 3, 5}, 80 | {5, 0, 0, 0}, 81 | {5, 0x843800, 3, 3}, 82 | {5, 0x08000, 1, 1}, 83 | } 84 | for _, tc := range cases { 85 | c := Precompute(tc.size) 86 | w, h := Dimensions(&c, tc.bits) 87 | if w != tc.w || h != tc.h { 88 | t.Errorf("Dimensions(%d, %x) = (%d,%d) != (%d,%d)", 89 | tc.size, tc.bits, w, h, tc.w, tc.h, 90 | ) 91 | } 92 | } 93 | 94 | } 95 | 96 | func TestBitCoords(t *testing.T) { 97 | cases := []struct { 98 | size uint 99 | x uint 100 | y uint 101 | }{ 102 | {5, 1, 1}, 103 | {3, 1, 1}, 104 | {3, 2, 2}, 105 | {5, 3, 1}, 106 | {5, 0, 1}, 107 | } 108 | for _, tc := range cases { 109 | c := Precompute(tc.size) 110 | bit := uint64(1) << (c.Size*tc.y + tc.x) 111 | x, y := BitCoords(&c, bit) 112 | if x != tc.x || y != tc.y { 113 | t.Errorf("BitCoords(Precompute(%d), (%d,%d)) = (%d, %d)", 114 | c.Size, tc.x, tc.y, x, y, 115 | ) 116 | } 117 | } 118 | } 119 | 120 | func TestTrailingZeros(t *testing.T) { 121 | cases := []struct { 122 | in uint64 123 | out uint 124 | }{ 125 | {0x00, 64}, 126 | {0x01, 0}, 127 | {0x02, 1}, 128 | {0x010, 4}, 129 | } 130 | for _, tc := range cases { 131 | got := TrailingZeros(tc.in) 132 | if got != tc.out { 133 | t.Errorf("TrailingZeros(%x)=%d != %d", 134 | tc.in, got, tc.out, 135 | ) 136 | } 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /tak/alloc.go: -------------------------------------------------------------------------------- 1 | package tak 2 | 3 | import "fmt" 4 | 5 | type position3 struct { 6 | Position 7 | alloc struct { 8 | Height [3 * 3]uint8 9 | Stacks [3 * 3]uint64 10 | Groups [6]uint64 11 | } 12 | } 13 | 14 | type position4 struct { 15 | Position 16 | alloc struct { 17 | Height [4 * 4]uint8 18 | Stacks [4 * 4]uint64 19 | Groups [8]uint64 20 | } 21 | } 22 | 23 | type position5 struct { 24 | Position 25 | alloc struct { 26 | Height [5 * 5]uint8 27 | Stacks [5 * 5]uint64 28 | Groups [10]uint64 29 | } 30 | } 31 | 32 | type position6 struct { 33 | Position 34 | alloc struct { 35 | Height [6 * 6]uint8 36 | Stacks [6 * 6]uint64 37 | Groups [12]uint64 38 | } 39 | } 40 | 41 | type position7 struct { 42 | Position 43 | alloc struct { 44 | Height [7 * 7]uint8 45 | Stacks [7 * 7]uint64 46 | Groups [14]uint64 47 | } 48 | } 49 | 50 | type position8 struct { 51 | Position 52 | alloc struct { 53 | Height [8 * 8]uint8 54 | Stacks [8 * 8]uint64 55 | Groups [16]uint64 56 | } 57 | } 58 | 59 | func alloc(tpl *Position) *Position { 60 | switch tpl.Size() { 61 | case 3: 62 | a := &position3{Position: *tpl} 63 | a.Height = a.alloc.Height[:] 64 | a.Stacks = a.alloc.Stacks[:] 65 | a.analysis.WhiteGroups = a.alloc.Groups[:0] 66 | copy(a.Height, tpl.Height) 67 | copy(a.Stacks, tpl.Stacks) 68 | 69 | return &a.Position 70 | case 4: 71 | a := &position4{Position: *tpl} 72 | a.Height = a.alloc.Height[:] 73 | a.Stacks = a.alloc.Stacks[:] 74 | a.analysis.WhiteGroups = a.alloc.Groups[:0] 75 | copy(a.Height, tpl.Height) 76 | copy(a.Stacks, tpl.Stacks) 77 | 78 | return &a.Position 79 | case 5: 80 | a := &position5{Position: *tpl} 81 | a.Height = a.alloc.Height[:] 82 | a.Stacks = a.alloc.Stacks[:] 83 | a.analysis.WhiteGroups = a.alloc.Groups[:0] 84 | copy(a.Height, tpl.Height) 85 | copy(a.Stacks, tpl.Stacks) 86 | 87 | return &a.Position 88 | case 6: 89 | a := &position6{Position: *tpl} 90 | a.Height = a.alloc.Height[:] 91 | a.Stacks = a.alloc.Stacks[:] 92 | a.analysis.WhiteGroups = a.alloc.Groups[:0] 93 | copy(a.Height, tpl.Height) 94 | copy(a.Stacks, tpl.Stacks) 95 | 96 | return &a.Position 97 | case 7: 98 | a := &position7{Position: *tpl} 99 | a.Height = a.alloc.Height[:] 100 | a.Stacks = a.alloc.Stacks[:] 101 | a.analysis.WhiteGroups = a.alloc.Groups[:0] 102 | copy(a.Height, tpl.Height) 103 | copy(a.Stacks, tpl.Stacks) 104 | 105 | return &a.Position 106 | case 8: 107 | a := &position8{Position: *tpl} 108 | a.Height = a.alloc.Height[:] 109 | a.Stacks = a.alloc.Stacks[:] 110 | a.analysis.WhiteGroups = a.alloc.Groups[:0] 111 | copy(a.Height, tpl.Height) 112 | copy(a.Stacks, tpl.Stacks) 113 | 114 | return &a.Position 115 | default: 116 | panic(fmt.Sprintf("illegal size: %d", tpl.Size())) 117 | } 118 | } 119 | 120 | func copyPosition(p *Position, out *Position) { 121 | h := out.Height 122 | s := out.Stacks 123 | g := out.analysis.WhiteGroups 124 | 125 | *out = *p 126 | out.Height = h 127 | out.Stacks = s 128 | out.analysis.WhiteGroups = g[:0] 129 | 130 | copy(out.Height, p.Height) 131 | copy(out.Stacks, p.Stacks) 132 | } 133 | 134 | func Alloc(size int) *Position { 135 | p := Position{cfg: &Config{Size: size}} 136 | return alloc(&p) 137 | } 138 | -------------------------------------------------------------------------------- /python/tak/model/server.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from tak.proto import analysis_pb2_grpc 4 | from tak.proto import analysis_pb2 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | 10 | from attrs import define, field 11 | 12 | import asyncio 13 | import time 14 | 15 | import typing as T 16 | 17 | 18 | MAX_QUEUE_DEPTH = 80 19 | 20 | 21 | @define 22 | class QueueRequest: 23 | position: torch.Tensor 24 | 25 | ready: asyncio.Event = field(factory=asyncio.Event) 26 | probs: T.Optional[np.ndarray] = None 27 | value: T.Optional[float] = None 28 | 29 | 30 | @define 31 | class Server(analysis_pb2_grpc.AnalysisServicer): 32 | model: nn.Module 33 | device: str = "cpu" 34 | 35 | queue: asyncio.Queue = field(factory=lambda: asyncio.Queue(MAX_QUEUE_DEPTH)) 36 | 37 | async def worker_loop(self): 38 | loop = asyncio.get_running_loop() 39 | 40 | while True: 41 | batch = [] 42 | batch.append(await self.queue.get()) 43 | while True: 44 | if len(batch) >= 8: 45 | try: 46 | batch.append(self.queue.get_nowait()) 47 | except asyncio.QueueEmpty: 48 | break 49 | else: 50 | try: 51 | elem = await asyncio.wait_for(self.queue.get(), 1.0 / 1000) 52 | batch.append(elem) 53 | except asyncio.TimeoutError: 54 | break 55 | # now we have a batch 56 | 57 | @torch.inference_mode() 58 | def run_model(batch): 59 | positions = torch.zeros( 60 | (len(batch), max(len(b.position) for b in batch)), dtype=torch.long 61 | ) 62 | mask = torch.zeros_like(positions, dtype=torch.bool) 63 | for (i, b) in enumerate(batch): 64 | positions[i, : len(b.position)] = b.position 65 | mask[i, len(b.position) :].fill_(1) 66 | out = self.model(positions.to(self.device), mask.to(self.device)) 67 | probs = ( 68 | torch.softmax(out["moves"], dim=-1) 69 | .to(device="cpu", dtype=torch.float32) 70 | .numpy() 71 | ) 72 | values = out["values"].to(device="cpu", dtype=torch.float32).numpy() 73 | return ( 74 | probs, 75 | values, 76 | ) 77 | 78 | start = time.perf_counter() 79 | (probs, values) = await loop.run_in_executor(None, run_model, batch) 80 | end = time.perf_counter() 81 | logging.info(f"did batch len={len(batch)} dur={1000*(end-start):0.1f}ms") 82 | for (i, b) in enumerate(batch): 83 | b.probs = probs[i] 84 | b.value = values[i] 85 | b.ready.set() 86 | 87 | async def Evaluate(self, request, context): 88 | position = torch.tensor(request.position, dtype=torch.long) 89 | 90 | req = QueueRequest(position=position) 91 | await self.queue.put(req) 92 | await req.ready.wait() 93 | 94 | return analysis_pb2.EvaluateResponse( 95 | move_probs_bytes=req.probs.tobytes(), value=req.value 96 | ) 97 | -------------------------------------------------------------------------------- /playtak/bot/mock_test.go: -------------------------------------------------------------------------------- 1 | package bot 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "strings" 7 | "sync" 8 | "testing" 9 | "time" 10 | 11 | "github.com/nelhage/taktician/tak" 12 | ) 13 | 14 | type Expectation struct { 15 | send, recv []string 16 | } 17 | 18 | type TestClient struct { 19 | send, recv chan string 20 | 21 | t *testing.T 22 | expect []Expectation 23 | } 24 | 25 | func (c *TestClient) shutdown() { 26 | close(c.send) 27 | } 28 | 29 | func NewTestClient(t *testing.T, expect []Expectation) *TestClient { 30 | c := &TestClient{ 31 | send: make(chan string), 32 | recv: make(chan string), 33 | t: t, 34 | expect: expect, 35 | } 36 | go c.sendRecv() 37 | return c 38 | } 39 | 40 | func (t *TestClient) sendRecv() { 41 | for i, e := range t.expect { 42 | for _, s := range e.send { 43 | t.recv <- s 44 | log.Printf("[srv] -> %s", s) 45 | } 46 | for j, r := range e.recv { 47 | got := <-t.send 48 | log.Printf("[srv] <- %s", got) 49 | if got != r { 50 | t.t.Fatalf("msg %d,%d: got %q != %q", 51 | i, j, got, r) 52 | } 53 | } 54 | } 55 | close(t.recv) 56 | } 57 | 58 | func (t *TestClient) SendCommand(cmd ...string) { 59 | t.send <- strings.Join(cmd, " ") 60 | } 61 | func (t *TestClient) Recv() <-chan string { 62 | return t.recv 63 | } 64 | 65 | type BotBase struct { 66 | game *Game 67 | } 68 | 69 | func (t *BotBase) NewGame(g *Game) { 70 | t.game = g 71 | } 72 | 73 | func (t *BotBase) GameOver() {} 74 | 75 | func (t *BotBase) AcceptUndo() bool { 76 | return false 77 | } 78 | func (t *BotBase) HandleTell(who, msg string) {} 79 | func (t *BotBase) HandleChat(room, who, msg string) {} 80 | 81 | type TestBotStatic struct { 82 | BotBase 83 | moves []tak.Move 84 | } 85 | 86 | func (t *TestBotStatic) GetMove(ctx context.Context, 87 | p *tak.Position, 88 | mine, theirs time.Duration) tak.Move { 89 | log.Printf("(*TestBot).GetMove(ply=%d color=%s)", 90 | p.MoveNumber(), 91 | p.ToMove(), 92 | ) 93 | if p.ToMove() != t.game.Color { 94 | return tak.Move{} 95 | } 96 | i := p.MoveNumber() / 2 97 | return t.moves[i] 98 | } 99 | 100 | type TestBotUndo struct { 101 | TestBotStatic 102 | undoPly int 103 | } 104 | 105 | func (t *TestBotUndo) GetMove(ctx context.Context, 106 | p *tak.Position, 107 | mine, theirs time.Duration) tak.Move { 108 | if p.MoveNumber() == t.undoPly+1 { 109 | select { 110 | case <-ctx.Done(): 111 | case <-time.After(10 * time.Millisecond): 112 | } 113 | } 114 | return t.TestBotStatic.GetMove(ctx, p, mine, theirs) 115 | } 116 | 117 | func (t *TestBotUndo) AcceptUndo() bool { 118 | return true 119 | } 120 | 121 | type TestBotThinker struct { 122 | TestBotStatic 123 | wg sync.WaitGroup 124 | } 125 | 126 | func (t *TestBotThinker) GetMove(ctx context.Context, 127 | p *tak.Position, 128 | mine, theirs time.Duration) tak.Move { 129 | defer t.wg.Done() 130 | if p.ToMove() != t.game.Color { 131 | <-ctx.Done() 132 | return tak.Move{} 133 | } 134 | return t.TestBotStatic.GetMove(ctx, p, mine, theirs) 135 | } 136 | 137 | type TestBotResume struct { 138 | TestBotStatic 139 | } 140 | 141 | func (t *TestBotResume) GetMove(ctx context.Context, 142 | p *tak.Position, 143 | mine, theirs time.Duration) tak.Move { 144 | if p.MoveNumber() == 0 { 145 | time.Sleep(10 * time.Millisecond) 146 | } 147 | return t.TestBotStatic.GetMove(ctx, p, mine, theirs) 148 | } 149 | -------------------------------------------------------------------------------- /doc/bitboards.md: -------------------------------------------------------------------------------- 1 | # bitboards for Tak 2 | 3 | It is [well-known][chess] in the field of chess AI that one efficient 4 | representation of a chess board uses of 64-bit integers as bit-sets to 5 | represent features of the 8x8 board. In addition to being a very 6 | compact representation, this format allows for all kinds of clever 7 | tricks when computing possible moves and attacks on the board. 8 | 9 | [chess]: https://chessprogramming.wikispaces.com/Bitboards 10 | 11 | Compared to chess, Tak has the additional complication of 12 | 3-dimensional stacks, which are less obviously representable in a 13 | bitwise fashion. This document describes Taktician's approach to 14 | efficiently representing Tak boards using bit-wise representations. 15 | 16 | # Stack tops 17 | 18 | The top piece in each stack has special significance in Tak. Walls and 19 | Capstones may only be present on top of stacks, and the top piece is 20 | relevant for determining roads, for scoring the game for flat wins, 21 | and for determining control over each stack. 22 | 23 | In light of these special properties of the top piece in each stack, 24 | Taktician represents the top of each stack separately from the rest of 25 | the board. 26 | 27 | Taktician stores the board-state of the stack tops using 4 64-bit 28 | bitsets: 29 | 30 | ``` 31 | White uint64 32 | Black uint64 33 | Standing uint64 34 | Caps uint64 35 | ``` 36 | 37 | `White` and `Black` store the color of the topmost piece in each 38 | stack, and are mutually exclusive. 39 | 40 | `Standing` and `Caps` store whether the topmost piece is a standing 41 | stone or capstone, respectively, and are also exclusive. A piece 42 | present in `White` or `Black` but not in `Standing` or `Caps` is a 43 | road. 44 | 45 | This representation affords very efficient calculation of several 46 | valuable board features: 47 | 48 | - `White&Caps` gives the location of white's capstone(s). 49 | - `White&^Standing` gives a bitset containing all positions that may 50 | be part of a road for White. 51 | - `popcount(White&^(Standing|Caps))` gives white's current flat count. 52 | 53 | (and vice-versa for black). 54 | 55 | # Stacks 56 | 57 | We begin by noting that pieces in a stack are constrained to be flats, 58 | and therefore a single bit suffices to represent a single piece in a 59 | stack. By convention, we'll assign `1` to black, and `0` to white. 60 | 61 | We can then represent a single stack by defining its height, and 62 | defining its pieces as a set of bits. 63 | 64 | For 6x6 and smaller, a `uint64` suffices to represent the highest 65 | possible stack, even assuming all available pieces were to be stacked 66 | atop each other. 67 | 68 | (efficiently handling the rare but hypothetically possible overflow 69 | case on 8x8 is an open problem). 70 | 71 | We therefore define the stacks by using two parallel arrays, one 72 | holding height, and one holding the stack contents: 73 | 74 | 75 | ``` 76 | Height []uint8 77 | Stacks []uint64 78 | ``` 79 | 80 | By convention, in Taktician, `Height` stores the height *including* 81 | the top layer of the stack, but `Stacks` omits the top piece in the 82 | stack. Thus, the low `Height[i]-1` bits of `Stacks[i]` are 83 | significant. Taktician uses the lsb to represent the top of the stack. 84 | 85 | Stacking and unstacking are fairly simply implemented by bit shifts 86 | and masks. Managing the interplay between the board top and the stacks 87 | requires some finesse, but was deemed worth it in light of the 88 | significance of the board top. Further experimentation and 89 | benchmarking may be in order, however. 90 | --------------------------------------------------------------------------------