├── .gitignore ├── README.md ├── board_state_interventions.py ├── caa.py ├── chess_utils.py ├── contrastive_activations └── lichess_train_layer_12_pos_start_25_activations.pt ├── data └── .gitignore ├── images ├── pawn_probe.png └── probe_acc_markers_graph.png ├── lichess_data_filtering.ipynb ├── linear_probes ├── .gitignore ├── analyze_test_results.ipynb ├── saved_probes │ ├── tf_lens_lichess_16layers_ckpt_no_optimizer_chess_piece_probe_layer_11.pth │ ├── tf_lens_lichess_8layers_ckpt_no_optimizer_chess_piece_probe_layer_5.pth │ ├── tf_lens_lichess_8layers_ckpt_no_optimizer_chess_skill_probe_layer_5.pth │ ├── tf_lens_lichess_8layers_ckpt_no_optimizer_chess_skill_probe_layer_7.pth │ ├── tf_lens_randominit_8layers_ckpt_no_optimizer_chess_piece_probe_layer_5.pth │ └── tf_lens_randominit_8layers_ckpt_no_optimizer_chess_skill_probe_layer_7.pth ├── test_data │ └── .gitignore └── view_probe.ipynb ├── model_setup.py ├── models ├── meta.pkl └── view_model.ipynb ├── othello_engine_utils.py ├── othello_utils.py ├── probe_output_visualization.ipynb ├── requirements.txt ├── tests ├── test_board_interventions.py ├── test_caa.py ├── test_chess_utils.py └── test_probe_training_and_eval.py ├── train_test_chess.py └── utils ├── board_grid_search_analysis.ipynb ├── chess_gpt_eval_data_filtering.ipynb ├── create_skill_intervention_from_skill_probe.ipynb ├── custom_functions_guide.md ├── nanogpt_to_transformer_lens.ipynb ├── othello_data_filtering.ipynb ├── unique_checks.ipynb └── view_caa.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | **/.vscode 3 | wandb/ 4 | 5 | data/* 6 | !data/train.csv 7 | !data/test.csv 8 | !data/train_test_splitter.ipynb 9 | 10 | models/* 11 | !models/meta.pkl 12 | !models/view_model.ipynb 13 | 14 | contrastive_activations/* 15 | intervention_logs/* 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # chess_llm_interpretability 2 | This evaluates LLMs trained on PGN format chess games through the use of linear probes. We can check the LLMs internal understanding of board state and ability to estimate the skill level of the players involved. We can also perform interventions on the model's internal board state by deleting pieces from its internal world model. 3 | 4 | This repo can train, evaluate, and visualize linear probes on LLMs that have been trained to play chess with PGN strings. For example, we can visualize where the model "thinks" the white pawns are. On the left, we have the actual white pawn location. In the middle, we clip the probe outputs to turn the heatmap into a more binary visualization. On the right, we have the full gradient of model beliefs. 5 | 6 | ![](/images/pawn_probe.png) 7 | 8 | I trained linear probes on both the model's ability to compute board state and estimate player ELO as it's predicting the next character. Here we can see a per layer graph of board state and elo classification accuracy across a range of LLMs. 9 | 10 | ![](/images/probe_acc_markers_graph.png) 11 | 12 | For more information, refer to this [post](https://adamkarvonen.github.io/machine_learning/2024/01/03/chess-world-models.html). 13 | 14 | # Setup 15 | 16 | Create a Python environment with Python 3.10 or 3.11 (I'm using 3.11). 17 | ``` 18 | pip install -r requirements.txt 19 | python model_setup.py 20 | ``` 21 | 22 | Then click "Run All" on `lichess_data_filtering.ipynb` (I'm filtering data in a notebook instead of a script because I use a series of graphs to illustrate what the data filtering is doing). 23 | To visualise probe outputs or better understand my work, check out `probe_output_visualization.ipynb`. It has commentary and many print statements to walk you through using a single probe and performing a single intervention. 24 | 25 | The `train_test_chess.py` script can be used to either train new linear probes or test a saved probe on the test set. 26 | 27 | Command line arguments: 28 | 29 | --mode: Specifies `train` or `test`. Optional, defaults to `train`. 30 | 31 | --probe: Determines the type of probe to be used. `piece` probes for the piece type on each square, `skill` probes the skill level of the White player. Optional, defaults to `piece`. 32 | 33 | 34 | Examples: 35 | 36 | Train piece board state probes: 37 | `python train_test_chess.py` 38 | 39 | Test skill probe: 40 | `python train_test_chess.py --mode test --probe skill` 41 | 42 | See all options: `python train_test_chess.py -h` 43 | 44 | To add new functions, refer to `utils/custom_functions_guide.md`. 45 | 46 | All experiments in this repo can be done with less than 1 GB of VRAM. Training probes on the 8 layer model takes about 10 minutes on my RTX 3050. 47 | 48 | # OthelloGPT 49 | 50 | This repo can also be used for training linear probes on OthelloGPT. Refer to `utils/othello_data_filtering.ipynb`. 51 | 52 | # Interventions 53 | 54 | To perform board state interventions on one layer, run `python board_state_interventions.py`. It will record JSON results in `intervention_logs/`. To get better results, train a set of 8 (one per layer) board state probes using `train_test_chess.py` and rerun. 55 | 56 | To perform skill interventions, you can train a set of 8 skill probes using `train_test_chess.py` or generate a set of 8 contrastive activations using `caa.py`. Note that contrastive activations tend to work a little better. If you want to use probe derived interventions, use this script to create activation files from the probes: `utils/create_skill_intervention_from_skill_probe.ipynb`. 57 | 58 | Then, follow these directions to use them to perform skill interventions: https://github.com/adamkarvonen/chess_gpt_eval/tree/master/nanogpt 59 | 60 | # Shape Annotations 61 | 62 | I've been using this tip from Noam Shazeer: 63 | 64 | Dimension key (from https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd): 65 | 66 | M = modes 67 | 68 | l = seq length before indexing 69 | 70 | L = seq length after indexing 71 | 72 | B = batch_size 73 | 74 | R = rows (or cols) 75 | 76 | C = classes for one hot encoding 77 | 78 | D = d_model of the GPT (512) 79 | 80 | For example 81 | 82 | ``` 83 | probe_out_MBLRRC = einsum( 84 | "batch pos d_model, modes d_model rows cols options -> modes batch pos rows cols options", 85 | resid_post_BLD, 86 | linear_probe_MDRRC, 87 | ) 88 | ``` 89 | 90 | # Useful links 91 | 92 | All code, models, and datasets are open source. 93 | 94 | To play the nanoGPT model against Stockfish, please visit: https://github.com/adamkarvonen/chess_gpt_eval/tree/master/nanogpt 95 | 96 | To train a Chess-GPT from scratch, please visit: https://github.com/adamkarvonen/nanoGPT 97 | 98 | All pretrained models are available here: https://huggingface.co/adamkarvonen/chess_llms 99 | 100 | All datasets are available here: https://huggingface.co/datasets/adamkarvonen/chess_games 101 | 102 | Wandb training loss curves and model configs can be viewed here: https://api.wandb.ai/links/adam-karvonen/u783xspb 103 | 104 | # Testing 105 | 106 | To run the end to end test suite, run `pytest -s` from the root directory. This will first train and test probes end to end on the 8 layer model, including comparing expected accuracy to actual accuracy within some tolerance. Then it will test out board state interventions and caa creation. It takes around 14 minutes. The `-s` flag is so you can see the training updates and gauge progress. 107 | 108 | # References 109 | 110 | Much of my linear probing was developed using Neel Nanda's linear probing code as a reference. Here are the main references I used: 111 | 112 | https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Othello_GPT.ipynb 113 | https://colab.research.google.com/github/likenneth/othello_world/blob/master/Othello_GPT_Circuits.ipynb 114 | https://www.neelnanda.io/mechanistic-interpretability/othello 115 | https://github.com/likenneth/othello_world/tree/master/mechanistic_interpretability -------------------------------------------------------------------------------- /board_state_interventions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fancy_einsum import einsum 3 | import chess 4 | from dataclasses import dataclass, field, fields 5 | import pickle 6 | import logging 7 | from functools import partial 8 | from enum import Enum 9 | import json 10 | 11 | import chess_utils 12 | import train_test_chess 13 | 14 | from jaxtyping import Int, Float, jaxtyped 15 | from torch import Tensor 16 | from beartype import beartype 17 | 18 | import cProfile 19 | import pstats 20 | import io 21 | 22 | torch.set_grad_enabled(False) 23 | 24 | # Flags to control logging 25 | debug_mode = False 26 | info_mode = True 27 | 28 | if debug_mode: 29 | log_level = logging.DEBUG 30 | elif info_mode: 31 | log_level = logging.INFO 32 | else: 33 | log_level = logging.WARNING 34 | 35 | # Configure logging 36 | logging.basicConfig(level=log_level) 37 | logger = logging.getLogger(__name__) 38 | 39 | GPT_LAYER_COUNT = 8 40 | DATA_DIR = "data/" 41 | SAVED_PROBE_DIR = f"linear_probes/saved_probes/" 42 | RECORDING_DIR = "intervention_logs/" 43 | SPLIT = "test" 44 | MODES = 1 # Currently only supporting 1 mode so this is fairly unnecessary 45 | START_POS = 0 46 | END_POS = 30 47 | BLANK_INDEX = chess_utils.PIECE_TO_ONE_HOT_MAPPING[0] 48 | SAMPLING_MOVES = 5 49 | TEMPERATURE = 1.0 50 | MAX_GAMES = 5000 51 | 52 | DEVICE = ( 53 | "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 54 | ) 55 | logger.info(f"Using device: {DEVICE}") 56 | 57 | with open("models/meta.pkl", "rb") as f: 58 | META = pickle.load(f) 59 | 60 | 61 | class InterventionType(Enum): 62 | SINGLE_SCALE = "single_scale" 63 | SINGLE_TARGET = "single_target" 64 | AVERAGE_TARGET = "average_target" 65 | 66 | 67 | class ModelType(Enum): 68 | ORIGINAL = "original" 69 | MODIFIED = "modified" 70 | 71 | 72 | @dataclass 73 | class MoveTracker: 74 | orig_board_argmax_legal_total: int = 0 75 | orig_board_sampled_legal_total: int = 0 76 | orig_board_sampled_legal_unique: int = 0 77 | mod_board_argmax_legal_total: int = 0 78 | mod_board_sampled_legal_total: int = 0 79 | mod_board_sampled_legal_unique: int = 0 80 | unique_moves: int = 0 81 | 82 | def update_with(self, other: "MoveTracker"): 83 | """Updates the attributes of this MoveTracker with values from another.""" 84 | self.orig_board_argmax_legal_total += other.orig_board_argmax_legal_total 85 | self.orig_board_sampled_legal_total += other.orig_board_sampled_legal_total 86 | self.orig_board_sampled_legal_unique += other.orig_board_sampled_legal_unique 87 | self.mod_board_argmax_legal_total += other.mod_board_argmax_legal_total 88 | self.mod_board_sampled_legal_total += other.mod_board_sampled_legal_total 89 | self.mod_board_sampled_legal_unique += other.mod_board_sampled_legal_unique 90 | self.unique_moves += other.unique_moves 91 | 92 | 93 | @dataclass 94 | class MoveCounters: 95 | total_moves: int = 0 96 | possible_moves: int = 0 97 | orig_model_tracker: MoveTracker = field(default_factory=MoveTracker) 98 | mod_model_tracker: MoveTracker = field(default_factory=MoveTracker) 99 | 100 | 101 | def get_probe_data(probe_name: str, num_games: int) -> train_test_chess.LinearProbeData: 102 | probe_file_location = f"{SAVED_PROBE_DIR}{probe_name}" 103 | with open(probe_file_location, "rb") as f: 104 | state_dict = torch.load(f, map_location=torch.device(DEVICE)) 105 | print(state_dict.keys()) 106 | for key in state_dict.keys(): 107 | if key != "linear_probe": 108 | print(key, state_dict[key]) 109 | 110 | config = chess_utils.find_config_by_name(state_dict["config_name"]) 111 | layer = state_dict["layer"] 112 | model_name = state_dict["model_name"] 113 | dataset_prefix = state_dict["dataset_prefix"] 114 | config.pos_start = state_dict["pos_start"] 115 | levels_of_interest = None 116 | if "levels_of_interest" in state_dict.keys(): 117 | levels_of_interest = state_dict["levels_of_interest"] 118 | config.levels_of_interest = levels_of_interest 119 | n_layers = state_dict["n_layers"] 120 | 121 | split = SPLIT 122 | input_dataframe_file = f"{DATA_DIR}{dataset_prefix}{split}.csv" 123 | config = chess_utils.set_config_min_max_vals_and_column_name( 124 | config, input_dataframe_file, dataset_prefix 125 | ) 126 | 127 | probe_data = train_test_chess.construct_linear_probe_data( 128 | input_dataframe_file, 129 | dataset_prefix, 130 | n_layers, 131 | model_name, 132 | config, 133 | num_games, 134 | DEVICE, 135 | ) 136 | return probe_data 137 | 138 | 139 | @jaxtyped(typechecker=beartype) 140 | def prepare_intervention_data( 141 | probe_names: dict[int, str], 142 | probe_data: train_test_chess.LinearProbeData, 143 | num_games: int, 144 | ) -> tuple[ 145 | dict[int, Float[Tensor, "modes d_model rows cols options"]], 146 | Int[Tensor, "modes num_games pgn_str_length rows cols"], 147 | Int[Tensor, "num_games num_white_moves"], 148 | ]: 149 | 150 | probes = {} 151 | checkpoint = None # Going to retain the last checkpoint for the config 152 | 153 | for layer, probe_name in probe_names.items(): 154 | probe_file_location = f"{SAVED_PROBE_DIR}{probe_name}" 155 | checkpoint = torch.load(probe_file_location, map_location=torch.device(DEVICE)) 156 | linear_probe = checkpoint["linear_probe"] 157 | probes[layer] = linear_probe 158 | 159 | config = chess_utils.find_config_by_name(checkpoint["config_name"]) 160 | state_stacks_all_chars = chess_utils.create_state_stacks( 161 | probe_data.board_seqs_string[:num_games], config.custom_board_state_function 162 | ) 163 | logger.info(f"state_stack shape: {state_stacks_all_chars.shape}") 164 | pgn_str_length = len(probe_data.board_seqs_string[0]) 165 | 166 | assert (state_stacks_all_chars.shape) == ( 167 | MODES, 168 | num_games, 169 | pgn_str_length, 170 | config.num_rows, 171 | config.num_cols, 172 | ) 173 | 174 | white_move_indices = probe_data.custom_indices[:num_games] 175 | num_white_moves = white_move_indices.shape[1] 176 | assert (white_move_indices.shape) == (num_games, num_white_moves) 177 | 178 | return probes, state_stacks_all_chars, white_move_indices 179 | 180 | 181 | def initialize_output_tracker(probes: dict[int, str]) -> dict: 182 | output_tracker: dict[int, dict] = {} 183 | for layer in probes: 184 | output_tracker[layer] = { 185 | "original_blank_probe": [], 186 | "modified_blank_probe": [], 187 | "original_piece_probe": [], 188 | "modified_piece_probe": [], 189 | "original_blank_grid": [], 190 | "modified_blank_grid": [], 191 | "original_piece_grid": [], 192 | "modified_piece_grid": [], 193 | "average_original_blank_grid": [], 194 | "average_modified_blank_grid": [], 195 | "average_original_piece_grid": [], 196 | "average_modified_piece_grid": [], 197 | "scales": [], 198 | "successes": [], 199 | "cells": [], 200 | "pieces": [], 201 | "modified_move": [], 202 | "original_move": [], 203 | } 204 | return output_tracker 205 | 206 | 207 | def initialize_scale_tracker(scales: list[float]) -> dict[float, MoveTracker]: 208 | scale_tracker: dict[float, MoveTracker] = {} 209 | for scale in scales: 210 | scale_tracker[scale] = MoveTracker() 211 | return scale_tracker 212 | 213 | 214 | @jaxtyped(typechecker=beartype) 215 | def update_output_tracker_grids( 216 | probes: dict[int, Float[Tensor, "modes d_model rows cols options"]], 217 | probe_data: train_test_chess.LinearProbeData, 218 | model_input: Int[Tensor, "num_games pgn_str_length"], 219 | state_stacks_all_chars: Int[Tensor, "modes num_games pgn_str_length rows cols"], 220 | output_tracker: dict, 221 | move_of_interest_index: int, 222 | sample_index: int, 223 | r: int, 224 | c: int, 225 | moved_piece_probe_index: int, 226 | moved_piece_int: int, 227 | model_move: str, 228 | model_type: ModelType, 229 | ) -> dict: 230 | 231 | model_type_str = model_type.value 232 | 233 | _, cache = probe_data.model.run_with_cache(model_input) 234 | 235 | for layer in output_tracker: 236 | probe_outputs = calculate_probe_outputs(probes, cache) 237 | probe_out = probe_outputs[layer] 238 | 239 | blank_probe_grid = probe_out[0, 0, move_of_interest_index, :, :, BLANK_INDEX] 240 | piece_probe_grid = probe_out[0, 0, move_of_interest_index, :, :, moved_piece_probe_index] 241 | 242 | blank_probe_out = blank_probe_grid[r, c] 243 | piece_probe_out = piece_probe_grid[r, c] 244 | output_tracker[layer][f"{model_type_str}_blank_grid"].append(blank_probe_grid.to("cpu")) 245 | output_tracker[layer][f"{model_type_str}_piece_grid"].append(piece_probe_grid.to("cpu")) 246 | output_tracker[layer][f"{model_type_str}_blank_probe"].append(blank_probe_out) 247 | output_tracker[layer][f"{model_type_str}_piece_probe"].append(piece_probe_out) 248 | output_tracker[layer][f"{model_type_str}_move"].append(model_move) 249 | 250 | average_blank_values = average_probe_empty_cell_value( 251 | state_stacks_all_chars, 252 | probe_outputs, 253 | BLANK_INDEX, 254 | move_of_interest_index, 255 | sample_index, 256 | ) 257 | average_piece_values = average_probe_empty_cell_value( 258 | state_stacks_all_chars, 259 | probe_outputs, 260 | moved_piece_probe_index, 261 | move_of_interest_index, 262 | sample_index, 263 | ) 264 | 265 | output_tracker[layer][f"average_{model_type_str}_blank_grid"].append( 266 | average_blank_values[layer] 267 | ) 268 | output_tracker[layer][f"average_{model_type_str}_piece_grid"].append( 269 | average_piece_values[layer] 270 | ) 271 | 272 | if model_type == ModelType.MODIFIED: 273 | for layer in output_tracker: 274 | # Duplicating some metadata for each layer but it's a small amount 275 | output_tracker[layer]["scales"].append(scale) 276 | output_tracker[layer]["successes"].append(False) 277 | output_tracker[layer]["cells"].append((r, c)) 278 | output_tracker[layer]["pieces"].append(moved_piece_int) 279 | 280 | return output_tracker 281 | 282 | 283 | def create_recording_data( 284 | move_counters: MoveCounters, scale_tracker: dict[float, MoveTracker] 285 | ) -> dict: 286 | records = {} 287 | records["orig_model_tracker"] = {} 288 | records["mod_model_tracker"] = {} 289 | for field in fields(move_counters.orig_model_tracker): 290 | records["orig_model_tracker"][field.name] = getattr( 291 | move_counters.orig_model_tracker, field.name 292 | ) 293 | for field in fields(move_counters.mod_model_tracker): 294 | records["mod_model_tracker"][field.name] = getattr( 295 | move_counters.mod_model_tracker, field.name 296 | ) 297 | for field in fields(move_counters): 298 | if field.name == "orig_model_tracker" or field.name == "mod_model_tracker": 299 | continue 300 | records[field.name] = getattr(move_counters, field.name) 301 | for scale in scale_tracker: 302 | records[scale] = {} 303 | for field in fields(scale_tracker[scale]): 304 | records[scale][field.name] = getattr(scale_tracker[scale], field.name) 305 | records["possible_sampled_moves"] = records["possible_moves"] * SAMPLING_MOVES 306 | return records 307 | 308 | 309 | def update_move_counters_best_per_move( 310 | move_counters: MoveCounters, 311 | per_move_scale_tracker: dict[float, MoveTracker], 312 | ) -> MoveCounters: 313 | """For each move, we find the best performing scale parameter. We then increment the move counter trackers with these values. 314 | The purpose is to get an upper bound on effectiveness if we could dynamically select a good scale parameter. 315 | """ 316 | scales = list(per_move_scale_tracker.keys()) 317 | for field in fields(per_move_scale_tracker[scales[0]]): 318 | best_scale_value = max( 319 | getattr(per_move_scale_tracker[scale], field.name) for scale in scales 320 | ) 321 | current_scale_value = getattr(move_counters.mod_model_tracker, field.name) 322 | setattr( 323 | move_counters.mod_model_tracker, 324 | field.name, 325 | best_scale_value + current_scale_value, 326 | ) 327 | 328 | return move_counters 329 | 330 | 331 | def sample_moves_from_model( 332 | model, 333 | model_input: Int[Tensor, "num_games pgn_str_length"], 334 | original_board: chess.Board, 335 | modified_board: chess.Board, 336 | ) -> MoveTracker: 337 | """Samples moves from a model and updates the provided list of boards with the 338 | total number of legal moves and unique legal moves for each board.""" 339 | unique_moves = set() 340 | move_tracker = MoveTracker() 341 | for _ in range(SAMPLING_MOVES): 342 | sampled_model_move = chess_utils.get_model_move( 343 | model, META, model_input, temperature=TEMPERATURE 344 | ) 345 | try: 346 | original_board.parse_san(sampled_model_move) 347 | # print(f"Model original move: {sampled_model_move}") 348 | move_tracker.orig_board_sampled_legal_total += 1 349 | if sampled_model_move not in unique_moves: 350 | move_tracker.orig_board_sampled_legal_unique += 1 351 | except: 352 | # print(f"Invalid original move: {sampled_model_move}") 353 | pass 354 | try: 355 | modified_board.parse_san(sampled_model_move) 356 | print(f"Model modified move: {sampled_model_move}") 357 | move_tracker.mod_board_sampled_legal_total += 1 358 | if sampled_model_move not in unique_moves: 359 | move_tracker.mod_board_sampled_legal_unique += 1 360 | except: 361 | print(f"Invalid modified move: {sampled_model_move}") 362 | pass 363 | unique_moves.add(sampled_model_move) 364 | 365 | move_tracker.unique_moves += len(unique_moves) 366 | 367 | return move_tracker 368 | 369 | 370 | def check_if_legal_move(board: chess.Board, move: str) -> bool: 371 | try: 372 | board.parse_san(move) 373 | return True 374 | except: 375 | return False 376 | 377 | 378 | @jaxtyped(typechecker=beartype) 379 | def calculate_probe_outputs( 380 | probes: dict[int, Float[Tensor, "modes d_model rows cols options"]], cache 381 | ) -> dict[int, Float[Tensor, "modes batch num_white_moves rows cols options"]]: 382 | probe_outputs = {} 383 | for layer in probes: 384 | resid_post = cache["resid_post", layer][:, :] # shape is (batch, pos, d_model) 385 | linear_probe = probes[layer] 386 | probe_outputs[layer] = einsum( 387 | "batch pos d_model, modes d_model rows cols options -> modes batch pos rows cols options", 388 | resid_post, 389 | linear_probe, 390 | ) 391 | return probe_outputs 392 | 393 | 394 | @jaxtyped(typechecker=beartype) 395 | def calculate_scale_coefficient( 396 | model_activations: Float[Tensor, "d_model"], 397 | flip_dir: Float[Tensor, "d_model"], 398 | probe: Float[Tensor, "d_model"], 399 | target: float, 400 | ) -> Tensor: 401 | """Find the scale coefficient that will result in the linear probe output being equal to the target value.""" 402 | left_side = torch.dot(model_activations, probe) - target 403 | right_side = torch.dot(flip_dir, probe) 404 | scale = left_side / right_side 405 | return scale 406 | 407 | 408 | def get_average_outputs(output_tracker: dict) -> tuple[float, float]: 409 | sum_first_elements = sum(item[0].item() for item in output_tracker) 410 | sum_second_elements = sum(item[1].item() for item in output_tracker) 411 | average_first = sum_first_elements / len(output_tracker) 412 | average_second = sum_second_elements / len(output_tracker) 413 | return average_first, average_second 414 | 415 | 416 | def average_probe_empty_cell_value( 417 | state_stacks: torch.Tensor, 418 | probe_outputs: dict[int, torch.Tensor], 419 | piece_index: int, 420 | move_of_interest_index, 421 | sample_index: int, 422 | ) -> dict[int, float]: 423 | """ 424 | If move_of_interest_index is a queen: 425 | For an 8 x 8 board with 1 queen, find the average probe output pre softmax for all cells that are not a queen. 426 | 427 | Returns: 428 | - float: The average of the specified values in the tensor. Returns 0 if no relevant values. 429 | """ 430 | average_cell_values = {} 431 | for layer in probe_outputs: 432 | probe_output = probe_outputs[layer] 433 | target_val = chess_utils.ONE_HOT_TO_PIECE_MAPPING[piece_index] 434 | probe_state = probe_output[0, 0, move_of_interest_index, :, :, piece_index] 435 | value_mask = state_stacks[0, sample_index, move_of_interest_index, :, :] != target_val 436 | value_mask = value_mask.to(DEVICE) 437 | 438 | # Select the relevant values based on the mask 439 | relevant_values = torch.masked_select(probe_state, value_mask) 440 | 441 | # Compute the mean of relevant values if there are any, otherwise return 0 442 | if relevant_values.nelement() == 0: 443 | average_cell_values[layer] = 0.0 444 | else: 445 | average_cell_values[layer] = relevant_values.mean().item() 446 | return average_cell_values 447 | 448 | 449 | # This is a 250 line function, which I'm not thrilled about. However, every sequential step is only used once in this function. 450 | # I made an initial attempt to break it up into smaller functions, but I found that it made the code harder to follow. 451 | # I also have limited time to refactor this function, so I'm leaving it as is for now. 452 | # There is a lot going on here, but it's all necessary. 453 | def perform_board_interventions( 454 | probe_names: dict[int, str], 455 | probe_data: train_test_chess.LinearProbeData, 456 | num_games: int, 457 | intervention_type: InterventionType, 458 | recording_name: str, 459 | piece_coefficient: float = 1.0, 460 | blank_coefficient: float = 1.0, 461 | track_outputs: bool = False, 462 | scales: list[float] = [0.1], 463 | ) -> float: 464 | probes, state_stacks_all_chars, white_move_indices = prepare_intervention_data( 465 | probe_names, probe_data, num_games 466 | ) 467 | # probes is a dict of [int: torch.Tensor] 468 | # probe is a tensor of shape (modes, d_model, rows, cols, options) 469 | # state_stacks_all_chars is a tensor of shape (modes, num_games, pgn_str_length, rows, cols) 470 | # white_move_indices is a tensor of shape (num_games, num_white_moves) 471 | scale_tracker = initialize_scale_tracker(scales) 472 | move_counters = MoveCounters() 473 | 474 | # Output tracker stores metadata and the original and modified probe outputs for the entire board per move per game for each layer 475 | # The results can be viewed as heatmaps per probe output in probe_output_data_exploration.ipynb 476 | # CAUTION: This can quickly grow to gigabytes of data 477 | if track_outputs: 478 | output_tracker = initialize_output_tracker(probes) 479 | 480 | average_piece_values = {} 481 | 482 | for sample_index in range(num_games): 483 | for scale in scales: 484 | print( 485 | f"Scale: {scale}, deterministic count: {scale_tracker[scale].mod_board_argmax_legal_total}, sampled count: {scale_tracker[scale].mod_board_sampled_legal_total}" 486 | ) 487 | 488 | for move_of_interest in range(START_POS, END_POS): 489 | print( 490 | f"Sample index: {sample_index}, total moves: {move_counters.total_moves}, possible moves: {move_counters.possible_moves}, legal intervention moves: {move_counters.mod_model_tracker.mod_board_argmax_legal_total}" 491 | ) 492 | move_counters.total_moves += 1 493 | 494 | # Step 1: Get the board state at move_of_interest 495 | move_of_interest_index = white_move_indices[sample_index][move_of_interest] 496 | pgn_string = probe_data.board_seqs_string[sample_index][: move_of_interest_index + 1] 497 | orig_board = chess_utils.pgn_string_to_board(pgn_string) 498 | 499 | # Step 2: Get the model move at move_of_interest 500 | # model_input.shape is (1, move_of_interest_index + 1) 501 | encoded_input = chess_utils.encode_string(META, pgn_string) 502 | # model input shape: (1, pgn_str_length) 503 | model_input = torch.tensor(encoded_input).unsqueeze(0).to(DEVICE) 504 | argmax_model_move = chess_utils.get_model_move( 505 | probe_data.model, META, model_input, temperature=0.0 506 | ) 507 | 508 | # Step 3: Check if the model move is legal. parse_san will throw an exception if the move is illegal 509 | try: 510 | model_move_san = orig_board.parse_san(argmax_model_move) 511 | except: 512 | continue 513 | 514 | move_counters.orig_model_tracker.orig_board_argmax_legal_total += 1 515 | 516 | print(f"\nargmax_model_move: {argmax_model_move}\n") 517 | 518 | # Step 4: Determine which piece was moved from which source square 519 | moved_piece = orig_board.piece_at(model_move_san.from_square) 520 | if moved_piece is None: 521 | raise Exception("No piece found at source square") 522 | moved_piece_int = chess_utils.PIECE_TO_INT[moved_piece.piece_type] 523 | moved_piece_probe_index = chess_utils.PIECE_TO_ONE_HOT_MAPPING[moved_piece_int] 524 | r, c = chess_utils.square_to_coordinate(model_move_san.from_square) 525 | 526 | # If the piece is a king, we skip the intervention as a legal chess game must have a king. 527 | if moved_piece.piece_type == chess.KING: 528 | continue 529 | 530 | # Step 5: Make a modified board where source square is now empty. Verify that it has legal moves available 531 | modified_board = orig_board.copy() 532 | modified_board.set_piece_at(model_move_san.from_square, None) 533 | 534 | if not any(orig_board.legal_moves): 535 | print("No legal moves available for the modified board. Skipping...") 536 | continue 537 | 538 | move_counters.possible_moves += 1 539 | 540 | # Step 5.1: Sample n moves from the unmodified model 541 | # Track how many moves were legal on both the original and modified boards 542 | move_tracker = sample_moves_from_model( 543 | probe_data.model, model_input, orig_board, modified_board 544 | ) 545 | move_counters.orig_model_tracker.update_with(move_tracker) 546 | 547 | # If we are targetting probe output values, collect the average probe output values. 548 | if intervention_type == InterventionType.AVERAGE_TARGET: 549 | _, cache = probe_data.model.run_with_cache(model_input) 550 | probe_outputs = calculate_probe_outputs(probes, cache) 551 | average_piece_values = average_probe_empty_cell_value( 552 | state_stacks_all_chars, 553 | probe_outputs, 554 | moved_piece_probe_index, 555 | move_of_interest_index, 556 | sample_index, 557 | ) 558 | 559 | # Initialize some legal move trackers. Note that these get reset every move. Add the end of the move, 560 | # we find the maximum value of these trackers to get the maximum possible legal moves for each move 561 | per_move_scale_tracker = initialize_scale_tracker(scales) 562 | for scale in scales: 563 | print(f"Scale: {scale}") 564 | 565 | if track_outputs: 566 | output_tracker = update_output_tracker_grids( 567 | probes, 568 | probe_data, 569 | model_input, 570 | state_stacks_all_chars, 571 | output_tracker, 572 | move_of_interest_index, 573 | sample_index, 574 | r, 575 | c, 576 | moved_piece_probe_index, 577 | moved_piece_int, 578 | argmax_model_move, 579 | ModelType.ORIGINAL, 580 | ) 581 | 582 | # This is the intervention function. In it, we obtain a vector to flip the square to blank in the model's activations at a given layer 583 | # Multiply it by some scale factor, then subtract it from the model's activations 584 | # If we make this function more modular and pass all variables in (probes, r, c, etc), it is much slower 585 | def flip_hook( 586 | resid, # shape is (1, num_white_moves, d_model) 587 | hook, 588 | layer: int, 589 | scale: float = 0.1, 590 | ): 591 | target = 0.0 592 | blank_probe = probes[layer][:, :, r, c, BLANK_INDEX].squeeze() 593 | piece_probe = probes[layer][:, :, r, c, moved_piece_probe_index].squeeze() 594 | 595 | flip_dir = (piece_probe * piece_coefficient) - (blank_probe * blank_coefficient) 596 | flip_dir = flip_dir / flip_dir.norm() 597 | 598 | if ( 599 | intervention_type == InterventionType.AVERAGE_TARGET 600 | or intervention_type == InterventionType.SINGLE_TARGET 601 | ): 602 | if intervention_type == InterventionType.AVERAGE_TARGET: 603 | target = average_piece_values[layer] + scale 604 | else: 605 | target = scale 606 | scale = calculate_scale_coefficient( 607 | resid[0, move_of_interest_index, :], 608 | flip_dir, 609 | piece_probe, 610 | float(target), 611 | ) 612 | # scale = min(0.3, scale) 613 | # print(target, scale) 614 | 615 | resid[0, :] -= scale * flip_dir 616 | 617 | # For experimentation with dynamic scale setting 618 | # coeff = resid[0, move_of_interest_index] @ flip_dir / flip_dir.norm() 619 | 620 | # So we only print once during inference 621 | # if resid.shape[1] <= move_of_interest_index + 1: 622 | # print( 623 | # f"Layer: {layer}, coeff: {coeff:10.3f}, scale: {scale:10.3f}, target: {target:10.3f}" 624 | # ) 625 | 626 | # Step 6: Intervene on the model's activations and get the model move under the modified board state 627 | probe_data.model.reset_hooks() 628 | for layer in probes: 629 | temp_hook_fn = partial(flip_hook, layer=layer, scale=scale) 630 | hook_name = f"blocks.{layer}.hook_resid_post" 631 | probe_data.model.add_hook(hook_name, temp_hook_fn) 632 | 633 | modified_board_argmax_model_move = chess_utils.get_model_move( 634 | probe_data.model, META, model_input, temperature=0.0 635 | ) 636 | 637 | print(f"\nModified board argmax model move: {modified_board_argmax_model_move}\n") 638 | 639 | # Step 6.1: Sample n moves from the modified model 640 | # Track how many moves were legal on the modified board 641 | # Note that we are tracking this for each scale 642 | move_tracker = sample_moves_from_model( 643 | probe_data.model, model_input, orig_board, modified_board 644 | ) 645 | per_move_scale_tracker[scale].update_with(move_tracker) 646 | 647 | # Step 6.2: If we are tracking outputs, update the output tracker with the modified outputs 648 | if track_outputs: 649 | output_tracker = update_output_tracker_grids( 650 | probes, 651 | probe_data, 652 | model_input, 653 | state_stacks_all_chars, 654 | output_tracker, 655 | move_of_interest_index, 656 | sample_index, 657 | r, 658 | c, 659 | moved_piece_probe_index, 660 | moved_piece_int, 661 | argmax_model_move, 662 | ModelType.MODIFIED, 663 | ) 664 | 665 | probe_data.model.reset_hooks() 666 | 667 | if check_if_legal_move(modified_board, modified_board_argmax_model_move): 668 | # Step 8: The move is legal. Update the legal move trackers 669 | if track_outputs: 670 | for layer in output_tracker: 671 | output_tracker[layer]["successes"][-1] = True 672 | per_move_scale_tracker[scale].mod_board_argmax_legal_total += 1 673 | 674 | if check_if_legal_move(orig_board, modified_board_argmax_model_move): 675 | per_move_scale_tracker[scale].orig_board_argmax_legal_total += 1 676 | 677 | scale_tracker[scale].update_with(per_move_scale_tracker[scale]) 678 | 679 | # Update move_counters with best result per move at end of turn 680 | move_counters = update_move_counters_best_per_move( 681 | move_counters, per_move_scale_tracker 682 | ) 683 | if move_counters.possible_moves > MAX_GAMES: 684 | break 685 | 686 | # After intervening on all moves in all games, save output_tracker and move_counters to disk 687 | if track_outputs: 688 | file_path = "output_tracker.pkl" 689 | with open(file_path, "wb") as file: 690 | pickle.dump(output_tracker, file) 691 | print(f"File saved to {file_path}") 692 | print( 693 | f"Sample index: {sample_index}, total moves: {move_counters.total_moves}, possible moves: {move_counters.possible_moves}, legal intervention moves: {move_counters.mod_model_tracker.mod_board_argmax_legal_total}" 694 | ) 695 | for scale in scales: 696 | print( 697 | f"Scale: {scale}, deterministic count: {scale_tracker[scale].mod_board_argmax_legal_total}, sampled count: {scale_tracker[scale].mod_board_sampled_legal_total}" 698 | ) 699 | recording_name = RECORDING_DIR + "/" + recording_name + ".json" 700 | with open(recording_name, "w") as file: 701 | records = create_recording_data(move_counters, scale_tracker) 702 | file.write(json.dumps(records)) 703 | 704 | return ( 705 | move_counters.mod_model_tracker.mod_board_argmax_legal_total / move_counters.possible_moves 706 | ) 707 | 708 | 709 | if __name__ == "__main__": 710 | 711 | scales_lookup: dict[InterventionType, list[float]] = { 712 | InterventionType.SINGLE_SCALE: [1.5], 713 | InterventionType.AVERAGE_TARGET: [9.0], 714 | InterventionType.SINGLE_TARGET: [-9], 715 | } 716 | 717 | intervention_types = [ 718 | InterventionType.SINGLE_SCALE, 719 | ] 720 | 721 | num_games = 200 722 | 723 | for intervention_type in intervention_types: 724 | 725 | probe_names = {} 726 | first_layer = 5 727 | last_layer = 5 728 | 729 | for i in range(first_layer, last_layer + 1, 1): 730 | probe_names[i] = ( 731 | f"tf_lens_lichess_{GPT_LAYER_COUNT}layers_ckpt_no_optimizer_chess_piece_probe_layer_{i}.pth" 732 | ) 733 | probe_data = get_probe_data(probe_names[first_layer], num_games) 734 | 735 | piece_coe = 1.0 736 | blank_coe = 0.0 737 | 738 | scales = scales_lookup[intervention_type] 739 | 740 | recording_name = f"n_layers={GPT_LAYER_COUNT}_intervention_type={intervention_type.value}_first_layer={first_layer}_last_layer={last_layer}_p={piece_coe}_b={blank_coe}_scales=" 741 | for scale in scales: 742 | recording_name += f"{str(scale).replace('.', '')[:5]}_" 743 | 744 | print(f"Recording name: {recording_name}") 745 | 746 | perform_board_interventions( 747 | probe_names, 748 | probe_data, 749 | num_games, 750 | intervention_type, 751 | recording_name, 752 | track_outputs=False, 753 | scales=scales, 754 | ) 755 | 756 | # For profiling, most cumulative time appears to be in forward pass in chess_utils.get_model_move() 757 | # def run_profile(): 758 | # pr = cProfile.Profile() 759 | # pr.enable() 760 | 761 | # perform_board_interventions( 762 | # probe_names, 763 | # probe_data, 764 | # 1, 765 | # intervention_type, 766 | # recording_name, 767 | # track_outputs=False, 768 | # scales=scales, 769 | # ) 770 | 771 | # pr.disable() 772 | # s = io.StringIO() 773 | # ps = pstats.Stats(pr, stream=s).sort_stats("cumulative") 774 | # ps.print_stats() 775 | # print(s.getvalue()) 776 | 777 | 778 | # run_profile() 779 | -------------------------------------------------------------------------------- /caa.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | from tqdm import tqdm 4 | import logging 5 | import itertools 6 | from transformer_lens import HookedTransformer 7 | from functools import partial 8 | 9 | import train_test_chess 10 | from train_test_chess import LinearProbeData 11 | import chess_utils 12 | from chess_utils import Config 13 | 14 | torch.set_grad_enabled(False) 15 | 16 | # log_level = logging.DEBUG 17 | log_level = logging.INFO 18 | # log_level = logging.WARNING 19 | 20 | # Configure logging 21 | logging.basicConfig(level=log_level) 22 | logger = logging.getLogger(__name__) 23 | 24 | BATCH_SIZE = 1 25 | MAXIMUM_TRAINING_GAMES = 2000 26 | 27 | 28 | def check_tensor_values(tensor, tensor_name="Tensor"): 29 | """Check if a tensor contains NaN, inf, or -inf values because we are summing 30k+ activations together.""" 30 | # isneginf is currently not implemented for mps tensors 31 | original_device_type = tensor.device.type 32 | if original_device_type == "mps": 33 | tensor = tensor.cpu() 34 | 35 | if torch.any(torch.isinf(tensor)): 36 | raise ValueError(f"Overflow detected: {tensor_name} contains inf") 37 | if torch.any(torch.isneginf(tensor)): 38 | raise ValueError(f"Overflow detected: {tensor_name} contains -inf") 39 | if torch.any(torch.isnan(tensor)): 40 | raise ValueError(f"Invalid value detected: {tensor_name} contains NaN") 41 | 42 | if original_device_type == "mps": 43 | tensor = tensor.to("mps") 44 | 45 | 46 | def add_hook_interventions( 47 | model: HookedTransformer, previous_activations: dict[int, torch.Tensor], scale: float = 0.25 48 | ) -> HookedTransformer: 49 | """Add hooks to the model to intervene in the forward pass.""" 50 | 51 | model.reset_hooks() 52 | 53 | def flip_hook(resid, hook, flip_dir: torch.Tensor): 54 | resid[:, :] += scale * flip_dir 55 | 56 | for layer, activation in previous_activations.items(): 57 | temp_hook_fn = partial(flip_hook, flip_dir=activation) 58 | hook_name = f"blocks.{layer}.hook_resid_post" 59 | model.add_hook(hook_name, temp_hook_fn) 60 | 61 | return model 62 | 63 | 64 | @torch.no_grad() 65 | def create_contrastive_activations( 66 | activation_name: str, 67 | probe_data: LinearProbeData, 68 | config: Config, 69 | logging_dict: dict, 70 | layer: int, 71 | max_games: int, 72 | ) -> torch.Tensor: 73 | """Creates a contrastive activation for a given layer and saves it to disk. 74 | We could do this for all layers at once for simple CAA, but it breaks the abstraction I was using for cascading CAA. 75 | """ 76 | assert logging_dict["split"] == "train", "Don't train on the test set" 77 | 78 | num_games = (max_games // BATCH_SIZE) * BATCH_SIZE 79 | 80 | if num_games < len(probe_data.board_seqs_int): 81 | raise ValueError( 82 | f"Number of games ({num_games}) is less than the number of games in the dataset ({len(probe_data.board_seqs_int)})" 83 | ) 84 | 85 | current_iter = 0 86 | full_train_indices = torch.arange(0, num_games) 87 | sum_high_elo = torch.zeros((512), device=device) 88 | sum_low_elo = torch.zeros((512), device=device) 89 | count_high_elo = 0 90 | count_low_elo = 0 91 | for i in tqdm(range(0, num_games, BATCH_SIZE)): 92 | indices = full_train_indices[i : i + BATCH_SIZE] 93 | games_int = probe_data.board_seqs_int[indices] # shape (batch, pgn_str_length) 94 | games_dots = probe_data.custom_indices[indices] # shape (batch, num_white_moves) 95 | games_dots = games_dots[:, config.pos_start :] 96 | 97 | if config.probing_for_skill: 98 | games_skill = probe_data.skill_stack[indices] 99 | logger.debug(f"games_skill shape: {games_skill.shape}") 100 | else: 101 | raise Exception("CAA currently only supports skill vectors") 102 | 103 | _, cache = probe_data.model.run_with_cache(games_int.to(device)[:, :-1], return_type=None) 104 | resid_post = cache["resid_post", layer][:, :] # shape (batch, pgn_str_length - 1, d_model) 105 | 106 | indexed_resid_posts = [] 107 | 108 | for batch_idx in range(games_dots.size(0)): 109 | # Get the indices for the current batch 110 | dots_indices_for_batch = games_dots[batch_idx] 111 | 112 | # Index the state_stack for the current batch 113 | indexed_resid_post = resid_post[batch_idx, dots_indices_for_batch] 114 | 115 | # Append the result to the list 116 | indexed_resid_posts.append(indexed_resid_post) 117 | 118 | resid_post = torch.stack(indexed_resid_posts) # shape (batch, num_white_moves, d_model) 119 | summed_resid_post = einops.reduce( 120 | resid_post, "batch indices model_dim -> batch model_dim", "sum" 121 | ) # shape (batch, d_model) 122 | 123 | for batch_idx in range(BATCH_SIZE): 124 | if games_skill[batch_idx] == config.levels_of_interest[1]: 125 | sum_high_elo += summed_resid_post[batch_idx] # shape (d_model) 126 | count_high_elo += 1 127 | elif games_skill[batch_idx] == config.levels_of_interest[0]: 128 | sum_low_elo += summed_resid_post[batch_idx] # shape (d_model) 129 | count_low_elo += 1 130 | else: 131 | raise Exception("Invalid skill level") 132 | 133 | logger.debug( 134 | f"count_high_elo: {count_high_elo}, count_low_elo: {count_low_elo}, games_skill: {games_skill}" 135 | ) 136 | 137 | if i % 100 == 0: 138 | logger.info( 139 | f"batch {i}, count_high_elo: {count_high_elo}, count_low_elo: {count_low_elo}" 140 | ) 141 | 142 | current_iter += BATCH_SIZE 143 | 144 | check_tensor_values(sum_high_elo, "sum_high_elo") 145 | check_tensor_values(sum_low_elo, "sum_low_elo") 146 | 147 | average_high_elo_activation = sum_high_elo / count_high_elo # shape (d_model) 148 | average_low_elo_activation = sum_low_elo / count_low_elo # shape (d_model) 149 | 150 | difference_vector = average_high_elo_activation - average_low_elo_activation 151 | 152 | logging_dict["average_high_elo_activation"] = average_high_elo_activation 153 | logging_dict["average_low_elo_activation"] = average_low_elo_activation 154 | logging_dict["difference_vector"] = difference_vector 155 | logging_dict["count_high_elo"] = count_high_elo 156 | logging_dict["count_low_elo"] = count_low_elo 157 | 158 | output_location = f"{CAA_DIR}{activation_name}.pt" 159 | 160 | logger.info(f"Saving activations to {output_location}") 161 | torch.save(logging_dict, output_location) 162 | 163 | return difference_vector 164 | 165 | 166 | MODEL_DIR = "models/" 167 | DATA_DIR = "data/" 168 | CAA_DIR = "contrastive_activations/" 169 | 170 | device = ( 171 | "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 172 | ) 173 | logger.info(f"Using device: {device}") 174 | 175 | 176 | if __name__ == "__main__": 177 | config = chess_utils.skill_config 178 | # Sweep over layers, levels of interest, pos_start, and dataset_prefix 179 | layers = range(5, 7, 1) 180 | levels_of_interest = [[0, 5]] 181 | pos_starts = [25] 182 | 183 | caa_type = "simple" 184 | # caa_type = "cascade" 185 | 186 | cascade_layers = "" 187 | 188 | if caa_type == "cascade": 189 | cascade_layers += "".join([f"{layer}_" for layer in layers]) 190 | 191 | previous_layer_activations = {} 192 | 193 | for ( 194 | layer, 195 | level, 196 | pos_start, 197 | ) in itertools.product(layers, levels_of_interest, pos_starts): 198 | dataset_prefix = "lichess_" 199 | layer = layer 200 | split = "train" 201 | n_layers = 8 202 | model_name = f"tf_lens_{dataset_prefix}{n_layers}layers_ckpt_no_optimizer" 203 | config.levels_of_interest = level 204 | input_dataframe_file = f"{DATA_DIR}{dataset_prefix}{split}.csv" 205 | config = chess_utils.set_config_min_max_vals_and_column_name( 206 | config, input_dataframe_file, dataset_prefix 207 | ) 208 | config.pos_start = pos_start 209 | 210 | probe_data = train_test_chess.construct_linear_probe_data( 211 | input_dataframe_file, 212 | dataset_prefix, 213 | n_layers, 214 | model_name, 215 | config, 216 | MAXIMUM_TRAINING_GAMES, 217 | device, 218 | ) 219 | 220 | levels_str = "".join([str(i) for i in level]) 221 | 222 | activation_name = ( 223 | f"type=caa_{caa_type}{cascade_layers}_model={n_layers}layers_layer={layer}_activations" 224 | ) 225 | 226 | logging_dict = train_test_chess.init_logging_dict( 227 | layer, 228 | config, 229 | split, 230 | dataset_prefix, 231 | model_name, 232 | n_layers, 233 | train_test_chess.TRAIN_PARAMS, 234 | ) 235 | 236 | if caa_type == "cascade": 237 | probe_data.model = add_hook_interventions( 238 | probe_data.model, previous_layer_activations, scale=0.15 239 | ) 240 | 241 | previous_layer_activations[layer] = create_contrastive_activations( 242 | activation_name, probe_data, config, logging_dict, layer, MAXIMUM_TRAINING_GAMES 243 | ) 244 | -------------------------------------------------------------------------------- /chess_utils.py: -------------------------------------------------------------------------------- 1 | import chess 2 | import pandas as pd 3 | import torch 4 | from torch.nn import functional as F 5 | from typing import Callable, Optional 6 | from dataclasses import dataclass 7 | from jaxtyping import Int, Float, jaxtyped 8 | from torch import Tensor 9 | from enum import Enum 10 | import othello_utils 11 | 12 | # Mapping of chess pieces to integers 13 | PIECE_TO_INT = { 14 | chess.PAWN: 1, 15 | chess.KNIGHT: 2, 16 | chess.BISHOP: 3, 17 | chess.ROOK: 4, 18 | chess.QUEEN: 5, 19 | chess.KING: 6, 20 | } 21 | 22 | INT_TO_PIECE = {value: key for key, value in PIECE_TO_INT.items()} 23 | PIECE_TO_ONE_HOT_MAPPING = { 24 | -6: 0, 25 | -5: 1, 26 | -4: 2, 27 | -3: 3, 28 | -2: 4, 29 | -1: 5, 30 | 0: 6, 31 | 1: 7, 32 | 2: 8, 33 | 3: 9, 34 | 4: 10, 35 | 5: 11, 36 | 6: 12, 37 | } 38 | BLANK_INDEX = PIECE_TO_ONE_HOT_MAPPING[0] 39 | ONE_HOT_TO_PIECE_MAPPING = {value: key for key, value in PIECE_TO_ONE_HOT_MAPPING.items()} 40 | 41 | 42 | def board_to_random_state(board: chess.Board, skill: Optional[int] = None) -> torch.Tensor: 43 | """Given a chess board object, return a 8x8 torch.Tensor. 44 | Every square should be randomly assigned to 1, -1, or 0. 45 | This is to sanity check the linear probe. 46 | In the 8x8 array, row 0 is A1-H1 (White), row 1 is A2-H2, etc.""" 47 | state_RR = torch.zeros((8, 8), dtype=torch.int) 48 | for i in range(64): 49 | state_RR[i // 8, i % 8] = torch.randint(-1, 2, (1,)) 50 | 51 | return state_RR 52 | 53 | 54 | def board_to_skill_state(board: chess.Board, skill: float) -> torch.Tensor: 55 | """Given a chess board object, return a 1x1 torch.Tensor. 56 | The 1x1 array should tell what skill level the player is.""" 57 | state_RR = torch.zeros((1, 1), dtype=torch.int) 58 | state_RR[0][0] = skill 59 | 60 | return state_RR 61 | 62 | 63 | # import chess.engine 64 | 65 | # stockfish_path = "/usr/games/stockfish" 66 | # engine = chess.engine.SimpleEngine.popen_uci(stockfish_path) 67 | 68 | 69 | def board_to_eval_state(board: chess.Board, skill: Optional[int] = None) -> torch.Tensor: 70 | """Given a chess board object, return a 1x1 torch.Tensor. 71 | The 1x1 array should tell which player is winning. 72 | -1 = Black has > 100 centipawns advantage, 0 = Draw, 1 = White has > 100 centipawns advantage. 73 | This is horribly inefficient and takes ~0.75 seconds per game. However, I'm just doing exploratory analysis. 74 | If we wanted efficiency, we could use a bunch of parallel CPU workers to evaluate the board state and store it 75 | in a lookup table. But, then we couldn't cleanly use this with the existing abstractions. 76 | To use this function, uncomment the import chess.engine through engine = above, and the internal code below. 77 | """ 78 | state_RR = torch.zeros((1, 1), dtype=torch.int) 79 | 80 | # info = engine.analyse(board, chess.engine.Limit(time=0.01)) 81 | # score = info["score"].white().score(mate_score=10000) 82 | 83 | # # Modify player_one_score based on the score 84 | # if score < 100: 85 | # state_RR[0][0] = -1 86 | # elif score > 100: 87 | # state_RR[0][0] = 1 88 | # else: 89 | # state_RR[0][0] = 0 90 | 91 | return state_RR 92 | 93 | 94 | def board_to_piece_color_state(board: chess.Board, skill: Optional[int] = None) -> torch.Tensor: 95 | """Given a chess board object, return a 8x8 torch.Tensor. 96 | The 8x8 array should tell if each square is black, white, or blank. 97 | White is 1, black is -1, and blank is 0. 98 | In the 8x8 array, row 0 is A1-H1 (White), row 1 is A2-H2, etc.""" 99 | state_RR = torch.zeros((8, 8), dtype=torch.int) 100 | for i in range(64): 101 | piece = board.piece_at(i) 102 | if piece: 103 | # Assign 1 for white pieces and -1 for black pieces 104 | state_RR[i // 8, i % 8] = 1 if piece.color == chess.WHITE else -1 105 | 106 | return state_RR 107 | 108 | 109 | def board_to_piece_state(board: chess.Board, skill: Optional[int] = None) -> torch.Tensor: 110 | """Given a chess board object, return an 8x8 torch.Tensor. 111 | The 8x8 array should tell what piece is on each square. A white pawn could be 1, a black pawn could be -1, etc. 112 | Blank squares should be 0. 113 | In the 8x8 array, row 0 is A1-H1 (White), row 1 is A2-H2, etc.""" 114 | 115 | # Because state_RR is initialized to all 0s, we only need to change the values of the pieces 116 | state_RR = torch.zeros((8, 8), dtype=torch.int) 117 | for i in range(64): 118 | piece = board.piece_at(i) 119 | if piece: 120 | piece_value = PIECE_TO_INT[piece.piece_type] 121 | # Multiply by -1 if the piece is black 122 | if piece.color == chess.BLACK: 123 | piece_value *= -1 124 | state_RR[i // 8, i % 8] = piece_value 125 | 126 | return state_RR 127 | 128 | 129 | def board_to_pin_state(board: chess.Board, skill: Optional[int] = None) -> torch.Tensor: 130 | """Given a chess board object, return a 1x1 torch.Tensor. 131 | The 1x1 array indicates if there are any pins on the board (1 = yes, 0 = no).""" 132 | 133 | state_RR = torch.zeros((1, 1), dtype=torch.int) 134 | 135 | # NOTE: Due to the model's MINE / YOURS / BLANK ontology, we should check for White XOR Black pins 136 | for color in [chess.WHITE]: 137 | for i in range(64): 138 | piece = board.piece_at(i) 139 | if piece and piece.color == color: 140 | if board.is_pinned(color, i): 141 | state_RR[0, 0] = 1 142 | return state_RR 143 | 144 | return state_RR 145 | 146 | 147 | def board_to_threat_state(board: chess.Board, skill: Optional[int] = None) -> torch.Tensor: 148 | """Given a chess board object, return an 8x8 torch.Tensor. 149 | The 8x8 array should tell if each square is being attacked by the opponent.""" 150 | 151 | ATTACKING_COLOR = chess.BLACK 152 | # Because state is initialized to all 0s, we only need to change the values of the pieces 153 | state_RR = torch.zeros((8, 8), dtype=torch.int) 154 | for i in range(64): 155 | if board.is_attacked_by(ATTACKING_COLOR, i): 156 | state_RR[i // 8, i % 8] = 1 157 | 158 | return state_RR 159 | 160 | 161 | def board_to_prev_state(board: chess.Board, skill: Optional[int] = None) -> torch.Tensor: 162 | """Given a chess board object, return an 8x8 torch.Tensor. 163 | The 8x8 array should tell what piece is on each square at a previous board state.""" 164 | 165 | PREVIOUS_TURNS = 25 166 | state_RR = torch.zeros((8, 8), dtype=torch.int) 167 | 168 | # If we cannot roll back PREVIOUS_TURNS, return a blank state 169 | # Predicting blank states is trivial, so be careful and change pos_start to not index into the blank states 170 | if len(board.move_stack) < PREVIOUS_TURNS: 171 | return state_RR 172 | 173 | new_board = board.copy() 174 | 175 | for _ in range(PREVIOUS_TURNS): 176 | new_board.pop() 177 | 178 | for i in range(64): 179 | piece = new_board.piece_at(i) 180 | if piece: 181 | piece_value = PIECE_TO_INT[piece.piece_type] 182 | # Multiply by -1 if the piece is black 183 | if piece.color == chess.BLACK: 184 | piece_value *= -1 185 | state_RR[i // 8, i % 8] = piece_value 186 | 187 | return state_RR 188 | 189 | 190 | def board_to_legal_moves_state(board: chess.Board, skill: Optional[int] = None) -> torch.Tensor: 191 | """Return an 8x8 torch.Tensor indicating squares where White has legal moves. 192 | 193 | Each square in the array is 1 if White can legally move a piece to that square, otherwise 0. 194 | In the 8x8 array, row 0 corresponds to A1-H1 (from White's perspective), row 1 to A2-H2, etc. 195 | """ 196 | MOVING_COLOR = chess.WHITE 197 | # Initialize the state array with all zeros 198 | state_RR = torch.zeros((8, 8), dtype=torch.int) 199 | 200 | # Iterate through all legal moves for White 201 | for move in board.legal_moves: 202 | # Check if the move is for a White piece 203 | if board.color_at(move.from_square) == MOVING_COLOR: 204 | # Update the state_RR array for the destination square of the move 205 | to_square = move.to_square 206 | state_RR[to_square // 8, to_square % 8] = 1 207 | 208 | return state_RR 209 | 210 | 211 | def board_to_last_self_move_state(board: chess.Board, skill: Optional[int] = None) -> torch.Tensor: 212 | """Given a chess board object, return an 8x8 torch.Tensor. 213 | All squares will be 0 except for the square where the last white move was made. 214 | In the 8x8 array, row 0 is A1-H1 (White), row 1 is A2-H2, etc. 215 | The purpose of this is to see if the linear probe can determine the next move of the GPT. 216 | To get next move instead of last move, we offset the state stack by 1 in linear_probe_forward_pass(): 217 | resid_post = resid_post[:, :-1, :] 218 | state_stack_one_hot = state_stack_one_hot[:, :, 1:, :, :, :] 219 | """ 220 | 221 | state_RR = torch.zeros((8, 8), dtype=torch.int) 222 | 223 | # If offset is 2, we are predicting the LLM's next move 224 | # If offset is 1, we are predicting the opponent's response to the LLM's next move 225 | offset = 2 226 | 227 | # If there is no last move (such as beginning of game), return the state as is 228 | if len(board.move_stack) < offset: 229 | return state_RR 230 | 231 | last_last_move = board.move_stack[-offset] 232 | destination_square = last_last_move.to_square 233 | moved_piece = board.piece_at(destination_square) 234 | if moved_piece is None: 235 | raise ValueError("Piece type is None") 236 | piece_value = PIECE_TO_INT[moved_piece.piece_type] 237 | if moved_piece.color == chess.BLACK: 238 | piece_value *= -1 239 | state_RR[destination_square // 8, destination_square % 8] = piece_value 240 | 241 | return state_RR 242 | 243 | 244 | def state_stack_to_chess_board(state_RR: torch.Tensor) -> chess.Board: 245 | """Given a state stack, return a chess.Board object. 246 | WARNING: The board will not include any information about whose turn it is, castling rights, en passant, etc. 247 | For this reason, pgn_string_to_board is preferred.""" 248 | board = chess.Board(fen=None) 249 | for row_idx, row in enumerate(state_RR): 250 | for col_idx, piece in enumerate(row): 251 | if piece != 0: 252 | piece_type = abs(piece) 253 | color = chess.WHITE if piece > 0 else chess.BLACK 254 | board.set_piece_at(chess.square(col_idx, row_idx), chess.Piece(piece_type, color)) 255 | return board 256 | 257 | 258 | def pgn_string_to_board(pgn_string: str) -> chess.Board: 259 | """Convert a PGN string to a chess.Board object. 260 | We are making an assumption that the PGN string is in this format: 261 | ;1.e4 e5 2. or ;1.e4 e5 2.Nf3""" 262 | board = chess.Board() 263 | for move in pgn_string.split(): 264 | if "." in move: 265 | move = move.split(".")[1] 266 | if move == "": 267 | continue 268 | board.push_san(move) 269 | return board 270 | 271 | 272 | def create_state_stack( 273 | moves_string: str, 274 | custom_board_to_state_fn: Callable[[chess.Board], torch.Tensor], 275 | skill: Optional[torch.Tensor] = None, 276 | ) -> torch.Tensor: 277 | """Given a string of PGN format moves, create an 8x8 torch.Tensor for every character in the string.""" 278 | 279 | board = chess.Board() 280 | initial_states_lRR = [] 281 | count = 1 282 | 283 | # Scan 1: Creates states, with length = number of moves in the game 284 | initial_states_lRR.append(custom_board_to_state_fn(board, skill).to(dtype=torch.int8)) 285 | # Apply each move to the board 286 | for move in moves_string.split(): 287 | try: 288 | count += 1 289 | # Skip move numbers 290 | if "." in move: 291 | board.push_san(move.split(".")[1]) 292 | else: 293 | board.push_san(move) 294 | 295 | initial_states_lRR.append(custom_board_to_state_fn(board, skill).to(dtype=torch.int8)) 296 | except: 297 | # because all games are truncated to len 680, often the last move is partial and invalid 298 | # so we don't need to log this, as it will happen on most games 299 | break 300 | 301 | # if count % 100 == 0: 302 | # pretty_print_state_stack(state) 303 | # print("_" * 50) 304 | # print(board) 305 | 306 | # Second Scan: Expand states to match the length of moves_string 307 | # For ;1.e4 e5 2.Nf3, ";1.e4" = idx 0, " e5" = idx 1, " 2.Nf3" = idx 2 308 | expanded_states_lRR = [] 309 | move_index = 0 310 | for char in moves_string: 311 | if char == " ": 312 | move_index += 1 313 | expanded_states_lRR.append(initial_states_lRR[min(move_index, len(initial_states_lRR) - 1)]) 314 | 315 | # expanded_states.append(initial_states[-1]) # The last element in expanded_states is the final position of the board. 316 | # Currently not using this as len(expanded_states) would be 1 greater than len(moves_string) and that would be confusing. 317 | return torch.stack(expanded_states_lRR) 318 | 319 | 320 | def create_state_stacks( 321 | moves_strings: list[str], 322 | custom_board_to_state_fn: Callable[[chess.Board], torch.Tensor], 323 | skill_array: Optional[torch.Tensor] = None, 324 | ) -> Float[Tensor, "modes sample_size pgn_str_length rows cols"]: 325 | """Given a list of strings of PGN format moves, create a tensor of shape (len(moves_strings), 8, 8). 326 | custom_board_to_state is a function that takes a chess.Board object and returns a 8x8 torch.Tensor for 327 | board state, or 1x1 for centipawn advantage.""" 328 | state_stacks_BlRR = [] 329 | skill = None 330 | 331 | for idx, pgn_string in enumerate(moves_strings): 332 | if skill_array is not None: 333 | skill = skill_array[idx] 334 | state_stack_lRR = create_state_stack(pgn_string, custom_board_to_state_fn, skill) 335 | 336 | state_stacks_BlRR.append(state_stack_lRR) 337 | 338 | # Convert the list of tensors to a single tensor 339 | final_state_stack_BlRR = torch.stack(state_stacks_BlRR) 340 | final_state_stack_MBlRR = final_state_stack_BlRR.unsqueeze(0) # Add a dimension for the modes 341 | # Currently, there is just one mode and it isn't necessary. For now, I'm maintaining the dimension for future use. 342 | return final_state_stack_MBlRR 343 | 344 | 345 | def state_stack_to_one_hot( 346 | num_modes: int, 347 | num_rows: int, 348 | num_cols: int, 349 | min_val: int, 350 | max_val: int, 351 | device: torch.device, 352 | state_stack_MBLRR: torch.Tensor, 353 | user_mapping: Optional[dict[int, int]] = None, 354 | ) -> Int[Tensor, "modes sample_size num_white_moves rows cols one_hot_range"]: 355 | """Input shape: assert(state_stacks_all_chars.shape) == (modes, sample_size, game_length, rows, cols) 356 | Output shape: assert(state_stacks_one_hot.shape) == (modes, sample_size, game_length, rows, cols, one_hot_range) 357 | """ 358 | range_size = max_val - min_val + 1 359 | 360 | mapping = {} 361 | if user_mapping: 362 | mapping = user_mapping 363 | min_val = min(mapping.values()) 364 | max_val = max(mapping.values()) 365 | range_size = max_val - min_val + 1 366 | else: 367 | for val in range(min_val, max_val + 1): 368 | mapping[val] = val - min_val 369 | 370 | # Initialize the one-hot tensor 371 | one_hot_MBLRRC = torch.zeros( 372 | state_stack_MBLRR.shape[0], # num modes 373 | state_stack_MBLRR.shape[1], # num games 374 | state_stack_MBLRR.shape[2], # num moves 375 | num_rows, 376 | num_cols, 377 | range_size, 378 | device=device, 379 | dtype=torch.int8, 380 | ) 381 | 382 | for val in mapping: 383 | one_hot_MBLRRC[..., mapping[val]] = state_stack_MBLRR == val 384 | 385 | return one_hot_MBLRRC 386 | 387 | 388 | def one_hot_to_state_stack(one_hot_MBLRRC: torch.Tensor, min_val: int) -> torch.Tensor: 389 | """We assume input shape UBLRRC, but it could work with other shapes.""" 390 | indices = torch.argmax(one_hot_MBLRRC, dim=-1) 391 | state_stack_MBLRR = indices + min_val 392 | return state_stack_MBLRR 393 | 394 | 395 | def square_to_coordinate(square: chess.Square) -> tuple[int, int]: 396 | row = chess.square_rank(square) 397 | column = chess.square_file(square) 398 | return (row, column) 399 | 400 | 401 | def find_dots_indices(moves_string: str) -> list[int]: 402 | """Returns a list of ints of indices of every '.' in the string. 403 | This will hopefully provide a reasonable starting point for training a linear probe. 404 | """ 405 | indices = [index for index, char in enumerate(moves_string) if char == "."] 406 | return indices 407 | 408 | 409 | def find_spaces_indices(moves_string: str) -> list[int]: 410 | """Returns a list of ints of indices of every ' ' in the string.""" 411 | indices = [index for index, char in enumerate(moves_string) if char == " "] 412 | return indices 413 | 414 | 415 | def get_all_white_pos_indices(moves_string: str) -> list[list[int]]: 416 | """From this pgn string: ;1.e4 c5 2.Nf3 d6 3.d4 cxd4 4.Qxd4 a6 5.Bc4 Nc6 6.Qd1... 417 | Return a list of lists of indices that correspond to the chars in parentheses: 418 | (;1.e4)< c5>( 2.Nf3)< d6>( 3.d4)< cxd4>( 4.Qxd4)< a6>( 5.Bc4)< Nc6>( 6.Qd1)""" 419 | space_indices = find_spaces_indices(moves_string) 420 | white_move_indices: list[list[int]] = [] 421 | start_index = 0 422 | 423 | if len(space_indices) == 0: 424 | return [list(range(0, len(moves_string)))] 425 | 426 | for i, space in enumerate(space_indices): 427 | if i % 2 == 1: 428 | start_index = space 429 | if i == len(space_indices) - 1: 430 | white_move_indices.append(list(range(start_index, len(moves_string)))) 431 | break 432 | continue 433 | white_move_indices.append(list(range(start_index, space))) 434 | return white_move_indices 435 | 436 | 437 | def get_all_black_pos_indices(moves_string: str) -> list[list[int]]: 438 | """From this pgn string: ;1.e4 c5 2.Nf3 d6 3.d4 cxd4 4.Qxd4 a6 5.Bc4 Nc6 6.Qd1... 439 | Return a list of lists of indices that correspond to the chars in brackets: 440 | (;1.e4)< c5>( 2.Nf3)< d6>( 3.d4)< cxd4>( 4.Qxd4)< a6>( 5.Bc4)< Nc6>( 6.Qd1)""" 441 | space_indices = find_spaces_indices(moves_string) 442 | black_move_indices: list[list[int]] = [] 443 | 444 | if len(space_indices) == 0: 445 | return [] 446 | 447 | start_index = space_indices[0] 448 | 449 | for i, space in enumerate(space_indices): 450 | if i % 2 == 0: 451 | start_index = space 452 | if i == len(space_indices) - 1: 453 | black_move_indices.append(list(range(start_index, len(moves_string)))) 454 | break 455 | continue 456 | black_move_indices.append(list(range(start_index, space))) 457 | return black_move_indices 458 | 459 | 460 | def find_odd_spaces_indices(moves_string: str) -> list[int]: 461 | """Returns a list of ints of odd indices of every ' ' in the string. 462 | There is some duplicated logic but it simplifies using the Callable function.""" 463 | indices = [index for index, char in enumerate(moves_string) if char == " "] 464 | # Select only the odd indices: start from index 1, go till the end, step by 2 465 | odd_indices = indices[1::2] 466 | return odd_indices 467 | 468 | 469 | def find_even_spaces_indices(moves_string: str) -> list[int]: 470 | """Returns a list of ints of even indices of every ' ' in the string. 471 | There is some duplicated logic but it simplifies using the Callable function.""" 472 | indices = [index for index, char in enumerate(moves_string) if char == " "] 473 | # Select only the even indices: start from index 0, go till the end, step by 2 474 | even_indices = indices[::2] 475 | return even_indices 476 | 477 | 478 | def find_dots_indices_offset_one(moves_string: str) -> list[int]: 479 | """Returns a list of ints of indices of every '.' in the string. 480 | This will hopefully provide a reasonable starting point for training a linear probe. 481 | """ 482 | indices = [index for index, char in enumerate(moves_string) if char == "."] 483 | 484 | incremented_indices = [index + 1 for index in indices if index + 1 < len(moves_string)] 485 | 486 | return incremented_indices 487 | 488 | 489 | def find_even_indices_offset_one(moves_string: str) -> list[int]: 490 | """ 491 | Returns a list of ints of even indices of every ' ' in the string, each incremented by one. 492 | If the incremented index would be greater than the length of the string, it is not included. 493 | """ 494 | indices = [index for index, char in enumerate(moves_string) if char == " "] 495 | even_indices = indices[::2] 496 | 497 | # Increment each even index by one, ensuring it doesn't exceed the string length 498 | incremented_indices = [index + 1 for index in even_indices if index + 1 < len(moves_string)] 499 | 500 | return incremented_indices 501 | 502 | 503 | def find_odd_indices_offset_one(moves_string: str) -> list[int]: 504 | """ 505 | Returns a list of ints of odd indices of every ' ' in the string, each incremented by one. 506 | If the incremented index would be greater than the length of the string, it is not included. 507 | """ 508 | indices = [index for index, char in enumerate(moves_string) if char == " "] 509 | odd_indices = indices[1::2] 510 | 511 | # Increment each odd index by one, ensuring it doesn't exceed the string length 512 | incremented_indices = [index + 1 for index in odd_indices if index + 1 < len(moves_string)] 513 | 514 | return incremented_indices 515 | 516 | 517 | def find_custom_indices(custom_indexing_fn: Callable, games_strs_Bl: list) -> torch.Tensor: 518 | 519 | shortest_length = 1e6 520 | custom_indices = [] 521 | for pgn in games_strs_Bl: 522 | indices = custom_indexing_fn(pgn) 523 | shortest_length = min(shortest_length, len(indices)) 524 | custom_indices.append(indices) 525 | print("Shortest length:", shortest_length) 526 | 527 | for i, indices in enumerate(custom_indices): 528 | custom_indices[i] = indices[:shortest_length] 529 | 530 | indices = torch.tensor(custom_indices, dtype=torch.int) 531 | 532 | return indices 533 | 534 | 535 | def encode_string(meta: dict, s: str) -> list[int]: 536 | """Encode a string into a list of integers.""" 537 | stoi = meta["stoi"] 538 | return [stoi[c] for c in s] 539 | 540 | 541 | def decode_list(meta: dict, l: list[int]) -> str: 542 | """Decode a list of integers into a string.""" 543 | itos = meta["itos"] 544 | return "".join([itos[i] for i in l]) 545 | 546 | 547 | # Adapted from nanogpt 548 | def get_model_move( 549 | model, 550 | meta: dict, 551 | idx: torch.Tensor, 552 | max_new_tokens: int = 7, 553 | temperature=1.0, 554 | block_size=1023, 555 | ): 556 | """Generate new tokens from a trained language model. If temperature is 0.0, greedy decoding is used. 557 | Otherwise, standard temperature based sampling is used.""" 558 | 559 | if temperature < 0: 560 | raise ValueError("temperature has to be non-negative") 561 | 562 | input_length = len(idx[0]) 563 | space_idx = encode_string(meta, " ")[0] 564 | with torch.inference_mode(): 565 | for _ in range(max_new_tokens): 566 | # if the sequence context is growing too long we must crop it at block_size 567 | idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:] 568 | if temperature == 0.0: 569 | # greedy decoding 570 | # model(idx_cond) is a tensor of shape (batch_size, sequence_length, vocab_size) 571 | # logits is a tensor of shape (batch_size, vocab_size) 572 | # idx_next is a tensor of shape (batch_size, 1) 573 | logits = model(idx_cond)[:, -1, :] 574 | idx_next = torch.argmax(logits, dim=-1).unsqueeze(-1) 575 | else: 576 | # forward the model to get the logits for the index in the sequence 577 | logits = model(idx_cond) 578 | # pluck the logits at the final step and scale by desired temperature 579 | logits = logits[:, -1, :] / temperature 580 | # apply softmax to convert logits to (normalized) probabilities 581 | probs = F.softmax(logits, dim=-1) 582 | # sample from the distribution 583 | idx_next = torch.multinomial(probs, num_samples=1) 584 | if idx_next[0] == space_idx: 585 | break 586 | # append sampled index to the running sequence and continue 587 | idx = torch.cat((idx, idx_next), dim=1) 588 | 589 | model_response = idx[:, input_length:] 590 | model_move = decode_list(meta, model_response[0].tolist()) 591 | return model_move 592 | 593 | 594 | class PlayerColor(Enum): 595 | WHITE = "White" 596 | BLACK = "Black" 597 | 598 | 599 | @dataclass 600 | class Config: 601 | min_val: int 602 | max_val: int 603 | custom_board_state_function: callable 604 | linear_probe_name: str 605 | custom_indexing_function: callable = find_dots_indices 606 | num_rows: int = 8 607 | num_cols: int = 8 608 | levels_of_interest: Optional[list[int]] = None 609 | column_name: str = None 610 | probing_for_skill: bool = False 611 | # pos_start indexes into custom_indexing_function. Example: if pos_start = 25, for find_dots_indices, selects everything after the first 25 moves 612 | pos_start: int = 0 613 | # If pos_end is None, it's set to the length of the shortest game in construct_linear_probe_data() 614 | pos_end: Optional[int] = None 615 | player_color: PlayerColor = PlayerColor.WHITE 616 | othello: bool = False 617 | 618 | 619 | piece_config = Config( 620 | min_val=-6, 621 | max_val=6, 622 | custom_board_state_function=board_to_piece_state, 623 | linear_probe_name="chess_piece_probe", 624 | ) 625 | 626 | pin_config = Config( 627 | min_val=0, 628 | max_val=1, 629 | custom_board_state_function=board_to_pin_state, 630 | num_rows=1, 631 | num_cols=1, 632 | linear_probe_name="chess_pin_probe", 633 | ) 634 | 635 | color_config = Config( 636 | min_val=-1, 637 | max_val=1, 638 | custom_board_state_function=board_to_piece_color_state, 639 | linear_probe_name="chess_color_probe", 640 | ) 641 | 642 | threat_config = Config( 643 | min_val=0, 644 | max_val=1, 645 | custom_board_state_function=board_to_threat_state, 646 | linear_probe_name="chess_threat_probe", 647 | ) 648 | 649 | legal_move_config = Config( 650 | min_val=0, 651 | max_val=1, 652 | custom_board_state_function=board_to_legal_moves_state, 653 | linear_probe_name="chess_legal_move_probe", 654 | ) 655 | 656 | prev_move_config = Config( 657 | min_val=-6, 658 | max_val=6, 659 | custom_board_state_function=board_to_prev_state, 660 | linear_probe_name="chess_prev_move_probe", 661 | pos_start=15, 662 | pos_end=16, 663 | ) 664 | 665 | random_config = Config( 666 | min_val=-1, 667 | max_val=1, 668 | custom_board_state_function=board_to_random_state, 669 | linear_probe_name="chess_random_probe", 670 | ) 671 | 672 | eval_config = Config( 673 | min_val=-1, 674 | max_val=1, 675 | custom_board_state_function=board_to_eval_state, 676 | linear_probe_name="chess_eval_probe", 677 | num_rows=1, 678 | num_cols=1, 679 | ) 680 | 681 | skill_config = Config( 682 | min_val=-2, 683 | max_val=20, 684 | custom_board_state_function=board_to_skill_state, 685 | linear_probe_name="chess_skill_probe", 686 | num_rows=1, 687 | num_cols=1, 688 | levels_of_interest=[0, 5], 689 | probing_for_skill=True, 690 | pos_start=25, 691 | ) 692 | 693 | othello_config = Config( 694 | min_val=-1, 695 | max_val=1, 696 | custom_board_state_function=othello_utils.games_batch_to_state_stack_mine_yours_BLRRC, 697 | linear_probe_name="othello_mine_yours_probe", 698 | othello=True, 699 | ) 700 | 701 | othello_valid_moves_config = Config( 702 | min_val=0, 703 | max_val=1, 704 | custom_board_state_function=othello_utils.games_batch_to_valid_moves_BLRRC, 705 | linear_probe_name="othello_valid_moves_probe", 706 | othello=True, 707 | ) 708 | 709 | 710 | def find_config_by_name(config_name: str) -> Config: 711 | """ 712 | Finds and returns the Config instance with a matching linear_probe_name. 713 | """ 714 | all_configs = [piece_config, color_config, random_config, skill_config, othello_config] 715 | for config in all_configs: 716 | if config.linear_probe_name == config_name: 717 | return config 718 | raise ValueError(f"Config with name {config_name} not found") 719 | 720 | 721 | def update_config_using_player_color( 722 | player_color: PlayerColor, config: Config, custom_function: Optional[Callable] = None 723 | ) -> Config: 724 | """Player color will determine which indexing function we use. In addition, we set player to white by default. 725 | If player is black, then we update the probe name as well.""" 726 | 727 | if custom_function: 728 | config.custom_indexing_function = custom_function 729 | config.player_color = player_color 730 | return config 731 | 732 | if player_color == PlayerColor.WHITE: 733 | config.custom_indexing_function = find_dots_indices 734 | config.player_color = player_color 735 | 736 | if player_color == PlayerColor.BLACK: 737 | config.linear_probe_name = config.linear_probe_name.replace("probe", "black_player_probe") 738 | config.custom_indexing_function = find_even_spaces_indices 739 | config.player_color = player_color 740 | 741 | return config 742 | 743 | 744 | def set_config_min_max_vals_and_column_name( 745 | config: Config, 746 | input_dataframe_file: str, 747 | dataset_prefix: str, 748 | ) -> Config: 749 | if config.levels_of_interest is not None or config.probing_for_skill: 750 | if dataset_prefix == "stockfish_": 751 | config.column_name = "player_two" 752 | elif "lichess_" in dataset_prefix: 753 | config.column_name = "WhiteEloBinIndex" 754 | else: 755 | return config 756 | df = pd.read_csv(input_dataframe_file) 757 | config.min_val = df[config.column_name].min() 758 | config.max_val = df[config.column_name].max() 759 | 760 | return config 761 | -------------------------------------------------------------------------------- /contrastive_activations/lichess_train_layer_12_pos_start_25_activations.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamkarvonen/chess_llm_interpretability/0f61e667fb8a809deda29e5db6c113a0a88f9998/contrastive_activations/lichess_train_layer_12_pos_start_25_activations.pt -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | 4 | # Except this file 5 | !.gitignore 6 | -------------------------------------------------------------------------------- /images/pawn_probe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamkarvonen/chess_llm_interpretability/0f61e667fb8a809deda29e5db6c113a0a88f9998/images/pawn_probe.png -------------------------------------------------------------------------------- /images/probe_acc_markers_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamkarvonen/chess_llm_interpretability/0f61e667fb8a809deda29e5db6c113a0a88f9998/images/probe_acc_markers_graph.png -------------------------------------------------------------------------------- /lichess_data_filtering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "from datasets import load_dataset\n", 11 | "import os" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "Our data begins as a bunch of PGN transcripts. However, to work in tensors we need all transcripts to be the same length. So, this file takes our PGNs and performs some filtering.\n", 19 | "\n", 20 | "This notebook has a very similar counterpart, `utils\\chess_gpt_eval_data_filtering.ipynb`. The lichess and chess_gpt_eval datasets have a different structure and different column names. For most peoples' needs, the lichess dataset alone should suffice, so I made two separate notebooks to keep this one simple.\n", 21 | "\n", 22 | "The output of this file is 4 different csv's:\n", 23 | "\n", 24 | "`lichess_100mb.csv`\" 100 MB of lichess PGN games, with every game also containing player Elo information.\n", 25 | "\n", 26 | "`lichess_100mb_filtered.csv`: We perform some filtering for game length, add player Elo bucket, and do some manipulation of the PGN string.\n", 27 | "\n", 28 | "`lichess_train.csv` and `lichess_test.csv` a 50 / 50 train / test split of `lichess_100mb_filtered.csv`, used for training and testing linear probes." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "DATA_DIR = \"data/\"\n", 45 | "prefix = \"lichess_\"\n", 46 | "\n", 47 | "\n", 48 | "input_file = f'{DATA_DIR}{prefix}100mb.csv'\n", 49 | "output_file = input_file.replace(\".csv\", \"_filtered.csv\")" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "First, we download the dataset if not present." 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "if not os.path.exists(input_file):\n", 66 | " dataset_path = \"adamkarvonen/chess_games\"\n", 67 | " file_path = f\"{prefix}100mb.zip\"\n", 68 | " # No idea why streaming=True is required to avoid an error here. Huggingface ¯\\_(ツ)_/¯\n", 69 | " dataset = load_dataset(dataset_path, data_files=file_path,streaming=True)\n", 70 | " df = pd.DataFrame(dataset['train'])\n", 71 | " df.to_csv(input_file, index=False)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "Our LLMs need a delimiter token \";\" at the beginning of every PGN string or it won't work as well." 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "df = pd.read_csv(input_file)\n", 88 | "\n", 89 | "def format_transcript(game: str) -> str:\n", 90 | " new_game = ';' + game\n", 91 | " return new_game\n", 92 | "\n", 93 | "df['transcript'] = df['transcript'].apply(format_transcript)\n", 94 | "\n", 95 | "for game in df.head()['transcript']:\n", 96 | " print(game)\n", 97 | " print()" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "Filter all games to be len 365. This means we discard anything under that length. I chose 365 because that's the 50% of df.describe(). I also count the number of moves (with x.split()) and discard anything below the 25th percentile. This makes it easier if I want to do any move based indexing." 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "len_df = df['transcript'].apply(lambda x: len(x))\n", 114 | "print(len_df.describe())\n", 115 | "\n", 116 | "game_length_in_chars = 365\n", 117 | "\n", 118 | "# Data setup. All games must have same length. 50% are >= 690 moves. I will discard all games less than 680, and truncate the rest to 680.\n", 119 | "filtered_df = df[df['transcript'].apply(lambda x: len(x) >= game_length_in_chars)].copy()\n", 120 | "filtered_df.loc[:, 'transcript'] = filtered_df['transcript'].apply(lambda x: x[:game_length_in_chars])\n", 121 | "\n", 122 | "len_df = filtered_df['transcript'].apply(lambda x: len(x))\n", 123 | "print(len_df.describe())\n", 124 | "\n", 125 | "move_count_df = filtered_df['transcript'].apply(lambda x: len(x.split()))\n", 126 | "move_count = move_count_df.describe()\n", 127 | "print(\"move count\", move_count_df.describe())\n", 128 | "quarter_percentile = move_count['25%']\n", 129 | "print(\"quarter percentile\", quarter_percentile)\n", 130 | "\n", 131 | "# Now I need to filter out games that are too short. I will discard all games less than 25th percentile moves.\n", 132 | "filtered_df = filtered_df[filtered_df['transcript'].apply(lambda x: len(x.split()) >= quarter_percentile)]\n", 133 | "print(filtered_df.describe())\n", 134 | "print(filtered_df.head())\n", 135 | "\n", 136 | "filtered_df.to_csv(output_file, index=False)\n", 137 | "\n", 138 | "move_count_df = filtered_df['transcript'].apply(lambda x: len(x.split()))\n", 139 | "print(move_count_df.describe())" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "print(len(filtered_df))\n", 149 | "print(filtered_df['WhiteElo'].describe())" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "For the classification task, I wanted some Elo bins for the probe to classify. This somewhat arbitrarily creates 6 different Elo bins." 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "import pandas as pd\n", 166 | "import numpy as np\n", 167 | "import matplotlib.pyplot as plt\n", 168 | "\n", 169 | "np.random.seed(0)\n", 170 | "\n", 171 | "# Function to create binned columns and bin index columns\n", 172 | "def create_binned_columns(df, column_name):\n", 173 | "\n", 174 | " # Ensure column is numeric and handle NaN values. Here, we choose to drop them, but you might fill them instead.\n", 175 | " if df[column_name].dtype.kind not in 'biufc' or pd.isnull(df[column_name]).any():\n", 176 | " df = df.dropna(subset=[column_name])\n", 177 | " df[column_name] = pd.to_numeric(df[column_name], errors='coerce')\n", 178 | "\n", 179 | " binned_column_name = f'{column_name}Binned'\n", 180 | " bin_index_column_name = f'{column_name}BinIndex'\n", 181 | " \n", 182 | " # Create quantile-based bins\n", 183 | " num_bins = 6\n", 184 | " # Create quantile-based bins with range labels, dropping duplicates if necessary\n", 185 | " df[binned_column_name], bins = pd.qcut(df[column_name], q=num_bins, retbins=True, duplicates='drop')\n", 186 | "\n", 187 | " # Convert bin labels to strings and assign to the column\n", 188 | " df[binned_column_name] = df[binned_column_name].apply(lambda x: f'({x.left}, {x.right}]')\n", 189 | "\n", 190 | " # Create bin index column\n", 191 | " df[bin_index_column_name] = pd.qcut(df[column_name], q=num_bins, labels=False, duplicates='drop')\n", 192 | "\n", 193 | "# Apply the function to both WhiteElo and BlackElo\n", 194 | "create_binned_columns(filtered_df, 'WhiteElo')\n", 195 | "create_binned_columns(filtered_df, 'BlackElo')\n", 196 | "\n", 197 | "filtered_df.to_csv(output_file, index=False)\n", 198 | "\n", 199 | "# Plotting\n", 200 | "fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(10, 8))\n", 201 | "\n", 202 | "# Histogram for WhiteElo\n", 203 | "axes[0].hist(filtered_df['WhiteElo'], bins=30, color='blue', alpha=0.7)\n", 204 | "axes[0].set_title('WhiteElo Distribution')\n", 205 | "axes[0].set_xlabel('WhiteElo')\n", 206 | "axes[0].set_ylabel('Frequency')\n", 207 | "\n", 208 | "# Bar chart for WhiteEloBinned\n", 209 | "bin_counts = filtered_df['WhiteEloBinned'].value_counts()\n", 210 | "axes[1].bar(bin_counts.index.astype(str), bin_counts.values, color='green', alpha=0.7)\n", 211 | "axes[1].set_title('WhiteElo Binned Distribution')\n", 212 | "axes[1].set_xlabel('WhiteElo Bins')\n", 213 | "axes[1].set_ylabel('Count')\n", 214 | "plt.xticks(rotation=45)\n", 215 | "\n", 216 | "plt.tight_layout()\n", 217 | "plt.show()\n", 218 | "\n" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "print(filtered_df['WhiteEloBinned'].value_counts())" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "print(filtered_df.head())" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "# shuffle all rows of the dataset\n", 246 | "\n", 247 | "df = pd.read_csv(output_file)\n", 248 | "df = df.sample(frac=1, random_state=200).reset_index(drop=True)\n", 249 | "df.to_csv(output_file, index=False)" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "import pandas as pd\n", 259 | "df = pd.read_csv(output_file)\n", 260 | "\n", 261 | "print(len(df))\n", 262 | "\n", 263 | "# Split df into a train and test split\n", 264 | "train = df.sample(frac=0.5, random_state=200)\n", 265 | "test = df.drop(train.index)\n", 266 | "\n", 267 | "print(len(train))\n", 268 | "print(len(test))\n", 269 | "\n", 270 | "# Save the train and test splits to csv\n", 271 | "train.to_csv(f'{DATA_DIR}{prefix}train.csv', index=False)\n", 272 | "test.to_csv(f'{DATA_DIR}{prefix}test.csv', index=False)" 273 | ] 274 | } 275 | ], 276 | "metadata": { 277 | "kernelspec": { 278 | "display_name": "othello", 279 | "language": "python", 280 | "name": "python3" 281 | }, 282 | "language_info": { 283 | "codemirror_mode": { 284 | "name": "ipython", 285 | "version": 3 286 | }, 287 | "file_extension": ".py", 288 | "mimetype": "text/x-python", 289 | "name": "python", 290 | "nbconvert_exporter": "python", 291 | "pygments_lexer": "ipython3", 292 | "version": "3.11.7" 293 | } 294 | }, 295 | "nbformat": 4, 296 | "nbformat_minor": 2 297 | } 298 | -------------------------------------------------------------------------------- /linear_probes/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | 4 | # But not these files/directories 5 | !/saved_probes/ 6 | !/view_probe.ipynb 7 | !/analyze_test_results.ipynb 8 | !/test_data/ 9 | !.gitignore 10 | -------------------------------------------------------------------------------- /linear_probes/analyze_test_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import pickle\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "\n", 13 | "PROBE_DIR = \"\"\n", 14 | "\n", 15 | "\n", 16 | "skill_files = [\n", 17 | " \"16layer_skill_probe_sweep_results\",\n", 18 | " \"8layer_skill_probe_sweep_results\",\n", 19 | " \"6layer_skill_probe_sweep_results\",\n", 20 | " \"8layer_15mbs200_skill_probe_sweep_results\",\n", 21 | " \"randominit_8layer_skill_probe_sweep_results\",\n", 22 | " \"randominit_16layer_skill_probe_sweep_results\",\n", 23 | "]\n", 24 | "piece_files = [\n", 25 | " \"16layer_piece_probe_sweep_results\",\n", 26 | " \"8layer_piece_probe_sweep_results\",\n", 27 | " \"6layer_piece_probe_sweep_results\",\n", 28 | " \"8layer_15m_piece_probe_sweep_results\",\n", 29 | " \"8layer_15mbs200_piece_probe_sweep_results\",\n", 30 | " \"randominit_8layer_piece_probe_sweep_results\",\n", 31 | " \"randominit_16layer_piece_probe_sweep_results\",\n", 32 | "]\n", 33 | "\n", 34 | "test_data_dir = os.path.join(PROBE_DIR, \"test_data\")\n", 35 | "\n", 36 | "\n", 37 | "skill_file_layer_data = {}\n", 38 | "piece_file_layer_data = {}\n", 39 | "\n", 40 | "\n", 41 | "def get_layer_data(folder_name):\n", 42 | " file_dir = os.path.join(test_data_dir, folder_name)\n", 43 | " # Step 1: List all pickle files\n", 44 | " pickle_files = [f for f in os.listdir(file_dir) if f.endswith(\".pkl\")]\n", 45 | " # Step 2 and 3: Open each file and calculate average accuracy\n", 46 | " average_accuracies = {}\n", 47 | " for file in pickle_files:\n", 48 | " with open(os.path.join(file_dir, file), \"rb\") as f:\n", 49 | " data = pickle.load(f)\n", 50 | " average_accuracy = sum(data[\"accuracy\"]) / len(data[\"accuracy\"])\n", 51 | " layer_num = int(file.split(\".\")[0].split(\"_\")[-1])\n", 52 | " average_accuracies[layer_num] = float(average_accuracy)\n", 53 | " return average_accuracies\n", 54 | "\n", 55 | "\n", 56 | "for skill_file in skill_files:\n", 57 | " skill_file_layer_data[skill_file] = get_layer_data(skill_file)\n", 58 | "for piece_file in piece_files:\n", 59 | " piece_file_layer_data[piece_file] = get_layer_data(piece_file)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "for skill_file in skill_file_layer_data:\n", 69 | " print(skill_file)\n", 70 | " # print the layer with the highest accuracy and its value\n", 71 | " max_layer = max(skill_file_layer_data[skill_file], key=skill_file_layer_data[skill_file].get)\n", 72 | " min_layer = min(skill_file_layer_data[skill_file], key=skill_file_layer_data[skill_file].get)\n", 73 | " print(max_layer, skill_file_layer_data[skill_file][max_layer])\n", 74 | " print(min_layer, skill_file_layer_data[skill_file][min_layer])\n", 75 | "\n", 76 | "for piece_file in piece_file_layer_data:\n", 77 | " print(piece_file)\n", 78 | " # print the layer with the highest accuracy and its value\n", 79 | " max_layer = max(piece_file_layer_data[piece_file], key=piece_file_layer_data[piece_file].get)\n", 80 | " min_layer = min(piece_file_layer_data[piece_file], key=piece_file_layer_data[piece_file].get)\n", 81 | " print(max_layer, piece_file_layer_data[piece_file][max_layer])\n", 82 | " print(min_layer, piece_file_layer_data[piece_file][min_layer])" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "for layer in skill_file_layer_data[skill_files[2]]:\n", 92 | " print(layer, skill_file_layer_data[skill_files[2]][layer])" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "import matplotlib.pyplot as plt\n", 102 | "\n", 103 | "plt.rcParams.update({\"font.size\": 16, \"axes.labelsize\": 16})\n", 104 | "plt.figure(figsize=(10, 8)) # Adjust the figure size as needed\n", 105 | "\n", 106 | "label_dict = {\n", 107 | " \"16layer_skill_probe_sweep_results\": \"16 layer model\",\n", 108 | " \"8layer_skill_probe_sweep_results\": \"8 layer model\",\n", 109 | " \"randominit_8layer_skill_probe_sweep_results\": \"8 layer randomized\",\n", 110 | " \"randominit_16layer_skill_probe_sweep_results\": \"16 layer randomized\",\n", 111 | " \"16layer_piece_probe_sweep_results\": \"16 layer model\",\n", 112 | " \"8layer_piece_probe_sweep_results\": \"8 layer model\",\n", 113 | " \"randominit_8layer_piece_probe_sweep_results\": \"8 layer randomized\",\n", 114 | " \"randominit_16layer_piece_probe_sweep_results\": \"16 layer randomized\",\n", 115 | "}\n", 116 | "\n", 117 | "for piece_file, layer_data in piece_file_layer_data.items():\n", 118 | " # Ensure the keys are sorted if they're not numeric or already in desired order\n", 119 | " keys = sorted(layer_data.keys())\n", 120 | " values = [layer_data[key] for key in keys]\n", 121 | "\n", 122 | " # increment keys by 1 to match the layer number\n", 123 | " keys = [key + 1 for key in keys]\n", 124 | "\n", 125 | " plt.plot(\n", 126 | " keys, values, label=label_dict[piece_file], linewidth=3, marker=\"o\"\n", 127 | " ) # Plot each line with label as piece_file\n", 128 | " plt.ylim(0.54, 1.0)\n", 129 | "\n", 130 | "plt.xlabel(\"Layer\", fontsize=20)\n", 131 | "plt.ylabel(\"Accuracy\", fontsize=20)\n", 132 | "plt.title(\"Probe Square Classification Accuracy per Layer\", fontsize=22)\n", 133 | "\n", 134 | "# Increase the legend font size\n", 135 | "plt.legend(fontsize=20) # Adjust fontsize as needed\n", 136 | "\n", 137 | "ax = plt.gca() # Get current axes\n", 138 | "ax.spines[\"top\"].set_linewidth(2)\n", 139 | "ax.spines[\"bottom\"].set_linewidth(2)\n", 140 | "ax.spines[\"left\"].set_linewidth(2)\n", 141 | "ax.spines[\"right\"].set_linewidth(2)\n", 142 | "\n", 143 | "# plt.show()\n", 144 | "plt.savefig(\"board_state_line_plot.png\")" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "import matplotlib.pyplot as plt\n", 154 | "import numpy as np\n", 155 | "\n", 156 | "# Number of dictionaries\n", 157 | "n = len(piece_file_layer_data)\n", 158 | "# All unique keys sorted, assuming all sub-dicts have the same keys for simplicity\n", 159 | "keys = sorted({key for d in piece_file_layer_data.values() for key in d})\n", 160 | "# Total number of unique keys\n", 161 | "total_keys = len(keys)\n", 162 | "# Width of each bar\n", 163 | "width = 1 / (n + 1)\n", 164 | "# Creating a figure and axis\n", 165 | "plt.figure(figsize=(10, 8))\n", 166 | "\n", 167 | "# Enumerating over each piece_file and its corresponding dictionary\n", 168 | "for i, (piece_file, layer_data) in enumerate(piece_file_layer_data.items(), start=1):\n", 169 | " # Calculating offsets for each bar to not overlap\n", 170 | " offsets = np.arange(len(keys)) + width * i\n", 171 | " # Getting values in the order of sorted keys\n", 172 | " values = [layer_data.get(key, 0) for key in keys] # Default to 0 if key not found\n", 173 | "\n", 174 | " # Plotting the bars with an offset\n", 175 | " plt.bar(offsets, values, width=width, label=piece_file.split('_')[0].replace('layer', ' layers'))\n", 176 | "\n", 177 | " # Set min and max y limits\n", 178 | " plt.ylim(0.6, 1.0)\n", 179 | "\n", 180 | "# Adjusting the x-ticks to be in the middle of the groups and setting the keys as labels\n", 181 | "plt.xticks(np.arange(len(keys)) + width * n / 2, keys)\n", 182 | "plt.xlabel('Layer')\n", 183 | "plt.ylabel('Accuracy')\n", 184 | "plt.title('Probe Board State Accuracy per Layer')\n", 185 | "plt.legend()\n", 186 | "# plt.show()\n", 187 | "plt.savefig('board_state_bar_graph.png')" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "import matplotlib.pyplot as plt\n", 197 | "\n", 198 | "plt.rcParams.update({'font.size': 16, 'axes.labelsize': 16})\n", 199 | "plt.figure(figsize=(10, 8)) # Adjust the figure size as needed\n", 200 | "\n", 201 | "plt.figure(figsize=(10, 8)) # Adjust the figure size as needed\n", 202 | "\n", 203 | "for skill_file, layer_data in skill_file_layer_data.items():\n", 204 | " # Ensure the keys are sorted if they're not numeric or already in desired order\n", 205 | " keys = sorted(layer_data.keys())\n", 206 | " values = [layer_data[key] for key in keys]\n", 207 | "\n", 208 | " keys = [key + 1 for key in keys]\n", 209 | "\n", 210 | " plt.plot(keys, values, label=label_dict[skill_file], linewidth=3, marker=\"o\")\n", 211 | " plt.ylim(0.54, 1.0)\n", 212 | "\n", 213 | "plt.xlabel('Layer', fontsize=20)\n", 214 | "plt.ylabel('Accuracy', fontsize=20)\n", 215 | "plt.title('Probe Elo Classification Accuracy per Layer', fontsize=22)\n", 216 | "\n", 217 | "# Increase the legend font size\n", 218 | "# plt.legend(fontsize=20) # Adjust fontsize as needed\n", 219 | "\n", 220 | "ax = plt.gca() # Get current axes\n", 221 | "ax.spines['top'].set_linewidth(2)\n", 222 | "ax.spines['bottom'].set_linewidth(2)\n", 223 | "ax.spines['left'].set_linewidth(2)\n", 224 | "ax.spines['right'].set_linewidth(2)\n", 225 | "\n", 226 | "# plt.show()\n", 227 | "plt.savefig('elo_line_plot.png')" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "import matplotlib.pyplot as plt\n", 237 | "import numpy as np\n", 238 | "\n", 239 | "# Number of dictionaries\n", 240 | "n = len(skill_file_layer_data)\n", 241 | "# All unique keys sorted, assuming all sub-dicts have the same keys for simplicity\n", 242 | "keys = sorted({key for d in skill_file_layer_data.values() for key in d})\n", 243 | "# Total number of unique keys\n", 244 | "total_keys = len(keys)\n", 245 | "# Width of each bar\n", 246 | "width = 1 / (n + 1)\n", 247 | "# Creating a figure and axis\n", 248 | "plt.figure(figsize=(10, 8))\n", 249 | "\n", 250 | "# Enumerating over each skill_file and its corresponding dictionary\n", 251 | "for i, (skill_file, layer_data) in enumerate(skill_file_layer_data.items(), start=1):\n", 252 | " # Calculating offsets for each bar to not overlap\n", 253 | " offsets = np.arange(len(keys)) + width * i\n", 254 | " # Getting values in the order of sorted keys\n", 255 | " values = [layer_data.get(key, 0) for key in keys] # Default to 0 if key not found\n", 256 | "\n", 257 | " # Plotting the bars with an offset\n", 258 | " plt.bar(offsets, values, width=width, label=skill_file.split('_')[0].replace('layer', ' layers'))\n", 259 | "\n", 260 | " # Set min and max y limits\n", 261 | " plt.ylim(0.6, 1.0)\n", 262 | "\n", 263 | "# Adjusting the x-ticks to be in the middle of the groups and setting the keys as labels\n", 264 | "plt.xticks(np.arange(len(keys)) + width * n / 2, keys)\n", 265 | "plt.xlabel('Layer')\n", 266 | "plt.ylabel('Accuracy')\n", 267 | "plt.title('Probe Elo Classification Accuracy per Layer')\n", 268 | "plt.legend()\n", 269 | "# plt.show()\n", 270 | "plt.savefig('elo_bar_graph.png')" 271 | ] 272 | } 273 | ], 274 | "metadata": { 275 | "kernelspec": { 276 | "display_name": "othello", 277 | "language": "python", 278 | "name": "python3" 279 | }, 280 | "language_info": { 281 | "codemirror_mode": { 282 | "name": "ipython", 283 | "version": 3 284 | }, 285 | "file_extension": ".py", 286 | "mimetype": "text/x-python", 287 | "name": "python", 288 | "nbconvert_exporter": "python", 289 | "pygments_lexer": "ipython3", 290 | "version": "3.11.7" 291 | } 292 | }, 293 | "nbformat": 4, 294 | "nbformat_minor": 2 295 | } 296 | -------------------------------------------------------------------------------- /linear_probes/saved_probes/tf_lens_lichess_16layers_ckpt_no_optimizer_chess_piece_probe_layer_11.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamkarvonen/chess_llm_interpretability/0f61e667fb8a809deda29e5db6c113a0a88f9998/linear_probes/saved_probes/tf_lens_lichess_16layers_ckpt_no_optimizer_chess_piece_probe_layer_11.pth -------------------------------------------------------------------------------- /linear_probes/saved_probes/tf_lens_lichess_8layers_ckpt_no_optimizer_chess_piece_probe_layer_5.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamkarvonen/chess_llm_interpretability/0f61e667fb8a809deda29e5db6c113a0a88f9998/linear_probes/saved_probes/tf_lens_lichess_8layers_ckpt_no_optimizer_chess_piece_probe_layer_5.pth -------------------------------------------------------------------------------- /linear_probes/saved_probes/tf_lens_lichess_8layers_ckpt_no_optimizer_chess_skill_probe_layer_5.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamkarvonen/chess_llm_interpretability/0f61e667fb8a809deda29e5db6c113a0a88f9998/linear_probes/saved_probes/tf_lens_lichess_8layers_ckpt_no_optimizer_chess_skill_probe_layer_5.pth -------------------------------------------------------------------------------- /linear_probes/saved_probes/tf_lens_lichess_8layers_ckpt_no_optimizer_chess_skill_probe_layer_7.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamkarvonen/chess_llm_interpretability/0f61e667fb8a809deda29e5db6c113a0a88f9998/linear_probes/saved_probes/tf_lens_lichess_8layers_ckpt_no_optimizer_chess_skill_probe_layer_7.pth -------------------------------------------------------------------------------- /linear_probes/saved_probes/tf_lens_randominit_8layers_ckpt_no_optimizer_chess_piece_probe_layer_5.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamkarvonen/chess_llm_interpretability/0f61e667fb8a809deda29e5db6c113a0a88f9998/linear_probes/saved_probes/tf_lens_randominit_8layers_ckpt_no_optimizer_chess_piece_probe_layer_5.pth -------------------------------------------------------------------------------- /linear_probes/saved_probes/tf_lens_randominit_8layers_ckpt_no_optimizer_chess_skill_probe_layer_7.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamkarvonen/chess_llm_interpretability/0f61e667fb8a809deda29e5db6c113a0a88f9998/linear_probes/saved_probes/tf_lens_randominit_8layers_ckpt_no_optimizer_chess_skill_probe_layer_7.pth -------------------------------------------------------------------------------- /linear_probes/test_data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | 4 | # Except this file 5 | !.gitignore 6 | !saved_figures/ -------------------------------------------------------------------------------- /linear_probes/view_probe.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "probe_name = \"tf_lens_lichess_16layers_ckpt_no_optimizer_chess_skill_probe_layer_11.pth\"\n", 11 | "with open(probe_name, 'rb') as f:\n", 12 | " state_dict = torch.load(f, map_location=torch.device('cpu'))\n", 13 | " print(state_dict.keys())\n", 14 | " for key in state_dict.keys():\n", 15 | " if key != \"linear_probe\":\n", 16 | " print(key, state_dict[key])" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "# This is used to find contrastive activations from a given linear probe\n", 26 | "\n", 27 | "print(state_dict['linear_probe'].shape)\n", 28 | "low_activation = state_dict['linear_probe'][..., 0].squeeze()\n", 29 | "high_activation = state_dict['linear_probe'][..., 1].squeeze()\n", 30 | "print(low_activation.shape)\n", 31 | "print(high_activation.shape)\n", 32 | "difference_vector = high_activation - low_activation\n", 33 | "state_dict['average_high_elo_activation'] = high_activation\n", 34 | "state_dict['average_low_elo_activation'] = low_activation\n", 35 | "state_dict['difference_vector'] = difference_vector\n", 36 | "\n", 37 | "torch.save(state_dict, probe_name)" 38 | ] 39 | } 40 | ], 41 | "metadata": { 42 | "kernelspec": { 43 | "display_name": "othello", 44 | "language": "python", 45 | "name": "python3" 46 | }, 47 | "language_info": { 48 | "codemirror_mode": { 49 | "name": "ipython", 50 | "version": 3 51 | }, 52 | "file_extension": ".py", 53 | "mimetype": "text/x-python", 54 | "name": "python", 55 | "nbconvert_exporter": "python", 56 | "pygments_lexer": "ipython3", 57 | "version": "3.10.13" 58 | } 59 | }, 60 | "nbformat": 4, 61 | "nbformat_minor": 2 62 | } 63 | -------------------------------------------------------------------------------- /model_setup.py: -------------------------------------------------------------------------------- 1 | # For nanogpt to transformer lens conversion 2 | import torch 3 | import einops 4 | 5 | import transformer_lens.utils as utils 6 | from transformer_lens import ( 7 | HookedTransformer, 8 | HookedTransformerConfig, 9 | ) 10 | 11 | import os 12 | 13 | # Our pytorch model is in the nanogpt format. For easy linear probing of the residual stream, we want to convert 14 | # it to the transformer lens format. This is done in the following code block. 15 | # This code was developed using Neel Nanda's othello_reference/Othello_GPT.ipynb as a reference. 16 | 17 | torch.set_grad_enabled(False) 18 | 19 | LOAD_AND_CONVERT_CHECKPOINT = True 20 | 21 | device = "cpu" 22 | 23 | MODEL_DIR = "models/" 24 | 25 | n_heads = 8 26 | n_layers = 8 27 | d_model = 512 28 | 29 | model_name = f"lichess_{n_layers}layers_ckpt_no_optimizer.pt" 30 | 31 | 32 | assert str(n_layers) in model_name 33 | 34 | if not os.path.exists(f"{MODEL_DIR}{model_name}"): 35 | state_dict = utils.download_file_from_hf("adamkarvonen/chess_llms", model_name) 36 | model = torch.load(state_dict, map_location=device) 37 | torch.save(model, f"{MODEL_DIR}{model_name}") 38 | 39 | 40 | checkpoint = torch.load(f"{MODEL_DIR}{model_name}", map_location=device) 41 | 42 | # Print the keys of the checkpoint dictionary 43 | print(checkpoint.keys()) 44 | model_state = checkpoint["model"] 45 | # for key, value in model_state.items(): 46 | # print(key, value.shape) 47 | 48 | 49 | def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig, bias: bool = False): 50 | """For https://github.com/karpathy/nanoGPT 51 | There are two complications with converting nanogpt models: 52 | The first is that some state dicts have an unwanted prefix on keys that needs to be removed. 53 | The second is that the models can be saved with or without bias. By default, there 54 | is no bias. This function can handle both cases.""" 55 | # Nanogpt models saved after torch.compile() have this unwanted prefix 56 | # This is a simple way to remove it 57 | unwanted_prefix = "_orig_mod." 58 | for k, v in list(old_state_dict.items()): 59 | if k.startswith(unwanted_prefix): 60 | old_state_dict[k[len(unwanted_prefix) :]] = old_state_dict.pop(k) 61 | 62 | new_state_dict = {} 63 | new_state_dict["pos_embed.W_pos"] = old_state_dict["transformer.wpe.weight"] 64 | new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"] 65 | 66 | new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"] 67 | new_state_dict["ln_final.b"] = torch.zeros_like(old_state_dict["transformer.ln_f.weight"]) 68 | new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T 69 | 70 | if bias: 71 | new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"] 72 | 73 | for layer in range(cfg.n_layers): 74 | layer_key = f"transformer.h.{layer}" 75 | 76 | new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[f"{layer_key}.ln_1.weight"] 77 | # A bias of zeros is required for folding layer norm 78 | new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like( 79 | old_state_dict[f"{layer_key}.ln_1.weight"] 80 | ) 81 | new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[f"{layer_key}.ln_2.weight"] 82 | new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like( 83 | old_state_dict[f"{layer_key}.ln_2.weight"] 84 | ) 85 | 86 | W = old_state_dict[f"{layer_key}.attn.c_attn.weight"] 87 | W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0) 88 | W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) 89 | W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) 90 | W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) 91 | new_state_dict[f"blocks.{layer}.attn.W_Q"] = W_Q 92 | new_state_dict[f"blocks.{layer}.attn.W_K"] = W_K 93 | new_state_dict[f"blocks.{layer}.attn.W_V"] = W_V 94 | 95 | W_O = old_state_dict[f"{layer_key}.attn.c_proj.weight"] 96 | W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 97 | new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O 98 | 99 | new_state_dict[f"blocks.{layer}.mlp.W_in"] = old_state_dict[ 100 | f"{layer_key}.mlp.c_fc.weight" 101 | ].T 102 | new_state_dict[f"blocks.{layer}.mlp.W_out"] = old_state_dict[ 103 | f"{layer_key}.mlp.c_proj.weight" 104 | ].T 105 | 106 | if bias: 107 | new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[f"{layer_key}.ln_1.bias"] 108 | new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[f"{layer_key}.ln_2.bias"] 109 | new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[ 110 | f"{layer_key}.mlp.c_fc.bias" 111 | ] 112 | new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[ 113 | f"{layer_key}.mlp.c_proj.bias" 114 | ] 115 | 116 | B = old_state_dict[f"{layer_key}.attn.c_attn.bias"] 117 | B_Q, B_K, B_V = torch.tensor_split(B, 3, dim=0) 118 | B_Q = einops.rearrange(B_Q, "(i h)->i h", i=cfg.n_heads) 119 | B_K = einops.rearrange(B_K, "(i h)->i h", i=cfg.n_heads) 120 | B_V = einops.rearrange(B_V, "(i h)->i h", i=cfg.n_heads) 121 | new_state_dict[f"blocks.{layer}.attn.b_Q"] = B_Q 122 | new_state_dict[f"blocks.{layer}.attn.b_K"] = B_K 123 | new_state_dict[f"blocks.{layer}.attn.b_V"] = B_V 124 | new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[ 125 | f"{layer_key}.attn.c_proj.bias" 126 | ] 127 | 128 | return new_state_dict 129 | 130 | 131 | if LOAD_AND_CONVERT_CHECKPOINT: 132 | synthetic_checkpoint = model_state 133 | for name, param in synthetic_checkpoint.items(): 134 | if name.startswith("_orig_mod.transformer.h.0") or not name.startswith( 135 | "_orig_mod.transformer.h" 136 | ): 137 | print(name, param.shape) 138 | 139 | cfg = HookedTransformerConfig( 140 | n_layers=n_layers, 141 | d_model=d_model, 142 | d_head=int(d_model / n_heads), 143 | n_heads=n_heads, 144 | d_mlp=d_model * 4, 145 | d_vocab=32, 146 | n_ctx=1023, 147 | act_fn="gelu", 148 | normalization_type="LNPre", 149 | ) 150 | model = HookedTransformer(cfg) 151 | model.to(device) 152 | 153 | model.load_and_process_state_dict(convert_nanogpt_weights(synthetic_checkpoint, cfg)) 154 | recorded_model_name = model_name.split(".")[0] 155 | torch.save(model.state_dict(), f"{MODEL_DIR}tf_lens_{recorded_model_name}.pth") 156 | 157 | # An example input 158 | sample_input = torch.tensor([[15, 6, 4, 27, 9, 0, 25, 10, 0, 7, 4, 19]]).to(device) 159 | # sample_input = torch.tensor([[15, 6, 4, 27, 9]]) 160 | # The argmax of the output (ie the most likely next move from each position) 161 | sample_output = torch.tensor([[6, 4, 27, 9, 0, 27, 10, 0, 7, 4, 19, 28]]) 162 | model_output = model(sample_input).argmax(dim=-1) 163 | print(model_output) 164 | print(sample_output == model_output) 165 | 166 | # For this particular sample_input, any model with decent chess skill should output sample_output. 167 | # So, this assert will definitely fail for a randomly initialized model, and may fail for models with low skill. 168 | # But, I've never seen that happen, so I'm keeping it simple for now. For a more robust test, use the nanogpt_to_transformer_lens.ipynb notebook. 169 | # This notebook actually runs the sample input through the original nanogpt model, and then through the converted transformer lens model. 170 | assert torch.all(sample_output == model_output) 171 | -------------------------------------------------------------------------------- /models/meta.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adamkarvonen/chess_llm_interpretability/0f61e667fb8a809deda29e5db6c113a0a88f9998/models/meta.pkl -------------------------------------------------------------------------------- /models/view_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "with open(f'lichess_stockfish_mix_8layers_ckpt_with_optimizer.pt', 'rb') as f:\n", 11 | " state_dict = torch.load(f, map_location=torch.device('cpu'))\n", 12 | " print(state_dict.keys())\n", 13 | " for key in state_dict.keys():\n", 14 | " if key != \"model\" and key != \"optimizer\":\n", 15 | " print(key, state_dict[key])" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import pickle\n", 25 | "# meta is used to encode the string pgn strings into integer sequences\n", 26 | "with open(\"meta.pkl\", \"rb\") as f:\n", 27 | " meta = pickle.load(f)\n", 28 | "\n", 29 | "print(meta)\n", 30 | "\n", 31 | "stoi, itos = meta[\"stoi\"], meta[\"itos\"]\n", 32 | "encode = lambda s: [stoi[c] for c in s]\n", 33 | "decode = lambda l: \"\".join([itos[i] for i in l])\n", 34 | "\n", 35 | "print(encode(\"1.e4 e6 2.Nf3 d5 3.Nc3 d4 4.Ne2 c5 5.c3 d3 6.Nf4 c4 7.Qa4+ Bd7 8.Qxc4 Nf6 9.e5 Ng4 10.h3 Nxf2 11.Kxf2 Qb6+ 12.Ke1 Bb5 13.Qc8+ Ke7 14.Bxd3 Bd7 15.Qc4 Nc6 16.Be4 Rc8 17.Qb3 Qc7 18.d4 Rb8 19.Be3 Na5 20.Qd1 g6 21.Bd3 Bg7 22.Rf1 Nc6 23.Kf2 Rhe8 24.Kg1 h6 25.Rc1 g5 26.Nh5 Bh8 27.Nd2 Qb6 28.Nf6 Red8 29.Nxd7 Rxd7 30.Qf3 Qxb2 31.Qxf7+ Kd8 32.Qf8+\"))\n", 36 | "print(decode(encode(\";1.e4 \")))" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "The following cells can be used to modify a checkpoint, such as by removing the optimizer or adding dataset metadata." 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "import torch\n", 53 | "\n", 54 | "model_name = \"lichess_8layers_results_ckpt_with_optimizer.pt\"\n", 55 | "checkpoint = torch.load(model_name, map_location='cpu')\n", 56 | "print(checkpoint.keys())" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "\n", 66 | "# Add the new key-value pair to the checkpoint\n", 67 | "checkpoint[\"dataset\"] = \"lichess_6gb_results_blocks.zip\"\n", 68 | "\n", 69 | "# Save the modified checkpoint\n", 70 | "torch.save(checkpoint, model_name)\n" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "state_dict = torch.load(model_name, map_location='cpu')\n", 80 | "print(state_dict.keys())\n", 81 | "\n", 82 | "del state_dict['optimizer']\n", 83 | "print(state_dict.keys())\n", 84 | "\n", 85 | "new_model_name = model_name.replace('with_optimizer.pt', 'no_optimizer.pt')\n", 86 | "torch.save(state_dict, new_model_name)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "print(checkpoint['dataset'])\n", 96 | "print(checkpoint['best_val_loss'])\n", 97 | "print(checkpoint['iter_num'])" 98 | ] 99 | } 100 | ], 101 | "metadata": { 102 | "kernelspec": { 103 | "display_name": "othello", 104 | "language": "python", 105 | "name": "python3" 106 | }, 107 | "language_info": { 108 | "codemirror_mode": { 109 | "name": "ipython", 110 | "version": 3 111 | }, 112 | "file_extension": ".py", 113 | "mimetype": "text/x-python", 114 | "name": "python", 115 | "nbconvert_exporter": "python", 116 | "pygments_lexer": "ipython3", 117 | "version": "3.10.13" 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 2 122 | } 123 | -------------------------------------------------------------------------------- /othello_engine_utils.py: -------------------------------------------------------------------------------- 1 | # Copy of https://github.com/likenneth/othello_world/blob/master/mechanistic_interpretability/mech_interp_othello_utils.py 2 | 3 | # %% 4 | import os 5 | import math 6 | import time 7 | from tqdm import tqdm 8 | import numpy as np 9 | from copy import deepcopy 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | from functools import partial 14 | from matplotlib.pyplot import imshow 15 | 16 | torch.set_grad_enabled(True) 17 | # A class to calculate the Othello Board State, shamelessly ripped from Kenneth Li's code base 18 | 19 | rows = list("abcdefgh") 20 | columns = [str(_) for _ in range(1, 9)] 21 | 22 | 23 | def permit(s): 24 | s = s.lower() 25 | if len(s) != 2: 26 | return -1 27 | if s[0] not in rows or s[1] not in columns: 28 | return -1 29 | return rows.index(s[0]) * 8 + columns.index(s[1]) 30 | 31 | 32 | def permit_reverse(integer): 33 | r, c = integer // 8, integer % 8 34 | return "".join([rows[r], columns[c]]) 35 | 36 | 37 | start_hands = [permit(_) for _ in ["d5", "d4", "e4", "e5"]] 38 | eights = [[-1, 0], [-1, 1], [0, 1], [1, 1], [1, 0], [1, -1], [0, -1], [-1, -1]] 39 | 40 | 41 | class OthelloBoardState: 42 | # 1 is black, -1 is white 43 | def __init__(self, board_size=8): 44 | self.board_size = board_size * board_size 45 | board = np.zeros((8, 8)) 46 | board[3, 4] = 1 47 | board[3, 3] = -1 48 | board[4, 3] = 1 49 | board[4, 4] = -1 50 | self.initial_state = board 51 | self.state = self.initial_state 52 | self.age = np.zeros((8, 8)) 53 | self.next_hand_color = 1 54 | self.history = [] 55 | 56 | def get_occupied( 57 | self, 58 | ): 59 | board = self.state 60 | tbr = board.flatten() != 0 61 | return tbr.tolist() 62 | 63 | def get_state( 64 | self, 65 | ): 66 | board = self.state + 1 # white 0, blank 1, black 2 67 | tbr = board.flatten() 68 | return tbr.tolist() 69 | 70 | def get_age( 71 | self, 72 | ): 73 | return self.age.flatten().tolist() 74 | 75 | def get_next_hand_color( 76 | self, 77 | ): 78 | return (self.next_hand_color + 1) // 2 79 | 80 | def update(self, moves, prt=False): 81 | # takes a new move or new moves and update state 82 | if prt: 83 | self.__print__() 84 | for _, move in enumerate(moves): 85 | self.umpire(move) 86 | if prt: 87 | self.__print__() 88 | 89 | def umpire(self, move): 90 | r, c = move // 8, move % 8 91 | assert self.state[r, c] == 0, f"{r}-{c} is already occupied!" 92 | occupied = np.sum(self.state != 0) 93 | color = self.next_hand_color 94 | tbf = [] 95 | for direction in eights: 96 | buffer = [] 97 | cur_r, cur_c = r, c 98 | while 1: 99 | cur_r, cur_c = cur_r + direction[0], cur_c + direction[1] 100 | if cur_r < 0 or cur_r > 7 or cur_c < 0 or cur_c > 7: 101 | break 102 | if self.state[cur_r, cur_c] == 0: 103 | break 104 | elif self.state[cur_r, cur_c] == color: 105 | tbf.extend(buffer) 106 | break 107 | else: 108 | buffer.append([cur_r, cur_c]) 109 | if len(tbf) == 0: # means one hand is forfeited 110 | # print(f"One {color} move forfeited") 111 | color *= -1 112 | self.next_hand_color *= -1 113 | for direction in eights: 114 | buffer = [] 115 | cur_r, cur_c = r, c 116 | while 1: 117 | cur_r, cur_c = cur_r + direction[0], cur_c + direction[1] 118 | if cur_r < 0 or cur_r > 7 or cur_c < 0 or cur_c > 7: 119 | break 120 | if self.state[cur_r, cur_c] == 0: 121 | break 122 | elif self.state[cur_r, cur_c] == color: 123 | tbf.extend(buffer) 124 | break 125 | else: 126 | buffer.append([cur_r, cur_c]) 127 | if len(tbf) == 0: 128 | valids = self.get_valid_moves() 129 | if len(valids) == 0: 130 | assert 0, "Both color cannot put piece, game should have ended!" 131 | else: 132 | assert 0, "Illegal move!" 133 | 134 | self.age += 1 135 | for ff in tbf: 136 | self.state[ff[0], ff[1]] *= -1 137 | self.age[ff[0], ff[1]] = 0 138 | self.state[r, c] = color 139 | self.age[r, c] = 0 140 | self.next_hand_color *= -1 141 | self.history.append(move) 142 | 143 | def __print__( 144 | self, 145 | ): 146 | print("-" * 20) 147 | print([permit_reverse(_) for _ in self.history]) 148 | a = "abcdefgh" 149 | for k, row in enumerate(self.state.tolist()): 150 | tbp = [] 151 | for ele in row: 152 | if ele == -1: 153 | tbp.append("O") 154 | elif ele == 0: 155 | tbp.append(" ") 156 | else: 157 | tbp.append("X") 158 | # tbp.append("\n") 159 | print(" ".join([a[k]] + tbp)) 160 | tbp = [str(k) for k in range(1, 9)] 161 | print(" ".join([" "] + tbp)) 162 | print("-" * 20) 163 | 164 | def tentative_move(self, move): 165 | # tentatively put a piece, do nothing to state 166 | # returns 0 if this is not a move at all: occupied or both player have to forfeit 167 | # return 1 if regular move 168 | # return 2 if forfeit happens but the opponent can drop piece at this place 169 | r, c = move // 8, move % 8 170 | if not self.state[r, c] == 0: 171 | return 0 172 | occupied = np.sum(self.state != 0) 173 | color = self.next_hand_color 174 | tbf = [] 175 | for direction in eights: 176 | buffer = [] 177 | cur_r, cur_c = r, c 178 | while 1: 179 | cur_r, cur_c = cur_r + direction[0], cur_c + direction[1] 180 | if cur_r < 0 or cur_r > 7 or cur_c < 0 or cur_c > 7: 181 | break 182 | if self.state[cur_r, cur_c] == 0: 183 | break 184 | elif self.state[cur_r, cur_c] == color: 185 | tbf.extend(buffer) 186 | break 187 | else: 188 | buffer.append([cur_r, cur_c]) 189 | if len(tbf) != 0: 190 | return 1 191 | else: # means one hand is forfeited 192 | # print(f"One {color} move forfeited") 193 | color *= -1 194 | # self.next_hand_color *= -1 195 | for direction in eights: 196 | buffer = [] 197 | cur_r, cur_c = r, c 198 | while 1: 199 | cur_r, cur_c = cur_r + direction[0], cur_c + direction[1] 200 | if cur_r < 0 or cur_r > 7 or cur_c < 0 or cur_c > 7: 201 | break 202 | if self.state[cur_r, cur_c] == 0: 203 | break 204 | elif self.state[cur_r, cur_c] == color: 205 | tbf.extend(buffer) 206 | break 207 | else: 208 | buffer.append([cur_r, cur_c]) 209 | if len(tbf) == 0: 210 | return 0 211 | else: 212 | return 2 213 | 214 | def get_valid_moves( 215 | self, 216 | ): 217 | regular_moves = [] 218 | forfeit_moves = [] 219 | for move in range(64): 220 | x = self.tentative_move(move) 221 | if x == 1: 222 | regular_moves.append(move) 223 | elif x == 2: 224 | forfeit_moves.append(move) 225 | else: 226 | pass 227 | if len(regular_moves): 228 | return regular_moves 229 | elif len(forfeit_moves): 230 | return forfeit_moves 231 | else: 232 | return [] 233 | 234 | def get_gt(self, moves, func, prt=False): 235 | # takes a new move or new moves and update state 236 | container = [] 237 | if prt: 238 | self.__print__() 239 | for _, move in enumerate(moves): 240 | self.umpire(move) 241 | container.append(getattr(self, func)()) 242 | # to predict first y, we need already know the first x 243 | if prt: 244 | self.__print__() 245 | return container 246 | 247 | 248 | # # %% 249 | # try: 250 | # othello 251 | # print("Othello dataset exists") 252 | 253 | # except: 254 | # print("Making dataset") 255 | # othello = get_othello(ood_num=-1, data_root=None, wthor=True) 256 | # train_dataset = CharDataset(othello) 257 | 258 | # # made_othello=True 259 | # # %% 260 | # full_seqs = list(filter(lambda x: len(x)==60, train_dataset.data.sequences)) 261 | # print(len(full_seqs)) 262 | # board_seqs = torch.tensor(full_seqs) 263 | # print(board_seqs.numel()) 264 | # # %% 265 | # # n = 50000 266 | # # board_seqs = torch.zeros((n, 60), dtype=int) 267 | # # for c, seq in enumerate(tqdm(othello.sequences)): 268 | # # board_seqs[c, :len(seq)] = torch.tensor(seq) 269 | # # if c == n-1: 270 | # # break 271 | # # %% 272 | # board_seqs_string = board_seqs 273 | # print(board_seqs_string.numel()) 274 | # # %% 275 | # board_seqs_int = board_seqs_string.clone() 276 | # board_seqs_int[board_seqs_string < 29] += 1 277 | # board_seqs_int[(board_seqs_string >= 29) & (board_seqs_string <= 34)] -= 1 278 | # board_seqs_int[(board_seqs_string > 34)] -= 3 279 | # rand = torch.randint(0, 1000000, (20,)) 280 | # print(board_seqs_int.flatten()[rand]) 281 | # print(board_seqs_string.flatten()[rand]) 282 | # # torch.save(board_seqs, "board_seqs.pt") 283 | # # %% 284 | # indices = torch.randperm(len(board_seqs_int)) 285 | # board_seqs_int = board_seqs_int[indices] 286 | # board_seqs_string = board_seqs_string[indices] 287 | # torch.save(board_seqs_int, "board_seqs_int.pth") 288 | # torch.save(board_seqs_string, "board_seqs_string.pth") 289 | # %% 290 | # board_seqs_int = torch.load("board_seqs_int.pth") 291 | # board_seqs_string = torch.load("board_seqs_string.pth") 292 | # print(board_seqs_int.shape) 293 | # imshow(board_seqs_int[:5], title="Board Seqs Int Test") 294 | # imshow(board_seqs_string[:5], title="Board Seqs String Test") 295 | # %% 296 | itos = { 297 | 0: -100, 298 | 1: 0, 299 | 2: 1, 300 | 3: 2, 301 | 4: 3, 302 | 5: 4, 303 | 6: 5, 304 | 7: 6, 305 | 8: 7, 306 | 9: 8, 307 | 10: 9, 308 | 11: 10, 309 | 12: 11, 310 | 13: 12, 311 | 14: 13, 312 | 15: 14, 313 | 16: 15, 314 | 17: 16, 315 | 18: 17, 316 | 19: 18, 317 | 20: 19, 318 | 21: 20, 319 | 22: 21, 320 | 23: 22, 321 | 24: 23, 322 | 25: 24, 323 | 26: 25, 324 | 27: 26, 325 | 28: 29, 326 | 29: 30, 327 | 30: 31, 328 | 31: 32, 329 | 32: 33, 330 | 33: 34, 331 | 34: 37, 332 | 35: 38, 333 | 36: 39, 334 | 37: 40, 335 | 38: 41, 336 | 39: 42, 337 | 40: 43, 338 | 41: 44, 339 | 42: 45, 340 | 43: 46, 341 | 44: 47, 342 | 45: 48, 343 | 46: 49, 344 | 47: 50, 345 | 48: 51, 346 | 49: 52, 347 | 50: 53, 348 | 51: 54, 349 | 52: 55, 350 | 53: 56, 351 | 54: 57, 352 | 55: 58, 353 | 56: 59, 354 | 57: 60, 355 | 58: 61, 356 | 59: 62, 357 | 60: 63, 358 | } 359 | 360 | stoi = { 361 | -100: 0, 362 | -1: 0, 363 | 0: 1, 364 | 1: 2, 365 | 2: 3, 366 | 3: 4, 367 | 4: 5, 368 | 5: 6, 369 | 6: 7, 370 | 7: 8, 371 | 8: 9, 372 | 9: 10, 373 | 10: 11, 374 | 11: 12, 375 | 12: 13, 376 | 13: 14, 377 | 14: 15, 378 | 15: 16, 379 | 16: 17, 380 | 17: 18, 381 | 18: 19, 382 | 19: 20, 383 | 20: 21, 384 | 21: 22, 385 | 22: 23, 386 | 23: 24, 387 | 24: 25, 388 | 25: 26, 389 | 26: 27, 390 | 29: 28, 391 | 30: 29, 392 | 31: 30, 393 | 32: 31, 394 | 33: 32, 395 | 34: 33, 396 | 37: 34, 397 | 38: 35, 398 | 39: 36, 399 | 40: 37, 400 | 41: 38, 401 | 42: 39, 402 | 43: 40, 403 | 44: 41, 404 | 45: 42, 405 | 46: 43, 406 | 47: 44, 407 | 48: 45, 408 | 49: 46, 409 | 50: 47, 410 | 51: 48, 411 | 52: 49, 412 | 53: 50, 413 | 54: 51, 414 | 55: 52, 415 | 56: 53, 416 | 57: 54, 417 | 58: 55, 418 | 59: 56, 419 | 60: 57, 420 | 61: 58, 421 | 62: 59, 422 | 63: 60, 423 | } 424 | # %% 425 | stoi_indices = [ 426 | 0, 427 | 1, 428 | 2, 429 | 3, 430 | 4, 431 | 5, 432 | 6, 433 | 7, 434 | 8, 435 | 9, 436 | 10, 437 | 11, 438 | 12, 439 | 13, 440 | 14, 441 | 15, 442 | 16, 443 | 17, 444 | 18, 445 | 19, 446 | 20, 447 | 21, 448 | 22, 449 | 23, 450 | 24, 451 | 25, 452 | 26, 453 | 29, 454 | 30, 455 | 31, 456 | 32, 457 | 33, 458 | 34, 459 | 37, 460 | 38, 461 | 39, 462 | 40, 463 | 41, 464 | 42, 465 | 43, 466 | 44, 467 | 45, 468 | 46, 469 | 47, 470 | 48, 471 | 49, 472 | 50, 473 | 51, 474 | 52, 475 | 53, 476 | 54, 477 | 55, 478 | 56, 479 | 57, 480 | 58, 481 | 59, 482 | 60, 483 | 61, 484 | 62, 485 | 63, 486 | ] 487 | alpha = "ABCDEFGH" 488 | 489 | 490 | def to_board_label(i): 491 | return f"{alpha[i//8]}{i%8}" 492 | 493 | 494 | board_labels = list(map(to_board_label, stoi_indices)) 495 | 496 | 497 | # %% 498 | def str_to_int(s): 499 | return stoi[s] - 1 500 | 501 | 502 | def to_int(x): 503 | # print("\t", x) 504 | if isinstance(x, torch.Tensor) and x.numel() == 1: 505 | return to_int(x.item()) 506 | elif isinstance(x, list) or isinstance(x, torch.Tensor) or isinstance(x, np.ndarray): 507 | return [to_int(i) for i in x] 508 | elif isinstance(x, int): 509 | return stoi[x] 510 | elif isinstance(x, str): 511 | x = x.upper() 512 | return to_int(to_string(x)) 513 | 514 | 515 | def to_string(x): 516 | """Confusingly, maps it to an int, but a board pos label not a token label (token labels have 0 == pass, and middle board cells don't exist)""" 517 | # print("\t", x) 518 | if isinstance(x, torch.Tensor) and x.numel() == 1: 519 | return to_string(x.item()) 520 | elif isinstance(x, list) or isinstance(x, torch.Tensor) or isinstance(x, np.ndarray): 521 | return [to_string(i) for i in x] 522 | elif isinstance(x, int): 523 | return itos[x] 524 | elif isinstance(x, str): 525 | x = x.upper() 526 | return 8 * alpha.index(x[0]) + int(x[1]) 527 | 528 | 529 | def to_label(x, from_int=True): 530 | # print("\t", x) 531 | if isinstance(x, torch.Tensor) and x.numel() == 1: 532 | return to_label(x.item(), from_int=from_int) 533 | elif isinstance(x, list) or isinstance(x, torch.Tensor) or isinstance(x, np.ndarray): 534 | return [to_label(i, from_int=from_int) for i in x] 535 | elif isinstance(x, int): 536 | if from_int: 537 | return to_board_label(to_string(x)) 538 | else: 539 | return to_board_label(x) 540 | elif isinstance(x, str): 541 | return x 542 | 543 | 544 | int_to_label = to_label 545 | string_to_label = partial(to_label, from_int=False) 546 | str_to_label = string_to_label 547 | 548 | 549 | def moves_to_state(moves): 550 | # moves is a list of string entries (ints) 551 | state = np.zeros((8, 8), dtype=bool) 552 | for move in moves: 553 | state[move // 8, move % 8] = 1.0 554 | return state 555 | 556 | 557 | int_labels = ( 558 | list(range(1, 28)) + ["X", "X"] + list(range(28, 34)) + ["X", "X"] + list(range(34, 61)) 559 | ) 560 | 561 | # %% 562 | 563 | 564 | def get_valid_moves(sequence): 565 | if isinstance(sequence, torch.Tensor): 566 | sequence = sequence.tolist() 567 | board = OthelloBoardState() 568 | return board.get_gt(sequence, "get_valid_moves") 569 | 570 | 571 | # get_valid_moves(board_seqs_string[0]) 572 | # %% 573 | def make_plot_state(board): 574 | state = np.copy(board.state).flatten() 575 | valid_moves = board.get_valid_moves() 576 | next_move = board.get_next_hand_color() 577 | # print(next_move, valid_moves) 578 | for move in valid_moves: 579 | state[move] = next_move - 0.5 580 | return state 581 | 582 | 583 | def add_counter(fig, position, color): 584 | is_black = color > 0 585 | row = position // 8 586 | col = position % 8 587 | fig.layout.shapes += ( 588 | dict( 589 | type="circle", 590 | x0=col - 0.2, 591 | y0=row - 0.2, 592 | x1=col + 0.2, 593 | y1=row + 0.2, 594 | fillcolor="black" if is_black else "white", 595 | line_color="green", 596 | line_width=0.5, 597 | ), 598 | ) 599 | return fig 600 | 601 | 602 | def counter_shape(position, color, mode="normal"): 603 | is_black = color > 0 604 | row = position // 8 605 | col = position % 8 606 | shape = dict( 607 | type="circle", 608 | fillcolor="black" if is_black else "white", 609 | ) 610 | if mode == "normal": 611 | shape.update( 612 | x0=col - 0.2, 613 | y0=row - 0.2, 614 | x1=col + 0.2, 615 | y1=row + 0.2, 616 | line_color="green", 617 | line_width=0.5, 618 | ) 619 | elif mode == "flipped": 620 | shape.update( 621 | x0=col - 0.22, 622 | y0=row - 0.22, 623 | x1=col + 0.22, 624 | y1=row + 0.22, 625 | line_color="purple", 626 | line_width=3, 627 | ) 628 | elif mode == "new": 629 | shape.update( 630 | line_color="red", 631 | line_width=4, 632 | x0=col - 0.25, 633 | y0=row - 0.25, 634 | x1=col + 0.25, 635 | y1=row + 0.25, 636 | ) 637 | return shape 638 | 639 | 640 | def plot_board(moves, return_fig=False): 641 | if isinstance(moves, torch.Tensor): 642 | moves = moves.tolist() 643 | if isinstance(moves[0], str): 644 | moves = to_string(moves) 645 | board = OthelloBoardState() 646 | states = [] 647 | states.append(make_plot_state(board)) 648 | for move in moves: 649 | board.umpire(move) 650 | states.append(make_plot_state(board)) 651 | states = np.stack(states, axis=0) 652 | fig = imshow( 653 | states.reshape(-1, 8, 8), 654 | color_continuous_scale="Geyser", 655 | aspect="equal", 656 | return_fig=True, 657 | animation_frame=0, 658 | y=["a", "b", "c", "d", "e", "f", "g", "h"], 659 | x=["0", "1", "2", "3", "4", "5", "6", "7"], 660 | animation_index=[ 661 | f"{i+1} ({'W' if i%2==0 else 'B'}) [{to_board_label(moves[i]) if i>=0 else 'X'} -> {to_board_label(moves[i+1]) if i 0 688 | row = position // 8 689 | col = position % 8 690 | offset = 0.3 691 | fig.layout.shapes += ( 692 | dict( 693 | type="rect", 694 | x0=col - offset, 695 | y0=row - offset, 696 | x1=col + offset, 697 | y1=row + offset, 698 | line_color="black" if is_black else "red", 699 | line_width=5, 700 | fillcolor=None, 701 | ), 702 | ) 703 | return fig 704 | 705 | 706 | def plot_board_log_probs(moves, logits, return_fig=False, use_counters=False): 707 | logits = logits.squeeze(0) 708 | if isinstance(moves, torch.Tensor): 709 | moves = moves.tolist() 710 | if isinstance(moves[0], str): 711 | moves = to_string(moves) 712 | # print(moves) 713 | assert len(moves) == len(logits) 714 | board = OthelloBoardState() 715 | states = [] 716 | # states.append(make_plot_state(board)) 717 | for move in moves: 718 | board.umpire(move) 719 | states.append(make_plot_state(board)) 720 | states = np.stack(states, axis=0) 721 | 722 | log_probs = logits.log_softmax(dim=-1) 723 | log_probs_template = torch.zeros((len(moves), 64)).cuda() - 100 724 | if log_probs.shape[-1] == 61: 725 | log_probs_template[:, stoi_indices] = log_probs[:, 1:] 726 | else: 727 | log_probs_template[:, stoi_indices] = log_probs[:, :] 728 | log_probs_template = log_probs_template.reshape(-1, 8, 8) 729 | 730 | fig = imshow( 731 | log_probs_template, 732 | color_continuous_scale="Blues", 733 | zmin=-6.0, 734 | zmax=0.0, 735 | aspect="equal", 736 | return_fig=True, 737 | animation_frame=0, 738 | y=["a", "b", "c", "d", "e", "f", "g", "h"], 739 | x=["0", "1", "2", "3", "4", "5", "6", "7"], 740 | animation_index=[ 741 | f"{i+1} ({'W' if i%2==0 else 'B'}) [{to_board_label(moves[i])} -> {to_board_label(moves[i+1]) if i{counter_text}" 760 | elif states[c].flatten()[i] == -1: 761 | if use_counters: 762 | shapes.append(counter_shape(i, False)) 763 | else: 764 | # white = green 765 | text[-1] = f"{counter_text}" 766 | else: 767 | if states[c].flatten()[i] > 0.2: 768 | text[-1] = f"{to_board_label(i)}" 769 | # print(i, c, "b") 770 | # frame = add_ring(frame, i, True) 771 | elif states[c].flatten()[i] < -0.2: 772 | text[-1] = ( 773 | f"{to_board_label(i)}" 774 | ) 775 | # print(i, c, "w") 776 | # frame = add_ring(frame, i, False) 777 | frame.layout.shapes = tuple(shapes) 778 | frame.data[0]["text"] = np.array(text).reshape(8, 8) 779 | frame.data[0]["texttemplate"] = "%{text}" 780 | frame.data[0][ 781 | "hovertemplate" 782 | ] = "%{y}%{x}
log prob: %{z}
prob=%{customdata}" 783 | frame.data[0]["customdata"] = torch.to_numpy(log_probs_template[c].exp()) 784 | # print(states) 785 | fig.layout.shapes = fig.frames[0].layout.shapes 786 | fig.data[0]["text"] = fig.frames[0].data[0]["text"] 787 | fig.data[0]["texttemplate"] = fig.frames[0].data[0]["texttemplate"] 788 | fig.data[0]["customdata"] = fig.frames[0].data[0]["customdata"] 789 | fig.data[0]["hovertemplate"] = fig.frames[0].data[0]["hovertemplate"] 790 | if return_fig: 791 | return fig 792 | else: 793 | fig.show() 794 | 795 | 796 | def plot_single_board(moves, model=None, return_fig=False, title=None): 797 | # moves is a list of string entries (ints) 798 | if isinstance(moves, torch.Tensor): 799 | moves = moves.tolist() 800 | if isinstance(moves[0], str): 801 | moves = to_string(moves) 802 | board = OthelloBoardState() 803 | if len(moves) > 1: 804 | board.update(moves[:-1]) 805 | 806 | prev_state = np.copy(board.state) 807 | prev_player = board.next_hand_color 808 | prev_valid_moves = board.get_valid_moves() 809 | board.umpire(moves[-1]) 810 | next_state = np.copy(board.state) 811 | next_player = board.next_hand_color 812 | next_valid_moves = board.get_valid_moves() 813 | 814 | empty = (prev_state == 0) & (next_state == 0) 815 | new = (prev_state == 0) & (next_state != 0) 816 | flipped = (prev_state != 0) & (next_state != prev_state) & (~new) 817 | prev_valid = moves_to_state(prev_valid_moves) 818 | next_valid = moves_to_state(next_valid_moves) 819 | 820 | state = np.copy(next_state) 821 | state[flipped] *= 0.9 822 | state[prev_valid] = 0.1 * prev_player 823 | state[next_valid] = 0.5 * next_player 824 | state[new] = 0.9 * prev_player 825 | if model is not None: 826 | logits = model(torch.tensor(to_int(moves)).cuda().unsqueeze(0)).cpu() 827 | log_probs = logits.log_softmax(-1) 828 | lps = torch.zeros(64) - 15.0 829 | lps[stoi_indices] = log_probs[0, -1, 1:] 830 | 831 | if title is None: 832 | title = f"{'Black' if prev_player!=1 else 'White'} To Play. Board State After {'Black' if prev_player==1 else 'White'} Plays {to_label(moves[-1], from_int=False)} " 833 | 834 | fig = imshow( 835 | state, 836 | color_continuous_scale="Geyser", 837 | title=title, 838 | y=[i for i in alpha], 839 | x=[str(i) for i in range(8)], 840 | aspect="equal", 841 | return_fig=True, 842 | ) 843 | fig = fig.update_layout(title_x=0.5) 844 | fig.data[0]["hovertemplate"] = "%{y}%{x}
%{customdata}" 845 | 846 | shapes = [] 847 | texts = [] 848 | for i in range(64): 849 | texts.append("") 850 | if empty.flatten()[i]: 851 | texts[-1] = to_label(i, from_int=False) 852 | elif flipped.flatten()[i]: 853 | shapes.append(counter_shape(i, prev_player == 1, mode="flipped")) 854 | elif new.flatten()[i]: 855 | shapes.append(counter_shape(i, prev_player == 1, mode="new")) 856 | elif prev_state.flatten()[i] != 0: 857 | shapes.append(counter_shape(i, prev_state.flatten()[i] == 1, mode="normal")) 858 | else: 859 | raise ValueError(i) 860 | fig.layout.shapes = tuple(shapes) 861 | fig.data[0]["text"] = np.array(texts).reshape(8, 8) 862 | fig.data[0]["texttemplate"] = "%{text}" 863 | if model is not None: 864 | fig.data[0]["customdata"] = np.array( 865 | [f"LP:{lps[i].item():.4f}
I:{int_labels[i]}
S:{i}" for i in range(64)] 866 | ).reshape(8, 8) 867 | else: 868 | fig.data[0]["customdata"] = np.array( 869 | [f"I:{int_labels[i]}
S:{i}" for i in range(64)] 870 | ).reshape(8, 8) 871 | 872 | if return_fig: 873 | return fig 874 | else: 875 | fig.show() 876 | return 877 | -------------------------------------------------------------------------------- /othello_utils.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | from datasets import load_dataset 3 | from othello_engine_utils import OthelloBoardState, stoi, itos 4 | 5 | 6 | def board_state_to_RRC(board_state, flip: int = 1): 7 | board_state = t.tensor(board_state, dtype=t.int8) 8 | board_state *= flip 9 | one_hot = t.zeros((8, 8, 3), dtype=t.int8) 10 | one_hot[..., 0] = (board_state == -1).int() 11 | one_hot[..., 1] = (board_state == 0).int() 12 | one_hot[..., 2] = (board_state == 1).int() 13 | return one_hot 14 | 15 | 16 | # TODO Remove duplicated logic from these functions 17 | def games_batch_to_state_stack_BLRRC(batch_str_moves): 18 | """Sequences of moves (dataset format) to state stack (one-hot) of shape (seq_len, 8, 8, 3)""" 19 | game_stack = [] 20 | for game in batch_str_moves: 21 | if isinstance(game, t.Tensor): 22 | game = game.flatten() 23 | 24 | board = OthelloBoardState() 25 | states = [] 26 | for move in game: 27 | board.umpire(move) 28 | one_hot = board_state_to_RRC(board.state) 29 | states.append(one_hot) 30 | states = t.stack(states, axis=0) 31 | game_stack.append(states) 32 | return t.stack(game_stack, axis=0) 33 | 34 | 35 | def games_batch_to_valid_moves_BLRRC(batch_str_moves): 36 | """Sequences of moves (dataset format) to state stack of valid moves""" 37 | game_stack = [] 38 | for game in batch_str_moves: 39 | if isinstance(game, t.Tensor): 40 | game = game.flatten() 41 | 42 | board = OthelloBoardState() 43 | states = [] 44 | for i, move in enumerate(game): 45 | moves_board = t.zeros(8, 8, 1, dtype=t.int8) 46 | board.umpire(move) 47 | valid_moves_list = board.get_valid_moves() 48 | for move in valid_moves_list: 49 | moves_board[move // 8, move % 8] = 1 50 | states.append(moves_board) 51 | states = t.stack(states, axis=0) 52 | game_stack.append(states) 53 | return t.stack(game_stack, axis=0) 54 | 55 | 56 | def games_batch_to_state_stack_mine_yours_BLRRC(batch_str_moves): 57 | """Sequences of moves (dataset format) to state stack (one-hot) of shape (seq_len, 8, 8, 3)""" 58 | game_stack = [] 59 | for game in batch_str_moves: 60 | if isinstance(game, t.Tensor): 61 | game = game.flatten() 62 | 63 | board = OthelloBoardState() 64 | states = [] 65 | for i, move in enumerate(game): 66 | flip = 1 67 | if i % 2 == 1: 68 | flip = -1 69 | board.umpire(move) 70 | one_hot = board_state_to_RRC(board.state, flip) 71 | states.append(one_hot) 72 | states = t.stack(states, axis=0) 73 | game_stack.append(states) 74 | return t.stack(game_stack, axis=0) 75 | 76 | 77 | othello_functions = [ 78 | games_batch_to_state_stack_BLRRC.__name__, 79 | games_batch_to_state_stack_mine_yours_BLRRC.__name__, 80 | games_batch_to_valid_moves_BLRRC.__name__, 81 | ] 82 | 83 | 84 | def get_othello_even_list_indices(tokens_list: list[int]) -> list[int]: 85 | """""" 86 | max_len = len(tokens_list) 87 | return [i for i in range(max_len) if i % 2 == 0] 88 | 89 | 90 | def get_othello_all_list_indices(tokens_list: list[int]) -> list[int]: 91 | """""" 92 | max_len = len(tokens_list) 93 | return [i for i in range(max_len)] 94 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformer_lens == 1.10.0 2 | tqdm == 4.66.1 3 | torch == 2.2.0 # transformer lens cannot use 2.0.0 or 2.1.0 4 | # I have successfully used both torch 2.1.1 and 2.2.0 5 | jaxtyping == 0.2.25 6 | beartype == 0.14.1 7 | wandb == 0.16.0 8 | fancy_einsum == 0.0.3 9 | einops == 0.7.0 10 | numpy == 1.26.0 11 | python-chess == 1.999 12 | pandas == 2.1.1 13 | plotly == 5.18.0 14 | matplotlib == 3.8.0 15 | nbformat == 5.9.2 16 | pytest == 8.1.1 -------------------------------------------------------------------------------- /tests/test_board_interventions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | # Not the ideal way of doing things, but it works. This way all test functions can pull models / probes / data from the expected location 5 | parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(parent_dir) 7 | 8 | import board_state_interventions 9 | 10 | # TOLERANCE is fairly high because we are testing on a small amount of games for efficiency, won't converge to the expected value 11 | # This is more of a smoke test than anything 12 | TOLERANCE = 0.2 13 | 14 | 15 | def test_single_layer_interventions(): 16 | intervention_type = board_state_interventions.InterventionType.SINGLE_SCALE 17 | 18 | num_games = 2 19 | 20 | probe_names = {} 21 | first_layer = 5 22 | last_layer = 5 23 | GPT_LAYER_COUNT = 8 24 | 25 | for i in range(first_layer, last_layer + 1, 1): 26 | probe_names[i] = ( 27 | f"tf_lens_lichess_{GPT_LAYER_COUNT}layers_ckpt_no_optimizer_chess_piece_probe_layer_{i}.pth" 28 | ) 29 | probe_data = board_state_interventions.get_probe_data(probe_names[first_layer], num_games) 30 | 31 | piece_coe = 1.0 32 | blank_coe = 0.0 33 | 34 | scales = [3.0] 35 | 36 | recording_name = f"TEST_ONLY_n_layers={GPT_LAYER_COUNT}_intervention_type={intervention_type.value}_first_layer={first_layer}_last_layer={last_layer}_p={piece_coe}_b={blank_coe}_scales=" 37 | for scale in scales: 38 | recording_name += f"{str(scale).replace('.', '')[:5]}_" 39 | 40 | print(f"Recording name: {recording_name}") 41 | 42 | success_rate = board_state_interventions.perform_board_interventions( 43 | probe_names, 44 | probe_data, 45 | num_games, 46 | intervention_type, 47 | recording_name, 48 | track_outputs=False, 49 | scales=scales, 50 | ) 51 | 52 | expected_success_rate = 0.8 53 | 54 | assert abs(success_rate - expected_success_rate) < TOLERANCE, f"Success rate mismatch" 55 | -------------------------------------------------------------------------------- /tests/test_caa.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | # Not the ideal way of doing things, but it works. This way all test functions can pull models / probes / data from the expected location 5 | parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(parent_dir) 7 | 8 | import caa 9 | import train_test_chess 10 | import chess_utils 11 | import torch 12 | 13 | MAXIMUM_TRAINING_GAMES = 500 14 | DATA_DIR = "data/" 15 | 16 | device = ( 17 | "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 18 | ) 19 | 20 | 21 | # Just a smoke test 22 | def test_caa(): 23 | 24 | config = chess_utils.skill_config 25 | 26 | caa_type = "simple" 27 | 28 | previous_layer_activations = {} 29 | 30 | dataset_prefix = "lichess_" 31 | layer = 5 32 | split = "train" 33 | n_layers = 16 34 | model_name = f"tf_lens_{dataset_prefix}{n_layers}layers_ckpt_no_optimizer" 35 | config.levels_of_interest = [0, 5] 36 | input_dataframe_file = f"{DATA_DIR}{dataset_prefix}{split}.csv" 37 | config = chess_utils.set_config_min_max_vals_and_column_name( 38 | config, input_dataframe_file, dataset_prefix 39 | ) 40 | config.pos_start = 25 41 | 42 | probe_data = train_test_chess.construct_linear_probe_data( 43 | input_dataframe_file, 44 | dataset_prefix, 45 | n_layers, 46 | model_name, 47 | config, 48 | MAXIMUM_TRAINING_GAMES, 49 | device, 50 | ) 51 | 52 | cascade_layers = "" 53 | 54 | activation_name = f"TEST_ONLY_type=caa_{caa_type}{cascade_layers}_model={n_layers}layers_layer={layer}_activations" 55 | 56 | logging_dict = train_test_chess.init_logging_dict( 57 | layer, 58 | config, 59 | split, 60 | dataset_prefix, 61 | model_name, 62 | n_layers, 63 | train_test_chess.TRAIN_PARAMS, 64 | ) 65 | 66 | previous_layer_activations[layer] = caa.create_contrastive_activations( 67 | activation_name, probe_data, config, logging_dict, layer, MAXIMUM_TRAINING_GAMES 68 | ) 69 | -------------------------------------------------------------------------------- /tests/test_chess_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | # Not the ideal way of doing things, but it works. This way all test functions can pull models / probes / data from the expected location 5 | parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(parent_dir) 7 | 8 | import chess_utils 9 | import chess 10 | import torch 11 | 12 | 13 | def test_white_pos_indices(): 14 | test1 = ";1.e4 c5 2.Nf3 d6 3" 15 | test2 = ";1.e4 c5 2.Nf3 d" 16 | test3 = ";1." 17 | 18 | ans1 = [[0, 1, 2, 3, 4], [8, 9, 10, 11, 12, 13], [17, 18]] 19 | ans2 = [[0, 1, 2, 3, 4], [8, 9, 10, 11, 12, 13]] 20 | ans3 = [[0, 1, 2]] 21 | 22 | assert chess_utils.get_all_white_pos_indices(test1) == ans1 23 | assert chess_utils.get_all_white_pos_indices(test2) == ans2 24 | assert chess_utils.get_all_white_pos_indices(test3) == ans3 25 | 26 | 27 | def test_black_pos_indices(): 28 | test1 = ";1.e4 c5 2.Nf3 d6 3" 29 | test2 = ";1.e4 c5 2.Nf3 d" 30 | test3 = ";1." 31 | 32 | ans1 = [[5, 6, 7], [14, 15, 16]] 33 | ans2 = [[5, 6, 7], [14, 15]] 34 | ans3 = [] 35 | 36 | assert chess_utils.get_all_black_pos_indices(test1) == ans1 37 | assert chess_utils.get_all_black_pos_indices(test2) == ans2 38 | assert chess_utils.get_all_black_pos_indices(test3) == ans3 39 | 40 | 41 | def test_board_to_piece_state(): 42 | 43 | test_str = ";1.e4 e5 2.Nf3" 44 | board = chess_utils.pgn_string_to_board(test_str) 45 | state = chess_utils.board_to_piece_state(board) 46 | 47 | expected_state = torch.tensor( 48 | [ 49 | [4, 2, 3, 5, 6, 3, 0, 4], 50 | [1, 1, 1, 1, 0, 1, 1, 1], 51 | [0, 0, 0, 0, 0, 2, 0, 0], 52 | [0, 0, 0, 0, 1, 0, 0, 0], 53 | [0, 0, 0, 0, -1, 0, 0, 0], 54 | [0, 0, 0, 0, 0, 0, 0, 0], 55 | [-1, -1, -1, -1, 0, -1, -1, -1], 56 | [-4, -2, -3, -5, -6, -3, -2, -4], 57 | ], 58 | dtype=torch.int, 59 | ) 60 | 61 | assert torch.equal(state, expected_state) 62 | -------------------------------------------------------------------------------- /tests/test_probe_training_and_eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | # Not the ideal way of doing things, but it works. This way all test functions can pull models / probes / data from the expected location 5 | parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(parent_dir) 7 | 8 | import train_test_chess 9 | import torch 10 | import chess_utils 11 | 12 | DATA_DIR = "data/" 13 | 14 | DEVICE = ( 15 | "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 16 | ) 17 | 18 | TRAIN_PARAMS = train_test_chess.TrainingParams() 19 | TRAIN_PARAMS.max_iters = 10000 20 | TRAIN_PARAMS.num_epochs = 2 21 | TRAIN_PARAMS.max_train_games = 5000 22 | TRAIN_PARAMS.max_test_games = 2000 23 | 24 | # TRAIN_EPSILON is fairly large because the probes haven't converged and are trained on random permutations of the training data 25 | TRAIN_EPSILON = 0.03 26 | TEST_EPSILON = 0.002 27 | # IMPORTANT: You must train probes on this model or the asserts will fail 28 | # https://huggingface.co/adamkarvonen/chess_llms/blob/main/lichess_8layers_ckpt_no_optimizer.pt 29 | # At a small epsilon and only 10k iters, the probes haven't converged so it will be sensitive to changes in the model 30 | # But at least this way testing the training code only takes 10 minutes 31 | 32 | 33 | def test_piece_train_linear_probe_cross_entropy(): 34 | torch.set_grad_enabled(True) 35 | config = chess_utils.piece_config 36 | first_layer = 0 37 | last_layer = 7 38 | 39 | dataset_prefix = "lichess_" 40 | split = "train" 41 | n_layers = 8 42 | model_name = f"tf_lens_{dataset_prefix}{n_layers}layers_ckpt_no_optimizer" 43 | 44 | input_dataframe_file = f"{DATA_DIR}{dataset_prefix}{split}.csv" 45 | config = chess_utils.set_config_min_max_vals_and_column_name( 46 | config, input_dataframe_file, dataset_prefix 47 | ) 48 | 49 | max_games = TRAIN_PARAMS.max_train_games + TRAIN_PARAMS.max_val_games 50 | 51 | probe_data = train_test_chess.construct_linear_probe_data( 52 | input_dataframe_file, 53 | dataset_prefix, 54 | n_layers, 55 | model_name, 56 | config, 57 | max_games, 58 | DEVICE, 59 | ) 60 | 61 | probes = train_test_chess.populate_probes_dict( 62 | list(range(first_layer, last_layer + 1)), 63 | config, 64 | TRAIN_PARAMS, 65 | split, 66 | dataset_prefix, 67 | model_name, 68 | n_layers, 69 | ) 70 | 71 | for probe in probes: 72 | probes[probe].probe_name = probes[probe].probe_name.replace("tf_lens", "TEST_ONLY_tf_lens") 73 | 74 | final_accs = train_test_chess.train_linear_probe_cross_entropy( 75 | probes, probe_data, config, TRAIN_PARAMS 76 | ) 77 | 78 | expected_final_accs = { 79 | 0: 0.745, 80 | 1: 0.746, 81 | 2: 0.768, 82 | 3: 0.803, 83 | 4: 0.863, 84 | 5: 0.981, 85 | 6: 0.978, 86 | 7: 0.961, 87 | } 88 | 89 | for layer in range(first_layer, last_layer + 1): 90 | assert ( 91 | abs(final_accs[layer] - expected_final_accs[layer]) < TRAIN_EPSILON 92 | ), f"Accuracy mismatch for layer {layer}" 93 | 94 | 95 | def test_skill_train_linear_probe_cross_entropy(): 96 | torch.set_grad_enabled(True) 97 | config = chess_utils.skill_config 98 | first_layer = 0 99 | last_layer = 7 100 | 101 | dataset_prefix = "lichess_" 102 | split = "train" 103 | n_layers = 8 104 | model_name = f"tf_lens_{dataset_prefix}{n_layers}layers_ckpt_no_optimizer" 105 | 106 | input_dataframe_file = f"{DATA_DIR}{dataset_prefix}{split}.csv" 107 | config = chess_utils.set_config_min_max_vals_and_column_name( 108 | config, input_dataframe_file, dataset_prefix 109 | ) 110 | 111 | max_games = TRAIN_PARAMS.max_train_games + TRAIN_PARAMS.max_val_games 112 | 113 | probe_data = train_test_chess.construct_linear_probe_data( 114 | input_dataframe_file, 115 | dataset_prefix, 116 | n_layers, 117 | model_name, 118 | config, 119 | max_games, 120 | DEVICE, 121 | ) 122 | 123 | probes = train_test_chess.populate_probes_dict( 124 | list(range(first_layer, last_layer + 1)), 125 | config, 126 | TRAIN_PARAMS, 127 | split, 128 | dataset_prefix, 129 | model_name, 130 | n_layers, 131 | ) 132 | 133 | for probe in probes: 134 | probes[probe].probe_name = probes[probe].probe_name.replace("tf_lens", "TEST_ONLY_tf_lens") 135 | 136 | final_accs = train_test_chess.train_linear_probe_cross_entropy( 137 | probes, probe_data, config, TRAIN_PARAMS 138 | ) 139 | 140 | expected_final_accs = { 141 | 0: 0.645, 142 | 1: 0.663, 143 | 2: 0.641, 144 | 3: 0.709, 145 | 4: 0.820, 146 | 5: 0.838, 147 | 6: 0.876, 148 | 7: 0.878, 149 | } 150 | 151 | for layer in range(first_layer, last_layer + 1): 152 | assert ( 153 | abs(final_accs[layer] - expected_final_accs[layer]) < TRAIN_EPSILON 154 | ), f"Accuracy mismatch for layer {layer}" 155 | 156 | 157 | def test_linear_probe_cross_entropy_test(): 158 | 159 | expected_results = { 160 | "tf_lens_lichess_8layers_ckpt_no_optimizer_chess_piece_probe_layer_5.pth": 0.9907, 161 | "tf_lens_lichess_8layers_ckpt_no_optimizer_chess_skill_probe_layer_7.pth": 0.8856, 162 | } 163 | 164 | print(expected_results) 165 | for probe_to_test in expected_results.keys(): 166 | probe_file_location = f"{train_test_chess.SAVED_PROBE_DIR}{probe_to_test}" 167 | # We will populate all parameters using information in the probe state dict 168 | with open(probe_file_location, "rb") as f: 169 | state_dict = torch.load(f, map_location=torch.device(DEVICE)) 170 | print(state_dict.keys()) 171 | for key in state_dict.keys(): 172 | if key != "linear_probe": 173 | print(key, state_dict[key]) 174 | 175 | config = chess_utils.find_config_by_name(state_dict["config_name"]) 176 | layer = state_dict["layer"] 177 | model_name = state_dict["model_name"] 178 | dataset_prefix = state_dict["dataset_prefix"] 179 | config.pos_start = state_dict["pos_start"] 180 | levels_of_interest = None 181 | if "levels_of_interest" in state_dict.keys(): 182 | levels_of_interest = state_dict["levels_of_interest"] 183 | config.levels_of_interest = levels_of_interest 184 | n_layers = state_dict["n_layers"] 185 | split = "test" 186 | 187 | input_dataframe_file = f"{DATA_DIR}{dataset_prefix}{split}.csv" 188 | config = chess_utils.set_config_min_max_vals_and_column_name( 189 | config, input_dataframe_file, dataset_prefix 190 | ) 191 | 192 | probe_data = train_test_chess.construct_linear_probe_data( 193 | input_dataframe_file, 194 | dataset_prefix, 195 | n_layers, 196 | model_name, 197 | config, 198 | TRAIN_PARAMS.max_test_games, 199 | DEVICE, 200 | ) 201 | 202 | logging_dict = train_test_chess.init_logging_dict( 203 | layer, config, split, dataset_prefix, model_name, n_layers, TRAIN_PARAMS 204 | ) 205 | 206 | result = train_test_chess.test_linear_probe_cross_entropy( 207 | probe_file_location, probe_data, config, logging_dict, TRAIN_PARAMS 208 | ) 209 | 210 | assert ( 211 | abs(result - expected_results[probe_to_test]) < TEST_EPSILON 212 | ), f"Accuracy mismatch for probe {probe_to_test}" 213 | -------------------------------------------------------------------------------- /utils/board_grid_search_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [] 7 | }, 8 | { 9 | "cell_type": "code", 10 | "execution_count": null, 11 | "metadata": {}, 12 | "outputs": [], 13 | "source": [ 14 | "import os\n", 15 | "GRID_SEARCH_DIR = \"intervention_logs\"\n", 16 | "\n", 17 | "for file in os.listdir(GRID_SEARCH_DIR):\n", 18 | " print(file)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import pandas as pd\n", 28 | "import os\n", 29 | "import numpy as np\n", 30 | "import json\n", 31 | "import re\n", 32 | "\n", 33 | "GRID_SEARCH_DIR = \"intervention_logs\"\n", 34 | "\n", 35 | "def extract_info_from_filename(filename: str) -> dict:\n", 36 | " \"\"\"\n", 37 | " Extracts information from the filename using regular expressions.\n", 38 | "\n", 39 | " Parameters:\n", 40 | " filename (str): The filename from which to extract the information.\n", 41 | "\n", 42 | " Returns:\n", 43 | " dict: A dictionary containing the extracted information.\n", 44 | " \"\"\"\n", 45 | " # Define the regex pattern to extract the needed information\n", 46 | " pattern = r\"sampling=both_intervention_type=(?P\\w+)_first_layer=(?P\\d+)_last_layer=(?P\\d+)_p=(?P

\\d+\\.\\d+)_b=(?P\\d+\\.\\d+)_scales=\"\n", 47 | " \n", 48 | " # Use the regex search function to find matches\n", 49 | " match = re.search(pattern, filename)\n", 50 | "\n", 51 | " # If a match is found, extract the information\n", 52 | " if match:\n", 53 | " info = match.groupdict()\n", 54 | " # Convert numeric values from strings to their appropriate types\n", 55 | " info['first_layer'] = int(info['first_layer'])\n", 56 | " info['last_layer'] = int(info['last_layer'])\n", 57 | " info['p'] = float(info['p'])\n", 58 | " info['b'] = float(info['b'])\n", 59 | " return info\n", 60 | " else:\n", 61 | " return {\"error\": \"No match found\"}\n", 62 | "\n", 63 | "# List all CSV files\n", 64 | "json_files = [f for f in os.listdir(GRID_SEARCH_DIR) if f.endswith('.json')]\n", 65 | "\n", 66 | "# Dictionary to hold file names and their average scores\n", 67 | "average_scores_dict = {}\n", 68 | "\n", 69 | "for file in json_files:\n", 70 | " file_info = extract_info_from_filename(file)\n", 71 | " if \"error\" in file_info:\n", 72 | " continue\n", 73 | " # if file_info[\"first_layer\"] != file_info[\"last_layer\"]:\n", 74 | " # continue\n", 75 | "\n", 76 | " with open(f\"{GRID_SEARCH_DIR}/{file}\", \"r\") as f:\n", 77 | " data = json.load(f)\n", 78 | " if \"possible_sampled_moves\" not in data:\n", 79 | " continue\n", 80 | " average_scores_dict[file] = {}\n", 81 | " average_scores_dict[file][\"possible_sampled_moves\"] = data[\"possible_sampled_moves\"]\n", 82 | " for key in data:\n", 83 | " try:\n", 84 | " # Try to convert the key to float\n", 85 | " float_key = float(key)\n", 86 | " # If conversion is successful, add the key to the dictionary\n", 87 | " average_scores_dict[file][key] = data[key]\n", 88 | " except ValueError:\n", 89 | " # If conversion fails, the key is not a float\n", 90 | " pass\n", 91 | "\n", 92 | "sampled_ratio_list = []\n", 93 | "for file in average_scores_dict:\n", 94 | " max_possible = average_scores_dict[file][\"possible_sampled_moves\"]\n", 95 | " for scale in average_scores_dict[file]:\n", 96 | " if scale == \"possible_sampled_moves\":\n", 97 | " continue\n", 98 | " legal_sampled = average_scores_dict[file][scale][\"mod_board_sampled_legal_total\"]\n", 99 | " sampled_ratio = legal_sampled / max_possible\n", 100 | " average_scores_dict[file][scale][\"sampled_ratio\"] = sampled_ratio\n", 101 | " sampled_ratio_list.append((file, scale, sampled_ratio))\n", 102 | "\n", 103 | "\n", 104 | "\n", 105 | "\n", 106 | "# Sort the files by average score\n", 107 | "sorted_sampled_ratio_list = sorted(sampled_ratio_list, key=lambda x: x[2], reverse=True)\n", 108 | "\n", 109 | "# Print sorted files and their scores\n", 110 | "for file, scale, score in sorted_sampled_ratio_list:\n", 111 | " # if \"gradient\" not in file:\n", 112 | " # continue\n", 113 | " # if scale != \"3.0\":\n", 114 | " # continue\n", 115 | " print(f\"{file}, {scale}, {score}\")\n" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "import pandas as pd\n", 125 | "import os\n", 126 | "import numpy as np\n", 127 | "import json\n", 128 | "import re\n", 129 | "\n", 130 | "GRID_SEARCH_DIR = \"intervention_logs\"\n", 131 | "\n", 132 | "def extract_info_from_filename(filename: str) -> dict:\n", 133 | " \"\"\"\n", 134 | " Extracts information from the filename using regular expressions.\n", 135 | "\n", 136 | " Parameters:\n", 137 | " filename (str): The filename from which to extract the information.\n", 138 | "\n", 139 | " Returns:\n", 140 | " dict: A dictionary containing the extracted information.\n", 141 | " \"\"\"\n", 142 | " # Define the regex pattern to extract the needed information\n", 143 | " pattern = r\"sampling=both_n_layers=8_intervention_type=(?P\\w+)_first_layer=(?P\\d+)_last_layer=(?P\\d+)_p=(?P

\\d+\\.\\d+)_b=(?P\\d+\\.\\d+)_iters=(?P\\d+)_scales=\"\n", 144 | " \n", 145 | " # Use the regex search function to find matches\n", 146 | " match = re.search(pattern, filename)\n", 147 | "\n", 148 | " # If a match is found, extract the information\n", 149 | " if match:\n", 150 | " info = match.groupdict()\n", 151 | " # Convert numeric values from strings to their appropriate types\n", 152 | " info['first_layer'] = int(info['first_layer'])\n", 153 | " info['last_layer'] = int(info['last_layer'])\n", 154 | " info['p'] = float(info['p'])\n", 155 | " info['b'] = float(info['b'])\n", 156 | " return info\n", 157 | " else:\n", 158 | " return {\"error\": \"No match found\"}\n", 159 | "\n", 160 | "# List all CSV files\n", 161 | "json_files = [f for f in os.listdir(GRID_SEARCH_DIR) if f.endswith('.json')]\n", 162 | "\n", 163 | "# Dictionary to hold file names and their average scores\n", 164 | "average_scores_dict = {}\n", 165 | "\n", 166 | "for file in json_files:\n", 167 | " file_info = extract_info_from_filename(file)\n", 168 | " if \"error\" in file_info:\n", 169 | " continue\n", 170 | " # if file_info[\"first_layer\"] != file_info[\"last_layer\"]:\n", 171 | " # continue\n", 172 | "\n", 173 | " with open(f\"{GRID_SEARCH_DIR}/{file}\", \"r\") as f:\n", 174 | " data = json.load(f)\n", 175 | " if \"possible_sampled_moves\" not in data:\n", 176 | " continue\n", 177 | " average_scores_dict[file] = {}\n", 178 | " average_scores_dict[file][\"possible_sampled_moves\"] = data[\"possible_sampled_moves\"]\n", 179 | " for key in data:\n", 180 | " try:\n", 181 | " # Try to convert the key to float\n", 182 | " float_key = float(key)\n", 183 | " # If conversion is successful, add the key to the dictionary\n", 184 | " average_scores_dict[file][key] = data[key]\n", 185 | " except ValueError:\n", 186 | " # If conversion fails, the key is not a float\n", 187 | " pass\n", 188 | "\n", 189 | "sampled_ratio_list = []\n", 190 | "for file in average_scores_dict:\n", 191 | " max_possible = average_scores_dict[file][\"possible_sampled_moves\"]\n", 192 | " for scale in average_scores_dict[file]:\n", 193 | " if scale == \"possible_sampled_moves\":\n", 194 | " continue\n", 195 | " legal_sampled = average_scores_dict[file][scale][\"mod_board_sampled_legal_total\"]\n", 196 | " sampled_ratio = legal_sampled / max_possible\n", 197 | " average_scores_dict[file][scale][\"sampled_ratio\"] = sampled_ratio\n", 198 | " sampled_ratio_list.append((file, scale, sampled_ratio))\n", 199 | "\n", 200 | "\n", 201 | "\n", 202 | "\n", 203 | "# Sort the files by average score\n", 204 | "sorted_sampled_ratio_list = sorted(sampled_ratio_list, key=lambda x: x[2], reverse=True)\n", 205 | "\n", 206 | "# Print sorted files and their scores\n", 207 | "for file, scale, score in sorted_sampled_ratio_list:\n", 208 | " # if \"gradient\" not in file:\n", 209 | " # continue\n", 210 | " # if scale != \"3.0\":\n", 211 | " # continue\n", 212 | " print(f\"{file}, {scale}, {score}\")\n" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "import pandas as pd\n", 222 | "import os\n", 223 | "import numpy as np\n", 224 | "from scipy.stats import sem, t\n", 225 | "\n", 226 | "def compute_average_score_and_confidence_interval(file_path: str, player: str = \"player_one\", confidence=0.95) -> dict:\n", 227 | " file_path = os.path.join(GRID_SEARCH_DIR, file_path)\n", 228 | " df = pd.read_csv(file_path)\n", 229 | "\n", 230 | " # Filter rows where 'Stockfish -1' is in 'game_title'\n", 231 | " # df = df[df['game_title'].str.contains(\"Stockfish 0\")]\n", 232 | " df = df[~df['game_title'].str.contains(\"Stockfish -1\")]\n", 233 | "\n", 234 | " # df = df[df['player_one_failed_to_find_legal_move'] == False]\n", 235 | "\n", 236 | " total_moves = len(df)\n", 237 | " successful_moves = df[df[f'{player}_failed_to_find_legal_move'] == False].shape[0]\n", 238 | " percentage_successful = (successful_moves / total_moves) * 100\n", 239 | "\n", 240 | " df[f\"{player}_score\"] = pd.to_numeric(df[f\"{player}_score\"], errors=\"coerce\")\n", 241 | "\n", 242 | " # Compute overall average score and confidence interval\n", 243 | " scores = df[f\"{player}_score\"].dropna()\n", 244 | " if len(scores) > 1:\n", 245 | " mean_score = np.mean(scores)\n", 246 | " std_err = sem(scores)\n", 247 | " h = std_err * t.ppf((1 + confidence) / 2, len(scores) - 1)\n", 248 | " result = {\"mean_score\": mean_score, \"confidence_interval\": (mean_score - h, mean_score + h)}\n", 249 | " else:\n", 250 | " result = {\"mean_score\": scores.iloc[0], \"confidence_interval\": (np.nan, np.nan)}\n", 251 | "\n", 252 | " result[\"percentage_successful\"] = percentage_successful\n", 253 | " result[\"total_moves\"] = total_moves\n", 254 | "\n", 255 | " # print(result)\n", 256 | "\n", 257 | " return result\n", 258 | "\n", 259 | " # # Compute average scores and confidence intervals\n", 260 | " # results = {}\n", 261 | " # for game_title, group in df.groupby(\"game_title\"):\n", 262 | " # scores = group[f\"{player}_score\"].dropna()\n", 263 | " # if len(scores) > 1:\n", 264 | " # mean_score = np.mean(scores)\n", 265 | " # std_err = sem(scores)\n", 266 | " # h = std_err * t.ppf((1 + confidence) / 2, len(scores) - 1)\n", 267 | " # results[game_title] = {\"mean_score\": mean_score, \"confidence_interval\": (mean_score - h, mean_score + h)}\n", 268 | " # else:\n", 269 | " # results[game_title] = {\"mean_score\": scores.iloc[0], \"confidence_interval\": (np.nan, np.nan)}\n", 270 | "\n", 271 | " # print(results)\n", 272 | " # return results\n", 273 | "\n", 274 | "# List all CSV files\n", 275 | "json_files = [f for f in os.listdir(GRID_SEARCH_DIR) if f.endswith('.csv')]\n", 276 | "\n", 277 | "# Dictionary to hold file names and their average scores\n", 278 | "results_dict = {}\n", 279 | "average_scores_dict = {}\n", 280 | "\n", 281 | "for file in json_files:\n", 282 | " try:\n", 283 | " # if \"10_random_moves\" not in file:\n", 284 | " # continue\n", 285 | " # if \"0_1_coefficient\" in file:\n", 286 | " # continue\n", 287 | " if \"levels_14\" in file or \"levels_15\" in file or \"levels_04\" in file:\n", 288 | " continue\n", 289 | " if \"200k\" in file:\n", 290 | " continue\n", 291 | " if \"pos_start_32\" in file:\n", 292 | " continue\n", 293 | " if \"20000_moves\" in file:\n", 294 | " continue\n", 295 | " # if \"layer_8\" in file:\n", 296 | " # continue\n", 297 | " # if \"8layers\" not in file and \"layers_8\" not in file:\n", 298 | " if \"8layers\" in file or \"layers_8\" in file:\n", 299 | " continue\n", 300 | " result = compute_average_score_and_confidence_interval(file)\n", 301 | " average_scores_dict[file] = result[\"mean_score\"]\n", 302 | " results_dict[file] = result\n", 303 | " except Exception as e:\n", 304 | " print(f\"Error processing file {file}: {e}\")\n", 305 | "\n", 306 | "# Sort the files by average score\n", 307 | "sorted_files = sorted(average_scores_dict.items(), key=lambda x: x[1], reverse=True)\n", 308 | "\n", 309 | "# Print sorted files and their scores\n", 310 | "for file, score in sorted_files:\n", 311 | " if \"10_random_moves\" not in file:\n", 312 | " continue\n", 313 | " confidence_interval = results_dict[file][\"confidence_interval\"]\n", 314 | " total_moves = results_dict[file][\"total_moves\"]\n", 315 | " percentage_successful = results_dict[file][\"percentage_successful\"]\n", 316 | " # if file.startswith(\"lichess_train\"):\n", 317 | " # file = file.replace(\"lichess_train\", \"lichess_000k_bins_train\")\n", 318 | " # if total_moves <= 100:\n", 319 | " # continue\n", 320 | "\n", 321 | " print(f\"{file}: {score}, {total_moves}, {percentage_successful}, {confidence_interval}\")\n", 322 | "\n", 323 | "print(\"\\n\\nBegin 0 random moves\\n\\n\")\n", 324 | "# Print sorted files and their scores\n", 325 | "for file, score in sorted_files:\n", 326 | " if \"_0_random_moves\" not in file:\n", 327 | " continue\n", 328 | " confidence_interval = results_dict[file][\"confidence_interval\"]\n", 329 | " total_moves = results_dict[file][\"total_moves\"]\n", 330 | " percentage_successful = results_dict[file][\"percentage_successful\"]\n", 331 | " # if file.startswith(\"lichess_train\"):\n", 332 | " # file = file.replace(\"lichess_train\", \"lichess_000k_bins_train\")\n", 333 | " # if total_moves <= 100:\n", 334 | " # continue\n", 335 | "\n", 336 | " print(f\"{file}: {score}, {total_moves}, {percentage_successful}, {confidence_interval}\")\n" 337 | ] 338 | } 339 | ], 340 | "metadata": { 341 | "kernelspec": { 342 | "display_name": "openai", 343 | "language": "python", 344 | "name": "python3" 345 | }, 346 | "language_info": { 347 | "codemirror_mode": { 348 | "name": "ipython", 349 | "version": 3 350 | }, 351 | "file_extension": ".py", 352 | "mimetype": "text/x-python", 353 | "name": "python", 354 | "nbconvert_exporter": "python", 355 | "pygments_lexer": "ipython3", 356 | "version": "3.11.7" 357 | } 358 | }, 359 | "nbformat": 4, 360 | "nbformat_minor": 2 361 | } 362 | -------------------------------------------------------------------------------- /utils/chess_gpt_eval_data_filtering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import re" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "I use this notebook for manipulating the outputs of my chess_gpt_eval repository and doing various experiments with it. These games are generally outputs of playing stockfish vs stockfish or Chess-GPT against stockfish. For standard uses, you shouldn't need to use this notebook." 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "DATA_DIR = \"data/\"\n", 27 | "prefix = \"rand_test_2\"\n", 28 | "\n", 29 | "input_file = f'{DATA_DIR}{prefix}.csv'\n", 30 | "output_file = f'{DATA_DIR}filtered_{prefix}.csv'" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "df = pd.read_csv(input_file)\n", 40 | "grouped = df.groupby('player_two')\n", 41 | "\n", 42 | "\n", 43 | "def format_transcript(game: str) -> str:\n", 44 | " new_game = ';' + game.split('\\n\\n')[1]\n", 45 | " new_game = re.sub(r\"(\\d+\\.) \", r\"\\1\", new_game)\n", 46 | " return new_game\n", 47 | "\n", 48 | "def format_player_name(name: str) -> str:\n", 49 | " \"\"\"This will go from e.g. \"Stockfish 0\" to \"0\".\"\"\"\n", 50 | " return name.split(' ')[1]\n", 51 | "\n", 52 | "\n", 53 | "df['transcript'] = df['transcript'].apply(format_transcript)\n", 54 | "df['player_two'] = df['player_two'].apply(format_player_name)\n", 55 | "\n", 56 | "for game in df.head()['transcript']:\n", 57 | " print(game)\n", 58 | " print()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "len_df = df['transcript'].apply(lambda x: len(x))\n", 68 | "print(len_df.describe())\n", 69 | "\n", 70 | "game_length_in_chars = 356\n", 71 | "\n", 72 | "# Data setup. All games must have same length. 50% are >= 690 moves. I will discard all games less than 680, and truncate the rest to 680.\n", 73 | "filtered_df = df[df['transcript'].apply(lambda x: len(x) >= game_length_in_chars)].copy()\n", 74 | "filtered_df.loc[:, 'transcript'] = filtered_df['transcript'].apply(lambda x: x[:game_length_in_chars])\n", 75 | "\n", 76 | "len_df = filtered_df['transcript'].apply(lambda x: len(x))\n", 77 | "print(len_df.describe())\n", 78 | "\n", 79 | "move_count_df = filtered_df['transcript'].apply(lambda x: len(x.split()))\n", 80 | "move_count = move_count_df.describe()\n", 81 | "print(\"move count\", move_count_df.describe())\n", 82 | "quarter_percentile = move_count['25%']\n", 83 | "print(\"quarter percentile\", quarter_percentile)\n", 84 | "\n", 85 | "# Now I need to filter out games that are too short. I will discard all games less than 25th percentile moves.\n", 86 | "filtered_df = filtered_df[filtered_df['transcript'].apply(lambda x: len(x.split()) >= quarter_percentile)]\n", 87 | "print(filtered_df.describe())\n", 88 | "print(filtered_df.head())\n", 89 | "\n", 90 | "filtered_df.to_csv(output_file, index=False)\n", 91 | "\n", 92 | "move_count_df = filtered_df['transcript'].apply(lambda x: len(x.split()))\n", 93 | "print(move_count_df.describe())" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "print(len(filtered_df))\n", 103 | "player_two_group_sizes = filtered_df.groupby('player_two').size()\n", 104 | "print(player_two_group_sizes)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# shuffle all rows of the dataset\n", 114 | "\n", 115 | "df = pd.read_csv(output_file)\n", 116 | "df = df.sample(frac=1, random_state=200).reset_index(drop=True)\n", 117 | "df.to_csv(output_file, index=False)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "import pandas as pd\n", 127 | "df = pd.read_csv(output_file)\n", 128 | "\n", 129 | "print(len(df))\n", 130 | "\n", 131 | "# Split df into a train and test split\n", 132 | "train = df.sample(frac=0.5, random_state=200)\n", 133 | "test = df.drop(train.index)\n", 134 | "\n", 135 | "print(len(train))\n", 136 | "print(len(test))\n", 137 | "\n", 138 | "# Save the train and test splits to csv\n", 139 | "train.to_csv(f'{DATA_DIR}{prefix}train.csv', index=False)\n", 140 | "test.to_csv(f'{DATA_DIR}{prefix}test.csv', index=False)" 141 | ] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "othello", 147 | "language": "python", 148 | "name": "python3" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.11.7" 161 | } 162 | }, 163 | "nbformat": 4, 164 | "nbformat_minor": 2 165 | } 166 | -------------------------------------------------------------------------------- /utils/create_skill_intervention_from_skill_probe.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "torch.set_grad_enabled(False)" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "for layer in range(8):\n", 20 | " state_dict_name = f\"tf_lens_lichess_8layers_ckpt_no_optimizer_chess_piece_probe_layer_{layer}.pth\"\n", 21 | " state_dict = torch.load(state_dict_name, map_location=torch.device('cpu'))\n", 22 | " average_high_elo = state_dict[\"average_high_elo_activation\"]\n", 23 | " average_low_elo = state_dict[\"average_low_elo_activation\"]\n", 24 | " difference_vector = state_dict[\"difference_vector\"]\n", 25 | " state_dict[\"difference_vector\"] = difference_vector\n", 26 | "\n", 27 | " new_state_dict_name = f\"type=probe_model=8layers_layer={layer}.pt\"\n", 28 | " torch.save(state_dict, new_state_dict_name)" 29 | ] 30 | } 31 | ], 32 | "metadata": { 33 | "kernelspec": { 34 | "display_name": "chess", 35 | "language": "python", 36 | "name": "python3" 37 | }, 38 | "language_info": { 39 | "codemirror_mode": { 40 | "name": "ipython", 41 | "version": 3 42 | }, 43 | "file_extension": ".py", 44 | "mimetype": "text/x-python", 45 | "name": "python", 46 | "nbconvert_exporter": "python", 47 | "pygments_lexer": "ipython3", 48 | "version": "3.11.7" 49 | } 50 | }, 51 | "nbformat": 4, 52 | "nbformat_minor": 2 53 | } 54 | -------------------------------------------------------------------------------- /utils/custom_functions_guide.md: -------------------------------------------------------------------------------- 1 | If you wish to add a new Othello or Chess board to state function, there are three steps. I will use `board_to_pin_state` and `games_batch_to_state_stack_mine_yours_BLRRC` as my examples. 2 | 3 | First, write a function that converts a chess board or othello board to a pytorch tensor. Refer to `chess_utils.py/board_to_pin_state()` and `othello_utils.py/games_batch_to_state_stack_mine_yours_BLRRC()` for how to do this. 4 | 5 | Next, add a `Config` for this function. To do this, refer to `chess_utils.py/othello_config` and `chess_utils.py/pin_config`. This config object is primarily used to indicate the expected size of the one hot output tensor, the name for the linear probe, if the function is used with othello, and if the board to state function should get additional data such as skill. I'm not very happy about this implementation, and it really should get refactored so board to state functions return one hot tensors instead. 6 | 7 | Currently, board to state functions return a row by column tensor. For board state, this would be 8x8. In Othello, if an element was -1, then it's white on that square. If it's 1, then it's black on that square. So, then tensor is 8x8, and all elements are -1, 0, or 1. 8 | 9 | I had done this so I could easily plot and visualize board states. This was a mistake. Everything should return one hot tensors, and it should convert one hot to standard tensors for plotting. 10 | 11 | Here's `othello_config`: 12 | 13 | ``` 14 | othello_config = Config( 15 | min_val=-1, 16 | max_val=1, 17 | custom_board_state_function=othello_utils.games_batch_to_state_stack_mine_yours_BLRRC, 18 | linear_probe_name="othello_mine_yours_probe", 19 | othello=True, 20 | ) 21 | ``` 22 | 23 | The minimum and maximum values present in the board state are -1 and 1, so that's min val and max val. After one hot encoding, it will be shape (8, 8, 3). The 3 is because any square can be -1, 0, or 1. Here's `pin_config`: 24 | 25 | ``` 26 | pin_config = Config( 27 | min_val=0, 28 | max_val=1, 29 | custom_board_state_function=board_to_pin_state, 30 | num_rows=1, 31 | num_cols=1, 32 | linear_probe_name="chess_pin_probe", 33 | ) 34 | ``` 35 | 36 | Because pin config is a 1x1 binary variable, its shape will be (1, 1, 1). It still has rows and columns so the shapes work with einops. 37 | 38 | Finally, in `train_test_chess.py`, add `config = chess_utils.my_config` before this line: `input_dataframe_file = f"{DATA_DIR}{dataset_prefix}{split}.csv"`. 39 | 40 | This is all you need to train a linear probe. To test a linear probe, add your config to `chess_utils.py/all_configs` in `chess_utils.py/find_config_by_name()`. 41 | 42 | If I every get around to it, this could be significantly cleaned up. The codebase started evolving and I never got around to doing a refactor to clean it up. It's not terrible as it's just three steps. For now, it is what it is. -------------------------------------------------------------------------------- /utils/othello_data_filtering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "from datasets import load_dataset\n", 11 | "import os" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "I use this notebook for manipulating the outputs of my chess_gpt_eval repository and doing various experiments with it. These games are generally outputs of playing stockfish vs stockfish or Chess-GPT against stockfish. For standard uses, you shouldn't need to use this notebook.\n", 19 | "\n", 20 | "To train or test on OthelloGPT, move this notebook into the root `chess_llm_interpretability`. Click run all.\n", 21 | "\n", 22 | "At the bottom of `train_test_chess.py`, set `othello=True`.\n", 23 | "\n", 24 | "You can also set the config to `othello_valid_moves_config`." 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "DATA_DIR = \"data/\"\n", 34 | "prefix = \"othello_\"\n", 35 | "\n", 36 | "\n", 37 | "input_file = f'{DATA_DIR}{prefix}100mb.csv'\n", 38 | "output_file = input_file.replace(\".csv\", \"_filtered.csv\")" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "if not os.path.exists(input_file):\n", 48 | " dataset_path = \"adamkarvonen/othello_45MB_games\"\n", 49 | " file_path = f\"{prefix}100mb.zip\"\n", 50 | " dataset = load_dataset(dataset_path)\n", 51 | " df = pd.DataFrame(dataset['train'])\n", 52 | " df.to_csv(input_file, index=False)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "# shuffle all rows of the dataset\n", 62 | "\n", 63 | "df = pd.read_csv(output_file)\n", 64 | "df = df.sample(frac=1, random_state=200).reset_index(drop=True)\n", 65 | "df.to_csv(output_file, index=False)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "import pandas as pd\n", 75 | "df = pd.read_csv(output_file)\n", 76 | "\n", 77 | "print(len(df))\n", 78 | "\n", 79 | "# Split df into a train and test split\n", 80 | "train = df.sample(frac=0.5, random_state=200)\n", 81 | "test = df.drop(train.index)\n", 82 | "\n", 83 | "print(len(train))\n", 84 | "print(len(test))\n", 85 | "\n", 86 | "# Save the train and test splits to csv\n", 87 | "train.to_csv(f'{DATA_DIR}{prefix}train.csv', index=False)\n", 88 | "test.to_csv(f'{DATA_DIR}{prefix}test.csv', index=False)" 89 | ] 90 | } 91 | ], 92 | "metadata": { 93 | "kernelspec": { 94 | "display_name": "othello", 95 | "language": "python", 96 | "name": "python3" 97 | }, 98 | "language_info": { 99 | "codemirror_mode": { 100 | "name": "ipython", 101 | "version": 3 102 | }, 103 | "file_extension": ".py", 104 | "mimetype": "text/x-python", 105 | "name": "python", 106 | "nbconvert_exporter": "python", 107 | "pygments_lexer": "ipython3", 108 | "version": "3.11.7" 109 | } 110 | }, 111 | "nbformat": 4, 112 | "nbformat_minor": 2 113 | } 114 | -------------------------------------------------------------------------------- /utils/unique_checks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Place this in the data/ folder in chess_llm_interpretability to perform various uniqueness checks on datasets" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "\n", 18 | "# Function to process chunks of the CSV file\n", 19 | "def process_chunks(chunk_iter):\n", 20 | " transcripts = set() # Initialize an empty set for transcripts\n", 21 | " for chunk in chunk_iter:\n", 22 | " # Update the set with transcripts from the current chunk\n", 23 | " transcripts.update(chunk['transcript'])\n", 24 | " return transcripts\n", 25 | "\n", 26 | "# Specify the path to your CSV file\n", 27 | "file_path = 'lichess_6gb.csv'\n", 28 | "\n", 29 | "# Create a chunk iterator with a reasonable chunk size\n", 30 | "chunk_size = 10**5 # Adjust this based on your system's performance and memory usage\n", 31 | "\n", 32 | "# Create an iterator object for chunks of the DataFrame\n", 33 | "chunk_iter = pd.read_csv(file_path, chunksize=chunk_size, usecols=['transcript'])\n", 34 | "\n", 35 | "# Process the chunks and get the set of transcripts\n", 36 | "transcripts_set = process_chunks(chunk_iter)\n", 37 | "\n", 38 | "print(f\"Total unique transcripts: {len(transcripts_set)}\")\n" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "testing_file_path = 'lichess_100mb.csv'\n", 48 | "df = pd.read_csv(testing_file_path, usecols=['transcript'])\n", 49 | "print(f\"Total rows in the DataFrame: {len(df)}\")\n", 50 | "\n", 51 | "# Initialize a counter for overlaps\n", 52 | "overlap_count = 0\n", 53 | "\n", 54 | "# Process each transcript in the DataFrame\n", 55 | "for transcript in df['transcript']:\n", 56 | " # Check if the transcript is already in the set\n", 57 | " if transcript in transcripts_set:\n", 58 | " overlap_count += 1\n", 59 | " else:\n", 60 | " # Add the new transcript to the set\n", 61 | " transcripts_set.add(transcript)\n", 62 | "\n", 63 | "print(f\"Total unique transcripts now: {len(transcripts_set)}\")\n", 64 | "print(f\"Number of overlaps found: {overlap_count}\")" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "import csv\n", 74 | "\n", 75 | "# Initialize counters\n", 76 | "total_rows = 0\n", 77 | "total_characters = 0\n", 78 | "\n", 79 | "# Specify the path to your CSV file\n", 80 | "file_path = 'lichess_6gb.csv'\n", 81 | "\n", 82 | "# Open the file and use csv.reader to handle potential complexities in the CSV format\n", 83 | "with open(file_path, 'r', encoding='utf-8') as csvfile:\n", 84 | " reader = csv.reader(csvfile)\n", 85 | " # Skip the header\n", 86 | " next(reader)\n", 87 | " for row in reader:\n", 88 | " total_rows += 1\n", 89 | " # Assuming transcript is the last column\n", 90 | " transcript = row[-1]\n", 91 | " total_characters += len(transcript)\n", 92 | "\n", 93 | "print(f\"Total number of rows: {total_rows}\")\n", 94 | "print(f\"Total number of characters in transcripts: {total_characters}\")\n" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "import pandas as pd\n", 104 | "\n", 105 | "file_path = 'lichess_6gb.csv'\n", 106 | "lichess_df = pd.read_csv(file_path)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "llm_file_path = '8layer_llm_games.csv'\n", 116 | "llm_df = pd.read_csv(llm_file_path)\n", 117 | "llm_df = llm_df[:100]\n", 118 | "total_games = len(df)\n", 119 | "print(f\"Total number of games: {total_games}\")\n", 120 | "print(f\"Total rows in the LLM DataFrame: {len(llm_df)}\")" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "import re\n", 130 | "\n", 131 | "def unique_transcripts_by_move_df(move_number: int, lichess_df: pd.DataFrame, llm_df: pd.DataFrame) -> int:\n", 132 | " \"\"\"\n", 133 | " Find how many games in `llm_df` are unique by move number compared to games in `lichess_df`.\n", 134 | " \n", 135 | " Parameters:\n", 136 | " - move_number: The move number to compare uniqueness by.\n", 137 | " - lichess_df: DataFrame containing the Lichess game transcripts.\n", 138 | " - llm_df: DataFrame containing the LLM game transcripts.\n", 139 | " \n", 140 | " Returns:\n", 141 | " - The number of unique games in `llm_df` by move number.\n", 142 | " \"\"\"\n", 143 | " lichess_set = set()\n", 144 | " \n", 145 | " # Process lichess_df to extract unique transcripts by move number\n", 146 | " for i, transcript in enumerate(lichess_df['transcript']):\n", 147 | " shortened_transcript = \" \".join(transcript.split(' ', move_number)[:move_number])\n", 148 | " lichess_set.add(shortened_transcript)\n", 149 | " \n", 150 | " unique_count = 0\n", 151 | " \n", 152 | " # Process llm_df to find unique transcripts by move number\n", 153 | " for i, transcript in enumerate(llm_df['transcript']):\n", 154 | " transcript = transcript.split(\"\\n\\n\")[1].strip()\n", 155 | " transcript = re.sub(r\"(\\d+\\.) \", r\"\\1\", transcript)\n", 156 | " shortened_transcript = \" \".join(transcript.split(' ', move_number)[:move_number])\n", 157 | " if shortened_transcript not in lichess_set:\n", 158 | " unique_count += 1\n", 159 | " \n", 160 | " return unique_count\n" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "move_number = 20\n", 170 | "unique_games_count = unique_transcripts_by_move_df(move_number, lichess_df, llm_df)\n", 171 | "print(f\"Unique games by move {move_number}: {unique_games_count}\")" 172 | ] 173 | } 174 | ], 175 | "metadata": { 176 | "kernelspec": { 177 | "display_name": "chess", 178 | "language": "python", 179 | "name": "python3" 180 | }, 181 | "language_info": { 182 | "codemirror_mode": { 183 | "name": "ipython", 184 | "version": 3 185 | }, 186 | "file_extension": ".py", 187 | "mimetype": "text/x-python", 188 | "name": "python", 189 | "nbconvert_exporter": "python", 190 | "pygments_lexer": "ipython3", 191 | "version": "3.11.7" 192 | } 193 | }, 194 | "nbformat": 4, 195 | "nbformat_minor": 2 196 | } 197 | -------------------------------------------------------------------------------- /utils/view_caa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import transformer_lens.utils as utils\n", 10 | "from transformer_lens import HookedTransformer, HookedTransformerConfig\n", 11 | "import einops\n", 12 | "import torch\n", 13 | "from tqdm import tqdm\n", 14 | "import numpy as np\n", 15 | "from fancy_einsum import einsum\n", 16 | "import chess\n", 17 | "import numpy as np\n", 18 | "import csv\n", 19 | "from dataclasses import dataclass\n", 20 | "from torch.nn import MSELoss, L1Loss\n", 21 | "import pandas as pd\n", 22 | "import pickle\n", 23 | "import os\n", 24 | "import logging\n", 25 | "\n", 26 | "import chess_utils\n", 27 | "import train_test_chess\n", 28 | "from train_test_chess import Config, LinearProbeData" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "torch.set_grad_enabled(False)\n", 38 | "CAA_DIR = \"contrastive_activations/\"" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "state_dict_name = f\"{CAA_DIR}lichess_train_layer_12_pos_start_25_activations.pt\"\n", 48 | "state_dict = torch.load(state_dict_name, map_location=torch.device('cpu'))\n", 49 | "print(state_dict.keys())\n", 50 | "average_high_elo = state_dict[\"average_high_elo_activation\"]\n", 51 | "average_low_elo = state_dict[\"average_low_elo_activation\"]\n", 52 | "difference_vector = average_high_elo - average_low_elo\n", 53 | "state_dict[\"difference_vector\"] = difference_vector\n", 54 | "torch.save(state_dict, state_dict_name)\n", 55 | "\n", 56 | "state_dict2_name = f\"{CAA_DIR}lichess_train_layer_12_pos_start_25_num_games_20000_activations.pt\"\n", 57 | "state_dict2 = torch.load(state_dict2_name, map_location=torch.device('cpu'))\n", 58 | "print(state_dict2.keys())\n", 59 | "average_high_elo2 = state_dict2[\"average_high_elo_activation\"]\n", 60 | "average_low_elo2 = state_dict2[\"average_low_elo_activation\"]\n", 61 | "difference_vector2 = average_high_elo2 - average_low_elo2\n", 62 | "state_dict2[\"difference_vector\"] = difference_vector2\n", 63 | "torch.save(state_dict2, state_dict2_name)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "print(difference_vector.shape)\n", 73 | "print(state_dict[\"layer\"])\n", 74 | "print(state_dict[\"pos_start\"])\n", 75 | "print(state_dict2[\"layer\"])\n", 76 | "print(state_dict2[\"pos_start\"])" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "from torch.nn.functional import cosine_similarity\n", 86 | "# Calculating average values of each tensor\n", 87 | "avg_value_high_elo = torch.mean(average_high_elo)\n", 88 | "avg_value_low_elo = torch.mean(average_low_elo)\n", 89 | "avg_value_difference = torch.mean(difference_vector)\n", 90 | "\n", 91 | "avg_value_high_elo2 = torch.mean(average_high_elo2)\n", 92 | "avg_value_low_elo2 = torch.mean(average_low_elo2)\n", 93 | "avg_value_difference2 = torch.mean(difference_vector2)\n", 94 | "\n", 95 | "# Calculating cosine similarity between all pairs\n", 96 | "cos_sim_high_low = cosine_similarity(average_high_elo.unsqueeze(0), average_low_elo.unsqueeze(0)).item()\n", 97 | "cos_sim_high_diff = cosine_similarity(average_high_elo.unsqueeze(0), difference_vector.unsqueeze(0)).item()\n", 98 | "cos_sim_low_diff = cosine_similarity(average_low_elo.unsqueeze(0), difference_vector.unsqueeze(0)).item()\n", 99 | "\n", 100 | "cos_sim_high_low2 = cosine_similarity(average_high_elo2.unsqueeze(0), average_low_elo2.unsqueeze(0)).item()\n", 101 | "cos_sim_high_diff2 = cosine_similarity(average_high_elo2.unsqueeze(0), difference_vector2.unsqueeze(0)).item()\n", 102 | "cos_sim_low_diff2 = cosine_similarity(average_low_elo2.unsqueeze(0), difference_vector2.unsqueeze(0)).item()\n", 103 | "\n", 104 | "print(avg_value_high_elo, avg_value_low_elo, avg_value_difference, cos_sim_high_low, cos_sim_high_diff, cos_sim_low_diff)\n", 105 | "print(avg_value_high_elo2, avg_value_low_elo2, avg_value_difference2, cos_sim_high_low2, cos_sim_high_diff2, cos_sim_low_diff2)\n" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "cos_sim_high_high = cosine_similarity(average_high_elo.unsqueeze(0), average_high_elo2.unsqueeze(0)).item()\n", 115 | "cos_sim_low_low = cosine_similarity(average_low_elo.unsqueeze(0), average_low_elo2.unsqueeze(0)).item()\n", 116 | "cos_sim_diff_diff = cosine_similarity(difference_vector.unsqueeze(0), difference_vector2.unsqueeze(0)).item()\n", 117 | "\n", 118 | "print(cos_sim_high_high, cos_sim_low_low, cos_sim_diff_diff)\n", 119 | "\n", 120 | "cos_sim_high_low2 = cosine_similarity(average_high_elo.unsqueeze(0), average_low_elo2.unsqueeze(0)).item()\n", 121 | "cos_sim_high_diff2 = cosine_similarity(average_high_elo.unsqueeze(0), difference_vector2.unsqueeze(0)).item()\n", 122 | "cos_sim_low_diff2 = cosine_similarity(average_low_elo.unsqueeze(0), difference_vector2.unsqueeze(0)).item()\n", 123 | "cos_sim_low_high2 = cosine_similarity(average_low_elo.unsqueeze(0), average_high_elo2.unsqueeze(0)).item()\n", 124 | "\n", 125 | "print(cos_sim_high_low2, cos_sim_high_diff2, cos_sim_low_diff2, cos_sim_low_high2)" 126 | ] 127 | } 128 | ], 129 | "metadata": { 130 | "kernelspec": { 131 | "display_name": "othello", 132 | "language": "python", 133 | "name": "python3" 134 | }, 135 | "language_info": { 136 | "codemirror_mode": { 137 | "name": "ipython", 138 | "version": 3 139 | }, 140 | "file_extension": ".py", 141 | "mimetype": "text/x-python", 142 | "name": "python", 143 | "nbconvert_exporter": "python", 144 | "pygments_lexer": "ipython3", 145 | "version": "3.10.13" 146 | } 147 | }, 148 | "nbformat": 4, 149 | "nbformat_minor": 2 150 | } 151 | --------------------------------------------------------------------------------