├── .gitignore ├── LICENSE ├── Othello_GPT_Circuits.ipynb ├── README.md ├── data ├── __init__.py └── othello.py ├── environment.yml ├── intervening_probe_interact_column.ipynb ├── intervention_benchmark.pkl ├── mechanistic_interpretability ├── board_seqs_int_small.npy ├── board_seqs_string_small.npy ├── main_linear_probe.pth ├── mech_interp_othello_utils.py ├── tl_exploration.py ├── tl_initial_exploration.py └── tl_probing_v1.py ├── mingpt ├── __init__.py ├── dataset.py ├── model.py ├── probe_model.py ├── probe_trainer.py ├── trainer.py └── utils.py ├── plot_attribution_via_intervention_othello.ipynb ├── produce_probes.sh ├── togglable ├── linear_champ.html ├── linear_random.html ├── linear_sync.html ├── nonlinear_champ.html ├── nonlinear_random.html └── nonlinear_sync.html ├── train_gpt_othello.ipynb └── train_probe_othello.py /.gitignore: -------------------------------------------------------------------------------- 1 | led / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # customized for files too large to get onto github 141 | data/othello_synthetic 142 | data/othello_championship 143 | ckpts/* 144 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Kenneth Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Update 02/13/2023 :fire::fire::fire: 2 | 3 | Neel Nanda just released a [TransformerLens](https://github.com/neelnanda-io/TransformerLens) version of Othello-GPT ([Colab](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Othello_GPT.ipynb), [Repo Notebook](https://github.com/neelnanda-io/TransformerLens/blob/main/demos/Othello_GPT.ipynb)), boosting the mechanistic interpretability research of it. Based on his work, a tool was made to inspect each MLP neuron in Othello-GPT, e.g. see the differing activation for [neuron 255 in layer 3](https://kran.ai/othelloscope/L2/N255) and [neuron 250 in layer 8](https://kran.ai/othelloscope/L7/N250). 4 | 5 | # Othello World 6 | 7 | This repository provides the code for training, probing and intervening the Othello-GPT in [Emergent World Representations: Exploring a Sequence Model Trained on a Synthetic Task](https://arxiv.org/abs/2210.13382), to be present at ICLR 2023. 8 | The implementation is based on [minGPT](https://github.com/karpathy/minGPT), thanks to Andrej Karpathy. 9 | 10 | ## Abstract 11 | 12 | > Language models show a surprising range of capabilities, but the source of their apparent competence is unclear. Do these networks just memorize a collection of surface statistics, or do they rely on internal representations of the process that generates the sequences they see? We investigate this question by applying a variant of the GPT model to the task of predicting legal moves in a simple board game, Othello. Although the network has no a priori knowledge of the game or its rules, we uncover evidence of an emergent nonlinear internal representation of the board state. Interventional experiments indicate this representation can be used to control the output of the network and create "latent saliency maps" that can help explain predictions in human terms. 13 | 14 | ## Table of Contents 15 | 16 | 1. [Installation](#installation) 17 | 2. [Training Othello-GPT](#training-othello-gpt) 18 | 3. [Probing Othello-GPT](#probing-othello-gpt) 19 | 4. [Intervening Othello-GPT](#intervening-othello-gpt) 20 | 5. [Attribution via Intervention Plots](#attribution-via-intervention-plots) 21 | 6. [How to Cite](#how-to-cite) 22 | 23 | ## Installation 24 | 25 | Some plotting functions require Latex on your machine: check [this FAQ](https://github.com/garrettj403/SciencePlots/wiki/FAQ#installing-latex) for how to install. 26 | Then use these commands to set up: 27 | ``` 28 | conda env create -f environment.yml 29 | conda activate othello 30 | python -m ipykernel install --user --name othello --display-name "othello" 31 | mkdir -p ckpts/battery_othello 32 | ``` 33 | 34 | ## Training Othello-GPT 35 | 36 | Download the [championship dataset](https://drive.google.com/drive/folders/1KFtP7gfrjmaoCV-WFC4XrdVeOxy1KmXe?usp=sharing) and the [synthetic dataset](https://drive.google.com/drive/folders/1pDMdMrnxMRiDnUd-CNfRNvZCi7VXFRtv?usp=sharing) and save them in `data` subfolder. 37 | Then see `train_gpt_othello.ipynb` for the training and validation. Alternatively, checkpoints can be downloaded from [here](https://drive.google.com/drive/folders/1bpnwJnccpr9W-N_hzXSm59hT7Lij4HxZ?usp=sharing) to skip this step. 38 | The default experiment setting requires $8$ GPU's and takes up to roughly $12$ Gigabytes memory on each. Once you set up the code, we can use `jupyter nbconvert --execute --to notebook --allow-errors --ExecutePreprocessor.timeout=-1 train_gpt_othello.ipynb --inplace --output ckpts/checkpoint.ipynb` to run it in background. 39 | 40 | ## Probing Othello-GPT 41 | 42 | Then we will use `train_probe_othello.py` to train probes. 43 | For example, if we want to train a nonlinear probe with hidden size $64$ on internal representations extracted from layer $6$ of the Othello-GPT trained on the championship dataset, we can use the command `python train_probe_othello.py --layer 6 --twolayer --mid_dim 64 --championship`. 44 | Checkpoints will be saved to `ckpts/battery_othello` or can be alternatively downloaded from [here](https://drive.google.com/drive/folders/1uvj_M9ekHDJVdVOvMq828Z23AE7jZ01H?usp=sharing). What produces the these checkpoints are `produce_probes.sh`. 45 | 46 | ## Intervening Othello-GPT 47 | 48 | See `intervening_probe_interact_column.ipynb` for the intervention experiment, where we can customize (1) which model to intervene on, (2) the pre-intervention board state (3) which square(s) to intervene on. 49 | 50 | ## Attribution via Intervention Plots 51 | 52 | See `plot_attribution_via_intervention_othello.ipynb` for the attribution via intervention experiment, where we can also customize (1) which model to intervene on, (2) the pre-intervention board state (3) which square(s) to attribute. 53 | 54 | ## How to Cite 55 | ``` 56 | @inproceedings{ 57 | li2023emergent, 58 | title={Emergent World Representations: Exploring a Sequence Model Trained on a Synthetic Task}, 59 | author={Kenneth Li and Aspen K Hopkins and David Bau and Fernanda Vi{\'e}gas and Hanspeter Pfister and Martin Wattenberg}, 60 | booktitle={The Eleventh International Conference on Learning Representations }, 61 | year={2023}, 62 | url={https://openreview.net/forum?id=DeG07_TcZvT} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .othello import get as get_othello 2 | import seaborn as sns 3 | import numpy as np 4 | import torch 5 | 6 | vv = .2 7 | 8 | def plot_probs(ax, probs, valids): 9 | assert probs.numel() == 64 10 | probs = probs.detach().cpu().numpy().reshape(8, 8) 11 | annot = [f"{_:.2f}" for _ in probs.flatten().tolist()] 12 | for valid_index in valids: 13 | annot[valid_index] = ("\\underline{" + annot[valid_index] + "}") 14 | # print(annot) 15 | sns.heatmap(probs, ax=ax, vmin=0, vmax=vv, 16 | yticklabels=list("ABCDEFGH"), xticklabels=list(range(1,9)), square=True, 17 | annot=np.array(annot).reshape(8, 8), cmap=sns.color_palette("Blues", as_cmap=True), fmt="", cbar=False) 18 | return ax 19 | 20 | def plot_mentals(ax, logits): 21 | assert logits.shape[0] == 64 22 | assert logits.shape[1] == 3 23 | probs = torch.softmax(logits, dim=-1) # [64, 3] 24 | probs, preds = torch.max(probs, dim=-1) # [64, ], [64, ] 25 | probs = probs.detach().cpu().numpy().reshape(8, 8) 26 | preds = preds.detach().cpu().numpy().reshape(8, 8) 27 | annot = [] 28 | for ele in preds.flatten().tolist(): 29 | if ele == 0: 30 | annot.append("O") 31 | elif ele == 1: 32 | annot.append(" ") 33 | else: 34 | annot.append("X") 35 | sns.heatmap(probs, ax=ax, vmin=0, vmax=1., 36 | yticklabels=list("ABCDEFGH"), xticklabels=list(range(1,9)), square=True, 37 | annot=np.array(annot).reshape(8, 8), cmap=sns.color_palette("Blues", as_cmap=True), fmt="", cbar=False) 38 | return ax -------------------------------------------------------------------------------- /data/othello.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pgn 3 | import numpy as np 4 | import random 5 | from tqdm import tqdm 6 | import time 7 | import multiprocessing 8 | import pickle 9 | import psutil 10 | import seaborn as sns 11 | import itertools 12 | from copy import copy, deepcopy 13 | from matplotlib.patches import Rectangle, Circle 14 | from matplotlib.collections import PatchCollection 15 | from matplotlib.colors import ListedColormap 16 | import matplotlib.patches as mpatches 17 | from matplotlib.colors import LinearSegmentedColormap 18 | 19 | rows = list("abcdefgh") 20 | columns = [str(_) for _ in range(1, 9)] 21 | 22 | mask = np.zeros(64).reshape(8, 8) 23 | mask[3, 3] = 1 24 | mask[3, 4] = 1 25 | mask[4, 3] = 1 26 | mask[4, 4] = 1 27 | mask = mask.astype(bool) 28 | 29 | class color: 30 | PURPLE = '\033[95m' 31 | CYAN = '\033[96m' 32 | DARKCYAN = '\033[36m' 33 | BLUE = '\033[94m' 34 | GREEN = '\033[92m' 35 | YELLOW = '\033[93m' 36 | RED = '\033[91m' 37 | BOLD = '\033[1m' 38 | UNDERLINE = '\033[4m' 39 | END = '\033[0m' 40 | 41 | # Othello is a strategy board game for two players (Black and White), played on an 8 by 8 board. 42 | # The game traditionally begins with four discs placed in the middle of the board as shown below. Black moves first. 43 | # W (27) B (28) 44 | # B (35) W (36) 45 | 46 | def permit(s): 47 | s = s.lower() 48 | if len(s) != 2: 49 | return -1 50 | if s[0] not in rows or s[1] not in columns: 51 | return -1 52 | return rows.index(s[0]) * 8 + columns.index(s[1]) 53 | 54 | def permit_reverse(integer): 55 | r, c = integer // 8, integer % 8 56 | return "".join([rows[r], columns[c]]) 57 | 58 | start_hands = [permit(_) for _ in ["d5", "d4", "e4", "e5"]] 59 | eights = [[-1, 0], [-1, 1], [0, 1], [1, 1], [1, 0], [1, -1], [0, -1], [-1, -1]] 60 | 61 | wanna_use = "othello_synthetic" 62 | 63 | class Othello: 64 | def __init__(self, ood_perc=0., data_root=None, wthor=False, ood_num=1000): 65 | # ood_perc: probability of swapping an in-distribution game (real championship game) 66 | # with a generated legit but stupid game, when data_root is None, should set to 0 67 | # data_root: if provided, will load pgn files there, else load from data/gen10e5 68 | # ood_num: how many simulated games to use, if -1, load 200 * 1e5 games = 20 million 69 | self.ood_perc = ood_perc 70 | self.sequences = [] 71 | self.results = [] 72 | self.board_size = 8 * 8 73 | criteria = lambda fn: fn.endswith("pgn") if wthor else fn.startswith("liveothello") 74 | if data_root is None: 75 | if ood_num == 0: 76 | return 77 | else: 78 | if ood_num != -1: # this setting used for generating synthetic dataset 79 | num_proc = multiprocessing.cpu_count() # use all processors 80 | p = multiprocessing.Pool(num_proc) 81 | for can in tqdm(p.imap(get_ood_game, range(ood_num)), total=ood_num): 82 | if not can in self.sequences: 83 | self.sequences.append(can) 84 | p.close() 85 | t_start = time.strftime("_%Y%m%d_%H%M%S") 86 | if ood_num > 1000: 87 | with open(f'./data/{wanna_use}/gen10e5_{t_start}.pickle', 'wb') as handle: 88 | pickle.dump(self.sequences, handle, protocol=pickle.HIGHEST_PROTOCOL) 89 | else: 90 | bar = tqdm(os.listdir(f"./data/{wanna_use}")) 91 | trash = [] 92 | cnt = 0 93 | for f in bar: 94 | if not f.endswith(".pickle"): 95 | continue 96 | with open(os.path.join(f"./data/{wanna_use}", f), 'rb') as handle: 97 | cnt += 1 98 | if cnt > 250: 99 | break 100 | b = pickle.load(handle) 101 | if len(b) < 9e4: # should be 1e5 each 102 | trash.append(f) 103 | continue 104 | self.sequences.extend(b) 105 | process = psutil.Process(os.getpid()) 106 | mem_gb = process.memory_info().rss / 2 ** 30 107 | bar.set_description(f"Mem Used: {mem_gb:.4} GB") 108 | print("Deduplicating...") 109 | seq = self.sequences 110 | seq.sort() 111 | self.sequences = [k for k, _ in itertools.groupby(seq)] 112 | for t in trash: 113 | os.remove(os.path.join(f"./data/{wanna_use}", f)) 114 | print(f"Deduplicating finished with {len(self.sequences)} games left") 115 | self.val = self.sequences[20000000:] 116 | self.sequences = self.sequences[:20000000] 117 | print(f"Using 20 million for training, {len(self.val)} for validation") 118 | else: 119 | for fn in os.listdir(data_root): 120 | if criteria(fn): 121 | with open(os.path.join(data_root, fn), "r") as f: 122 | pgn_text = f.read() 123 | games = pgn.loads(pgn_text) 124 | num_ldd = len(games) 125 | processed = [] 126 | res = [] 127 | for game in games: 128 | tba = [] 129 | for move in game.moves: 130 | x = permit(move) 131 | if x != -1: 132 | tba.append(x) 133 | else: 134 | break 135 | if len(tba) != 0: 136 | try: 137 | rr = [int(s) for s in game.result.split("-")] 138 | except: 139 | # print(game.result) 140 | # break 141 | rr = [0, 0] 142 | res.append(rr) 143 | processed.append(tba) 144 | 145 | num_psd = len(processed) 146 | print(f"Loaded {num_psd}/{num_ldd} (qualified/total) sequences from {fn}") 147 | self.sequences.extend(processed) 148 | self.results.extend(res) 149 | 150 | def __len__(self, ): 151 | return len(self.sequences) 152 | def __getitem__(self, i): 153 | if random.random() < self.ood_perc: 154 | tbr = get_ood_game(0) 155 | else: 156 | tbr = self.sequences[i] 157 | return tbr 158 | 159 | def get_ood_game(_): 160 | tbr = [] 161 | ab = OthelloBoardState() 162 | possible_next_steps = ab.get_valid_moves() 163 | while possible_next_steps: 164 | next_step = random.choice(possible_next_steps) 165 | tbr.append(next_step) 166 | ab.update([next_step, ]) 167 | possible_next_steps = ab.get_valid_moves() 168 | return tbr 169 | 170 | def get(ood_perc=0., data_root=None, wthor=False, ood_num=1000): 171 | return Othello(ood_perc, data_root, wthor, ood_num) 172 | 173 | class OthelloBoardState(): 174 | # 1 is black, -1 is white 175 | def __init__(self, board_size = 8): 176 | self.board_size = board_size * board_size 177 | board = np.zeros((8, 8)) 178 | board[3, 4] = 1 179 | board[3, 3] = -1 180 | board[4, 3] = 1 181 | board[4, 4] = -1 182 | self.initial_state = board 183 | self.state = self.initial_state 184 | self.age = np.zeros((8, 8)) 185 | self.next_hand_color = 1 186 | self.history = [] 187 | 188 | def get_occupied(self, ): 189 | board = self.state 190 | tbr = board.flatten() != 0 191 | return tbr.tolist() 192 | def get_state(self, ): 193 | board = self.state + 1 # white 0, blank 1, black 2 194 | tbr = board.flatten() 195 | return tbr.tolist() 196 | def get_age(self, ): 197 | return self.age.flatten().tolist() 198 | def get_next_hand_color(self, ): 199 | return (self.next_hand_color + 1) // 2 200 | 201 | def update(self, moves, prt=False): 202 | # takes a new move or new moves and update state 203 | if prt: 204 | self.__print__() 205 | for _, move in enumerate(moves): 206 | self.umpire(move) 207 | if prt: 208 | self.__print__() 209 | 210 | def umpire(self, move): 211 | r, c = move // 8, move % 8 212 | assert self.state[r, c] == 0, f"{r}-{c} is already occupied!" 213 | occupied = np.sum(self.state != 0) 214 | color = self.next_hand_color 215 | tbf = [] 216 | for direction in eights: 217 | buffer = [] 218 | cur_r, cur_c = r, c 219 | while 1: 220 | cur_r, cur_c = cur_r + direction[0], cur_c + direction[1] 221 | if cur_r < 0 or cur_r > 7 or cur_c < 0 or cur_c > 7: 222 | break 223 | if self.state[cur_r, cur_c] == 0: 224 | break 225 | elif self.state[cur_r, cur_c] == color: 226 | tbf.extend(buffer) 227 | break 228 | else: 229 | buffer.append([cur_r, cur_c]) 230 | if len(tbf) == 0: # means one hand is forfeited 231 | # print(f"One {color} move forfeited") 232 | color *= -1 233 | self.next_hand_color *= -1 234 | for direction in eights: 235 | buffer = [] 236 | cur_r, cur_c = r, c 237 | while 1: 238 | cur_r, cur_c = cur_r + direction[0], cur_c + direction[1] 239 | if cur_r < 0 or cur_r > 7 or cur_c < 0 or cur_c > 7: 240 | break 241 | if self.state[cur_r, cur_c] == 0: 242 | break 243 | elif self.state[cur_r, cur_c] == color: 244 | tbf.extend(buffer) 245 | break 246 | else: 247 | buffer.append([cur_r, cur_c]) 248 | if len(tbf) == 0: 249 | valids = self.get_valid_moves() 250 | if len(valids) == 0: 251 | assert 0, "Both color cannot put piece, game should have ended!" 252 | else: 253 | assert 0, "Illegal move!" 254 | 255 | self.age += 1 256 | for ff in tbf: 257 | self.state[ff[0], ff[1]] *= -1 258 | self.age[ff[0], ff[1]] = 0 259 | self.state[r, c] = color 260 | self.age[r, c] = 0 261 | self.next_hand_color *= -1 262 | self.history.append(move) 263 | 264 | def __print__(self, ): 265 | print("-"*20) 266 | print([permit_reverse(_) for _ in self.history]) 267 | a = "abcdefgh" 268 | for k, row in enumerate(self.state.tolist()): 269 | tbp = [] 270 | for ele in row: 271 | if ele == -1: 272 | tbp.append("O") 273 | elif ele == 0: 274 | tbp.append(" ") 275 | else: 276 | tbp.append("X") 277 | # tbp.append("\n") 278 | print(" ".join([a[k]] + tbp)) 279 | tbp = [str(k) for k in range(1, 9)] 280 | print(" ".join([" "] + tbp)) 281 | print("-"*20) 282 | 283 | def plot_hm(self, ax, heatmap, pdmove, logit=False): 284 | padding = np.array([0., 0.]) 285 | trs = {-1: r'O', 0: " ", 1: r'X'} 286 | if len(heatmap) == 60: 287 | heatmap = [heatmap[:27], padding, heatmap[27:33], padding, heatmap[33:]] 288 | heatmap = np.concatenate(heatmap) 289 | assert len(heatmap) == 64 290 | heatmap = np.array(heatmap).reshape(8, 8) 291 | annot = [trs[_] for _ in self.state.flatten().tolist()] 292 | cloned = deepcopy(self) 293 | cloned.update([pdmove, ]) 294 | 295 | next_color = 1 - cloned.get_next_hand_color() 296 | annot[pdmove] = ("\\underline{" + (trs[next_color * 2 -1]) + "}")[-13:] 297 | 298 | color = {-1:'white', 0:'grey', 1:'black'} 299 | ann_col = [color[_] for _ in self.state.flatten().tolist()] 300 | # ann_col[pdmove] = color[next_color * 2 -1] 301 | text_for_next_color = color[next_color * 2 -1].capitalize() 302 | 303 | del cloned 304 | if logit: 305 | max_logit = np.max(np.abs(heatmap)) 306 | sns.heatmap(data=heatmap, cbar=False, xticklabels=list(range(1,9)), 307 | # cmap=LinearSegmentedColormap.from_list("custom_cmap", ["#D3D3D3", "#3349F2"]), 308 | cmap=sns.color_palette("vlag", as_cmap=True), 309 | yticklabels=list("ABCDEFGH"), ax=ax, fmt="", square=True, linewidths=.5, vmin=-max_logit, vmax=max_logit, center=0) 310 | else: 311 | sns.heatmap(data=heatmap, cbar=False, xticklabels=list(range(1,9)), 312 | # cmap=LinearSegmentedColormap.from_list("custom_cmap", ["#D3D3D3", "#B90E0A"]), 313 | cmap=sns.color_palette("vlag", as_cmap=True), 314 | yticklabels=list("ABCDEFGH"), ax=ax, fmt="", square=True, linewidths=.5, vmin=-1, vmax=1, center=0) 315 | ax.set_title(f"Prediction: {text_for_next_color} at " + permit_reverse(pdmove).upper()) 316 | ax.add_patch(Rectangle((pdmove%8, pdmove//8), 1, 1, fill=False, edgecolor='black', lw=2)) 317 | 318 | patchList = [] 319 | for loca, col in enumerate(ann_col): 320 | if col != 'grey': 321 | patchList.append(PatchCollection([mpatches.Circle((loca%8 + 0.5, loca//8 + 0.5) ,.25, facecolor=col)], match_original=True)) 322 | for i in patchList: 323 | ax.add_collection(i) 324 | return ax 325 | 326 | def tentative_move(self, move): 327 | # tentatively put a piece, do nothing to state 328 | # returns 0 if this is not a move at all: occupied or both player have to forfeit 329 | # return 1 if regular move 330 | # return 2 if forfeit happens but the opponent can drop piece at this place 331 | r, c = move // 8, move % 8 332 | if not self.state[r, c] == 0: 333 | return 0 334 | occupied = np.sum(self.state != 0) 335 | color = self.next_hand_color 336 | tbf = [] 337 | for direction in eights: 338 | buffer = [] 339 | cur_r, cur_c = r, c 340 | while 1: 341 | cur_r, cur_c = cur_r + direction[0], cur_c + direction[1] 342 | if cur_r < 0 or cur_r > 7 or cur_c < 0 or cur_c > 7: 343 | break 344 | if self.state[cur_r, cur_c] == 0: 345 | break 346 | elif self.state[cur_r, cur_c] == color: 347 | tbf.extend(buffer) 348 | break 349 | else: 350 | buffer.append([cur_r, cur_c]) 351 | if len(tbf) != 0: 352 | return 1 353 | else: # means one hand is forfeited 354 | # print(f"One {color} move forfeited") 355 | color *= -1 356 | # self.next_hand_color *= -1 357 | for direction in eights: 358 | buffer = [] 359 | cur_r, cur_c = r, c 360 | while 1: 361 | cur_r, cur_c = cur_r + direction[0], cur_c + direction[1] 362 | if cur_r < 0 or cur_r > 7 or cur_c < 0 or cur_c > 7: 363 | break 364 | if self.state[cur_r, cur_c] == 0: 365 | break 366 | elif self.state[cur_r, cur_c] == color: 367 | tbf.extend(buffer) 368 | break 369 | else: 370 | buffer.append([cur_r, cur_c]) 371 | if len(tbf) == 0: 372 | return 0 373 | else: 374 | return 2 375 | 376 | def get_valid_moves(self, ): 377 | regular_moves = [] 378 | forfeit_moves = [] 379 | for move in range(64): 380 | x = self.tentative_move(move) 381 | if x == 1: 382 | regular_moves.append(move) 383 | elif x == 2: 384 | forfeit_moves.append(move) 385 | else: 386 | pass 387 | if len(regular_moves): 388 | return regular_moves 389 | elif len(forfeit_moves): 390 | return forfeit_moves 391 | else: 392 | return [] 393 | 394 | def get_gt(self, moves, func, prt=False): 395 | # takes a new move or new moves and update state 396 | container = [] 397 | if prt: 398 | self.__print__() 399 | for _, move in enumerate(moves): 400 | self.umpire(move) 401 | container.append(getattr(self, func)()) 402 | # to predict first y, we need already know the first x 403 | if prt: 404 | self.__print__() 405 | return container 406 | 407 | if __name__ == "__main__": 408 | pass 409 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: othello 2 | channels: 3 | - pytorch-lts 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - alabaster=0.7.12=pyhd3eb1b0_0 9 | - anaconda-client=1.7.2=py38_0 10 | - anaconda-project=0.9.1=pyhd3eb1b0_1 11 | - anyio=2.2.0=py38h06a4308_1 12 | - appdirs=1.4.4=py_0 13 | - argh=0.26.2=py38_0 14 | - argon2-cffi=20.1.0=py38h27cfd23_1 15 | - asn1crypto=1.4.0=py_0 16 | - astroid=2.5=py38h06a4308_1 17 | - astropy=4.2.1=py38h27cfd23_1 18 | - async_generator=1.10=pyhd3eb1b0_0 19 | - atomicwrites=1.4.0=py_0 20 | - attrs=20.3.0=pyhd3eb1b0_0 21 | - autopep8=1.5.6=pyhd3eb1b0_0 22 | - babel=2.9.0=pyhd3eb1b0_0 23 | - backcall=0.2.0=pyhd3eb1b0_0 24 | - backports=1.0=pyhd3eb1b0_2 25 | - backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0 26 | - backports.shutil_get_terminal_size=1.0.0=pyhd3eb1b0_3 27 | - backports.tempfile=1.0=pyhd3eb1b0_1 28 | - backports.weakref=1.0.post1=py_1 29 | - beautifulsoup4=4.9.3=pyha847dfd_0 30 | - bitarray=2.1.0=py38h27cfd23_1 31 | - bkcharts=0.2=py38_0 32 | - black=19.10b0=py_0 33 | - blas=1.0=mkl 34 | - bleach=3.3.0=pyhd3eb1b0_0 35 | - blosc=1.21.0=h8c45485_0 36 | - bokeh=2.3.2=py38h06a4308_0 37 | - boto=2.49.0=py38_0 38 | - bottleneck=1.3.2=py38heb32a55_1 39 | - brotlipy=0.7.0=py38h27cfd23_1003 40 | - bzip2=1.0.8=h7b6447c_0 41 | - c-ares=1.17.1=h27cfd23_0 42 | - ca-certificates=2021.10.26=h06a4308_2 43 | - cairo=1.16.0=hf32fb01_1 44 | - certifi=2021.10.8=py38h06a4308_0 45 | - cffi=1.14.5=py38h261ae71_0 46 | - chardet=4.0.0=py38h06a4308_1003 47 | - click=7.1.2=pyhd3eb1b0_0 48 | - cloudpickle=1.6.0=py_0 49 | - clyent=1.2.2=py38_1 50 | - colorama=0.4.4=pyhd3eb1b0_0 51 | - conda-content-trust=0.1.1=pyhd3eb1b0_0 52 | - conda-package-handling=1.7.3=py38h27cfd23_1 53 | - conda-repo-cli=1.0.4=pyhd3eb1b0_0 54 | - conda-verify=3.4.2=py_1 55 | - contextlib2=0.6.0.post1=py_0 56 | - cryptography=3.4.7=py38hd23ed53_0 57 | - cudatoolkit=11.1.74=h6bb024c_0 58 | - curl=7.71.1=hbc83047_1 59 | - cycler=0.10.0=py38_0 60 | - cython=0.29.23=py38h2531618_0 61 | - cytoolz=0.11.0=py38h7b6447c_0 62 | - dask=2021.4.0=pyhd3eb1b0_0 63 | - dask-core=2021.4.0=pyhd3eb1b0_0 64 | - dbus=1.13.18=hb2f20db_0 65 | - decorator=5.0.6=pyhd3eb1b0_0 66 | - defusedxml=0.7.1=pyhd3eb1b0_0 67 | - diff-match-patch=20200713=py_0 68 | - distributed=2021.4.1=py38h06a4308_0 69 | - docutils=0.17.1=py38h06a4308_1 70 | - entrypoints=0.3=py38_0 71 | - et_xmlfile=1.0.1=py_1001 72 | - expat=2.3.0=h2531618_2 73 | - fastcache=1.1.0=py38h7b6447c_0 74 | - ffmpeg=4.2.2=h20bf706_0 75 | - filelock=3.0.12=pyhd3eb1b0_1 76 | - flake8=3.9.0=pyhd3eb1b0_0 77 | - flask=1.1.2=pyhd3eb1b0_0 78 | - fontconfig=2.13.1=h6c09931_0 79 | - freetype=2.10.4=h5ab3b9f_0 80 | - fribidi=1.0.10=h7b6447c_0 81 | - fsspec=0.9.0=pyhd3eb1b0_0 82 | - future=0.18.2=py38_1 83 | - get_terminal_size=1.0.0=haa9412d_0 84 | - gevent=21.1.2=py38h27cfd23_1 85 | - glib=2.68.1=h36276a3_0 86 | - glob2=0.7=pyhd3eb1b0_0 87 | - gmp=6.2.1=h2531618_2 88 | - gmpy2=2.0.8=py38hd5f6e3b_3 89 | - gnutls=3.6.15=he1e5248_0 90 | - graphite2=1.3.14=h23475e2_0 91 | - greenlet=1.0.0=py38h2531618_2 92 | - gst-plugins-base=1.14.0=h8213a91_2 93 | - gstreamer=1.14.0=h28cd5cc_2 94 | - h5py=2.10.0=py38h7918eee_0 95 | - harfbuzz=2.8.0=h6f93f22_0 96 | - hdf5=1.10.4=hb1b8bf9_0 97 | - heapdict=1.0.1=py_0 98 | - html5lib=1.1=py_0 99 | - icu=58.2=he6710b0_3 100 | - idna=2.10=pyhd3eb1b0_0 101 | - imageio=2.9.0=pyhd3eb1b0_0 102 | - imagesize=1.2.0=pyhd3eb1b0_0 103 | - importlib-metadata=3.10.0=py38h06a4308_0 104 | - importlib_metadata=3.10.0=hd3eb1b0_0 105 | - iniconfig=1.1.1=pyhd3eb1b0_0 106 | - intel-openmp=2021.2.0=h06a4308_610 107 | - intervaltree=3.1.0=py_0 108 | - ipykernel=5.3.4=py38h5ca1d4c_0 109 | - ipython=7.22.0=py38hb070fc8_0 110 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 111 | - ipywidgets=7.6.3=pyhd3eb1b0_1 112 | - isort=5.8.0=pyhd3eb1b0_0 113 | - itsdangerous=1.1.0=pyhd3eb1b0_0 114 | - jbig=2.1=hdba287a_0 115 | - jdcal=1.4.1=py_0 116 | - jedi=0.17.2=py38h06a4308_1 117 | - jeepney=0.6.0=pyhd3eb1b0_0 118 | - jinja2=2.11.3=pyhd3eb1b0_0 119 | - joblib=1.0.1=pyhd3eb1b0_0 120 | - jpeg=9b=h024ee3a_2 121 | - json5=0.9.5=py_0 122 | - jsonschema=3.2.0=py_2 123 | - jupyter=1.0.0=py38_7 124 | - jupyter-packaging=0.7.12=pyhd3eb1b0_0 125 | - jupyter_client=6.1.12=pyhd3eb1b0_0 126 | - jupyter_console=6.4.0=pyhd3eb1b0_0 127 | - jupyter_core=4.7.1=py38h06a4308_0 128 | - jupyter_server=1.4.1=py38h06a4308_0 129 | - jupyterlab=3.2.1=pyhd3eb1b0_1 130 | - jupyterlab_pygments=0.1.2=py_0 131 | - jupyterlab_server=2.4.0=pyhd3eb1b0_0 132 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 133 | - keyring=22.3.0=py38h06a4308_0 134 | - kiwisolver=1.3.1=py38h2531618_0 135 | - krb5=1.18.2=h173b8e3_0 136 | - lame=3.100=h7b6447c_0 137 | - lazy-object-proxy=1.6.0=py38h27cfd23_0 138 | - lcms2=2.12=h3be6417_0 139 | - ld_impl_linux-64=2.33.1=h53a641e_7 140 | - libarchive=3.4.2=h62408e4_0 141 | - libcurl=7.71.1=h20c2e04_1 142 | - libedit=3.1.20210216=h27cfd23_1 143 | - libev=4.33=h7b6447c_0 144 | - libffi=3.3=he6710b0_2 145 | - libgcc-ng=9.1.0=hdf63c60_0 146 | - libgfortran-ng=7.3.0=hdf63c60_0 147 | - libidn2=2.3.2=h7f8727e_0 148 | - liblief=0.10.1=he6710b0_0 149 | - libllvm10=10.0.1=hbcb73fb_5 150 | - libopus=1.3.1=h7b6447c_0 151 | - libpng=1.6.37=hbc83047_0 152 | - libsodium=1.0.18=h7b6447c_0 153 | - libspatialindex=1.9.3=h2531618_0 154 | - libssh2=1.9.0=h1ba5d50_1 155 | - libstdcxx-ng=9.1.0=hdf63c60_0 156 | - libtasn1=4.16.0=h27cfd23_0 157 | - libtiff=4.2.0=h85742a9_0 158 | - libtool=2.4.6=h7b6447c_1005 159 | - libunistring=0.9.10=h27cfd23_0 160 | - libuuid=1.0.3=h1bed415_2 161 | - libuv=1.40.0=h7b6447c_0 162 | - libvpx=1.7.0=h439df22_0 163 | - libwebp-base=1.2.0=h27cfd23_0 164 | - libxcb=1.14=h7b6447c_0 165 | - libxml2=2.9.10=hb55368b_3 166 | - libxslt=1.1.34=hc22bd24_0 167 | - llvmlite=0.36.0=py38h612dafd_4 168 | - locket=0.2.1=py38h06a4308_1 169 | - lxml=4.6.3=py38h9120a33_0 170 | - lz4-c=1.9.3=h2531618_0 171 | - lzo=2.10=h7b6447c_2 172 | - markupsafe=1.1.1=py38h7b6447c_0 173 | - matplotlib=3.3.4=py38h06a4308_0 174 | - matplotlib-base=3.3.4=py38h62a2d02_0 175 | - mccabe=0.6.1=py38_1 176 | - mistune=0.8.4=py38h7b6447c_1000 177 | - mkl=2021.2.0=h06a4308_296 178 | - mkl-service=2.3.0=py38h27cfd23_1 179 | - mkl_fft=1.3.0=py38h42c9631_2 180 | - mkl_random=1.2.1=py38ha9443f7_2 181 | - mock=4.0.3=pyhd3eb1b0_0 182 | - more-itertools=8.7.0=pyhd3eb1b0_0 183 | - mpc=1.1.0=h10f8cd9_1 184 | - mpfr=4.0.2=hb69a4c5_1 185 | - mpmath=1.2.1=py38h06a4308_0 186 | - msgpack-python=1.0.2=py38hff7bd54_1 187 | - multipledispatch=0.6.0=py38_0 188 | - mypy_extensions=0.4.3=py38_0 189 | - navigator-updater=0.2.1=py38_0 190 | - nbclassic=0.2.6=pyhd3eb1b0_0 191 | - nbclient=0.5.3=pyhd3eb1b0_0 192 | - nbconvert=6.0.7=py38_0 193 | - nbformat=5.1.3=pyhd3eb1b0_0 194 | - ncurses=6.2=he6710b0_1 195 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 196 | - nettle=3.7.3=hbbd107a_1 197 | - networkx=2.5=py_0 198 | - ninja=1.10.2=hff7bd54_1 199 | - nltk=3.6.1=pyhd3eb1b0_0 200 | - nose=1.3.7=pyhd3eb1b0_1006 201 | - notebook=6.3.0=py38h06a4308_0 202 | - numba=0.53.1=py38ha9443f7_0 203 | - numexpr=2.7.3=py38h22e1b3c_1 204 | - numpy=1.20.1=py38h93e21f0_0 205 | - numpy-base=1.20.1=py38h7d8b39e_0 206 | - numpydoc=1.1.0=pyhd3eb1b0_1 207 | - olefile=0.46=py_0 208 | - openh264=2.1.0=hd408876_0 209 | - openpyxl=3.0.7=pyhd3eb1b0_0 210 | - openssl=1.1.1l=h7f8727e_0 211 | - packaging=20.9=pyhd3eb1b0_0 212 | - pandas=1.2.4=py38h2531618_0 213 | - pandoc=2.12=h06a4308_0 214 | - pandocfilters=1.4.3=py38h06a4308_1 215 | - pango=1.45.3=hd140c19_0 216 | - parso=0.7.0=py_0 217 | - partd=1.2.0=pyhd3eb1b0_0 218 | - patchelf=0.12=h2531618_1 219 | - path=15.1.2=py38h06a4308_0 220 | - path.py=12.5.0=0 221 | - pathlib2=2.3.5=py38h06a4308_2 222 | - pathspec=0.7.0=py_0 223 | - patsy=0.5.1=py38_0 224 | - pcre=8.44=he6710b0_0 225 | - pep8=1.7.1=py38_0 226 | - pexpect=4.8.0=pyhd3eb1b0_3 227 | - pickleshare=0.7.5=pyhd3eb1b0_1003 228 | - pillow=8.2.0=py38he98fc37_0 229 | - pip=21.0.1=py38h06a4308_0 230 | - pixman=0.40.0=h7b6447c_0 231 | - pkginfo=1.7.0=py38h06a4308_0 232 | - pluggy=0.13.1=py38h06a4308_0 233 | - ply=3.11=py38_0 234 | - prometheus_client=0.10.1=pyhd3eb1b0_0 235 | - prompt-toolkit=3.0.17=pyh06a4308_0 236 | - prompt_toolkit=3.0.17=hd3eb1b0_0 237 | - psutil=5.8.0=py38h27cfd23_1 238 | - ptyprocess=0.7.0=pyhd3eb1b0_2 239 | - py=1.10.0=pyhd3eb1b0_0 240 | - py-lief=0.10.1=py38h403a769_0 241 | - pycodestyle=2.6.0=pyhd3eb1b0_0 242 | - pycosat=0.6.3=py38h7b6447c_1 243 | - pycparser=2.20=py_2 244 | - pycurl=7.43.0.6=py38h1ba5d50_0 245 | - pydocstyle=6.0.0=pyhd3eb1b0_0 246 | - pyerfa=1.7.3=py38h27cfd23_0 247 | - pyflakes=2.2.0=pyhd3eb1b0_0 248 | - pygments=2.8.1=pyhd3eb1b0_0 249 | - pylint=2.7.4=py38h06a4308_1 250 | - pyls-black=0.4.6=hd3eb1b0_0 251 | - pyls-spyder=0.3.2=pyhd3eb1b0_0 252 | - pyodbc=4.0.30=py38he6710b0_0 253 | - pyopenssl=20.0.1=pyhd3eb1b0_1 254 | - pyparsing=2.4.7=pyhd3eb1b0_0 255 | - pyqt=5.9.2=py38h05f1152_4 256 | - pyrsistent=0.17.3=py38h7b6447c_0 257 | - pysocks=1.7.1=py38h06a4308_0 258 | - pytables=3.6.1=py38h9fd0a39_0 259 | - pytest=6.2.3=py38h06a4308_2 260 | - python=3.8.8=hdb3f193_5 261 | - python-dateutil=2.8.1=pyhd3eb1b0_0 262 | - python-jsonrpc-server=0.4.0=py_0 263 | - python-language-server=0.36.2=pyhd3eb1b0_0 264 | - python-libarchive-c=2.9=pyhd3eb1b0_1 265 | - pytorch=1.8.2=py3.8_cuda11.1_cudnn8.0.5_0 266 | - pytz=2021.1=pyhd3eb1b0_0 267 | - pywavelets=1.1.1=py38h7b6447c_2 268 | - pyxdg=0.27=pyhd3eb1b0_0 269 | - pyyaml=5.4.1=py38h27cfd23_1 270 | - pyzmq=20.0.0=py38h2531618_1 271 | - qdarkstyle=2.8.1=py_0 272 | - qt=5.9.7=h5867ecd_1 273 | - qtawesome=1.0.2=pyhd3eb1b0_0 274 | - qtconsole=5.0.3=pyhd3eb1b0_0 275 | - qtpy=1.9.0=py_0 276 | - readline=8.1=h27cfd23_0 277 | - regex=2021.4.4=py38h27cfd23_0 278 | - requests=2.25.1=pyhd3eb1b0_0 279 | - ripgrep=12.1.1=0 280 | - rope=0.18.0=py_0 281 | - rtree=0.9.7=py38h06a4308_1 282 | - ruamel_yaml=0.15.100=py38h27cfd23_0 283 | - scikit-image=0.18.1=py38ha9443f7_0 284 | - scikit-learn=0.24.1=py38ha9443f7_0 285 | - scipy=1.6.2=py38had2a1c9_1 286 | - seaborn=0.11.1=pyhd3eb1b0_0 287 | - secretstorage=3.3.1=py38h06a4308_0 288 | - send2trash=1.5.0=pyhd3eb1b0_1 289 | - setuptools=52.0.0=py38h06a4308_0 290 | - simplegeneric=0.8.1=py38_2 291 | - singledispatch=3.6.1=pyhd3eb1b0_1001 292 | - sip=4.19.13=py38he6710b0_0 293 | - six=1.15.0=py38h06a4308_0 294 | - sniffio=1.2.0=py38h06a4308_1 295 | - snowballstemmer=2.1.0=pyhd3eb1b0_0 296 | - sortedcollections=2.1.0=pyhd3eb1b0_0 297 | - sortedcontainers=2.3.0=pyhd3eb1b0_0 298 | - soupsieve=2.2.1=pyhd3eb1b0_0 299 | - sphinx=4.0.1=pyhd3eb1b0_0 300 | - sphinxcontrib=1.0=py38_1 301 | - sphinxcontrib-applehelp=1.0.2=pyhd3eb1b0_0 302 | - sphinxcontrib-devhelp=1.0.2=pyhd3eb1b0_0 303 | - sphinxcontrib-htmlhelp=1.0.3=pyhd3eb1b0_0 304 | - sphinxcontrib-jsmath=1.0.1=pyhd3eb1b0_0 305 | - sphinxcontrib-qthelp=1.0.3=pyhd3eb1b0_0 306 | - sphinxcontrib-serializinghtml=1.1.4=pyhd3eb1b0_0 307 | - sphinxcontrib-websupport=1.2.4=py_0 308 | - spyder=4.2.5=py38h06a4308_0 309 | - spyder-kernels=1.10.2=py38h06a4308_0 310 | - sqlalchemy=1.4.15=py38h27cfd23_0 311 | - sqlite=3.35.4=hdfb4753_0 312 | - statsmodels=0.12.2=py38h27cfd23_0 313 | - sympy=1.8=py38h06a4308_0 314 | - tbb=2020.3=hfd86e86_0 315 | - tblib=1.7.0=py_0 316 | - terminado=0.9.4=py38h06a4308_0 317 | - testpath=0.4.4=pyhd3eb1b0_0 318 | - textdistance=4.2.1=pyhd3eb1b0_0 319 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 320 | - three-merge=0.1.1=pyhd3eb1b0_0 321 | - tifffile=2020.10.1=py38hdd07704_2 322 | - tk=8.6.10=hbc83047_0 323 | - toml=0.10.2=pyhd3eb1b0_0 324 | - toolz=0.11.1=pyhd3eb1b0_0 325 | - torchaudio=0.8.2=py38 326 | - torchvision=0.9.2=py38_cu111 327 | - tornado=6.1=py38h27cfd23_0 328 | - tqdm=4.59.0=pyhd3eb1b0_1 329 | - traitlets=5.0.5=pyhd3eb1b0_0 330 | - typed-ast=1.4.2=py38h27cfd23_1 331 | - typing_extensions=3.7.4.3=pyha847dfd_0 332 | - ujson=4.0.2=py38h2531618_0 333 | - unicodecsv=0.14.1=py38_0 334 | - unixodbc=2.3.9=h7b6447c_0 335 | - urllib3=1.26.4=pyhd3eb1b0_0 336 | - watchdog=1.0.2=py38h06a4308_1 337 | - wcwidth=0.2.5=py_0 338 | - webencodings=0.5.1=py38_1 339 | - werkzeug=1.0.1=pyhd3eb1b0_0 340 | - wheel=0.36.2=pyhd3eb1b0_0 341 | - widgetsnbextension=3.5.1=py38_0 342 | - wrapt=1.12.1=py38h7b6447c_1 343 | - wurlitzer=2.1.0=py38h06a4308_0 344 | - x264=1!157.20191217=h7b6447c_0 345 | - xlrd=2.0.1=pyhd3eb1b0_0 346 | - xlsxwriter=1.3.8=pyhd3eb1b0_0 347 | - xlwt=1.3.0=py38_0 348 | - xmltodict=0.12.0=py_0 349 | - xz=5.2.5=h7b6447c_0 350 | - yaml=0.2.5=h7b6447c_0 351 | - yapf=0.31.0=pyhd3eb1b0_0 352 | - zeromq=4.3.4=h2531618_0 353 | - zict=2.0.0=pyhd3eb1b0_0 354 | - zipp=3.4.1=pyhd3eb1b0_0 355 | - zlib=1.2.11=h7b6447c_3 356 | - zope=1.0=py38_1 357 | - zope.event=4.5.0=py38_0 358 | - zope.interface=5.3.0=py38h27cfd23_0 359 | - zstd=1.4.5=h9ceee32_0 360 | - pip: 361 | - pgnparser==1.0 362 | -------------------------------------------------------------------------------- /intervention_benchmark.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/likenneth/othello_world/f23bb5696cf30b93bd8af8a391ee33fc3aac417e/intervention_benchmark.pkl -------------------------------------------------------------------------------- /mechanistic_interpretability/board_seqs_int_small.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/likenneth/othello_world/f23bb5696cf30b93bd8af8a391ee33fc3aac417e/mechanistic_interpretability/board_seqs_int_small.npy -------------------------------------------------------------------------------- /mechanistic_interpretability/board_seqs_string_small.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/likenneth/othello_world/f23bb5696cf30b93bd8af8a391ee33fc3aac417e/mechanistic_interpretability/board_seqs_string_small.npy -------------------------------------------------------------------------------- /mechanistic_interpretability/main_linear_probe.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/likenneth/othello_world/f23bb5696cf30b93bd8af8a391ee33fc3aac417e/mechanistic_interpretability/main_linear_probe.pth -------------------------------------------------------------------------------- /mechanistic_interpretability/tl_probing_v1.py: -------------------------------------------------------------------------------- 1 | import transformer_lens.utils as utils 2 | from transformer_lens import HookedTransformer, HookedTransformerConfig 3 | from mech_interp_othello_utils import OthelloBoardState 4 | import einops 5 | import torch 6 | from tqdm import tqdm 7 | import numpy as np 8 | from fancy_einsum import einsum 9 | 10 | cfg = HookedTransformerConfig( 11 | n_layers=8, 12 | d_model=512, 13 | d_head=64, 14 | n_heads=8, 15 | d_mlp=2048, 16 | d_vocab=61, 17 | n_ctx=59, 18 | act_fn="gelu", 19 | normalization_type="LNPre", 20 | ) 21 | model = HookedTransformer(cfg) 22 | 23 | 24 | sd = utils.download_file_from_hf( 25 | "NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth" 26 | ) 27 | # champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth") 28 | model.load_state_dict(sd) 29 | 30 | # %% 31 | board_seqs_int = torch.tensor(np.load("board_seqs_int_small.npy")).long() 32 | board_seqs_string = torch.tensor(np.load("board_seqs_string_small.npy")) 33 | # %% 34 | def seq_to_state_stack(str_moves): 35 | if isinstance(str_moves, torch.Tensor): 36 | str_moves = str_moves.tolist() 37 | board = OthelloBoardState() 38 | states = [] 39 | for move in str_moves: 40 | board.umpire(move) 41 | states.append(np.copy(board.state)) 42 | states = np.stack(states, axis=0) 43 | return states 44 | 45 | 46 | state_stack = torch.tensor( 47 | np.stack([seq_to_state_stack(seq) for seq in board_seqs_string[:50, :-1]]) 48 | ) 49 | print(state_stack.shape) 50 | # %% 51 | 52 | # %% 53 | layer = 6 54 | batch_size = 100 55 | lr = 1e-4 56 | wd = 0.01 57 | pos_start = 5 58 | pos_end = model.cfg.n_ctx - 5 59 | length = pos_end - pos_start 60 | options = 3 61 | rows = 8 62 | cols = 8 63 | num_epochs = 2 64 | num_games = 100000 65 | x = 0 66 | y = 2 67 | probe_name = "main_linear_probe" 68 | # The first mode is blank or not, the second mode is next or prev GIVEN that it is not blank 69 | modes = 3 70 | alternating = torch.tensor([1 if i%2 == 0 else -1 for i in range(length)], device="cuda") 71 | 72 | 73 | def state_stack_to_one_hot(state_stack): 74 | one_hot = torch.zeros( 75 | modes, # blank vs color (mode) 76 | state_stack.shape[0], # num games 77 | state_stack.shape[1], # num moves 78 | rows, # rows 79 | cols, # cols 80 | options, # the two options 81 | device=state_stack.device, 82 | dtype=torch.int, 83 | ) 84 | one_hot[:, ..., 0] = state_stack == 0 85 | one_hot[:, ..., 1] = state_stack == -1 86 | one_hot[:, ..., 2] = state_stack == 1 87 | return one_hot 88 | state_stack_one_hot = state_stack_to_one_hot(state_stack) 89 | print(state_stack_one_hot.shape) 90 | print((state_stack_one_hot[:, 0, 17, 4:9, 2:5])) 91 | print((state_stack[0, 17, 4:9, 2:5])) 92 | # %% 93 | linear_probe = torch.randn( 94 | modes, model.cfg.d_model, rows, cols, options, requires_grad=False, device="cuda" 95 | )/np.sqrt(model.cfg.d_model) 96 | linear_probe.requires_grad = True 97 | optimiser = torch.optim.AdamW([linear_probe], lr=lr, betas=(0.9, 0.99), weight_decay=wd) 98 | 99 | 100 | for epoch in range(num_epochs): 101 | full_train_indices = torch.randperm(num_games) 102 | for i in tqdm(range(0, num_games, batch_size)): 103 | indices = full_train_indices[i:i+batch_size] 104 | games_int = board_seqs_int[indices] 105 | games_str = board_seqs_string[indices] 106 | state_stack = torch.stack( 107 | [torch.tensor(seq_to_state_stack(games_str[i])) for i in range(batch_size)] 108 | ) 109 | state_stack = state_stack[:, pos_start:pos_end, :, :] 110 | 111 | state_stack_one_hot = state_stack_to_one_hot(state_stack).cuda() 112 | with torch.inference_mode(): 113 | _, cache = model.run_with_cache(games_int.cuda()[:, :-1], return_type=None) 114 | resid_post = cache["resid_post", layer][:, pos_start:pos_end] 115 | probe_out = einsum( 116 | "batch pos d_model, modes d_model rows cols options -> modes batch pos rows cols options", 117 | resid_post, 118 | linear_probe, 119 | ) 120 | # print(probe_out.shape) 121 | 122 | # acc_blank = (probe_out[0].argmax(-1) == state_stack_one_hot[0].argmax(-1)).float().mean() 123 | # acc_color = ((probe_out[1].argmax(-1) == state_stack_one_hot[1].argmax(-1)) * state_stack_one_hot[1].sum(-1)).float().sum()/(state_stack_one_hot[1]).float().sum() 124 | 125 | probe_log_probs = probe_out.log_softmax(-1) 126 | probe_correct_log_probs = einops.reduce( 127 | probe_log_probs * state_stack_one_hot, 128 | "modes batch pos rows cols options -> modes pos rows cols", 129 | "mean" 130 | ) * options # Multiply to correct for the mean over options 131 | loss_even = -probe_correct_log_probs[0, 0::2].mean(0).sum() # note that "even" means odd in the game framing, since we offset by 5 moves lol 132 | loss_odd = -probe_correct_log_probs[1, 1::2].mean(0).sum() 133 | loss_all = -probe_correct_log_probs[2, :].mean(0).sum() 134 | 135 | loss = loss_even + loss_odd + loss_all 136 | loss.backward() # it's important to do a single backward pass for mysterious PyTorch reasons, so we add up the losses - it's per mode and per square. 137 | 138 | optimiser.step() 139 | optimiser.zero_grad() 140 | torch.save(linear_probe, f"{probe_name}.pth") 141 | # %% 142 | # %% -------------------------------------------------------------------------------- /mingpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/likenneth/othello_world/f23bb5696cf30b93bd8af8a391ee33fc3aac417e/mingpt/__init__.py -------------------------------------------------------------------------------- /mingpt/dataset.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | class CharDataset(Dataset): 6 | def __init__(self, data): 7 | if hasattr(data, "ood_perc"): 8 | ood_perc = data.ood_perc 9 | data.ood_perc = 0 # shut down the randomness 10 | chars = sorted(list(set(list(itertools.chain.from_iterable(data)))) + [-100, ]) 11 | data_size, vocab_size = len(data), len(chars) # vocab size 61, with -100 sorted to the front 12 | max_len = max([len(data[_]) for _ in range(len(data))]) # should be 60 in Othello 13 | print('Dataset created has %d sequences, %d unique words.' % (data_size, vocab_size)) 14 | 15 | self.stoi = {ch: i for i, ch in enumerate(chars)} 16 | self.itos = {i: ch for i, ch in enumerate(chars)} 17 | self.max_len = max_len 18 | self.block_size = max_len - 1 # for autoregressive training 19 | self.vocab_size = vocab_size 20 | if hasattr(data, "ood_perc"): 21 | data.ood_perc = ood_perc # turn on the randomness 22 | self.data = data 23 | 24 | def __len__(self): 25 | return len(self.data) 26 | 27 | def __getitem__(self, idx): 28 | # grab a chunk of (block_size + 1) characters from the data 29 | chunk = self.data[idx] 30 | if len(chunk) != self.max_len: 31 | chunk += [-100, ] * (self.max_len - len(chunk)) # -100 can be ignored in CE 32 | # encode every character to an integer 33 | dix = [self.stoi[s] for s in chunk] 34 | """ 35 | arrange data and targets so that the first i elements of x 36 | will be asked to predict the i-th element of y. Notice that 37 | the eventual language model will actually make block_size 38 | individual predictions at the same time based on this data, 39 | so we are being clever and amortizing the cost of the forward 40 | pass of the network. So for example if block_size is 4, then 41 | we could e.g. sample a chunk of text "hello", the integers in 42 | x will correspond to "hell" and in y will be "ello". This will 43 | then actually "multitask" 4 separate examples at the same time 44 | in the language model: 45 | - given just "h", please predict "e" as next 46 | - given "he" please predict "l" next 47 | - given "hel" predict "l" next 48 | - given "hell" predict "o" next 49 | 50 | In addition, because the DataLoader will create batches of examples, 51 | every forward/backward pass during traning will simultaneously train 52 | a LOT of predictions, amortizing a lot of computation. In particular, 53 | for a batched input of integers X (B, T) where B is batch size and 54 | T is block_size and Y (B, T), the network will during training be 55 | simultaneously training to make B*T predictions, all at once! Of course, 56 | at test time we can paralellize across batch B, but unlike during training 57 | we cannot parallelize across the time dimension T - we have to run 58 | a forward pass of the network to recover the next single character of the 59 | sequence along each batch dimension, and repeatedly always feed in a next 60 | character to get the next one. 61 | 62 | So yes there is a big asymmetry between train/test time of autoregressive 63 | models. During training we can go B*T at a time with every forward pass, 64 | but during test time we can only go B at a time, T times, with T forward 65 | passes. 66 | """ 67 | x = torch.tensor(dix[:-1], dtype=torch.long) 68 | y = torch.tensor(dix[1:], dtype=torch.long) 69 | return x, y -------------------------------------------------------------------------------- /mingpt/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | GPT model: 3 | - the initial stem consists of a combination of token encoding and a positional encoding 4 | - the meat of it is a uniform sequence of Transformer blocks 5 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block 6 | - all blocks feed into a central residual pathway similar to resnets 7 | - the final decoder is a linear projection into a vanilla Softmax classifier 8 | """ 9 | 10 | import math 11 | import logging 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import functional as F 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | class GPTConfig: 20 | """ base GPT config, params common to all GPT versions """ 21 | embd_pdrop = 0.1 22 | resid_pdrop = 0.1 23 | attn_pdrop = 0.1 24 | 25 | def __init__(self, vocab_size, block_size, **kwargs): 26 | self.vocab_size = vocab_size 27 | self.block_size = block_size 28 | for k,v in kwargs.items(): 29 | setattr(self, k, v) 30 | 31 | class CausalSelfAttention(nn.Module): 32 | """ 33 | A vanilla multi-head masked self-attention layer with a projection at the end. 34 | It is possible to use torch.nn.MultiheadAttention here but I am including an 35 | explicit implementation here to show that there is nothing too scary here. 36 | """ 37 | 38 | def __init__(self, config): 39 | super().__init__() 40 | assert config.n_embd % config.n_head == 0 41 | # key, query, value projections for all heads 42 | self.key = nn.Linear(config.n_embd, config.n_embd) 43 | self.query = nn.Linear(config.n_embd, config.n_embd) 44 | self.value = nn.Linear(config.n_embd, config.n_embd) 45 | # regularization 46 | self.attn_drop = nn.Dropout(config.attn_pdrop) 47 | self.resid_drop = nn.Dropout(config.resid_pdrop) 48 | # output projection 49 | self.proj = nn.Linear(config.n_embd, config.n_embd) 50 | # causal mask to ensure that attention is only applied to the left in the input sequence 51 | self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) 52 | .view(1, 1, config.block_size, config.block_size)) 53 | self.n_head = config.n_head 54 | 55 | def forward(self, x, layer_past=None, only_last=-1): 56 | B, T, C = x.size() 57 | 58 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 59 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 60 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 61 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 62 | 63 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 64 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 65 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 66 | if only_last != -1: 67 | att[:, :, -only_last:, :-only_last] = float('-inf') 68 | att = F.softmax(att, dim=-1) 69 | y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 70 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 71 | 72 | # output projection 73 | y = self.resid_drop(self.proj(y)) 74 | return y, att 75 | 76 | class Block(nn.Module): 77 | """ an unassuming Transformer block """ 78 | 79 | def __init__(self, config): 80 | super().__init__() 81 | self.ln1 = nn.LayerNorm(config.n_embd) 82 | self.ln2 = nn.LayerNorm(config.n_embd) 83 | self.attn = CausalSelfAttention(config) 84 | self.mlp = nn.Sequential( 85 | nn.Linear(config.n_embd, 4 * config.n_embd), 86 | nn.GELU(), 87 | nn.Linear(4 * config.n_embd, config.n_embd), 88 | nn.Dropout(config.resid_pdrop), 89 | ) 90 | 91 | def forward(self, x, return_att=False, only_last=-1): 92 | updt, att = self.attn(self.ln1(x), only_last=only_last) 93 | x = x + updt 94 | x = x + self.mlp(self.ln2(x)) 95 | if return_att: 96 | return x, att 97 | else: 98 | return x 99 | 100 | class GPT(nn.Module): 101 | """ the full GPT language model, with a context size of block_size """ 102 | 103 | def __init__(self, config): 104 | super().__init__() 105 | 106 | # input embedding stem 107 | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) 108 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) 109 | self.drop = nn.Dropout(config.embd_pdrop) 110 | # transformer 111 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 112 | self.n_layer = config.n_layer 113 | # decoder head 114 | self.ln_f = nn.LayerNorm(config.n_embd) 115 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 116 | 117 | self.block_size = config.block_size 118 | self.apply(self._init_weights) 119 | 120 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) 121 | 122 | def get_block_size(self): 123 | return self.block_size 124 | 125 | def _init_weights(self, module): 126 | if isinstance(module, (nn.Linear, nn.Embedding)): 127 | module.weight.data.normal_(mean=0.0, std=0.02) 128 | if isinstance(module, nn.Linear) and module.bias is not None: 129 | module.bias.data.zero_() 130 | elif isinstance(module, nn.LayerNorm): 131 | module.bias.data.zero_() 132 | module.weight.data.fill_(1.0) 133 | 134 | def configure_optimizers(self, train_config): 135 | """ 136 | This long function is unfortunately doing something very simple and is being very defensive: 137 | We are separating out all parameters of the model into two buckets: those that will experience 138 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 139 | We are then returning the PyTorch optimizer object. 140 | """ 141 | 142 | # separate out all parameters to those that will and won't experience regularizing weight decay 143 | decay = set() 144 | no_decay = set() 145 | whitelist_weight_modules = (torch.nn.Linear, ) 146 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 147 | for mn, m in self.named_modules(): 148 | for pn, p in m.named_parameters(): 149 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 150 | 151 | if pn.endswith('bias'): 152 | # all biases will not be decayed 153 | no_decay.add(fpn) 154 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 155 | # weights of whitelist modules will be weight decayed 156 | decay.add(fpn) 157 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 158 | # weights of blacklist modules will NOT be weight decayed 159 | no_decay.add(fpn) 160 | 161 | # special case the position embedding parameter in the root GPT module as not decayed 162 | no_decay.add('pos_emb') 163 | 164 | # validate that we considered every parameter 165 | param_dict = {pn: p for pn, p in self.named_parameters()} 166 | inter_params = decay & no_decay 167 | union_params = decay | no_decay 168 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 169 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 170 | % (str(param_dict.keys() - union_params), ) 171 | 172 | # create the pytorch optimizer object 173 | optim_groups = [ 174 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, 175 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 176 | ] 177 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) 178 | return optimizer 179 | 180 | def forward(self, idx, targets=None): 181 | b, t = idx.size() # both of shape [B, T] 182 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 183 | 184 | # forward the GPT model 185 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 186 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 187 | x = self.drop(token_embeddings + position_embeddings) 188 | x = self.blocks(x) 189 | x = self.ln_f(x) # [B, T, f] 190 | logits = self.head(x) # [B, T, # Words] 191 | # if we are given some desired targets also calculate the loss 192 | loss = None 193 | if targets is not None: 194 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0) # -100 in the string space is mapped to 0 in the index space 195 | return logits, loss 196 | 197 | class GPTforProbing(GPT): 198 | def __init__(self, config, probe_layer=-1, ln=False): 199 | super(GPTforProbing, self).__init__(config) 200 | # we probe the activation after the self.probe_layer-th layer 201 | self.probe_layer = self.n_layer if probe_layer == -1 else probe_layer 202 | assert self.probe_layer <= self.n_layer and self.probe_layer >= 0, "Invalid layer index to probe" 203 | self.ln = ln 204 | 205 | def forward(self, idx, return_att=False): 206 | b, t = idx.size() # both of shape [B, T] 207 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 208 | 209 | # forward the GPT model 210 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 211 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 212 | x = self.drop(token_embeddings + position_embeddings) 213 | 214 | for b in self.blocks[:self.probe_layer]: 215 | if return_att: 216 | x, att = b(x, return_att=return_att) 217 | else: 218 | x = b(x, return_att=return_att) 219 | 220 | # x = self.blocks(x) 221 | if self.ln: 222 | x = self.ln_f(x) # [B, T, f] 223 | # logits = self.head(x) # [B, T, # Words] 224 | if return_att: 225 | return x, att 226 | else: 227 | return x 228 | 229 | class GPTforIntervention(GPT): 230 | def __init__(self, config, probe_layer=-1): 231 | super(GPTforIntervention, self).__init__(config) 232 | # we probe the activation after the self.probe_layer-th layer 233 | self.probe_layer = self.n_layer if probe_layer == -1 else probe_layer 234 | assert self.probe_layer <= self.n_layer and self.probe_layer >= 1, "Invalid layer index to probe" 235 | 236 | def forward_1st_stage(self, idx): 237 | b, t = idx.size() # both of shape [B, T] 238 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 239 | 240 | # forward the GPT model 241 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 242 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 243 | x = self.drop(token_embeddings + position_embeddings) 244 | 245 | for b in self.blocks[:self.probe_layer]: 246 | x = b(x) 247 | 248 | # x = self.blocks(x) 249 | # x = self.ln_f(x) # [B, T, f] 250 | # logits = self.head(x) # [B, T, # Words] 251 | return x 252 | 253 | def forward_2nd_stage(self, x, targets=None, only_last=-1): 254 | for b in self.blocks[self.probe_layer:]: 255 | x = b(x, only_last=only_last) 256 | x = self.ln_f(x) # [B, T, f] 257 | logits = self.head(x) # [B, T, # Words] 258 | # if we are given some desired targets also calculate the loss 259 | loss = None 260 | if targets is not None: 261 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100) 262 | return logits, loss 263 | 264 | class GPTforProbeIA(GPT): 265 | # probe interaction 266 | # how probes between different layers interact with each other 267 | def __init__(self, config, probe_layer=-1): 268 | super(GPTforProbeIA, self).__init__(config) 269 | # we probe the activation after the self.probe_layer-th layer 270 | self.probe_layer = self.n_layer if probe_layer == -1 else probe_layer 271 | assert self.probe_layer <= self.n_layer and self.probe_layer >= 0, "Invalid layer index to probe" 272 | 273 | def forward_1st_stage(self, idx): 274 | b, t = idx.size() # both of shape [B, T] 275 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 276 | 277 | # forward the GPT model 278 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 279 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 280 | x = self.drop(token_embeddings + position_embeddings) 281 | 282 | for b in self.blocks[:self.probe_layer]: 283 | x = b(x) 284 | 285 | # x = self.blocks(x) 286 | # x = self.ln_f(x) # [B, T, f] 287 | # logits = self.head(x) # [B, T, # Words] 288 | return x 289 | 290 | def forward_2nd_stage(self, x, start_layer, end_layer=-1): 291 | tbr = [] 292 | if end_layer == -1: 293 | end_layer = self.n_layer + 1 294 | for b in self.blocks[start_layer: end_layer]: 295 | x = b(x) 296 | tbr.append(x) 297 | # x = self.ln_f(x) # [B, T, f] 298 | return tbr 299 | 300 | def predict(self, x, targets=None): 301 | x = self.ln_f(x) # [B, T, f] 302 | logits = self.head(x) # [B, T, # Words] 303 | # if we are given some desired targets also calculate the loss 304 | loss = None 305 | if targets is not None: 306 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100) 307 | return logits, loss -------------------------------------------------------------------------------- /mingpt/probe_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | class BatteryProbeClassification(nn.Module): 9 | # combines 64 classification problem for the case of Othello 10 | def __init__(self, device, probe_class, num_task, input_dim=512): # from 0 to 15 11 | super().__init__() 12 | self.input_dim = input_dim 13 | self.probe_class = probe_class 14 | self.num_task = num_task 15 | self.proj = nn.Linear(self.input_dim, self.probe_class * self.num_task, bias=True) 16 | self.apply(self._init_weights) 17 | self.to(device) 18 | def forward(self, act, y=None): 19 | # [B, f], [B, #task] 20 | logits = self.proj(act).reshape(-1, self.num_task, self.probe_class) # [B, #task, C] 21 | if y is None: 22 | return logits, None 23 | else: 24 | targets = y.to(torch.long) 25 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100) 26 | return logits, loss 27 | 28 | def _init_weights(self, module): 29 | if isinstance(module, (nn.Linear, nn.Embedding)): 30 | module.weight.data.normal_(mean=0.0, std=0.02) 31 | if isinstance(module, nn.Linear) and module.bias is not None: 32 | module.bias.data.zero_() 33 | elif isinstance(module, nn.LayerNorm): 34 | module.bias.data.zero_() 35 | module.weight.data.fill_(1.0) 36 | def configure_optimizers(self, train_config): 37 | """ 38 | This long function is unfortunately doing something very simple and is being very defensive: 39 | We are separating out all parameters of the model into two buckets: those that will experience 40 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 41 | We are then returning the PyTorch optimizer object. 42 | """ 43 | # separate out all parameters to those that will and won't experience regularizing weight decay 44 | decay = set() 45 | no_decay = set() 46 | whitelist_weight_modules = (torch.nn.Linear, ) 47 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 48 | for mn, m in self.named_modules(): 49 | for pn, p in m.named_parameters(): 50 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 51 | if pn.endswith('bias'): 52 | # biases of whitelist modules will be weight decayed 53 | decay.add(fpn) 54 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 55 | # weights of whitelist modules will be weight decayed 56 | decay.add(fpn) 57 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 58 | # weights of blacklist modules will NOT be weight decayed 59 | no_decay.add(fpn) 60 | 61 | # special case the position embedding parameter in the root GPT module as not decayed 62 | # no_decay.add('pos_emb') 63 | 64 | # validate that we considered every parameter 65 | param_dict = {pn: p for pn, p in self.named_parameters()} 66 | inter_params = decay & no_decay 67 | union_params = decay | no_decay 68 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 69 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 70 | % (str(param_dict.keys() - union_params), ) 71 | print("Decayed:", decay) 72 | # create the pytorch optimizer object 73 | optim_groups = [ 74 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, 75 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 76 | ] 77 | optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) 78 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.75, patience=0) 79 | return optimizer, scheduler 80 | 81 | class BatteryProbeClassificationTwoLayer(nn.Module): 82 | # combines 64 classification problem for the case of Othello 83 | def __init__(self, device, probe_class, num_task, mid_dim, input_dim=512): # from 0 to 15 84 | super().__init__() 85 | self.input_dim = input_dim 86 | self.probe_class = probe_class 87 | self.num_task = num_task 88 | self.mid_dim = mid_dim 89 | self.proj = nn.Sequential( 90 | nn.Linear(self.input_dim, self.mid_dim, bias=True), 91 | nn.ReLU(True), 92 | nn.Linear(self.mid_dim, self.probe_class * self.num_task, bias=True), 93 | ) 94 | self.apply(self._init_weights) 95 | self.to(device) 96 | def forward(self, act, y=None): 97 | # [B, f], [B, #task] 98 | logits = self.proj(act).reshape(-1, self.num_task, self.probe_class) # [B, #task, C] 99 | if y is None: 100 | return logits, None 101 | else: 102 | targets = y.to(torch.long) 103 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100) 104 | return logits, loss 105 | 106 | def _init_weights(self, module): 107 | if isinstance(module, (nn.Linear, nn.Embedding)): 108 | module.weight.data.normal_(mean=0.0, std=0.02) 109 | if isinstance(module, nn.Linear) and module.bias is not None: 110 | module.bias.data.zero_() 111 | elif isinstance(module, nn.LayerNorm): 112 | module.bias.data.zero_() 113 | module.weight.data.fill_(1.0) 114 | def configure_optimizers(self, train_config): 115 | """ 116 | This long function is unfortunately doing something very simple and is being very defensive: 117 | We are separating out all parameters of the model into two buckets: those that will experience 118 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 119 | We are then returning the PyTorch optimizer object. 120 | """ 121 | # separate out all parameters to those that will and won't experience regularizing weight decay 122 | decay = set() 123 | no_decay = set() 124 | whitelist_weight_modules = (torch.nn.Linear, ) 125 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 126 | for mn, m in self.named_modules(): 127 | for pn, p in m.named_parameters(): 128 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 129 | if pn.endswith('bias'): 130 | # biases of whitelist modules will be weight decayed 131 | decay.add(fpn) 132 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 133 | # weights of whitelist modules will be weight decayed 134 | decay.add(fpn) 135 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 136 | # weights of blacklist modules will NOT be weight decayed 137 | no_decay.add(fpn) 138 | 139 | # special case the position embedding parameter in the root GPT module as not decayed 140 | # no_decay.add('pos_emb') 141 | 142 | # validate that we considered every parameter 143 | param_dict = {pn: p for pn, p in self.named_parameters()} 144 | inter_params = decay & no_decay 145 | union_params = decay | no_decay 146 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 147 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 148 | % (str(param_dict.keys() - union_params), ) 149 | print("Decayed:", decay) 150 | # create the pytorch optimizer object 151 | optim_groups = [ 152 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, 153 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 154 | ] 155 | optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) 156 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.75, patience=0) 157 | return optimizer, scheduler -------------------------------------------------------------------------------- /mingpt/probe_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple training loop; Boilerplate that could apply to any arbitrary neural network, 3 | so nothing in this file really has anything to do with GPT specifically. 4 | """ 5 | import os 6 | import math 7 | import logging 8 | 9 | from tqdm import tqdm 10 | import numpy as np 11 | import json 12 | import torch 13 | import torch.optim as optim 14 | from torch.optim.lr_scheduler import LambdaLR 15 | from torch.utils.data.dataloader import DataLoader 16 | from matplotlib import pyplot as plt 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | class TrainerConfig: 21 | # optimization parameters 22 | max_epochs = 10 23 | batch_size = 64 24 | learning_rate = 3e-4 25 | betas = (0.9, 0.95) 26 | grad_norm_clip = 1.0 27 | weight_decay = 0.1 # only applied on matmul weights 28 | # learning rate decay params: linear warmup followed by cosine decay to 10% of original 29 | lr_decay = False 30 | warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere 31 | final_tokens = 260e9 # (at what point we reach 10% of original LR) 32 | # checkpoint settings 33 | ckpt_path = None 34 | num_workers = 0 # for DataLoader 35 | 36 | def __init__(self, **kwargs): 37 | for k,v in kwargs.items(): 38 | setattr(self, k, v) 39 | 40 | class Trainer: 41 | def __init__(self, model, train_dataset, test_dataset, config): 42 | self.model = model 43 | self.train_dataset = train_dataset 44 | self.test_dataset = test_dataset 45 | self.config = config 46 | 47 | # take over whatever gpus are on the system 48 | self.device = 'cpu' 49 | if torch.cuda.is_available(): 50 | self.device = torch.cuda.current_device() 51 | self.model = torch.nn.DataParallel(self.model).to(self.device) 52 | 53 | # log something for plotting 54 | self.train_loss_cont = [] 55 | self.test_loss_cont = [] 56 | self.train_acc_cont = [] 57 | self.test_acc_cont = [] 58 | # would be a list of T-long, each is a lits of 60-long, for stratified accuracies 59 | self.train_strat_acc_cont = [] 60 | self.test_strat_acc_cont = [] 61 | 62 | def flush_plot(self, ): 63 | # plt.close() 64 | fig, axs = plt.subplots(1, 2, figsize=(20, 10), dpi= 80, facecolor='w', edgecolor='k') 65 | axs = axs.flat 66 | axs[0].plot(self.train_loss_cont, label="train") 67 | axs[0].plot(self.test_loss_cont, label="test") 68 | axs[0].set_title("Loss") 69 | axs[0].legend() 70 | axs[1].plot(self.train_acc_cont, label="train") 71 | axs[1].plot(self.test_acc_cont, label="test") 72 | axs[1].set_title("Accuracy") 73 | axs[1].legend() 74 | plt.show() 75 | 76 | def save_traces(self, ): 77 | tbd = { 78 | "train_loss_cont": self.train_loss_cont, "test_loss_cont" :self.test_loss_cont, 79 | "train_acc_cont": self.train_acc_cont, "test_acc_cont": self.test_acc_cont, 80 | "train_strat_acc_cont": self.train_strat_acc_cont, "test_strat_acc_cont": self.test_strat_acc_cont, 81 | } 82 | with open(os.path.join(self.config.ckpt_path, "tensorboard.txt"), "w") as f: 83 | f.write(json.dumps(tbd) + "\n") 84 | 85 | def save_checkpoint(self): 86 | # DataParallel wrappers keep raw model object in .module attribute 87 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 88 | if not os.path.exists(self.config.ckpt_path): 89 | os.makedirs(self.config.ckpt_path) 90 | torch.save(raw_model.state_dict(), os.path.join(self.config.ckpt_path, "checkpoint.ckpt")) 91 | 92 | def train(self, prt=True): 93 | model, config = self.model, self.config 94 | raw_model = model.module if hasattr(self.model, "module") else model 95 | optimizer, scheduler = raw_model.configure_optimizers(config) 96 | 97 | def run_epoch(split): 98 | is_train = split == 'train' 99 | model.train(is_train) 100 | data = self.train_dataset if is_train else self.test_dataset 101 | loader = DataLoader(data, shuffle=True, pin_memory=True, 102 | batch_size=config.batch_size, 103 | num_workers=config.num_workers) 104 | 105 | losses = [] 106 | totals_epoch = np.zeros(60, dtype=float) # np.array of shape [60], for positions of age 0 to 59 107 | hits_epoch = np.zeros(60, dtype=float) # np.array of shape [60], for positions of age 0 to 59 108 | pbar = tqdm(enumerate(loader), total=len(loader), disable=not prt) if is_train else enumerate(loader) 109 | for it, (x, y, age) in pbar: 110 | x = x.to(self.device) # [B, f] 111 | y = y.to(self.device) # [B, #task=64] 112 | age = age.to(self.device) # [B, #task=64], in 0--59 113 | 114 | with torch.set_grad_enabled(is_train): 115 | logits, loss = model(x, y) 116 | loss = loss.mean() # collapse all losses if they are scattered on multiple gpus 117 | losses.append(loss.item()) 118 | totals_epoch += np.array([torch.sum(age == i).item() for i in range(60)]).astype(float) 119 | y_hat = torch.argmax(logits, dim=-1, keepdim=False) # [B, #task] 120 | hits = y_hat == y # [B, #task] 121 | hits_epoch += np.array([torch.sum(hits * (age == i)).item() for i in range(60)]).astype(float) 122 | 123 | if is_train: 124 | # backprop and update the parameters 125 | model.zero_grad() 126 | loss.backward() 127 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) 128 | optimizer.step() 129 | mean_loss = float(np.mean(losses)) 130 | mean_acc = np.sum(hits_epoch).item() / np.sum(totals_epoch).item() 131 | lr = optimizer.param_groups[0]['lr'] 132 | pbar.set_description(f"epoch {epoch+1}: train loss {mean_loss:.5f}; lr {lr:.2e}; train acc {mean_acc*100:.2f}%") 133 | if is_train: 134 | self.train_loss_cont.append(mean_loss) 135 | self.train_acc_cont.append(mean_acc) 136 | self.train_strat_acc_cont.append((hits_epoch / totals_epoch).tolist()) 137 | 138 | if not is_train: 139 | test_loss = float(np.mean(losses)) 140 | scheduler.step(test_loss) 141 | test_acc = np.sum(hits_epoch).item() / np.sum(totals_epoch).item() 142 | if prt: 143 | logger.info(f"test loss {test_loss:.5f}; test acc {test_acc*100:.2f}%") 144 | self.test_loss_cont.append(test_loss) 145 | self.test_acc_cont.append(test_acc) 146 | self.test_strat_acc_cont.append((hits_epoch / totals_epoch).tolist()) 147 | return test_loss 148 | 149 | best_loss = float('inf') 150 | self.tokens = 0 # counter used for learning rate decay 151 | 152 | for epoch in range(config.max_epochs): 153 | run_epoch('train') 154 | if self.test_dataset is not None: 155 | test_loss = run_epoch('test') 156 | if test_loss < best_loss: 157 | best_loss = test_loss 158 | self.save_checkpoint() 159 | -------------------------------------------------------------------------------- /mingpt/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple training loop; Boilerplate that could apply to any arbitrary neural network, 3 | so nothing in this file really has anything to do with GPT specifically. 4 | """ 5 | 6 | import math 7 | import logging 8 | 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torch.optim.lr_scheduler import LambdaLR 15 | from torch.utils.data.dataloader import DataLoader 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | class TrainerConfig: 20 | # optimization parameters 21 | max_epochs = 10 22 | batch_size = 64 23 | learning_rate = 3e-4 24 | betas = (0.9, 0.95) 25 | grad_norm_clip = 1.0 26 | weight_decay = 0.1 # only applied on matmul weights 27 | # learning rate decay params: linear warmup followed by cosine decay to 10% of original 28 | lr_decay = False 29 | warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere 30 | final_tokens = 260e9 # (at what point we reach 10% of original LR) 31 | # checkpoint settings 32 | ckpt_path = None 33 | num_workers = 0 # for DataLoader 34 | 35 | def __init__(self, **kwargs): 36 | for k,v in kwargs.items(): 37 | setattr(self, k, v) 38 | 39 | class Trainer: 40 | def __init__(self, model, train_dataset, test_dataset, config): 41 | self.model = model 42 | self.train_dataset = train_dataset 43 | self.test_dataset = test_dataset 44 | self.config = config 45 | 46 | # take over whatever gpus are on the system 47 | self.device = 'cpu' 48 | if torch.cuda.is_available(): 49 | self.device = torch.cuda.current_device() 50 | self.model = torch.nn.DataParallel(self.model).to(self.device) 51 | 52 | def save_checkpoint(self): 53 | # DataParallel wrappers keep raw model object in .module attribute 54 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 55 | logger.info("saving %s", self.config.ckpt_path) 56 | torch.save(raw_model.state_dict(), self.config.ckpt_path) 57 | 58 | def train(self): 59 | model, config = self.model, self.config 60 | raw_model = model.module if hasattr(self.model, "module") else model 61 | optimizer = raw_model.configure_optimizers(config) 62 | 63 | def run_epoch(split): 64 | is_train = split == 'train' 65 | model.train(is_train) 66 | data = self.train_dataset if is_train else self.test_dataset 67 | loader = DataLoader(data, shuffle=True, pin_memory=True, 68 | batch_size=config.batch_size, 69 | num_workers=config.num_workers) 70 | 71 | losses = [] 72 | pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader) 73 | for it, (x, y) in pbar: 74 | 75 | # place data on the correct device 76 | x = x.to(self.device) # [B, T] 77 | y = y.to(self.device) # [B, T] 78 | 79 | # forward the model 80 | with torch.set_grad_enabled(is_train): 81 | logits, loss = model(x, y) 82 | loss = loss.mean() # collapse all losses if they are scattered on multiple gpus 83 | losses.append(loss.item()) 84 | 85 | if is_train: 86 | # backprop and update the parameters 87 | model.zero_grad() 88 | loss.backward() 89 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) 90 | optimizer.step() 91 | 92 | # decay the learning rate based on our progress 93 | if config.lr_decay: 94 | self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) 95 | if self.tokens < config.warmup_tokens: 96 | # linear warmup 97 | lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens)) 98 | else: 99 | # cosine learning rate decay 100 | progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) 101 | lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) 102 | lr = config.learning_rate * lr_mult 103 | for param_group in optimizer.param_groups: 104 | param_group['lr'] = lr 105 | else: 106 | lr = config.learning_rate 107 | 108 | # report progress 109 | pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}") 110 | 111 | if not is_train: 112 | test_loss = float(np.mean(losses)) 113 | logger.info("test loss: %f", test_loss) 114 | return test_loss 115 | 116 | best_loss = float('inf') 117 | self.tokens = 0 # counter used for learning rate decay 118 | for epoch in range(config.max_epochs): 119 | 120 | run_epoch('train') 121 | if self.test_dataset is not None: 122 | test_loss = run_epoch('test') 123 | 124 | # supports early stopping based on the test loss, or just save always if no test set is provided 125 | if self.config.ckpt_path is not None: 126 | if self.test_dataset is None: 127 | self.save_checkpoint() 128 | continue 129 | if test_loss < best_loss: 130 | best_loss = test_loss 131 | self.save_checkpoint() 132 | 133 | -------------------------------------------------------------------------------- /mingpt/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from matplotlib import pyplot as plt 7 | 8 | from data.othello import permit, start_hands, OthelloBoardState, permit_reverse 9 | 10 | def set_seed(seed): 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | 16 | def top_k_logits(logits, k): 17 | v, ix = torch.topk(logits, k) 18 | out = logits.clone() 19 | out[out < v[:, [-1]]] = -float('Inf') 20 | return out 21 | 22 | @torch.no_grad() 23 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): 24 | """ 25 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in 26 | the sequence, feeding the predictions back into the model each time. Clearly the sampling 27 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window 28 | of block_size, unlike an RNN that has an infinite context window. 29 | """ 30 | block_size = model.get_block_size() 31 | model.eval() 32 | for k in range(steps): 33 | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed 34 | logits, _ = model(x_cond) 35 | # pluck the logits at the final step and scale by temperature 36 | logits = logits[:, -1, :] / temperature 37 | # optionally crop probabilities to only the top k options 38 | if top_k is not None: 39 | logits = top_k_logits(logits, top_k) 40 | # apply softmax to convert to probabilities 41 | probs = F.softmax(logits, dim=-1) 42 | # sample from the distribution or take the most likely 43 | if sample: 44 | ix = torch.multinomial(probs, num_samples=1) 45 | else: 46 | _, ix = torch.topk(probs, k=1, dim=-1) 47 | # append to the sequence and continue 48 | x = torch.cat((x, ix), dim=1) 49 | 50 | return x 51 | 52 | def print_board(labels): 53 | # torch tensor, [64], in 0--2 54 | bs_in_probe_mind = labels -1 55 | anob = OthelloBoardState() 56 | anob.state = bs_in_probe_mind.detach().cpu().numpy().reshape(8, 8) 57 | anob.__print__() 58 | 59 | def intervene(p, mid_act, labels_pre_intv, wtd, htd, plot=False): 60 | # p: probe model 61 | # mid_act: [512, ], the intervened, might not be at the lastest temporal position 62 | # labels_pre_intv, 63 | # wtd: a dict of intervention_position, intervention_from, intervention_to 64 | # htd: a dict of some intervention parameters 65 | # plot: not supported yet 66 | # return a new_mid_act 67 | new_mid_act = torch.tensor(mid_act.detach().cpu().numpy()).cuda() 68 | new_mid_act.requires_grad = True 69 | opt = torch.optim.Adam([new_mid_act], lr=htd["lr"]) 70 | 71 | labels_post_intv = labels_pre_intv.clone() 72 | weight_mask = htd["reg_strg"] * torch.ones(64).cuda() 73 | 74 | labels_post_intv[permit(wtd["intervention_position"])] = wtd["intervention_to"] 75 | weight_mask[permit(wtd["intervention_position"])] = 1 76 | 77 | logit_container = [] 78 | loss_container = [] 79 | for i in range(htd["steps"]): 80 | opt.zero_grad() 81 | logits_running = p(new_mid_act[None, :])[0][0] # [64, 3] 82 | logit_container.append(logits_running[permit(wtd["intervention_position"])].detach().cpu().numpy()) 83 | loss = F.cross_entropy(logits_running, labels_post_intv, reduction="none") 84 | loss = torch.mean(weight_mask * loss) 85 | loss.backward() # by torch semantics, loss is to be minimized 86 | loss_container.append(loss.item()) 87 | opt.step() 88 | if 0: 89 | logits = np.stack(logit_container, axis=0) 90 | plt.plot(logits[:, 0], color="r", label="White") 91 | plt.plot(logits[:, 1], color="g", label="Blank") 92 | plt.plot(logits[:, 2], color="b", label="Black") 93 | plt.legend() 94 | labels_post_intv_hat = logits_running.detach().argmax(dim=-1) # [64] 95 | num_error = torch.sum(labels_post_intv_hat - labels_post_intv).item() 96 | 97 | if plot: 98 | if num_error == 0: 99 | print(wtd["intervention_position"] + " Sucessfully intervened!") 100 | else: 101 | print(wtd["intervention_position"] + " Failed intervention! See the below two borads:") 102 | print("labels_post_intv_reality") 103 | print_board(labels_post_intv_hat) 104 | print("labels_post_intv_wished") 105 | print_board(labels_post_intv) 106 | 107 | return new_mid_act 108 | -------------------------------------------------------------------------------- /produce_probes.sh: -------------------------------------------------------------------------------- 1 | for X in {0..8} 2 | do 3 | 4 | CUDA_VISIBLE_DEVICES=0 python train_probe_othello.py --layer $X --random 5 | CUDA_VISIBLE_DEVICES=0 python train_probe_othello.py --layer $X --championship 6 | CUDA_VISIBLE_DEVICES=0 python train_probe_othello.py --layer $X 7 | 8 | for Y in {2,4,8,16,32,64,128,256,512} 9 | do 10 | CUDA_VISIBLE_DEVICES=0 python train_probe_othello.py --layer $X --twolayer --mid_dim $layer --random 11 | CUDA_VISIBLE_DEVICES=0 python train_probe_othello.py --layer $X --twolayer --mid_dim $layer --championship 12 | CUDA_VISIBLE_DEVICES=0 python train_probe_othello.py --layer $X --twolayer --mid_dim $layer 13 | done 14 | 15 | done -------------------------------------------------------------------------------- /togglable/linear_champ.html: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 |