├── .gitignore ├── LICENSE ├── README.md ├── primitives └── functional.py ├── rasp ├── __init__.py ├── core.py ├── daily.py ├── manual.py ├── model.py └── parser.py ├── setup.py └── test ├── core_test.py ├── manual_test.py ├── model_test.py └── transformer_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | daily.py 132 | .rasp_cache 133 | .vscode 134 | notebooks 135 | rasp_repl/ 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yash Bonde 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rasp 2 | 3 | Implementing Restricted Access Sequence Processing (RASP) transformer language from ["Thinking like Transformers"](https://arxiv.org/pdf/2106.06981.pdf) paper. From the paper: 4 | 5 | > RASP can be used to program solutions to tasks that could conceivably be learned by a Transformer, and how a Transformer can be trained to mimic a RASP solution. 6 | 7 | ## Installation 8 | 9 | Simple pip install. `pip install git+https://github.com/yashbonde/rasp` 10 | 11 | ### What is RASP? 12 | 13 | - `select` (creating selection matrices called selectors): This is similar to attention creation `k x q.T => attn`. 14 | - `aggregate` (collapsing selectors and s-ops into a new s-ops): This is like value multiplication `attn x v`. 15 | - `selector_width` (creating an s-op from a selector): returns number of activated tensors `selector_width(select(tokens,tokens,==))("hello")=[1,1,2,2,1]`. 16 | 17 | ### Structure 18 | 19 | - `rasp/`: main code library 20 | - `test/`: tests 21 | - `primitives`: neural functions and primitives built using manual codes that can be used for training 22 | 23 | ## Code 24 | 25 | ### Example: Reverse [WIP] 26 | 27 | So you can built complex flows directly in terms of architecture ex. building reverse function: 28 | ```python 29 | from rasp import RaspModule 30 | 31 | reverse = RaspModule(''' 32 | def reverse(tokens): 33 | opp_idx = length - indices - 1; 34 | flip = select (indices ,opp_index ,==) ; 35 | return aggregate (flip, tokens); 36 | ''') 37 | assert reverse("hey") == "yeh" 38 | ``` 39 | 40 | This would create a neural network as follows: 41 | ```python 42 | class Flip(nn.Module): 43 | # flip = select (indices ,opp_index ,==) ; 44 | def __init__(self): 45 | self.n_head = 1; 46 | 47 | def forward(self): 48 | pass 49 | ``` 50 | 51 | ### Tests 52 | 53 | All the code for tests are given in `test/`, run `pytest -v`. 54 | ## Experiments 55 | 56 | 1. Reverse e.g.: `reverse("abc")="cba"` 57 | 2. Histograms, with a unique beginning-of-sequence (BOS) token `$` (e.g., `hist_bos("$aba")=[$,2,1,2]`) and without it (e.g., `hist_nobos("aba")=[2,1,2]`) 58 | 3. Double-Histograms, with BOS: for each token, the number of unique tokens with same histogram value as itself. E.g.: `hist2("$abbc")=[§,2,1,1,2]` 59 | 4. Sort, with BOS: ordering the input tokens lexicographically. e.g.: `sort("$cba")="$abc".` 60 | 5. Most-Freq, with BOS: returning the unique input tokens in order of decreasing frequency, with original position as a tie-breaker and the BOS token for padding. E.g.: `most_freq("$abbccddd")="$dbca$$$$"` 61 | 6. Dyck-i PTF, for `i = 1, 2`: the task of returning, at each output position, whether the input prefix up to and including that position is a legal Dyck-i sequence (`T`), and if not, whether it can (`P`) or cannot (`F`) be continued into a legal Dyck-i sequence. E.g: `Dyck1_ptf("()())")="PTPTF"` 62 | 63 | ### References: 64 | 65 | - [Tokens List](https://github.com/tech-srl/RASP/blob/main/RASP_support/zzantlr/RASP.tokens) 66 | - [RASP Cheatsheet](https://github.com/tech-srl/RASP/blob/main/cheat_sheet.pdf) 67 | -------------------------------------------------------------------------------- /primitives/functional.py: -------------------------------------------------------------------------------- 1 | from rasp.manual import * 2 | import numpy as np 3 | 4 | def identity(input): 5 | attn = select(indices(input), indices(input), "==") 6 | attn = [[attn]] # shape: [number of blocks, number of heads in each block] 7 | return input, attn 8 | 9 | def reverse(input): 10 | # flip = select ( indices , length - indices - 1 ,==) 11 | # aggregate (flip , tokens ) 12 | i = indices(input) 13 | l = length(input) 14 | attn = select(i, l - i, "==") 15 | attn = [[attn]] # shape: [number of blocks, number of heads in] 16 | 17 | if isinstance(input, str): 18 | out = input[::-1] 19 | elif isinstance(input, (list, torch.Tensor, np.ndarray)): 20 | out = [x[::-1] for x in input] 21 | return out, attn 22 | -------------------------------------------------------------------------------- /rasp/__init__.py: -------------------------------------------------------------------------------- 1 | import rasp.core 2 | import rasp.manual 3 | import rasp.model 4 | 5 | 6 | try: 7 | import rasp.daily 8 | except ImportError: 9 | # download the latest utilities - github.com/yashbonde 10 | def get_gist(id): 11 | import requests, json, os 12 | d = json.loads(requests.get(f"https://api.github.com/gists/{id}").content.decode("utf-8")) 13 | for file_name in d["files"]: # save all the files exactly like on your gist 14 | here = os.path.split(os.path.abspath(__file__))[0] 15 | with open(os.path.join(here, file_name), "w") as f: 16 | f.write(d["files"][file_name]["content"]) 17 | get_gist("62df9d16858a43775c22a6af00a8d707") 18 | -------------------------------------------------------------------------------- /rasp/core.py: -------------------------------------------------------------------------------- 1 | # Since building simple primitives is the primary task of this language, 2 | # training the model should have first class support. 3 | # now you can directly load a primitive as follows: 4 | # 5 | # >>> from rasp import Primitive 6 | # >>> reverse = Primitive("reverse") 7 | # >>> reverse("hey") 8 | # ... "yeh" 9 | 10 | import json 11 | import numpy as np 12 | from tqdm import trange 13 | 14 | from rasp.model import * 15 | from rasp.manual import ivocab, vocab, tokens 16 | from rasp.daily import Hashlib 17 | 18 | def set_seed(seed): 19 | if seed is not None: 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | 24 | 25 | class Primitive: 26 | # primtive class is a Transformer neural network whose objective 27 | # is to perform that particular task. 28 | def __init__(self, name, code = None, seed = 4, **model_kwargs): 29 | set_seed(4) 30 | if code is not None: 31 | raise NotImplementedError("code parsing is still not implemented, hold your horses!") 32 | 33 | self.model = get_model(**model_kwargs) 34 | self.name = name 35 | 36 | str_ = f"{name}-" + json.dumps(self.model.config.get_json()) 37 | self._hash = Hashlib.sha256(str_) 38 | 39 | def get_parameters(self): 40 | return self.model.parameters() 41 | 42 | def __call__(self, *args, **kwargs): 43 | return self.model(*args, **kwargs) 44 | 45 | def viz(self, x): 46 | # this is not the best visualisation of attention since the values are 47 | # in float. But this is good enough to see what's up 48 | print("-+-" + "-" * len(x) * 2) 49 | print(" | " + " ".join(x)) 50 | print("-+-" + "-" * len(x) * 2) 51 | r = self(x, output_dict = True) 52 | a = r.attns[0] * 10 53 | a = a.long() 54 | a = a.tolist()[0] 55 | for i in range(len(a)): 56 | print(f"{x[i]}|", " ".join([str(b) for b in a[i]])) 57 | print("-+-" + "-" * len(x) * 2) 58 | 59 | def train(self, ds, man_fn, optim_name = "Adam", n_epochs = 5, pbar = False, **optimiser_params): 60 | """training any primitive has first class support since this is what each primitive is""" 61 | optim = getattr(torch.optim, optim_name)(self.get_parameters(), **optimiser_params) 62 | for i in range(n_epochs): 63 | bar = trange(len(ds)) if pbar else range(len(ds)) 64 | for x, j in zip(ds, bar): 65 | t = man_fn(x) 66 | out, loss = self(idx = x, targets = t) 67 | 68 | optim.zero_grad() 69 | loss.backward() 70 | optim.step() 71 | 72 | if j % 50 == 0: 73 | print(loss) 74 | self.viz(ds[0]) 75 | 76 | def get_vocab(): 77 | return vocab, ivocab 78 | -------------------------------------------------------------------------------- /rasp/daily.py: -------------------------------------------------------------------------------- 1 | # whole bunch of utility functions I use in day to day 2 | # - @yashbonde / https://github.com/yashbonde 3 | # 4 | # Now this is a simple script and cannot be loaded like a package 5 | # so you'll need to import it. This is how you can do it 6 | """ 7 | try: 8 | from daily import * 9 | except ImportError as e: 10 | import requests 11 | x = requests.get().content 12 | with open("daily.py", "wb") as f: 13 | f.write(x) 14 | from daily import * 15 | """ 16 | 17 | import logging 18 | logging.basicConfig(level="INFO") 19 | def log(x: str, *args): 20 | x = str(x) 21 | for y in args: 22 | x += " " + str(y) 23 | logging.info(x) 24 | 25 | 26 | def fetch(url): 27 | # efficient loading of URLS 28 | import os, tempfile, hashlib, requests 29 | fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest()) 30 | if os.path.isfile(fp) and os.stat(fp).st_size > 0: 31 | with open(fp, "rb") as f: 32 | dat = f.read() 33 | else: 34 | print("fetching", url) 35 | dat = requests.get(url).content 36 | with open(fp+".tmp", "wb") as f: 37 | f.write(dat) 38 | os.rename(fp+".tmp", fp) 39 | return dat 40 | 41 | 42 | def get_files_in_folder(folder, ext = [".txt"]): 43 | # this method is faster than glob 44 | import os 45 | all_paths = [] 46 | for root,_,files in os.walk(folder): 47 | for f in files: 48 | for e in ext: 49 | if f.endswith(e): 50 | all_paths.append(os.path.join(root,f)) 51 | return all_paths 52 | 53 | 54 | def json_load(path): 55 | # load any JSON like file with comment strings '//' 56 | # json files should have description strings so it's more friendly 57 | # but it won't load with json.load(f) so read the file, remove the 58 | # comments and json.loads(text) 59 | import json, re 60 | with open(path, 'r') as f: 61 | text = f.read() 62 | text = re.sub(r"\s*(\/{2}.*)\n", "\n", text) 63 | config = json.loads(text) 64 | return config 65 | 66 | 67 | def folder(x): 68 | # get the folder of this file path 69 | import os 70 | return os.path.split(os.path.abspath(x))[0] 71 | 72 | 73 | # class to handle hashing methods 74 | class Hashlib: 75 | # local imports don't work! 76 | # import hashlib 77 | 78 | def sha256(x): 79 | import hashlib 80 | x = x if isinstance(x, bytes) else x.encode("utf-8") 81 | return hashlib.sha256(x).hexdigest() 82 | 83 | def md5(x): 84 | import hashlib 85 | x = x if isinstance(x, bytes) else x.encode("utf-8") 86 | return hashlib.md5(x).hexdigest() 87 | -------------------------------------------------------------------------------- /rasp/manual.py: -------------------------------------------------------------------------------- 1 | # Implementation of "Thinking Like Transformers" (https://arxiv.org/pdf/2106.06981.pdf) 2 | # full repo: https://github.com/tech-srl/RASP 3 | # @yashbonde - 18.06.2021 4 | # MIT License 5 | # 6 | # Why build this? 7 | # - learning how to write languages is the best way to learn how to minimise useless shit 8 | # and maximise simplicity of code + was fun to code whilst in Deep Thoughts! 9 | # 10 | # Where can I use this? 11 | # - See the examples, if it's not there then will add it later. 12 | # 13 | # Things that are different from the paper 14 | # - 15 | # TODO: 16 | # - implement conditionals 17 | # - additional operators such as `in`, `sort`, `count` 18 | 19 | import string 20 | import numpy as np 21 | import torch 22 | import einops as ein 23 | 24 | # NOTE: This is a demo code and not meant to be a production thing 25 | # changing the vocab will break some tests, so for the sake of your 26 | # and my sanity don't change this. 27 | vocab = {k:i for i,k in enumerate(string.ascii_lowercase + "$")} 28 | ivocab = {i:k for k,i in vocab.items()} 29 | 30 | # ---- built in 31 | def tokens(x, bos = False): 32 | """ 33 | if bos == True, then output has bos tag added 34 | 35 | ### Always PAD ### 36 | 37 | # tokens("hello") = [ 7, 4, 11, 11, 14] 38 | # tokens(tokens("hello")) = "hello" 39 | # tokens(["hello", "hello"]) = [[7, 4, 11, 11, 14], [7, 4, 11, 11, 14]] 40 | # tokens([[7, 4, 11, 11, 14], [7, 4, 11, 11, 14]]) = ["hello", "hello"] 41 | 42 | Logic Flow: 43 | # Case A (str only): "hello" 44 | # Case B (list of str): ["hello", "hello"] 45 | # Case C (tensor 1D): [ 7, 4, 11, 11, 14] 46 | # Case D (tensor 2D): [[ 7, 4, 11, 11, 14], [ 7, 4, 11, 11, 14]] 47 | 48 | can consume strings, lists, arrays and tensors 49 | """ 50 | 51 | if isinstance(x, str): 52 | # Case A (str only): "hello" 53 | out = torch.Tensor([vocab["$"]] + [vocab[t] for t in x.lower()]).long() 54 | if not bos: 55 | out = out[1:] 56 | return out 57 | elif isinstance(x, list) and isinstance(x[0], str): 58 | # Case B (list of str): ["hello", "hello"] 59 | m = max([len(y) for y in x]) 60 | for i,y in enumerate(x): 61 | x[i] = x[i] + "".join(["$" for _ in range(m - len(x[i]))]) 62 | return torch.cat([tokens(s, bos).unsqueeze(0) for s in x]).long() 63 | else: 64 | assert isinstance(x, (torch.Tensor, np.ndarray)), "Can consume only strings and torch.Tensors / np.ndarrays" 65 | # input is likely a tensor 66 | if len(x.shape) == 1: 67 | # Case C (tensor 1D): [ 7, 4, 11, 11, 14] 68 | out = "".join([ivocab[t] for t in x.tolist()]) 69 | # FORCE REMOVE PADDING 70 | if "$" in out[1:]: 71 | out = out[:out[1:].index("$")] 72 | if not bos and out[0] == "$": 73 | out = out[1:] 74 | out = out 75 | return out 76 | else: 77 | # Case D (tensor 2D): [ [ 7, 4, 11, 11, 14], [ 7, 4, 11, 11, 14]] 78 | return [tokens(s, bos) for s in x] 79 | 80 | def indices(x): 81 | # indices("hello") = [0,1,2,3,4] 82 | return torch.arange(len(x)).float() 83 | 84 | def length(x): 85 | # length("hello") = [5,5,5,5,5] 86 | return torch.ones(len(x)) * len(x) 87 | 88 | 89 | # --- element wise 90 | def logical(x, op, y = None): 91 | # logical(x, "and", y) 92 | def _or(x, y): 93 | return torch.logical_or(x.contiguous().view(-1), y.contiguous().view(-1)).view(x.shape) 94 | def _and(x, y): 95 | return torch.logical_and(x.contiguous().view(-1), y.contiguous().view(-1)).view(x.shape) 96 | def _not(x, y): 97 | return torch.logical_not(x.contiguous().view(-1)).view(x.shape) 98 | def _xor(x, y): 99 | return torch.logical_xor(x.contiguous().view(-1), y.contiguous().view(-1)).view(x.shape) 100 | 101 | assert op in ["or", "and", "not", "xor"], f"`{op}` not supported" 102 | if op != "not": 103 | assert x.shape == y.shape, f"Shapes must be same, got {x.shape}, {y.shape}" 104 | out = {"or": _or, "and": _and, "not": _not, "xor": _xor}[op](x, y) 105 | return out 106 | 107 | def elementwise(x, op, y): 108 | # elementwise(x, "-", y) 109 | if op in ["or", "and", "not", "xor"]: 110 | return logical(x, op, y) 111 | 112 | def _add(x, y): return x + y 113 | def _mul(x, y): return x * y 114 | def _sub(x, y): return x - y 115 | def _div(x, y): 116 | out = torch.div(x, y) 117 | out[out == float("inf")] = 0 118 | out = torch.nan_to_num(out, 0) 119 | return out 120 | 121 | assert x.shape == y.shape, f"Shapes must be same, got {x.shape}, {y.shape}" 122 | assert op in ["+", "-", "*", "/"], f"`{op}` not supported" 123 | 124 | out = {"+":_add, "-":_sub, "*":_mul, "/":_div}[op](x, y) 125 | return out 126 | 127 | 128 | # --- select 129 | def select(m1: torch.Tensor, m2, op): 130 | # creating boolean matrices called "selectors" 131 | if isinstance(m2, (bool, int)): 132 | m2 = torch.ones(m1.shape) * m2 133 | 134 | assert len(m1.shape) == 1 135 | assert len(m2.shape) == 1 136 | 137 | rows = ein.repeat(m1, "w -> n w", n = m2.shape[0]) 138 | cols = ein.repeat(m2, "h -> h n", n = m1.shape[0]) 139 | 140 | init_shape = rows.shape 141 | out = { 142 | "==": torch.eq, 143 | "!=": lambda *x: ~torch.eq(*x), 144 | "<=": torch.less_equal, 145 | "<": torch.less, 146 | ">": torch.greater, 147 | ">=": torch.greater_equal, 148 | }[op](rows.contiguous().view(-1), cols.contiguous().view(-1)) 149 | out = out.view(*init_shape) 150 | 151 | return out 152 | 153 | # --- aggregate 154 | def aggregate(s, x, agg = "mean"): 155 | # collapsing selectors and s-ops into new s-ops 156 | x = ein.repeat(x, "w -> n w", n = s.shape[0]) 157 | sf = s.float() 158 | y = x * sf 159 | 160 | if agg == "mean": 161 | ym = y.sum(1) / sf.sum(1) 162 | else: 163 | raise ValueError(f"agg: `{agg}` not found") 164 | 165 | return torch.nan_to_num(ym, 0) 166 | 167 | # --- simple select aggregate 168 | def flip(x): 169 | i = indices(x); l = length(x) 170 | return select(i, l-i-1, "==") 171 | 172 | # --- selector_width 173 | 174 | # def selector_width(x): 175 | # pass 176 | 177 | # def selector_width (sel ,assume_bos = False): 178 | # light0 = indicator ( indices == 0) 179 | # or0 = sel or select_eq ( indices ,0) 180 | # and0 =sel and select_eq ( indices ,0) 181 | # or0_0_frac =aggregate (or0 , light0 ) 182 | # or0_width = 1 / or0_0_frac 183 | # and0_width = aggregate (and0 ,light0 ,0) 184 | # 185 | # # if has bos , remove bos from width 186 | # # (doesn ’t count , even if chosen by 187 | # # sel) and return . 188 | # bos_res = or0_width - 1 189 | # 190 | # # else , remove 0 - position from or0 , 191 | # # and re -add according to and0 : 192 | # nobos_res = bos_res + and0_width 193 | # 194 | # return bos_res if assume_bos else 195 | # nobos_res -------------------------------------------------------------------------------- /rasp/model.py: -------------------------------------------------------------------------------- 1 | # give proper credits, stole from: https://github.com/karpathy/minGPT/blob/master/mingpt/model.py 2 | # but he won't mind! 3 | # modified for yashbonde/rasp 4 | 5 | import math 6 | import json 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | import einops as ein 11 | 12 | from rasp.manual import vocab, tokens 13 | 14 | # ------ configurations ------ # 15 | 16 | class Config: 17 | def __init__(self, **kwargs): 18 | self.vocab_size = len(vocab) 19 | self.n_embd = 18 20 | self.block_size = 32 21 | self.dropout = 0.0 22 | self.n_layer = 1 23 | self.n_head = 1 24 | for k,v in kwargs.items(): 25 | setattr(self, k, v) 26 | 27 | def get_json(self): 28 | return json.dumps(vars(self)) 29 | 30 | # ------ response ------ # 31 | 32 | class Response: 33 | def __init__(self, logits, loss, attns): 34 | self.logits = logits 35 | self.loss = loss 36 | self.attns = attns 37 | self.tokens = tokens(logits.argmax(-1)) 38 | 39 | # ------ model ------ # 40 | 41 | class SelfAttention(nn.Module): 42 | def __init__(self, config): 43 | super().__init__() 44 | assert config.n_embd % config.n_head == 0 45 | # single qkv like GPT 46 | self.qkv = nn.Linear(config.n_embd, config.n_embd * 3) 47 | self.split_size = config.n_embd 48 | 49 | # output projection 50 | self.proj = nn.Linear(config.n_embd, config.n_embd) 51 | 52 | # causal mask to ensure that attention is only applied to the left in the input sequence 53 | self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) 54 | .view(1, 1, config.block_size, config.block_size)) 55 | self.n_head = config.n_head 56 | 57 | def forward(self, x, attn_mask = None): 58 | B, T, C = x.size() 59 | 60 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 61 | q,k,v = self.qkv(x).split(self.split_size, 2) 62 | 63 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 64 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 65 | if attn_mask is not None: 66 | att = att + attn_mask 67 | att = F.softmax(att, dim=-1) 68 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 69 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 70 | 71 | # output projection 72 | y = self.proj(y) 73 | return y, att 74 | 75 | 76 | class Block(nn.Module): 77 | """ an unassuming Transformer block """ 78 | 79 | def __init__(self, config): 80 | super().__init__() 81 | self.ln1 = nn.LayerNorm(config.n_embd) 82 | self.ln2 = nn.LayerNorm(config.n_embd) 83 | self.split_size = config.n_embd 84 | self.attn = SelfAttention(config) 85 | self.mlp = nn.Sequential( 86 | nn.Linear(config.n_embd, 4 * config.n_embd), 87 | nn.GELU(), 88 | nn.Linear(4 * config.n_embd, config.n_embd), 89 | nn.Dropout(config.dropout), 90 | ) 91 | 92 | self.n_head = config.n_head 93 | 94 | def forward(self, x): 95 | x, attn_mask = x 96 | y = self.ln1(x) 97 | y, att = self.attn(y, attn_mask) 98 | x = x + y 99 | x = x + self.mlp(self.ln2(x)) 100 | return [x, att] 101 | 102 | 103 | class FullTransformer(nn.Module): 104 | """A full transformer model with focus on data and I/O. 105 | Can consume strings, lists, arrays and torch-tensors.""" 106 | 107 | def __init__(self, config): 108 | super().__init__() 109 | 110 | # input embedding stem 111 | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) 112 | self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) 113 | self.drop = nn.Dropout(config.dropout) 114 | 115 | # transformer 116 | self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) 117 | 118 | # decoder head 119 | self.ln_f = nn.LayerNorm(config.n_embd) 120 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 121 | self.block_size = config.block_size 122 | 123 | self.config = config 124 | 125 | @property 126 | def num_parameters(self): 127 | return sum(p.numel() for p in self.parameters()) 128 | 129 | def get_device(self): 130 | return next(self.parameters()).device 131 | 132 | def format_inputs_and_tokens(self, idx, targets): 133 | d = self.get_device() 134 | P_token = "$"; P_id = vocab[P_token]; 135 | 136 | # format input and convert to tokens 137 | if not isinstance(idx, torch.Tensor): 138 | idx = tokens(idx) 139 | if len(idx.shape) == 1: 140 | idx = idx.unsqueeze(0) 141 | 142 | # create attention masks as follows: 143 | # 144 | # input --> "wd" 145 | # 146 | # [[ 0., 0., 0., 0., 0.], 147 | # [ 0., 0., 0., 0., 0.], 148 | # [ 0., 0., -1000000., -1000000., -1000000.], 149 | # [ 0., 0., -1000000., -1000000., -1000000.], 150 | # [ 0., 0., -1000000., -1000000., -1000000.]] 151 | # 152 | # this is not the fastest method out there, but hey gets the job done. 153 | m = torch.zeros((len(idx), idx.shape[1], idx.shape[1])) 154 | for _i,t in enumerate(idx): 155 | if P_id in t[1:]: 156 | l_ = (t[1:] == P_id).long().argmax(-1) 157 | l_ = min(l_ + 1, t.shape[0]) 158 | m[_i, l_:, l_:] = -1e6 159 | 160 | if targets is not None: 161 | assert isinstance(targets, (list, tuple)), \ 162 | "target needs to have a LongTensor and a list/tuple of attn for MSE" 163 | 164 | targets, attn_masks = targets 165 | 166 | # convert the target to proper tokens 167 | if not isinstance(targets, torch.Tensor): 168 | targets = tokens(targets) 169 | if len(targets.shape) == 1: 170 | targets = targets.unsqueeze(0) 171 | 172 | # verify the shapes of attention masks 173 | assert len(attn_masks) == len(self.blocks), \ 174 | "Number of attentions should be same as number of blocks. " +\ 175 | f"{len(attn_masks)} != {len(self.blocks)}" 176 | assert isinstance(attn_masks[0], (list, tuple, torch.Tensor)), \ 177 | "Each sequence in the attention should be a tuple/list/tensor" 178 | 179 | for i,(b,a) in enumerate(zip(self.blocks, attn_masks)): 180 | assert isinstance(a[0], torch.Tensor) 181 | assert len(a) == b.n_head, \ 182 | f"Number of attn != number of heads in a block. Got: {len(a), b.n_head}" 183 | attn_masks[i] = torch.cat([aa.float().unsqueeze(0) for aa in a], 0) 184 | 185 | # for i,a in enumerate(attn_masks): 186 | # print(":--:", i, a.shape, self.blocks[i].n_head) 187 | 188 | # convert to the final tuple 189 | targets = (targets, attn_masks) 190 | 191 | return idx, m, targets 192 | 193 | def forward(self, idx, targets=None, output_dict = False): 194 | """ 195 | Args: 196 | idx: Can take in following objects: 197 | - string 198 | - list[string] 199 | - torch.LongTensor (1D) 200 | - torch.LongTensor (2D) 201 | targets (optional): Since in rasp you calculate losses for attention matrix 202 | as well, this targets is a list: 203 | - torch.LongTensor(): with the cross entropy for entire input tokens, 204 | just like a normal transformer (GPT/BERT) 205 | - target_attn_masks: this is the target matrices for all the attentions in the network. 206 | ensure that the number of heads and values are common. 207 | 208 | Returns: 209 | [type]: [description] 210 | """ 211 | idx, attn_mask, targets = self.format_inputs_and_tokens(idx, targets) 212 | 213 | b, t = idx.size() 214 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 215 | 216 | # forward the GPT model 217 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 218 | position_embeddings = self.pos_emb[:, :t] # each position maps to a (learnable) vector 219 | x = self.drop(token_embeddings + position_embeddings) 220 | all_attn = [] 221 | for b in self.blocks: 222 | x, att = b([x, attn_mask]) 223 | all_attn.append(att) 224 | x = self.ln_f(x) 225 | logits = self.head(x) 226 | 227 | # if we are given some desired targets also calculate the loss 228 | loss = None 229 | if targets is not None: 230 | targets, attn_masks = targets 231 | 232 | # Cross Entropy loss 233 | ce_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 234 | 235 | # MSE-loss, for each head, manually 236 | mse_loss = 0 237 | for a,t in zip(all_attn, attn_masks): 238 | t = torch.tile(t, [a.shape[0], 1, 1]) # proper batch-ise 239 | # print(a.shape, t.shape) # [b, s, s] 240 | 241 | mse_loss += F.mse_loss(a, t) 242 | loss = ce_loss + mse_loss 243 | 244 | if not output_dict: 245 | return logits, loss 246 | return Response(logits, loss, all_attn) 247 | 248 | # ------ model function ------ # 249 | 250 | def get_model(**kwargs): 251 | config = Config(**kwargs) 252 | return FullTransformer(config) 253 | -------------------------------------------------------------------------------- /rasp/parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Parser 4 | ====== 5 | 6 | This code will parse the input string code and convert to Module 7 | object that can contain transformer layers as well. 8 | 9 | The primary objective is to keep this thing as simple as possible 10 | and ensuring it is easily tokenized for GPTs and parsed by the 11 | code as well. 12 | """ 13 | 14 | import re 15 | import os 16 | import importlib 17 | from ast import literal_eval 18 | from rasp.daily import * 19 | 20 | import torch 21 | from torch import nn 22 | 23 | mod_template = '''class RaspMod(nn.Module): 24 | def __init__(self): 25 | super().__init__() 26 | {init} 27 | {forward} 28 | ''' 29 | 30 | def load(code): 31 | # manage the cache folder 32 | r_cache = os.path.join(folder(__file__), ".rasp_cache") 33 | os.makedirs(r_cache, exist_ok = True) 34 | _h = Hashlib.md5(code) 35 | 36 | fpath = os.path.join(r_cache, f"{_h}.py") 37 | print("-->", fpath) 38 | with open(fpath, "w") as f: 39 | f.write(code) 40 | 41 | spec = importlib.util.spec_from_file_location("RaspMod", fpath) 42 | foo = importlib.util.module_from_spec(spec) 43 | print(foo, dir(foo)) 44 | return foo 45 | 46 | 47 | 48 | 49 | def get_rsp( 50 | code, 51 | *variables, 52 | ): 53 | # args = ", ".join(*variables) 54 | forward = "\n".join([f" {x}" for x in code.strip().split("\n")]) 55 | mod = mod_template.format(init = "self.w = nn.Linear(34, 123);", forward = forward) 56 | print(mod) 57 | cls = load(mod)() 58 | return cls 59 | 60 | if __name__ == "__main__": 61 | code_casual_attn = ''' 62 | def forward(self, x, y): 63 | idx = indices(x); 64 | selectors = select(idx, idx + 1, "<"); 65 | return y + ein.repeat(selectors, 'h w -> b h w n', n = 1, b = y.shape[0]); 66 | ''' 67 | casual_attention = get_rsp(code_casual_attn) 68 | 69 | print(casual_attention) 70 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | import os 4 | from rasp.daily import folder 5 | 6 | from distutils.core import setup 7 | 8 | setup( 9 | name='rasp', 10 | version='n', 11 | description='Restricted Access Sequence Processing (RASP) Language', 12 | long_description=open(os.path.join(folder(__file__), "README.md")).read(), 13 | author='Yash Bonde', 14 | author_email='bonde.yash97@gmail.com', 15 | url='https://github.com/yashbonde/rasp', 16 | packages=['rasp'], 17 | ) 18 | -------------------------------------------------------------------------------- /test/core_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import torch 4 | from rasp.core import * 5 | from rasp.model import get_model 6 | from rasp.daily import folder 7 | import sys 8 | import os 9 | 10 | here = folder(folder(__file__)) 11 | sys.path.append(os.path.join(here, "primitives")) 12 | 13 | from primitives import functional as F 14 | 15 | def set_seed(seed): 16 | if seed is not None: 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | class TestCore(unittest.TestCase): 21 | 22 | def test_identity(self): 23 | # just need to check if it works 24 | string = "hello" 25 | 26 | # first pass just to check if everything works 27 | model = get_model() 28 | logits, loss = model(string) 29 | 30 | # second pass with loss 31 | target = F.identity(string) 32 | logits, loss = model(string, target) 33 | 34 | if __name__ == "__main__": 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /test/manual_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from rasp.manual import * 4 | 5 | F = False 6 | T = True 7 | 8 | def set_seed(seed): 9 | if seed is not None: 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed_all(seed) 12 | 13 | class TestPrimaryOps(unittest.TestCase): 14 | 15 | def test_tokens(self): 16 | # single string encoding -> decoding 17 | self.assertTrue(isinstance(tokens("hello"), torch.Tensor)) 18 | self.assertEqual(tokens("hello").tolist(), [ 7, 4, 11, 11, 14]) 19 | self.assertEqual(tokens(tokens("hey")), "hey") 20 | 21 | # single string encoding -> decoding (bos tag) 22 | self.assertEqual(tokens("hey", True).tolist(), [26, 7, 4, 24]) 23 | self.assertEqual(tokens(tokens("hey", True)), "hey") 24 | self.assertEqual(tokens(tokens("hey", True), True), "$hey") 25 | 26 | # list of strings 27 | self.assertTrue(isinstance(tokens(["hello", "hello"]), torch.Tensor)) 28 | self.assertEqual(tokens(["hello", "hello"]).tolist(), [[ 7, 4, 11, 11, 14], [ 7, 4, 11, 11, 14]]) 29 | self.assertEqual(tokens(tokens(["hey", "hey"])), ["hey", "hey"]) 30 | 31 | # list of strings (bos tag) 32 | self.assertEqual(tokens(["hello", "hello"], True).tolist(), [[26, 7, 4, 11, 11, 14], [26, 7, 4, 11, 11, 14]]) 33 | self.assertEqual(tokens(tokens(["hey", "hey"], True)), ["hey", "hey"]) 34 | self.assertEqual(tokens(tokens(["hey", "hey"], True), True), ["$hey", "$hey"]) 35 | 36 | def test_logical_1d(self): 37 | x = torch.Tensor([F, F, T]) 38 | y = torch.Tensor([F, T, T]) 39 | self.assertEqual(logical(x, "not", y).tolist(), [T, T, F]) 40 | self.assertEqual(logical(x, "or", y).tolist(), [F, T, T]) 41 | self.assertEqual(logical(x, "and", y).tolist(), [F, F, T]) 42 | self.assertEqual(logical(x, "xor", y).tolist(), [F, T, F]) 43 | 44 | def test_logical_2d(self): 45 | x = torch.Tensor([[F, F, T], [T, T, F], [T, F, T]]).bool() 46 | y = torch.Tensor([[T, F, F], [T, T, F], [T, T, T]]).bool() 47 | 48 | self.assertEqual(logical(x, "not", y).tolist(), [[T, T, F], [F, F, T], [F, T, F]]) 49 | self.assertEqual(logical(x, "or", y).tolist(), [[T, F, T], [T, T, F], [T, T, T]]) 50 | self.assertEqual(logical(x, "and", y).tolist(), [[F, F, F], [T, T, F], [T, F, T]]) 51 | self.assertEqual(logical(x, "xor", y).tolist(), [[T, F, T], [F, F, F], [F, T, F]]) 52 | 53 | def test_elementwise(self): 54 | x = torch.Tensor([[0, 0, 2], [2, 2, 0], [2, 0, 2]]) 55 | y = torch.Tensor([[3, 0, 0], [3, 3, 0], [3, 3, 3]]) 56 | 57 | self.assertEqual(elementwise(x, "+", y).tolist(), [[ 3, 0, 2], [ 5, 5, 0], [ 5, 3, 5]]) 58 | self.assertEqual(elementwise(x, "-", y).tolist(), [[-3, 0, 2], [-1, -1, 0], [-1, -3, -1]]) 59 | self.assertEqual(elementwise(x, "*", y).tolist(), [[ 0, 0, 0], [ 6, 6, 0], [ 6, 0, 6]]) 60 | self.assertTrue(np.allclose( 61 | np.array(elementwise(x, "/", y).tolist()).reshape(-1), 62 | np.array([[ 0, 0, 0], [2/3,2/3,0], [2/3, 0, 2/3]]).reshape(-1) 63 | )) 64 | 65 | def test_select(self): 66 | x = torch.Tensor([1, 2, 2]) 67 | y = torch.Tensor([0, 1, 2]) 68 | s = select(x, y, "==") 69 | self.assertEqual(s.tolist(), [[F, F, F], [T, F, F], [F, T, T]]) 70 | 71 | def test_aggregate(self): 72 | x = torch.Tensor([4, 6, 8]) 73 | s = torch.Tensor([[F, F, F], [T, F, F], [F, T, T]]).bool() 74 | self.assertEqual(aggregate(s, x).tolist(), [0, 4, 7]) 75 | 76 | def test_indices(self): 77 | self.assertEqual(indices("hi").tolist(), [0, 1]) 78 | 79 | def test_length(self): 80 | self.assertEqual(length("yoo").tolist(), [3, 3, 3]) 81 | self.assertEqual(length("hi").tolist(), [2, 2]) 82 | 83 | def test_flip(self): 84 | self.assertEqual(flip("hey").tolist(), [[F, F, T], [F, T, F], [T, F, F]]) 85 | 86 | 87 | class PaperBuiltOps(unittest.TestCase): 88 | 89 | def test_reverse(self): 90 | x = "hey" 91 | reverse = tokens(aggregate(flip(x), tokens(x))) 92 | self.assertEqual(reverse, "yeh") 93 | 94 | def test_select_1(self): 95 | i = indices("hey") 96 | out = select(i, i, "<").tolist() 97 | self.assertEqual(out, [[F,F,F],[T,F,F],[T,T,F]]) 98 | 99 | def test_select_aggregate_1(self): 100 | def a(x): 101 | return select(indices(x), indices(x), "<") 102 | i = indices("hey"); a = a("hey") 103 | aggregate(a, i + 1, "mean") 104 | 105 | def test_select_aggregate_2(self): 106 | def load1(x): 107 | return select(indices(x), 1, "==") 108 | out = logical(load1("hey"), "or", flip("hey")).tolist() 109 | self.assertEqual(out, [[F,T,T],[F,T,F],[T,T,F]]) 110 | 111 | 112 | if __name__ == "__main__": 113 | unittest.main() 114 | -------------------------------------------------------------------------------- /test/model_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from rasp.model import * 3 | from rasp.core import Primitive, get_vocab 4 | 5 | def set_seed(seed): 6 | if seed is not None: 7 | torch.manual_seed(seed) 8 | torch.cuda.manual_seed_all(seed) 9 | 10 | 11 | class TestTransformer(unittest.TestCase): 12 | 13 | def test_model_string_input(self): 14 | set_seed(4) 15 | model = get_model() 16 | out, loss = model("hello") 17 | self.assertEqual(tokens(out.argmax(-1)), ['etffn']) 18 | out, loss = model(["hello", "world"]) 19 | self.assertEqual(tokens(out.argmax(-1)), ['etffn', 'zkofk']) 20 | out, loss = model(["hello", "wd", "sdfg"]) 21 | self.assertEqual(tokens(out.argmax(-1)), ['etffn', 'zkxxx', 'jkkdx']) 22 | 23 | def test_train(self): 24 | import os, sys 25 | import random 26 | import torch 27 | 28 | import numpy as np 29 | from rasp.daily import folder 30 | sys.path.append(os.path.join(folder(folder(__file__)), "primitives")) 31 | from primitives import functional as F 32 | 33 | # check if the loading is working 34 | # print(F.identity("foo")) 35 | vocab, ivocab = get_vocab() 36 | 37 | def identity_dataset(n = 200, m = 32): 38 | # since our manual primitives take care of the input output 39 | # we can batch the dataset into buckets of similar lengths 40 | set_seed(4) 41 | ds = [] 42 | for _ in range(n): # generate samples 43 | x = "".join([ 44 | ivocab[_i] for _i in np.random.randint(0, len(vocab) - 1, size = (np.random.randint(m) + 1,)) 45 | ]) 46 | ds.append(x) 47 | 48 | # create the dataset 49 | m = max([len(x) for x in ds]) 50 | for i,s in enumerate(ds): 51 | s = s[:m] 52 | if np.random.random() > 0.6: 53 | _i = np.random.randint(len(s)) 54 | _j = _i + np.random.randint(5) 55 | _v = ivocab[np.random.randint(25)] 56 | s = s[_i] + "".join([_v for _ in range(_i, _j, 1)]) + s[_j:] 57 | ds[i] = s[:m] 58 | return ds 59 | 60 | # create dataset 61 | ds = identity_dataset() 62 | 63 | # define the primitive 64 | p = Primitive("identity") 65 | # print("Test 1D:", tokens(p(ds[0])[0].argmax(-1))) 66 | # print("Test (batch):", tokens(p(ds[:2])[0].argmax(-1))) 67 | 68 | # train the network things 69 | p.train(ds, F.identity) 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /test/transformer_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from rasp.model import * 4 | from rasp.manual import tokens 5 | 6 | def set_seed(seed): 7 | if seed is not None: 8 | torch.manual_seed(seed) 9 | torch.cuda.manual_seed_all(seed) 10 | 11 | # assertion values 12 | 13 | # for numpy and tensor 14 | FIRST_PASS_TARGET_TENSOR = [[13, 23, 3, 10, 14, 13], [18, 5, 25, 3, 4, 10]] 15 | FIRST_PASS_LOSS_TENSOR = 4.4479 16 | SECOND_PASS_TARGET_TENSOR = [[19, 23, 3, 10, 14, 13], [18, 5, 25, 14, 4, 10]] 17 | SECOND_PASS_LOSS_TENSOR = 4.3977 18 | 19 | # for string computation 20 | STRING_PREDICTION = ['fme'] 21 | FIRST_PASS_LOSS_STRING = 4.5970 22 | SECOND_PASS_LOSS_STRING = 4.5050 23 | 24 | # test class 25 | 26 | class TestTransformer(unittest.TestCase): 27 | 28 | # test if the model is even initialized correctly 29 | def test_initialize(self): 30 | config = Config() 31 | model = FullTransformer(config) 32 | self.assertEqual(model.num_parameters, 5706) 33 | 34 | # forward + backward testing with tensors 35 | def test_forward(self): 36 | set_seed(4) 37 | config = Config() 38 | model = FullTransformer(config) 39 | x = torch.randint(0, config.vocab_size, size = (2, 6)) 40 | logits, loss = model(x) 41 | self.assertEqual( logits.argmax(-1).tolist(), FIRST_PASS_TARGET_TENSOR ) 42 | self.assertEqual(loss, None) 43 | 44 | def test_forward_with_loss(self): 45 | set_seed(4) 46 | config = Config() 47 | model = FullTransformer(config) 48 | x = torch.randint(0, config.vocab_size, size = (2, 6)) 49 | target = torch.randint(0, config.vocab_size, size = (2, 6)) 50 | target = (target, [[torch.randn(6, 6)]]) 51 | 52 | logits, loss = model(x, target) 53 | self.assertEqual( logits.argmax(-1).tolist(), FIRST_PASS_TARGET_TENSOR ) 54 | out = np.isclose(loss.item(), FIRST_PASS_LOSS_TENSOR) 55 | self.assertTrue(out) 56 | 57 | def test_backward(self): 58 | set_seed(4) 59 | config = Config() 60 | model = FullTransformer(config) 61 | optim = torch.optim.Adam(model.parameters()) 62 | 63 | x = torch.randint(0, config.vocab_size, size = (2, 6)) 64 | target = torch.randint(0, config.vocab_size, size = (2, 6)) 65 | target = (target, [[torch.randn(6, 6)]]) 66 | 67 | logits, loss = model(x, target) 68 | self.assertEqual( logits.argmax(-1).tolist(), FIRST_PASS_TARGET_TENSOR ) 69 | out = np.isclose(loss.item(), FIRST_PASS_LOSS_TENSOR) 70 | self.assertTrue(out) 71 | 72 | optim.zero_grad() 73 | loss.backward() 74 | optim.step() 75 | 76 | logits, loss = model(x, target) 77 | self.assertEqual( logits.argmax(-1).tolist(), SECOND_PASS_TARGET_TENSOR) 78 | out = np.isclose(loss.item(), SECOND_PASS_LOSS_TENSOR) 79 | self.assertTrue(out) 80 | 81 | # forward + backward testing with strings 82 | def test_forward_str(self): 83 | set_seed(4) 84 | config = Config() 85 | model = FullTransformer(config) 86 | x = "hey" 87 | logits, loss = model(x) 88 | p = [tokens(x) for x in logits.argmax(-1)] 89 | self.assertEqual(p, STRING_PREDICTION) 90 | self.assertEqual(loss, None) 91 | 92 | def test_forward_with_loss_str(self): 93 | set_seed(4) 94 | config = Config() 95 | model = FullTransformer(config) 96 | x = "hey"; target = "hey" 97 | target = (target, [[torch.randn(3, 3)]]) 98 | 99 | logits, loss = model(x, target) 100 | 101 | p = [tokens(x) for x in logits.argmax(-1)] 102 | 103 | self.assertEqual(p, STRING_PREDICTION) 104 | out = np.isclose(loss.item(), FIRST_PASS_LOSS_STRING) 105 | self.assertTrue(out) 106 | 107 | def test_backward_str(self): 108 | set_seed(4) 109 | config = Config() 110 | model = FullTransformer(config) 111 | optim = torch.optim.Adam(model.parameters()) 112 | 113 | # first pass 114 | x = "hey"; target = "hey" 115 | target = (target, [[torch.randn(3, 3)]]) 116 | 117 | logits, loss = model(x, target) 118 | p = [tokens(x) for x in logits.argmax(-1)] 119 | self.assertEqual(p, STRING_PREDICTION) 120 | out = np.isclose(loss.item(), FIRST_PASS_LOSS_STRING) 121 | self.assertTrue(out) 122 | 123 | # backprop 124 | optim.zero_grad() 125 | loss.backward() 126 | optim.step() 127 | 128 | # second pass 129 | logits, loss = model(x, target) 130 | p = [tokens(x) for x in logits.argmax(-1)] 131 | self.assertEqual(p, STRING_PREDICTION) 132 | out = np.isclose(loss.item(), SECOND_PASS_LOSS_STRING) 133 | self.assertTrue(out) 134 | 135 | # NOTE: this assumes that the entire model resides on a single card, ie. there is 136 | # no model distributed. 137 | 138 | # forward + backward testing with tensors CUDA 139 | @unittest.skipUnless(torch.cuda.is_available(), "CUDA not found, skipping these tests") 140 | def test_initialize_cuda(self): 141 | # test if the model is even initialized correctly 142 | config = Config() 143 | model = FullTransformer(config).cuda() 144 | self.assertEqual(model.num_parameters, 5706) 145 | del model 146 | 147 | @unittest.skipUnless(torch.cuda.is_available(), "CUDA not found, skipping these tests") 148 | def test_forward_cuda(self): 149 | set_seed(4) 150 | config = Config() 151 | model = FullTransformer(config).cuda() 152 | x = torch.randint(0, config.vocab_size, size = (2, 6)).cuda() 153 | logits, loss = model(x) 154 | self.assertEqual( logits.argmax(-1).detach().cpu().tolist(), FIRST_PASS_TARGET_TENSOR ) 155 | self.assertEqual(loss, None) 156 | 157 | @unittest.skipUnless(torch.cuda.is_available(), "CUDA not found, skipping these tests") 158 | def test_forward_with_loss_cuda(self): 159 | set_seed(4) 160 | config = Config() 161 | model = FullTransformer(config).cuda() 162 | x = torch.randint(0, config.vocab_size, size = (2, 6)).cuda() 163 | target = torch.randint(0, config.vocab_size, size = (2, 6)).cuda() 164 | target = (target, [[torch.randn(6, 6).cuda()]]) 165 | 166 | logits, loss = model(x, target) 167 | self.assertEqual( logits.argmax(-1).detach().cpu().tolist(), FIRST_PASS_TARGET_TENSOR ) 168 | out = np.isclose(loss.item(), FIRST_PASS_LOSS_TENSOR) 169 | self.assertTrue(out) 170 | 171 | @unittest.skipUnless(torch.cuda.is_available(), "CUDA not found, skipping these tests") 172 | def test_backward_cuda(self): 173 | set_seed(4) 174 | config = Config() 175 | model = FullTransformer(config).cuda() 176 | optim = torch.optim.Adam(model.parameters()) 177 | 178 | x = torch.randint(0, config.vocab_size, size = (2, 6)).cuda() 179 | target = torch.randint(0, config.vocab_size, size = (2, 6)).cuda() 180 | target = (target, [[torch.randn(6, 6).cuda()]]) 181 | 182 | logits, loss = model(x, target) 183 | self.assertEqual( logits.argmax(-1).detach().cpu().tolist(), FIRST_PASS_TARGET_TENSOR ) 184 | out = np.isclose(loss.item(), FIRST_PASS_LOSS_TENSOR) 185 | self.assertTrue(out) 186 | 187 | optim.zero_grad() 188 | loss.backward() 189 | optim.step() 190 | 191 | logits, loss = model(x, target) 192 | self.assertEqual( logits.argmax(-1).detach().cpu().tolist(), SECOND_PASS_TARGET_TENSOR) 193 | out = np.isclose(loss.item(), SECOND_PASS_LOSS_TENSOR) 194 | self.assertTrue(out) 195 | 196 | # forward + backward testing with strings CUDA 197 | @unittest.skipUnless(torch.cuda.is_available(), "CUDA not found, skipping these tests") 198 | def test_forward_str_cuda(self): 199 | set_seed(4) 200 | config = Config() 201 | model = FullTransformer(config).cuda() 202 | x = "hey" 203 | logits, loss = model(x) 204 | p = [tokens(x) for x in logits.argmax(-1)] 205 | self.assertEqual(p, STRING_PREDICTION) 206 | self.assertEqual(loss, None) 207 | 208 | @unittest.skipUnless(torch.cuda.is_available(), "CUDA not found, skipping these tests") 209 | def test_forward_with_loss_str_cuda(self): 210 | set_seed(4) 211 | config = Config() 212 | model = FullTransformer(config).cuda() 213 | x = "hey"; target = "hey" 214 | target = (target, [[torch.randn(3, 3).cuda()]]) 215 | 216 | logits, loss = model(x, target) 217 | p = [tokens(x) for x in logits.argmax(-1)] 218 | self.assertEqual(p, STRING_PREDICTION) 219 | out = np.isclose(loss.item(), FIRST_PASS_LOSS_STRING) 220 | self.assertTrue(out) 221 | 222 | @unittest.skipUnless(torch.cuda.is_available(), "CUDA not found, skipping these tests") 223 | def test_backward_str_cuda(self): 224 | set_seed(4) 225 | config = Config() 226 | model = FullTransformer(config).cuda() 227 | optim = torch.optim.Adam(model.parameters()) 228 | 229 | # first pass 230 | x = "hey"; target = "hey" 231 | target = (target, [[torch.randn(3, 3).cuda()]]) 232 | 233 | logits, loss = model(x, target) 234 | p = [tokens(x) for x in logits.argmax(-1)] 235 | self.assertEqual(p, STRING_PREDICTION) 236 | out = np.isclose(loss.item(), FIRST_PASS_LOSS_STRING) 237 | self.assertTrue(out) 238 | 239 | # backprop 240 | optim.zero_grad() 241 | loss.backward() 242 | optim.step() 243 | 244 | # second pass 245 | logits, loss = model(x, target) 246 | p = [tokens(x) for x in logits.argmax(-1)] 247 | self.assertEqual(p, STRING_PREDICTION) 248 | out = np.isclose(loss.item(), SECOND_PASS_LOSS_STRING) 249 | self.assertTrue(out) 250 | 251 | if __name__ == "__main__": 252 | unittest.main() 253 | --------------------------------------------------------------------------------