├── tinymodel ├── tokenization │ ├── neo_tok_ids_to_ts.pt │ ├── ts_tok_ids_to_neo.pt │ └── tokenization.py ├── __init__.py ├── sparse_mlp.py ├── lm_modules.py └── lm.py ├── .gitignore ├── pyproject.toml └── README.md /tinymodel/tokenization/neo_tok_ids_to_ts.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/noanabeshima/tinymodel/HEAD/tinymodel/tokenization/neo_tok_ids_to_ts.pt -------------------------------------------------------------------------------- /tinymodel/tokenization/ts_tok_ids_to_neo.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/noanabeshima/tinymodel/HEAD/tinymodel/tokenization/ts_tok_ids_to_neo.pt -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | **/__pycache__ 3 | testing.ipynb 4 | *.ipynb 5 | dist/ 6 | *.gitignore 7 | mlp_map_test/ 8 | attn_test/ 9 | simulation.py 10 | test.pt -------------------------------------------------------------------------------- /tinymodel/__init__.py: -------------------------------------------------------------------------------- 1 | from .lm import TinyModel 2 | from .sparse_mlp import SparseMLP 3 | from .tokenization.tokenization import dec, enc, tok_see, tokenizer, raw_toks, toks -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "tinymodel" 3 | version = "0.1.2.2-7" 4 | description = "A small TinyStories LM with SAEs and transcoders" 5 | authors = ["Noa Nabeshima "] 6 | readme = "README.md" 7 | packages = [{include = "tinymodel"}] 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.11" 11 | torch = "^2.3.1" 12 | numpy = "^1.26.4" 13 | tqdm = "^4.66.4" 14 | einops = "^0.8.0" 15 | transformers = "^4.41.2" 16 | unidecode = "^1.3.8" 17 | huggingface-hub = "^0.23.4" 18 | datasets = "^2.20.0" 19 | 20 | 21 | [[tool.poetry.source]] 22 | name = "tinymodel" 23 | url = "https://github.com/noanabeshima/tinymodel" 24 | priority = "explicit" 25 | 26 | [build-system] 27 | requires = ["poetry-core"] 28 | build-backend = "poetry.core.masonry.api" 29 | -------------------------------------------------------------------------------- /tinymodel/sparse_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from huggingface_hub import hf_hub_download 4 | 5 | 6 | class SparseMLP(nn.Module): 7 | def __init__(self, d_model, n_features): 8 | super().__init__() 9 | self.d_model = d_model 10 | self.n_features = n_features 11 | 12 | self.encoder = nn.Linear(d_model, n_features) 13 | self.act = nn.ReLU() 14 | self.decoder = nn.Linear(n_features, d_model) 15 | 16 | def get_acts(self, x, indices=None): 17 | """Indices are either a slice, an int, or a list of ints""" 18 | if indices is None: 19 | return self.act(self.encoder(x)) 20 | preacts = x @ self.encoder.weight.T[:, indices] + self.encoder.bias[indices] 21 | return self.act(preacts) 22 | 23 | def __call__(self, x): 24 | x = self.encoder(x) 25 | x = self.act(x) 26 | x = self.decoder(x) 27 | return x 28 | 29 | @classmethod 30 | def from_pretrained(self, state_dict_path: str, repo_id="noanabeshima/tiny_model"): 31 | """Uses huggingface_hub to download an SAE/sparse MLP.""" 32 | state_dict = torch.load( 33 | hf_hub_download(repo_id=repo_id, filename=state_dict_path + ".pt") 34 | ) 35 | n_features, d_model = state_dict["encoder.weight"].shape 36 | mlp = SparseMLP(d_model=d_model, n_features=n_features) 37 | mlp.load_state_dict(state_dict) 38 | return mlp 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [TinyModel](https://github.com/noanabeshima/tiny_model) 2 | 3 | # There's currently an issue where Python 3.13 doesn't work with TinyModel, please use Python 3.11 or 3.12 instead. Python 3.11.11 definitely works. 4 | 5 | 6 | TinyModel is a 4 layer, 44M parameter model trained on [TinyStories V2](https://arxiv.org/abs/2305.07759) for mechanistic interpretability. It uses ReLU activations and no layernorms. It comes with trained SAEs and transcoders. 7 | 8 | It can be installed with `pip install tinymodel` for Python 3.11 and higher. 9 | 10 | This library is in an alpha state, it probably has some bugs. Please let me know if you find any or you're having any trouble with the library, I can be emailed at my full name @ gmail.com or messaged on twitter. You can also add GitHub issues. 11 | 12 | 13 | ``` 14 | from tinymodel import TinyModel, tokenizer 15 | 16 | lm = TinyModel() 17 | 18 | # for inference 19 | tok_ids, padding_mask = tokenizer(['Once upon a time', 'In the forest']) 20 | logprobs = lm(tok_ids) 21 | 22 | # Get SAE/transcoder acts 23 | # See 'SAEs/Transcoders' section for more information. 24 | feature_acts = lm['M1N123'](tok_ids) 25 | all_feat_acts = lm['M2'](tok_ids) 26 | 27 | # Generation 28 | lm.generate('Once upon a time, Ada was happily walking through a magical forest with') 29 | 30 | # To decode tok_ids you can use 31 | tokenizer.decode(tok_ids) 32 | ``` 33 | 34 | It was trained for 3 epochs on a [preprocessed version of TinyStoriesV2](https://huggingface.co/datasets/noanabeshima/TinyStoriesV2). Pre-tokenized dataset [here](https://huggingface.co/datasets/noanabeshima/TinyModelTokIds). I recommend using this dataset for getting SAE/transcoder activations. 35 | 36 | 37 | 38 | # SAEs/transcoders 39 | Some sparse SAEs/transcoders are provided along with the model. 40 | 41 | For example, `acts = lm['M2N100'](tok_ids)` 42 | 43 | To get sparse acts, choose which part of the transformer block you want to look at (currently [sparse MLP](https://www.lesswrong.com/posts/MXabwqMwo3rkGqEW8/sparse-mlp-distillation)/[transcoder](https://www.alignmentforum.org/posts/YmkjnWtZGLbHRbzrP/transcoders-enable-fine-grained-interpretable-circuit) and SAEs on attention out are available, under the tags `'M'` and `'A'` respectively). Residual stream and MLP out SAEs exist, they just haven't been added yet, bug me on e.g. Twitter if you want this to happen fast. 44 | 45 | Then, add the layer. A sparse MLP at layer 2 would be `'M2'`. 46 | Finally, optionally add a particular neuron. For example `'M0N10000'`. 47 | 48 | # Tokenization 49 | Tokenization is done as follows: 50 | - the top-10K most frequent tokens using the GPT-NeoX tokenizer are selected and sorted by frequency. 51 | - To tokenize a document, first tokenize with the GPT-NeoX tokenizer. Then replace tokens not in the top 10K tokens with a special \[UNK\] token id. All token ids are then mapped to be between 1 and 10K, roughly sorted from most frequent to least. 52 | - Finally, prepend the document with a [BEGIN] token id. 53 | 54 | -------------------------------------------------------------------------------- /tinymodel/lm_modules.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class HookPoint(nn.Module): 9 | def __init__(self, name=None): 10 | super().__init__() 11 | self.name = name 12 | 13 | def forward(self, x): 14 | return x 15 | 16 | def __repr__(self): 17 | if self.name is not None: 18 | return f"HookPoint('{self.name}')" 19 | else: 20 | return "HookPoint()" 21 | 22 | 23 | class Attention(nn.Module): 24 | def __init__(self, n_heads, d_model, d_head, max_seq_len): 25 | super().__init__() 26 | self.Q = nn.Linear(d_model, d_head * n_heads, bias=False) 27 | self.K = nn.Linear(d_model, d_head * n_heads, bias=False) 28 | self.V = nn.Linear(d_model, d_head * n_heads, bias=False) 29 | self.O = nn.Linear(d_head * n_heads, d_model) 30 | 31 | nn.init.normal_(self.Q.weight, std=np.sqrt(2 / (d_model + d_head))) 32 | nn.init.normal_(self.K.weight, std=np.sqrt(2 / (d_model + d_head))) 33 | nn.init.zeros_(self.O.bias) 34 | 35 | self.n_heads = n_heads 36 | self.d_model = d_model 37 | self.d_head = d_head 38 | self.max_seq_len = max_seq_len 39 | 40 | self.attn_inp = HookPoint() 41 | self.qs = HookPoint() 42 | self.ks = HookPoint() 43 | self.vs = HookPoint() 44 | self.head_writeouts = HookPoint() 45 | self.catted_head_writeouts = HookPoint() 46 | self.attn_out = HookPoint() 47 | 48 | @property 49 | def Wq(self): 50 | return einops.rearrange( 51 | self.Q.weight.detach(), "d (h k) -> h d k", h=self.n_heads 52 | ) 53 | 54 | @property 55 | def Wk(self): 56 | return einops.rearrange( 57 | self.K.weight.detach(), "d (h k) -> h d k", h=self.n_heads 58 | ) 59 | 60 | @property 61 | def Wv(self): 62 | return einops.rearrange( 63 | self.V.weight.detach(), "d (h k) -> h d k", h=self.n_heads 64 | ) 65 | 66 | @property 67 | def Wo(self): 68 | return self.O.weight.detach() 69 | 70 | def forward(self, x): 71 | x = self.attn_inp(x) # hookpoint 72 | 73 | q, k, v = self.Q(x), self.K(x), self.V(x) 74 | 75 | qs = einops.rearrange(q, "b s (h d) -> b h s d", h=self.n_heads) 76 | qs = self.qs(qs) # hookpoint 77 | 78 | ks = einops.rearrange(k, "b s (h d) -> b h s d", h=self.n_heads) 79 | ks = self.ks(ks) # hookpoint 80 | 81 | vs = einops.rearrange(v, "b s (h d) -> b h s d", h=self.n_heads) 82 | vs = self.vs(vs) # hookpoint 83 | 84 | # force torch to use flash attention 2 85 | if x.dtype == torch.float16 or x.dtype == torch.bfloat16: 86 | with torch.backends.cuda.sdp_kernel( 87 | enable_flash=True, enable_math=False, enable_mem_efficient=False 88 | ): 89 | head_writeouts = F.scaled_dot_product_attention( 90 | qs, ks, vs, is_causal=True 91 | ) 92 | else: 93 | head_writeouts = F.scaled_dot_product_attention(qs, ks, vs, is_causal=True) 94 | head_writeouts = self.head_writeouts(head_writeouts) # hookpoint 95 | 96 | catted_head_writeouts = einops.rearrange(head_writeouts, "b h q d -> b q (h d)") 97 | catted_head_writeouts = self.catted_head_writeouts( 98 | catted_head_writeouts 99 | ) # hookpoint 100 | 101 | attn_out = self.O(catted_head_writeouts) 102 | attn_out = self.attn_out(attn_out) # hookpoint 103 | 104 | return attn_out 105 | 106 | 107 | class MLP(nn.Module): 108 | def __init__(self, d_model, d_mlp): 109 | super().__init__() 110 | self.read_in = nn.Linear(d_model, d_mlp) 111 | self.act = nn.ReLU() 112 | self.write_out = nn.Linear(d_mlp, d_model) 113 | 114 | self.d_model = d_model 115 | self.d_mlp = d_mlp 116 | 117 | self.mlp_inp = HookPoint() 118 | self.mlp_out = HookPoint() 119 | 120 | def forward(self, x): 121 | x = self.mlp_inp(x) # hookpoint 122 | 123 | preacts = self.read_in(x) 124 | acts = self.act(preacts) 125 | mlp_out = self.write_out(acts) 126 | 127 | mlp_out = self.mlp_out(mlp_out) # hookpoint 128 | 129 | return mlp_out 130 | 131 | 132 | class TransformerBlock(nn.Module): 133 | def __init__(self, d_model, n_heads, max_seq_len): 134 | super().__init__() 135 | assert d_model % n_heads == 0, "n_heads must divide d_model" 136 | d_head = d_model // n_heads 137 | 138 | self.attn = Attention( 139 | n_heads=n_heads, d_model=d_model, d_head=d_head, max_seq_len=max_seq_len 140 | ) 141 | self.mlp = MLP(d_model=d_model, d_mlp=4 * d_model) 142 | 143 | self.d_model = d_model 144 | self.d_head = d_head 145 | self.n_heads = n_heads 146 | self.max_seq_len = max_seq_len 147 | 148 | self.res_attn = HookPoint() 149 | self.res_mlp = HookPoint() 150 | self.res_final = HookPoint() 151 | 152 | def forward(self, x): 153 | x = self.res_attn(x) # hookpoint 154 | 155 | attn_x = self.attn(x) 156 | x = attn_x + x 157 | 158 | x = self.res_mlp(x) # hookpoint 159 | 160 | mlp_x = self.mlp(x) 161 | x = mlp_x + x 162 | 163 | x = self.res_final(x) # hookpoint 164 | 165 | return x 166 | -------------------------------------------------------------------------------- /tinymodel/tokenization/tokenization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from typing import List, Union 4 | 5 | import numpy as np 6 | import torch 7 | from transformers import AutoTokenizer 8 | from unidecode import unidecode 9 | 10 | current_file = os.path.abspath(__file__) 11 | current_dir = os.path.dirname(current_file) 12 | neo_tokenizer = AutoTokenizer.from_pretrained( 13 | "roneneldan/TinyStories", 14 | padding=True, 15 | truncation=True, 16 | max_length=2048, 17 | ) 18 | neo_tokenizer.model_max_length = 2048 19 | neo_tokenizer.add_special_tokens( 20 | { 21 | "bos_token": "[BEGIN]", 22 | "eos_token": "[END]", 23 | "pad_token": "[PAD]", 24 | "unk_token": "[UNK]", 25 | }, 26 | ) 27 | neo_tok_ids_to_ts = torch.load(f"{current_dir}/neo_tok_ids_to_ts.pt", weights_only=True) 28 | ts_tok_ids_to_neo = torch.load(f"{current_dir}/ts_tok_ids_to_neo.pt", weights_only=True) 29 | 30 | 31 | def clean_text(text): 32 | # Convert from unicode to ascii to make tokenization better; don't split up quotation marks into multiple tokens e.g. 33 | text = unidecode(text) 34 | 35 | # tabs to spaces 36 | text = re.sub(r"\t", " ", text) 37 | 38 | # remove trailing spaces 39 | text = re.sub(r"[\s]+\n", "\n", text) 40 | 41 | # Replace multiple newlines with single newline 42 | text = re.sub(r"\n\n+", "\n", text) 43 | 44 | # Replace multiple spaces with single space 45 | text = re.sub(r" +", " ", text) 46 | 47 | return text 48 | 49 | 50 | def enc(stories, padding=True, return_attn_mask=False, max_length=256, add_begin=False): 51 | if add_begin is True and isinstance(max_length, int) and max_length > 0: 52 | max_length = max_length - 1 53 | if isinstance(stories, str): 54 | stories = [stories] 55 | stories = [ 56 | story 57 | for story in stories 58 | if ("â" not in story) 59 | and ("€" not in story) 60 | and ("»" not in story) 61 | and ("«" not in story) 62 | ] 63 | stories = [clean_text(story) for story in stories] 64 | 65 | # Start with the TinyStories tokenizer, the GPTNeo tokenizer. 66 | out = neo_tokenizer( 67 | stories, 68 | max_length=max_length, 69 | return_tensors="pt", 70 | padding=padding, 71 | truncation=True, 72 | ) 73 | input_ids, attn_mask = out["input_ids"], out["attention_mask"] 74 | 75 | # Replace tokens not in the top-10k most frequent tokens in the train dataset with the [UNK] special token. 76 | # Something like 1~6% of documents have at least one [UNK] token 77 | # All non-[UNK] tokens appear at least 100 times in the train dataset 78 | # I think that in the original TinyStories dataset these were just dropped instead of being replaced with an [UNK] token. 79 | unk_mask = ~torch.isin(input_ids, ts_tok_ids_to_neo) * attn_mask.bool() 80 | input_ids[unk_mask] = neo_tokenizer.unk_token_id 81 | 82 | # Replace the first [PAD] token with an [END] token 83 | eos_idx = attn_mask.argmin(dim=1) 84 | for i, eos_i in enumerate(eos_idx): 85 | if eos_i != 0: 86 | input_ids[i, eos_i] = neo_tokenizer.eos_token_id 87 | 88 | # Add a [BEGIN] token to the beginning of every story. The model is trained with this. 89 | if add_begin is True: 90 | input_ids = torch.cat( 91 | ( 92 | torch.full((input_ids.shape[0], 1), neo_tokenizer.bos_token_id), 93 | input_ids, 94 | ), 95 | dim=1, 96 | ) 97 | attn_mask = torch.cat( 98 | (torch.ones(attn_mask.shape[0], 1, dtype=torch.int), attn_mask), dim=1 99 | ) 100 | 101 | # Convert from GPTNeo tok ids to custom tinystories tok ids, the most common Neo tok ids and some special tokens 102 | input_ids = neo_tok_ids_to_ts[input_ids] 103 | 104 | if return_attn_mask: 105 | return input_ids, attn_mask 106 | else: 107 | return input_ids 108 | 109 | 110 | def dec(ts_tok_ids): 111 | if ( 112 | type(ts_tok_ids) in {torch.Tensor, np.ndarray} 113 | and np.prod(ts_tok_ids.shape) == 1 114 | ): 115 | ts_tok_ids = int(ts_tok_ids.item()) 116 | if isinstance(ts_tok_ids, int): 117 | ts_tok_ids = [ts_tok_ids] 118 | if isinstance(ts_tok_ids, list): 119 | ts_tok_ids = torch.tensor(ts_tok_ids) 120 | ts_tok_ids = ts_tok_ids.cpu() 121 | if not isinstance(ts_tok_ids, torch.Tensor): 122 | ts_tok_ids = torch.tensor(ts_tok_ids, dtype=torch.int32) 123 | neo_tok_ids = ts_tok_ids_to_neo[ts_tok_ids] 124 | if len(neo_tok_ids.shape) == 1: 125 | return neo_tokenizer.decode(neo_tok_ids) 126 | else: 127 | return neo_tokenizer.batch_decode(neo_tok_ids) 128 | 129 | 130 | def tok_see( 131 | tok_ids: Union[str, torch.Tensor, list[int]], 132 | printout=False, 133 | symbolic_spaces=True, 134 | symbolic_newlines=True, 135 | ): 136 | if isinstance(tok_ids, str): 137 | tok_ids = enc(tok_ids, add_begin=False, max_length=2048) 138 | if isinstance(tok_ids, np.ndarray): 139 | tok_ids = torch.tensor(tok_ids) 140 | if isinstance(tok_ids, torch.Tensor): 141 | tok_ids = tok_ids.squeeze() 142 | if len(tok_ids.shape) == 0: 143 | tok_ids = tok_ids[None] 144 | assert len(tok_ids.shape) == 1, tok_ids.shape 145 | # toks = [dec(tok_id).replace(' ', '⋅').replace('\n', '↵') for tok_id in tok_ids] 146 | toks = [dec(tok_id) for tok_id in tok_ids] 147 | if symbolic_newlines: 148 | toks = [tok.replace("\n", "↵") for tok in toks] 149 | if symbolic_spaces: 150 | toks = [tok.replace(" ", "⋅") for tok in toks] 151 | if printout: 152 | print(toks) 153 | return toks 154 | 155 | 156 | class Tokenizer: 157 | def __init__(self): 158 | self.vocab_size = 10_000 159 | 160 | def encode(self, s: str): 161 | assert isinstance(s, str) 162 | 163 | return enc(s, add_begin=False)[0].tolist() 164 | 165 | def decode(self, tok_ids: Union[list, torch.Tensor, int]): 166 | if isinstance(tok_ids, int): 167 | tok_ids = [tok_ids] 168 | assert isinstance(tok_ids, list) or isinstance(tok_ids, torch.Tensor) 169 | 170 | return dec(tok_ids) 171 | 172 | def __call__( 173 | self, 174 | docs: List[str], 175 | padding=True, 176 | return_attn_mask=True, 177 | max_length=256, 178 | add_begin=True, 179 | ): 180 | return enc( 181 | docs, 182 | padding=padding, 183 | return_attn_mask=return_attn_mask, 184 | max_length=max_length, 185 | add_begin=add_begin, 186 | ) 187 | 188 | 189 | tokenizer = Tokenizer() 190 | 191 | 192 | raw_toks = np.array([dec(tok_id) for tok_id in range(10_000)]) 193 | toks = np.array([tok.replace("\n", "↵").replace(" ", "⋅") for tok in raw_toks]) -------------------------------------------------------------------------------- /tinymodel/lm.py: -------------------------------------------------------------------------------- 1 | import re 2 | from textwrap import dedent 3 | 4 | import torch 5 | import torch.distributions as dists 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from huggingface_hub import hf_hub_download 9 | 10 | from .lm_modules import TransformerBlock 11 | from .sparse_mlp import SparseMLP 12 | from .tokenization.tokenization import dec, enc 13 | 14 | DEFAULT_SPARSE_MLPS = { 15 | # "M0": "mlp_map_test/M0_S-6_R2_P2", 16 | # "M1": "mlp_map_test/M1_S-4_R8_P2", 17 | 18 | "M0": "mlp_map_test/M0_S-2_R1_P0", 19 | "M1": "mlp_map_test/M1_S-2_R1_P0", 20 | "M2": "mlp_map_test/M2_S-2_R1_P0", 21 | "M3": "mlp_map_test/M3_S-1_B0_P0", 22 | 23 | "A0": "attn/A0_S-2_R1_P0", 24 | "A1": "attn/A1_S-1_R1_P0", 25 | "A2": "attn/A2_S-2_R1_P0", 26 | "A3": "attn/A3_S-1_R1_P0", 27 | } 28 | 29 | 30 | def parse_mlp_tag(mlp_tag): 31 | defaults_tag_pat = re.compile( 32 | r"(?P(M|Rm|Ra|A|Mo))(?P\d+)(\D(?P\d+))?" 33 | ) 34 | defaults_match = defaults_tag_pat.fullmatch(mlp_tag) 35 | file_tag_pat = re.compile(r'(?P(?P(Mo|M|A|Rm|Ra))(?P\d+)_S[-\d]+.{0,6}_P\d+)([^\d](?P\d+))?') 36 | full_file_match = file_tag_pat.fullmatch(mlp_tag) 37 | 38 | if defaults_match: 39 | match_groups = defaults_match.groupdict() 40 | mlp_type, layer, feature_idx = ( 41 | match_groups["mlp_type"], 42 | int(match_groups["layer"]), 43 | match_groups["feature_idx"] 44 | ) 45 | 46 | 47 | feature_idx = None if feature_idx is None else int(feature_idx) 48 | 49 | assert mlp_type+str(layer) in DEFAULT_SPARSE_MLPS 50 | return DEFAULT_SPARSE_MLPS[mlp_type+str(layer)], mlp_type, layer, feature_idx 51 | elif full_file_match: 52 | # try interpreting the mlp_tag as a filename 53 | 54 | mlp_type_to_file = { 55 | 'Mo': 'mlp_out', 56 | 'A': 'attn_test', 57 | 'M': 'mlp_map_test', # transcoder 58 | 'T': 'mlp_map_test', # transcoder 59 | # 'Ra': 'res_pre_attn', 60 | # 'Rm': 'res_pre_mlp' 61 | } 62 | 63 | match_groups = full_file_match.groupdict() 64 | 65 | full_name, mlp_type, layer, feature_idx = match_groups['full_name'], match_groups['mlp_type'], int(match_groups['layer']), match_groups['feature_idx'] 66 | file = mlp_type_to_file[mlp_type] + '/' + full_name 67 | 68 | feature_idx = None if feature_idx is None else int(feature_idx) 69 | 70 | return file, mlp_type, layer, feature_idx 71 | else: 72 | return False 73 | 74 | 75 | 76 | 77 | 78 | class TinyModel(nn.Module): 79 | def __init__( 80 | self, 81 | d_model=768, 82 | n_layers=4, 83 | n_heads=16, 84 | max_seq_len=256, 85 | vocab_size=10_000, 86 | from_pretrained="tiny_model", 87 | ): 88 | super().__init__() 89 | self.embed = nn.Embedding(vocab_size, d_model) 90 | self.embed.weight = nn.Parameter( 91 | 1e-4 * torch.randn(self.embed.weight.shape, requires_grad=True) 92 | ) 93 | self.pos_embed = nn.Parameter( 94 | 1e-4 * torch.randn(1, max_seq_len, d_model, requires_grad=True) 95 | ) 96 | 97 | self.torso = nn.Sequential( 98 | *[ 99 | TransformerBlock( 100 | d_model=d_model, n_heads=n_heads, max_seq_len=max_seq_len 101 | ) 102 | for _ in range(n_layers) 103 | ] 104 | ) 105 | self.lm_head = nn.Linear(d_model, vocab_size) 106 | 107 | self.d_model = d_model 108 | self.n_layers = n_layers 109 | self.n_heads = n_heads 110 | self.vocab_size = vocab_size 111 | self.max_seq_len = max_seq_len 112 | 113 | if isinstance(from_pretrained, str): 114 | self.load_state_dict(get_state_dict(from_pretrained)) 115 | else: 116 | assert ( 117 | from_pretrained is False 118 | ), "from_pretrained kwarg must be False or a string specifying model" 119 | 120 | # Dict from mlp_tag to sparse mlp 121 | self.sparse_mlps = nn.ModuleDict() 122 | 123 | @property 124 | def dtype(self): 125 | return self.embed.weight.dtype 126 | 127 | @property 128 | def device(self): 129 | return self.embed.weight.device 130 | 131 | def forward(self, tok_ids, return_idx=None): 132 | T = tok_ids.shape[-1] 133 | x = self.embed(tok_ids) + self.pos_embed[:, :T] 134 | if return_idx is not None: 135 | assert isinstance(return_idx, int) 136 | assert 0 <= return_idx and return_idx <= self.n_layers 137 | for layer_idx, layer in enumerate(self.torso): 138 | if layer_idx == return_idx: 139 | return x 140 | x = layer(x) 141 | else: 142 | x = self.torso(x) 143 | logits = self.lm_head(x) 144 | return F.log_softmax(logits, dim=-1) 145 | 146 | def generate(self, prompt, n_toks=50, temperature=0.8, break_on_end=True): 147 | assert temperature >= 0.0 148 | toks = enc(prompt, add_begin=True).to(self.lm_head.weight.device) 149 | 150 | for _ in range(n_toks): 151 | with torch.no_grad(): 152 | logprobs = self.forward(toks)[0, -1] 153 | if temperature == 0: 154 | next_tok = logprobs.argmax().item() 155 | else: 156 | next_tok = dists.Categorical( 157 | logits=logprobs * (1 / temperature) 158 | ).sample() 159 | toks = torch.cat((toks, torch.tensor([[next_tok]]).to(toks.device)), dim=-1) 160 | if break_on_end and next_tok == enc("[END]").item(): 161 | break 162 | if toks.shape[1] >= self.max_seq_len: 163 | break 164 | return dec(toks[:, 1:])[0] 165 | 166 | def sparse_mlp(self, mlp_tag=None, mlp=None): 167 | ''' 168 | Returns `get_sparse_mlp_acts`, which takes in tok_ids and returns sparse mlp activations. It optionally allows `indices`. 169 | ''' 170 | assert not (mlp_tag is None and mlp is None) 171 | 172 | parse_output = parse_mlp_tag(mlp_tag) 173 | 174 | 175 | if parse_output is False: 176 | assert False, dedent( 177 | 'Failed to parse mlp.' 178 | ) 179 | # assert False, dedent( 180 | # """ 181 | # [STUB] 182 | # That\'s not a valid MLP tag. Here are some examples of MLP tags: 183 | # M0, A2, Rm0, Ra1, Mo3 184 | # They start with a string in [M, A, Rm, Ra, Mo] 185 | # representing mlp map, attn out SAE, residual pre-mlp SAE, residual pre-attn SAE, and MLP out SAE respectively. 186 | # and they end with a number representing the layer. 187 | 188 | # You can also specify individual feature_idxs, e.g. lm['A2.100'](tok_ids) to get the activations of neuron 100. 189 | # """ 190 | # ) 191 | else: 192 | file, mlp_type, layer, feature_idx = parse_output 193 | mlp_tag = mlp_type + str(layer) 194 | if mlp is None: 195 | sparse_mlp = SparseMLP.from_pretrained(file).to(device=self.device, dtype=self.dtype) 196 | else: 197 | sparse_mlp = mlp.to(device=self.device, dtype=self.dtype) 198 | # else: 199 | # assert False, dedent( 200 | # """ 201 | # mlp_tag {mlp_tag} not found in tiny_model.sparse_mlps or DEFAULT_SPARSE_MLPS 202 | 203 | # [STUB]: unimplemented 204 | # To add a sparse_mlp, do e.g. 205 | # tiny_model.set_saes({ 206 | # \'M2\': SparseMLP.from_pretrained(\'mlp_map/M0_S-1_B0_P0\') 207 | # }) 208 | 209 | # Available keys (of form {mlp_type}{layer}) are: 210 | # M0..3 (for MLPs) 211 | # A0..3 (for Attn out) 212 | # Rm0..3 (for SAE on the residual before MLP) 213 | # Ra0..3 (for SAE on the residual stream before attn) 214 | # Mo0..3 (for SAE on MLP out) 215 | 216 | # See https://huggingface.co/noanabeshima/tiny_model/tree/main for available sparse MLPs. 217 | # """ 218 | # ) 219 | 220 | def get_sparse_mlp_acts(tok_ids, indices=feature_idx): 221 | x = self.forward(tok_ids, return_idx=layer) 222 | if mlp_type == "Ra": 223 | return sparse_mlp.get_acts(x, indices=indices) 224 | attn_out = self.torso[layer].attn(x) 225 | if mlp_type == "A": 226 | return sparse_mlp.get_acts(attn_out, indices=indices) 227 | x = attn_out + x 228 | if mlp_type in {"M", "Rm"}: 229 | return sparse_mlp.get_acts(x, indices=indices) 230 | else: 231 | assert mlp_type == "Mo", "mlp_type must be one of Ra, A, M, Rm, Mo" 232 | mlp_out = self.torso[layer].mlp(x) 233 | return sparse_mlp.get_acts(mlp_out, indices=indices) 234 | 235 | return get_sparse_mlp_acts 236 | 237 | def __getitem__(self, mlp_tag): 238 | """ 239 | To be used like: 240 | sparse_acts = lm['A0'](tok_ids, indices=[1,5,100]) 241 | sparse_acts = lm['M1'](tok_ids, indices=slice(0,100)) 242 | sparse_acts = lm['M3'](tok_ids, indices=0) 243 | 244 | or for single neurons 245 | 246 | sparse_acts = lm['M2N100'](tok_ids) 247 | sparse_acts = lm['M2.100'](tok_ids) 248 | """ 249 | return self.sparse_mlp(mlp_tag) 250 | 251 | 252 | def get_state_dict(model_fname="tiny_model"): 253 | assert model_fname in [ 254 | "tiny_model", 255 | "tiny_model_2L_1E", 256 | "tiny_model_2L_3E" 257 | ], "There are 3 models available: `tiny_model`, `tiny_model_2L_1E`, and `tiny_model_2L_3E`." 258 | state_dict = torch.load( 259 | hf_hub_download(repo_id="noanabeshima/tiny_model", filename=f"{model_fname}.pt"), 260 | map_location=torch.device('cpu'), 261 | weights_only=True 262 | ) 263 | return state_dict 264 | --------------------------------------------------------------------------------