├── .gitignore ├── LICENSE ├── README.md ├── grok ├── __init__.py ├── data.py ├── measure.py ├── metrics.py ├── training.py ├── transformer.py └── visualization.py ├── nbs └── flatness.ipynb ├── scripts ├── compute_sharpness.py ├── create_metric_graphs.py ├── create_metrics_for_epochs.py ├── create_partial_metrics.py ├── make_data.py ├── torch-setup.sh ├── train.py └── visualize_metrics.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | default 132 | checkpoints 133 | .vscode 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenAI Grok Curve Experiments 2 | 3 | ## Paper 4 | 5 | This is the code for the paper [Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets](https://arxiv.org/abs/2201.02177) by Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin, and Vedant Misra 6 | 7 | ## Installation and Training 8 | 9 | ```bash 10 | pip install -e . 11 | ./scripts/train.py 12 | ``` 13 | -------------------------------------------------------------------------------- /grok/__init__.py: -------------------------------------------------------------------------------- 1 | from . import transformer 2 | from . import data 3 | from . import training 4 | from . import metrics 5 | from . import visualization 6 | -------------------------------------------------------------------------------- /grok/data.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | import os 4 | import sys 5 | import random 6 | 7 | import torch 8 | from torch import Tensor, LongTensor 9 | import numpy as np 10 | from typing import Tuple, List, Dict, Any, Union, Optional 11 | from tqdm import tqdm 12 | 13 | from sympy.combinatorics.permutations import Permutation 14 | from mod import Mod 15 | 16 | import blobfile as bf 17 | 18 | 19 | VALID_OPERATORS = { 20 | "+": "addition", 21 | "-": "subtraction", 22 | "*": "muliplication", 23 | "/": "division", 24 | "**2+": "squarepoly", 25 | "**3+": "cubepoly", 26 | "x**2+y**2_mod_97": "quad1", 27 | "x**2+y**2+x*y_mod_97": "quad2", 28 | "x**2+y**2+x*y+x_mod_97": "quad3", 29 | "x**3+x*y_mod_97": "cube1", 30 | "x**3+x*y**2+y_mod_97": "cube2", 31 | "(x._value//y)if(y._value%2==1)else(x-y)_mod_97": "mix1", 32 | "s5": "s5", 33 | "s5conj": "s5conj", 34 | "s5aba": "s5aba", 35 | "+*": "even-addition_odd-multiplication", 36 | "+-": "even-addition_odd-subtraction", 37 | "sort": "sort", 38 | "reverse": "reverse", 39 | "copy": "copy", 40 | } 41 | EOS_TOKEN = "<|eos|>" 42 | EQ_TOKEN = "=" 43 | MODULUS = 97 44 | NUMS = list(range(MODULUS)) 45 | 46 | DEFAULT_DATA_DIR = "data" 47 | 48 | 49 | def render(operand, join_str=""): 50 | if ( 51 | isinstance(operand, list) 52 | or isinstance(operand, tuple) 53 | or isinstance(operand, np.ndarray) 54 | ): 55 | return join_str.join(map(render, operand)) 56 | elif isinstance(operand, Permutation): 57 | return "".join(map(str, operand.array_form)) 58 | elif isinstance(operand, Mod): 59 | return str(operand._value) 60 | else: 61 | return str(operand) 62 | 63 | 64 | def create_data_files(data_dir: str = DEFAULT_DATA_DIR): 65 | ArithmeticTokenizer.create_token_file(data_dir) 66 | ArithmeticDataset.create_dataset_files(data_dir) 67 | 68 | 69 | class ArithmeticTokenizer: 70 | """Stores the list of token text to token id mappings and converts between them""" 71 | 72 | token_file = "tokens.txt" 73 | 74 | def __init__(self, data_dir=DEFAULT_DATA_DIR) -> None: 75 | self.token_file = bf.join(data_dir, self.token_file) 76 | 77 | self.itos = self.get_tokens() 78 | 79 | self.stoi: Dict[str, int] = dict([(s, i) for i, s in enumerate(self.itos)]) 80 | 81 | def _encode(self, s: str) -> Tensor: 82 | return LongTensor([self.stoi[t] for t in s.split(" ")]) 83 | 84 | def encode(self, obj: Union[str, List]) -> Tensor: 85 | """ 86 | Convert a string of text into a rank-1 tensor of token ids 87 | or convert a list of strings of text into a rank-2 tensor of token ids 88 | 89 | :param obj: the string or list of strings to convert 90 | :returns: a tensor of the token ids 91 | """ 92 | if isinstance(obj, str): 93 | return self._encode(obj) 94 | elif isinstance(obj, list): 95 | return torch.stack([self._encode(s) for s in obj], dim=0) 96 | else: 97 | raise NotImplementedError 98 | 99 | def decode(self, tensor: Tensor, with_brackets: bool = False) -> str: 100 | """ 101 | Convert a tensor of token ids into a string of text 102 | 103 | :param tensor: a tensor of the token ids 104 | :param with_brackets: if true, the returned string will include <> brackets 105 | around the text corresponding to each token. 106 | :returns: string of these tokens. 107 | """ 108 | indices = tensor.long() 109 | if with_brackets: 110 | l = "<" 111 | r = ">" 112 | else: 113 | l = "" 114 | r = "" 115 | tokens = [l + self.itos[i] + r for i in indices] 116 | return " ".join(tokens) 117 | 118 | def __len__(self) -> int: 119 | """ 120 | :returns: the number of tokens in this vocabulary 121 | """ 122 | return len(self.itos) 123 | 124 | @classmethod 125 | def get_tokens(cls): 126 | tokens = ( 127 | [EOS_TOKEN, EQ_TOKEN] 128 | + list(sorted(list(VALID_OPERATORS.keys()))) 129 | + list(map(render, NUMS)) 130 | + list(map(render, itertools.permutations(range(5)))) # s5 131 | ) 132 | return tokens 133 | 134 | 135 | class ArithmeticDataset: 136 | """A Dataset of arithmetic equations""" 137 | 138 | @classmethod 139 | def splits( 140 | cls, 141 | train_pct: float, 142 | operator: str, 143 | operand_length: Optional[int] = None, 144 | data_dir: str = DEFAULT_DATA_DIR, 145 | ): 146 | """ 147 | Creates training and validation datasets 148 | 149 | :param train_pct: percentage of total equations used for training data 150 | :param operator: The arithmetic operator for this dataset e.g. '+', '-', '*', '/', 'sort' 151 | :param operand_length: for list based datasets the length of the lists 152 | :returns: (train_dataset, validation_dataset) 153 | """ 154 | 155 | assert (0 < train_pct) and (train_pct < 100) 156 | 157 | ds_name = cls.get_dsname(operator, operand_length) 158 | eqs = cls.make_data(operator, operand_length) 159 | 160 | train_rows, _ = cls.calc_split_len(train_pct, len(eqs)) 161 | 162 | train_ds = cls(ds_name, eqs[:train_rows], train=True, data_dir=data_dir) 163 | val_ds = cls(ds_name, eqs[train_rows:], train=False, data_dir=data_dir) 164 | 165 | return train_ds, val_ds 166 | 167 | @classmethod 168 | def calc_split_len(cls, train_pct, ds_len): 169 | train_rows = round(ds_len * (train_pct / 100.0)) 170 | val_rows = ds_len - train_rows 171 | return train_rows, val_rows 172 | 173 | def __init__(self, name, data: Union[Tensor, List[str]], train, data_dir) -> None: 174 | """ 175 | :param data: A list of equations strings. Each equation must have an '=' in it. 176 | """ 177 | self.tokenizer = ArithmeticTokenizer(data_dir) 178 | self.name = name 179 | self.train = train 180 | if isinstance(data, list): 181 | self.data = self.tokenizer.encode(data) 182 | else: 183 | self.data = data 184 | 185 | def __len__(self) -> int: 186 | """ 187 | :returns: total number of equations in this dataset 188 | """ 189 | return self.data.shape[0] 190 | 191 | # @classmethod 192 | # def _render(cls, operand): 193 | # return render(operand, join_str=" ") 194 | # 195 | # @classmethod 196 | # def _render_eq(parts): 197 | # return " ".join(map(render, parts)) 198 | 199 | @classmethod 200 | def _make_binary_operation_data(cls, operator: str, operands=None) -> List[str]: 201 | if operator == "s5": 202 | operands = operands or list(range(5)) 203 | elems = map(np.array, itertools.permutations(operands)) 204 | tuples = itertools.product(elems, repeat=2) 205 | elif operator in ["s5conj", "s5aba"]: 206 | operands = operands or list(range(5)) 207 | elems = map(Permutation, itertools.permutations(operands)) 208 | tuples = itertools.product(elems, repeat=2) 209 | elif "_mod_" in operator: 210 | modulo = int(operator.split("_mod_")[-1]) 211 | elems = [Mod(i, modulo) for i in range(modulo)] 212 | tuples = itertools.product(elems, repeat=2) 213 | else: 214 | operands = operands or NUMS 215 | tuples = itertools.product(operands, repeat=2) 216 | 217 | # if operator == "s5": 218 | # print("elems", list(elems)) 219 | # print("tuples", list(tuples)) 220 | eqs = [] 221 | for a, b in tuples: 222 | if operator == "/": 223 | if b == 0: 224 | continue 225 | else: 226 | c = a 227 | a = (b * c) % MODULUS 228 | elif operator == "s5": 229 | c = b[a] 230 | elif operator == "s5conj": 231 | c = a * b * (a.__invert__()) 232 | elif operator == "s5aba": 233 | c = a * b * a 234 | elif operator == "+*": 235 | if a % 2 == 0: 236 | c = (a + b) % MODULUS 237 | else: 238 | c = (a * b) % MODULUS 239 | elif operator == "+-": 240 | if a % 2 == 0: 241 | c = (a + b) % MODULUS 242 | else: 243 | c = (a - b) % MODULUS 244 | elif "_mod_" in operator: 245 | expression = operator.split("_mod_")[0] 246 | function = eval(f"lambda x, y: ({expression})") 247 | c = function(a, b) 248 | else: 249 | c = eval(f"({a} {operator} {b}) % {MODULUS}") 250 | eq = " ".join(map(render, [a, operator, b, "=", c])) 251 | eqs.append(eq) 252 | 253 | # if operator == "s5": 254 | # print("eqs", eqs) 255 | return eqs 256 | 257 | # @staticmethod 258 | # def _render_unop_example(operator, lhs, rhs): 259 | # return " ".join([operator, render(lhs), "=", render(rhs)]) 260 | 261 | @staticmethod 262 | def _make_unary_operation_data(operator: str, operands: Tensor) -> List[str]: 263 | """ 264 | :param operator: The unary operator to apply to each operand e.g. '+' 265 | :param operands: A tensor of operands 266 | :returns: list of equations""" 267 | num_examples = len(operands) 268 | 269 | if operator == "sort": 270 | rhs = torch.sort(operands, dim=1)[0] 271 | elif operator == "reverse": 272 | rhs = torch.flip(operands, dims=(1,)) 273 | elif operator == "copy": 274 | rhs = operands 275 | else: 276 | raise Exception("unsupported operator") 277 | 278 | def func(L, R): 279 | L = map(str, L) 280 | R = map(str, R) 281 | return f"{operator} {' '.join(L)} = {' '.join(R)}" 282 | 283 | if num_examples < 1000000000: 284 | eqs = [ 285 | func(L, R) 286 | for L, R in tqdm( 287 | zip(operands.tolist(), rhs.tolist()), total=num_examples 288 | ) 289 | ] 290 | else: 291 | with ProcessPoolExecutor() as executor: 292 | eqs = executor.map(func, tqdm(zip(operands, rhs), total=num_examples)) 293 | 294 | return eqs 295 | 296 | # @staticmethod 297 | # def _make_s5_data(abstract=False) -> List[str]: 298 | # elems = itertools.permutations([0, 1, 2, 3, 4]) 299 | # pairs = itertools.product(elems, repeat=2) 300 | # eqs = [] 301 | # for a, b in pairs: 302 | # a = np.array(a) 303 | # b = np.array(b) 304 | # c = b[a] 305 | # eq = " ".join(map(render, (a, "s5", b, "=", c))) 306 | # eq = cls._render_eq([a, , b, "=", c]) 307 | # eqs.append(eq) 308 | # 309 | # return eqs 310 | 311 | @classmethod 312 | def get_dsname(cls, operator, operand_length) -> str: 313 | operator, noise_level = cls._get_operator_and_noise_level(operator) 314 | ds_name = VALID_OPERATORS[operator] 315 | if operand_length is not None: 316 | ds_name += f"_length-{operand_length}" 317 | if noise_level > 0: 318 | ds_name += f"_noise-{noise_level}" 319 | return ds_name 320 | 321 | @classmethod 322 | def get_file_path(cls, operator, operand_length=None, data_dir=DEFAULT_DATA_DIR): 323 | ds_name = cls.get_dsname(operator, operand_length) 324 | ds_file = bf.join(data_dir, f"{ds_name}_data.txt") 325 | return ds_file, ds_name 326 | 327 | @classmethod 328 | def _get_operator_and_noise_level(cls, operator): 329 | if "_noisy" in operator: 330 | operator, noise_level = operator.split("_noisy_") 331 | return operator, int(noise_level) 332 | else: 333 | return operator, 0 334 | 335 | @classmethod 336 | def make_data(cls, operator, operands=None, shuffle=True, seed=0) -> List[str]: 337 | operator, noise_level = cls._get_operator_and_noise_level(operator) 338 | assert operator in VALID_OPERATORS 339 | 340 | if operator not in ["sort", "reverse", "copy"]: 341 | data = cls._make_binary_operation_data(operator) 342 | else: 343 | data = cls._make_unary_operation_data(operator, operands) 344 | 345 | rng = np.random.RandomState(seed=seed) 346 | if shuffle: 347 | rng.shuffle(data) 348 | 349 | if noise_level > 0: 350 | random_answer_eqns = rng.choice(data, size=noise_level) 351 | random_answers = [ 352 | random_eq.split(" = ")[1] for random_eq in random_answer_eqns 353 | ] 354 | for i in range(noise_level): 355 | data[i] = data[i].split(" = ")[0] + " = " + random_answers[i] 356 | 357 | data = [EOS_TOKEN + " " + eq + " " + EOS_TOKEN for eq in data] 358 | 359 | return data 360 | 361 | # @classmethod 362 | # def create_data_file( 363 | # cls, operator, operand_length=None, shuffle=True, data_dir=DEFAULT_DATA_DIR 364 | # ): 365 | # if VALID_OPERATORS[operator]["binary_eval"]: 366 | # cls.write_dataset( 367 | # cls.make_binary_operation_data(operator), paths["ds_file"] 368 | # ) 369 | # 370 | # pass 371 | 372 | # @classmethod 373 | # def write_dataset(eqs: List[str], ds_file: str): 374 | # print(f"-> writing {ds_file}", flush=True) 375 | # with open(ds_file, "w") as fh: 376 | # fh.writelines([EOS_TOKEN + " " + eq + " " + EOS_TOKEN + "\n" for eq in eqs]) 377 | 378 | @classmethod 379 | def _make_lists(cls, sizes=[2, 3], nums=NUMS): 380 | lists: dict = {} 381 | for size in sizes: 382 | lists[size] = torch.tensor( 383 | list(itertools.permutations(nums, r=size)), 384 | dtype=torch.int, 385 | ) 386 | return lists 387 | 388 | 389 | class ArithmeticIterator(torch.utils.data.IterableDataset): 390 | """ 391 | An iterator over batches of data in an ArithmeticDataset 392 | """ 393 | 394 | def __init__( 395 | self, 396 | dataset: ArithmeticDataset, 397 | device: torch.device, 398 | batchsize_hint: float = 0, 399 | shuffle: bool = True, 400 | ) -> None: 401 | """ 402 | :param dataset: the dataset to iterate over 403 | :param device: the torch device to send batches to 404 | :param batchsize_hint: * 0 means we use a default batchsize 405 | * -1 means the entire dataset 406 | * float between 0 and 1 means each batch is 407 | that fraction of the DS 408 | * int > 1 means that specific batch size 409 | :param shuffle: whether or not to randomly shuffle the dataset 410 | """ 411 | self.dataset = dataset 412 | self.batchsize = self.calculate_batchsize( 413 | len(dataset), batchsize_hint=batchsize_hint 414 | ) 415 | self.device = device 416 | self.reset_iteration(shuffle=shuffle) 417 | 418 | @staticmethod 419 | def calculate_batchsize(ds_size: int, batchsize_hint: int = 0) -> int: 420 | """ 421 | Calculates which batch size to use 422 | 423 | :param ds_size: the number of equations in the dataset 424 | :param batchsize_hint: * 0 means we use a default batchsize 425 | * -1 means the entire dataset 426 | * float between 0 and 1 means each batch is 427 | that fraction of the DS 428 | * int > 1 means that specific batch size 429 | :returns: the actual batchsize to use 430 | """ 431 | 432 | if batchsize_hint == -1: 433 | return ds_size 434 | elif batchsize_hint == 0: 435 | return min(512, math.ceil(ds_size / 2.0)) 436 | elif (batchsize_hint > 0) and (batchsize_hint < 1): 437 | return math.ceil(ds_size * batchsize_hint) 438 | elif batchsize_hint > 1: 439 | return min(batchsize_hint, ds_size) 440 | else: 441 | raise ValueError("batchsize_hint must be >= -1") 442 | 443 | def reset_iteration(self, shuffle=True): 444 | self.index = 0 445 | if shuffle and self.dataset.train: 446 | self.permutation = torch.randperm(len(self.dataset)) 447 | else: 448 | self.permutation = torch.arange(len(self.dataset)) 449 | 450 | def __iter__(self): 451 | """ 452 | :returns: this iterator 453 | """ 454 | return self 455 | 456 | def __next__(self) -> Dict[str, Tensor]: 457 | """ 458 | Returns one batch of data. 459 | 460 | :raises: StopIteration when we're out of data 461 | :returns: batch tensor of shape (self.batchsize, tokens_per_eq) 462 | """ 463 | 464 | batch_begin = self.index * self.batchsize 465 | if batch_begin > len(self.dataset) - 1: 466 | self.reset_iteration() 467 | raise StopIteration 468 | indices = self.permutation[batch_begin : batch_begin + self.batchsize] 469 | text = self.dataset.data[indices, :-1] 470 | target = self.dataset.data[indices, 1:] 471 | batch = {"text": text.to(self.device), "target": target.to(self.device)} 472 | self.index += 1 473 | return batch 474 | 475 | def __len__(self) -> int: 476 | """ 477 | :returns: the total number of batches 478 | """ 479 | return math.ceil(len(self.dataset) / self.batchsize) 480 | -------------------------------------------------------------------------------- /grok/measure.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import numpy as np 4 | 5 | import scipy.optimize 6 | 7 | 8 | def get_loss_and_grads(x, model, data_loader): 9 | 10 | # if type(x).__module__ == np.__name__: 11 | # x = torch.from_numpy(x).float() 12 | # x = x.cuda() 13 | 14 | model.eval() 15 | 16 | x_start = 0 17 | for p in model.parameters(): 18 | param_size = p.data.size() 19 | param_idx = 1 20 | for s in param_size: 21 | param_idx *= s 22 | x_part = x[x_start : x_start + param_idx] 23 | p.data = torch.Tensor(x_part.reshape(param_size)) 24 | x_start += param_idx 25 | 26 | batch_losses = [] 27 | batch_grads = [] 28 | for it, batch in enumerate(data_loader): 29 | 30 | # Move data to correct device 31 | # inputs = inputs.to(device) 32 | # targets = targets.to(device) 33 | 34 | with torch.set_grad_enabled(True): 35 | # loss, grads = model(idx=inputs, targets=targets, grads=True) 36 | loss, grads = model._step(batch=batch, batch_idx=1, train=True, grads=True) 37 | 38 | # Todo: average over dataset 39 | batch_losses.append(loss) 40 | # batch_grads.append(None if grads is None else grads.cpu().numpy().astype(np.float64)) 41 | batch_grads.append(None if grads is None else grads) 42 | 43 | mean_losses = torch.mean(torch.stack(batch_losses)) 44 | mean_grads = torch.mean(torch.stack(batch_grads), dim=0) 45 | 46 | return (mean_losses, mean_grads.cpu().numpy().astype(np.float64)) 47 | 48 | 49 | def get_weights(model): 50 | """ 51 | Given a model, return a vector of weights. 52 | """ 53 | x0 = None 54 | for p in model.parameters(): 55 | if x0 is None: 56 | x0 = p.data.view(-1) 57 | else: 58 | x0 = torch.cat((x0, p.data.view(-1))) 59 | return x0.cpu().numpy() 60 | 61 | 62 | def get_sharpness(data_loader, model, subspace_dim=10, epsilon=1e-3, maxiter=10): 63 | """ 64 | Compute the sharpness around some point in weight space, as specified 65 | in Keskar et. al. (2016) Sec 2.2.2: 66 | https://arxiv.org/pdf/1609.04836.pdf 67 | 68 | See: 69 | https://gist.github.com/arthurmensch/c55ac413868550f89225a0b9212aa4cd 70 | https://gist.github.com/gngdb/a9f912df362a85b37c730154ef3c294b 71 | https://github.com/keskarnitish/large-batch-training 72 | https://github.com/wenwei202/smoothout 73 | https://github.com/keras-team/keras/pull/3064 74 | """ 75 | 76 | x0 = get_weights(model) 77 | 78 | f_x0, _ = get_loss_and_grads(x0, model, data_loader) 79 | f_x0 = -f_x0 80 | logging.info("min loss f_x0 = {loss:.4f}".format(loss=f_x0)) 81 | 82 | if 0 == subspace_dim: 83 | x_min = np.reshape(x0 - epsilon * (np.abs(x0) + 1), (x0.shape[0], 1)) 84 | x_max = np.reshape(x0 + epsilon * (np.abs(x0) + 1), (x0.shape[0], 1)) 85 | bounds = np.concatenate([x_min, x_max], 1) 86 | func = lambda x: get_loss_and_grads(x, model, data_loader) 87 | init_guess = x0 88 | else: 89 | assert subspace_dim <= x0.shape[0] 90 | 91 | # Computed via Keskar, et. al 92 | # https://arxiv.org/pdf/1609.04836.pdf 93 | 94 | A_plus = np.random.rand(subspace_dim, x0.shape[0]) * 2.0 - 1.0 95 | A_plus_norm = np.linalg.norm(A_plus, axis=1) 96 | A_plus = A_plus / np.reshape(A_plus_norm, (subspace_dim, 1)) 97 | A = np.linalg.pinv(A_plus) 98 | 99 | abs_bound = epsilon * (np.abs(np.dot(A_plus, x0)) + 1) 100 | abs_bound = np.reshape(abs_bound, (abs_bound.shape[0], 1)) 101 | bounds = np.concatenate([-abs_bound, abs_bound], 1) 102 | 103 | def func(y): 104 | f_loss, f_grads = get_loss_and_grads( 105 | x0 + np.dot(A, y), 106 | model, 107 | data_loader, 108 | ) 109 | return f_loss, np.dot(np.transpose(A), f_grads) 110 | 111 | init_guess = np.zeros(subspace_dim) 112 | 113 | minimum_x, f_x, d = scipy.optimize.fmin_l_bfgs_b( 114 | func, 115 | init_guess, 116 | maxiter=maxiter, 117 | bounds=bounds, 118 | disp=1, 119 | ) 120 | f_x = -f_x 121 | logging.info("max loss f_x = {loss:.4f}".format(loss=f_x)) 122 | 123 | # Eq 4 in Keskar 124 | phi = (f_x - f_x0) / (1 + f_x0) * 100 125 | 126 | # Restore parameter values 127 | x0 = torch.from_numpy(x0).float() 128 | # x0 = x0.cuda() 129 | x_start = 0 130 | for p in model.parameters(): 131 | param_size = p.data.size() 132 | param_idx = 1 133 | for s in param_size: 134 | param_idx *= s 135 | x_part = x0[x_start : x_start + param_idx] 136 | p.data = x_part.view(param_size) 137 | x_start += param_idx 138 | 139 | return phi 140 | -------------------------------------------------------------------------------- /grok/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import copy 4 | import torch.nn as nn 5 | from typing import Callable 6 | 7 | # References: 8 | # https://github.com/nitarshan/robust-generalization-measures 9 | # https://github.com/bneyshabur/generalization-bounds 10 | # https://github.com/bneyshabur/over-parametrization 11 | 12 | 13 | def compute_measure( 14 | model: nn.Module, 15 | init_model: nn.Module, 16 | measure_func: Callable, 17 | operator: str, 18 | kwargs: dict = {}, 19 | p: int = 1, 20 | ) -> float: 21 | """ 22 | Computes measure value for each layer given trained network and network at 23 | initialization. Then aggregates values per layer using specified operator. 24 | 25 | :param model: trained network 26 | :param init_model: network at initialization 27 | :param measure_func: callable for the measure to compute 28 | :param operator: 'log_product', 'sum', 'max', 'product', or 'norm' 29 | :param p: p in L^p 30 | :return: value of the desired measure 31 | """ 32 | 33 | measure_value = 0 34 | # weight_modules = ["Linear", "Embedding"] 35 | weight_modules = ["Linear"] 36 | 37 | if operator == "product": 38 | measure_value = math.exp( 39 | compute_measure(model, init_model, measure_func, "log_product", kwargs, p) 40 | ) 41 | elif operator == "norm": 42 | measure_value = ( 43 | compute_measure(model, init_model, measure_func, "sum", kwargs, p=p) 44 | ) ** (1 / p) 45 | else: 46 | measure_value = 0 47 | for child, init_child in zip(model.children(), init_model.children()): 48 | module_name = child._get_name() 49 | if module_name in weight_modules: 50 | if operator == "log_product": 51 | measure_value += math.log(measure_func(child, init_child, **kwargs)) 52 | elif operator == "sum": 53 | measure_value += (measure_func(child, init_child, **kwargs)) ** p 54 | elif operator == "max": 55 | measure_value = max( 56 | measure_value, measure_func(child, init_child, **kwargs) 57 | ) 58 | else: 59 | measure_value += compute_measure( 60 | child, init_child, measure_func, operator, kwargs, p=p 61 | ) 62 | return measure_value 63 | 64 | 65 | def norm(module, init_module, p=2, q=2): 66 | """ 67 | Calculates l_pq norm of a parameter matrix 68 | l_p norm of incoming weights to each hidden unit 69 | l_q norm on the hidden units 70 | """ 71 | return module.weight.view(module.weight.size(0), -1).norm(p=p, dim=1).norm(q).item() 72 | 73 | 74 | def op_norm(module, init_module, p=float("Inf")): 75 | """ 76 | Calculates l_p norm of eigenvalues of parameter matrix 77 | """ 78 | _, S, _ = module.weight.view(module.weight.size(0), -1).svd() 79 | return S.norm(p).item() 80 | 81 | 82 | def dist(module, init_module, p=2, q=2): 83 | """ 84 | Calculates l_pq distance of the parameter matrix of a layer from the random 85 | initialization: 86 | l_p norm of incoming weights to each hidden unit 87 | l_q norm on the hidden units 88 | """ 89 | return ( 90 | (module.weight - init_module.weight) 91 | .view(module.weight.size(0), -1) 92 | .norm(p=p, dim=1) 93 | .norm(q) 94 | .item() 95 | ) 96 | 97 | 98 | def h_dist(module, init_module, p=2, q=2): 99 | """ 100 | Calculate l_pq distance of parameters of trained network from random init 101 | Includes extra factor depending on number of hidden units 102 | """ 103 | return (n_hidden(module, init_module) ** (1 - 1 / q)) * dist( 104 | module, init_module, p=p, q=q 105 | ) 106 | 107 | 108 | def h_dist_op_norm(module, init_module, p=2, q=2, p_op=float("Inf")): 109 | """ 110 | Calculate ratio of h_dist to operator norm 111 | """ 112 | return h_dist(module, init_module, p=p, q=q) / op_norm(module, init_module, p=p_op) 113 | 114 | 115 | def n_hidden(module, init_module): 116 | """ 117 | Number of hidden units 118 | """ 119 | return module.weight.size(0) 120 | 121 | 122 | def depth(module, init_module): 123 | """ 124 | Depth (always == 1 for any linear layer) 125 | """ 126 | return 1 127 | 128 | 129 | def n_param(module, init_module): 130 | """ 131 | Num parameters 132 | """ 133 | bparam = 0 if module.bias is None else module.bias.size(0) 134 | return bparam + module.weight.size(0) * module.weight.view( 135 | module.weight.size(0), -1 136 | ).size(1) 137 | 138 | 139 | def lp_path_norm(model, device, p=2, input_size=[3, 32, 32]): 140 | """ 141 | Path norm (Neyshabur 2015) 142 | """ 143 | 144 | tmp_model = copy.deepcopy(model) 145 | tmp_model.eval() 146 | for param in tmp_model.parameters(): 147 | if param.requires_grad: 148 | param.abs_().pow_(p) 149 | data_ones = torch.ones(input_size).to(device) 150 | return (tmp_model(data_ones).sum() ** (1 / p)).item() 151 | 152 | 153 | def calculate(trained_model, init_model, device, dataset_size, margin, input_dim): 154 | """ 155 | Calculates various measures given trained model and model at init 156 | Computes: 157 | measures: norm based measures on the model 158 | bounds: generalization bounds on the model 159 | """ 160 | 161 | model = copy.deepcopy(trained_model) 162 | 163 | # depth 164 | d = compute_measure(model, init_model, depth, "sum", {}) 165 | 166 | # number of parameters (not including batch norm) 167 | nparam = compute_measure(model, init_model, n_param, "sum", {}) 168 | 169 | measure, bound = {}, {} 170 | with torch.no_grad(): 171 | 172 | # Compute measures 173 | measure["L_{1,inf} norm"] = ( 174 | compute_measure( 175 | model, init_model, norm, "product", {"p": 1, "q": float("Inf")} 176 | ) 177 | / margin 178 | ) 179 | measure["Frobenius norm"] = ( 180 | compute_measure(model, init_model, norm, "product", {"p": 2, "q": 2}) 181 | / margin 182 | ) 183 | measure["L_{3,1.5} norm"] = ( 184 | compute_measure(model, init_model, norm, "product", {"p": 3, "q": 1.5}) 185 | / margin 186 | ) 187 | measure["Spectral norm"] = ( 188 | compute_measure(model, init_model, op_norm, "product", {"p": float("Inf")}) 189 | / margin 190 | ) 191 | measure["L_1.5 operator norm"] = ( 192 | compute_measure(model, init_model, op_norm, "product", {"p": 1.5}) / margin 193 | ) 194 | measure["Trace norm"] = ( 195 | compute_measure(model, init_model, op_norm, "product", {"p": 1}) / margin 196 | ) 197 | 198 | # input_size = [context_len, emb_dim] 199 | # measure["L1_path norm"] = ( 200 | # lp_path_norm( 201 | # model, device, p=1, input_size=input_size 202 | # ) 203 | # / margin 204 | # ) 205 | # measure["L1.5_path norm"] = ( 206 | # lp_path_norm( 207 | # model, device, p=1.5, input_size=input_size 208 | # ) 209 | # / margin 210 | # ) 211 | # measure["L2_path norm"] = ( 212 | # lp_path_norm( 213 | # model, device, p=2, input_size=input_size 214 | # ) 215 | # / margin 216 | # ) 217 | 218 | # Compute generalization bounds without constant or additive logarithmic factors 219 | 220 | # Golowich 2018 221 | # https://arxiv.org/pdf/1712.06541.pdf 222 | alpha = math.sqrt(d + math.log(1 * input_dim * input_dim)) 223 | 224 | # Bartlett Mendelson 2002 225 | bound["L1_max Bound"] = ( 226 | alpha * measure["L_{1,inf} norm"] / math.sqrt(dataset_size) 227 | ) 228 | 229 | # Neyshabur 2015 230 | bound["Frobenius Bound"] = ( 231 | alpha * measure["Frobenius norm"] / math.sqrt(dataset_size) 232 | ) 233 | 234 | # Neyshabur 2015 235 | bound["L_{3,1.5} Bound"] = ( 236 | alpha * measure["L_{3,1.5} norm"] / (dataset_size ** (1 / 3)) 237 | ) 238 | 239 | beta = math.log(dataset_size) * math.log(nparam) 240 | ratio = compute_measure( 241 | model, 242 | init_model, 243 | h_dist_op_norm, 244 | "norm", 245 | {"p": 2, "q": 1, "p_op": float("Inf")}, 246 | p=2 / 3, 247 | ) 248 | 249 | # Spectral L_{2, 1} Bound 250 | # Bartlett 2017 251 | bound["Spec_L_{2,1} Bound"] = ( 252 | beta * measure["Spectral norm"] * ratio / math.sqrt(dataset_size) 253 | ) 254 | 255 | ratio = compute_measure( 256 | model, 257 | init_model, 258 | h_dist_op_norm, 259 | "norm", 260 | {"p": 2, "q": 2, "p_op": float("Inf")}, 261 | p=2, 262 | ) 263 | 264 | # Spectral Frobenius 265 | # Neyshabur 2018 266 | # https://arxiv.org/pdf/1706.08947.pdf 267 | bound["Spec_Fro Bound"] = ( 268 | d * measure["Spectral norm"] * ratio / math.sqrt(dataset_size) 269 | ) 270 | 271 | return measure, bound 272 | -------------------------------------------------------------------------------- /grok/training.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import copy 5 | import json 6 | import logging 7 | import math 8 | import os 9 | import sys 10 | import pickle 11 | from argparse import ArgumentParser, Namespace 12 | from functools import reduce 13 | from typing import Any, Dict, List, Optional, Tuple, Union 14 | import time 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | from pytorch_lightning import LightningModule, Trainer 20 | from pytorch_lightning.callbacks import Callback, ModelCheckpoint 21 | from pytorch_lightning.loggers import CSVLogger 22 | from torch import Tensor 23 | from torch.optim.lr_scheduler import LambdaLR 24 | 25 | import grok.metrics as metrics 26 | from grok.data import ( 27 | DEFAULT_DATA_DIR, 28 | EOS_TOKEN, 29 | VALID_OPERATORS, 30 | ArithmeticDataset, 31 | ArithmeticIterator, 32 | ) 33 | from grok.transformer import Transformer 34 | from grok.measure import get_sharpness 35 | 36 | DEFAULT_LOG_DIR = "logs" 37 | 38 | 39 | class TrainableTransformer(LightningModule): 40 | """ 41 | Adds training methods to train a generic transformer on arithmetic equations 42 | """ 43 | 44 | def __init__(self, hparams: Namespace) -> None: 45 | """ 46 | :param hparams: An argparse.Namespace with parameters defined in 47 | self.add_model_specific_args(). 48 | """ 49 | super().__init__() 50 | self.hparams = hparams # type: ignore 51 | self.prepare_data() 52 | 53 | self.transformer = Transformer( 54 | hparams.n_layers, 55 | hparams.n_heads, 56 | hparams.d_model, 57 | hparams.dropout, 58 | hparams.max_context_len, 59 | len(self.train_dataset.tokenizer), 60 | hparams.non_linearity, 61 | weight_noise=self.hparams.weight_noise, 62 | ) 63 | 64 | self.margin = torch.Tensor([0]) 65 | self.next_epoch_to_eval = -1 66 | self.next_train_epoch_to_log = 0 67 | 68 | @staticmethod 69 | def add_model_specific_args(parser: ArgumentParser) -> ArgumentParser: 70 | """ 71 | Defines the hyperparameter arguments needed by instances of this 72 | class. This is intended to be called when parsing command line 73 | arguments. 74 | 75 | :param parser: an argparse.ArgumentParser created by the caller 76 | :returns: the argument parser with the command line arguments added 77 | for this class. 78 | """ 79 | parser.add_argument( 80 | "--batchsize", 81 | type=float, 82 | # default=0.25, 83 | default=0, 84 | help="-1 -> entire dataset, 0 -> auto-calculate, 0 fraction of dataset, N>1 -> N", 85 | ) 86 | 87 | parser.add_argument("--n_layers", type=int, default=2) 88 | parser.add_argument("--n_heads", type=int, default=4) 89 | parser.add_argument("--d_model", type=int, default=128) 90 | parser.add_argument("--dropout", type=float, default=0.0) 91 | parser.add_argument("--weight_noise", type=float, default=0.0) 92 | parser.add_argument("--non_linearity", type=str, default="relu") 93 | parser.add_argument("--max_context_len", type=int, default=50) 94 | 95 | parser.add_argument("--math_operator", type=str, default="+") 96 | parser.add_argument( 97 | "--operand_length", 98 | type=int, 99 | help="for list operations, the length of the lists", 100 | ) 101 | 102 | parser.add_argument("--train_data_pct", type=float, default=5) 103 | parser.add_argument("--warmup_steps", type=int, default=10) 104 | parser.add_argument("--anneal_lr_steps", type=int, default=100000) 105 | parser.add_argument("--anneal_lr", dest="anneal_lr", action="store_true") 106 | parser.set_defaults(anneal_lr=False) 107 | 108 | parser.add_argument("--max_lr", type=float, default=1e-3) 109 | parser.add_argument("--weight_decay", type=float, default=0) 110 | parser.add_argument("--weight_decay_kind", type=str, default="to_zero") 111 | parser.add_argument("--noise_factor", type=float, default=0) 112 | 113 | parser.add_argument( 114 | "--save_activations", dest="save_activations", action="store_true" 115 | ) 116 | parser.set_defaults(save_activations=False) 117 | parser.add_argument("--save_outputs", dest="save_outputs", action="store_true") 118 | parser.set_defaults(save_outputs=False) 119 | 120 | parser.add_argument( 121 | "--logdir", 122 | type=str, 123 | default=DEFAULT_LOG_DIR, 124 | ) 125 | parser.add_argument( 126 | "--datadir", 127 | type=str, 128 | default=DEFAULT_DATA_DIR, 129 | ) 130 | 131 | return parser 132 | 133 | def prepare_data(self) -> None: 134 | """ 135 | Used by pytorch_lighting 136 | 137 | Loads training data to self.train_dataset 138 | Loads validation data to self.val_dataset 139 | """ 140 | (self.train_dataset, self.val_dataset,) = ArithmeticDataset.splits( 141 | train_pct=self.hparams.train_data_pct, # type: ignore 142 | operator=self.hparams.math_operator, # type: ignore 143 | operand_length=self.hparams.operand_length, # type: ignore 144 | data_dir=self.hparams.datadir, # type: ignore 145 | ) 146 | 147 | def train_dataloader(self) -> ArithmeticIterator: # type: ignore 148 | """ 149 | Used by pytorch_lighting 150 | 151 | :returns: an iterator for self.train_dataset 152 | """ 153 | device = self.transformer.embedding.weight.device 154 | iterator = ArithmeticIterator( 155 | self.train_dataset, 156 | device, 157 | batchsize_hint=self.hparams.batchsize, # type: ignore 158 | ) 159 | self.train_batchsize = iterator.batchsize 160 | self.batches_per_epoch = len(iterator) 161 | 162 | return iterator 163 | 164 | def val_dataloader(self) -> ArithmeticIterator: # type: ignore 165 | """ 166 | Used by pytorch_lighting 167 | 168 | :returns: an iterator for self.train_dataset 169 | """ 170 | device = self.transformer.embedding.weight.device 171 | iterator = ArithmeticIterator( 172 | self.val_dataset, 173 | device, 174 | batchsize_hint=-1, # no need to batch validation data 175 | ) 176 | return iterator 177 | 178 | def test_dataloader(self) -> ArithmeticIterator: # type: ignore 179 | """ 180 | Used by pytorch_lighting 181 | 182 | :returns: an iterator for self.train_dataset 183 | """ 184 | device = self.transformer.embedding.weight.device 185 | iterator = ArithmeticIterator( 186 | self.val_dataset, device, batchsize_hint=-1 # type: ignore 187 | ) 188 | return iterator 189 | 190 | def _scheduler_lr(self, step: int) -> float: 191 | """ 192 | Used by pytorch_lighting 193 | 194 | :returns: the learning_rate for this training step 195 | """ 196 | max_lr = self.hparams.max_lr # type: ignore 197 | min_lr = self.hparams.max_lr / 10 # type: ignore 198 | warmup_steps = self.hparams.warmup_steps # type: ignore 199 | if not self.hparams.anneal_lr: 200 | if step <= warmup_steps: 201 | lr = (float(step) / max(warmup_steps, 1)) * max_lr 202 | else: 203 | lr = max_lr 204 | else: 205 | if step <= warmup_steps: 206 | lr = (float(step) / max(warmup_steps, 1)) * max_lr 207 | elif step <= self.hparams.anneal_lr_steps + warmup_steps: 208 | effective_step = step - warmup_steps 209 | t = effective_step / self.hparams.anneal_lr_steps 210 | cos = (1 + np.cos(np.pi * t)) / 2 211 | lr = min_lr + (max_lr - min_lr) * cos 212 | # lr = max_lr - ((effective_step / max_effective_step) * (max_lr - min_lr)) 213 | else: 214 | lr = min_lr 215 | return lr 216 | 217 | def configure_optimizers(self) -> Tuple[List[Any], List[Dict]]: 218 | """ 219 | Used by pytorch_lighting 220 | 221 | :returns: optimizers and schedulers. 222 | """ 223 | optimizer = CustomAdamW( 224 | self.parameters(), 225 | betas=(0.9, 0.98), 226 | eps=1e-8, 227 | lr=1, 228 | weight_decay=self.hparams.weight_decay, 229 | noise_factor=self.hparams.noise_factor, 230 | weight_decay_form=self.hparams.weight_decay_kind, 231 | ) 232 | # optimizer = SAM( 233 | # self.parameters(), 234 | # base_optimizer=CustomAdamW, 235 | # rho=0.05, 236 | # betas=(0.9, 0.98), 237 | # eps=1e-8, 238 | # lr=1, 239 | # weight_decay=self.hparams.weight_decay, 240 | # noise_factor=self.hparams.noise_factor, 241 | # ) 242 | schedulers = [ 243 | { 244 | "scheduler": LambdaLR(optimizer, lr_lambda=self._scheduler_lr), 245 | "interval": "step", 246 | "frequency": 1, 247 | } 248 | ] 249 | return [optimizer], schedulers 250 | 251 | def _accuracy(self, y_hat: Tensor, y: Tensor) -> Tensor: 252 | """ 253 | Takes the most likely solution predicted for each equation and 254 | calculates the frac of equations in the batch for which these 255 | answers were correct 256 | 257 | :param y_hat: The softmax tensor output of the transformer 258 | :param y: A tensor of the token ids for the correct answers to each 259 | equation in the batch 260 | :returns: the fraction of equations correctly answered 261 | """ 262 | 263 | # find max prediction from output 264 | y_hat = torch.max(y_hat, dim=-2).indices # batchsize x num_rhs_tokens 265 | row_accuracy = torch.min((y_hat == y), dim=-1).values # shape: batchsize 266 | accuracy = row_accuracy.float() * 100 # shape: batchsize 267 | return accuracy 268 | 269 | def _step( 270 | self, 271 | batch: Dict, 272 | batch_idx: int, 273 | train: bool = True, 274 | reduction: str = "mean", 275 | grads: bool = False, 276 | ) -> Tuple[Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor]: 277 | """ 278 | Performs one forward pass on a training or validation batch 279 | 280 | :param batch: The batch of equations to process 281 | :param batch_idx: which batch this is in the epoch. 282 | :param train: True is this is a training batch, false otherwise 283 | :returns: The loss from the predicted solutions to the equation, 284 | The accuracy of the predicted solutions 285 | The fraction of this dataset contained in this batch 286 | The portion of the input equations left of the equal sign 287 | The softmax probilities for the solutions to the equations 288 | A list lists of attention matrices by layer and head 289 | A list lists of value matrices by layer and head 290 | Margin for this batch 291 | """ 292 | x = batch["text"] # shape = batchsize * context_len 293 | y = batch["target"] # shape = batchsize * context_len 294 | y_hat, attentions, values = self( 295 | x=x, save_activations=self.hparams.save_activations # type: ignore 296 | ) # shape = batchsize * context_len * vocab_size 297 | y_hat = y_hat.transpose(-2, -1) # shape = batchsize * vocab_size * context_len 298 | 299 | # Note: each sample must have exactly one '=' and all of them must 300 | # have it in the same position. 301 | eq_token_index = self.train_dataset.tokenizer.stoi["="] 302 | eq_position_t = torch.nonzero(y[0, :] == eq_token_index, as_tuple=False) 303 | eq_position = int(eq_position_t.squeeze()) 304 | 305 | # only calculate loss/accuracy on right hand side of the equation 306 | y_rhs = y[..., eq_position + 1 :] 307 | y_hat_rhs = y_hat[..., eq_position + 1 :] 308 | x_lhs = x[..., : eq_position + 1] 309 | 310 | if train: 311 | coeff = float(batch["target"].shape[0]) / len(self.train_dataset) 312 | else: 313 | coeff = float(batch["target"].shape[0]) / len(self.val_dataset) 314 | loss = F.cross_entropy(y_hat_rhs, y_rhs, reduction=reduction) 315 | 316 | with torch.no_grad(): 317 | acc = self._accuracy(y_hat_rhs, y_rhs) 318 | if reduction == "mean": 319 | acc = acc.mean() 320 | 321 | """ 322 | device = self.transformer.embedding.weight.device 323 | self.margin = self.margin.to(device) 324 | 325 | output = y_hat_rhs.clone() # batchsize, vocabsize, rhs tokens 326 | output_m = output.clone() # batchsize, vocabsize, rhs tokens 327 | target = y_rhs.clone() # batchsize, rhs tokens 328 | 329 | for i in range(output.size(0)): # batch 330 | for j in range(output.size(2)): # rhs tokens 331 | output_m[i, target[i, j], j] = output_m[i, :, j].min() 332 | 333 | for i in range(output.size(2)): # rhs tokens 334 | output_compressed = output[:, target[:, i], i].squeeze().diag() 335 | output_m_compressed = ( 336 | output_m[:, output_m.max(dim=1).indices[:, i], i].squeeze().diag() 337 | ) 338 | self.margin = torch.cat( 339 | ( 340 | self.margin, 341 | (output_compressed - output_m_compressed), 342 | ), 343 | 0, 344 | ) 345 | """ 346 | grad_vec = None 347 | if grads: 348 | loss.backward() 349 | for p in self.parameters(): 350 | p.grad.data.div_(batch["text"].shape[0]) 351 | if grad_vec is None: 352 | grad_vec = p.grad.data.view(-1) 353 | else: 354 | grad_vec = torch.cat((grad_vec, p.grad.data.view(-1))) 355 | return loss, grad_vec 356 | return loss, acc, coeff, x_lhs, y_hat_rhs, attentions, values 357 | 358 | 359 | def _save_inputs(self, outputs: Dict, ds: str) -> None: 360 | """ 361 | Saves the input equations to disk for analysis later 362 | 363 | :param outputs: a list of tuples from self.training_step() 364 | :param ds: a string ('train' or 'val') naming which dataset 365 | these inputs are from. 366 | :param train: True is this is a training batch, false otherwise 367 | """ 368 | logdir = self.hparams.logdir + "/inputs/" + ds # type: ignore 369 | os.makedirs(logdir, exist_ok=True) 370 | pickle_file = logdir + f"/{ds}.pt" 371 | 372 | x_lhs = torch.cat([x["x_lhs"] for x in outputs]) 373 | with open(pickle_file, "wb") as fh: 374 | torch.save(x_lhs, fh) 375 | 376 | def _merge_batch_activations( 377 | self, partial_activations: List[List[Tensor]] 378 | ) -> List[List[Tensor]]: 379 | """ 380 | Merges the head_attentions / head_values from all batches in 381 | this epoch. 382 | 383 | :param partial_activations: A list of 384 | (lists of lists of activations by layer and head) 385 | :returns: A lists of lists of activations by layer and head 386 | """ 387 | # num_batches = len(partial_activations) 388 | num_layers = len(partial_activations[0]) 389 | num_heads = len(partial_activations[0][0]) 390 | activations: List = [] 391 | for _ in range(num_layers): 392 | activations.append([]) 393 | for _ in range(num_heads): 394 | activations[-1].append([]) 395 | 396 | for minibatch_activations in partial_activations: 397 | for l, layer_activations in enumerate(minibatch_activations): 398 | for h, head_attn in enumerate(layer_activations): 399 | # # print(f"head_attn = {head_attn}") 400 | activations[l][h].append(head_attn) 401 | 402 | for l in range(num_layers): 403 | for h in range(num_heads): 404 | activations[l][h] = torch.cat(activations[l][h]) 405 | 406 | return activations 407 | 408 | def _save_activations(self, outputs: Dict, ds: str) -> None: 409 | """ 410 | Saves activations out to disk for analysis later 411 | 412 | :param outputs: a list of tuples from self.training_step() 413 | """ 414 | 415 | output: Dict[str, Any] = {} 416 | if self.hparams.save_outputs: # type: ignore 417 | y_hat_rhs = torch.cat([x["y_hat_rhs"] for x in outputs]) 418 | output["y_hat_rhs"] = y_hat_rhs 419 | if self.hparams.save_activations: # type: ignore 420 | partial_attentions = list([o["partial_attentions"] for o in outputs]) 421 | attentions = self._merge_batch_activations(partial_attentions) 422 | partial_values = list([o["partial_values"] for o in outputs]) 423 | values = self._merge_batch_activations(partial_values) 424 | output["attentions"] = attentions 425 | output["values"] = values 426 | if self.hparams.save_outputs or self.hparams.save_activations: # type: ignore 427 | logdir = self.hparams.logdir + "/outputs/" + ds # type: ignore 428 | os.makedirs(logdir, exist_ok=True) 429 | pickle_file = logdir + f"/epoch_{self.current_epoch:010}.pt" 430 | with open(pickle_file, "wb") as fh: 431 | torch.save(output, fh) 432 | 433 | def training_step(self, batch, batch_idx): 434 | """ 435 | Used by pytorch_lightning 436 | Runs one forward training pass on one batch 437 | 438 | :param batch: The batch of equations to process 439 | :param batch_idx: which batch this is in the epoch. 440 | :returns: a dict with loss, accuracy, lr, probabilities of solutions, 441 | attentions, and values 442 | """ 443 | if batch_idx == 0: 444 | self.training_epoch_start_time = time.time() 445 | self.fwd_time_in_epoch = 0 446 | 447 | start = time.time() 448 | loss, accuracy, coeff, x_lhs, y_hat_rhs, attentions, values = self._step( 449 | batch=batch, batch_idx=batch_idx, train=True 450 | ) 451 | self.fwd_time_in_epoch += time.time() - start 452 | 453 | schedulers = self.trainer.lr_schedulers[0] 454 | if self.current_epoch != self.next_train_epoch_to_log: 455 | return {"loss": loss} 456 | lr = schedulers["scheduler"].optimizer.param_groups[0]["lr"] 457 | output = { 458 | "loss": loss, 459 | "partial_train_loss": coeff * loss, 460 | "partial_train_accuracy": coeff * accuracy, 461 | "learning_rate": torch.tensor([lr]), 462 | "y_hat_rhs": y_hat_rhs, 463 | "partial_attentions": attentions, 464 | "partial_values": values, 465 | } 466 | if self.current_epoch == 0: 467 | output["x_lhs"] = x_lhs 468 | 469 | return output 470 | 471 | def training_epoch_end(self, outputs): 472 | """ 473 | Used by pytorch_lightning 474 | Accumulates results of all forward training passes in this epoch 475 | 476 | :param outputs: a list of dicts from self.training_step() 477 | :param batch_idx: which batch this is in the epoch. 478 | :returns: a dict with loss, accuracy, lr, probabilities of solutions, 479 | attentions, and values 480 | """ 481 | epoch_is_to_be_logged = self.current_epoch == self.next_train_epoch_to_log 482 | if epoch_is_to_be_logged: 483 | self.next_train_epoch_to_log = max( 484 | int(1.01 * self.next_train_epoch_to_log), 485 | self.next_train_epoch_to_log + 1, 486 | ) 487 | with torch.no_grad(): 488 | try: 489 | loss = torch.stack([x["partial_train_loss"] for x in outputs]).sum() 490 | except Exception as e: 491 | print("!" * 80) 492 | print(outputs) 493 | raise e 494 | perplexity = torch.exp(loss) 495 | accuracy = torch.stack( 496 | [x["partial_train_accuracy"] for x in outputs] 497 | ).sum() 498 | # avg_lr = torch.stack([x["learning_rate"] for x in outputs]).mean() 499 | # max_lr = torch.stack([x["learning_rate"] for x in outputs]).max() 500 | # last_lr = outputs[-1]["learning_rate"] 501 | first_lr = outputs[0]["learning_rate"] 502 | 503 | if self.hparams.save_activations or self.hparams.save_outputs: 504 | if self.current_epoch == 0: 505 | self._save_inputs(outputs, ds="train") 506 | self._save_activations(outputs, ds="train") 507 | 508 | logs = { 509 | "train_loss": loss, 510 | "train_accuracy": accuracy, 511 | "train_perplexity": perplexity, 512 | "learning_rate": first_lr, 513 | "len_train_ds": len(self.train_dataset), 514 | "len_val_ds": len(self.val_dataset), 515 | "batches_per_epoch": self.batches_per_epoch, 516 | "time_per_epoch": time.time() - self.training_epoch_start_time, 517 | "fwd_time_in_epoch": self.fwd_time_in_epoch, 518 | } 519 | for k, v in logs.items(): 520 | self.log(k, v) 521 | 522 | def validation_step(self, batch, batch_idx): 523 | """ 524 | Used by pytorch_lightning 525 | Runs one forward validation pass on one batch 526 | 527 | :param batch: The batch of equations to process 528 | :param batch_idx: which batch this is in the epoch. 529 | :returns: a dict with val_loss, val_accuracy, probabilities of solutions, 530 | attentions, and values 531 | """ 532 | if self.next_epoch_to_eval < self.current_epoch: 533 | self.next_epoch_to_eval = self.current_epoch 534 | if self.current_epoch != self.next_epoch_to_eval: 535 | return {} 536 | with torch.no_grad(): 537 | loss, accuracy, coeff, x_lhs, y_hat_rhs, attentions, values = self._step( 538 | batch=batch, batch_idx=batch_idx, train=False 539 | ) 540 | output = { 541 | "partial_val_loss": coeff * loss, 542 | "partial_val_accuracy": coeff * accuracy, 543 | "y_hat_rhs": y_hat_rhs, 544 | "partial_attentions": attentions, 545 | "partial_values": values, 546 | } 547 | if self.current_epoch == 0: 548 | output["x_lhs"] = x_lhs 549 | 550 | return output 551 | 552 | def validation_epoch_end(self, outputs): 553 | """ 554 | Used by pytorch_lightning 555 | Accumulates results of all forward validation passes in this epoch 556 | 557 | :param outputs: a list of dicts from self.validation_step() 558 | :param batch_idx: which batch this is in the epoch. 559 | :returns: a dict with val_loss, val_accuracy 560 | """ 561 | validation_is_real = len(outputs[0]) != 0 562 | 563 | if validation_is_real: 564 | self.next_epoch_to_eval = max( 565 | int(1.02 * self.next_epoch_to_eval), self.next_epoch_to_eval + 1 566 | ) 567 | 568 | loss = torch.stack([x["partial_val_loss"] for x in outputs]).sum() 569 | perplexity = torch.exp(loss) 570 | accuracy = torch.stack([x["partial_val_accuracy"] for x in outputs]).sum() 571 | 572 | if self.hparams.save_activations or self.hparams.save_outputs: 573 | if self.current_epoch == 0: 574 | self._save_inputs(outputs, ds="val") 575 | self._save_activations(outputs, ds="val") 576 | 577 | logs = { 578 | "val_loss": loss, 579 | "val_accuracy": accuracy, 580 | "val_perplexity": perplexity, 581 | } 582 | for name, param in self.named_parameters(): 583 | # n parameters 584 | n_params = param.numel() 585 | # get the l2 norm of the parameter 586 | logs["paramnorm_" + name] = torch.norm( 587 | param, 2 588 | ).detach().cpu().numpy() / np.sqrt(n_params) 589 | 590 | # train accuracy 591 | device = self.transformer.embedding.weight.device 592 | train_data = self.train_dataset.data.to(device) 593 | training_data = {"text": train_data[:, :-1], "target": train_data[:, 1:]} 594 | with torch.no_grad(): 595 | tr_loss, tr_acc, *_ = self._step(training_data, 0) 596 | logs["full_train_loss"] = tr_loss 597 | logs["full_train_acc"] = tr_acc 598 | 599 | for k, v in logs.items(): 600 | self.log(k, v) 601 | # save a checkpoint if the epoch is a power of 2 602 | if ( 603 | self.current_epoch > 0 604 | and int(2 ** (int(np.log(self.current_epoch) / np.log(2)))) 605 | == self.current_epoch 606 | ): 607 | self.trainer.save_checkpoint( 608 | os.path.join( 609 | self.hparams.checkpoint_path, 610 | "epoch_" + str(self.current_epoch) + ".ckpt", 611 | ) 612 | ) 613 | if validation_is_real: 614 | return logs 615 | 616 | def test_step(self, batch, batch_idx): 617 | """ 618 | Used by pytorch_lightning 619 | Runs one forward validation pass on one batch 620 | 621 | :param batch: The batch of equations to process 622 | :param batch_idx: which batch this is in the epoch. 623 | :returns: a dict with val_loss, val_accuracy, probabilities of solutions, 624 | attentions, and values 625 | """ 626 | 627 | loss, accuracy, coeff, x_lhs, y_hat_rhs, attentions, values = self._step( 628 | batch=batch, batch_idx=batch_idx, train=False, reduction="none" 629 | ) 630 | output = { 631 | "partial_test_loss": coeff * loss, 632 | "partial_test_accuracy": coeff * accuracy, 633 | "y_hat_rhs": y_hat_rhs, 634 | "partial_attentions": attentions, 635 | "partial_values": values, 636 | } 637 | if self.current_epoch == 0: 638 | output["x_lhs"] = x_lhs 639 | 640 | return output 641 | 642 | def test_epoch_end(self, outputs): 643 | """ 644 | Used by pytorch_lightning 645 | Accumulates results of all forward validation passes in this epoch 646 | 647 | :param outputs: a list of dicts from self.validation_step() 648 | :param batch_idx: which batch this is in the epoch. 649 | :returns: a dict with val_loss, val_accuracy 650 | """ 651 | loss = torch.cat([x["partial_test_loss"] for x in outputs], dim=0) # .sum() 652 | # loss = list([x["partial_test_loss"] for x in outputs]) # .sum() 653 | perplexity = torch.exp(loss) 654 | accuracy = torch.cat([x["partial_test_accuracy"] for x in outputs], dim=0) 655 | 656 | logs = { 657 | "test_loss": loss, 658 | "test_accuracy": accuracy, 659 | "test_perplexity": perplexity, 660 | } 661 | 662 | return {"test_loss": loss, "log": logs} 663 | 664 | def forward(self, *args, **kwargs) -> Any: 665 | """Passes all arguments directly to Tranformer.forward()""" 666 | return self.transformer(*args, **kwargs) 667 | 668 | 669 | def train(hparams: Namespace) -> None: 670 | """ 671 | This is the main trainer_method. This sets up and runs experiment with 672 | the defined hyperparameters 673 | 674 | :param hparams: An argparse.Namespace with all of the relevant hyperparameters 675 | """ 676 | 677 | # Process the args 678 | if hparams.logdir is None: 679 | hparams.logdir = os.environ.get("LOGDIR", ".") 680 | hparams.logdir = os.path.abspath(hparams.logdir) 681 | 682 | # Make sure d_model, heads, and d_key are compatible 683 | assert ( 684 | hparams.d_model % hparams.n_heads == 0 685 | ), "n_heads=%s does not evenly divide d_model=%s" % ( 686 | hparams.n_heads, 687 | hparams.d_model, 688 | ) 689 | hparams.d_key = hparams.d_model / hparams.n_heads 690 | 691 | # Set up the RNGs for repeatability 692 | if hparams.random_seed != -1: 693 | torch.manual_seed(hparams.random_seed) 694 | torch.cuda.manual_seed(hparams.random_seed) 695 | torch.backends.cudnn.deterministic = True 696 | torch.backends.cudnn.benchmark = False 697 | 698 | checkpoint_path = hparams.logdir + "/checkpoints" 699 | os.makedirs(checkpoint_path, exist_ok=True) 700 | hparams.checkpoint_path = checkpoint_path 701 | 702 | # Create the model 703 | model = TrainableTransformer(hparams).float() 704 | 705 | torch.save(model, os.path.join(checkpoint_path, "init.pt")) 706 | 707 | logger = CSVLogger(hparams.logdir) 708 | 709 | # checkpointer = ModelCheckpoint( 710 | # filepath=checkpoint_path, 711 | # monitor="save_ckpt", 712 | # mode="max", 713 | # save_top_k=len(hparams.ckpt_epochs), 714 | # verbose=False, 715 | # ) 716 | 717 | trainer_args = { 718 | "max_steps": hparams.max_steps, 719 | "min_steps": hparams.max_steps, 720 | "max_epochs": int(1e8), 721 | "val_check_interval": 1, 722 | "profiler": False, 723 | # "checkpoint_callback": checkpointer, 724 | "logger": logger, 725 | "log_every_n_steps": 1, 726 | "flush_logs_every_n_steps": 1000, 727 | } 728 | if torch.cuda.is_available() and hparams.gpu >= 0: 729 | trainer_args["gpus"] = [hparams.gpu] 730 | 731 | trainer = Trainer(**trainer_args) 732 | 733 | trainer.fit(model=model) # type: ignore 734 | """ 735 | margin = np.percentile(model.margin.detach().cpu().numpy(), 5) 736 | device = transformer.embedding.weight.device 737 | measures, bounds = metrics.calculate( 738 | transformer, 739 | transformer_init.to(device), 740 | device, 741 | dataset_size, 742 | margin, 743 | input_dim=hparams.d_model, 744 | ) 745 | 746 | measures_file = os.path.join(logger.log_dir, "measures.json") 747 | bounds_file = os.path.join(logger.log_dir, "bounds.json") 748 | with open(measures_file, "w") as fh: 749 | json.dump(measures, fh) 750 | with open(bounds_file, "w") as fh: 751 | json.dump(bounds, fh) 752 | """ 753 | return hparams.logdir 754 | 755 | 756 | def compute_sharpness(hparams: Namespace, ckpts) -> None: 757 | """ 758 | This is the compute_sharpness method. This loads a series of checkpoints in 759 | the defined hyperparameters 760 | 761 | :param hparams: An argparse.Namespace with all of the relevant hyperparameters 762 | """ 763 | 764 | # Process the args 765 | if hparams.logdir is None: 766 | hparams.logdir = os.environ.get("LOGDIR", ".") 767 | hparams.logdir = os.path.abspath(hparams.logdir) 768 | 769 | # Make sure d_model, heads, and d_key are compatible 770 | assert ( 771 | hparams.d_model % hparams.n_heads == 0 772 | ), "n_heads=%s does not evenly divide d_model=%s" % ( 773 | hparams.n_heads, 774 | hparams.d_model, 775 | ) 776 | hparams.d_key = hparams.d_model / hparams.n_heads 777 | 778 | # Set up the RNGs for repeatability 779 | if hparams.random_seed != -1: 780 | torch.manual_seed(hparams.random_seed) 781 | torch.cuda.manual_seed(hparams.random_seed) 782 | torch.backends.cudnn.deterministic = True 783 | torch.backends.cudnn.benchmark = False 784 | 785 | checkpoint_path = hparams.logdir + "/checkpoints" 786 | os.makedirs(checkpoint_path, exist_ok=True) 787 | hparams.checkpoint_path = checkpoint_path 788 | 789 | # Create the model 790 | model = TrainableTransformer(hparams).float() 791 | 792 | torch.save(model, os.path.join(checkpoint_path, "init.pt")) 793 | 794 | logger = CSVLogger(hparams.logdir) 795 | 796 | 797 | trainer_args = { 798 | "max_steps": hparams.max_steps, 799 | "min_steps": hparams.max_steps, 800 | "max_epochs": int(1e8), 801 | "val_check_interval": 1, 802 | "profiler": False, 803 | # "checkpoint_callback": checkpointer, 804 | "logger": logger, 805 | "log_every_n_steps": 1, 806 | "flush_logs_every_n_steps": 1000, 807 | } 808 | if torch.cuda.is_available() and hparams.gpu >= 0: 809 | trainer_args["gpus"] = [hparams.gpu] 810 | 811 | trainer = Trainer(**trainer_args) 812 | 813 | for ckpt in ckpts: 814 | print(f"Loading checkpoint {ckpt}") 815 | # model = torch.load(ckpt) 816 | # model.load_state_dict(torch.load(ckpt)) 817 | 818 | checkpoint = torch.load(ckpt) 819 | # print(dir(checkpoint), type(checkpoint), "Ckpt") 820 | # for k, v in checkpoint.items(): 821 | # print(k) 822 | # print(checkpoint["hyper_parameters"]) 823 | 824 | hps = checkpoint["hyper_parameters"] 825 | hps = argparse.Namespace(**hps) 826 | model = TrainableTransformer(hps).float() 827 | model.load_state_dict(checkpoint["state_dict"]) 828 | 829 | phi = get_sharpness(model.train_dataloader(), model) 830 | results = {} 831 | results[ckpt] = phi 832 | pickle.dump(results, open(f"results/results_SD-{i}.pkl", "wb")) 833 | 834 | 835 | def add_args(parser=None) -> Namespace: 836 | """ 837 | Parses the command line arguments 838 | 839 | :returns: an argparse.Namespace with all of the needed arguments 840 | """ 841 | if parser is None: 842 | parser = ArgumentParser() 843 | parser.add_argument("--random_seed", type=int, default=-1) 844 | parser.add_argument("--gpu", type=int, default=0) 845 | parser.add_argument("--max_epochs", type=int, default=None) 846 | parser.add_argument("--max_steps", type=int, default=100000) 847 | # parser.add_argument("--checkpoint_period", type=int, default=1) 848 | parser = TrainableTransformer.add_model_specific_args(parser) 849 | return parser 850 | 851 | 852 | class CustomAdamW(torch.optim.Optimizer): 853 | def __init__( 854 | self, 855 | params, 856 | lr=1e-3, 857 | betas=(0.9, 0.999), 858 | eps=1e-8, 859 | weight_decay=1e-2, 860 | amsgrad=False, 861 | noise_factor=0.0, 862 | weight_decay_form="to_zero", 863 | ): 864 | if not 0.0 <= lr: 865 | raise ValueError("Invalid learning rate: {}".format(lr)) 866 | if not 0.0 <= eps: 867 | raise ValueError("Invalid epsilon value: {}".format(eps)) 868 | if not 0.0 <= betas[0] < 1.0: 869 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 870 | if not 0.0 <= betas[1] < 1.0: 871 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 872 | if not weight_decay_form in ["to_zero", "to_init", "jiggle", "honest"]: 873 | raise ValueError( 874 | f"Invalid weight decay form: {weight_decay_form}, should be one of ['to_zero', 'to_init', 'jiggle']" 875 | ) 876 | # if not 0.0 <= weight_decay: 877 | # raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 878 | defaults = dict( 879 | lr=lr, 880 | betas=betas, 881 | eps=eps, 882 | weight_decay=weight_decay, 883 | amsgrad=amsgrad, 884 | noise_factor=noise_factor, 885 | weight_decay_form=weight_decay_form, 886 | ) 887 | super(CustomAdamW, self).__init__(params, defaults) 888 | 889 | def __setstate__(self, state): 890 | super(CustomAdamW, self).__setstate__(state) 891 | for group in self.param_groups: 892 | group.setdefault("amsgrad", False) 893 | 894 | @torch.no_grad() 895 | def step(self, closure=None): 896 | """Performs a single optimization step. 897 | 898 | Arguments: 899 | closure (callable, optional): A closure that reevaluates the model 900 | and returns the loss. 901 | """ 902 | loss = None 903 | if closure is not None: 904 | with torch.enable_grad(): 905 | loss = closure() 906 | 907 | for group in self.param_groups: 908 | for p in group["params"]: 909 | if p.grad is None: 910 | continue 911 | 912 | # Perform optimization step 913 | grad = p.grad 914 | 915 | if group["weight_decay"] > 0: 916 | if group["weight_decay_form"] == "honest": 917 | grad = grad + group["weight_decay"] * p.detach() 918 | 919 | if grad.is_sparse: 920 | raise RuntimeError( 921 | "Adam does not support sparse gradients, please consider SparseAdam instead" 922 | ) 923 | amsgrad = group["amsgrad"] 924 | 925 | state = self.state[p] 926 | 927 | # State initialization 928 | if len(state) == 0: 929 | state["step"] = 0 930 | # Exponential moving average of gradient values 931 | state["exp_avg"] = torch.zeros_like( 932 | p, memory_format=torch.preserve_format 933 | ) 934 | # Exponential moving average of squared gradient values 935 | state["exp_avg_sq"] = torch.zeros_like( 936 | p, memory_format=torch.preserve_format 937 | ) 938 | if group["weight_decay_form"] == "to_init": 939 | state["init"] = p.detach().clone() 940 | if amsgrad: 941 | # Maintains max of all exp. moving avg. of sq. grad. values 942 | state["max_exp_avg_sq"] = torch.zeros_like( 943 | p, memory_format=torch.preserve_format 944 | ) 945 | 946 | if group["weight_decay"] > 0: 947 | if group["weight_decay_form"] == "to_zero": 948 | p.mul_(1 - group["lr"] * group["weight_decay"]) 949 | elif group["weight_decay_form"] == "to_init": 950 | p.add_( 951 | (state["init"] - p) * (group["lr"] * group["weight_decay"]) 952 | ) 953 | elif group["weight_decay_form"] == "jiggle": 954 | p.mul_( 955 | torch.exp( 956 | torch.randn(1).cuda() 957 | * (group["lr"] * group["weight_decay"]) 958 | ) 959 | ) 960 | elif group["weight_decay_form"] == "honest": 961 | pass 962 | else: 963 | raise ValueError( 964 | f"Invalid weight decay form: {group['weight_decay_form']}" 965 | ) 966 | 967 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 968 | if amsgrad: 969 | max_exp_avg_sq = state["max_exp_avg_sq"] 970 | beta1, beta2 = group["betas"] 971 | 972 | state["step"] += 1 973 | bias_correction1 = 1 - beta1 ** state["step"] 974 | bias_correction2 = 1 - beta2 ** state["step"] 975 | 976 | # Decay the first and second moment running average coefficient 977 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 978 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 979 | if amsgrad: 980 | # Maintains the maximum of all 2nd moment running avg. till now 981 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 982 | # Use the max. for normalizing running avg. of gradient 983 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( 984 | group["eps"] 985 | ) 986 | else: 987 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( 988 | group["eps"] 989 | ) 990 | 991 | step_size = group["lr"] / bias_correction1 992 | 993 | upd = exp_avg / denom 994 | # add uniform gaussian noise to the update 995 | if group["noise_factor"] > 0: 996 | upd += torch.randn_like(upd) * group["noise_factor"] 997 | # if group['noise_factor'] > 0: 998 | # upd *= torch.exp(torch.randn_like(upd) * group['noise_factor']) 999 | p.add_(-step_size * upd) 1000 | 1001 | return loss 1002 | 1003 | 1004 | class SAM(torch.optim.Optimizer): 1005 | def __init__(self, params, base_optimizer, rho=0.05, **kwargs): 1006 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 1007 | 1008 | defaults = dict(rho=rho, **kwargs) 1009 | super(SAM, self).__init__(params, defaults) 1010 | 1011 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 1012 | self.param_groups = self.base_optimizer.param_groups 1013 | 1014 | @torch.no_grad() 1015 | def first_step(self, zero_grad=False): 1016 | grad_norm = self._grad_norm() 1017 | for group in self.param_groups: 1018 | scale = group["rho"] / (grad_norm + 1e-12) 1019 | 1020 | for p in group["params"]: 1021 | if p.grad is None: 1022 | continue 1023 | e_w = p.grad * scale.to(p) 1024 | p.add_(e_w) # climb to the local maximum "w + e(w)" 1025 | self.state[p]["e_w"] = e_w 1026 | 1027 | if zero_grad: 1028 | self.zero_grad() 1029 | 1030 | @torch.no_grad() 1031 | def second_step(self, zero_grad=False): 1032 | for group in self.param_groups: 1033 | for p in group["params"]: 1034 | if p.grad is None: 1035 | continue 1036 | p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)" 1037 | 1038 | self.base_optimizer.step() # do the actual "sharpness-aware" update 1039 | 1040 | if zero_grad: 1041 | self.zero_grad() 1042 | 1043 | @torch.no_grad() 1044 | def step(self, closure=None): 1045 | assert ( 1046 | closure is not None 1047 | ), "Sharpness Aware Minimization requires closure, but it was not provided" 1048 | closure = torch.enable_grad()( 1049 | closure 1050 | ) # the closure should do a full forward-backward pass 1051 | 1052 | self.first_step(zero_grad=True) 1053 | closure() 1054 | self.second_step() 1055 | 1056 | def _grad_norm(self): 1057 | shared_device = self.param_groups[0]["params"][ 1058 | 0 1059 | ].device # put everything on the same device, in case of model parallelism 1060 | grad_norms = [ 1061 | p.grad.norm(p=2).to(shared_device) 1062 | for group in self.param_groups 1063 | for p in group["params"] 1064 | if p.grad is not None 1065 | ] 1066 | print("grad norms is ", grad_norms, "!" * 1000) 1067 | norm = torch.norm( 1068 | torch.stack(grad_norms), 1069 | p=2, 1070 | ) 1071 | return norm 1072 | -------------------------------------------------------------------------------- /grok/transformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from argparse import ArgumentParser, Namespace 3 | from typing import Tuple, List, Dict, Union 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from numpy import cos, sin, sqrt 10 | from torch import tensor, Tensor 11 | from torch.optim.lr_scheduler import LambdaLR 12 | import pytorch_lightning as pl 13 | 14 | from argparse import ArgumentParser 15 | 16 | 17 | class Linear(nn.Linear): 18 | def __init__(self, *args, **kwargs): 19 | self.weight_noise = kwargs.pop("weight_noise") 20 | super().__init__(*args, **kwargs) 21 | 22 | def forward(self, input: Tensor) -> Tensor: 23 | if self.weight_noise > 0 and self.training: 24 | bias = self.bias if self.bias is None else self.bias + torch.randn_like(self.bias) * self.weight_noise 25 | weight = self.weight + torch.randn_like(self.weight) * self.weight_noise 26 | # weight = self.weight * torch.exp(torch.randn_like(self.weight) * self.weight_noise) 27 | else: 28 | bias = self.bias 29 | weight = self.weight 30 | 31 | return F.linear( 32 | input, 33 | weight, 34 | bias, 35 | ) 36 | 37 | class LayerNorm(nn.LayerNorm): 38 | def __init__(self, *args, **kwargs): 39 | self.weight_noise = kwargs.pop("weight_noise") 40 | super().__init__(*args, **kwargs) 41 | 42 | def forward(self, input: Tensor) -> Tensor: 43 | if self.weight_noise > 0 and self.training: 44 | bias = self.bias if self.bias is None else self.bias + torch.randn_like(self.bias) * self.weight_noise 45 | weight = self.weight + torch.randn_like(self.weight) * self.weight_noise 46 | # weight = self.weight * torch.exp(torch.randn_like(self.weight) * self.weight_noise) 47 | else: 48 | bias = self.bias 49 | weight = self.weight 50 | return F.layer_norm( 51 | input, 52 | self.normalized_shape, 53 | weight, 54 | bias, 55 | self.eps, 56 | ) 57 | 58 | 59 | class Embedding(nn.Embedding): 60 | def __init__(self, *args, **kwargs): 61 | self.weight_noise = kwargs.pop("weight_noise") 62 | super().__init__(*args, **kwargs) 63 | 64 | def forward(self, input: Tensor) -> Tensor: 65 | if self.weight_noise > 0 and self.training: 66 | weight = self.weight + torch.randn_like(self.weight) * self.weight_noise 67 | # weight = self.weight * torch.exp(torch.randn_like(self.weight) * self.weight_noise) 68 | else: 69 | weight = self.weight 70 | return F.embedding( 71 | input, 72 | weight, 73 | self.padding_idx, 74 | self.max_norm, 75 | self.norm_type, 76 | self.scale_grad_by_freq, 77 | self.sparse, 78 | ) 79 | 80 | 81 | class AttentionHead(nn.Module): 82 | def __init__(self, d_model: int, d_key: int, weight_noise: float) -> None: 83 | 84 | super().__init__() 85 | 86 | self.d_key = d_key 87 | 88 | # head projections 89 | self.Wq = Linear(d_model, d_key, bias=False, weight_noise=weight_noise) 90 | self.Wk = Linear(d_model, d_key, bias=False, weight_noise=weight_noise) 91 | self.Wv = Linear(d_model, d_key, bias=False, weight_noise=weight_noise) 92 | 93 | self.softmax = nn.Softmax(dim=-1) 94 | 95 | def forward( 96 | self, 97 | queries: Tensor, 98 | keys: Tensor, 99 | values: Tensor, 100 | mask: Union[Tensor, None] = None, 101 | save_activations: bool = False, 102 | ) -> Tuple[Tensor, Union[Tensor, None], Union[Tensor, None]]: 103 | 104 | # project queries, keys, values 105 | queries = self.Wq(queries) 106 | keys = self.Wk(keys) 107 | values = self.Wv(values) 108 | 109 | # calculate compatibility function 110 | attn = torch.matmul(queries, torch.transpose(keys, -2, -1)) 111 | attn = attn / sqrt(self.d_key) 112 | 113 | # Filter out attention to future positions 114 | if mask is not None: 115 | attn.masked_fill_(mask == 0, float("-inf")) 116 | 117 | # softmax 118 | attn = self.softmax(attn) 119 | 120 | # sum the weighted value vectors 121 | result: Tensor = torch.matmul(attn, values) # shape = (max_context_len, d_key) 122 | if save_activations: 123 | leaf_attn = attn.clone().detach() # type: ignore 124 | leaf_values = values.clone().detach() # type: ignore 125 | else: 126 | leaf_attn = None # type: ignore 127 | leaf_values = None # type: ignore 128 | 129 | return result, leaf_attn, leaf_values 130 | 131 | 132 | class MultiHeadAttention(nn.Module): 133 | def __init__(self, d_model: int, heads: int, weight_noise: float = 0.0) -> None: 134 | super().__init__() 135 | d_key = int(d_model / heads) 136 | 137 | attn_heads = [ 138 | AttentionHead(d_model, d_key, weight_noise=weight_noise) 139 | for _ in range(heads) 140 | ] 141 | self.attn_heads = nn.ModuleList(attn_heads) 142 | self.Wo = Linear(d_model, d_model, bias=False, weight_noise=weight_noise) 143 | 144 | def forward( 145 | self, 146 | queries: Tensor, 147 | keys: Tensor, 148 | values: Tensor, 149 | mask: Tensor = None, 150 | save_activations=False, 151 | ) -> Tuple[Tensor, List[Tensor], List[Tensor]]: 152 | 153 | head_outputs = [ 154 | h( 155 | queries=queries, 156 | keys=keys, 157 | values=values, 158 | mask=mask, 159 | save_activations=save_activations, 160 | ) 161 | for h in self.attn_heads 162 | ] 163 | head_results = [output[0] for output in head_outputs] 164 | 165 | if save_activations: 166 | layer_attns = list([output[1] for output in head_outputs]) 167 | layer_values = list([output[2] for output in head_outputs]) 168 | else: 169 | layer_attns = [] 170 | layer_values = [] 171 | 172 | multihead_result = torch.cat(head_results, dim=-1) 173 | multihead_result = self.Wo(multihead_result) 174 | return multihead_result, layer_attns, layer_values 175 | 176 | 177 | class FFN(nn.Module): 178 | def __init__( 179 | self, 180 | d_model: int, 181 | multiplier: int = 4, 182 | non_linearity: str = "relu", 183 | weight_noise: float = 0.0, 184 | ) -> None: 185 | super().__init__() 186 | 187 | d_ff = int(multiplier * d_model) 188 | 189 | non_linearities = {"relu": nn.ReLU, "gelu": nn.GELU} 190 | 191 | self.ffn = nn.Sequential( 192 | Linear(d_model, d_ff, bias=False, weight_noise=weight_noise), 193 | non_linearities[non_linearity](), 194 | Linear(d_ff, d_model, bias=False, weight_noise=weight_noise), 195 | ) 196 | 197 | def forward(self, x: Tensor) -> Tensor: 198 | return self.ffn(x) 199 | 200 | 201 | class DecoderBlock(nn.Module): 202 | def __init__( 203 | self, 204 | d_model: int, 205 | heads: int, 206 | dropout: float, 207 | non_linearity: str = "relu", 208 | weight_noise: float = 0.0, 209 | ) -> None: 210 | super().__init__() 211 | 212 | self.self_attn = MultiHeadAttention(d_model, heads, weight_noise=weight_noise) 213 | # self.self_attn_drop = nn.Dropout(p=dropout) 214 | self.self_attn_norm = LayerNorm(d_model, weight_noise=weight_noise) 215 | 216 | self.ffn = FFN(d_model, non_linearity=non_linearity, weight_noise=weight_noise) 217 | self.ffn_drop = nn.Dropout(p=dropout) 218 | self.ffn_norm = LayerNorm(d_model, weight_noise=weight_noise) 219 | 220 | def forward( 221 | self, 222 | x: Tensor, 223 | self_attn_mask: Tensor = None, 224 | save_activations: bool = False, 225 | ) -> Tuple[Tensor, List[Tensor], List[Tensor]]: 226 | a1, layer_attns, layer_values = self.self_attn( 227 | x, x, x, self_attn_mask, save_activations 228 | ) 229 | # a1 = self.self_attn_drop(a1) 230 | a1 = self.self_attn_norm(x + a1) 231 | 232 | a2 = self.ffn(a1) 233 | a2 = self.ffn_drop(a2) 234 | a2 = self.ffn_norm(a1 + a2) 235 | 236 | return a2, layer_attns, layer_values 237 | 238 | 239 | class Decoder(nn.Module): 240 | def __init__( 241 | self, 242 | d_model: int, 243 | heads: int, 244 | num_blocks: int, 245 | dropout: float, 246 | non_linearity: str = "relu", 247 | weight_noise: float = 0.0, 248 | ) -> None: 249 | super().__init__() 250 | 251 | self.blocks = nn.ModuleList( 252 | [ 253 | DecoderBlock( 254 | d_model, heads, dropout, non_linearity, weight_noise=weight_noise 255 | ) 256 | for _ in range(num_blocks) 257 | ] 258 | ) 259 | 260 | def forward( 261 | self, 262 | x: Tensor, 263 | self_attn_mask: Tensor = None, 264 | save_activations=False, 265 | ) -> Tuple[Tensor, List[List[Tensor]], List[List[Tensor]]]: 266 | 267 | a = x 268 | attentions = [] 269 | values = [] 270 | for block in self.blocks: 271 | a, layer_attentions, layer_values = block( 272 | a, self_attn_mask, save_activations=save_activations 273 | ) 274 | if save_activations: 275 | attentions.append(layer_attentions) 276 | values.append(layer_values) 277 | return a, attentions, values 278 | 279 | 280 | class Transformer(nn.Module): 281 | def __init__( 282 | self, 283 | n_layers: int = 4, 284 | n_heads: int = 4, 285 | d_model: int = 256, 286 | dropout: float = 0.1, 287 | max_context_len: int = 1024, 288 | vocab_len: int = 2000, 289 | non_linearity: str = "relu", 290 | weight_noise: float = 0.0, 291 | ) -> None: 292 | super().__init__() 293 | 294 | self.n_layers = n_layers 295 | self.n_heads = n_heads 296 | self.d_model = d_model 297 | self.dropout = dropout 298 | self.max_context_len = max_context_len 299 | self.non_linearity = non_linearity 300 | 301 | self.vocab_len = vocab_len 302 | 303 | self.embedding = Embedding(vocab_len, d_model, weight_noise=weight_noise) # type: ignore 304 | self.register_buffer( 305 | "position_encoding", self._position_encoding(max_context_len, d_model) 306 | ) 307 | self.register_buffer("self_attn_mask", self.make_mask(max_context_len)) 308 | 309 | self.decoder = Decoder( 310 | d_model, 311 | n_heads, 312 | n_layers, 313 | dropout, 314 | self.non_linearity, 315 | weight_noise=weight_noise, 316 | ) 317 | 318 | self.linear = Linear(d_model, vocab_len, bias=False, weight_noise=weight_noise) 319 | 320 | @staticmethod 321 | def make_mask(context_len: int) -> Tensor: 322 | return torch.ones([context_len, context_len]).tril() 323 | 324 | @classmethod 325 | def _position_encoding(cls, context_len: int, d_model: int) -> Tensor: 326 | rows = [ 327 | tensor( 328 | [ 329 | sin(pos / (10000 ** (i / d_model))) 330 | if i % 2 == 0 331 | else cos(pos / (10000 ** ((i - 1) / d_model))) 332 | for i in range(d_model) 333 | ] 334 | ) 335 | for pos in range(context_len) 336 | ] 337 | stack = torch.stack(rows, dim=1) 338 | 339 | return stack.T # type: ignore 340 | 341 | def embed(self, indices: Tensor) -> Tensor: 342 | context_len = indices.shape[-1] 343 | pe = self.position_encoding[:context_len, :] # type: ignore 344 | 345 | embedded = self.embedding(indices) 346 | 347 | return pe + embedded 348 | 349 | def forward( 350 | self, 351 | x: Tensor, 352 | pos: int = None, 353 | save_activations: bool = False, 354 | ) -> Tuple[Tensor, Union[Tensor, None], Union[Tensor, None]]: 355 | """parameters: 356 | x: (rank-1 tensor) vocab indices of decoder input token 357 | sequence""" 358 | 359 | # Make sure sampling inputs are on the correct device 360 | x = x.to(self.embedding.weight.device) 361 | 362 | # make_attention mask 363 | this_max_context_len = x.shape[-1] 364 | self_attn_mask = self.self_attn_mask[ # type: ignore 365 | :this_max_context_len, :this_max_context_len 366 | ] 367 | 368 | # Decode 369 | x = self.embed(x) 370 | decoded, attentions, values = self.decoder( 371 | x, self_attn_mask, save_activations=save_activations 372 | ) 373 | 374 | # Return predictions for specific token 375 | if pos is not None: 376 | decoded = decoded[:, pos, :] 377 | 378 | y_hat = self.linear(decoded) 379 | return y_hat, attentions, values 380 | -------------------------------------------------------------------------------- /grok/visualization.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import logging 3 | import os 4 | import math 5 | import socket 6 | 7 | from collections import defaultdict 8 | from copy import deepcopy 9 | 10 | import matplotlib.pyplot as plt 11 | import matplotlib.ticker as mtick 12 | import numpy as np 13 | import torch 14 | 15 | from mpl_toolkits.axes_grid1 import make_axes_locatable 16 | from tqdm import tqdm 17 | 18 | from grok.data import ArithmeticDataset 19 | 20 | logging.basicConfig(level=logging.ERROR) 21 | logger = logging.getLogger("grok.view_metrics") 22 | logger.setLevel(logging.ERROR) 23 | 24 | GROK_DIR = os.path.expanduser("~/data/grok") 25 | IMAGE_DIR = f"{GROK_DIR}/images" 26 | DATA_DIR = f"{GROK_DIR}/data" 27 | 28 | 29 | DEFAULT_CMAP = "viridis" 30 | 31 | default_metric_limits = { 32 | "min_val_accuracy": 0, 33 | "max_val_accuracy": 100, 34 | "min_T": 0, # 0 35 | "max_T": 100, # 87.5 36 | "min_D": 0, # 8 37 | "max_D": 2048, # 256 38 | "min_H": 0, # 1 39 | "max_H": 1204, # 8 40 | "min_L": 0, # 1 41 | "max_L": 1024, # 4 42 | "min_accuracy": 0, 43 | "max_accuracy": 100, 44 | } 45 | 46 | default_axis_scales = {"x": "linear", "y": "linear"} 47 | 48 | 49 | ## Data Loading Functions 50 | 51 | 52 | def factor_expts(expts): 53 | result = {} 54 | for expt in expts: 55 | expt_s = expt.split("_") 56 | arch = "_".join(expt_s[:3]) 57 | t = int(float(expt_s[3].split("-")[1])) 58 | result.setdefault(arch, {}) 59 | result[arch][t] = expt 60 | return result 61 | 62 | 63 | def load_metric_data(data_dir, epochs=100000, load_partial_data=True): 64 | # layers x heads x d_model x train_pct 65 | data = {} 66 | expts = os.listdir(data_dir) 67 | archs = factor_expts(expts) 68 | logger.debug(archs) 69 | for arch in archs: 70 | T = sorted(archs[arch].keys()) 71 | data[arch] = { 72 | "T": torch.LongTensor(T), 73 | "metrics": torch.zeros((max(T), 5, epochs)), 74 | } 75 | # print(f"metrics_shape = {data[arch]['metrics'].shape}") 76 | for i, t in tqdm(list(enumerate(T))): 77 | expt = archs[arch][t] 78 | logger.debug(expt) 79 | log_dir = data_dir + "/" + expt 80 | 81 | # print("log_dir", log_dir) 82 | try: 83 | with open(log_dir + "/default/version_0/metrics.csv", "r") as fh: 84 | logger.debug(f"loading {log_dir}") 85 | reader = list(csv.DictReader(fh)) 86 | val_t = torch.FloatTensor( 87 | [ 88 | [ 89 | float(r["val_loss"]), 90 | float(r["val_accuracy"]), 91 | ] 92 | for r in reader 93 | if r["val_loss"] 94 | ] 95 | ).T 96 | train_t = torch.FloatTensor( 97 | [ 98 | [ 99 | float(r["learning_rate"]), 100 | float(r["train_loss"]), 101 | float(r["train_accuracy"]), 102 | ] 103 | for r in reader 104 | if r["train_loss"] 105 | ] 106 | ).T 107 | # logger.debug(val_t.shape) 108 | # logger.debug(train_t[0, -3:]) 109 | if load_partial_data: 110 | raise Exception("Not implemented") 111 | elif (val_t.shape[-1] >= epochs) and (train_t.shape[-1] >= epochs): 112 | data[arch]["metrics"][i] = torch.cat( 113 | [train_t[..., :epochs], val_t[..., :epochs]], dim=0 114 | ) 115 | else: 116 | data[arch]["T"][i] = 0 117 | # except FileNotFoundError: 118 | except: 119 | data[arch]["T"][i] = 0 120 | indices = torch.nonzero(data[arch]["T"]).squeeze() 121 | if len(indices.shape) == 0: 122 | indices = indices.unsqueeze(0) 123 | # print(f"indices.shape = {indices.shape}") 124 | data[arch]["T"] = data[arch]["T"][indices] 125 | # print(f"data[arch]['T'].shape = {data[arch]['T'].shape}") 126 | data[arch]["metrics"] = data[arch]["metrics"][indices] 127 | # print(f"data[arch]['metrics'].shape = {data[arch]['metrics'].shape}") 128 | data[arch]["metrics"] = torch.transpose(data[arch]["metrics"], 0, 1) 129 | # print(f"data[arch]['metrics'].shape = {data[arch]['metrics'].shape}") 130 | return data 131 | 132 | 133 | def most_interesting(metric_data): 134 | interesting_metric_data = {} 135 | for arch in metric_data: 136 | T = metric_data[arch]["T"] 137 | max_acc_by_t = torch.max( 138 | metric_data[arch]["val_accuracy"], dim=1, keepdim=True 139 | ).values.squeeze() 140 | max_loss_by_t = torch.max( 141 | metric_data[arch]["val_loss"], dim=1, keepdim=True 142 | ).values.squeeze() 143 | acc_idx = torch.nonzero(max_acc_by_t >= 95).squeeze() 144 | if acc_idx.shape == torch.Size([0]): 145 | acc_idx = torch.nonzero(max_acc_by_t == max_acc_by_t.max()).squeeze() 146 | if acc_idx.shape == torch.Size([]): 147 | acc_idx = acc_idx.unsqueeze(0) 148 | max_loss = torch.max(max_loss_by_t[acc_idx]) 149 | loss_idx = torch.nonzero(max_loss_by_t[acc_idx] == max_loss) 150 | interesting_idx = acc_idx[loss_idx].squeeze() 151 | 152 | interesting_metric_data[arch] = {} 153 | for k in metric_data[arch]: 154 | interesting_metric_data[arch][k] = metric_data[arch][k][ 155 | interesting_idx 156 | ].unsqueeze(0) 157 | 158 | return interesting_metric_data 159 | 160 | 161 | # ## Graph Drawing Functions 162 | 163 | 164 | def moving_avg(Y, steps): 165 | return np.convolve(Y, np.ones(steps), "valid") / steps 166 | 167 | 168 | def find_inflections(Y, smoothing_steps=100): 169 | avg_Y = moving_avg(Y, smoothing_steps) 170 | avg_direction = torch.FloatTensor(np.sign(avg_Y[1:] - avg_Y[:-1])) 171 | avg_direction = torch.cat([avg_direction[0].unsqueeze(0), avg_direction]) 172 | avg_inflections = torch.nonzero(avg_direction[1:] - avg_direction[:-1]).squeeze() 173 | avg_inflections = [0] + (avg_inflections + 1).tolist() + [len(Y) - 1] 174 | logger.debug(f"avg_inflections = {avg_inflections}") 175 | inflections = [] 176 | for i in range(2, len(avg_inflections)): 177 | low = avg_inflections[i - 2] 178 | high = avg_inflections[i] 179 | logger.debug(f"low={low}") 180 | logger.debug(f"high={high}") 181 | if avg_direction[low + 1] < 0: 182 | indices = Y[low:high].argmin() + low 183 | logger.debug(f"min = (Y[{indices}] = {Y[int(indices)]}") 184 | else: 185 | indices = Y[low:high].argmax() + low 186 | logger.debug(f"max = (Y[{indices}] = {Y[int(indices)]}") 187 | inflections.append(indices) 188 | return torch.LongTensor(inflections) 189 | 190 | 191 | def check_limits(arch_name, limits): 192 | L, H, D = [float(v.split("-")[1]) for v in arch_name.split("_")] 193 | if (L > limits["max_L"]) or (L < limits["min_L"]): 194 | return False 195 | if (H > limits["max_H"]) or (H < limits["min_H"]): 196 | return False 197 | if (D > limits["max_D"]) or (D < limits["min_D"]): 198 | return False 199 | # if (T > limits['max_T']) or (T < limits['min_T']): 200 | # return False 201 | return True 202 | 203 | 204 | def filter_archs(data, limits={}): 205 | my_limits = deepcopy(default_metric_limits) 206 | my_limits.update(limits) 207 | limits = my_limits 208 | archs = sorted(list(set([a for a in data.keys() if check_limits(a, limits)]))) 209 | logger.debug(f"archs = {archs}") 210 | return archs 211 | 212 | 213 | def get_metric_data(data, limits={}): 214 | my_limits = deepcopy(default_metric_limits) 215 | my_limits.update(limits) 216 | limits = my_limits 217 | 218 | for k in limits.keys(): 219 | metric = k.replace("min_", "").replace("max_", "") 220 | assert ( 221 | limits["max_" + metric] >= limits["min_" + metric] 222 | ), f"invalid {metric} limits" 223 | 224 | d = {} 225 | for arch in filter_archs(data, limits): 226 | logger.debug(arch) 227 | indices = torch.nonzero( 228 | torch.logical_and( 229 | data[arch]["T"] >= limits["min_T"], data[arch]["T"] <= limits["max_T"] 230 | ) 231 | ).squeeze(dim=-1) 232 | logger.debug(f"indices={indices}") 233 | learning_rate, train_loss, train_accuracy, val_loss, val_accuracy = data[arch][ 234 | "metrics" 235 | ][:, indices, :] 236 | d[arch] = { 237 | "T": data[arch]["T"][indices], 238 | "learning_rate": data[arch]["metrics"][0, indices, :], 239 | "train_loss": data[arch]["metrics"][1, indices, :], 240 | "train_accuracy": data[arch]["metrics"][2, indices, :], 241 | "val_loss": data[arch]["metrics"][3, indices, :], 242 | "val_accuracy": data[arch]["metrics"][4, indices, :], 243 | } 244 | return d 245 | 246 | 247 | def add_metric_graph( 248 | fig, 249 | ax, 250 | metric, 251 | metric_data, 252 | scales=default_axis_scales, 253 | cmap=DEFAULT_CMAP, 254 | inflection_hline=False, 255 | ds_len=None, 256 | batchsize=97, 257 | ): 258 | ax.set_title(metric) 259 | ax.set_xscale(scales["x"]) 260 | ax.set_yscale(scales["y"]) 261 | if ds_len is None: 262 | ax.set_xlabel("epochs") 263 | else: 264 | ax.set_xlabel("updates") 265 | 266 | # if 'loss' in metric: 267 | # ymin=0 268 | # ax.axis(ymin=ymin) 269 | if "accuracy" in metric: 270 | ax.yaxis.set_major_formatter(mtick.PercentFormatter()) 271 | ymin = 1e-16 272 | ymax = 101 273 | ax.axis(ymin=ymin, ymax=ymax) 274 | if "loss" in metric: 275 | ymin = 1e-16 276 | ymax = 15 277 | ax.axis(ymin=ymin, ymax=ymax) 278 | 279 | total_plots = 0 280 | logger.debug(f"processing {metric}") 281 | plots = [] 282 | for arch in metric_data: 283 | metric_data[arch]["T"] = metric_data[arch]["T"].squeeze() 284 | logger.debug((" " * 4) + f"arch = {arch}") 285 | if len(metric_data[arch]["T"].shape) == 0: 286 | metric_data[arch]["T"] = metric_data[arch]["T"].unsqueeze(0) 287 | T_min = int(metric_data[arch]["T"][0]) 288 | T_max = int(metric_data[arch]["T"][-1]) 289 | # T_min = 0 290 | # T_max = 88 291 | sm = plt.cm.ScalarMappable( 292 | cmap=cmap, norm=plt.Normalize(vmin=T_min, vmax=T_max) 293 | ) 294 | colors = sm.to_rgba(metric_data[arch]["T"]) 295 | for i, t in enumerate(metric_data[arch]["T"]): 296 | if ds_len is None: 297 | steps_per_epoch = 1 298 | else: 299 | train_rows, val_rows = ArithmeticDataset.calc_split_len( 300 | t.item(), ds_len 301 | ) 302 | steps_per_epoch = math.ceil(train_rows / batchsize) 303 | 304 | logger.debug((" " * 4) + f"t = {t}") 305 | # print( 306 | # f"metric_data[arch][metric].shape = {metric_data[arch][metric].shape}" 307 | # ) 308 | Y = metric_data[arch][metric][i] 309 | # print(f"Y = {Y}") 310 | assert len(Y.shape) == 1, f"Y.shape = {Y.shape} is invalid" 311 | X = torch.arange(1, Y.shape[0] + 1) * steps_per_epoch 312 | assert len(X.shape) == 1, f"X.shape = {X.shape} is invalid" 313 | 314 | label = arch + f" t={t}" 315 | 316 | # ax.set_xlim(left=X[0], right=X[-1] + 1) 317 | if metric == "val_loss" and inflection_hline: 318 | Y_infs = find_inflections(Y) 319 | ax.axhline(y=Y[Y_infs[0]], color="orange") 320 | if metric == "val_accuracy": 321 | label += " (max = %.2f)" % max(Y) 322 | total_plots += 1 323 | ax.plot(X, Y, label=label, color=colors[i]) 324 | if T_max - T_min <= 10: 325 | pass 326 | ax.legend() 327 | else: 328 | fig.colorbar( 329 | sm, 330 | ax=ax, 331 | label="% training data", 332 | ticks=range(T_min, T_max, int((T_max - T_min) / 5)), 333 | ) 334 | 335 | 336 | def add_comm_graph( 337 | ax, metric, kind, comm_data, arch, scales=default_axis_scales, cmap=DEFAULT_CMAP 338 | ): 339 | assert metric in ( 340 | "loss", 341 | "accuracy", 342 | "perplexity", 343 | ) 344 | assert kind in ( 345 | "comm", 346 | "non_comm", 347 | "modulo", 348 | "non_modulo", 349 | "assoc", 350 | "non_assoc", 351 | "zero", 352 | "non_zero", 353 | ) 354 | ax.set_title(metric) 355 | ax.set_xscale(scales["x"]) 356 | ax.set_yscale(scales["y"]) 357 | ax.set_xlabel("epochs") 358 | if "accuracy" in metric: 359 | ax.yaxis.set_major_formatter(mtick.PercentFormatter()) 360 | X = [int(r["epoch"]) for r in comm_data] 361 | Y = torch.tensor( 362 | ( 363 | [float(r["comm" + "_" + metric]) for r in comm_data], 364 | [float(r["non_comm" + "_" + metric]) for r in comm_data], 365 | # [float(r["assoc" + "_" + metric]) for r in comm_data], 366 | # [float(r["non_assoc" + "_" + metric]) for r in comm_data], 367 | # [float(r["zero" + "_" + metric]) for r in comm_data], 368 | # [float(r["non_zero" + "_" + metric]) for r in comm_data], 369 | ) 370 | ) 371 | # label = kind 372 | # if kind.endswith("comm"): 373 | # label += "utative" 374 | 375 | labels = ["commutative", "non-commutative"] 376 | # labels = ["zero", "non_zero"] 377 | # labels = ["associative", "non_associative"] 378 | # label = f"{arch} {kind}_{metric}" 379 | # ax.plot(X, Y, label=label) 380 | sm = plt.cm.ScalarMappable(cmap="cividis", norm=plt.Normalize(vmin=0, vmax=len(Y))) 381 | colors = sm.to_rgba(range(len(Y))) 382 | ax.stackplot(X, Y, baseline="zero", labels=labels, colors=colors) 383 | ax.legend() 384 | 385 | 386 | def add_extremum_graph( 387 | ax, 388 | metric, 389 | kind, 390 | metric_data, 391 | scales=default_axis_scales, 392 | epochs=[-1], 393 | show_legend=True, 394 | ): 395 | assert kind in ("max", "min") 396 | ax.set_title(f"{kind} {metric}") 397 | ax.set_xlabel("training data") 398 | ax.xaxis.set_major_formatter(mtick.PercentFormatter()) 399 | xmin = 0 400 | xmax = 100 401 | ax.axis(xmin=xmin, xmax=xmax) 402 | 403 | # ax.set_ylabel(metric) 404 | ax.set_xscale(scales["x"]) 405 | ax.set_yscale(scales["y"]) 406 | if "accuracy" in metric: 407 | ax.yaxis.set_major_formatter(mtick.PercentFormatter()) 408 | ymin = -1 409 | ymax = 105 410 | ax.axis(ymin=ymin, ymax=ymax) 411 | 412 | # if 'learning' in metric: 413 | # ymin=0 414 | # ymax=0.002 415 | # ax.axis(ymin=ymin, ymax=ymax) 416 | 417 | plots = {} 418 | 419 | total_plots = 0 420 | for arch in metric_data: 421 | X = metric_data[arch]["T"] 422 | if kind == "max": 423 | Y = torch.max( 424 | metric_data[arch][metric], dim=1, keepdim=True 425 | ).values.squeeze() 426 | elif kind == "min": 427 | Y = torch.min( 428 | metric_data[arch][metric], dim=1, keepdim=True 429 | ).values.squeeze() 430 | 431 | # ax.set_xlim(0, 100) 432 | ax.set_xticks(np.arange(0, 100, 5)) 433 | label = f"{kind} {metric} {arch}" 434 | ax.plot(X, Y, label=label) 435 | total_plots += 1 436 | 437 | if show_legend and total_plots <= 12: 438 | ax.legend() 439 | pass 440 | 441 | 442 | def add_inflection_graphs( 443 | ax, metric, metric_data, scales=default_axis_scales, smoothing_steps=100 444 | ): 445 | ax.set_title(f"{metric} inflections by train_data_pct") 446 | ax.set_xlabel("train_data_pct") 447 | ax.set_ylabel(f"{metric} inflections") 448 | ax.set_xscale(scales["x"]) 449 | ax.set_yscale(scales["y"]) 450 | if "accuracy" in metric: 451 | ymin = 0 452 | ymax = 100 453 | ax.axis(xmin=0, xmax=87.5, ymin=ymin, ymax=ymax) 454 | if "learning" in metric: 455 | ymin = 0 456 | ymax = 0.002 457 | ax.axis(xmin=0, xmax=87.5, ymin=ymin, ymax=ymax) 458 | 459 | total_plots = 0 460 | for arch in metric_data: 461 | for num in range(5): 462 | for i, t in enumerate(metric_data[arch]["T"]): 463 | Y = metric_data[arch][metric][i] 464 | X = torch.arange(Y.shape[-1]) 465 | inflections = find_inflections(Y, smoothing_steps=smoothing_steps) 466 | ax.plot(X[inflections], Y[inflections], label=f"{arch} t={t}") 467 | total_plots += 1 468 | 469 | if total_plots <= 12: 470 | ax.legend() 471 | pass 472 | 473 | 474 | def colorbar(mappable, ticks=None, labels=None): 475 | last_axes = plt.gca() 476 | ax = mappable.axes 477 | fig = ax.figure 478 | divider = make_axes_locatable(ax) 479 | cax = divider.append_axes("right", size="5%", pad=0.1) 480 | cbar = fig.colorbar(mappable, cax=cax, ticks=ticks) 481 | if labels is not None: 482 | cbar.ax.set_yticklabels(labels) # vertically oriented colorbar 483 | plt.sca(last_axes) 484 | return cbar 485 | 486 | 487 | def add_matshow( 488 | fig, ax, t, name, vmin=0, vmax=100, cmap=DEFAULT_CMAP, show_colorbar=True 489 | ): 490 | sides = ("left", "right", "top", "bottom") 491 | labels = { 492 | "left": True, 493 | "right": False, 494 | "top": False, 495 | "bottom": True, 496 | "labelleft": True, 497 | "labelright": False, 498 | "labeltop": False, 499 | "labelbottom": True, 500 | } 501 | m = ax.matshow( 502 | t.cpu().detach().numpy(), vmin=vmin, vmax=vmax, origin="lower", cmap=cmap 503 | ) 504 | # c = ax.pcolor(t.cpu(), vmin=vmin, vmax=vmax, cmap=cmap) 505 | ax.set_title(name) 506 | ax.set_xlabel("A") 507 | ax.set_ylabel("B") 508 | ax.set_xticks(np.arange(0, t.shape[1], 10)) 509 | # ax.set_xticklabels(np.arange(1, t.shape[1]+1)) 510 | # ax.set_yticks(np.arange(0.5, t.shape[0] + .5, 1)) 511 | ax.set_yticks(np.arange(0, t.shape[0], 10)) 512 | # ax.set_yticks(np.arange(t.shape[0])) 513 | # ax.set_yticklabels(np.arange(1, t.shape[0]+1)) 514 | ax.tick_params(axis="both", which="both", **labels) 515 | if show_colorbar: 516 | colorbar(m) 517 | -------------------------------------------------------------------------------- /nbs/flatness.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "geographic-personal", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import pickle\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "\n", 14 | "def paired_sort(list1, list2):\n", 15 | " list1, list2 = zip(*sorted(zip(list1, list2)))\n", 16 | " return list1, list2\n", 17 | "\n", 18 | "def plot_phi_by_ckpt():\n", 19 | "\n", 20 | " nums = []\n", 21 | " flatness = []\n", 22 | "\n", 23 | " for f in sorted(os.listdir(\"../results/\")):\n", 24 | " if \"pkl\" in f:\n", 25 | " num = int(f.split(\"-\")[1].split(\".pkl\")[0])\n", 26 | " dat = pickle.load(open(os.path.join(\"../results/\", f), \"rb\"))\n", 27 | " nums.append(num)\n", 28 | " flatness.append(dat[list(dat.keys())[0]].item())\n", 29 | "\n", 30 | " \n", 31 | " nums, flatness = paired_sort(nums, flatness)\n", 32 | " plt.plot(nums, flatness)\n", 33 | " plt.xticks(range(len(nums)))\n", 34 | " plt.ylabel(\"phi\")\n", 35 | " plt.xlabel(\"SD-{n} checkpoint\")\n", 36 | " plt.savefig(\"phi-by-ckpt.png\")" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "id": "mineral-assembly", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "plot_phi_by_ckpt()" 47 | ] 48 | } 49 | ], 50 | "metadata": { 51 | "kernelspec": { 52 | "display_name": "Python 3", 53 | "language": "python", 54 | "name": "python3" 55 | }, 56 | "language_info": { 57 | "codemirror_mode": { 58 | "name": "ipython", 59 | "version": 3 60 | }, 61 | "file_extension": ".py", 62 | "mimetype": "text/x-python", 63 | "name": "python", 64 | "nbconvert_exporter": "python", 65 | "pygments_lexer": "ipython3", 66 | "version": "3.6.13" 67 | } 68 | }, 69 | "nbformat": 4, 70 | "nbformat_minor": 5 71 | } 72 | -------------------------------------------------------------------------------- /scripts/compute_sharpness.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import grok 5 | 6 | parser = grok.training.add_args() 7 | parser.set_defaults(logdir=os.environ.get("LOGDIR", ".")) 8 | hparams = parser.parse_args() 9 | hparams.datadir = os.path.abspath(hparams.datadir) 10 | hparams.logdir = os.path.abspath(hparams.logdir) 11 | 12 | 13 | print(hparams) 14 | 15 | ckpts = [f"./ckpts/L-2_H-4_D-128_T-70_DROP-0_SD-{i}_WU-10_LR-1p0.ckpt" for i in range(20)] 16 | print(grok.training.compute_sharpness(hparams, ckpts)) 17 | -------------------------------------------------------------------------------- /scripts/create_metric_graphs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Render metrics graphs 5 | 6 | import csv 7 | import logging 8 | import os 9 | import glob 10 | import socket 11 | from argparse import ArgumentParser 12 | 13 | from collections import defaultdict 14 | 15 | import matplotlib.pyplot as plt 16 | import matplotlib.ticker as mtick 17 | import numpy as np 18 | import torch 19 | 20 | from mpl_toolkits.axes_grid1 import make_axes_locatable 21 | from tqdm import tqdm 22 | 23 | from sklearn.manifold import TSNE 24 | 25 | import grok 26 | from grok.visualization import * 27 | 28 | # from grok_runs import RUNS 29 | 30 | logging.basicConfig(level=logging.ERROR) 31 | logger = logging.getLogger("grok.view_metrics") 32 | logger.setLevel(logging.ERROR) 33 | 34 | RUNS = { 35 | "subtraction": ( 36 | 9409, 37 | "subtraction/2021-02-05-03-33-56-alethea-sjjf", 38 | ), 39 | } 40 | 41 | 42 | limits = { 43 | "min_val_accuracy": 0, 44 | "max_val_accuracy": 100, 45 | "min_T": 0, # 0 46 | "max_T": 100, # 87.5 47 | "min_D": 0, # 8 48 | "max_D": 256, # 256 49 | "min_H": 0, # 1 50 | "max_H": 4, # 8 51 | "min_L": 0, # 1 52 | "max_L": 4, # 4 53 | "min_accuracy": 0, 54 | "max_accuracy": 100, 55 | } 56 | 57 | for k in limits.keys(): 58 | metric = k.replace("min_", "").replace("max_", "") 59 | assert ( 60 | limits["max_" + metric] >= limits["min_" + metric] 61 | ), f"invalid {metric} limits" 62 | 63 | 64 | parser = ArgumentParser() 65 | parser.add_argument("-i", "--image_dir", type=str, default=IMAGE_DIR) 66 | args = parser.parse_args() 67 | 68 | 69 | def create_loss_curves( 70 | metric_data, 71 | epochs, 72 | run, 73 | most_interesting_only=False, 74 | image_dir=args.image_dir, 75 | ds_len=None, 76 | cmap=DEFAULT_CMAP, 77 | ): 78 | scales = { 79 | "x": "log", 80 | "y": "linear", 81 | } 82 | 83 | 84 | arch = list(metric_data.keys())[0] 85 | 86 | ncols = 2 87 | nrows = 3 88 | fig_width = ncols * 8 89 | fig_height = nrows * 5 90 | fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) 91 | 92 | add_metric_graph( 93 | fig, axs[0, 0], "val_loss", metric_data, scales, cmap=cmap, ds_len=ds_len 94 | ) 95 | add_metric_graph( 96 | fig, axs[0, 1], "val_accuracy", metric_data, scales, cmap, ds_len=ds_len 97 | ) 98 | add_metric_graph( 99 | fig, axs[1, 0], "train_loss", metric_data, scales, cmap, ds_len=ds_len 100 | ) 101 | add_metric_graph( 102 | fig, axs[1, 1], "train_accuracy", metric_data, scales, cmap, ds_len=ds_len 103 | ) 104 | add_metric_graph( 105 | fig, 106 | axs[2, 0], 107 | "learning_rate", 108 | metric_data, 109 | scales, 110 | cmap, # ds_len=ds_len 111 | ) 112 | fig.suptitle(f"{operation} {list(data.keys())[0]}") 113 | fig.tight_layout() 114 | 115 | img_file = f"{image_dir}/loss_curves/{operation}_loss_curves_{arch}" 116 | if ds_len is not None: 117 | img_file += "_by_update" 118 | if most_interesting_only: 119 | img_file += "_most_interesting" 120 | img_file += ".png" 121 | d = os.path.split(img_file)[0] 122 | os.makedirs(d, exist_ok=True) 123 | print(f"Writing {img_file}") 124 | fig.savefig(img_file) 125 | plt.close(fig) 126 | 127 | 128 | def create_max_accuracy_curves( 129 | metric_data, epochs, run, image_dir=args.image_dir, ds_len=None 130 | ): 131 | scales = { 132 | "x": "linear", 133 | "y": "linear", 134 | } 135 | 136 | ncols = 1 137 | nrows = 2 138 | fig_width = ncols * 8 139 | fig_height = nrows * 5 140 | fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) 141 | 142 | def get_ax(row=0, col=0, nrows=nrows, ncols=ncols, axs=axs): 143 | if nrows == 0: 144 | if ncols == 1: 145 | return axs 146 | else: 147 | return axs[col] 148 | else: 149 | if ncols == 1: 150 | return axs[row] 151 | else: 152 | return axs[row, col] 153 | 154 | add_extremum_graph( 155 | get_ax(0, 0), "val_accuracy", "max", metric_data, show_legend=False 156 | ) 157 | add_extremum_graph( 158 | get_ax(1, 0), "train_accuracy", "max", metric_data, show_legend=False 159 | ) 160 | fig.suptitle(f"{operation} {list(data.keys())[0]}") 161 | fig.tight_layout() 162 | 163 | expt = list(metric_data.keys())[0] 164 | img_file = f"{image_dir}/max_accuracy/{operation}_max_accuracy_{arch}.png" 165 | d = os.path.split(img_file)[0] 166 | os.makedirs(d, exist_ok=True) 167 | print(f"Writing {img_file}") 168 | fig.savefig(img_file) 169 | plt.close(fig) 170 | 171 | 172 | def create_tsne_graphs(operation, expt, run_dir, image_dir=args.image_dir): 173 | 174 | saved_pt_dir = f"{run_dir}/activations" 175 | saved_pts = [] 176 | 177 | loss_ts = [] 178 | accuracy_ts = [] 179 | epochs_ts = [] 180 | print(f'glob = {saved_pt_dir + "/activations_*.pt"}') 181 | files = sorted(glob.glob(saved_pt_dir + "/activations_*.pt")) 182 | print(f"files = {files}") 183 | 184 | for file in files: 185 | print(f"Loading {file}") 186 | saved_pt = torch.load(file) 187 | saved_pts.append(saved_pt) 188 | loss_ts.append(saved_pt["val_loss"].mean(dim=-1)) 189 | accuracy_ts.append(saved_pt["val_accuracy"]) 190 | epochs_ts.append(saved_pt["epochs"].squeeze()) 191 | 192 | loss_t = torch.cat(loss_ts, dim=0).T.detach() 193 | accuracy_t = torch.cat(accuracy_ts, dim=0).T.detach() 194 | epochs_t = torch.cat(epochs_ts, dim=0).detach() 195 | print(loss_t.shape) 196 | print(accuracy_t.shape) 197 | print(epochs_t.shape) 198 | ###### 199 | a = 0 200 | num_eqs = len(loss_t) 201 | b = a + num_eqs 202 | 203 | print("Doing T-SNE..") 204 | loss_tsne = TSNE(n_components=2, init="pca").fit_transform(loss_t) 205 | print("...done T-SNE.") 206 | 207 | ncols = 1 208 | nrows = 1 209 | fig_width = ncols * 8 210 | fig_height = nrows * 5 211 | fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) 212 | 213 | axs.scatter(loss_tsne[:, 0], loss_tsne[:, 1]) 214 | 215 | img_file = f"{image_dir}/tsne/{operation}_{expt}.png" 216 | d = os.path.split(img_file)[0] 217 | os.makedirs(d, exist_ok=True) 218 | print(f"Writing {img_file}") 219 | fig.savefig(img_file) 220 | plt.close(fig) 221 | 222 | 223 | for operation in RUNS: 224 | print("") 225 | print("") 226 | print(f"Processing {operation}", flush=True) 227 | 228 | if operation.endswith("-epochs"): 229 | epochs = int(operation.split("/")[-1].split("-")[0]) 230 | else: 231 | epochs = 5000 232 | 233 | #### 234 | 235 | ds_len, run = RUNS[operation] 236 | 237 | 238 | data = load_metric_data(f"{DATA_DIR}/{run}", epochs=epochs, load_partial_data=False) 239 | 240 | # check it 241 | for arch in data: 242 | # print(data[arch]["metrics"].shape) 243 | metrics, expts, epochs = data[arch]["metrics"].shape 244 | message = ( 245 | f"{arch} : loaded {metrics} metrics, {expts} experiments, {epochs} epochs" 246 | ) 247 | assert metrics == 5, "INVALID metrics count: " + message 248 | assert expts < 88, "INVALID experiments count: " + message 249 | assert epochs == epochs, f"INVALID epochs count: " + message 250 | print(message) 251 | 252 | # ## Set filters on the data to view 253 | 254 | metric_data = get_metric_data(data, limits) 255 | 256 | # Draw loss and accuracy curves 257 | 258 | create_max_accuracy_curves(metric_data, epochs, run) 259 | 260 | create_loss_curves(metric_data, epochs, run) 261 | create_loss_curves(metric_data, epochs, run, ds_len=ds_len) 262 | 263 | most_interesting_metric_data = most_interesting(metric_data) 264 | 265 | create_loss_curves( 266 | most_interesting_metric_data, epochs, run, most_interesting_only=True 267 | ) 268 | create_loss_curves( 269 | most_interesting_metric_data, 270 | epochs, 271 | run, 272 | most_interesting_only=True, 273 | ds_len=ds_len, 274 | ) 275 | 276 | # Draw max accuracy curves 277 | 278 | # T-SNE of loss curves: 279 | try: 280 | for arch in most_interesting_metric_data: 281 | t = int(most_interesting_metric_data[arch]["T"][0].item()) 282 | expt = f"{arch}_T-{t}_DROP-0.0" 283 | create_tsne_graphs(operation, expt, run_dir=f"{DATA_DIR}/{run}/{expt}") 284 | except: 285 | print("TSNE failed") 286 | -------------------------------------------------------------------------------- /scripts/create_metrics_for_epochs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import logging 4 | 5 | logging.basicConfig(level=logging.ERROR) 6 | import csv 7 | import copy 8 | import os 9 | import grok 10 | import numpy as np 11 | import sys 12 | import subprocess 13 | import torch 14 | from torch.multiprocessing import Process 15 | from grok import trainer 16 | from tqdm import tqdm 17 | from argparse import ArgumentParser 18 | from collections import Counter 19 | 20 | 21 | torch.multiprocessing.freeze_support() 22 | try: 23 | torch.multiprocessing.set_start_method("spawn") 24 | except RuntimeError: 25 | pass 26 | 27 | # Get args 28 | EPOCHS = ( 29 | list(range(10)) 30 | + list(range(10, 200, 2)) 31 | + list(range(200, 5000, 10)) 32 | + list(range(5000, 10000, 50)) 33 | + [10000] 34 | ) 35 | 36 | parser = ArgumentParser() 37 | parser.add_argument( 38 | "--data_dir", type=str, help="where to find the runs", required=True 39 | ) 40 | parser.add_argument("--expt", type=str, default=None) 41 | parser.add_argument("--epochs_per_run", type=int, default=40) 42 | 43 | 44 | def parent(expts): 45 | for expt in expts: 46 | print(f"Processing {expt}") 47 | all_results = {} 48 | for first_epoch in range(0, len(EPOCHS), hparams.epochs_per_run): 49 | these_epochs = [ 50 | str(e) 51 | for e in EPOCHS[first_epoch : first_epoch + hparams.epochs_per_run] 52 | ] 53 | expt_dir = data_dir + "/" + expt 54 | cmd = [ 55 | "./create_partial_metrics.py", 56 | f"--gpu={hparams.gpu}", 57 | f"--expt_dir={expt_dir}", 58 | f'--epochs={",".join(these_epochs)}', 59 | ] 60 | result = subprocess.run(cmd, capture_output=False, shell=False) 61 | if result.returncode != 0: 62 | sys.exit(result.returncode) 63 | 64 | 65 | hparams = trainer.get_args(parser) 66 | 67 | data_dir = hparams.data_dir 68 | 69 | if hparams.expt is not None: 70 | expts = [hparams.expt] 71 | else: 72 | expts = os.listdir(data_dir) 73 | 74 | parent(expts) 75 | -------------------------------------------------------------------------------- /scripts/create_partial_metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import logging 4 | 5 | logging.basicConfig(level=logging.ERROR) 6 | import csv 7 | import copy 8 | import glob 9 | import os 10 | import grok 11 | import numpy as np 12 | import subprocess 13 | import torch 14 | import sys 15 | from torch.multiprocessing import Process 16 | from grok import trainer 17 | from tqdm import tqdm 18 | from argparse import ArgumentParser 19 | from collections import Counter 20 | from grok_runs import RUNS 21 | from grok_metrics_lib import ( 22 | DATA_DIR, 23 | load_metric_data, 24 | get_metric_data, 25 | most_interesting, 26 | ) 27 | 28 | 29 | # Make N_EPOCHS exponentially spaced sets of epochs from 1 to 10,000 30 | N_EPOCHS = 32 31 | BASE = 9999 ** (1.0 / (N_EPOCHS - 1)) 32 | epochs = (BASE ** torch.arange(1, N_EPOCHS).float()).long().tolist() 33 | DEFAULT_EPOCHS = ",".join([str(i) for i in epochs]) 34 | 35 | parser = ArgumentParser() 36 | parser.add_argument("--expt_dir", type=str, help="where to find the runs") 37 | parser.add_argument("--epochs", type=str, default=DEFAULT_EPOCHS) 38 | 39 | 40 | def child(hparams): 41 | expt_dir = hparams.expt_dir 42 | epochs = [int(e) for e in hparams.epochs.split(",")] 43 | # print("epochs = ", epochs) 44 | device = torch.device(f"cuda:{hparams.gpu}") 45 | ckpt_dir = expt_dir + "/" + "checkpoints" 46 | # ckpt_files = [ckpt_dir + f"/epoch={epoch}.ckpt" for epoch in epochs] 47 | hparams.logdir = expt_dir 48 | 49 | results = { 50 | "val_loss": None, 51 | "val_accuracy": None, 52 | } 53 | 54 | processed_epochs = [] 55 | # with tqdm(epochs, unit="epochs", initial=epochs[0], total=epochs[-1]) as pbar: 56 | # last_epoch = epochs[0] 57 | for idx, epoch in tqdm(list(enumerate(epochs))): 58 | # pbar.update(epoch - last_epoch) 59 | # last_epoch = epoch 60 | ckpt_files = glob.glob(ckpt_dir + f"/epoch={epoch}-step=*.ckpt") 61 | ckpt_files += glob.glob(ckpt_dir + f"/epoch={epoch}.ckpt") 62 | try: 63 | ckpt_file = ckpt_files[-1] 64 | ckpt = torch.load( 65 | ckpt_file, 66 | map_location=f"cuda:{0}", # FIXME 67 | ) 68 | processed_epochs.append(epoch) 69 | except FileNotFoundError: 70 | continue 71 | 72 | for k, v in ckpt["hyper_parameters"].items(): 73 | setattr(hparams, k, v) 74 | 75 | new_state_dict = {} 76 | for k, v in ckpt["state_dict"].items(): 77 | if k.startswith("transformer."): 78 | new_state_dict[k] = v 79 | else: 80 | new_state_dict["transformer." + k] = v 81 | ckpt["state_dict"] = new_state_dict 82 | 83 | model = trainer.TrainableTransformer(hparams).float() 84 | model.load_state_dict(ckpt["state_dict"]) 85 | model = model.to(device).eval() 86 | dl = model.test_dataloader() 87 | dl.reset_iteration(shuffle=False) 88 | 89 | outputs = [model.test_step(batch, idx) for (idx, batch) in enumerate(dl)] 90 | r = model.test_epoch_end(outputs)["log"] 91 | if results["val_loss"] is None: 92 | results["val_loss"] = r["test_loss"].squeeze().unsqueeze(0) 93 | results["val_accuracy"] = r["test_accuracy"].squeeze().unsqueeze(0) 94 | else: 95 | results["val_loss"] = torch.cat( 96 | [results["val_loss"], r["test_loss"].squeeze().unsqueeze(0)], dim=0 97 | ) 98 | results["val_accuracy"] = torch.cat( 99 | [ 100 | results["val_accuracy"], 101 | r["test_accuracy"].squeeze().unsqueeze(0), 102 | ], 103 | dim=0, 104 | ) 105 | 106 | for k, v in results.items(): 107 | results[k] = v.to("cpu") 108 | results["epochs"] = torch.LongTensor(processed_epochs, device="cpu") 109 | results["dl"] = dl 110 | 111 | os.makedirs(expt_dir + "/activations", exist_ok=True) 112 | ptfile = ( 113 | expt_dir + f"/activations/activations_{epochs[0]:010d}_{epochs[-1]:010d}.pt" 114 | ) 115 | torch.save(results, ptfile) 116 | 117 | 118 | if __name__ == "__main__": 119 | hparams = trainer.get_args(parser) 120 | if hparams.expt_dir is not None: 121 | child(hparams) 122 | else: 123 | for operation in RUNS: 124 | print(f"running {operation}") 125 | ds_len, run = RUNS[operation] 126 | data = load_metric_data( 127 | f"{DATA_DIR}/{run}", epochs=10000, load_partial_data=False 128 | ) 129 | metric_data = get_metric_data(data) 130 | metric_data = most_interesting(metric_data) 131 | for arch in metric_data: 132 | interesting_t = int(metric_data[arch]["T"][0].item()) 133 | expt = f"{arch}_T-{interesting_t}" 134 | print(f"--> expt {expt}") 135 | glb = f"{DATA_DIR}/{run}/{expt}_*" 136 | # print(f"glb {glb}") 137 | expt_dir = glob.glob(glb)[0] 138 | cmd = [sys.argv[0], "--expt_dir", expt_dir] 139 | subprocess.run(cmd, check=False, shell=False) 140 | # child(hparams) 141 | -------------------------------------------------------------------------------- /scripts/make_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from argparse import ArgumentParser 4 | from grok.data import create_data_files, DEFAULT_DATA_DIR 5 | 6 | 7 | parser = ArgumentParser() 8 | parser.add_argument("-d", "--data_directory", type=str, default=DEFAULT_DATA_DIR) 9 | args = parser.parse_args() 10 | create_data_files(args.data_directory) -------------------------------------------------------------------------------- /scripts/torch-setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set up torch with magma support 4 | 5 | DIR="`mktemp -d`" 6 | cd $DIR 7 | 8 | # Install deps 9 | conda install -y numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses 10 | 11 | # Install torch, magma 12 | conda install -y -c pytorch magma-cuda110 13 | 14 | # Build torch from scratch 15 | git clone --recursive https://github.com/pytorch/pytorch 16 | cd pytorch 17 | 18 | # If updating an existing checkout 19 | # git submodule sync 20 | # git submodule update --init --recursive 21 | 22 | export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} 23 | python setup.py install -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import grok 4 | import os 5 | 6 | parser = grok.training.add_args() 7 | parser.set_defaults(logdir=os.environ.get("GROK_LOGDIR", ".")) 8 | hparams = parser.parse_args() 9 | hparams.datadir = os.path.abspath(hparams.datadir) 10 | hparams.logdir = os.path.abspath(hparams.logdir) 11 | 12 | 13 | print(hparams) 14 | print(grok.training.train(hparams)) 15 | -------------------------------------------------------------------------------- /scripts/visualize_metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import csv 4 | import json 5 | import logging 6 | import os 7 | import subprocess 8 | from argparse import ArgumentParser 9 | from copy import deepcopy 10 | from glob import glob 11 | from pprint import pprint 12 | 13 | import blobfile as bf 14 | import grok 15 | import matplotlib.pyplot as plt 16 | import matplotlib.ticker as mtick 17 | import numpy as np 18 | import torch 19 | import yaml 20 | from tqdm import tqdm 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | # take args: input_dir output_dir 25 | parser = ArgumentParser() 26 | parser.add_argument( 27 | "-i", 28 | "--input_dir", 29 | type=str, 30 | required=True, 31 | ) 32 | parser.add_argument( 33 | "-o", 34 | "--output_dir", 35 | type=str, 36 | required=True, 37 | ) 38 | parser = grok.training.add_args(parser) 39 | args = parser.parse_args() 40 | print(args, flush=True) 41 | 42 | if torch.cuda.is_available(): 43 | device = "cuda" 44 | else: 45 | device = "cpu" 46 | 47 | 48 | def load_expt_metrics( 49 | expt_dir, 50 | args, 51 | ): 52 | """load the metrics for one experiment""" 53 | args = deepcopy(args) 54 | 55 | # load the hparams for this experiment 56 | with open(f"{expt_dir}/default/version_0/hparams.yaml", "r") as fh: 57 | hparams_dict = yaml.safe_load(fh) 58 | 59 | for k, v in hparams_dict.items(): 60 | setattr(args, k, v) 61 | 62 | # load the summarized validation and training data for every epoch 63 | val_data = { 64 | "step": [], 65 | "epoch": [], 66 | "val_loss": [], 67 | "val_accuracy": [], 68 | } 69 | train_data = { 70 | "step": [], 71 | "epoch": [], 72 | "train_loss": [], 73 | "train_accuracy": [], 74 | "learning_rate": [], 75 | } 76 | 77 | with open(f"{expt_dir}/default/version_0/metrics.csv", "r") as fh: 78 | for row in csv.DictReader(fh): 79 | if row["train_loss"] != "": 80 | for k in train_data: 81 | if k in ["step", "epoch"]: 82 | v = int(row[k]) 83 | else: 84 | v = float(row[k]) 85 | train_data[k].append(v) 86 | else: 87 | for k in val_data: 88 | if k in ["step", "epoch"]: 89 | v = int(row[k]) 90 | else: 91 | v = float(row[k]) 92 | val_data[k].append(v) 93 | 94 | return { 95 | "hparams": hparams_dict, 96 | "train": train_data, 97 | "val": val_data, 98 | # "raw": raw_data, 99 | } 100 | 101 | 102 | def load_run_metrics( 103 | run_dir, 104 | args=args, 105 | ): 106 | """load all the metrics for a collection of experiments with the same architecture 107 | across various amounts of training data""" 108 | metric_data = {} 109 | from os import walk 110 | 111 | _, expt_dirs, _ = next(os.walk(run_dir)) 112 | for expt_dir in tqdm(expt_dirs, unit="expt"): 113 | try: 114 | expt_data = load_expt_metrics(f"{run_dir}/{expt_dir}", args) 115 | train_data_pct = expt_data["hparams"]["train_data_pct"] 116 | metric_data[train_data_pct] = expt_data 117 | except FileNotFoundError: 118 | pass 119 | return metric_data 120 | 121 | 122 | def add_metric_graph( 123 | fig, 124 | ax, 125 | arch, 126 | metric, 127 | metric_data, 128 | scales, 129 | cmap="viridis", 130 | by="step", # step or epoch 131 | max_increment=0, 132 | ): 133 | ax.set_title(metric) 134 | ax.set_xscale(scales["x"]) 135 | ax.set_yscale(scales["y"]) 136 | ax.set_xlabel(by) 137 | 138 | if "accuracy" in metric: 139 | ax.yaxis.set_major_formatter(mtick.PercentFormatter()) 140 | ymin = 1e-16 141 | ymax = 101 142 | ax.axis(ymin=ymin, ymax=ymax) 143 | if "loss" in metric: 144 | ymin = 1e-16 145 | ymax = 15 146 | ax.axis(ymin=ymin, ymax=ymax) 147 | 148 | total_plots = 0 149 | logger.debug(f"processing {metric}") 150 | plots = [] 151 | T = list(sorted(metric_data.keys())) 152 | T_max = int(T[-1]) 153 | T_min = int(T[0]) 154 | sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=T[0], vmax=T[-1])) 155 | colors = sm.to_rgba(T) 156 | for i, t in enumerate(T): 157 | if "val" in metric: 158 | this_data = metric_data[t]["val"] 159 | else: 160 | this_data = metric_data[t]["train"] 161 | 162 | X = this_data[by] 163 | Y = this_data[metric] 164 | if max_increment > 0: 165 | X = [x for x in X if x <= max_increment] 166 | Y = Y[: len(X)] 167 | 168 | if len(X) != len(Y): 169 | logger.warning(f"Mismatched data: {metric} at t={t}") 170 | continue 171 | if not Y: 172 | logger.warning(f"No data for {metric}i at t={t}") 173 | continue 174 | 175 | label = arch + f" t={t}" 176 | 177 | if "accuracy" in metric: 178 | label += " (max = %.2f)" % max(Y) 179 | elif "loss" in metric: 180 | label += " (min = %.2f)" % min(Y) 181 | total_plots += 1 182 | ax.plot(X, Y, label=label, color=colors[i]) 183 | if T_max - T_min <= 10: 184 | ax.legend() 185 | else: 186 | fig.colorbar( 187 | sm, 188 | ax=ax, 189 | label="% training data", 190 | ticks=range(T_min, T_max + 1, int((T_max - T_min) / 5)), 191 | ) 192 | 193 | 194 | def add_max_accuracy_graph( 195 | ax, 196 | arch, 197 | metric, 198 | metric_data, 199 | scales, 200 | by="step", 201 | max_increment=0, 202 | ): 203 | ax.set_title(f"max {metric}") 204 | ax.set_xlabel("% of total data trained on") 205 | ax.xaxis.set_major_formatter(mtick.PercentFormatter()) 206 | xmin = 0 207 | xmax = 100 208 | ymin = 1e-16 209 | ymax = 101 210 | ax.axis(xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax) 211 | ax.set_xscale(scales["x"]) 212 | ax.set_yscale(scales["y"]) 213 | ax.yaxis.set_major_formatter(mtick.PercentFormatter()) 214 | ax.xaxis.set_major_formatter(mtick.PercentFormatter()) 215 | 216 | T = list(sorted(metric_data.keys())) 217 | T_max = int(T[-1]) 218 | T_min = int(T[0]) 219 | Y = [] 220 | for i, t in enumerate(T): 221 | if "val" in metric: 222 | this_data = metric_data[t]["val"] 223 | else: 224 | this_data = metric_data[t]["train"] 225 | X = this_data[by] 226 | if max_increment > 0: 227 | X = [x for x in X if x <= max_increment] 228 | max_idx = len(X) 229 | else: 230 | max_idx = -1 231 | try: 232 | Y.append(max(this_data[metric][:max_idx])) 233 | except ValueError: 234 | Y.append(np.nan) 235 | 236 | ax.set_xticks(np.arange(0, 100, 5)) 237 | label = f"max {metric} {arch}" 238 | ax.plot(T, Y, label=label) 239 | 240 | 241 | def create_loss_curves( 242 | metric_data, 243 | arch, 244 | operation, 245 | # epochs, 246 | most_interesting_only=False, 247 | image_dir=args.output_dir, 248 | by="step", 249 | max_increment=0, 250 | cmap="viridis", 251 | ): 252 | scales = { 253 | "x": "log", 254 | "y": "linear", 255 | } 256 | 257 | ncols = 2 258 | nrows = 3 259 | fig_width = ncols * 8 260 | fig_height = nrows * 5 261 | fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) 262 | 263 | add_metric_graph( 264 | fig, 265 | axs[0, 0], 266 | arch, 267 | "val_loss", 268 | metric_data, 269 | scales, 270 | cmap, 271 | by, 272 | max_increment=max_increment, 273 | ) 274 | add_metric_graph( 275 | fig, 276 | axs[0, 1], 277 | arch, 278 | "val_accuracy", 279 | metric_data, 280 | scales, 281 | cmap, 282 | by, 283 | max_increment=max_increment, 284 | ) 285 | add_metric_graph( 286 | fig, 287 | axs[1, 0], 288 | arch, 289 | "train_loss", 290 | metric_data, 291 | scales, 292 | cmap, 293 | by, 294 | max_increment=max_increment, 295 | ) 296 | add_metric_graph( 297 | fig, 298 | axs[1, 1], 299 | arch, 300 | "train_accuracy", 301 | metric_data, 302 | scales, 303 | cmap, 304 | by, 305 | max_increment=max_increment, 306 | ) 307 | add_metric_graph( 308 | fig, 309 | axs[2, 0], 310 | arch, 311 | "learning_rate", 312 | metric_data, 313 | scales, 314 | cmap, 315 | by, 316 | max_increment=max_increment, 317 | ) 318 | fig.suptitle(f"{operation} {arch} {max_increment:06d} {by}s") 319 | fig.tight_layout() 320 | 321 | img_file = f"{image_dir}/loss_curves/{operation}_loss_curves_{arch}__upto_{max_increment:010d}_{by}" 322 | if most_interesting_only: 323 | img_file += "_most_interesting" 324 | img_file += ".png" 325 | d = os.path.split(img_file)[0] 326 | os.makedirs(d, exist_ok=True) 327 | print(f"Writing {img_file}") 328 | fig.savefig(img_file) 329 | plt.close(fig) 330 | 331 | 332 | def create_max_accuracy_curves( 333 | metric_data, 334 | arch, 335 | operation, 336 | by="step", 337 | max_increment=0, 338 | image_dir=args.output_dir, 339 | ): 340 | scales = { 341 | "x": "linear", 342 | "y": "linear", 343 | } 344 | 345 | ncols = 1 346 | nrows = 2 347 | fig_width = ncols * 8 348 | fig_height = nrows * 5 349 | fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) 350 | 351 | add_max_accuracy_graph( 352 | axs[0], 353 | arch, 354 | "val_accuracy", 355 | metric_data, 356 | scales, 357 | by=by, 358 | max_increment=max_increment, 359 | ) 360 | axs[0].legend() 361 | add_max_accuracy_graph( 362 | axs[1], 363 | arch, 364 | "train_accuracy", 365 | metric_data, 366 | scales, 367 | by=by, 368 | max_increment=max_increment, 369 | ) 370 | axs[1].legend() 371 | fig.suptitle(f"{operation} {arch} {max_increment:06d} {by}s") 372 | fig.tight_layout() 373 | 374 | img_file = f"{image_dir}/max_accuracy/{operation}_max_accuracy_{arch}_upto_{max_increment:010d}_{by}.png" 375 | d = os.path.split(img_file)[0] 376 | os.makedirs(d, exist_ok=True) 377 | print(f"Writing {img_file}") 378 | fig.savefig(img_file) 379 | plt.close(fig) 380 | 381 | 382 | def create_tsne_graphs( 383 | operation, 384 | expt, 385 | run_dir, 386 | image_dir=args.output_dir, 387 | ): 388 | 389 | saved_pt_dir = f"{run_dir}/activations" 390 | saved_pts = [] 391 | 392 | loss_ts = [] 393 | accuracy_ts = [] 394 | epochs_ts = [] 395 | print(f'glob = {saved_pt_dir + "/activations_*.pt"}') 396 | files = sorted(glob.glob(saved_pt_dir + "/activations_*.pt")) 397 | print(f"files = {files}") 398 | 399 | for file in files: 400 | print(f"Loading {file}") 401 | saved_pt = torch.load(file) 402 | saved_pts.append(saved_pt) 403 | loss_ts.append(saved_pt["val_loss"].mean(dim=-1)) 404 | accuracy_ts.append(saved_pt["val_accuracy"]) 405 | epochs_ts.append(saved_pt["epochs"].squeeze()) 406 | 407 | loss_t = torch.cat(loss_ts, dim=0).T.detach() 408 | accuracy_t = torch.cat(accuracy_ts, dim=0).T.detach() 409 | epochs_t = torch.cat(epochs_ts, dim=0).detach() 410 | print(loss_t.shape) 411 | print(accuracy_t.shape) 412 | print(epochs_t.shape) 413 | ###### 414 | a = 0 415 | num_eqs = len(loss_t) 416 | b = a + num_eqs 417 | 418 | print("Doing T-SNE..") 419 | loss_tsne = TSNE(n_components=2, init="pca").fit_transform(loss_t) 420 | print("...done T-SNE.") 421 | 422 | ncols = 1 423 | nrows = 1 424 | fig_width = ncols * 8 425 | fig_height = nrows * 5 426 | fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height)) 427 | 428 | axs.scatter(loss_tsne[:, 0], loss_tsne[:, 1]) 429 | 430 | img_file = f"{image_dir}/tsne/{operation}_{expt}.png" 431 | d = os.path.split(img_file)[0] 432 | os.makedirs(d, exist_ok=True) 433 | print(f"Writing {img_file}") 434 | fig.savefig(img_file) 435 | plt.close(fig) 436 | 437 | 438 | def get_arch(metric_data): 439 | k = list(metric_data.keys())[0] 440 | hparams = metric_data[k]["hparams"] 441 | arch = f'L-{hparams["n_layers"]}_H-{hparams["n_heads"]}_D-{hparams["d_model"]}_B-{hparams["batchsize"]}_S-{hparams["random_seed"]}_DR-{hparams["dropout"]}' 442 | return arch 443 | 444 | 445 | def get_operation(metric_data): 446 | k = list(metric_data.keys())[0] 447 | hparams = metric_data[k]["hparams"] 448 | operator = hparams["math_operator"] 449 | operand_length = hparams["operand_length"] 450 | _, operation = grok.data.ArithmeticDataset.get_file_path(operator, operand_length) 451 | return operation 452 | 453 | 454 | def get_max_epochs(metric_data): 455 | k = list(metric_data.keys())[0] 456 | hparams = metric_data[k]["hparams"] 457 | return hparams["max_epochs"] 458 | 459 | 460 | rundir = args.input_dir 461 | 462 | try: 463 | metric_data = load_run_metrics(rundir, args) 464 | arch = get_arch(metric_data) 465 | operation = get_operation(metric_data) 466 | max_epochs = get_max_epochs(metric_data) 467 | 468 | for by in ["step", "epoch"]: 469 | create_loss_curves(metric_data, arch, operation, by=by) 470 | 471 | by = "epoch" 472 | last_i = -1 473 | for i in sorted(list(set(2 ** (np.arange(167) / 10)))): 474 | if i > max_epochs: 475 | break 476 | i = int(round(i)) 477 | create_max_accuracy_curves( 478 | metric_data, 479 | arch, 480 | operation, 481 | by=by, 482 | max_increment=i, 483 | ) 484 | 485 | # make a video 486 | in_files = os.path.join( 487 | args.output_dir, 488 | "max_accuracy", 489 | f"{operation}_max_accuracy_{arch}_upto_%*.png", 490 | ) 491 | out_file = os.path.join(args.output_dir, f"{operation}_{arch}_max_accuracy.mp4") 492 | cmd = [ 493 | "ffmpeg", 494 | "-y", 495 | "-r", 496 | "16", 497 | "-i", 498 | in_files, 499 | "-vcodec", 500 | "libx264", 501 | "-crf", 502 | "25", 503 | "-pix_fmt", 504 | "yuv420p", 505 | out_file, 506 | ] 507 | subprocess.check_call(cmd) 508 | 509 | except BaseException as e: 510 | print(f"{rundir} failed: {e}") 511 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="grok", 5 | packages=find_packages(), 6 | version="0.0.1", 7 | install_requires=[ 8 | "pytorch_lightning", 9 | "blobfile", 10 | "numpy", 11 | "torch", 12 | "tqdm", 13 | "scipy", 14 | "mod", 15 | "matplotlib", 16 | ], 17 | ) 18 | --------------------------------------------------------------------------------