The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .gitignore
├── LICENSE
├── README.md
├── grok
    ├── __init__.py
    ├── data.py
    ├── measure.py
    ├── metrics.py
    ├── training.py
    ├── transformer.py
    └── visualization.py
├── nbs
    └── flatness.ipynb
├── scripts
    ├── compute_sharpness.py
    ├── create_metric_graphs.py
    ├── create_metrics_for_epochs.py
    ├── create_partial_metrics.py
    ├── make_data.py
    ├── torch-setup.sh
    ├── train.py
    └── visualize_metrics.py
└── setup.py


/.gitignore:
--------------------------------------------------------------------------------
  1 | # Byte-compiled / optimized / DLL files
  2 | __pycache__/
  3 | *.py[cod]
  4 | *$py.class
  5 | 
  6 | # C extensions
  7 | *.so
  8 | 
  9 | # Distribution / packaging
 10 | .Python
 11 | build/
 12 | develop-eggs/
 13 | dist/
 14 | downloads/
 15 | eggs/
 16 | .eggs/
 17 | lib/
 18 | lib64/
 19 | parts/
 20 | sdist/
 21 | var/
 22 | wheels/
 23 | pip-wheel-metadata/
 24 | share/python-wheels/
 25 | *.egg-info/
 26 | .installed.cfg
 27 | *.egg
 28 | MANIFEST
 29 | 
 30 | # PyInstaller
 31 | #  Usually these files are written by a python script from a template
 32 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 33 | *.manifest
 34 | *.spec
 35 | 
 36 | # Installer logs
 37 | pip-log.txt
 38 | pip-delete-this-directory.txt
 39 | 
 40 | # Unit test / coverage reports
 41 | htmlcov/
 42 | .tox/
 43 | .nox/
 44 | .coverage
 45 | .coverage.*
 46 | .cache
 47 | nosetests.xml
 48 | coverage.xml
 49 | *.cover
 50 | *.py,cover
 51 | .hypothesis/
 52 | .pytest_cache/
 53 | 
 54 | # Translations
 55 | *.mo
 56 | *.pot
 57 | 
 58 | # Django stuff:
 59 | *.log
 60 | local_settings.py
 61 | db.sqlite3
 62 | db.sqlite3-journal
 63 | 
 64 | # Flask stuff:
 65 | instance/
 66 | .webassets-cache
 67 | 
 68 | # Scrapy stuff:
 69 | .scrapy
 70 | 
 71 | # Sphinx documentation
 72 | docs/_build/
 73 | 
 74 | # PyBuilder
 75 | target/
 76 | 
 77 | # Jupyter Notebook
 78 | .ipynb_checkpoints
 79 | 
 80 | # IPython
 81 | profile_default/
 82 | ipython_config.py
 83 | 
 84 | # pyenv
 85 | .python-version
 86 | 
 87 | # pipenv
 88 | #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
 89 | #   However, in case of collaboration, if having platform-specific dependencies or dependencies
 90 | #   having no cross-platform support, pipenv may install dependencies that don't work, or not
 91 | #   install all needed dependencies.
 92 | #Pipfile.lock
 93 | 
 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
 95 | __pypackages__/
 96 | 
 97 | # Celery stuff
 98 | celerybeat-schedule
 99 | celerybeat.pid
100 | 
101 | # SageMath parsed files
102 | *.sage.py
103 | 
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 | 
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 | 
117 | # Rope project settings
118 | .ropeproject
119 | 
120 | # mkdocs documentation
121 | /site
122 | 
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 | 
128 | # Pyre type checker
129 | .pyre/
130 | 
131 | default
132 | checkpoints
133 | .vscode
134 | 


--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 | MIT License
 2 | 
 3 | Copyright (c) 2021 OpenAI
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 | 


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
 1 | # OpenAI Grok Curve Experiments
 2 | 
 3 | ## Paper
 4 | 
 5 | This is the code for the paper [Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets](https://arxiv.org/abs/2201.02177) by Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin, and Vedant Misra
 6 | 
 7 | ## Installation and Training
 8 | 
 9 | ```bash
10 | pip install -e .
11 | ./scripts/train.py
12 | ```
13 | 


--------------------------------------------------------------------------------
/grok/__init__.py:
--------------------------------------------------------------------------------
1 | from . import transformer
2 | from . import data
3 | from . import training
4 | from . import metrics
5 | from . import visualization
6 | 


--------------------------------------------------------------------------------
/grok/data.py:
--------------------------------------------------------------------------------
  1 | import itertools
  2 | import math
  3 | import os
  4 | import sys
  5 | import random
  6 | 
  7 | import torch
  8 | from torch import Tensor, LongTensor
  9 | import numpy as np
 10 | from typing import Tuple, List, Dict, Any, Union, Optional
 11 | from tqdm import tqdm
 12 | 
 13 | from sympy.combinatorics.permutations import Permutation
 14 | from mod import Mod
 15 | 
 16 | import blobfile as bf
 17 | 
 18 | 
 19 | VALID_OPERATORS = {
 20 |     "+": "addition",
 21 |     "-": "subtraction",
 22 |     "*": "muliplication",
 23 |     "/": "division",
 24 |     "**2+": "squarepoly",
 25 |     "**3+": "cubepoly",
 26 |     "x**2+y**2_mod_97": "quad1",
 27 |     "x**2+y**2+x*y_mod_97": "quad2",
 28 |     "x**2+y**2+x*y+x_mod_97": "quad3",
 29 |     "x**3+x*y_mod_97": "cube1",
 30 |     "x**3+x*y**2+y_mod_97": "cube2",
 31 |     "(x._value//y)if(y._value%2==1)else(x-y)_mod_97": "mix1",
 32 |     "s5": "s5",
 33 |     "s5conj": "s5conj",
 34 |     "s5aba": "s5aba",
 35 |     "+*": "even-addition_odd-multiplication",
 36 |     "+-": "even-addition_odd-subtraction",
 37 |     "sort": "sort",
 38 |     "reverse": "reverse",
 39 |     "copy": "copy",
 40 | }
 41 | EOS_TOKEN = "<|eos|>"
 42 | EQ_TOKEN = "="
 43 | MODULUS = 97
 44 | NUMS = list(range(MODULUS))
 45 | 
 46 | DEFAULT_DATA_DIR = "data"
 47 | 
 48 | 
 49 | def render(operand, join_str=""):
 50 |     if (
 51 |         isinstance(operand, list)
 52 |         or isinstance(operand, tuple)
 53 |         or isinstance(operand, np.ndarray)
 54 |     ):
 55 |         return join_str.join(map(render, operand))
 56 |     elif isinstance(operand, Permutation):
 57 |         return "".join(map(str, operand.array_form))
 58 |     elif isinstance(operand, Mod):
 59 |         return str(operand._value)
 60 |     else:
 61 |         return str(operand)
 62 | 
 63 | 
 64 | def create_data_files(data_dir: str = DEFAULT_DATA_DIR):
 65 |     ArithmeticTokenizer.create_token_file(data_dir)
 66 |     ArithmeticDataset.create_dataset_files(data_dir)
 67 | 
 68 | 
 69 | class ArithmeticTokenizer:
 70 |     """Stores the list of token text to token id mappings and converts between them"""
 71 | 
 72 |     token_file = "tokens.txt"
 73 | 
 74 |     def __init__(self, data_dir=DEFAULT_DATA_DIR) -> None:
 75 |         self.token_file = bf.join(data_dir, self.token_file)
 76 | 
 77 |         self.itos = self.get_tokens()
 78 | 
 79 |         self.stoi: Dict[str, int] = dict([(s, i) for i, s in enumerate(self.itos)])
 80 | 
 81 |     def _encode(self, s: str) -> Tensor:
 82 |         return LongTensor([self.stoi[t] for t in s.split(" ")])
 83 | 
 84 |     def encode(self, obj: Union[str, List]) -> Tensor:
 85 |         """
 86 |         Convert a string of text into a rank-1 tensor of token ids
 87 |         or convert a list of strings of text into a rank-2 tensor of token ids
 88 | 
 89 |         :param obj: the string or list of strings to convert
 90 |         :returns: a tensor of the token ids
 91 |         """
 92 |         if isinstance(obj, str):
 93 |             return self._encode(obj)
 94 |         elif isinstance(obj, list):
 95 |             return torch.stack([self._encode(s) for s in obj], dim=0)
 96 |         else:
 97 |             raise NotImplementedError
 98 | 
 99 |     def decode(self, tensor: Tensor, with_brackets: bool = False) -> str:
100 |         """
101 |         Convert a tensor of token ids into a string of text
102 | 
103 |         :param tensor: a tensor of the token ids
104 |         :param with_brackets: if true, the returned string will include <> brackets
105 |                               around the text corresponding to each token.
106 |         :returns: string of these tokens.
107 |         """
108 |         indices = tensor.long()
109 |         if with_brackets:
110 |             l = "<"
111 |             r = ">"
112 |         else:
113 |             l = ""
114 |             r = ""
115 |         tokens = [l + self.itos[i] + r for i in indices]
116 |         return " ".join(tokens)
117 | 
118 |     def __len__(self) -> int:
119 |         """
120 |         :returns: the number of tokens in this vocabulary
121 |         """
122 |         return len(self.itos)
123 | 
124 |     @classmethod
125 |     def get_tokens(cls):
126 |         tokens = (
127 |             [EOS_TOKEN, EQ_TOKEN]
128 |             + list(sorted(list(VALID_OPERATORS.keys())))
129 |             + list(map(render, NUMS))
130 |             + list(map(render, itertools.permutations(range(5))))  # s5
131 |         )
132 |         return tokens
133 | 
134 | 
135 | class ArithmeticDataset:
136 |     """A Dataset of arithmetic equations"""
137 | 
138 |     @classmethod
139 |     def splits(
140 |         cls,
141 |         train_pct: float,
142 |         operator: str,
143 |         operand_length: Optional[int] = None,
144 |         data_dir: str = DEFAULT_DATA_DIR,
145 |     ):
146 |         """
147 |         Creates training and validation datasets
148 | 
149 |         :param train_pct: percentage of total equations used for training data
150 |         :param operator: The arithmetic operator for this dataset e.g. '+', '-', '*', '/', 'sort'
151 |         :param operand_length: for list based datasets the length of the lists
152 |         :returns: (train_dataset, validation_dataset)
153 |         """
154 | 
155 |         assert (0 < train_pct) and (train_pct < 100)
156 | 
157 |         ds_name = cls.get_dsname(operator, operand_length)
158 |         eqs = cls.make_data(operator, operand_length)
159 | 
160 |         train_rows, _ = cls.calc_split_len(train_pct, len(eqs))
161 | 
162 |         train_ds = cls(ds_name, eqs[:train_rows], train=True, data_dir=data_dir)
163 |         val_ds = cls(ds_name, eqs[train_rows:], train=False, data_dir=data_dir)
164 | 
165 |         return train_ds, val_ds
166 | 
167 |     @classmethod
168 |     def calc_split_len(cls, train_pct, ds_len):
169 |         train_rows = round(ds_len * (train_pct / 100.0))
170 |         val_rows = ds_len - train_rows
171 |         return train_rows, val_rows
172 | 
173 |     def __init__(self, name, data: Union[Tensor, List[str]], train, data_dir) -> None:
174 |         """
175 |         :param data: A list of equations strings. Each equation must have an '=' in it.
176 |         """
177 |         self.tokenizer = ArithmeticTokenizer(data_dir)
178 |         self.name = name
179 |         self.train = train
180 |         if isinstance(data, list):
181 |             self.data = self.tokenizer.encode(data)
182 |         else:
183 |             self.data = data
184 | 
185 |     def __len__(self) -> int:
186 |         """
187 |         :returns: total number of equations in this dataset
188 |         """
189 |         return self.data.shape[0]
190 | 
191 |     # @classmethod
192 |     # def _render(cls, operand):
193 |     #    return render(operand, join_str=" ")
194 |     #
195 |     # @classmethod
196 |     # def _render_eq(parts):
197 |     #    return " ".join(map(render, parts))
198 | 
199 |     @classmethod
200 |     def _make_binary_operation_data(cls, operator: str, operands=None) -> List[str]:
201 |         if operator == "s5":
202 |             operands = operands or list(range(5))
203 |             elems = map(np.array, itertools.permutations(operands))
204 |             tuples = itertools.product(elems, repeat=2)
205 |         elif operator in ["s5conj", "s5aba"]:
206 |             operands = operands or list(range(5))
207 |             elems = map(Permutation, itertools.permutations(operands))
208 |             tuples = itertools.product(elems, repeat=2)
209 |         elif "_mod_" in operator:
210 |             modulo = int(operator.split("_mod_")[-1])
211 |             elems = [Mod(i, modulo) for i in range(modulo)]
212 |             tuples = itertools.product(elems, repeat=2)
213 |         else:
214 |             operands = operands or NUMS
215 |             tuples = itertools.product(operands, repeat=2)
216 | 
217 |         # if operator == "s5":
218 |         #     print("elems", list(elems))
219 |         #     print("tuples", list(tuples))
220 |         eqs = []
221 |         for a, b in tuples:
222 |             if operator == "/":
223 |                 if b == 0:
224 |                     continue
225 |                 else:
226 |                     c = a
227 |                     a = (b * c) % MODULUS
228 |             elif operator == "s5":
229 |                 c = b[a]
230 |             elif operator == "s5conj":
231 |                 c = a * b * (a.__invert__())
232 |             elif operator == "s5aba":
233 |                 c = a * b * a
234 |             elif operator == "+*":
235 |                 if a % 2 == 0:
236 |                     c = (a + b) % MODULUS
237 |                 else:
238 |                     c = (a * b) % MODULUS
239 |             elif operator == "+-":
240 |                 if a % 2 == 0:
241 |                     c = (a + b) % MODULUS
242 |                 else:
243 |                     c = (a - b) % MODULUS
244 |             elif "_mod_" in operator:
245 |                 expression = operator.split("_mod_")[0]
246 |                 function = eval(f"lambda x, y: ({expression})")
247 |                 c = function(a, b)
248 |             else:
249 |                 c = eval(f"({a} {operator} {b}) % {MODULUS}")
250 |             eq = " ".join(map(render, [a, operator, b, "=", c]))
251 |             eqs.append(eq)
252 | 
253 |         # if operator == "s5":
254 |         #     print("eqs", eqs)
255 |         return eqs
256 | 
257 |     # @staticmethod
258 |     # def _render_unop_example(operator, lhs, rhs):
259 |     #    return " ".join([operator, render(lhs), "=", render(rhs)])
260 | 
261 |     @staticmethod
262 |     def _make_unary_operation_data(operator: str, operands: Tensor) -> List[str]:
263 |         """
264 |         :param operator: The unary operator to apply to each operand e.g. '+'
265 |         :param operands: A tensor of operands
266 |         :returns: list of equations"""
267 |         num_examples = len(operands)
268 | 
269 |         if operator == "sort":
270 |             rhs = torch.sort(operands, dim=1)[0]
271 |         elif operator == "reverse":
272 |             rhs = torch.flip(operands, dims=(1,))
273 |         elif operator == "copy":
274 |             rhs = operands
275 |         else:
276 |             raise Exception("unsupported operator")
277 | 
278 |         def func(L, R):
279 |             L = map(str, L)
280 |             R = map(str, R)
281 |             return f"{operator} {' '.join(L)} = {' '.join(R)}"
282 | 
283 |         if num_examples < 1000000000:
284 |             eqs = [
285 |                 func(L, R)
286 |                 for L, R in tqdm(
287 |                     zip(operands.tolist(), rhs.tolist()), total=num_examples
288 |                 )
289 |             ]
290 |         else:
291 |             with ProcessPoolExecutor() as executor:
292 |                 eqs = executor.map(func, tqdm(zip(operands, rhs), total=num_examples))
293 | 
294 |         return eqs
295 | 
296 |     # @staticmethod
297 |     # def _make_s5_data(abstract=False) -> List[str]:
298 |     #    elems = itertools.permutations([0, 1, 2, 3, 4])
299 |     #    pairs = itertools.product(elems, repeat=2)
300 |     #    eqs = []
301 |     #    for a, b in pairs:
302 |     #        a = np.array(a)
303 |     #        b = np.array(b)
304 |     #        c = b[a]
305 |     #        eq = " ".join(map(render, (a, "s5", b, "=", c)))
306 |     #        eq = cls._render_eq([a, , b, "=", c])
307 |     #        eqs.append(eq)
308 |     #
309 |     #    return eqs
310 | 
311 |     @classmethod
312 |     def get_dsname(cls, operator, operand_length) -> str:
313 |         operator, noise_level = cls._get_operator_and_noise_level(operator)
314 |         ds_name = VALID_OPERATORS[operator]
315 |         if operand_length is not None:
316 |             ds_name += f"_length-{operand_length}"
317 |         if noise_level > 0:
318 |             ds_name += f"_noise-{noise_level}"
319 |         return ds_name
320 | 
321 |     @classmethod
322 |     def get_file_path(cls, operator, operand_length=None, data_dir=DEFAULT_DATA_DIR):
323 |         ds_name = cls.get_dsname(operator, operand_length)
324 |         ds_file = bf.join(data_dir, f"{ds_name}_data.txt")
325 |         return ds_file, ds_name
326 | 
327 |     @classmethod
328 |     def _get_operator_and_noise_level(cls, operator):
329 |         if "_noisy" in operator:
330 |             operator, noise_level = operator.split("_noisy_")
331 |             return operator, int(noise_level)
332 |         else:
333 |             return operator, 0
334 | 
335 |     @classmethod
336 |     def make_data(cls, operator, operands=None, shuffle=True, seed=0) -> List[str]:
337 |         operator, noise_level = cls._get_operator_and_noise_level(operator)
338 |         assert operator in VALID_OPERATORS
339 | 
340 |         if operator not in ["sort", "reverse", "copy"]:
341 |             data = cls._make_binary_operation_data(operator)
342 |         else:
343 |             data = cls._make_unary_operation_data(operator, operands)
344 | 
345 |         rng = np.random.RandomState(seed=seed)
346 |         if shuffle:
347 |             rng.shuffle(data)
348 | 
349 |         if noise_level > 0:
350 |             random_answer_eqns = rng.choice(data, size=noise_level)
351 |             random_answers = [
352 |                 random_eq.split(" = ")[1] for random_eq in random_answer_eqns
353 |             ]
354 |             for i in range(noise_level):
355 |                 data[i] = data[i].split(" = ")[0] + " = " + random_answers[i]
356 | 
357 |         data = [EOS_TOKEN + " " + eq + " " + EOS_TOKEN for eq in data]
358 | 
359 |         return data
360 | 
361 |     # @classmethod
362 |     # def create_data_file(
363 |     #    cls, operator, operand_length=None, shuffle=True, data_dir=DEFAULT_DATA_DIR
364 |     # ):
365 |     #    if VALID_OPERATORS[operator]["binary_eval"]:
366 |     #        cls.write_dataset(
367 |     #            cls.make_binary_operation_data(operator), paths["ds_file"]
368 |     #        )
369 |     #
370 |     #    pass
371 | 
372 |     # @classmethod
373 |     # def write_dataset(eqs: List[str], ds_file: str):
374 |     #    print(f"-> writing {ds_file}", flush=True)
375 |     #    with open(ds_file, "w") as fh:
376 |     #        fh.writelines([EOS_TOKEN + " " + eq + " " + EOS_TOKEN + "\n" for eq in eqs])
377 | 
378 |     @classmethod
379 |     def _make_lists(cls, sizes=[2, 3], nums=NUMS):
380 |         lists: dict = {}
381 |         for size in sizes:
382 |             lists[size] = torch.tensor(
383 |                 list(itertools.permutations(nums, r=size)),
384 |                 dtype=torch.int,
385 |             )
386 |         return lists
387 | 
388 | 
389 | class ArithmeticIterator(torch.utils.data.IterableDataset):
390 |     """
391 |     An iterator over batches of data in an ArithmeticDataset
392 |     """
393 | 
394 |     def __init__(
395 |         self,
396 |         dataset: ArithmeticDataset,
397 |         device: torch.device,
398 |         batchsize_hint: float = 0,
399 |         shuffle: bool = True,
400 |     ) -> None:
401 |         """
402 |         :param dataset: the dataset to iterate over
403 |         :param device: the torch device to send batches to
404 |         :param batchsize_hint: * 0 means we use a default batchsize
405 |                                * -1 means the entire dataset
406 |                                * float between 0 and 1 means each batch is
407 |                                  that fraction of the DS
408 |                                * int > 1 means that specific batch size
409 |         :param shuffle: whether or not to randomly shuffle the dataset
410 |         """
411 |         self.dataset = dataset
412 |         self.batchsize = self.calculate_batchsize(
413 |             len(dataset), batchsize_hint=batchsize_hint
414 |         )
415 |         self.device = device
416 |         self.reset_iteration(shuffle=shuffle)
417 | 
418 |     @staticmethod
419 |     def calculate_batchsize(ds_size: int, batchsize_hint: int = 0) -> int:
420 |         """
421 |         Calculates which batch size to use
422 | 
423 |         :param ds_size: the number of equations in the dataset
424 |         :param batchsize_hint: * 0 means we use a default batchsize
425 |                                * -1 means the entire dataset
426 |                                * float between 0 and 1 means each batch is
427 |                                  that fraction of the DS
428 |                                * int > 1 means that specific batch size
429 |         :returns: the actual batchsize to use
430 |         """
431 | 
432 |         if batchsize_hint == -1:
433 |             return ds_size
434 |         elif batchsize_hint == 0:
435 |             return min(512, math.ceil(ds_size / 2.0))
436 |         elif (batchsize_hint > 0) and (batchsize_hint < 1):
437 |             return math.ceil(ds_size * batchsize_hint)
438 |         elif batchsize_hint > 1:
439 |             return min(batchsize_hint, ds_size)
440 |         else:
441 |             raise ValueError("batchsize_hint must be >= -1")
442 | 
443 |     def reset_iteration(self, shuffle=True):
444 |         self.index = 0
445 |         if shuffle and self.dataset.train:
446 |             self.permutation = torch.randperm(len(self.dataset))
447 |         else:
448 |             self.permutation = torch.arange(len(self.dataset))
449 | 
450 |     def __iter__(self):
451 |         """
452 |         :returns: this iterator
453 |         """
454 |         return self
455 | 
456 |     def __next__(self) -> Dict[str, Tensor]:
457 |         """
458 |         Returns one batch of data.
459 | 
460 |         :raises: StopIteration when we're out of data
461 |         :returns: batch tensor of shape (self.batchsize, tokens_per_eq)
462 |         """
463 | 
464 |         batch_begin = self.index * self.batchsize
465 |         if batch_begin > len(self.dataset) - 1:
466 |             self.reset_iteration()
467 |             raise StopIteration
468 |         indices = self.permutation[batch_begin : batch_begin + self.batchsize]
469 |         text = self.dataset.data[indices, :-1]
470 |         target = self.dataset.data[indices, 1:]
471 |         batch = {"text": text.to(self.device), "target": target.to(self.device)}
472 |         self.index += 1
473 |         return batch
474 | 
475 |     def __len__(self) -> int:
476 |         """
477 |         :returns: the total number of batches
478 |         """
479 |         return math.ceil(len(self.dataset) / self.batchsize)
480 | 


--------------------------------------------------------------------------------
/grok/measure.py:
--------------------------------------------------------------------------------
  1 | import logging
  2 | import torch
  3 | import numpy as np
  4 | 
  5 | import scipy.optimize
  6 | 
  7 | 
  8 | def get_loss_and_grads(x, model, data_loader):
  9 | 
 10 |     # if type(x).__module__ == np.__name__:
 11 |     #     x = torch.from_numpy(x).float()
 12 |     #     x = x.cuda()
 13 | 
 14 |     model.eval()
 15 | 
 16 |     x_start = 0
 17 |     for p in model.parameters():
 18 |         param_size = p.data.size()
 19 |         param_idx = 1
 20 |         for s in param_size:
 21 |             param_idx *= s
 22 |         x_part = x[x_start : x_start + param_idx]
 23 |         p.data = torch.Tensor(x_part.reshape(param_size))
 24 |         x_start += param_idx
 25 | 
 26 |     batch_losses = []
 27 |     batch_grads = []
 28 |     for it, batch in enumerate(data_loader):
 29 | 
 30 |         # Move data to correct device
 31 |         # inputs = inputs.to(device)
 32 |         # targets = targets.to(device)
 33 | 
 34 |         with torch.set_grad_enabled(True):
 35 |             # loss, grads = model(idx=inputs, targets=targets, grads=True)
 36 |             loss, grads = model._step(batch=batch, batch_idx=1, train=True, grads=True)
 37 | 
 38 |         # Todo: average over dataset
 39 |         batch_losses.append(loss)
 40 |         # batch_grads.append(None if grads is None else grads.cpu().numpy().astype(np.float64))
 41 |         batch_grads.append(None if grads is None else grads)
 42 | 
 43 |     mean_losses = torch.mean(torch.stack(batch_losses))
 44 |     mean_grads = torch.mean(torch.stack(batch_grads), dim=0)
 45 | 
 46 |     return (mean_losses, mean_grads.cpu().numpy().astype(np.float64))
 47 | 
 48 | 
 49 | def get_weights(model):
 50 |     """
 51 |     Given a model, return a vector of weights.
 52 |     """
 53 |     x0 = None
 54 |     for p in model.parameters():
 55 |         if x0 is None:
 56 |             x0 = p.data.view(-1)
 57 |         else:
 58 |             x0 = torch.cat((x0, p.data.view(-1)))
 59 |     return x0.cpu().numpy()
 60 | 
 61 | 
 62 | def get_sharpness(data_loader, model, subspace_dim=10, epsilon=1e-3, maxiter=10):
 63 |     """
 64 |     Compute the sharpness around some point in weight space, as specified
 65 |     in Keskar et. al. (2016) Sec 2.2.2:
 66 |     https://arxiv.org/pdf/1609.04836.pdf
 67 | 
 68 |     See:
 69 |         https://gist.github.com/arthurmensch/c55ac413868550f89225a0b9212aa4cd
 70 |         https://gist.github.com/gngdb/a9f912df362a85b37c730154ef3c294b
 71 |         https://github.com/keskarnitish/large-batch-training
 72 |         https://github.com/wenwei202/smoothout
 73 |         https://github.com/keras-team/keras/pull/3064
 74 |     """
 75 | 
 76 |     x0 = get_weights(model)
 77 | 
 78 |     f_x0, _ = get_loss_and_grads(x0, model, data_loader)
 79 |     f_x0 = -f_x0
 80 |     logging.info("min loss f_x0 = {loss:.4f}".format(loss=f_x0))
 81 | 
 82 |     if 0 == subspace_dim:
 83 |         x_min = np.reshape(x0 - epsilon * (np.abs(x0) + 1), (x0.shape[0], 1))
 84 |         x_max = np.reshape(x0 + epsilon * (np.abs(x0) + 1), (x0.shape[0], 1))
 85 |         bounds = np.concatenate([x_min, x_max], 1)
 86 |         func = lambda x: get_loss_and_grads(x, model, data_loader)
 87 |         init_guess = x0
 88 |     else:
 89 |         assert subspace_dim <= x0.shape[0]
 90 | 
 91 |         # Computed via Keskar, et. al
 92 |         # https://arxiv.org/pdf/1609.04836.pdf
 93 | 
 94 |         A_plus = np.random.rand(subspace_dim, x0.shape[0]) * 2.0 - 1.0
 95 |         A_plus_norm = np.linalg.norm(A_plus, axis=1)
 96 |         A_plus = A_plus / np.reshape(A_plus_norm, (subspace_dim, 1))
 97 |         A = np.linalg.pinv(A_plus)
 98 | 
 99 |         abs_bound = epsilon * (np.abs(np.dot(A_plus, x0)) + 1)
100 |         abs_bound = np.reshape(abs_bound, (abs_bound.shape[0], 1))
101 |         bounds = np.concatenate([-abs_bound, abs_bound], 1)
102 | 
103 |         def func(y):
104 |             f_loss, f_grads = get_loss_and_grads(
105 |                 x0 + np.dot(A, y),
106 |                 model,
107 |                 data_loader,
108 |             )
109 |             return f_loss, np.dot(np.transpose(A), f_grads)
110 | 
111 |         init_guess = np.zeros(subspace_dim)
112 | 
113 |     minimum_x, f_x, d = scipy.optimize.fmin_l_bfgs_b(
114 |         func,
115 |         init_guess,
116 |         maxiter=maxiter,
117 |         bounds=bounds,
118 |         disp=1,
119 |     )
120 |     f_x = -f_x
121 |     logging.info("max loss f_x = {loss:.4f}".format(loss=f_x))
122 | 
123 |     # Eq 4 in Keskar
124 |     phi = (f_x - f_x0) / (1 + f_x0) * 100
125 | 
126 |     # Restore parameter values
127 |     x0 = torch.from_numpy(x0).float()
128 |     # x0 = x0.cuda()
129 |     x_start = 0
130 |     for p in model.parameters():
131 |         param_size = p.data.size()
132 |         param_idx = 1
133 |         for s in param_size:
134 |             param_idx *= s
135 |         x_part = x0[x_start : x_start + param_idx]
136 |         p.data = x_part.view(param_size)
137 |         x_start += param_idx
138 | 
139 |     return phi
140 | 


--------------------------------------------------------------------------------
/grok/metrics.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import math
  3 | import copy
  4 | import torch.nn as nn
  5 | from typing import Callable
  6 | 
  7 | # References:
  8 | # https://github.com/nitarshan/robust-generalization-measures
  9 | # https://github.com/bneyshabur/generalization-bounds
 10 | # https://github.com/bneyshabur/over-parametrization
 11 | 
 12 | 
 13 | def compute_measure(
 14 |     model: nn.Module,
 15 |     init_model: nn.Module,
 16 |     measure_func: Callable,
 17 |     operator: str,
 18 |     kwargs: dict = {},
 19 |     p: int = 1,
 20 | ) -> float:
 21 |     """
 22 |     Computes measure value for each layer given trained network and network at
 23 |     initialization.  Then aggregates values per layer using specified operator.
 24 | 
 25 |     :param model: trained network
 26 |     :param init_model: network at initialization
 27 |     :param measure_func: callable for the measure to compute
 28 |     :param operator: 'log_product', 'sum', 'max', 'product', or 'norm'
 29 |     :param p: p in L^p
 30 |     :return: value of the desired measure
 31 |     """
 32 | 
 33 |     measure_value = 0
 34 |     # weight_modules = ["Linear", "Embedding"]
 35 |     weight_modules = ["Linear"]
 36 | 
 37 |     if operator == "product":
 38 |         measure_value = math.exp(
 39 |             compute_measure(model, init_model, measure_func, "log_product", kwargs, p)
 40 |         )
 41 |     elif operator == "norm":
 42 |         measure_value = (
 43 |             compute_measure(model, init_model, measure_func, "sum", kwargs, p=p)
 44 |         ) ** (1 / p)
 45 |     else:
 46 |         measure_value = 0
 47 |         for child, init_child in zip(model.children(), init_model.children()):
 48 |             module_name = child._get_name()
 49 |             if module_name in weight_modules:
 50 |                 if operator == "log_product":
 51 |                     measure_value += math.log(measure_func(child, init_child, **kwargs))
 52 |                 elif operator == "sum":
 53 |                     measure_value += (measure_func(child, init_child, **kwargs)) ** p
 54 |                 elif operator == "max":
 55 |                     measure_value = max(
 56 |                         measure_value, measure_func(child, init_child, **kwargs)
 57 |                     )
 58 |             else:
 59 |                 measure_value += compute_measure(
 60 |                     child, init_child, measure_func, operator, kwargs, p=p
 61 |                 )
 62 |     return measure_value
 63 | 
 64 | 
 65 | def norm(module, init_module, p=2, q=2):
 66 |     """
 67 |     Calculates l_pq norm of a parameter matrix
 68 |       l_p norm of incoming weights to each hidden unit
 69 |       l_q norm on the hidden units
 70 |     """
 71 |     return module.weight.view(module.weight.size(0), -1).norm(p=p, dim=1).norm(q).item()
 72 | 
 73 | 
 74 | def op_norm(module, init_module, p=float("Inf")):
 75 |     """
 76 |     Calculates l_p norm of eigenvalues of parameter matrix
 77 |     """
 78 |     _, S, _ = module.weight.view(module.weight.size(0), -1).svd()
 79 |     return S.norm(p).item()
 80 | 
 81 | 
 82 | def dist(module, init_module, p=2, q=2):
 83 |     """
 84 |     Calculates l_pq distance of the parameter matrix of a layer from the random
 85 |     initialization:
 86 |         l_p norm of incoming weights to each hidden unit
 87 |         l_q norm on the hidden units
 88 |     """
 89 |     return (
 90 |         (module.weight - init_module.weight)
 91 |         .view(module.weight.size(0), -1)
 92 |         .norm(p=p, dim=1)
 93 |         .norm(q)
 94 |         .item()
 95 |     )
 96 | 
 97 | 
 98 | def h_dist(module, init_module, p=2, q=2):
 99 |     """
100 |     Calculate l_pq distance of parameters of trained network from random init
101 |     Includes extra factor depending on number of hidden units
102 |     """
103 |     return (n_hidden(module, init_module) ** (1 - 1 / q)) * dist(
104 |         module, init_module, p=p, q=q
105 |     )
106 | 
107 | 
108 | def h_dist_op_norm(module, init_module, p=2, q=2, p_op=float("Inf")):
109 |     """
110 |     Calculate ratio of h_dist to operator norm
111 |     """
112 |     return h_dist(module, init_module, p=p, q=q) / op_norm(module, init_module, p=p_op)
113 | 
114 | 
115 | def n_hidden(module, init_module):
116 |     """
117 |     Number of hidden units
118 |     """
119 |     return module.weight.size(0)
120 | 
121 | 
122 | def depth(module, init_module):
123 |     """
124 |     Depth (always == 1 for any linear layer)
125 |     """
126 |     return 1
127 | 
128 | 
129 | def n_param(module, init_module):
130 |     """
131 |     Num parameters
132 |     """
133 |     bparam = 0 if module.bias is None else module.bias.size(0)
134 |     return bparam + module.weight.size(0) * module.weight.view(
135 |         module.weight.size(0), -1
136 |     ).size(1)
137 | 
138 | 
139 | def lp_path_norm(model, device, p=2, input_size=[3, 32, 32]):
140 |     """
141 |     Path norm (Neyshabur 2015)
142 |     """
143 | 
144 |     tmp_model = copy.deepcopy(model)
145 |     tmp_model.eval()
146 |     for param in tmp_model.parameters():
147 |         if param.requires_grad:
148 |             param.abs_().pow_(p)
149 |     data_ones = torch.ones(input_size).to(device)
150 |     return (tmp_model(data_ones).sum() ** (1 / p)).item()
151 | 
152 | 
153 | def calculate(trained_model, init_model, device, dataset_size, margin, input_dim):
154 |     """
155 |     Calculates various measures given trained model and model at init
156 |     Computes:
157 |         measures: norm based measures on the model
158 |         bounds: generalization bounds on the model
159 |     """
160 | 
161 |     model = copy.deepcopy(trained_model)
162 | 
163 |     # depth
164 |     d = compute_measure(model, init_model, depth, "sum", {})
165 | 
166 |     # number of parameters (not including batch norm)
167 |     nparam = compute_measure(model, init_model, n_param, "sum", {})
168 | 
169 |     measure, bound = {}, {}
170 |     with torch.no_grad():
171 | 
172 |         # Compute measures
173 |         measure["L_{1,inf} norm"] = (
174 |             compute_measure(
175 |                 model, init_model, norm, "product", {"p": 1, "q": float("Inf")}
176 |             )
177 |             / margin
178 |         )
179 |         measure["Frobenius norm"] = (
180 |             compute_measure(model, init_model, norm, "product", {"p": 2, "q": 2})
181 |             / margin
182 |         )
183 |         measure["L_{3,1.5} norm"] = (
184 |             compute_measure(model, init_model, norm, "product", {"p": 3, "q": 1.5})
185 |             / margin
186 |         )
187 |         measure["Spectral norm"] = (
188 |             compute_measure(model, init_model, op_norm, "product", {"p": float("Inf")})
189 |             / margin
190 |         )
191 |         measure["L_1.5 operator norm"] = (
192 |             compute_measure(model, init_model, op_norm, "product", {"p": 1.5}) / margin
193 |         )
194 |         measure["Trace norm"] = (
195 |             compute_measure(model, init_model, op_norm, "product", {"p": 1}) / margin
196 |         )
197 | 
198 |         # input_size = [context_len, emb_dim]
199 |         # measure["L1_path norm"] = (
200 |         #     lp_path_norm(
201 |         #         model, device, p=1, input_size=input_size
202 |         #     )
203 |         #     / margin
204 |         # )
205 |         # measure["L1.5_path norm"] = (
206 |         #     lp_path_norm(
207 |         #         model, device, p=1.5, input_size=input_size
208 |         #     )
209 |         #     / margin
210 |         # )
211 |         # measure["L2_path norm"] = (
212 |         #     lp_path_norm(
213 |         #         model, device, p=2, input_size=input_size
214 |         #     )
215 |         #     / margin
216 |         # )
217 | 
218 |         # Compute generalization bounds without constant or additive logarithmic factors
219 | 
220 |         # Golowich 2018
221 |         # https://arxiv.org/pdf/1712.06541.pdf
222 |         alpha = math.sqrt(d + math.log(1 * input_dim * input_dim))
223 | 
224 |         # Bartlett Mendelson 2002
225 |         bound["L1_max Bound"] = (
226 |             alpha * measure["L_{1,inf} norm"] / math.sqrt(dataset_size)
227 |         )
228 | 
229 |         # Neyshabur 2015
230 |         bound["Frobenius Bound"] = (
231 |             alpha * measure["Frobenius norm"] / math.sqrt(dataset_size)
232 |         )
233 | 
234 |         # Neyshabur 2015
235 |         bound["L_{3,1.5} Bound"] = (
236 |             alpha * measure["L_{3,1.5} norm"] / (dataset_size ** (1 / 3))
237 |         )
238 | 
239 |         beta = math.log(dataset_size) * math.log(nparam)
240 |         ratio = compute_measure(
241 |             model,
242 |             init_model,
243 |             h_dist_op_norm,
244 |             "norm",
245 |             {"p": 2, "q": 1, "p_op": float("Inf")},
246 |             p=2 / 3,
247 |         )
248 | 
249 |         # Spectral L_{2, 1} Bound
250 |         # Bartlett 2017
251 |         bound["Spec_L_{2,1} Bound"] = (
252 |             beta * measure["Spectral norm"] * ratio / math.sqrt(dataset_size)
253 |         )
254 | 
255 |         ratio = compute_measure(
256 |             model,
257 |             init_model,
258 |             h_dist_op_norm,
259 |             "norm",
260 |             {"p": 2, "q": 2, "p_op": float("Inf")},
261 |             p=2,
262 |         )
263 | 
264 |         # Spectral Frobenius
265 |         # Neyshabur 2018
266 |         # https://arxiv.org/pdf/1706.08947.pdf
267 |         bound["Spec_Fro Bound"] = (
268 |             d * measure["Spectral norm"] * ratio / math.sqrt(dataset_size)
269 |         )
270 | 
271 |     return measure, bound
272 | 


--------------------------------------------------------------------------------
/grok/training.py:
--------------------------------------------------------------------------------
   1 | #!/usr/bin/env python
   2 | 
   3 | import argparse
   4 | import copy
   5 | import json
   6 | import logging
   7 | import math
   8 | import os
   9 | import sys
  10 | import pickle
  11 | from argparse import ArgumentParser, Namespace
  12 | from functools import reduce
  13 | from typing import Any, Dict, List, Optional, Tuple, Union
  14 | import time
  15 | 
  16 | import numpy as np
  17 | import torch
  18 | import torch.nn.functional as F
  19 | from pytorch_lightning import LightningModule, Trainer
  20 | from pytorch_lightning.callbacks import Callback, ModelCheckpoint
  21 | from pytorch_lightning.loggers import CSVLogger
  22 | from torch import Tensor
  23 | from torch.optim.lr_scheduler import LambdaLR
  24 | 
  25 | import grok.metrics as metrics
  26 | from grok.data import (
  27 |     DEFAULT_DATA_DIR,
  28 |     EOS_TOKEN,
  29 |     VALID_OPERATORS,
  30 |     ArithmeticDataset,
  31 |     ArithmeticIterator,
  32 | )
  33 | from grok.transformer import Transformer
  34 | from grok.measure import get_sharpness
  35 | 
  36 | DEFAULT_LOG_DIR = "logs"
  37 | 
  38 | 
  39 | class TrainableTransformer(LightningModule):
  40 |     """
  41 |     Adds training methods to train a generic transformer on arithmetic equations
  42 |     """
  43 | 
  44 |     def __init__(self, hparams: Namespace) -> None:
  45 |         """
  46 |         :param hparams: An argparse.Namespace with parameters defined in
  47 |                         self.add_model_specific_args().
  48 |         """
  49 |         super().__init__()
  50 |         self.hparams = hparams  # type: ignore
  51 |         self.prepare_data()
  52 | 
  53 |         self.transformer = Transformer(
  54 |             hparams.n_layers,
  55 |             hparams.n_heads,
  56 |             hparams.d_model,
  57 |             hparams.dropout,
  58 |             hparams.max_context_len,
  59 |             len(self.train_dataset.tokenizer),
  60 |             hparams.non_linearity,
  61 |             weight_noise=self.hparams.weight_noise,
  62 |         )
  63 | 
  64 |         self.margin = torch.Tensor([0])
  65 |         self.next_epoch_to_eval = -1
  66 |         self.next_train_epoch_to_log = 0
  67 | 
  68 |     @staticmethod
  69 |     def add_model_specific_args(parser: ArgumentParser) -> ArgumentParser:
  70 |         """
  71 |         Defines the hyperparameter arguments needed by instances of this
  72 |         class. This is intended to be called when parsing command line
  73 |         arguments.
  74 | 
  75 |         :param parser: an argparse.ArgumentParser created by the caller
  76 |         :returns: the argument parser with the command line arguments added
  77 |                   for this class.
  78 |         """
  79 |         parser.add_argument(
  80 |             "--batchsize",
  81 |             type=float,
  82 |             # default=0.25,
  83 |             default=0,
  84 |             help="-1 -> entire dataset, 0 -> auto-calculate, 0<N<1 -> fraction of dataset, N>1 -> N",
  85 |         )
  86 | 
  87 |         parser.add_argument("--n_layers", type=int, default=2)
  88 |         parser.add_argument("--n_heads", type=int, default=4)
  89 |         parser.add_argument("--d_model", type=int, default=128)
  90 |         parser.add_argument("--dropout", type=float, default=0.0)
  91 |         parser.add_argument("--weight_noise", type=float, default=0.0)
  92 |         parser.add_argument("--non_linearity", type=str, default="relu")
  93 |         parser.add_argument("--max_context_len", type=int, default=50)
  94 | 
  95 |         parser.add_argument("--math_operator", type=str, default="+")
  96 |         parser.add_argument(
  97 |             "--operand_length",
  98 |             type=int,
  99 |             help="for list operations, the length of the lists",
 100 |         )
 101 | 
 102 |         parser.add_argument("--train_data_pct", type=float, default=5)
 103 |         parser.add_argument("--warmup_steps", type=int, default=10)
 104 |         parser.add_argument("--anneal_lr_steps", type=int, default=100000)
 105 |         parser.add_argument("--anneal_lr", dest="anneal_lr", action="store_true")
 106 |         parser.set_defaults(anneal_lr=False)
 107 | 
 108 |         parser.add_argument("--max_lr", type=float, default=1e-3)
 109 |         parser.add_argument("--weight_decay", type=float, default=0)
 110 |         parser.add_argument("--weight_decay_kind", type=str, default="to_zero")
 111 |         parser.add_argument("--noise_factor", type=float, default=0)
 112 | 
 113 |         parser.add_argument(
 114 |             "--save_activations", dest="save_activations", action="store_true"
 115 |         )
 116 |         parser.set_defaults(save_activations=False)
 117 |         parser.add_argument("--save_outputs", dest="save_outputs", action="store_true")
 118 |         parser.set_defaults(save_outputs=False)
 119 | 
 120 |         parser.add_argument(
 121 |             "--logdir",
 122 |             type=str,
 123 |             default=DEFAULT_LOG_DIR,
 124 |         )
 125 |         parser.add_argument(
 126 |             "--datadir",
 127 |             type=str,
 128 |             default=DEFAULT_DATA_DIR,
 129 |         )
 130 | 
 131 |         return parser
 132 | 
 133 |     def prepare_data(self) -> None:
 134 |         """
 135 |         Used by pytorch_lighting
 136 | 
 137 |         Loads training data to self.train_dataset
 138 |         Loads validation data to self.val_dataset
 139 |         """
 140 |         (self.train_dataset, self.val_dataset,) = ArithmeticDataset.splits(
 141 |             train_pct=self.hparams.train_data_pct,  # type: ignore
 142 |             operator=self.hparams.math_operator,  # type: ignore
 143 |             operand_length=self.hparams.operand_length,  # type: ignore
 144 |             data_dir=self.hparams.datadir,  # type: ignore
 145 |         )
 146 | 
 147 |     def train_dataloader(self) -> ArithmeticIterator:  # type: ignore
 148 |         """
 149 |         Used by pytorch_lighting
 150 | 
 151 |         :returns: an iterator for self.train_dataset
 152 |         """
 153 |         device = self.transformer.embedding.weight.device
 154 |         iterator = ArithmeticIterator(
 155 |             self.train_dataset,
 156 |             device,
 157 |             batchsize_hint=self.hparams.batchsize,  # type: ignore
 158 |         )
 159 |         self.train_batchsize = iterator.batchsize
 160 |         self.batches_per_epoch = len(iterator)
 161 | 
 162 |         return iterator
 163 | 
 164 |     def val_dataloader(self) -> ArithmeticIterator:  # type: ignore
 165 |         """
 166 |         Used by pytorch_lighting
 167 | 
 168 |         :returns: an iterator for self.train_dataset
 169 |         """
 170 |         device = self.transformer.embedding.weight.device
 171 |         iterator = ArithmeticIterator(
 172 |             self.val_dataset,
 173 |             device,
 174 |             batchsize_hint=-1,  # no need to batch validation data
 175 |         )
 176 |         return iterator
 177 | 
 178 |     def test_dataloader(self) -> ArithmeticIterator:  # type: ignore
 179 |         """
 180 |         Used by pytorch_lighting
 181 | 
 182 |         :returns: an iterator for self.train_dataset
 183 |         """
 184 |         device = self.transformer.embedding.weight.device
 185 |         iterator = ArithmeticIterator(
 186 |             self.val_dataset, device, batchsize_hint=-1  # type: ignore
 187 |         )
 188 |         return iterator
 189 | 
 190 |     def _scheduler_lr(self, step: int) -> float:
 191 |         """
 192 |         Used by pytorch_lighting
 193 | 
 194 |         :returns: the learning_rate for this training step
 195 |         """
 196 |         max_lr = self.hparams.max_lr  # type: ignore
 197 |         min_lr = self.hparams.max_lr / 10  # type: ignore
 198 |         warmup_steps = self.hparams.warmup_steps  # type: ignore
 199 |         if not self.hparams.anneal_lr:
 200 |             if step <= warmup_steps:
 201 |                 lr = (float(step) / max(warmup_steps, 1)) * max_lr
 202 |             else:
 203 |                 lr = max_lr
 204 |         else:
 205 |             if step <= warmup_steps:
 206 |                 lr = (float(step) / max(warmup_steps, 1)) * max_lr
 207 |             elif step <= self.hparams.anneal_lr_steps + warmup_steps:
 208 |                 effective_step = step - warmup_steps
 209 |                 t = effective_step / self.hparams.anneal_lr_steps
 210 |                 cos = (1 + np.cos(np.pi * t)) / 2
 211 |                 lr = min_lr + (max_lr - min_lr) * cos
 212 |                 # lr = max_lr - ((effective_step / max_effective_step) * (max_lr - min_lr))
 213 |             else:
 214 |                 lr = min_lr
 215 |         return lr
 216 | 
 217 |     def configure_optimizers(self) -> Tuple[List[Any], List[Dict]]:
 218 |         """
 219 |         Used by pytorch_lighting
 220 | 
 221 |         :returns: optimizers and schedulers.
 222 |         """
 223 |         optimizer = CustomAdamW(
 224 |             self.parameters(),
 225 |             betas=(0.9, 0.98),
 226 |             eps=1e-8,
 227 |             lr=1,
 228 |             weight_decay=self.hparams.weight_decay,
 229 |             noise_factor=self.hparams.noise_factor,
 230 |             weight_decay_form=self.hparams.weight_decay_kind,
 231 |         )
 232 |         # optimizer = SAM(
 233 |         #     self.parameters(),
 234 |         #     base_optimizer=CustomAdamW,
 235 |         #     rho=0.05,
 236 |         #     betas=(0.9, 0.98),
 237 |         #     eps=1e-8,
 238 |         #     lr=1,
 239 |         #     weight_decay=self.hparams.weight_decay,
 240 |         #     noise_factor=self.hparams.noise_factor,
 241 |         # )
 242 |         schedulers = [
 243 |             {
 244 |                 "scheduler": LambdaLR(optimizer, lr_lambda=self._scheduler_lr),
 245 |                 "interval": "step",
 246 |                 "frequency": 1,
 247 |             }
 248 |         ]
 249 |         return [optimizer], schedulers
 250 | 
 251 |     def _accuracy(self, y_hat: Tensor, y: Tensor) -> Tensor:
 252 |         """
 253 |         Takes the most likely solution predicted for each equation and
 254 |         calculates the frac of equations in the batch for which these
 255 |         answers were correct
 256 | 
 257 |         :param y_hat: The softmax tensor output of the transformer
 258 |         :param y: A tensor of the token ids for the correct answers to each
 259 |                   equation in the batch
 260 |         :returns: the fraction of equations correctly answered
 261 |         """
 262 | 
 263 |         # find max prediction from output
 264 |         y_hat = torch.max(y_hat, dim=-2).indices  # batchsize x num_rhs_tokens
 265 |         row_accuracy = torch.min((y_hat == y), dim=-1).values  # shape: batchsize
 266 |         accuracy = row_accuracy.float() * 100  # shape: batchsize
 267 |         return accuracy
 268 | 
 269 |     def _step(
 270 |         self,
 271 |         batch: Dict,
 272 |         batch_idx: int,
 273 |         train: bool = True,
 274 |         reduction: str = "mean",
 275 |         grads: bool = False,
 276 |     ) -> Tuple[Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor]:
 277 |         """
 278 |         Performs one forward pass on a training or validation batch
 279 | 
 280 |         :param batch: The batch of equations to process
 281 |         :param batch_idx: which batch this is in the epoch.
 282 |         :param train: True is this is a training batch, false otherwise
 283 |         :returns: The loss from the predicted solutions to the equation,
 284 |                   The accuracy of the predicted solutions
 285 |                   The fraction of this dataset contained in this batch
 286 |                   The portion of the input equations left of the equal sign
 287 |                   The softmax probilities for the solutions to the equations
 288 |                   A list lists of attention matrices by layer and head
 289 |                   A list lists of value matrices by layer and head
 290 |                   Margin for this batch
 291 |         """
 292 |         x = batch["text"]  # shape = batchsize * context_len
 293 |         y = batch["target"]  # shape = batchsize * context_len
 294 |         y_hat, attentions, values = self(
 295 |             x=x, save_activations=self.hparams.save_activations  # type: ignore
 296 |         )  # shape = batchsize * context_len * vocab_size
 297 |         y_hat = y_hat.transpose(-2, -1)  # shape = batchsize * vocab_size * context_len
 298 | 
 299 |         # Note: each sample must have exactly one '=' and all of them must
 300 |         # have it in the same position.
 301 |         eq_token_index = self.train_dataset.tokenizer.stoi["="]
 302 |         eq_position_t = torch.nonzero(y[0, :] == eq_token_index, as_tuple=False)
 303 |         eq_position = int(eq_position_t.squeeze())
 304 | 
 305 |         # only calculate loss/accuracy on right hand side of the equation
 306 |         y_rhs = y[..., eq_position + 1 :]
 307 |         y_hat_rhs = y_hat[..., eq_position + 1 :]
 308 |         x_lhs = x[..., : eq_position + 1]
 309 | 
 310 |         if train:
 311 |             coeff = float(batch["target"].shape[0]) / len(self.train_dataset)
 312 |         else:
 313 |             coeff = float(batch["target"].shape[0]) / len(self.val_dataset)
 314 |         loss = F.cross_entropy(y_hat_rhs, y_rhs, reduction=reduction)
 315 | 
 316 |         with torch.no_grad():
 317 |             acc = self._accuracy(y_hat_rhs, y_rhs)
 318 |             if reduction == "mean":
 319 |                 acc = acc.mean()
 320 | 
 321 |         """
 322 |         device = self.transformer.embedding.weight.device
 323 |         self.margin = self.margin.to(device)
 324 | 
 325 |         output = y_hat_rhs.clone()  # batchsize, vocabsize, rhs tokens
 326 |         output_m = output.clone()  # batchsize, vocabsize, rhs tokens
 327 |         target = y_rhs.clone()  # batchsize, rhs tokens
 328 | 
 329 |         for i in range(output.size(0)):  # batch
 330 |             for j in range(output.size(2)):  # rhs tokens
 331 |                 output_m[i, target[i, j], j] = output_m[i, :, j].min()
 332 | 
 333 |         for i in range(output.size(2)):  # rhs tokens
 334 |             output_compressed = output[:, target[:, i], i].squeeze().diag()
 335 |             output_m_compressed = (
 336 |                 output_m[:, output_m.max(dim=1).indices[:, i], i].squeeze().diag()
 337 |             )
 338 |             self.margin = torch.cat(
 339 |                 (
 340 |                     self.margin,
 341 |                     (output_compressed - output_m_compressed),
 342 |                 ),
 343 |                 0,
 344 |             )
 345 |         """
 346 |         grad_vec = None
 347 |         if grads:
 348 |             loss.backward()
 349 |             for p in self.parameters():
 350 |                 p.grad.data.div_(batch["text"].shape[0])
 351 |                 if grad_vec is None:
 352 |                     grad_vec = p.grad.data.view(-1)
 353 |                 else:
 354 |                     grad_vec = torch.cat((grad_vec, p.grad.data.view(-1)))
 355 |             return loss, grad_vec
 356 |         return loss, acc, coeff, x_lhs, y_hat_rhs, attentions, values
 357 | 
 358 | 
 359 |     def _save_inputs(self, outputs: Dict, ds: str) -> None:
 360 |         """
 361 |         Saves the input equations to disk for analysis later
 362 | 
 363 |         :param outputs: a list of tuples from self.training_step()
 364 |         :param ds: a string ('train' or 'val') naming which dataset
 365 |                    these inputs are from.
 366 |         :param train: True is this is a training batch, false otherwise
 367 |         """
 368 |         logdir = self.hparams.logdir + "/inputs/" + ds  # type: ignore
 369 |         os.makedirs(logdir, exist_ok=True)
 370 |         pickle_file = logdir + f"/{ds}.pt"
 371 | 
 372 |         x_lhs = torch.cat([x["x_lhs"] for x in outputs])
 373 |         with open(pickle_file, "wb") as fh:
 374 |             torch.save(x_lhs, fh)
 375 | 
 376 |     def _merge_batch_activations(
 377 |         self, partial_activations: List[List[Tensor]]
 378 |     ) -> List[List[Tensor]]:
 379 |         """
 380 |         Merges the head_attentions / head_values from all batches in
 381 |         this epoch.
 382 | 
 383 |         :param partial_activations: A list of
 384 |                                    (lists of lists of activations by layer and head)
 385 |         :returns: A lists of lists of activations by layer and head
 386 |         """
 387 |         # num_batches = len(partial_activations)
 388 |         num_layers = len(partial_activations[0])
 389 |         num_heads = len(partial_activations[0][0])
 390 |         activations: List = []
 391 |         for _ in range(num_layers):
 392 |             activations.append([])
 393 |             for _ in range(num_heads):
 394 |                 activations[-1].append([])
 395 | 
 396 |         for minibatch_activations in partial_activations:
 397 |             for l, layer_activations in enumerate(minibatch_activations):
 398 |                 for h, head_attn in enumerate(layer_activations):
 399 |                     # # print(f"head_attn = {head_attn}")
 400 |                     activations[l][h].append(head_attn)
 401 | 
 402 |         for l in range(num_layers):
 403 |             for h in range(num_heads):
 404 |                 activations[l][h] = torch.cat(activations[l][h])
 405 | 
 406 |         return activations
 407 | 
 408 |     def _save_activations(self, outputs: Dict, ds: str) -> None:
 409 |         """
 410 |         Saves activations out to disk for analysis later
 411 | 
 412 |         :param outputs: a list of tuples from self.training_step()
 413 |         """
 414 | 
 415 |         output: Dict[str, Any] = {}
 416 |         if self.hparams.save_outputs:  # type: ignore
 417 |             y_hat_rhs = torch.cat([x["y_hat_rhs"] for x in outputs])
 418 |             output["y_hat_rhs"] = y_hat_rhs
 419 |         if self.hparams.save_activations:  # type: ignore
 420 |             partial_attentions = list([o["partial_attentions"] for o in outputs])
 421 |             attentions = self._merge_batch_activations(partial_attentions)
 422 |             partial_values = list([o["partial_values"] for o in outputs])
 423 |             values = self._merge_batch_activations(partial_values)
 424 |             output["attentions"] = attentions
 425 |             output["values"] = values
 426 |         if self.hparams.save_outputs or self.hparams.save_activations:  # type: ignore
 427 |             logdir = self.hparams.logdir + "/outputs/" + ds  # type: ignore
 428 |             os.makedirs(logdir, exist_ok=True)
 429 |             pickle_file = logdir + f"/epoch_{self.current_epoch:010}.pt"
 430 |             with open(pickle_file, "wb") as fh:
 431 |                 torch.save(output, fh)
 432 | 
 433 |     def training_step(self, batch, batch_idx):
 434 |         """
 435 |         Used by pytorch_lightning
 436 |         Runs one forward training pass on one batch
 437 | 
 438 |         :param batch: The batch of equations to process
 439 |         :param batch_idx: which batch this is in the epoch.
 440 |         :returns: a dict with loss, accuracy, lr, probabilities of solutions,
 441 |                   attentions, and values
 442 |         """
 443 |         if batch_idx == 0:
 444 |             self.training_epoch_start_time = time.time()
 445 |             self.fwd_time_in_epoch = 0
 446 | 
 447 |         start = time.time()
 448 |         loss, accuracy, coeff, x_lhs, y_hat_rhs, attentions, values = self._step(
 449 |             batch=batch, batch_idx=batch_idx, train=True
 450 |         )
 451 |         self.fwd_time_in_epoch += time.time() - start
 452 | 
 453 |         schedulers = self.trainer.lr_schedulers[0]
 454 |         if self.current_epoch != self.next_train_epoch_to_log:
 455 |             return {"loss": loss}
 456 |         lr = schedulers["scheduler"].optimizer.param_groups[0]["lr"]
 457 |         output = {
 458 |             "loss": loss,
 459 |             "partial_train_loss": coeff * loss,
 460 |             "partial_train_accuracy": coeff * accuracy,
 461 |             "learning_rate": torch.tensor([lr]),
 462 |             "y_hat_rhs": y_hat_rhs,
 463 |             "partial_attentions": attentions,
 464 |             "partial_values": values,
 465 |         }
 466 |         if self.current_epoch == 0:
 467 |             output["x_lhs"] = x_lhs
 468 | 
 469 |         return output
 470 | 
 471 |     def training_epoch_end(self, outputs):
 472 |         """
 473 |         Used by pytorch_lightning
 474 |         Accumulates results of all forward training passes in this epoch
 475 | 
 476 |         :param outputs: a list of dicts from self.training_step()
 477 |         :param batch_idx: which batch this is in the epoch.
 478 |         :returns: a dict with loss, accuracy, lr, probabilities of solutions,
 479 |                   attentions, and values
 480 |         """
 481 |         epoch_is_to_be_logged = self.current_epoch == self.next_train_epoch_to_log
 482 |         if epoch_is_to_be_logged:
 483 |             self.next_train_epoch_to_log = max(
 484 |                 int(1.01 * self.next_train_epoch_to_log),
 485 |                 self.next_train_epoch_to_log + 1,
 486 |             )
 487 |             with torch.no_grad():
 488 |                 try:
 489 |                     loss = torch.stack([x["partial_train_loss"] for x in outputs]).sum()
 490 |                 except Exception as e:
 491 |                     print("!" * 80)
 492 |                     print(outputs)
 493 |                     raise e
 494 |                 perplexity = torch.exp(loss)
 495 |                 accuracy = torch.stack(
 496 |                     [x["partial_train_accuracy"] for x in outputs]
 497 |                 ).sum()
 498 |             # avg_lr = torch.stack([x["learning_rate"] for x in outputs]).mean()
 499 |             # max_lr = torch.stack([x["learning_rate"] for x in outputs]).max()
 500 |             # last_lr = outputs[-1]["learning_rate"]
 501 |             first_lr = outputs[0]["learning_rate"]
 502 | 
 503 |             if self.hparams.save_activations or self.hparams.save_outputs:
 504 |                 if self.current_epoch == 0:
 505 |                     self._save_inputs(outputs, ds="train")
 506 |                 self._save_activations(outputs, ds="train")
 507 | 
 508 |             logs = {
 509 |                 "train_loss": loss,
 510 |                 "train_accuracy": accuracy,
 511 |                 "train_perplexity": perplexity,
 512 |                 "learning_rate": first_lr,
 513 |                 "len_train_ds": len(self.train_dataset),
 514 |                 "len_val_ds": len(self.val_dataset),
 515 |                 "batches_per_epoch": self.batches_per_epoch,
 516 |                 "time_per_epoch": time.time() - self.training_epoch_start_time,
 517 |                 "fwd_time_in_epoch": self.fwd_time_in_epoch,
 518 |             }
 519 |             for k, v in logs.items():
 520 |                 self.log(k, v)
 521 | 
 522 |     def validation_step(self, batch, batch_idx):
 523 |         """
 524 |         Used by pytorch_lightning
 525 |         Runs one forward validation pass on one batch
 526 | 
 527 |         :param batch: The batch of equations to process
 528 |         :param batch_idx: which batch this is in the epoch.
 529 |         :returns: a dict with val_loss, val_accuracy, probabilities of solutions,
 530 |                   attentions, and values
 531 |         """
 532 |         if self.next_epoch_to_eval < self.current_epoch:
 533 |             self.next_epoch_to_eval = self.current_epoch
 534 |         if self.current_epoch != self.next_epoch_to_eval:
 535 |             return {}
 536 |         with torch.no_grad():
 537 |             loss, accuracy, coeff, x_lhs, y_hat_rhs, attentions, values = self._step(
 538 |                 batch=batch, batch_idx=batch_idx, train=False
 539 |             )
 540 |         output = {
 541 |             "partial_val_loss": coeff * loss,
 542 |             "partial_val_accuracy": coeff * accuracy,
 543 |             "y_hat_rhs": y_hat_rhs,
 544 |             "partial_attentions": attentions,
 545 |             "partial_values": values,
 546 |         }
 547 |         if self.current_epoch == 0:
 548 |             output["x_lhs"] = x_lhs
 549 | 
 550 |         return output
 551 | 
 552 |     def validation_epoch_end(self, outputs):
 553 |         """
 554 |         Used by pytorch_lightning
 555 |         Accumulates results of all forward validation passes in this epoch
 556 | 
 557 |         :param outputs: a list of dicts from self.validation_step()
 558 |         :param batch_idx: which batch this is in the epoch.
 559 |         :returns: a dict with val_loss, val_accuracy
 560 |         """
 561 |         validation_is_real = len(outputs[0]) != 0
 562 | 
 563 |         if validation_is_real:
 564 |             self.next_epoch_to_eval = max(
 565 |                 int(1.02 * self.next_epoch_to_eval), self.next_epoch_to_eval + 1
 566 |             )
 567 | 
 568 |             loss = torch.stack([x["partial_val_loss"] for x in outputs]).sum()
 569 |             perplexity = torch.exp(loss)
 570 |             accuracy = torch.stack([x["partial_val_accuracy"] for x in outputs]).sum()
 571 | 
 572 |             if self.hparams.save_activations or self.hparams.save_outputs:
 573 |                 if self.current_epoch == 0:
 574 |                     self._save_inputs(outputs, ds="val")
 575 |                 self._save_activations(outputs, ds="val")
 576 | 
 577 |             logs = {
 578 |                 "val_loss": loss,
 579 |                 "val_accuracy": accuracy,
 580 |                 "val_perplexity": perplexity,
 581 |             }
 582 |             for name, param in self.named_parameters():
 583 |                 # n parameters
 584 |                 n_params = param.numel()
 585 |                 # get the l2 norm of the parameter
 586 |                 logs["paramnorm_" + name] = torch.norm(
 587 |                     param, 2
 588 |                 ).detach().cpu().numpy() / np.sqrt(n_params)
 589 | 
 590 |             # train accuracy
 591 |             device = self.transformer.embedding.weight.device
 592 |             train_data = self.train_dataset.data.to(device)
 593 |             training_data = {"text": train_data[:, :-1], "target": train_data[:, 1:]}
 594 |             with torch.no_grad():
 595 |                 tr_loss, tr_acc, *_ = self._step(training_data, 0)
 596 |                 logs["full_train_loss"] = tr_loss
 597 |                 logs["full_train_acc"] = tr_acc
 598 | 
 599 |             for k, v in logs.items():
 600 |                 self.log(k, v)
 601 |         # save a checkpoint if the epoch is a power of 2
 602 |         if (
 603 |             self.current_epoch > 0
 604 |             and int(2 ** (int(np.log(self.current_epoch) / np.log(2))))
 605 |             == self.current_epoch
 606 |         ):
 607 |             self.trainer.save_checkpoint(
 608 |                 os.path.join(
 609 |                     self.hparams.checkpoint_path,
 610 |                     "epoch_" + str(self.current_epoch) + ".ckpt",
 611 |                 )
 612 |             )
 613 |         if validation_is_real:
 614 |             return logs
 615 | 
 616 |     def test_step(self, batch, batch_idx):
 617 |         """
 618 |         Used by pytorch_lightning
 619 |         Runs one forward validation pass on one batch
 620 | 
 621 |         :param batch: The batch of equations to process
 622 |         :param batch_idx: which batch this is in the epoch.
 623 |         :returns: a dict with val_loss, val_accuracy, probabilities of solutions,
 624 |                   attentions, and values
 625 |         """
 626 | 
 627 |         loss, accuracy, coeff, x_lhs, y_hat_rhs, attentions, values = self._step(
 628 |             batch=batch, batch_idx=batch_idx, train=False, reduction="none"
 629 |         )
 630 |         output = {
 631 |             "partial_test_loss": coeff * loss,
 632 |             "partial_test_accuracy": coeff * accuracy,
 633 |             "y_hat_rhs": y_hat_rhs,
 634 |             "partial_attentions": attentions,
 635 |             "partial_values": values,
 636 |         }
 637 |         if self.current_epoch == 0:
 638 |             output["x_lhs"] = x_lhs
 639 | 
 640 |         return output
 641 | 
 642 |     def test_epoch_end(self, outputs):
 643 |         """
 644 |         Used by pytorch_lightning
 645 |         Accumulates results of all forward validation passes in this epoch
 646 | 
 647 |         :param outputs: a list of dicts from self.validation_step()
 648 |         :param batch_idx: which batch this is in the epoch.
 649 |         :returns: a dict with val_loss, val_accuracy
 650 |         """
 651 |         loss = torch.cat([x["partial_test_loss"] for x in outputs], dim=0)  # .sum()
 652 |         # loss = list([x["partial_test_loss"] for x in outputs])  # .sum()
 653 |         perplexity = torch.exp(loss)
 654 |         accuracy = torch.cat([x["partial_test_accuracy"] for x in outputs], dim=0)
 655 | 
 656 |         logs = {
 657 |             "test_loss": loss,
 658 |             "test_accuracy": accuracy,
 659 |             "test_perplexity": perplexity,
 660 |         }
 661 | 
 662 |         return {"test_loss": loss, "log": logs}
 663 | 
 664 |     def forward(self, *args, **kwargs) -> Any:
 665 |         """Passes all arguments directly to Tranformer.forward()"""
 666 |         return self.transformer(*args, **kwargs)
 667 | 
 668 | 
 669 | def train(hparams: Namespace) -> None:
 670 |     """
 671 |     This is the main trainer_method. This sets up and runs experiment with
 672 |     the defined hyperparameters
 673 | 
 674 |     :param hparams: An argparse.Namespace with all of the relevant hyperparameters
 675 |     """
 676 | 
 677 |     # Process the args
 678 |     if hparams.logdir is None:
 679 |         hparams.logdir = os.environ.get("LOGDIR", ".")
 680 |     hparams.logdir = os.path.abspath(hparams.logdir)
 681 | 
 682 |     # Make sure d_model, heads, and d_key are compatible
 683 |     assert (
 684 |         hparams.d_model % hparams.n_heads == 0
 685 |     ), "n_heads=%s does not evenly divide d_model=%s" % (
 686 |         hparams.n_heads,
 687 |         hparams.d_model,
 688 |     )
 689 |     hparams.d_key = hparams.d_model / hparams.n_heads
 690 | 
 691 |     # Set up the RNGs for repeatability
 692 |     if hparams.random_seed != -1:
 693 |         torch.manual_seed(hparams.random_seed)
 694 |         torch.cuda.manual_seed(hparams.random_seed)
 695 |         torch.backends.cudnn.deterministic = True
 696 |         torch.backends.cudnn.benchmark = False
 697 | 
 698 |     checkpoint_path = hparams.logdir + "/checkpoints"
 699 |     os.makedirs(checkpoint_path, exist_ok=True)
 700 |     hparams.checkpoint_path = checkpoint_path
 701 | 
 702 |     # Create the model
 703 |     model = TrainableTransformer(hparams).float()
 704 | 
 705 |     torch.save(model, os.path.join(checkpoint_path, "init.pt"))
 706 | 
 707 |     logger = CSVLogger(hparams.logdir)
 708 | 
 709 |     # checkpointer = ModelCheckpoint(
 710 |     #     filepath=checkpoint_path,
 711 |     #     monitor="save_ckpt",
 712 |     #     mode="max",
 713 |     #     save_top_k=len(hparams.ckpt_epochs),
 714 |     #     verbose=False,
 715 |     # )
 716 | 
 717 |     trainer_args = {
 718 |         "max_steps": hparams.max_steps,
 719 |         "min_steps": hparams.max_steps,
 720 |         "max_epochs": int(1e8),
 721 |         "val_check_interval": 1,
 722 |         "profiler": False,
 723 |         # "checkpoint_callback": checkpointer,
 724 |         "logger": logger,
 725 |         "log_every_n_steps": 1,
 726 |         "flush_logs_every_n_steps": 1000,
 727 |     }
 728 |     if torch.cuda.is_available() and hparams.gpu >= 0:
 729 |         trainer_args["gpus"] = [hparams.gpu]
 730 | 
 731 |     trainer = Trainer(**trainer_args)
 732 | 
 733 |     trainer.fit(model=model)  # type: ignore
 734 |     """
 735 |     margin = np.percentile(model.margin.detach().cpu().numpy(), 5)
 736 |     device = transformer.embedding.weight.device
 737 |     measures, bounds = metrics.calculate(
 738 |         transformer,
 739 |         transformer_init.to(device),
 740 |         device,
 741 |         dataset_size,
 742 |         margin,
 743 |         input_dim=hparams.d_model,
 744 |     )
 745 | 
 746 |     measures_file = os.path.join(logger.log_dir, "measures.json")
 747 |     bounds_file = os.path.join(logger.log_dir, "bounds.json")
 748 |     with open(measures_file, "w") as fh:
 749 |         json.dump(measures, fh)
 750 |     with open(bounds_file, "w") as fh:
 751 |         json.dump(bounds, fh)
 752 |     """
 753 |     return hparams.logdir
 754 | 
 755 | 
 756 | def compute_sharpness(hparams: Namespace, ckpts) -> None:
 757 |     """
 758 |     This is the compute_sharpness method. This loads a series of checkpoints in
 759 |     the defined hyperparameters
 760 | 
 761 |     :param hparams: An argparse.Namespace with all of the relevant hyperparameters
 762 |     """
 763 | 
 764 |     # Process the args
 765 |     if hparams.logdir is None:
 766 |         hparams.logdir = os.environ.get("LOGDIR", ".")
 767 |     hparams.logdir = os.path.abspath(hparams.logdir)
 768 | 
 769 |     # Make sure d_model, heads, and d_key are compatible
 770 |     assert (
 771 |         hparams.d_model % hparams.n_heads == 0
 772 |     ), "n_heads=%s does not evenly divide d_model=%s" % (
 773 |         hparams.n_heads,
 774 |         hparams.d_model,
 775 |     )
 776 |     hparams.d_key = hparams.d_model / hparams.n_heads
 777 | 
 778 |     # Set up the RNGs for repeatability
 779 |     if hparams.random_seed != -1:
 780 |         torch.manual_seed(hparams.random_seed)
 781 |         torch.cuda.manual_seed(hparams.random_seed)
 782 |         torch.backends.cudnn.deterministic = True
 783 |         torch.backends.cudnn.benchmark = False
 784 | 
 785 |     checkpoint_path = hparams.logdir + "/checkpoints"
 786 |     os.makedirs(checkpoint_path, exist_ok=True)
 787 |     hparams.checkpoint_path = checkpoint_path
 788 | 
 789 |     # Create the model
 790 |     model = TrainableTransformer(hparams).float()
 791 | 
 792 |     torch.save(model, os.path.join(checkpoint_path, "init.pt"))
 793 | 
 794 |     logger = CSVLogger(hparams.logdir)
 795 | 
 796 | 
 797 |     trainer_args = {
 798 |         "max_steps": hparams.max_steps,
 799 |         "min_steps": hparams.max_steps,
 800 |         "max_epochs": int(1e8),
 801 |         "val_check_interval": 1,
 802 |         "profiler": False,
 803 |         # "checkpoint_callback": checkpointer,
 804 |         "logger": logger,
 805 |         "log_every_n_steps": 1,
 806 |         "flush_logs_every_n_steps": 1000,
 807 |     }
 808 |     if torch.cuda.is_available() and hparams.gpu >= 0:
 809 |         trainer_args["gpus"] = [hparams.gpu]
 810 | 
 811 |     trainer = Trainer(**trainer_args)
 812 | 
 813 |     for ckpt in ckpts:
 814 |         print(f"Loading checkpoint {ckpt}")
 815 |         # model = torch.load(ckpt)
 816 |         # model.load_state_dict(torch.load(ckpt))
 817 | 
 818 |         checkpoint = torch.load(ckpt)
 819 |         # print(dir(checkpoint), type(checkpoint), "Ckpt")
 820 |         # for k, v in checkpoint.items():
 821 |         #     print(k)
 822 |         # print(checkpoint["hyper_parameters"])
 823 | 
 824 |         hps = checkpoint["hyper_parameters"]
 825 |         hps = argparse.Namespace(**hps)
 826 |         model = TrainableTransformer(hps).float()
 827 |         model.load_state_dict(checkpoint["state_dict"])
 828 | 
 829 |         phi = get_sharpness(model.train_dataloader(), model)
 830 |         results = {}
 831 |         results[ckpt] = phi
 832 |         pickle.dump(results, open(f"results/results_SD-{i}.pkl", "wb"))
 833 | 
 834 | 
 835 | def add_args(parser=None) -> Namespace:
 836 |     """
 837 |     Parses the command line arguments
 838 | 
 839 |     :returns: an argparse.Namespace with all of the needed arguments
 840 |     """
 841 |     if parser is None:
 842 |         parser = ArgumentParser()
 843 |     parser.add_argument("--random_seed", type=int, default=-1)
 844 |     parser.add_argument("--gpu", type=int, default=0)
 845 |     parser.add_argument("--max_epochs", type=int, default=None)
 846 |     parser.add_argument("--max_steps", type=int, default=100000)
 847 |     # parser.add_argument("--checkpoint_period", type=int, default=1)
 848 |     parser = TrainableTransformer.add_model_specific_args(parser)
 849 |     return parser
 850 | 
 851 | 
 852 | class CustomAdamW(torch.optim.Optimizer):
 853 |     def __init__(
 854 |         self,
 855 |         params,
 856 |         lr=1e-3,
 857 |         betas=(0.9, 0.999),
 858 |         eps=1e-8,
 859 |         weight_decay=1e-2,
 860 |         amsgrad=False,
 861 |         noise_factor=0.0,
 862 |         weight_decay_form="to_zero",
 863 |     ):
 864 |         if not 0.0 <= lr:
 865 |             raise ValueError("Invalid learning rate: {}".format(lr))
 866 |         if not 0.0 <= eps:
 867 |             raise ValueError("Invalid epsilon value: {}".format(eps))
 868 |         if not 0.0 <= betas[0] < 1.0:
 869 |             raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
 870 |         if not 0.0 <= betas[1] < 1.0:
 871 |             raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
 872 |         if not weight_decay_form in ["to_zero", "to_init", "jiggle", "honest"]:
 873 |             raise ValueError(
 874 |                 f"Invalid weight decay form: {weight_decay_form}, should be one of ['to_zero', 'to_init', 'jiggle']"
 875 |             )
 876 |         # if not 0.0 <= weight_decay:
 877 |         #     raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
 878 |         defaults = dict(
 879 |             lr=lr,
 880 |             betas=betas,
 881 |             eps=eps,
 882 |             weight_decay=weight_decay,
 883 |             amsgrad=amsgrad,
 884 |             noise_factor=noise_factor,
 885 |             weight_decay_form=weight_decay_form,
 886 |         )
 887 |         super(CustomAdamW, self).__init__(params, defaults)
 888 | 
 889 |     def __setstate__(self, state):
 890 |         super(CustomAdamW, self).__setstate__(state)
 891 |         for group in self.param_groups:
 892 |             group.setdefault("amsgrad", False)
 893 | 
 894 |     @torch.no_grad()
 895 |     def step(self, closure=None):
 896 |         """Performs a single optimization step.
 897 | 
 898 |         Arguments:
 899 |             closure (callable, optional): A closure that reevaluates the model
 900 |                 and returns the loss.
 901 |         """
 902 |         loss = None
 903 |         if closure is not None:
 904 |             with torch.enable_grad():
 905 |                 loss = closure()
 906 | 
 907 |         for group in self.param_groups:
 908 |             for p in group["params"]:
 909 |                 if p.grad is None:
 910 |                     continue
 911 | 
 912 |                 # Perform optimization step
 913 |                 grad = p.grad
 914 | 
 915 |                 if group["weight_decay"] > 0:
 916 |                     if group["weight_decay_form"] == "honest":
 917 |                         grad = grad + group["weight_decay"] * p.detach()
 918 | 
 919 |                 if grad.is_sparse:
 920 |                     raise RuntimeError(
 921 |                         "Adam does not support sparse gradients, please consider SparseAdam instead"
 922 |                     )
 923 |                 amsgrad = group["amsgrad"]
 924 | 
 925 |                 state = self.state[p]
 926 | 
 927 |                 # State initialization
 928 |                 if len(state) == 0:
 929 |                     state["step"] = 0
 930 |                     # Exponential moving average of gradient values
 931 |                     state["exp_avg"] = torch.zeros_like(
 932 |                         p, memory_format=torch.preserve_format
 933 |                     )
 934 |                     # Exponential moving average of squared gradient values
 935 |                     state["exp_avg_sq"] = torch.zeros_like(
 936 |                         p, memory_format=torch.preserve_format
 937 |                     )
 938 |                     if group["weight_decay_form"] == "to_init":
 939 |                         state["init"] = p.detach().clone()
 940 |                     if amsgrad:
 941 |                         # Maintains max of all exp. moving avg. of sq. grad. values
 942 |                         state["max_exp_avg_sq"] = torch.zeros_like(
 943 |                             p, memory_format=torch.preserve_format
 944 |                         )
 945 | 
 946 |                 if group["weight_decay"] > 0:
 947 |                     if group["weight_decay_form"] == "to_zero":
 948 |                         p.mul_(1 - group["lr"] * group["weight_decay"])
 949 |                     elif group["weight_decay_form"] == "to_init":
 950 |                         p.add_(
 951 |                             (state["init"] - p) * (group["lr"] * group["weight_decay"])
 952 |                         )
 953 |                     elif group["weight_decay_form"] == "jiggle":
 954 |                         p.mul_(
 955 |                             torch.exp(
 956 |                                 torch.randn(1).cuda()
 957 |                                 * (group["lr"] * group["weight_decay"])
 958 |                             )
 959 |                         )
 960 |                     elif group["weight_decay_form"] == "honest":
 961 |                         pass
 962 |                     else:
 963 |                         raise ValueError(
 964 |                             f"Invalid weight decay form: {group['weight_decay_form']}"
 965 |                         )
 966 | 
 967 |                 exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
 968 |                 if amsgrad:
 969 |                     max_exp_avg_sq = state["max_exp_avg_sq"]
 970 |                 beta1, beta2 = group["betas"]
 971 | 
 972 |                 state["step"] += 1
 973 |                 bias_correction1 = 1 - beta1 ** state["step"]
 974 |                 bias_correction2 = 1 - beta2 ** state["step"]
 975 | 
 976 |                 # Decay the first and second moment running average coefficient
 977 |                 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
 978 |                 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
 979 |                 if amsgrad:
 980 |                     # Maintains the maximum of all 2nd moment running avg. till now
 981 |                     torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
 982 |                     # Use the max. for normalizing running avg. of gradient
 983 |                     denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
 984 |                         group["eps"]
 985 |                     )
 986 |                 else:
 987 |                     denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
 988 |                         group["eps"]
 989 |                     )
 990 | 
 991 |                 step_size = group["lr"] / bias_correction1
 992 | 
 993 |                 upd = exp_avg / denom
 994 |                 # add uniform gaussian noise to the update
 995 |                 if group["noise_factor"] > 0:
 996 |                     upd += torch.randn_like(upd) * group["noise_factor"]
 997 |                 # if group['noise_factor'] > 0:
 998 |                 #     upd *= torch.exp(torch.randn_like(upd) * group['noise_factor'])
 999 |                 p.add_(-step_size * upd)
1000 | 
1001 |         return loss
1002 | 
1003 | 
1004 | class SAM(torch.optim.Optimizer):
1005 |     def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
1006 |         assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
1007 | 
1008 |         defaults = dict(rho=rho, **kwargs)
1009 |         super(SAM, self).__init__(params, defaults)
1010 | 
1011 |         self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
1012 |         self.param_groups = self.base_optimizer.param_groups
1013 | 
1014 |     @torch.no_grad()
1015 |     def first_step(self, zero_grad=False):
1016 |         grad_norm = self._grad_norm()
1017 |         for group in self.param_groups:
1018 |             scale = group["rho"] / (grad_norm + 1e-12)
1019 | 
1020 |             for p in group["params"]:
1021 |                 if p.grad is None:
1022 |                     continue
1023 |                 e_w = p.grad * scale.to(p)
1024 |                 p.add_(e_w)  # climb to the local maximum "w + e(w)"
1025 |                 self.state[p]["e_w"] = e_w
1026 | 
1027 |         if zero_grad:
1028 |             self.zero_grad()
1029 | 
1030 |     @torch.no_grad()
1031 |     def second_step(self, zero_grad=False):
1032 |         for group in self.param_groups:
1033 |             for p in group["params"]:
1034 |                 if p.grad is None:
1035 |                     continue
1036 |                 p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"
1037 | 
1038 |         self.base_optimizer.step()  # do the actual "sharpness-aware" update
1039 | 
1040 |         if zero_grad:
1041 |             self.zero_grad()
1042 | 
1043 |     @torch.no_grad()
1044 |     def step(self, closure=None):
1045 |         assert (
1046 |             closure is not None
1047 |         ), "Sharpness Aware Minimization requires closure, but it was not provided"
1048 |         closure = torch.enable_grad()(
1049 |             closure
1050 |         )  # the closure should do a full forward-backward pass
1051 | 
1052 |         self.first_step(zero_grad=True)
1053 |         closure()
1054 |         self.second_step()
1055 | 
1056 |     def _grad_norm(self):
1057 |         shared_device = self.param_groups[0]["params"][
1058 |             0
1059 |         ].device  # put everything on the same device, in case of model parallelism
1060 |         grad_norms = [
1061 |             p.grad.norm(p=2).to(shared_device)
1062 |             for group in self.param_groups
1063 |             for p in group["params"]
1064 |             if p.grad is not None
1065 |         ]
1066 |         print("grad norms is ", grad_norms, "!" * 1000)
1067 |         norm = torch.norm(
1068 |             torch.stack(grad_norms),
1069 |             p=2,
1070 |         )
1071 |         return norm
1072 | 


--------------------------------------------------------------------------------
/grok/transformer.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/env python
  2 | from argparse import ArgumentParser, Namespace
  3 | from typing import Tuple, List, Dict, Union
  4 | 
  5 | import numpy as np
  6 | import torch
  7 | import torch.nn as nn
  8 | import torch.nn.functional as F
  9 | from numpy import cos, sin, sqrt
 10 | from torch import tensor, Tensor
 11 | from torch.optim.lr_scheduler import LambdaLR
 12 | import pytorch_lightning as pl
 13 | 
 14 | from argparse import ArgumentParser
 15 | 
 16 | 
 17 | class Linear(nn.Linear):
 18 |     def __init__(self, *args, **kwargs):
 19 |         self.weight_noise = kwargs.pop("weight_noise")
 20 |         super().__init__(*args, **kwargs)
 21 | 
 22 |     def forward(self, input: Tensor) -> Tensor:
 23 |         if self.weight_noise > 0 and self.training:
 24 |             bias = self.bias if self.bias is None else self.bias + torch.randn_like(self.bias) * self.weight_noise
 25 |             weight = self.weight + torch.randn_like(self.weight) * self.weight_noise
 26 |             # weight = self.weight * torch.exp(torch.randn_like(self.weight) * self.weight_noise)
 27 |         else:
 28 |             bias = self.bias
 29 |             weight = self.weight
 30 |             
 31 |         return F.linear(
 32 |             input,
 33 |             weight,
 34 |             bias,
 35 |         )
 36 | 
 37 | class LayerNorm(nn.LayerNorm):
 38 |     def __init__(self, *args, **kwargs):
 39 |         self.weight_noise = kwargs.pop("weight_noise")
 40 |         super().__init__(*args, **kwargs)
 41 | 
 42 |     def forward(self, input: Tensor) -> Tensor:
 43 |         if self.weight_noise > 0 and self.training:
 44 |             bias = self.bias if self.bias is None else self.bias + torch.randn_like(self.bias) * self.weight_noise
 45 |             weight = self.weight + torch.randn_like(self.weight) * self.weight_noise
 46 |             # weight = self.weight * torch.exp(torch.randn_like(self.weight) * self.weight_noise)
 47 |         else:
 48 |             bias = self.bias
 49 |             weight = self.weight
 50 |         return F.layer_norm(
 51 |             input,
 52 |             self.normalized_shape,
 53 |             weight,
 54 |             bias,
 55 |             self.eps,
 56 |         )
 57 | 
 58 | 
 59 | class Embedding(nn.Embedding):
 60 |     def __init__(self, *args, **kwargs):
 61 |         self.weight_noise = kwargs.pop("weight_noise")
 62 |         super().__init__(*args, **kwargs)
 63 | 
 64 |     def forward(self, input: Tensor) -> Tensor:
 65 |         if self.weight_noise > 0 and self.training:
 66 |             weight = self.weight + torch.randn_like(self.weight) * self.weight_noise
 67 |             # weight = self.weight * torch.exp(torch.randn_like(self.weight) * self.weight_noise)
 68 |         else:
 69 |             weight = self.weight
 70 |         return F.embedding(
 71 |             input,
 72 |             weight,
 73 |             self.padding_idx,
 74 |             self.max_norm,
 75 |             self.norm_type,
 76 |             self.scale_grad_by_freq,
 77 |             self.sparse,
 78 |         )
 79 | 
 80 | 
 81 | class AttentionHead(nn.Module):
 82 |     def __init__(self, d_model: int, d_key: int, weight_noise: float) -> None:
 83 | 
 84 |         super().__init__()
 85 | 
 86 |         self.d_key = d_key
 87 | 
 88 |         # head projections
 89 |         self.Wq = Linear(d_model, d_key, bias=False, weight_noise=weight_noise)
 90 |         self.Wk = Linear(d_model, d_key, bias=False, weight_noise=weight_noise)
 91 |         self.Wv = Linear(d_model, d_key, bias=False, weight_noise=weight_noise)
 92 | 
 93 |         self.softmax = nn.Softmax(dim=-1)
 94 | 
 95 |     def forward(
 96 |         self,
 97 |         queries: Tensor,
 98 |         keys: Tensor,
 99 |         values: Tensor,
100 |         mask: Union[Tensor, None] = None,
101 |         save_activations: bool = False,
102 |     ) -> Tuple[Tensor, Union[Tensor, None], Union[Tensor, None]]:
103 | 
104 |         # project queries, keys, values
105 |         queries = self.Wq(queries)
106 |         keys = self.Wk(keys)
107 |         values = self.Wv(values)
108 | 
109 |         # calculate compatibility function
110 |         attn = torch.matmul(queries, torch.transpose(keys, -2, -1))
111 |         attn = attn / sqrt(self.d_key)
112 | 
113 |         # Filter out attention to future positions
114 |         if mask is not None:
115 |             attn.masked_fill_(mask == 0, float("-inf"))
116 | 
117 |         # softmax
118 |         attn = self.softmax(attn)
119 | 
120 |         # sum the weighted value vectors
121 |         result: Tensor = torch.matmul(attn, values)  # shape = (max_context_len, d_key)
122 |         if save_activations:
123 |             leaf_attn = attn.clone().detach()  # type: ignore
124 |             leaf_values = values.clone().detach()  # type: ignore
125 |         else:
126 |             leaf_attn = None  # type: ignore
127 |             leaf_values = None  # type: ignore
128 | 
129 |         return result, leaf_attn, leaf_values
130 | 
131 | 
132 | class MultiHeadAttention(nn.Module):
133 |     def __init__(self, d_model: int, heads: int, weight_noise: float = 0.0) -> None:
134 |         super().__init__()
135 |         d_key = int(d_model / heads)
136 | 
137 |         attn_heads = [
138 |             AttentionHead(d_model, d_key, weight_noise=weight_noise)
139 |             for _ in range(heads)
140 |         ]
141 |         self.attn_heads = nn.ModuleList(attn_heads)
142 |         self.Wo = Linear(d_model, d_model, bias=False, weight_noise=weight_noise)
143 | 
144 |     def forward(
145 |         self,
146 |         queries: Tensor,
147 |         keys: Tensor,
148 |         values: Tensor,
149 |         mask: Tensor = None,
150 |         save_activations=False,
151 |     ) -> Tuple[Tensor, List[Tensor], List[Tensor]]:
152 | 
153 |         head_outputs = [
154 |             h(
155 |                 queries=queries,
156 |                 keys=keys,
157 |                 values=values,
158 |                 mask=mask,
159 |                 save_activations=save_activations,
160 |             )
161 |             for h in self.attn_heads
162 |         ]
163 |         head_results = [output[0] for output in head_outputs]
164 | 
165 |         if save_activations:
166 |             layer_attns = list([output[1] for output in head_outputs])
167 |             layer_values = list([output[2] for output in head_outputs])
168 |         else:
169 |             layer_attns = []
170 |             layer_values = []
171 | 
172 |         multihead_result = torch.cat(head_results, dim=-1)
173 |         multihead_result = self.Wo(multihead_result)
174 |         return multihead_result, layer_attns, layer_values
175 | 
176 | 
177 | class FFN(nn.Module):
178 |     def __init__(
179 |         self,
180 |         d_model: int,
181 |         multiplier: int = 4,
182 |         non_linearity: str = "relu",
183 |         weight_noise: float = 0.0,
184 |     ) -> None:
185 |         super().__init__()
186 | 
187 |         d_ff = int(multiplier * d_model)
188 | 
189 |         non_linearities = {"relu": nn.ReLU, "gelu": nn.GELU}
190 | 
191 |         self.ffn = nn.Sequential(
192 |             Linear(d_model, d_ff, bias=False, weight_noise=weight_noise),
193 |             non_linearities[non_linearity](),
194 |             Linear(d_ff, d_model, bias=False, weight_noise=weight_noise),
195 |         )
196 | 
197 |     def forward(self, x: Tensor) -> Tensor:
198 |         return self.ffn(x)
199 | 
200 | 
201 | class DecoderBlock(nn.Module):
202 |     def __init__(
203 |         self,
204 |         d_model: int,
205 |         heads: int,
206 |         dropout: float,
207 |         non_linearity: str = "relu",
208 |         weight_noise: float = 0.0,
209 |     ) -> None:
210 |         super().__init__()
211 | 
212 |         self.self_attn = MultiHeadAttention(d_model, heads, weight_noise=weight_noise)
213 |         # self.self_attn_drop = nn.Dropout(p=dropout)
214 |         self.self_attn_norm = LayerNorm(d_model, weight_noise=weight_noise)
215 | 
216 |         self.ffn = FFN(d_model, non_linearity=non_linearity, weight_noise=weight_noise)
217 |         self.ffn_drop = nn.Dropout(p=dropout)
218 |         self.ffn_norm = LayerNorm(d_model, weight_noise=weight_noise)
219 | 
220 |     def forward(
221 |         self,
222 |         x: Tensor,
223 |         self_attn_mask: Tensor = None,
224 |         save_activations: bool = False,
225 |     ) -> Tuple[Tensor, List[Tensor], List[Tensor]]:
226 |         a1, layer_attns, layer_values = self.self_attn(
227 |             x, x, x, self_attn_mask, save_activations
228 |         )
229 |         # a1 = self.self_attn_drop(a1)
230 |         a1 = self.self_attn_norm(x + a1)
231 | 
232 |         a2 = self.ffn(a1)
233 |         a2 = self.ffn_drop(a2)
234 |         a2 = self.ffn_norm(a1 + a2)
235 | 
236 |         return a2, layer_attns, layer_values
237 | 
238 | 
239 | class Decoder(nn.Module):
240 |     def __init__(
241 |         self,
242 |         d_model: int,
243 |         heads: int,
244 |         num_blocks: int,
245 |         dropout: float,
246 |         non_linearity: str = "relu",
247 |         weight_noise: float = 0.0,
248 |     ) -> None:
249 |         super().__init__()
250 | 
251 |         self.blocks = nn.ModuleList(
252 |             [
253 |                 DecoderBlock(
254 |                     d_model, heads, dropout, non_linearity, weight_noise=weight_noise
255 |                 )
256 |                 for _ in range(num_blocks)
257 |             ]
258 |         )
259 | 
260 |     def forward(
261 |         self,
262 |         x: Tensor,
263 |         self_attn_mask: Tensor = None,
264 |         save_activations=False,
265 |     ) -> Tuple[Tensor, List[List[Tensor]], List[List[Tensor]]]:
266 | 
267 |         a = x
268 |         attentions = []
269 |         values = []
270 |         for block in self.blocks:
271 |             a, layer_attentions, layer_values = block(
272 |                 a, self_attn_mask, save_activations=save_activations
273 |             )
274 |             if save_activations:
275 |                 attentions.append(layer_attentions)
276 |                 values.append(layer_values)
277 |         return a, attentions, values
278 | 
279 | 
280 | class Transformer(nn.Module):
281 |     def __init__(
282 |         self,
283 |         n_layers: int = 4,
284 |         n_heads: int = 4,
285 |         d_model: int = 256,
286 |         dropout: float = 0.1,
287 |         max_context_len: int = 1024,
288 |         vocab_len: int = 2000,
289 |         non_linearity: str = "relu",
290 |         weight_noise: float = 0.0,
291 |     ) -> None:
292 |         super().__init__()
293 | 
294 |         self.n_layers = n_layers
295 |         self.n_heads = n_heads
296 |         self.d_model = d_model
297 |         self.dropout = dropout
298 |         self.max_context_len = max_context_len
299 |         self.non_linearity = non_linearity
300 | 
301 |         self.vocab_len = vocab_len
302 | 
303 |         self.embedding = Embedding(vocab_len, d_model, weight_noise=weight_noise)  # type: ignore
304 |         self.register_buffer(
305 |             "position_encoding", self._position_encoding(max_context_len, d_model)
306 |         )
307 |         self.register_buffer("self_attn_mask", self.make_mask(max_context_len))
308 | 
309 |         self.decoder = Decoder(
310 |             d_model,
311 |             n_heads,
312 |             n_layers,
313 |             dropout,
314 |             self.non_linearity,
315 |             weight_noise=weight_noise,
316 |         )
317 | 
318 |         self.linear = Linear(d_model, vocab_len, bias=False, weight_noise=weight_noise)
319 | 
320 |     @staticmethod
321 |     def make_mask(context_len: int) -> Tensor:
322 |         return torch.ones([context_len, context_len]).tril()
323 | 
324 |     @classmethod
325 |     def _position_encoding(cls, context_len: int, d_model: int) -> Tensor:
326 |         rows = [
327 |             tensor(
328 |                 [
329 |                     sin(pos / (10000 ** (i / d_model)))
330 |                     if i % 2 == 0
331 |                     else cos(pos / (10000 ** ((i - 1) / d_model)))
332 |                     for i in range(d_model)
333 |                 ]
334 |             )
335 |             for pos in range(context_len)
336 |         ]
337 |         stack = torch.stack(rows, dim=1)
338 | 
339 |         return stack.T  # type: ignore
340 | 
341 |     def embed(self, indices: Tensor) -> Tensor:
342 |         context_len = indices.shape[-1]
343 |         pe = self.position_encoding[:context_len, :]  # type: ignore
344 | 
345 |         embedded = self.embedding(indices)
346 | 
347 |         return pe + embedded
348 | 
349 |     def forward(
350 |         self,
351 |         x: Tensor,
352 |         pos: int = None,
353 |         save_activations: bool = False,
354 |     ) -> Tuple[Tensor, Union[Tensor, None], Union[Tensor, None]]:
355 |         """parameters:
356 |         x:  (rank-1 tensor) vocab indices of decoder input token
357 |                      sequence"""
358 | 
359 |         # Make sure sampling inputs are on the correct device
360 |         x = x.to(self.embedding.weight.device)
361 | 
362 |         # make_attention mask
363 |         this_max_context_len = x.shape[-1]
364 |         self_attn_mask = self.self_attn_mask[  # type: ignore
365 |             :this_max_context_len, :this_max_context_len
366 |         ]
367 | 
368 |         # Decode
369 |         x = self.embed(x)
370 |         decoded, attentions, values = self.decoder(
371 |             x, self_attn_mask, save_activations=save_activations
372 |         )
373 | 
374 |         # Return predictions for specific token
375 |         if pos is not None:
376 |             decoded = decoded[:, pos, :]
377 | 
378 |         y_hat = self.linear(decoded)
379 |         return y_hat, attentions, values
380 | 


--------------------------------------------------------------------------------
/grok/visualization.py:
--------------------------------------------------------------------------------
  1 | import csv
  2 | import logging
  3 | import os
  4 | import math
  5 | import socket
  6 | 
  7 | from collections import defaultdict
  8 | from copy import deepcopy
  9 | 
 10 | import matplotlib.pyplot as plt
 11 | import matplotlib.ticker as mtick
 12 | import numpy as np
 13 | import torch
 14 | 
 15 | from mpl_toolkits.axes_grid1 import make_axes_locatable
 16 | from tqdm import tqdm
 17 | 
 18 | from grok.data import ArithmeticDataset
 19 | 
 20 | logging.basicConfig(level=logging.ERROR)
 21 | logger = logging.getLogger("grok.view_metrics")
 22 | logger.setLevel(logging.ERROR)
 23 | 
 24 | GROK_DIR = os.path.expanduser("~/data/grok")
 25 | IMAGE_DIR = f"{GROK_DIR}/images"
 26 | DATA_DIR = f"{GROK_DIR}/data"
 27 | 
 28 | 
 29 | DEFAULT_CMAP = "viridis"
 30 | 
 31 | default_metric_limits = {
 32 |     "min_val_accuracy": 0,
 33 |     "max_val_accuracy": 100,
 34 |     "min_T": 0,  # 0
 35 |     "max_T": 100,  # 87.5
 36 |     "min_D": 0,  # 8
 37 |     "max_D": 2048,  # 256
 38 |     "min_H": 0,  # 1
 39 |     "max_H": 1204,  # 8
 40 |     "min_L": 0,  # 1
 41 |     "max_L": 1024,  # 4
 42 |     "min_accuracy": 0,
 43 |     "max_accuracy": 100,
 44 | }
 45 | 
 46 | default_axis_scales = {"x": "linear", "y": "linear"}
 47 | 
 48 | 
 49 | ## Data Loading Functions
 50 | 
 51 | 
 52 | def factor_expts(expts):
 53 |     result = {}
 54 |     for expt in expts:
 55 |         expt_s = expt.split("_")
 56 |         arch = "_".join(expt_s[:3])
 57 |         t = int(float(expt_s[3].split("-")[1]))
 58 |         result.setdefault(arch, {})
 59 |         result[arch][t] = expt
 60 |     return result
 61 | 
 62 | 
 63 | def load_metric_data(data_dir, epochs=100000, load_partial_data=True):
 64 |     # layers x heads x d_model x train_pct
 65 |     data = {}
 66 |     expts = os.listdir(data_dir)
 67 |     archs = factor_expts(expts)
 68 |     logger.debug(archs)
 69 |     for arch in archs:
 70 |         T = sorted(archs[arch].keys())
 71 |         data[arch] = {
 72 |             "T": torch.LongTensor(T),
 73 |             "metrics": torch.zeros((max(T), 5, epochs)),
 74 |         }
 75 |         # print(f"metrics_shape = {data[arch]['metrics'].shape}")
 76 |         for i, t in tqdm(list(enumerate(T))):
 77 |             expt = archs[arch][t]
 78 |             logger.debug(expt)
 79 |             log_dir = data_dir + "/" + expt
 80 | 
 81 |             # print("log_dir", log_dir)
 82 |             try:
 83 |                 with open(log_dir + "/default/version_0/metrics.csv", "r") as fh:
 84 |                     logger.debug(f"loading {log_dir}")
 85 |                     reader = list(csv.DictReader(fh))
 86 |                     val_t = torch.FloatTensor(
 87 |                         [
 88 |                             [
 89 |                                 float(r["val_loss"]),
 90 |                                 float(r["val_accuracy"]),
 91 |                             ]
 92 |                             for r in reader
 93 |                             if r["val_loss"]
 94 |                         ]
 95 |                     ).T
 96 |                     train_t = torch.FloatTensor(
 97 |                         [
 98 |                             [
 99 |                                 float(r["learning_rate"]),
100 |                                 float(r["train_loss"]),
101 |                                 float(r["train_accuracy"]),
102 |                             ]
103 |                             for r in reader
104 |                             if r["train_loss"]
105 |                         ]
106 |                     ).T
107 |                     # logger.debug(val_t.shape)
108 |                     # logger.debug(train_t[0, -3:])
109 |                     if load_partial_data:
110 |                         raise Exception("Not implemented")
111 |                     elif (val_t.shape[-1] >= epochs) and (train_t.shape[-1] >= epochs):
112 |                         data[arch]["metrics"][i] = torch.cat(
113 |                             [train_t[..., :epochs], val_t[..., :epochs]], dim=0
114 |                         )
115 |                     else:
116 |                         data[arch]["T"][i] = 0
117 |             # except FileNotFoundError:
118 |             except:
119 |                 data[arch]["T"][i] = 0
120 |         indices = torch.nonzero(data[arch]["T"]).squeeze()
121 |         if len(indices.shape) == 0:
122 |             indices = indices.unsqueeze(0)
123 |         # print(f"indices.shape = {indices.shape}")
124 |         data[arch]["T"] = data[arch]["T"][indices]
125 |         # print(f"data[arch]['T'].shape = {data[arch]['T'].shape}")
126 |         data[arch]["metrics"] = data[arch]["metrics"][indices]
127 |         # print(f"data[arch]['metrics'].shape = {data[arch]['metrics'].shape}")
128 |         data[arch]["metrics"] = torch.transpose(data[arch]["metrics"], 0, 1)
129 |         # print(f"data[arch]['metrics'].shape = {data[arch]['metrics'].shape}")
130 |     return data
131 | 
132 | 
133 | def most_interesting(metric_data):
134 |     interesting_metric_data = {}
135 |     for arch in metric_data:
136 |         T = metric_data[arch]["T"]
137 |         max_acc_by_t = torch.max(
138 |             metric_data[arch]["val_accuracy"], dim=1, keepdim=True
139 |         ).values.squeeze()
140 |         max_loss_by_t = torch.max(
141 |             metric_data[arch]["val_loss"], dim=1, keepdim=True
142 |         ).values.squeeze()
143 |         acc_idx = torch.nonzero(max_acc_by_t >= 95).squeeze()
144 |         if acc_idx.shape == torch.Size([0]):
145 |             acc_idx = torch.nonzero(max_acc_by_t == max_acc_by_t.max()).squeeze()
146 |         if acc_idx.shape == torch.Size([]):
147 |             acc_idx = acc_idx.unsqueeze(0)
148 |         max_loss = torch.max(max_loss_by_t[acc_idx])
149 |         loss_idx = torch.nonzero(max_loss_by_t[acc_idx] == max_loss)
150 |         interesting_idx = acc_idx[loss_idx].squeeze()
151 | 
152 |         interesting_metric_data[arch] = {}
153 |         for k in metric_data[arch]:
154 |             interesting_metric_data[arch][k] = metric_data[arch][k][
155 |                 interesting_idx
156 |             ].unsqueeze(0)
157 | 
158 |         return interesting_metric_data
159 | 
160 | 
161 | # ## Graph Drawing Functions
162 | 
163 | 
164 | def moving_avg(Y, steps):
165 |     return np.convolve(Y, np.ones(steps), "valid") / steps
166 | 
167 | 
168 | def find_inflections(Y, smoothing_steps=100):
169 |     avg_Y = moving_avg(Y, smoothing_steps)
170 |     avg_direction = torch.FloatTensor(np.sign(avg_Y[1:] - avg_Y[:-1]))
171 |     avg_direction = torch.cat([avg_direction[0].unsqueeze(0), avg_direction])
172 |     avg_inflections = torch.nonzero(avg_direction[1:] - avg_direction[:-1]).squeeze()
173 |     avg_inflections = [0] + (avg_inflections + 1).tolist() + [len(Y) - 1]
174 |     logger.debug(f"avg_inflections = {avg_inflections}")
175 |     inflections = []
176 |     for i in range(2, len(avg_inflections)):
177 |         low = avg_inflections[i - 2]
178 |         high = avg_inflections[i]
179 |         logger.debug(f"low={low}")
180 |         logger.debug(f"high={high}")
181 |         if avg_direction[low + 1] < 0:
182 |             indices = Y[low:high].argmin() + low
183 |             logger.debug(f"min = (Y[{indices}] = {Y[int(indices)]}")
184 |         else:
185 |             indices = Y[low:high].argmax() + low
186 |             logger.debug(f"max = (Y[{indices}] = {Y[int(indices)]}")
187 |         inflections.append(indices)
188 |     return torch.LongTensor(inflections)
189 | 
190 | 
191 | def check_limits(arch_name, limits):
192 |     L, H, D = [float(v.split("-")[1]) for v in arch_name.split("_")]
193 |     if (L > limits["max_L"]) or (L < limits["min_L"]):
194 |         return False
195 |     if (H > limits["max_H"]) or (H < limits["min_H"]):
196 |         return False
197 |     if (D > limits["max_D"]) or (D < limits["min_D"]):
198 |         return False
199 |     # if (T > limits['max_T']) or (T < limits['min_T']):
200 |     #    return False
201 |     return True
202 | 
203 | 
204 | def filter_archs(data, limits={}):
205 |     my_limits = deepcopy(default_metric_limits)
206 |     my_limits.update(limits)
207 |     limits = my_limits
208 |     archs = sorted(list(set([a for a in data.keys() if check_limits(a, limits)])))
209 |     logger.debug(f"archs = {archs}")
210 |     return archs
211 | 
212 | 
213 | def get_metric_data(data, limits={}):
214 |     my_limits = deepcopy(default_metric_limits)
215 |     my_limits.update(limits)
216 |     limits = my_limits
217 | 
218 |     for k in limits.keys():
219 |         metric = k.replace("min_", "").replace("max_", "")
220 |         assert (
221 |             limits["max_" + metric] >= limits["min_" + metric]
222 |         ), f"invalid {metric} limits"
223 | 
224 |     d = {}
225 |     for arch in filter_archs(data, limits):
226 |         logger.debug(arch)
227 |         indices = torch.nonzero(
228 |             torch.logical_and(
229 |                 data[arch]["T"] >= limits["min_T"], data[arch]["T"] <= limits["max_T"]
230 |             )
231 |         ).squeeze(dim=-1)
232 |         logger.debug(f"indices={indices}")
233 |         learning_rate, train_loss, train_accuracy, val_loss, val_accuracy = data[arch][
234 |             "metrics"
235 |         ][:, indices, :]
236 |         d[arch] = {
237 |             "T": data[arch]["T"][indices],
238 |             "learning_rate": data[arch]["metrics"][0, indices, :],
239 |             "train_loss": data[arch]["metrics"][1, indices, :],
240 |             "train_accuracy": data[arch]["metrics"][2, indices, :],
241 |             "val_loss": data[arch]["metrics"][3, indices, :],
242 |             "val_accuracy": data[arch]["metrics"][4, indices, :],
243 |         }
244 |     return d
245 | 
246 | 
247 | def add_metric_graph(
248 |     fig,
249 |     ax,
250 |     metric,
251 |     metric_data,
252 |     scales=default_axis_scales,
253 |     cmap=DEFAULT_CMAP,
254 |     inflection_hline=False,
255 |     ds_len=None,
256 |     batchsize=97,
257 | ):
258 |     ax.set_title(metric)
259 |     ax.set_xscale(scales["x"])
260 |     ax.set_yscale(scales["y"])
261 |     if ds_len is None:
262 |         ax.set_xlabel("epochs")
263 |     else:
264 |         ax.set_xlabel("updates")
265 | 
266 |     # if 'loss' in metric:
267 |     #    ymin=0
268 |     #    ax.axis(ymin=ymin)
269 |     if "accuracy" in metric:
270 |         ax.yaxis.set_major_formatter(mtick.PercentFormatter())
271 |         ymin = 1e-16
272 |         ymax = 101
273 |         ax.axis(ymin=ymin, ymax=ymax)
274 |     if "loss" in metric:
275 |         ymin = 1e-16
276 |         ymax = 15
277 |         ax.axis(ymin=ymin, ymax=ymax)
278 | 
279 |     total_plots = 0
280 |     logger.debug(f"processing {metric}")
281 |     plots = []
282 |     for arch in metric_data:
283 |         metric_data[arch]["T"] = metric_data[arch]["T"].squeeze()
284 |         logger.debug((" " * 4) + f"arch = {arch}")
285 |         if len(metric_data[arch]["T"].shape) == 0:
286 |             metric_data[arch]["T"] = metric_data[arch]["T"].unsqueeze(0)
287 |         T_min = int(metric_data[arch]["T"][0])
288 |         T_max = int(metric_data[arch]["T"][-1])
289 |         # T_min = 0
290 |         # T_max = 88
291 |         sm = plt.cm.ScalarMappable(
292 |             cmap=cmap, norm=plt.Normalize(vmin=T_min, vmax=T_max)
293 |         )
294 |         colors = sm.to_rgba(metric_data[arch]["T"])
295 |         for i, t in enumerate(metric_data[arch]["T"]):
296 |             if ds_len is None:
297 |                 steps_per_epoch = 1
298 |             else:
299 |                 train_rows, val_rows = ArithmeticDataset.calc_split_len(
300 |                     t.item(), ds_len
301 |                 )
302 |                 steps_per_epoch = math.ceil(train_rows / batchsize)
303 | 
304 |             logger.debug((" " * 4) + f"t = {t}")
305 |             # print(
306 |             #    f"metric_data[arch][metric].shape = {metric_data[arch][metric].shape}"
307 |             # )
308 |             Y = metric_data[arch][metric][i]
309 |             # print(f"Y = {Y}")
310 |             assert len(Y.shape) == 1, f"Y.shape = {Y.shape} is invalid"
311 |             X = torch.arange(1, Y.shape[0] + 1) * steps_per_epoch
312 |             assert len(X.shape) == 1, f"X.shape = {X.shape} is invalid"
313 | 
314 |             label = arch + f" t={t}"
315 | 
316 |             # ax.set_xlim(left=X[0], right=X[-1] + 1)
317 |             if metric == "val_loss" and inflection_hline:
318 |                 Y_infs = find_inflections(Y)
319 |                 ax.axhline(y=Y[Y_infs[0]], color="orange")
320 |             if metric == "val_accuracy":
321 |                 label += " (max = %.2f)" % max(Y)
322 |             total_plots += 1
323 |             ax.plot(X, Y, label=label, color=colors[i])
324 |     if T_max - T_min <= 10:
325 |         pass
326 |         ax.legend()
327 |     else:
328 |         fig.colorbar(
329 |             sm,
330 |             ax=ax,
331 |             label="% training data",
332 |             ticks=range(T_min, T_max, int((T_max - T_min) / 5)),
333 |         )
334 | 
335 | 
336 | def add_comm_graph(
337 |     ax, metric, kind, comm_data, arch, scales=default_axis_scales, cmap=DEFAULT_CMAP
338 | ):
339 |     assert metric in (
340 |         "loss",
341 |         "accuracy",
342 |         "perplexity",
343 |     )
344 |     assert kind in (
345 |         "comm",
346 |         "non_comm",
347 |         "modulo",
348 |         "non_modulo",
349 |         "assoc",
350 |         "non_assoc",
351 |         "zero",
352 |         "non_zero",
353 |     )
354 |     ax.set_title(metric)
355 |     ax.set_xscale(scales["x"])
356 |     ax.set_yscale(scales["y"])
357 |     ax.set_xlabel("epochs")
358 |     if "accuracy" in metric:
359 |         ax.yaxis.set_major_formatter(mtick.PercentFormatter())
360 |     X = [int(r["epoch"]) for r in comm_data]
361 |     Y = torch.tensor(
362 |         (
363 |             [float(r["comm" + "_" + metric]) for r in comm_data],
364 |             [float(r["non_comm" + "_" + metric]) for r in comm_data],
365 |             # [float(r["assoc" + "_" + metric]) for r in comm_data],
366 |             # [float(r["non_assoc" + "_" + metric]) for r in comm_data],
367 |             # [float(r["zero" + "_" + metric]) for r in comm_data],
368 |             # [float(r["non_zero" + "_" + metric]) for r in comm_data],
369 |         )
370 |     )
371 |     # label = kind
372 |     # if kind.endswith("comm"):
373 |     #    label += "utative"
374 | 
375 |     labels = ["commutative", "non-commutative"]
376 |     # labels = ["zero", "non_zero"]
377 |     # labels = ["associative", "non_associative"]
378 |     # label = f"{arch} {kind}_{metric}"
379 |     # ax.plot(X, Y, label=label)
380 |     sm = plt.cm.ScalarMappable(cmap="cividis", norm=plt.Normalize(vmin=0, vmax=len(Y)))
381 |     colors = sm.to_rgba(range(len(Y)))
382 |     ax.stackplot(X, Y, baseline="zero", labels=labels, colors=colors)
383 |     ax.legend()
384 | 
385 | 
386 | def add_extremum_graph(
387 |     ax,
388 |     metric,
389 |     kind,
390 |     metric_data,
391 |     scales=default_axis_scales,
392 |     epochs=[-1],
393 |     show_legend=True,
394 | ):
395 |     assert kind in ("max", "min")
396 |     ax.set_title(f"{kind} {metric}")
397 |     ax.set_xlabel("training data")
398 |     ax.xaxis.set_major_formatter(mtick.PercentFormatter())
399 |     xmin = 0
400 |     xmax = 100
401 |     ax.axis(xmin=xmin, xmax=xmax)
402 | 
403 |     # ax.set_ylabel(metric)
404 |     ax.set_xscale(scales["x"])
405 |     ax.set_yscale(scales["y"])
406 |     if "accuracy" in metric:
407 |         ax.yaxis.set_major_formatter(mtick.PercentFormatter())
408 |         ymin = -1
409 |         ymax = 105
410 |         ax.axis(ymin=ymin, ymax=ymax)
411 | 
412 |     # if 'learning' in metric:
413 |     #    ymin=0
414 |     #    ymax=0.002
415 |     #    ax.axis(ymin=ymin, ymax=ymax)
416 | 
417 |     plots = {}
418 | 
419 |     total_plots = 0
420 |     for arch in metric_data:
421 |         X = metric_data[arch]["T"]
422 |         if kind == "max":
423 |             Y = torch.max(
424 |                 metric_data[arch][metric], dim=1, keepdim=True
425 |             ).values.squeeze()
426 |         elif kind == "min":
427 |             Y = torch.min(
428 |                 metric_data[arch][metric], dim=1, keepdim=True
429 |             ).values.squeeze()
430 | 
431 |         # ax.set_xlim(0, 100)
432 |         ax.set_xticks(np.arange(0, 100, 5))
433 |         label = f"{kind} {metric} {arch}"
434 |         ax.plot(X, Y, label=label)
435 |         total_plots += 1
436 | 
437 |     if show_legend and total_plots <= 12:
438 |         ax.legend()
439 |         pass
440 | 
441 | 
442 | def add_inflection_graphs(
443 |     ax, metric, metric_data, scales=default_axis_scales, smoothing_steps=100
444 | ):
445 |     ax.set_title(f"{metric} inflections by train_data_pct")
446 |     ax.set_xlabel("train_data_pct")
447 |     ax.set_ylabel(f"{metric} inflections")
448 |     ax.set_xscale(scales["x"])
449 |     ax.set_yscale(scales["y"])
450 |     if "accuracy" in metric:
451 |         ymin = 0
452 |         ymax = 100
453 |         ax.axis(xmin=0, xmax=87.5, ymin=ymin, ymax=ymax)
454 |     if "learning" in metric:
455 |         ymin = 0
456 |         ymax = 0.002
457 |         ax.axis(xmin=0, xmax=87.5, ymin=ymin, ymax=ymax)
458 | 
459 |     total_plots = 0
460 |     for arch in metric_data:
461 |         for num in range(5):
462 |             for i, t in enumerate(metric_data[arch]["T"]):
463 |                 Y = metric_data[arch][metric][i]
464 |                 X = torch.arange(Y.shape[-1])
465 |                 inflections = find_inflections(Y, smoothing_steps=smoothing_steps)
466 |                 ax.plot(X[inflections], Y[inflections], label=f"{arch} t={t}")
467 |                 total_plots += 1
468 | 
469 |     if total_plots <= 12:
470 |         ax.legend()
471 |         pass
472 | 
473 | 
474 | def colorbar(mappable, ticks=None, labels=None):
475 |     last_axes = plt.gca()
476 |     ax = mappable.axes
477 |     fig = ax.figure
478 |     divider = make_axes_locatable(ax)
479 |     cax = divider.append_axes("right", size="5%", pad=0.1)
480 |     cbar = fig.colorbar(mappable, cax=cax, ticks=ticks)
481 |     if labels is not None:
482 |         cbar.ax.set_yticklabels(labels)  # vertically oriented colorbar
483 |     plt.sca(last_axes)
484 |     return cbar
485 | 
486 | 
487 | def add_matshow(
488 |     fig, ax, t, name, vmin=0, vmax=100, cmap=DEFAULT_CMAP, show_colorbar=True
489 | ):
490 |     sides = ("left", "right", "top", "bottom")
491 |     labels = {
492 |         "left": True,
493 |         "right": False,
494 |         "top": False,
495 |         "bottom": True,
496 |         "labelleft": True,
497 |         "labelright": False,
498 |         "labeltop": False,
499 |         "labelbottom": True,
500 |     }
501 |     m = ax.matshow(
502 |         t.cpu().detach().numpy(), vmin=vmin, vmax=vmax, origin="lower", cmap=cmap
503 |     )
504 |     # c = ax.pcolor(t.cpu(), vmin=vmin, vmax=vmax, cmap=cmap)
505 |     ax.set_title(name)
506 |     ax.set_xlabel("A")
507 |     ax.set_ylabel("B")
508 |     ax.set_xticks(np.arange(0, t.shape[1], 10))
509 |     # ax.set_xticklabels(np.arange(1, t.shape[1]+1))
510 |     # ax.set_yticks(np.arange(0.5, t.shape[0] + .5, 1))
511 |     ax.set_yticks(np.arange(0, t.shape[0], 10))
512 |     # ax.set_yticks(np.arange(t.shape[0]))
513 |     # ax.set_yticklabels(np.arange(1, t.shape[0]+1))
514 |     ax.tick_params(axis="both", which="both", **labels)
515 |     if show_colorbar:
516 |         colorbar(m)
517 | 


--------------------------------------------------------------------------------
/nbs/flatness.ipynb:
--------------------------------------------------------------------------------
 1 | {
 2 |  "cells": [
 3 |   {
 4 |    "cell_type": "code",
 5 |    "execution_count": null,
 6 |    "id": "geographic-personal",
 7 |    "metadata": {},
 8 |    "outputs": [],
 9 |    "source": [
10 |     "import os\n",
11 |     "import pickle\n",
12 |     "import matplotlib.pyplot as plt\n",
13 |     "\n",
14 |     "def paired_sort(list1, list2):\n",
15 |     "    list1, list2 = zip(*sorted(zip(list1, list2)))\n",
16 |     "    return list1, list2\n",
17 |     "\n",
18 |     "def plot_phi_by_ckpt():\n",
19 |     "\n",
20 |     "    nums = []\n",
21 |     "    flatness = []\n",
22 |     "\n",
23 |     "    for f in sorted(os.listdir(\"../results/\")):\n",
24 |     "        if \"pkl\" in f:\n",
25 |     "            num = int(f.split(\"-\")[1].split(\".pkl\")[0])\n",
26 |     "            dat = pickle.load(open(os.path.join(\"../results/\", f), \"rb\"))\n",
27 |     "            nums.append(num)\n",
28 |     "            flatness.append(dat[list(dat.keys())[0]].item())\n",
29 |     "\n",
30 |     "            \n",
31 |     "    nums, flatness = paired_sort(nums, flatness)\n",
32 |     "    plt.plot(nums, flatness)\n",
33 |     "    plt.xticks(range(len(nums)))\n",
34 |     "    plt.ylabel(\"phi\")\n",
35 |     "    plt.xlabel(\"SD-{n} checkpoint\")\n",
36 |     "    plt.savefig(\"phi-by-ckpt.png\")"
37 |    ]
38 |   },
39 |   {
40 |    "cell_type": "code",
41 |    "execution_count": null,
42 |    "id": "mineral-assembly",
43 |    "metadata": {},
44 |    "outputs": [],
45 |    "source": [
46 |     "plot_phi_by_ckpt()"
47 |    ]
48 |   }
49 |  ],
50 |  "metadata": {
51 |   "kernelspec": {
52 |    "display_name": "Python 3",
53 |    "language": "python",
54 |    "name": "python3"
55 |   },
56 |   "language_info": {
57 |    "codemirror_mode": {
58 |     "name": "ipython",
59 |     "version": 3
60 |    },
61 |    "file_extension": ".py",
62 |    "mimetype": "text/x-python",
63 |    "name": "python",
64 |    "nbconvert_exporter": "python",
65 |    "pygments_lexer": "ipython3",
66 |    "version": "3.6.13"
67 |   }
68 |  },
69 |  "nbformat": 4,
70 |  "nbformat_minor": 5
71 | }
72 | 


--------------------------------------------------------------------------------
/scripts/compute_sharpness.py:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env python
 2 | 
 3 | import os
 4 | import grok
 5 | 
 6 | parser = grok.training.add_args()
 7 | parser.set_defaults(logdir=os.environ.get("LOGDIR", "."))
 8 | hparams = parser.parse_args()
 9 | hparams.datadir = os.path.abspath(hparams.datadir)
10 | hparams.logdir = os.path.abspath(hparams.logdir)
11 | 
12 | 
13 | print(hparams)
14 | 
15 | ckpts = [f"./ckpts/L-2_H-4_D-128_T-70_DROP-0_SD-{i}_WU-10_LR-1p0.ckpt" for i in range(20)]
16 | print(grok.training.compute_sharpness(hparams, ckpts))
17 | 


--------------------------------------------------------------------------------
/scripts/create_metric_graphs.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/env python
  2 | # coding: utf-8
  3 | 
  4 | # Render metrics graphs
  5 | 
  6 | import csv
  7 | import logging
  8 | import os
  9 | import glob
 10 | import socket
 11 | from argparse import ArgumentParser
 12 | 
 13 | from collections import defaultdict
 14 | 
 15 | import matplotlib.pyplot as plt
 16 | import matplotlib.ticker as mtick
 17 | import numpy as np
 18 | import torch
 19 | 
 20 | from mpl_toolkits.axes_grid1 import make_axes_locatable
 21 | from tqdm import tqdm
 22 | 
 23 | from sklearn.manifold import TSNE
 24 | 
 25 | import grok
 26 | from grok.visualization import *
 27 | 
 28 | # from grok_runs import RUNS
 29 | 
 30 | logging.basicConfig(level=logging.ERROR)
 31 | logger = logging.getLogger("grok.view_metrics")
 32 | logger.setLevel(logging.ERROR)
 33 | 
 34 | RUNS = {
 35 |     "subtraction": (
 36 |         9409,
 37 |         "subtraction/2021-02-05-03-33-56-alethea-sjjf",
 38 |     ),
 39 | }
 40 | 
 41 | 
 42 | limits = {
 43 |     "min_val_accuracy": 0,
 44 |     "max_val_accuracy": 100,
 45 |     "min_T": 0,  # 0
 46 |     "max_T": 100,  # 87.5
 47 |     "min_D": 0,  # 8
 48 |     "max_D": 256,  # 256
 49 |     "min_H": 0,  # 1
 50 |     "max_H": 4,  # 8
 51 |     "min_L": 0,  # 1
 52 |     "max_L": 4,  # 4
 53 |     "min_accuracy": 0,
 54 |     "max_accuracy": 100,
 55 | }
 56 | 
 57 | for k in limits.keys():
 58 |     metric = k.replace("min_", "").replace("max_", "")
 59 |     assert (
 60 |         limits["max_" + metric] >= limits["min_" + metric]
 61 |     ), f"invalid {metric} limits"
 62 | 
 63 | 
 64 | parser = ArgumentParser()
 65 | parser.add_argument("-i", "--image_dir", type=str, default=IMAGE_DIR)
 66 | args = parser.parse_args()
 67 | 
 68 | 
 69 | def create_loss_curves(
 70 |     metric_data,
 71 |     epochs,
 72 |     run,
 73 |     most_interesting_only=False,
 74 |     image_dir=args.image_dir,
 75 |     ds_len=None,
 76 |     cmap=DEFAULT_CMAP,
 77 | ):
 78 |     scales = {
 79 |         "x": "log",
 80 |         "y": "linear",
 81 |     }
 82 | 
 83 | 
 84 |     arch = list(metric_data.keys())[0]
 85 | 
 86 |     ncols = 2
 87 |     nrows = 3
 88 |     fig_width = ncols * 8
 89 |     fig_height = nrows * 5
 90 |     fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height))
 91 | 
 92 |     add_metric_graph(
 93 |         fig, axs[0, 0], "val_loss", metric_data, scales, cmap=cmap, ds_len=ds_len
 94 |     )
 95 |     add_metric_graph(
 96 |         fig, axs[0, 1], "val_accuracy", metric_data, scales, cmap, ds_len=ds_len
 97 |     )
 98 |     add_metric_graph(
 99 |         fig, axs[1, 0], "train_loss", metric_data, scales, cmap, ds_len=ds_len
100 |     )
101 |     add_metric_graph(
102 |         fig, axs[1, 1], "train_accuracy", metric_data, scales, cmap, ds_len=ds_len
103 |     )
104 |     add_metric_graph(
105 |         fig,
106 |         axs[2, 0],
107 |         "learning_rate",
108 |         metric_data,
109 |         scales,
110 |         cmap,  # ds_len=ds_len
111 |     )
112 |     fig.suptitle(f"{operation} {list(data.keys())[0]}")
113 |     fig.tight_layout()
114 | 
115 |     img_file = f"{image_dir}/loss_curves/{operation}_loss_curves_{arch}"
116 |     if ds_len is not None:
117 |         img_file += "_by_update"
118 |     if most_interesting_only:
119 |         img_file += "_most_interesting"
120 |     img_file += ".png"
121 |     d = os.path.split(img_file)[0]
122 |     os.makedirs(d, exist_ok=True)
123 |     print(f"Writing {img_file}")
124 |     fig.savefig(img_file)
125 |     plt.close(fig)
126 | 
127 | 
128 | def create_max_accuracy_curves(
129 |     metric_data, epochs, run, image_dir=args.image_dir, ds_len=None
130 | ):
131 |     scales = {
132 |         "x": "linear",
133 |         "y": "linear",
134 |     }
135 | 
136 |     ncols = 1
137 |     nrows = 2
138 |     fig_width = ncols * 8
139 |     fig_height = nrows * 5
140 |     fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height))
141 | 
142 |     def get_ax(row=0, col=0, nrows=nrows, ncols=ncols, axs=axs):
143 |         if nrows == 0:
144 |             if ncols == 1:
145 |                 return axs
146 |             else:
147 |                 return axs[col]
148 |         else:
149 |             if ncols == 1:
150 |                 return axs[row]
151 |             else:
152 |                 return axs[row, col]
153 | 
154 |     add_extremum_graph(
155 |         get_ax(0, 0), "val_accuracy", "max", metric_data, show_legend=False
156 |     )
157 |     add_extremum_graph(
158 |         get_ax(1, 0), "train_accuracy", "max", metric_data, show_legend=False
159 |     )
160 |     fig.suptitle(f"{operation} {list(data.keys())[0]}")
161 |     fig.tight_layout()
162 | 
163 |     expt = list(metric_data.keys())[0]
164 |     img_file = f"{image_dir}/max_accuracy/{operation}_max_accuracy_{arch}.png"
165 |     d = os.path.split(img_file)[0]
166 |     os.makedirs(d, exist_ok=True)
167 |     print(f"Writing {img_file}")
168 |     fig.savefig(img_file)
169 |     plt.close(fig)
170 | 
171 | 
172 | def create_tsne_graphs(operation, expt, run_dir, image_dir=args.image_dir):
173 | 
174 |     saved_pt_dir = f"{run_dir}/activations"
175 |     saved_pts = []
176 | 
177 |     loss_ts = []
178 |     accuracy_ts = []
179 |     epochs_ts = []
180 |     print(f'glob = {saved_pt_dir + "/activations_*.pt"}')
181 |     files = sorted(glob.glob(saved_pt_dir + "/activations_*.pt"))
182 |     print(f"files = {files}")
183 | 
184 |     for file in files:
185 |         print(f"Loading {file}")
186 |         saved_pt = torch.load(file)
187 |         saved_pts.append(saved_pt)
188 |         loss_ts.append(saved_pt["val_loss"].mean(dim=-1))
189 |         accuracy_ts.append(saved_pt["val_accuracy"])
190 |         epochs_ts.append(saved_pt["epochs"].squeeze())
191 | 
192 |     loss_t = torch.cat(loss_ts, dim=0).T.detach()
193 |     accuracy_t = torch.cat(accuracy_ts, dim=0).T.detach()
194 |     epochs_t = torch.cat(epochs_ts, dim=0).detach()
195 |     print(loss_t.shape)
196 |     print(accuracy_t.shape)
197 |     print(epochs_t.shape)
198 |     ######
199 |     a = 0
200 |     num_eqs = len(loss_t)
201 |     b = a + num_eqs
202 | 
203 |     print("Doing T-SNE..")
204 |     loss_tsne = TSNE(n_components=2, init="pca").fit_transform(loss_t)
205 |     print("...done T-SNE.")
206 | 
207 |     ncols = 1
208 |     nrows = 1
209 |     fig_width = ncols * 8
210 |     fig_height = nrows * 5
211 |     fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height))
212 | 
213 |     axs.scatter(loss_tsne[:, 0], loss_tsne[:, 1])
214 | 
215 |     img_file = f"{image_dir}/tsne/{operation}_{expt}.png"
216 |     d = os.path.split(img_file)[0]
217 |     os.makedirs(d, exist_ok=True)
218 |     print(f"Writing {img_file}")
219 |     fig.savefig(img_file)
220 |     plt.close(fig)
221 | 
222 | 
223 | for operation in RUNS:
224 |     print("")
225 |     print("")
226 |     print(f"Processing {operation}", flush=True)
227 | 
228 |     if operation.endswith("-epochs"):
229 |         epochs = int(operation.split("/")[-1].split("-")[0])
230 |     else:
231 |         epochs = 5000
232 | 
233 |     ####
234 | 
235 |     ds_len, run = RUNS[operation]
236 | 
237 | 
238 |     data = load_metric_data(f"{DATA_DIR}/{run}", epochs=epochs, load_partial_data=False)
239 | 
240 |     # check it
241 |     for arch in data:
242 |         # print(data[arch]["metrics"].shape)
243 |         metrics, expts, epochs = data[arch]["metrics"].shape
244 |         message = (
245 |             f"{arch} : loaded {metrics} metrics, {expts} experiments, {epochs} epochs"
246 |         )
247 |         assert metrics == 5, "INVALID metrics count: " + message
248 |         assert expts < 88, "INVALID experiments count: " + message
249 |         assert epochs == epochs, f"INVALID epochs count: " + message
250 |         print(message)
251 | 
252 |     # ## Set filters on the data to view
253 | 
254 |     metric_data = get_metric_data(data, limits)
255 | 
256 |     # Draw loss and accuracy curves
257 | 
258 |     create_max_accuracy_curves(metric_data, epochs, run)
259 | 
260 |     create_loss_curves(metric_data, epochs, run)
261 |     create_loss_curves(metric_data, epochs, run, ds_len=ds_len)
262 | 
263 |     most_interesting_metric_data = most_interesting(metric_data)
264 | 
265 |     create_loss_curves(
266 |         most_interesting_metric_data, epochs, run, most_interesting_only=True
267 |     )
268 |     create_loss_curves(
269 |         most_interesting_metric_data,
270 |         epochs,
271 |         run,
272 |         most_interesting_only=True,
273 |         ds_len=ds_len,
274 |     )
275 | 
276 |     # Draw max accuracy curves
277 | 
278 |     # T-SNE of loss curves:
279 |     try:
280 |         for arch in most_interesting_metric_data:
281 |             t = int(most_interesting_metric_data[arch]["T"][0].item())
282 |             expt = f"{arch}_T-{t}_DROP-0.0"
283 |             create_tsne_graphs(operation, expt, run_dir=f"{DATA_DIR}/{run}/{expt}")
284 |     except:
285 |         print("TSNE failed")
286 | 


--------------------------------------------------------------------------------
/scripts/create_metrics_for_epochs.py:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env python
 2 | 
 3 | import logging
 4 | 
 5 | logging.basicConfig(level=logging.ERROR)
 6 | import csv
 7 | import copy
 8 | import os
 9 | import grok
10 | import numpy as np
11 | import sys
12 | import subprocess
13 | import torch
14 | from torch.multiprocessing import Process
15 | from grok import trainer
16 | from tqdm import tqdm
17 | from argparse import ArgumentParser
18 | from collections import Counter
19 | 
20 | 
21 | torch.multiprocessing.freeze_support()
22 | try:
23 |     torch.multiprocessing.set_start_method("spawn")
24 | except RuntimeError:
25 |     pass
26 | 
27 | # Get args
28 | EPOCHS = (
29 |     list(range(10))
30 |     + list(range(10, 200, 2))
31 |     + list(range(200, 5000, 10))
32 |     + list(range(5000, 10000, 50))
33 |     + [10000]
34 | )
35 | 
36 | parser = ArgumentParser()
37 | parser.add_argument(
38 |     "--data_dir", type=str, help="where to find the runs", required=True
39 | )
40 | parser.add_argument("--expt", type=str, default=None)
41 | parser.add_argument("--epochs_per_run", type=int, default=40)
42 | 
43 | 
44 | def parent(expts):
45 |     for expt in expts:
46 |         print(f"Processing {expt}")
47 |         all_results = {}
48 |         for first_epoch in range(0, len(EPOCHS), hparams.epochs_per_run):
49 |             these_epochs = [
50 |                 str(e)
51 |                 for e in EPOCHS[first_epoch : first_epoch + hparams.epochs_per_run]
52 |             ]
53 |             expt_dir = data_dir + "/" + expt
54 |             cmd = [
55 |                 "./create_partial_metrics.py",
56 |                 f"--gpu={hparams.gpu}",
57 |                 f"--expt_dir={expt_dir}",
58 |                 f'--epochs={",".join(these_epochs)}',
59 |             ]
60 |             result = subprocess.run(cmd, capture_output=False, shell=False)
61 |             if result.returncode != 0:
62 |                 sys.exit(result.returncode)
63 | 
64 | 
65 | hparams = trainer.get_args(parser)
66 | 
67 | data_dir = hparams.data_dir
68 | 
69 | if hparams.expt is not None:
70 |     expts = [hparams.expt]
71 | else:
72 |     expts = os.listdir(data_dir)
73 | 
74 | parent(expts)
75 | 


--------------------------------------------------------------------------------
/scripts/create_partial_metrics.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/env python
  2 | 
  3 | import logging
  4 | 
  5 | logging.basicConfig(level=logging.ERROR)
  6 | import csv
  7 | import copy
  8 | import glob
  9 | import os
 10 | import grok
 11 | import numpy as np
 12 | import subprocess
 13 | import torch
 14 | import sys
 15 | from torch.multiprocessing import Process
 16 | from grok import trainer
 17 | from tqdm import tqdm
 18 | from argparse import ArgumentParser
 19 | from collections import Counter
 20 | from grok_runs import RUNS
 21 | from grok_metrics_lib import (
 22 |     DATA_DIR,
 23 |     load_metric_data,
 24 |     get_metric_data,
 25 |     most_interesting,
 26 | )
 27 | 
 28 | 
 29 | # Make N_EPOCHS exponentially spaced sets of epochs from 1 to 10,000
 30 | N_EPOCHS = 32
 31 | BASE = 9999 ** (1.0 / (N_EPOCHS - 1))
 32 | epochs = (BASE ** torch.arange(1, N_EPOCHS).float()).long().tolist()
 33 | DEFAULT_EPOCHS = ",".join([str(i) for i in epochs])
 34 | 
 35 | parser = ArgumentParser()
 36 | parser.add_argument("--expt_dir", type=str, help="where to find the runs")
 37 | parser.add_argument("--epochs", type=str, default=DEFAULT_EPOCHS)
 38 | 
 39 | 
 40 | def child(hparams):
 41 |     expt_dir = hparams.expt_dir
 42 |     epochs = [int(e) for e in hparams.epochs.split(",")]
 43 |     # print("epochs = ", epochs)
 44 |     device = torch.device(f"cuda:{hparams.gpu}")
 45 |     ckpt_dir = expt_dir + "/" + "checkpoints"
 46 |     # ckpt_files = [ckpt_dir + f"/epoch={epoch}.ckpt" for epoch in epochs]
 47 |     hparams.logdir = expt_dir
 48 | 
 49 |     results = {
 50 |         "val_loss": None,
 51 |         "val_accuracy": None,
 52 |     }
 53 | 
 54 |     processed_epochs = []
 55 |     # with tqdm(epochs, unit="epochs", initial=epochs[0], total=epochs[-1]) as pbar:
 56 |     #    last_epoch = epochs[0]
 57 |     for idx, epoch in tqdm(list(enumerate(epochs))):
 58 |         # pbar.update(epoch - last_epoch)
 59 |         # last_epoch = epoch
 60 |         ckpt_files = glob.glob(ckpt_dir + f"/epoch={epoch}-step=*.ckpt")
 61 |         ckpt_files += glob.glob(ckpt_dir + f"/epoch={epoch}.ckpt")
 62 |         try:
 63 |             ckpt_file = ckpt_files[-1]
 64 |             ckpt = torch.load(
 65 |                 ckpt_file,
 66 |                 map_location=f"cuda:{0}",  # FIXME
 67 |             )
 68 |             processed_epochs.append(epoch)
 69 |         except FileNotFoundError:
 70 |             continue
 71 | 
 72 |         for k, v in ckpt["hyper_parameters"].items():
 73 |             setattr(hparams, k, v)
 74 | 
 75 |         new_state_dict = {}
 76 |         for k, v in ckpt["state_dict"].items():
 77 |             if k.startswith("transformer."):
 78 |                 new_state_dict[k] = v
 79 |             else:
 80 |                 new_state_dict["transformer." + k] = v
 81 |         ckpt["state_dict"] = new_state_dict
 82 | 
 83 |         model = trainer.TrainableTransformer(hparams).float()
 84 |         model.load_state_dict(ckpt["state_dict"])
 85 |         model = model.to(device).eval()
 86 |         dl = model.test_dataloader()
 87 |         dl.reset_iteration(shuffle=False)
 88 | 
 89 |         outputs = [model.test_step(batch, idx) for (idx, batch) in enumerate(dl)]
 90 |         r = model.test_epoch_end(outputs)["log"]
 91 |         if results["val_loss"] is None:
 92 |             results["val_loss"] = r["test_loss"].squeeze().unsqueeze(0)
 93 |             results["val_accuracy"] = r["test_accuracy"].squeeze().unsqueeze(0)
 94 |         else:
 95 |             results["val_loss"] = torch.cat(
 96 |                 [results["val_loss"], r["test_loss"].squeeze().unsqueeze(0)], dim=0
 97 |             )
 98 |             results["val_accuracy"] = torch.cat(
 99 |                 [
100 |                     results["val_accuracy"],
101 |                     r["test_accuracy"].squeeze().unsqueeze(0),
102 |                 ],
103 |                 dim=0,
104 |             )
105 | 
106 |     for k, v in results.items():
107 |         results[k] = v.to("cpu")
108 |     results["epochs"] = torch.LongTensor(processed_epochs, device="cpu")
109 |     results["dl"] = dl
110 | 
111 |     os.makedirs(expt_dir + "/activations", exist_ok=True)
112 |     ptfile = (
113 |         expt_dir + f"/activations/activations_{epochs[0]:010d}_{epochs[-1]:010d}.pt"
114 |     )
115 |     torch.save(results, ptfile)
116 | 
117 | 
118 | if __name__ == "__main__":
119 |     hparams = trainer.get_args(parser)
120 |     if hparams.expt_dir is not None:
121 |         child(hparams)
122 |     else:
123 |         for operation in RUNS:
124 |             print(f"running {operation}")
125 |             ds_len, run = RUNS[operation]
126 |             data = load_metric_data(
127 |                 f"{DATA_DIR}/{run}", epochs=10000, load_partial_data=False
128 |             )
129 |             metric_data = get_metric_data(data)
130 |             metric_data = most_interesting(metric_data)
131 |             for arch in metric_data:
132 |                 interesting_t = int(metric_data[arch]["T"][0].item())
133 |                 expt = f"{arch}_T-{interesting_t}"
134 |                 print(f"--> expt {expt}")
135 |                 glb = f"{DATA_DIR}/{run}/{expt}_*"
136 |                 # print(f"glb {glb}")
137 |                 expt_dir = glob.glob(glb)[0]
138 |                 cmd = [sys.argv[0], "--expt_dir", expt_dir]
139 |                 subprocess.run(cmd, check=False, shell=False)
140 |                 # child(hparams)
141 | 


--------------------------------------------------------------------------------
/scripts/make_data.py:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env python
 2 | 
 3 | from argparse import ArgumentParser
 4 | from grok.data import create_data_files, DEFAULT_DATA_DIR
 5 | 
 6 | 
 7 | parser = ArgumentParser()
 8 | parser.add_argument("-d", "--data_directory", type=str, default=DEFAULT_DATA_DIR)
 9 | args = parser.parse_args()
10 | create_data_files(args.data_directory)


--------------------------------------------------------------------------------
/scripts/torch-setup.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | 
 3 | # Set up torch with magma support
 4 | 
 5 | DIR="`mktemp -d`"
 6 | cd $DIR
 7 | 
 8 | # Install deps
 9 | conda install -y numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses
10 | 
11 | # Install torch, magma
12 | conda install -y -c pytorch magma-cuda110
13 | 
14 | # Build torch from scratch
15 | git clone --recursive https://github.com/pytorch/pytorch
16 | cd pytorch
17 | 
18 | # If updating an existing checkout
19 | # git submodule sync
20 | # git submodule update --init --recursive
21 | 
22 | export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
23 | python setup.py install


--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env python
 2 | 
 3 | import grok
 4 | import os
 5 | 
 6 | parser = grok.training.add_args()
 7 | parser.set_defaults(logdir=os.environ.get("GROK_LOGDIR", "."))
 8 | hparams = parser.parse_args()
 9 | hparams.datadir = os.path.abspath(hparams.datadir)
10 | hparams.logdir = os.path.abspath(hparams.logdir)
11 | 
12 | 
13 | print(hparams)
14 | print(grok.training.train(hparams))
15 | 


--------------------------------------------------------------------------------
/scripts/visualize_metrics.py:
--------------------------------------------------------------------------------
  1 | #!/usr/bin/env python
  2 | # coding: utf-8
  3 | import csv
  4 | import json
  5 | import logging
  6 | import os
  7 | import subprocess
  8 | from argparse import ArgumentParser
  9 | from copy import deepcopy
 10 | from glob import glob
 11 | from pprint import pprint
 12 | 
 13 | import blobfile as bf
 14 | import grok
 15 | import matplotlib.pyplot as plt
 16 | import matplotlib.ticker as mtick
 17 | import numpy as np
 18 | import torch
 19 | import yaml
 20 | from tqdm import tqdm
 21 | 
 22 | logger = logging.getLogger(__name__)
 23 | 
 24 | # take args: input_dir output_dir
 25 | parser = ArgumentParser()
 26 | parser.add_argument(
 27 |     "-i",
 28 |     "--input_dir",
 29 |     type=str,
 30 |     required=True,
 31 | )
 32 | parser.add_argument(
 33 |     "-o",
 34 |     "--output_dir",
 35 |     type=str,
 36 |     required=True,
 37 | )
 38 | parser = grok.training.add_args(parser)
 39 | args = parser.parse_args()
 40 | print(args, flush=True)
 41 | 
 42 | if torch.cuda.is_available():
 43 |     device = "cuda"
 44 | else:
 45 |     device = "cpu"
 46 | 
 47 | 
 48 | def load_expt_metrics(
 49 |     expt_dir,
 50 |     args,
 51 | ):
 52 |     """load the metrics for one experiment"""
 53 |     args = deepcopy(args)
 54 | 
 55 |     # load the hparams for this experiment
 56 |     with open(f"{expt_dir}/default/version_0/hparams.yaml", "r") as fh:
 57 |         hparams_dict = yaml.safe_load(fh)
 58 | 
 59 |     for k, v in hparams_dict.items():
 60 |         setattr(args, k, v)
 61 | 
 62 |     # load the summarized validation and training data for every epoch
 63 |     val_data = {
 64 |         "step": [],
 65 |         "epoch": [],
 66 |         "val_loss": [],
 67 |         "val_accuracy": [],
 68 |     }
 69 |     train_data = {
 70 |         "step": [],
 71 |         "epoch": [],
 72 |         "train_loss": [],
 73 |         "train_accuracy": [],
 74 |         "learning_rate": [],
 75 |     }
 76 | 
 77 |     with open(f"{expt_dir}/default/version_0/metrics.csv", "r") as fh:
 78 |         for row in csv.DictReader(fh):
 79 |             if row["train_loss"] != "":
 80 |                 for k in train_data:
 81 |                     if k in ["step", "epoch"]:
 82 |                         v = int(row[k])
 83 |                     else:
 84 |                         v = float(row[k])
 85 |                     train_data[k].append(v)
 86 |             else:
 87 |                 for k in val_data:
 88 |                     if k in ["step", "epoch"]:
 89 |                         v = int(row[k])
 90 |                     else:
 91 |                         v = float(row[k])
 92 |                     val_data[k].append(v)
 93 | 
 94 |     return {
 95 |         "hparams": hparams_dict,
 96 |         "train": train_data,
 97 |         "val": val_data,
 98 |         # "raw": raw_data,
 99 |     }
100 | 
101 | 
102 | def load_run_metrics(
103 |     run_dir,
104 |     args=args,
105 | ):
106 |     """load all the metrics for a collection of experiments with the same architecture
107 |     across various amounts of training data"""
108 |     metric_data = {}
109 |     from os import walk
110 | 
111 |     _, expt_dirs, _ = next(os.walk(run_dir))
112 |     for expt_dir in tqdm(expt_dirs, unit="expt"):
113 |         try:
114 |             expt_data = load_expt_metrics(f"{run_dir}/{expt_dir}", args)
115 |             train_data_pct = expt_data["hparams"]["train_data_pct"]
116 |             metric_data[train_data_pct] = expt_data
117 |         except FileNotFoundError:
118 |             pass
119 |     return metric_data
120 | 
121 | 
122 | def add_metric_graph(
123 |     fig,
124 |     ax,
125 |     arch,
126 |     metric,
127 |     metric_data,
128 |     scales,
129 |     cmap="viridis",
130 |     by="step",  # step or epoch
131 |     max_increment=0,
132 | ):
133 |     ax.set_title(metric)
134 |     ax.set_xscale(scales["x"])
135 |     ax.set_yscale(scales["y"])
136 |     ax.set_xlabel(by)
137 | 
138 |     if "accuracy" in metric:
139 |         ax.yaxis.set_major_formatter(mtick.PercentFormatter())
140 |         ymin = 1e-16
141 |         ymax = 101
142 |         ax.axis(ymin=ymin, ymax=ymax)
143 |     if "loss" in metric:
144 |         ymin = 1e-16
145 |         ymax = 15
146 |         ax.axis(ymin=ymin, ymax=ymax)
147 | 
148 |     total_plots = 0
149 |     logger.debug(f"processing {metric}")
150 |     plots = []
151 |     T = list(sorted(metric_data.keys()))
152 |     T_max = int(T[-1])
153 |     T_min = int(T[0])
154 |     sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=T[0], vmax=T[-1]))
155 |     colors = sm.to_rgba(T)
156 |     for i, t in enumerate(T):
157 |         if "val" in metric:
158 |             this_data = metric_data[t]["val"]
159 |         else:
160 |             this_data = metric_data[t]["train"]
161 | 
162 |         X = this_data[by]
163 |         Y = this_data[metric]
164 |         if max_increment > 0:
165 |             X = [x for x in X if x <= max_increment]
166 |             Y = Y[: len(X)]
167 | 
168 |         if len(X) != len(Y):
169 |             logger.warning(f"Mismatched data: {metric} at t={t}")
170 |             continue
171 |         if not Y:
172 |             logger.warning(f"No data for {metric}i at t={t}")
173 |             continue
174 | 
175 |         label = arch + f" t={t}"
176 | 
177 |         if "accuracy" in metric:
178 |             label += " (max = %.2f)" % max(Y)
179 |         elif "loss" in metric:
180 |             label += " (min = %.2f)" % min(Y)
181 |         total_plots += 1
182 |         ax.plot(X, Y, label=label, color=colors[i])
183 |     if T_max - T_min <= 10:
184 |         ax.legend()
185 |     else:
186 |         fig.colorbar(
187 |             sm,
188 |             ax=ax,
189 |             label="% training data",
190 |             ticks=range(T_min, T_max + 1, int((T_max - T_min) / 5)),
191 |         )
192 | 
193 | 
194 | def add_max_accuracy_graph(
195 |     ax,
196 |     arch,
197 |     metric,
198 |     metric_data,
199 |     scales,
200 |     by="step",
201 |     max_increment=0,
202 | ):
203 |     ax.set_title(f"max {metric}")
204 |     ax.set_xlabel("% of total data trained on")
205 |     ax.xaxis.set_major_formatter(mtick.PercentFormatter())
206 |     xmin = 0
207 |     xmax = 100
208 |     ymin = 1e-16
209 |     ymax = 101
210 |     ax.axis(xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax)
211 |     ax.set_xscale(scales["x"])
212 |     ax.set_yscale(scales["y"])
213 |     ax.yaxis.set_major_formatter(mtick.PercentFormatter())
214 |     ax.xaxis.set_major_formatter(mtick.PercentFormatter())
215 | 
216 |     T = list(sorted(metric_data.keys()))
217 |     T_max = int(T[-1])
218 |     T_min = int(T[0])
219 |     Y = []
220 |     for i, t in enumerate(T):
221 |         if "val" in metric:
222 |             this_data = metric_data[t]["val"]
223 |         else:
224 |             this_data = metric_data[t]["train"]
225 |         X = this_data[by]
226 |         if max_increment > 0:
227 |             X = [x for x in X if x <= max_increment]
228 |             max_idx = len(X)
229 |         else:
230 |             max_idx = -1
231 |         try:
232 |             Y.append(max(this_data[metric][:max_idx]))
233 |         except ValueError:
234 |             Y.append(np.nan)
235 | 
236 |     ax.set_xticks(np.arange(0, 100, 5))
237 |     label = f"max {metric} {arch}"
238 |     ax.plot(T, Y, label=label)
239 | 
240 | 
241 | def create_loss_curves(
242 |     metric_data,
243 |     arch,
244 |     operation,
245 |     # epochs,
246 |     most_interesting_only=False,
247 |     image_dir=args.output_dir,
248 |     by="step",
249 |     max_increment=0,
250 |     cmap="viridis",
251 | ):
252 |     scales = {
253 |         "x": "log",
254 |         "y": "linear",
255 |     }
256 | 
257 |     ncols = 2
258 |     nrows = 3
259 |     fig_width = ncols * 8
260 |     fig_height = nrows * 5
261 |     fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height))
262 | 
263 |     add_metric_graph(
264 |         fig,
265 |         axs[0, 0],
266 |         arch,
267 |         "val_loss",
268 |         metric_data,
269 |         scales,
270 |         cmap,
271 |         by,
272 |         max_increment=max_increment,
273 |     )
274 |     add_metric_graph(
275 |         fig,
276 |         axs[0, 1],
277 |         arch,
278 |         "val_accuracy",
279 |         metric_data,
280 |         scales,
281 |         cmap,
282 |         by,
283 |         max_increment=max_increment,
284 |     )
285 |     add_metric_graph(
286 |         fig,
287 |         axs[1, 0],
288 |         arch,
289 |         "train_loss",
290 |         metric_data,
291 |         scales,
292 |         cmap,
293 |         by,
294 |         max_increment=max_increment,
295 |     )
296 |     add_metric_graph(
297 |         fig,
298 |         axs[1, 1],
299 |         arch,
300 |         "train_accuracy",
301 |         metric_data,
302 |         scales,
303 |         cmap,
304 |         by,
305 |         max_increment=max_increment,
306 |     )
307 |     add_metric_graph(
308 |         fig,
309 |         axs[2, 0],
310 |         arch,
311 |         "learning_rate",
312 |         metric_data,
313 |         scales,
314 |         cmap,
315 |         by,
316 |         max_increment=max_increment,
317 |     )
318 |     fig.suptitle(f"{operation} {arch} {max_increment:06d} {by}s")
319 |     fig.tight_layout()
320 | 
321 |     img_file = f"{image_dir}/loss_curves/{operation}_loss_curves_{arch}__upto_{max_increment:010d}_{by}"
322 |     if most_interesting_only:
323 |         img_file += "_most_interesting"
324 |     img_file += ".png"
325 |     d = os.path.split(img_file)[0]
326 |     os.makedirs(d, exist_ok=True)
327 |     print(f"Writing {img_file}")
328 |     fig.savefig(img_file)
329 |     plt.close(fig)
330 | 
331 | 
332 | def create_max_accuracy_curves(
333 |     metric_data,
334 |     arch,
335 |     operation,
336 |     by="step",
337 |     max_increment=0,
338 |     image_dir=args.output_dir,
339 | ):
340 |     scales = {
341 |         "x": "linear",
342 |         "y": "linear",
343 |     }
344 | 
345 |     ncols = 1
346 |     nrows = 2
347 |     fig_width = ncols * 8
348 |     fig_height = nrows * 5
349 |     fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height))
350 | 
351 |     add_max_accuracy_graph(
352 |         axs[0],
353 |         arch,
354 |         "val_accuracy",
355 |         metric_data,
356 |         scales,
357 |         by=by,
358 |         max_increment=max_increment,
359 |     )
360 |     axs[0].legend()
361 |     add_max_accuracy_graph(
362 |         axs[1],
363 |         arch,
364 |         "train_accuracy",
365 |         metric_data,
366 |         scales,
367 |         by=by,
368 |         max_increment=max_increment,
369 |     )
370 |     axs[1].legend()
371 |     fig.suptitle(f"{operation} {arch} {max_increment:06d} {by}s")
372 |     fig.tight_layout()
373 | 
374 |     img_file = f"{image_dir}/max_accuracy/{operation}_max_accuracy_{arch}_upto_{max_increment:010d}_{by}.png"
375 |     d = os.path.split(img_file)[0]
376 |     os.makedirs(d, exist_ok=True)
377 |     print(f"Writing {img_file}")
378 |     fig.savefig(img_file)
379 |     plt.close(fig)
380 | 
381 | 
382 | def create_tsne_graphs(
383 |     operation,
384 |     expt,
385 |     run_dir,
386 |     image_dir=args.output_dir,
387 | ):
388 | 
389 |     saved_pt_dir = f"{run_dir}/activations"
390 |     saved_pts = []
391 | 
392 |     loss_ts = []
393 |     accuracy_ts = []
394 |     epochs_ts = []
395 |     print(f'glob = {saved_pt_dir + "/activations_*.pt"}')
396 |     files = sorted(glob.glob(saved_pt_dir + "/activations_*.pt"))
397 |     print(f"files = {files}")
398 | 
399 |     for file in files:
400 |         print(f"Loading {file}")
401 |         saved_pt = torch.load(file)
402 |         saved_pts.append(saved_pt)
403 |         loss_ts.append(saved_pt["val_loss"].mean(dim=-1))
404 |         accuracy_ts.append(saved_pt["val_accuracy"])
405 |         epochs_ts.append(saved_pt["epochs"].squeeze())
406 | 
407 |     loss_t = torch.cat(loss_ts, dim=0).T.detach()
408 |     accuracy_t = torch.cat(accuracy_ts, dim=0).T.detach()
409 |     epochs_t = torch.cat(epochs_ts, dim=0).detach()
410 |     print(loss_t.shape)
411 |     print(accuracy_t.shape)
412 |     print(epochs_t.shape)
413 |     ######
414 |     a = 0
415 |     num_eqs = len(loss_t)
416 |     b = a + num_eqs
417 | 
418 |     print("Doing T-SNE..")
419 |     loss_tsne = TSNE(n_components=2, init="pca").fit_transform(loss_t)
420 |     print("...done T-SNE.")
421 | 
422 |     ncols = 1
423 |     nrows = 1
424 |     fig_width = ncols * 8
425 |     fig_height = nrows * 5
426 |     fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height))
427 | 
428 |     axs.scatter(loss_tsne[:, 0], loss_tsne[:, 1])
429 | 
430 |     img_file = f"{image_dir}/tsne/{operation}_{expt}.png"
431 |     d = os.path.split(img_file)[0]
432 |     os.makedirs(d, exist_ok=True)
433 |     print(f"Writing {img_file}")
434 |     fig.savefig(img_file)
435 |     plt.close(fig)
436 | 
437 | 
438 | def get_arch(metric_data):
439 |     k = list(metric_data.keys())[0]
440 |     hparams = metric_data[k]["hparams"]
441 |     arch = f'L-{hparams["n_layers"]}_H-{hparams["n_heads"]}_D-{hparams["d_model"]}_B-{hparams["batchsize"]}_S-{hparams["random_seed"]}_DR-{hparams["dropout"]}'
442 |     return arch
443 | 
444 | 
445 | def get_operation(metric_data):
446 |     k = list(metric_data.keys())[0]
447 |     hparams = metric_data[k]["hparams"]
448 |     operator = hparams["math_operator"]
449 |     operand_length = hparams["operand_length"]
450 |     _, operation = grok.data.ArithmeticDataset.get_file_path(operator, operand_length)
451 |     return operation
452 | 
453 | 
454 | def get_max_epochs(metric_data):
455 |     k = list(metric_data.keys())[0]
456 |     hparams = metric_data[k]["hparams"]
457 |     return hparams["max_epochs"]
458 | 
459 | 
460 | rundir = args.input_dir
461 | 
462 | try:
463 |     metric_data = load_run_metrics(rundir, args)
464 |     arch = get_arch(metric_data)
465 |     operation = get_operation(metric_data)
466 |     max_epochs = get_max_epochs(metric_data)
467 | 
468 |     for by in ["step", "epoch"]:
469 |         create_loss_curves(metric_data, arch, operation, by=by)
470 | 
471 |     by = "epoch"
472 |     last_i = -1
473 |     for i in sorted(list(set(2 ** (np.arange(167) / 10)))):
474 |         if i > max_epochs:
475 |             break
476 |         i = int(round(i))
477 |         create_max_accuracy_curves(
478 |             metric_data,
479 |             arch,
480 |             operation,
481 |             by=by,
482 |             max_increment=i,
483 |         )
484 | 
485 |     # make a video
486 |     in_files = os.path.join(
487 |         args.output_dir,
488 |         "max_accuracy",
489 |         f"{operation}_max_accuracy_{arch}_upto_%*.png",
490 |     )
491 |     out_file = os.path.join(args.output_dir, f"{operation}_{arch}_max_accuracy.mp4")
492 |     cmd = [
493 |         "ffmpeg",
494 |         "-y",
495 |         "-r",
496 |         "16",
497 |         "-i",
498 |         in_files,
499 |         "-vcodec",
500 |         "libx264",
501 |         "-crf",
502 |         "25",
503 |         "-pix_fmt",
504 |         "yuv420p",
505 |         out_file,
506 |     ]
507 |     subprocess.check_call(cmd)
508 | 
509 | except BaseException as e:
510 |     print(f"{rundir} failed: {e}")
511 | 


--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
 1 | from setuptools import find_packages, setup
 2 | 
 3 | setup(
 4 |     name="grok",
 5 |     packages=find_packages(),
 6 |     version="0.0.1",
 7 |     install_requires=[
 8 |         "pytorch_lightning",
 9 |         "blobfile",
10 |         "numpy",
11 |         "torch",
12 |         "tqdm",
13 |         "scipy",
14 |         "mod",
15 |         "matplotlib",
16 |     ],
17 | )
18 | 


--------------------------------------------------------------------------------