├── .gitignore ├── data ├── monk.txt └── solitude.txt ├── LICENSE ├── requirements.txt ├── README.md └── scripts ├── segment.py ├── utils.py ├── train_probe.py ├── test_probe.py ├── readout.py └── training.py /.gitignore: -------------------------------------------------------------------------------- 1 | env 2 | scripts/__pycache__ 3 | scripts/wandb 4 | checkpoints 5 | logs 6 | -------------------------------------------------------------------------------- /data/monk.txt: -------------------------------------------------------------------------------- 1 | Monk's compositions and improvisations feature dissonances and angular melodic twists, often using flat ninths, flat fifths, unexpected chromatic notes together, low bass notes and stride, and fast whole tone runs, combining a highly percussive attack with abrupt, dramatic use of switched key releases, silences, and hesitations. 2 | -------------------------------------------------------------------------------- /data/solitude.txt: -------------------------------------------------------------------------------- 1 | Many years later, as he faced the firing squad, Colonel Aureliano Buendía was to remember that distant afternoon when his father took him to discover ice. At that time Macondo was a village of twenty adobe houses, built on the bank of a river of clear water that ran along a bed of polished stones, which were white and enormous, like prehistoric eggs. The world was so recent that many things lacked names, and in order to indicate them it was necessary to point. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sheridan Feucht 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.31.0 2 | annotated-types==0.7.0 3 | bidict==0.23.1 4 | blis==0.7.11 5 | catalogue==2.0.10 6 | certifi==2024.6.2 7 | charset-normalizer==3.3.2 8 | click==8.1.7 9 | cloudpathlib==0.18.1 10 | confection==0.1.5 11 | cymem==2.0.8 12 | diffusers==0.29.1 13 | docker-pycreds==0.4.0 14 | einops==0.8.0 15 | filelock==3.15.4 16 | fsspec==2024.6.0 17 | gitdb==4.0.11 18 | GitPython==3.1.43 19 | h11==0.14.0 20 | huggingface-hub==0.23.4 21 | idna==3.7 22 | importlib-metadata==7.2.1 23 | jinja2==3.1.4 24 | langcodes==3.4.0 25 | language-data==1.2.0 26 | marisa-trie==1.2.0 27 | markdown-it-py==3.0.0 28 | MarkupSafe==2.1.5 29 | mdurl==0.1.2 30 | mpmath==1.3.0 31 | murmurhash==1.0.10 32 | networkx==3.1 33 | nnsight==0.2.19 34 | numpy==1.24.4 35 | nvidia-cublas-cu12==12.1.3.1 36 | nvidia-cuda-cupti-cu12==12.1.105 37 | nvidia-cuda-nvrtc-cu12==12.1.105 38 | nvidia-cuda-runtime-cu12==12.1.105 39 | nvidia-cudnn-cu12==8.9.2.26 40 | nvidia-cufft-cu12==11.0.2.54 41 | nvidia-curand-cu12==10.3.2.106 42 | nvidia-cusolver-cu12==11.4.5.107 43 | nvidia-cusparse-cu12==12.1.0.106 44 | nvidia-nccl-cu12==2.20.5 45 | nvidia-nvjitlink-cu12==12.5.40 46 | nvidia-nvtx-cu12==12.1.105 47 | packaging==24.1 48 | pandas==2.0.3 49 | pillow==10.3.0 50 | platformdirs==4.2.2 51 | preshed==3.0.9 52 | protobuf==5.27.1 53 | psutil==6.0.0 54 | pydantic==2.7.4 55 | pydantic-core==2.18.4 56 | pygments==2.18.0 57 | python-dateutil==2.9.0.post0 58 | python-engineio==4.9.1 59 | python-socketio==5.11.3 60 | pytorch-warmup==0.1.1 61 | pytz==2024.1 62 | PyYAML==6.0.1 63 | regex==2024.5.15 64 | requests==2.32.3 65 | rich==13.7.1 66 | safetensors==0.4.3 67 | sentencepiece==0.2.0 68 | sentry-sdk==2.6.0 69 | setproctitle==1.3.3 70 | shellingham==1.5.4 71 | simple-websocket==1.0.0 72 | six==1.16.0 73 | smart-open==7.0.4 74 | smmap==5.0.1 75 | spacy==3.7.5 76 | spacy-legacy==3.0.12 77 | spacy-loggers==1.0.5 78 | srsly==2.4.8 79 | sympy==1.12.1 80 | thinc==8.2.5 81 | tokenizers==0.19.1 82 | torch==2.3.1 83 | torchvision==0.18.1 84 | tqdm==4.66.4 85 | transformers==4.41.2 86 | triton==2.3.1 87 | typer==0.12.3 88 | typing-extensions==4.12.2 89 | tzdata==2024.1 90 | urllib3==2.2.2 91 | wandb==0.17.3 92 | wasabi==1.1.3 93 | weasel==0.4.1 94 | websocket-client==1.8.0 95 | wrapt==1.16.0 96 | wsproto==1.2.0 97 | zipp==3.19.2 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Token Erasure as a Footprint of Implicit Vocabulary Items in LLMs 2 | How do LLMs process multi-token words, common phrases, and named entities? We discover a pattern of token erasure that we hypothesize to be a 'footprint' of how LLMs process unnatural tokenization. 3 | 4 | Read more about our paper here:
5 | 🌐 https://footprints.baulab.info
6 | 📄 https://arxiv.org/abs/2406.20086 7 | 8 | 9 | 10 | ## Setup 11 | To run our code, clone this repository and create a new virtual environment using Python 3.8.10: 12 | ``` 13 | python3 -m venv env 14 | source env/bin/activate 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## Segmenting a Document 19 | An implementation of Algorithm 1 in our paper is provided in `segment.py`. This script can be run like so: 20 | ``` 21 | python segment.py --document my_doc.txt --model meta-llama/Llama-2-7b-hf 22 | ``` 23 | allowing you to segment any paragraph of text into high-scoring token sequences. 24 | ``` 25 | segments from highest to lowest score: 26 | 'dramatic' 0.5815845847818102 27 | 'twists' 0.5553912909024803 28 | 'low bass' 0.41476866118921824 29 | 'cuss' 0.3979072428604316 30 | 'fifth' 0.3842911866668146 31 | 'using' 0.3568337553491195 32 | ... 33 | 'ive' -0.07994025301498671 34 | 's' -0.14006704260206485 35 | 'ations' -0.2306471753618856 36 | 'itions' -0.3348596893891435 37 | ``` 38 | Adding the `--output_html` flag will also save an HTML file in the style of the below example to the folder `./logs/html`, bolding all multi-token sequences and coloring them blue if they have a higher erasure score. 39 | 40 | Monk example from website 41 | 42 | ## "Reading Out" Vocabulary Examples 43 | To apply this segmentation algorithm to an entire dataset (as seen in Tables 3 through 6), run 44 | ``` 45 | python readout.py --model meta-llama/Meta-Llama-3-8B --dataset ../data/wikipedia_test_500.csv 46 | ``` 47 | which specifically replicates Appendix Table 6. You can use your own dataset csv, as long as it contains a 'text' column with the documents you want to analyze. 48 | 49 | ## Loading Our Probes 50 | Checkpoints for each of the linear probes used in our paper are available at https://huggingface.co/sfeucht/footprints. To load a linear probe used in this paper, run the following code snippet: 51 | 52 | ```python 53 | import torch 54 | import torch.nn as nn 55 | from huggingface_hub import hf_hub_download 56 | 57 | class LinearModel(nn.Module): 58 | def __init__(self, input_size, output_size, bias=False): 59 | super(LinearModel, self).__init__() 60 | self.fc = nn.Linear(input_size, output_size, bias=bias) 61 | def forward(self, x): 62 | output = self.fc(x) 63 | return output 64 | 65 | # example: llama-2-7b probe at layer 0, predicting 3 tokens ago 66 | # predicting the next token would be `layer0_tgtidx1.ckpt` 67 | checkpoint_path = hf_hub_download( 68 | repo_id="sfeucht/footprints", 69 | filename="llama-2-7b/layer0_tgtidx-3.ckpt" 70 | ) 71 | 72 | # model_size is 4096 for both models. 73 | # vocab_size is 32000 for Llama-2-7b and 128256 for Llama-3-8b 74 | probe = LinearModel(4096, 32000).cuda() 75 | probe.load_state_dict(torch.load(checkpoint_path)) 76 | ``` 77 | 78 | ## Training Your Own Probes 79 | We have provided the probes used for the paper above. However, if you would still like to train your own linear probes, we provide code for training and testing linear probes on Llama hidden states in `./scripts`. To train a probe on e.g. layer 12 to predict two tokens ago, run 80 | ``` 81 | python train_probe.py --layer 12 --target_idx -2 82 | ``` 83 | and a linear model will be trained on Llama-2-7b by default and stored as a checkpoint in `./checkpoints`. These checkpoints can then be read by `./scripts/test_probe.py` and tested on either CounterFact tokens, Wikipedia tokens (multi-token words or spaCy entities), or plain Pile tokens. Test results are stored in `./logs`. 84 | ``` 85 | python test_probe.py --checkpoint ../checkpoints/Llama-2-7b-hf/.../final.ckpt --test_data counterfact_expanded.csv 86 | ``` 87 | 88 | ## Datasets Used 89 | We use three datasets in this paper, which can all be found in `./data`. 90 | 91 | - CounterFact [(Meng et al., 2022)](https://rome.baulab.info/) 92 | - `counterfact_expanded.csv` was used for all of the CounterFact tests in the paper, and includes rows in addition to the original CounterFact dataset. 93 | - Pile [(Gao et al., 2020)](https://pile.eleuther.ai/) 94 | - `train_tiny_1000.csv` was used to train all of the probes. 95 | - `val_tiny_500.csv` was used to validate probe hyperparameters. 96 | - `test_tiny_500.csv` was used for overall Pile test results. 97 | - Wikipedia [(Wikimedia Foundation, 2022)](https://huggingface.co/datasets/legacy-datasets/wikipedia) 98 | - `wikipedia_test_500.csv` was used for overall Wikipedia test results. 99 | - `wikipedia_val_500.csv` and `wikipedia_train_1000.csv` were not used in this work, but are included for completeness. 100 | -------------------------------------------------------------------------------- /scripts/segment.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Segment one document into highest-scoring token subsequences. 3 | ''' 4 | import os 5 | import argparse 6 | import math 7 | from nnsight import LanguageModel 8 | from readout import get_doc_info, partition_doc, get_probe 9 | 10 | ## HTML Formatting 11 | def col(s): 12 | # make sure between -1 and 1 13 | s = max(-1, min(1, s)) 14 | 15 | # blue (high) 0, 98, 255 16 | blue_r = 0 17 | blue_g = 98 18 | blue_b = 255 19 | 20 | if (s > 0): 21 | r = math.floor(255 + s * (blue_r - 255)) 22 | g = math.floor(255 + s * (blue_g - 255)) 23 | b = math.floor(255 + s * (blue_b - 255)) 24 | 25 | return f"style='background-color:rgb({r}, {g}, {b})'" 26 | else: 27 | return "" 28 | 29 | def html_span(segments, doc_unis, tokenizer): 30 | # all the tokens in original doc 31 | doc_tokens = [t for t in doc_unis['tok-1']] 32 | 33 | boxs_col = lambda x: tokenizer(f"")['input_ids'][1:] 34 | boxs = lambda x: tokenizer(f"")['input_ids'][1:] 35 | boxe = tokenizer("")['input_ids'][1:] 36 | 37 | blds = tokenizer("")['input_ids'][1:] 38 | blde = tokenizer("")['input_ids'][1:] 39 | 40 | bos_tok = tokenizer("<s>")['input_ids'][1:] 41 | 42 | out = [] 43 | for (x, y), val in segments: 44 | tokens = doc_tokens[x:y+1] 45 | 46 | # replace with s 47 | if tokens[0] == 1: 48 | tokens = bos_tok + tokens[1:] # 's' 49 | 50 | if len(tokens) > 1: 51 | to_appnd = [*boxs_col(val), *blds, *tokens, *blde, *boxe] 52 | else: 53 | to_appnd = [*boxs(val), *tokens, *boxe] 54 | 55 | out += to_appnd 56 | 57 | return tokenizer.decode(out) 58 | 59 | def html_view(segments, doc_unis, tokenizer): 60 | # sort segments to be in document order 61 | segments = sorted(segments, key=lambda t: t[0][0]) 62 | inject = html_span(segments, doc_unis, tokenizer) 63 | 64 | # white out non-mtw segment values 65 | whiteout = [] 66 | for (x,y), val in segments: 67 | if x == y: 68 | whiteout.append(((x,y), 0)) 69 | else: 70 | whiteout.append(((x,y), val)) 71 | 72 | template = f""" 73 | 74 | 75 | 76 | 84 | 85 | 86 | 87 |

88 | {inject} 89 |

90 | 91 | 92 | 93 | """ 94 | 95 | return template 96 | 97 | 98 | def load_model_and_probes(path): 99 | model = LanguageModel(path, device_map='cuda') 100 | tokenizer = model.tokenizer 101 | 102 | hf_name = { 103 | 'meta-llama/Llama-2-7b-hf' : 'llama-2-7b', 104 | 'meta-llama/Meta-Llama-3-8B' : 'llama-3-8b' 105 | }[path] 106 | layer_start = 1 107 | layer_end = 9 108 | 109 | start_probes, end_probes = [], [] 110 | start_probes.append(get_probe(layer_start, 0, hf_name)) 111 | end_probes.append(get_probe(layer_end, 0, hf_name)) 112 | 113 | start_probes.append(get_probe(layer_start, -1, hf_name)) 114 | end_probes.append(get_probe(layer_end, -1, hf_name)) 115 | 116 | start_probes.append(get_probe(layer_start, -2, hf_name)) 117 | end_probes.append(get_probe(layer_end, -2, hf_name)) 118 | 119 | return model, tokenizer, start_probes, end_probes 120 | 121 | 122 | def main(args): 123 | # load in model and probes 124 | model, tokenizer, start_probes, end_probes = load_model_and_probes(args.model) 125 | 126 | # read in given txt file 127 | with open(args.document, 'r') as f: 128 | input_text = f.read().strip() 129 | 130 | # tokenize 131 | tokens = tokenizer(input_text)['input_ids'][:args.max_length] 132 | 133 | # get probe info and partition document 134 | doc_info = get_doc_info(tokens, model, args.layer_start, args.layer_end, start_probes, end_probes, tokenizer) 135 | segments = partition_doc(doc_info) 136 | 137 | # save html output if desired 138 | if args.output_html: 139 | html_output = html_view(segments, doc_info, tokenizer) 140 | 141 | write_dir = "../logs/html/" 142 | fname = f"{args.document.split('/')[-1][:-4]}.html" 143 | os.makedirs(write_dir, exist_ok=True) 144 | 145 | print("saving html output as " + write_dir + fname) 146 | with open(f"../logs/html/{args.document.split('/')[-1][:-4]}.html", 'w') as f: 147 | f.write(html_output) 148 | 149 | # print document segments 150 | print("\nsegments from highest to lowest score:") 151 | for (p, q), val in segments: 152 | text = tokenizer.decode(tokens[p : q+1]) 153 | print(repr(text), '\t', val) 154 | 155 | if __name__ == "__main__": 156 | parser = argparse.ArgumentParser() 157 | 158 | parser.add_argument("--document", default="../data/monk.txt") 159 | parser.add_argument('--layer_start', type=int, default=1) 160 | parser.add_argument('--layer_end', type=int, default=9) 161 | parser.add_argument('--max_length', type=int, default=256) 162 | 163 | parser.add_argument('--output_html', action='store_true') 164 | parser.set_defaults(output_html=False) 165 | 166 | parser.add_argument('--model', default='meta-llama/Llama-2-7b-hf', 167 | choices=['meta-llama/Meta-Llama-3-8B', 'meta-llama/Llama-2-7b-hf']) 168 | 169 | args = parser.parse_args() 170 | 171 | main(args) 172 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Small helper functions used for training and testing. 3 | ''' 4 | import torch 5 | 6 | def acc(pred_dist, target_toks): 7 | return 100 * (sum(pred_dist.argmax(dim=-1) == target_toks) / len(target_toks)) 8 | 9 | def _topktoks(logits, k=1): 10 | _, top_tokens = logits.topk(k=k, dim=-1) 11 | return top_tokens 12 | 13 | def _topkprobs(logits, tokenizer, k=5): 14 | top_probs, top_tokens = torch.softmax(logits, dim=0).topk(k=k, dim=-1) 15 | out = {} 16 | for i in range(k): 17 | out[f"top_{i+1}_prob"] = top_probs[i].item() 18 | out[f"top_{i+1}_tok_id"] = top_tokens[i].item() 19 | out[f"top_{i+1}_tok"] = tokenizer.decode(top_tokens[i].tolist()) 20 | return out 21 | 22 | ''' 23 | The next three functions handle entity_mask lists for the function test() in training.py 24 | ''' 25 | # assume you are given i as the last position in ngram mask 26 | def ngram_size(i, entity_mask): 27 | assert i >= 0 28 | assert entity_mask[i] == 1 29 | assert i + 1 == len(entity_mask) or entity_mask[i+1] == 0 30 | 31 | if i == 0: 32 | return 1 # unigram [1] 33 | size = 1 34 | j = i - 1 35 | 36 | while entity_mask[j] == 1: 37 | size += 1 38 | j -= 1 39 | if j < 0: 40 | break 41 | return size 42 | 43 | # is position i the last position in an entity in this mask? 44 | def is_entity_last(i, entity_mask): 45 | if entity_mask[i] == 0: 46 | return False 47 | else: 48 | if i + 1 == len(entity_mask): 49 | return True # end of sequence, i is last. 50 | 51 | # otherwise i is last if the next token is 0. 52 | return entity_mask[i+1] == 0 53 | 54 | # this is about to get thorny. 55 | # returns whether there is an entity ngram of our chosen shape at i based on target index: 56 | # 0 counts only unigram entities, e.g. [Iran] 57 | # -1 counts only bigram entities, e.g. [New, York] 58 | # -2 counts only trigram entities, e.g. [Empire, State, Building] 59 | # -3 counts only 4-gram entities, e.g. [Co, ca, Co, la] 60 | # 1 counts only bigram entities but the other way. 61 | def is_entity_ngram(i, entity_mask, target_idx=-1): 62 | if entity_mask[i] == 0: 63 | return False 64 | 65 | # otherwise, the current guy is an entity, so let's check if the previous/future one 66 | # (to make an ngram) and the ones in between are also an entity as well 67 | else: 68 | if target_idx == 0: 69 | if i-1<0 and i+1 >= len(entity_mask): # [1] 70 | return bool(entity_mask[i]) 71 | elif i-1<0 and i+1 < len(entity_mask): # [1,0...] 72 | return bool(entity_mask[i]) and not entity_mask[i+1] 73 | elif i-1>=0 and i+1 >= len(entity_mask): # [...0,1] 74 | return bool(entity_mask[i]) and not entity_mask[i-1] 75 | else: # i-1>=0 and i+1 < len(entity_mask), they're both within bounds 76 | return bool(entity_mask[i]) and not entity_mask[i+1] and not entity_mask[i-1] 77 | 78 | # backwards bigram. [1,1,0,0,0] 79 | elif target_idx == -1: 80 | if i-1<0: 81 | return False 82 | if i-2<0: 83 | if i+1 < len(entity_mask): 84 | return not entity_mask[i+1] and bool(entity_mask[i] and entity_mask[i-1]) 85 | else: 86 | return bool(entity_mask[i] and entity_mask[i-1]) 87 | else: 88 | if i+1 < len(entity_mask): 89 | return not entity_mask[i+1] and not entity_mask[i-2] and bool(entity_mask[i] and entity_mask[i-1]) 90 | else: 91 | return not entity_mask[i-2] and bool(entity_mask[i] and entity_mask[i-1]) 92 | 93 | # backwards trigram [0,1,1,1,0] 94 | elif target_idx == -2: 95 | if i-2<0: 96 | return False 97 | if i-3<0: 98 | if i+1 < len(entity_mask): 99 | return not entity_mask[i+1] and sum(entity_mask[i-2:i+1]) == 3 100 | else: 101 | return sum(entity_mask[i-2:i+1]) == 3 102 | else: 103 | if i+1 < len(entity_mask): 104 | return not entity_mask[i+1] and not entity_mask[i-3] and (sum(entity_mask[i-2:i+1]) == 3) 105 | else: 106 | return not entity_mask[i-3] and (sum(entity_mask[i-2:i+1]) == 3) 107 | 108 | # backwards 4-gram [0,1,1,1,1] 109 | elif target_idx == -3: 110 | if i-3<0: 111 | return False 112 | if i-4<0: 113 | if i+1 < len(entity_mask): # [1,1,1,1,0...] 114 | return not entity_mask[i+1] and sum(entity_mask[i-3:i+1]) == 4 115 | else: # [1,1,1,1] 116 | return sum(entity_mask[i-3:i+1]) == 4 117 | else: 118 | if i+1 < len(entity_mask): # [0,1,1,1,1,0...] 119 | return not entity_mask[i+1] and not entity_mask[i-4] and (sum(entity_mask[i-3:i+1]) == 4) 120 | else: # [...0,1,1,1,1] 121 | return not entity_mask[i-4] and (sum(entity_mask[i-3:i+1]) == 4) 122 | 123 | # forwards bigram [0,1,1,0] 124 | elif target_idx == 1: 125 | if i+1 >= len(entity_mask): # [...,1] 126 | return False 127 | if i+2 >= len(entity_mask): # [...,1,0] 128 | if i-1 < 0: # [1,0] 129 | return bool(entity_mask[i] and entity_mask[i+1]) 130 | else: # [...0,1,0] 131 | return not entity_mask[i-1] and bool(entity_mask[i] and entity_mask[i+1]) 132 | else: 133 | if i-1 < 0: # [1,0,...] 134 | return not entity_mask[i+2] and bool(entity_mask[i] and entity_mask[i+1]) 135 | else: # [...0,1,0,...] 136 | return not entity_mask[i-1] and not entity_mask[i+2] and bool(entity_mask[i] and entity_mask[i+1]) -------------------------------------------------------------------------------- /scripts/train_probe.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Train a new linear probe to predict a token offset (target_idx) at a specific 3 | layer. For example, you can train a probe with target_idx=-2 and layer=12 to 4 | take hidden states at layer 12 (the 13th layer) and predict what was two tokens 5 | before that hidden state. 6 | ''' 7 | import os 8 | import csv 9 | import wandb 10 | import torch 11 | import argparse 12 | 13 | import pandas as pd 14 | import torch.nn.functional as F 15 | import pytorch_warmup as warmup 16 | 17 | from torch.utils.data import DataLoader 18 | from torch.optim.lr_scheduler import ReduceLROnPlateau 19 | from nnsight import LanguageModel 20 | 21 | from training import LinearModel, train_epoch, test, DocDataset, DocCollate 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | wandb.login() 25 | 26 | def datasetname(input): 27 | return input.split('/')[-1][:-4] 28 | 29 | def add_args(s, args): 30 | for k, v in vars(args).items(): 31 | if k in ['probe_bsz', 'probe_epochs']: 32 | s += f"-{k[6:]}{v}" 33 | elif k in ['probe_lr']: 34 | s += f"-{k[6:]}" + "{:1.5f}".format(v) 35 | return s 36 | 37 | def main(args): 38 | model = LanguageModel(args.model, device_map=device) 39 | tokenizer = model.tokenizer 40 | tokenizer.add_special_tokens({'pad_token':''}) 41 | 42 | VOCAB_SIZE = model.config.vocab_size 43 | MODEL_SIZE = model.config.hidden_size 44 | MODEL_NAME = args.model.split('/')[-1] 45 | 46 | # max_seq_len from below sources is 2048, but changing to 512 for memory/speed 47 | # https://github.com/meta-llama/llama/blob/main/llama/model.py#L31 48 | # https://github.com/meta-llama/llama3/blob/bf8d18cd087a4a0b3f61075b7de0b86cf6c70697/llama/model.py#L32 49 | WINDOW_SIZE = args.window_size 50 | 51 | for param in model.parameters(): 52 | param.requires_grad = False 53 | 54 | run_name = add_args(f"{MODEL_NAME}LAYER{args.layer}-TGTIDX{args.target_idx}-{datasetname(args.train_data)}", args) 55 | wandb.init(project = args.wandb_proj, name = run_name, config = args, settings=wandb.Settings(start_method="fork")) 56 | 57 | run_name += f'-{wandb.run.id}' 58 | wandb.run.name = run_name 59 | 60 | if args.probe_bsz > 10: 61 | print(f"Warning: batch size represents number of documents (each doc contains a few hundred tokens). {args.probe_bsz}>10, you may want to use a smaller batch size.") 62 | 63 | # make dirs that include the wandb id 64 | checkpoint_dir = f"../checkpoints/{MODEL_NAME}/{run_name}" 65 | log_dir = f"../logs/{MODEL_NAME}/{run_name}" 66 | os.makedirs(checkpoint_dir, exist_ok=True) 67 | os.makedirs(log_dir, exist_ok=True) 68 | 69 | # load data csvs 70 | train_data = pd.read_csv(args.train_data) 71 | val_data = pd.read_csv(args.val_data) 72 | test_data = pd.read_csv(args.test_data) 73 | 74 | # pass in subjects from counterfact dataset as "entities" to split during testing 75 | which_entity = "subject" 76 | entities = None 77 | if test_data is not None: 78 | if which_entity in test_data.columns: 79 | entities = list(test_data[which_entity]) 80 | 81 | train_dataset = DocDataset(model, tokenizer, args.layer, args.target_idx, train_data, WINDOW_SIZE, VOCAB_SIZE, device) 82 | val_dataset = DocDataset(model, tokenizer, args.layer, args.target_idx, val_data, WINDOW_SIZE, VOCAB_SIZE, device) 83 | test_dataset = DocDataset(model, tokenizer, args.layer, args.target_idx, test_data, WINDOW_SIZE, VOCAB_SIZE, device, entities=entities) 84 | 85 | linear_probe = LinearModel(MODEL_SIZE, VOCAB_SIZE).to(device) 86 | wandb.watch(linear_probe) 87 | 88 | collate_fn = DocCollate(args.layer, args.target_idx, tokenizer, model, WINDOW_SIZE, device) 89 | 90 | train_loader = DataLoader(dataset=train_dataset, batch_size=args.probe_bsz, collate_fn=collate_fn, 91 | drop_last=True, pin_memory=False, shuffle=True) 92 | val_loader = DataLoader(dataset=val_dataset, batch_size=args.probe_bsz, collate_fn=collate_fn, 93 | drop_last=True, pin_memory=False) 94 | test_loader = DataLoader(dataset=test_dataset, batch_size=args.probe_bsz, collate_fn=collate_fn, 95 | drop_last=True, pin_memory=False) 96 | 97 | optimizer = torch.optim.AdamW(linear_probe.parameters(), lr=args.probe_lr, weight_decay=args.probe_wd) # no momentum 98 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer) 99 | 100 | criterion = { 101 | "ce" : F.cross_entropy, 102 | "mse" : F.mse_loss 103 | }[args.criterion] 104 | scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3) 105 | 106 | print('training linear probe...') 107 | batches_seen = 0 108 | for epoch in range(args.probe_epochs): 109 | print('# Epoch {} #'.format(epoch)) 110 | batches_seen = train_epoch(epoch, linear_probe, train_loader, criterion, optimizer, warmup_scheduler, args.accumulate, args.clip_threshold, batches_seen, device) 111 | 112 | # log validation loss at the end of each epoch 113 | val_loss, val_acc, val_topk_acc, val_entity_ng_acc, val_other_acc = test(linear_probe, val_loader, criterion, device, tokenizer, args.target_idx, return_results=False) 114 | wandb.log({"val_loss": val_loss, "val_acc": val_acc, "val_topk_acc": val_topk_acc, "val_entity_ng_acc": val_entity_ng_acc, "val_other_acc":val_other_acc}) 115 | 116 | if warmup_scheduler is not None: 117 | with warmup_scheduler.dampening(): 118 | scheduler.step(val_loss) 119 | else: 120 | scheduler.step(val_loss) 121 | 122 | # Get final testing accuracy and prediction results 123 | torch.save(linear_probe.state_dict(), f"{checkpoint_dir}/final.ckpt") 124 | test_loss, test_acc, test_topk_acc, test_entity_ng_acc, test_other_acc, test_results = test(linear_probe, test_loader, criterion, device, tokenizer, args.target_idx, return_results=True) 125 | test_results.to_csv(log_dir + f"/{datasetname(args.test_data)}_results.csv", quoting=csv.QUOTE_ALL, encoding='utf-8') 126 | 127 | print('Test Loss: {:10.4f} Accuracy: {:3.4f}%\n'.format(test_loss, test_acc)) 128 | wandb.log({"test_loss": test_loss, "test_acc": test_acc, "test_topk_acc": test_topk_acc, "test_entity_ng_acc": test_entity_ng_acc, "test_other_acc": test_other_acc}) 129 | wandb.finish() 130 | 131 | 132 | if __name__ == '__main__': 133 | parser = argparse.ArgumentParser() 134 | 135 | # training info for linear probe 136 | parser.add_argument('--probe_bsz', type=int, default=1) 137 | parser.add_argument('--probe_lr', type=float, default=0.1) 138 | parser.add_argument('--probe_wd', type=float, default=0.001) 139 | parser.add_argument('--probe_epochs', type=int, default=8) 140 | 141 | parser.add_argument('--wandb_proj', type=str, default='footprints') 142 | parser.add_argument('--accumulate', type=int, default=30) 143 | parser.add_argument('--clip_threshold', type=float, default=0.1) 144 | 145 | parser.add_argument('--window_size', type=int, default=512) 146 | parser.add_argument('--num_workers', type=int, default=12) 147 | parser.add_argument('--criterion', type=str, choices=['mse', 'ce'], default='ce') 148 | 149 | # document data locations 150 | parser.add_argument('--train_data', type=str, default='../data/train_tiny_1000.csv') 151 | parser.add_argument('--val_data', type=str, default='../data/val_tiny_500.csv') 152 | parser.add_argument('--test_data', type=str, default='../data/test_tiny_500.csv') 153 | 154 | # required specifications for where probe is trained 155 | parser.add_argument('--layer', type=int, required=True, 156 | help='which layer to train the probe at, from -1...32 where -1 is embedding layer and 32 is output pre-softmax.') 157 | parser.add_argument('--target_idx', type=int, required=True, 158 | help='which token the probe should predict from current hidden state (e.g. 0 for current token, -1 for prev)') 159 | parser.add_argument('--model', type=str, choices=['meta-llama/Llama-2-7b-hf', 'meta-llama/Meta-Llama-3-8B', 'EleutherAI/pythia-6.9b'], default='meta-llama/Llama-2-7b-hf') 160 | 161 | 162 | args = parser.parse_args() 163 | main(args) 164 | 165 | 166 | -------------------------------------------------------------------------------- /scripts/test_probe.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Load a probe checkpoint from `train_probe.py and get test results on it for another dataset. 3 | If you're loading a llama 3 checkpoint, make sure to specify --model meta-llama/Meta-Llama-3-8B 4 | ''' 5 | import os 6 | import csv 7 | import torch 8 | import argparse 9 | import pandas as pd 10 | import numpy as np 11 | import regex as re 12 | import spacy 13 | 14 | import torch.nn.functional as F 15 | from collections import Counter 16 | from torch.utils.data import DataLoader 17 | from nnsight import LanguageModel 18 | from training import LinearModel, test, DocDataset, DocCollate 19 | 20 | torch.manual_seed(0) 21 | 22 | idx_to_n = { 23 | 1 : 2, 24 | 0 : 1, 25 | -1 : 2, 26 | -2 : 3, 27 | -3 : 4 28 | } 29 | 30 | def idx_to_zip(idx, toks): 31 | if idx <= 0: 32 | skip = { 33 | 0 : 0, 34 | -1 : 1, 35 | -2 : 2, 36 | -3 : 3 37 | }[idx] 38 | return zip(toks[:-skip], toks[skip:]) 39 | elif idx == 1: 40 | return zip(toks[1:], toks[:-1]) 41 | 42 | def datasetname(input): 43 | return input.split('/')[-1][:-4] 44 | 45 | # get nice string when saving wikipedia results 46 | def wikiextra(datasetname, spacy): 47 | if "wikipedia" in datasetname: 48 | if spacy: 49 | return "_mte" 50 | else: 51 | return "_mtw" 52 | else: 53 | return "" 54 | 55 | # take in results dataframe and add column `train_freq` indicating frequency of that ngram in the training dataset. 56 | # note: this only does the SKIP ngrams. so "the big tower" and "the small tower" are the same for tgtidx=-2. 57 | def train_ngram_info(results, train_df, tokenizer, model_name, target_idx): 58 | # get information about train ngram frequencies 59 | gt_ctr = Counter() 60 | all_gt_ct = 0 61 | for d in list(train_df['text']): 62 | if "llama" in model_name: 63 | toks = tokenizer(d)['input_ids'] 64 | else: 65 | bos = tokenizer.bos_token 66 | toks = tokenizer(bos+d)['input_ids'] 67 | gt_ctr += Counter(idx_to_zip(target_idx, toks)) 68 | all_gt_ct += len(toks) 69 | 70 | # then for each ngram in results, save that ngram's train count 71 | results['all_train_ct'] = [all_gt_ct for _ in range(len(results))] 72 | new = [] 73 | for i, row in results.iterrows(): 74 | # want to be able to change this based on target_idx 75 | try: 76 | ct = gt_ctr[(row['actual_tok_id'], row['current_tok_id'])] 77 | except KeyError: 78 | ct = 0 79 | 80 | row['ngram_train_ct'] = ct 81 | 82 | if row['ngram_train_ct'] > 0: 83 | row['log_ngram_train_freq'] = np.log(row['ngram_train_ct'] / row['all_train_ct']) 84 | else: 85 | row['log_ngram_train_freq'] = 0 86 | new.append(row) 87 | 88 | return pd.DataFrame(new) 89 | 90 | # returns a |V|x|V| matrix where each row is conditional probability of all preceding tokens 91 | # given that token as the succeeding token. e.g. row for "York" has a high probability on "New" 92 | # https://jofrhwld.github.io/teaching/courses/2022_lin517/python_sessions/04_session5.html#conditional-probability 93 | def train_conditional_probs(train_df, train_name, tokenizer, model_name, vocab_size, target_idx=-1): 94 | q_name = train_name + f"_qmodel{target_idx}.ckpt" 95 | if q_name in os.listdir("../data/qmodels"): 96 | # load the previous matrix for this dataset and target_idx=-1 97 | return torch.load(f"../data/qmodels/{q_name}") 98 | 99 | # edit this to account for diff target_idxs 100 | ng_ctr = Counter() 101 | ug_ctr = Counter() 102 | out = torch.zeros(size=(vocab_size, vocab_size)) 103 | for doc in list(train_df['text']): 104 | if "llama" in model_name: 105 | toks = tokenizer(doc)['input_ids'] 106 | else: 107 | bos = tokenizer.bos_token 108 | toks = tokenizer(bos+doc)['input_ids'] 109 | ng_ctr += Counter(idx_to_zip(target_idx, toks)) # ngrams(toks, idx_to_n[target_idx])) 110 | ug_ctr += Counter(toks) 111 | assert(len(ug_ctr) <= vocab_size) 112 | 113 | # each ROW correspnds to the second token 114 | # right now we fill with joint probabilities 115 | for tup, ct in ng_ctr.items(): 116 | out[tup] = ct / sum(ng_ctr.values()) 117 | 118 | # then divide each row by p(x1) out of all unigrams 119 | unigram_probs = torch.zeros(size=(vocab_size,)) 120 | for i in range(vocab_size): 121 | unigram_probs[i] = ug_ctr[i] / sum(ug_ctr.values()) 122 | assert(torch.sum(unigram_probs) == 1) 123 | 124 | # avoid divide by 0 errors by just dividing by 1 (since numerator will also be 0 in those cases) 125 | denom = torch.where(unigram_probs == 0, torch.ones_like(unigram_probs), unigram_probs) 126 | out = (out.T / denom).T 127 | print(out) 128 | 129 | torch.save(out, f"../data/{q_name}.ckpt") 130 | return out 131 | 132 | 133 | def main(args): 134 | device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu") 135 | 136 | model = LanguageModel(args.model, device_map=device) 137 | tokenizer = model.tokenizer 138 | tokenizer.add_special_tokens({'pad_token':''}) 139 | 140 | VOCAB_SIZE = model.vocab_size 141 | MODEL_SIZE = model.config.hidden_size 142 | MODEL_NAME = args.model.split('/')[-1] 143 | 144 | # window size is actually 2048 but I choose 512 for brevity 145 | WINDOW_SIZE = args.window_size 146 | 147 | for p in model.parameters(): 148 | p.requires_grad = False 149 | 150 | probe = LinearModel(MODEL_SIZE, VOCAB_SIZE).to(device) 151 | probe.load_state_dict(torch.load(args.checkpoint)) 152 | 153 | # intuit target_idx from the filename 154 | if "TGTIDX" in args.checkpoint: 155 | s = re.search(r'TGTIDX-\d+|TGTIDX\d+', args.checkpoint).group() 156 | target_idx = int(s[6:]) 157 | else: 158 | raise Exception("Can't infer target index from checkpoint: " + args.checkpoint) 159 | 160 | # intuit layer from the filename 161 | if "LAYER" in args.checkpoint: 162 | s = re.search(r'LAYER-\d+|LAYER\d+', args.checkpoint).group() 163 | layer = int(s[5:]) 164 | else: 165 | raise Exception("Can't infer layer from checkpoint: " + args.checkpoint) 166 | 167 | collate_fn = DocCollate(layer, target_idx, tokenizer, model, WINDOW_SIZE, device) 168 | 169 | test_data = pd.read_csv(args.test_data) 170 | 171 | if args.test_data == "../data/counterfact_expandeds.csv": 172 | corr_str = { 173 | 'Llama-2-7b-hf' : 'llama-2-7b', 174 | 'Meta-Llama-3-8B' : 'llama-3-8b' 175 | }[MODEL_NAME] 176 | 177 | test_data = test_data.loc[test_data[f'{corr_str}_correct']] 178 | print(f"pruned down to only correct CounterFact answers, {len(test_data)}") 179 | 180 | # pass in subjects from counterfact dataset as entities 181 | which_entity = "subject" 182 | entities = None 183 | if test_data is not None: 184 | if which_entity in test_data.columns: 185 | entities = list(test_data[which_entity]) 186 | 187 | if 'wikipedia' in args.test_data: 188 | if args.wiki_spacy: 189 | nlp = spacy.load("en_core_web_sm") 190 | entities = [] 191 | for d in test_data['text']: 192 | doc = nlp(d) 193 | # https://stackoverflow.com/questions/70185150/return-all-possible-entity-types-from-spacy-model 194 | # ['ORG', 'CARDINAL', 'DATE', 'GPE', 'PERSON', 'MONEY', 'PRODUCT', 'TIME', 'PERCENT', 'WORK_OF_ART', 'QUANTITY', 'NORP', 'LOC', 'EVENT', 'ORDINAL', 'FAC', 'LAW', 'LANGUAGE'] 195 | # we want non-number ones. no dates, money, cardinals, time etc. 196 | desired_types = ['ORG', 'GPE', 'PERSON', 'PRODUCT', 'WORK_OF_ART', 'NORP', 'LOC', 'EVENT', 'FAC', 'LAW', 'LANGUAGE'] 197 | entities += [e.text for e in doc.ents if e.label_ in desired_types] 198 | entities = list(set(entities)) 199 | print(entities[:10]) 200 | else: 201 | multi_tok = [] 202 | for txt in test_data['text']: 203 | txt = re.sub(r'[^\w\s]', '', txt) 204 | txt = re.sub(r'[0-9]', '', txt) 205 | for word in Counter(txt.split()): 206 | if len(tokenizer(word)['input_ids'][1:]) > 1: 207 | multi_tok.append(word) 208 | print(multi_tok[:10]) 209 | entities = list(set(multi_tok)) 210 | 211 | test_dataset = DocDataset(model, tokenizer, layer, target_idx, test_data, WINDOW_SIZE, VOCAB_SIZE, device, entities=entities) 212 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, collate_fn=collate_fn) 213 | 214 | criterion = { 215 | "ce" : F.cross_entropy, 216 | "mse" : F.mse_loss 217 | }[args.criterion] 218 | 219 | test_loss, test_acc, test_topk_acc, test_entity_ng_acc, test_other_acc, test_results = test(probe, test_loader, 220 | criterion, device, tokenizer, target_idx, return_results=True) 221 | 222 | model_folder = args.checkpoint.split('/')[2] # hf-llama-2 223 | run_name = args.checkpoint.split('/')[-2] 224 | log_dir = f"../logs/{model_folder}/{run_name}/" 225 | out_csv = f"{datasetname(args.test_data)}_results{wikiextra(args.test_data, args.wiki_spacy)}.csv" 226 | 227 | os.makedirs(log_dir, exist_ok=True) 228 | print(log_dir + out_csv) 229 | test_results.to_csv(log_dir + out_csv, quoting=csv.QUOTE_ALL, encoding='utf-8') 230 | print('Test Loss: {:10.4f} Accuracy: {:3.4f}%\n'.format(test_loss, test_acc)) 231 | print('Test Top-5 Accuracy: {:3.4f}%\n'.format(test_topk_acc)) 232 | print('Test Accuracy for Entity Ngrams: {:3.4f}% (Other: {:3.4f})\n'.format(test_entity_ng_acc, test_other_acc)) 233 | 234 | return test_loss, test_acc, test_entity_ng_acc, test_other_acc 235 | 236 | if __name__ == '__main__': 237 | parser = argparse.ArgumentParser() 238 | 239 | # defaults 240 | parser.add_argument('--criterion', type=str, choices=['mse', 'ce'], default='ce') 241 | parser.add_argument('--cuda', type=int, default=0) 242 | parser.add_argument('--num_workers', type=int, default=12) 243 | parser.add_argument('--window_size', type=int, default=512) 244 | 245 | # what dataset to test on 246 | parser.add_argument('--test_data', type=str, 247 | choices=[ 248 | '../data/counterfact_expanded.csv', 249 | '../data/test_tiny_500.csv', 250 | '../data/wikipedia_test_500.csv' 251 | ], default="../data/test_tiny_500.csv") 252 | 253 | # for wikipedia dataset, do MTE if True, otherwise do MTW 254 | parser.add_argument('--wiki_spacy', action='store_true') 255 | parser.set_defaults(wiki_spacy=False) 256 | 257 | # specify probe checkpoint and model. tests the same layer and target_idx. 258 | parser.add_argument('--model', type=str, choices=['meta-llama/Llama-2-7b-hf', 'meta-llama/Meta-Llama-3-8B'], default='meta-llama/Llama-2-7b-hf') 259 | parser.add_argument('--checkpoint', type=str, required=True, help="e.g. ../checkpoints/Llama-2-7b-hf/llamaLAYER12-TGTIDX-2.../final.ckpt") 260 | 261 | args = parser.parse_args() 262 | main(args) 263 | -------------------------------------------------------------------------------- /scripts/readout.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Given a dataset csv with the column 'text' and a choice between Llama-2-7b and 3 | Llama-3-8b, "read out" the vocabulary of that model based on the given datset. 4 | ''' 5 | import os 6 | import re 7 | import csv 8 | import time 9 | import torch 10 | import pickle 11 | import argparse 12 | import numpy as np 13 | import pandas as pd 14 | 15 | from collections import Counter 16 | from transformers import AutoTokenizer 17 | from training import LinearModel 18 | from nnsight import LanguageModel 19 | from huggingface_hub import hf_hub_download 20 | 21 | tgt_idxs = ['i0', 'in1', 'in2', 'in3'] 22 | 23 | def datasetname(s): 24 | if s == "wikipedia": 25 | return s 26 | elif s == "bigbio/med_qa": 27 | return s.split('/')[-1] 28 | elif s == "bigbio/blurb": 29 | return "blurb" 30 | elif s == "wmt/wmt14": 31 | return "wmt14" 32 | else: 33 | return s.split('/')[-1][:-4] 34 | 35 | # load in the outputs of `algorithm1.py` to double check and use as basis for this. 36 | def load_scores(model, dataset): 37 | dir = f"../logs/{model}/candidates/{dataset}_layer1_9/" 38 | dfs = [] 39 | for f in os.listdir(dir): 40 | if "scores" in f: 41 | df = pd.read_csv(dir + f) 42 | df['doc_idx'] = int(re.search(r'\d+', f).group(0)) 43 | dfs.append(df) 44 | 45 | return pd.concat(dfs).drop(columns=["Unnamed: 0"]) 46 | 47 | def get_probe(layer, target_idx, model): 48 | # example: llama-2-7b probe at layer 0, predicting 3 tokens ago 49 | # predicting the next token would be `layer0_tgtidx1.ckpt` 50 | checkpoint_path = hf_hub_download( 51 | repo_id="sfeucht/footprints", 52 | filename=f"{model}/layer{layer}_tgtidx{target_idx}.ckpt" 53 | ) 54 | 55 | # model_size is 4096 for both models. 56 | # vocab_size is 32000 for Llama-2-7b and 128256 for Llama-3-8b 57 | model_size = 4096 58 | if model == 'llama-2-7b': 59 | vocab_size = 32000 60 | elif model == 'llama-3-8b': 61 | vocab_size = 128256 62 | 63 | probe = LinearModel(model_size, vocab_size).cuda() 64 | probe.load_state_dict(torch.load(checkpoint_path)) 65 | 66 | return probe 67 | 68 | 69 | ''' 70 | Implementation of "Erasure Score" 71 | ''' 72 | def psi(doc_info, i, j): 73 | sn = doc_info.iloc[i:j+1] 74 | 75 | idx_to_key = { 76 | 0 : 'tok-1_i0_probdelta', 77 | -1 : 'tok-1_in1_probdelta', 78 | -2 : 'tok-1_in2_probdelta', 79 | -3 : 'tok-1_in3_probdelta' 80 | } 81 | 82 | # we're doing 0 indexing for t, so this is different from the paper 83 | # in the paper we did start-end, so we have to flip all these to end-start and add - in front of probdelta 84 | # also, we have to do (t + idx < i) since we're using absolute idxs 85 | 86 | ideal = 1 87 | score = -sn.iloc[-1]['tok-1_i0_probdelta'] 88 | for t, row in sn.iterrows(): 89 | # for idx in range(-3, 0): # OPTION include -3 information 90 | for idx in range(-2, 0): 91 | # 0 - 3 means it's outside bounds 92 | # 1 - 1 is inside bounds (predicting t=0) 93 | # 3 - 2 is inside bounds (predicting t=1 from t=3) 94 | sign = -1 if (t + idx < i) else 1 95 | 96 | try: 97 | sc = -row[idx_to_key[idx]] 98 | if not np.isnan(sc): 99 | score += sign * sc 100 | ideal += 1 101 | except KeyError: 102 | pass 103 | 104 | return score / ideal 105 | 106 | ''' 107 | Given doc_info, apply Algorithm 1 to segment this particular document into non- 108 | overlapping high-scoring chunks. Allow for unigram segments to fill in the gaps. 109 | ''' 110 | def partition_doc(doc_info): 111 | # create and initialize the matrix 112 | n = len(doc_info) 113 | 114 | # implement as np array 115 | dp = np.ones((n, n)) 116 | dp = dp * -1 117 | 118 | # fill out the matrix 119 | for i in range(n): 120 | for j in range(n): 121 | if i <= j: 122 | if True: # j - i < 6: 123 | dp[i][j] = psi(doc_info, i, j) 124 | 125 | # get the top scores in order 126 | x, y = np.unravel_index(np.argsort(-dp.flatten()), dp.shape) 127 | coords = np.array(list(zip(x, y))) 128 | 129 | # go through all the top ngrams and add them to list, marking which ones become invalid as we go. 130 | segments = [] 131 | for p, q in coords: 132 | if p <= q: 133 | val = dp[p, q] 134 | 135 | valid = True 136 | for (x, y), _ in segments: 137 | if x > q or y < p: 138 | pass 139 | else: 140 | valid = False 141 | break 142 | 143 | if valid: 144 | segments.append(((p,q), val)) 145 | 146 | # validate that the segments fully cover doc 147 | all_ranges = [] 148 | for (x, y), val in segments: 149 | r = range(x, y+1) 150 | all_ranges += list(r) 151 | if set(all_ranges) == set(range(len(dp))): 152 | print("segments have full coverage") 153 | else: 154 | print("WARNING: segments did not fully cover doc") 155 | 156 | return segments 157 | 158 | # take the segmentation of a document and read out the multi-token words 159 | # that were identified via this approach. 160 | def read_out_entries(segments, tokens, filter_unis=True): 161 | entries = [] 162 | vals = [] 163 | for (x, y), val in segments: 164 | r = range(x, y+1) 165 | if not (x==y and filter_unis): 166 | entries.append([tokens[idx] for idx in r]) 167 | vals.append(val) 168 | 169 | return entries, vals 170 | 171 | ''' 172 | Run all the possible probes for a specific token 173 | ''' 174 | def get_tok_metrics(toks, i, start_probes, end_probes, start_states, end_states, tgt_idxs): 175 | corrstart, corrend, probdelta = {}, {}, {} 176 | 177 | # run each pair of probes on hidden states 178 | for start_probe, end_probe, s in zip(start_probes, end_probes, tgt_idxs): 179 | label = { 180 | 'i0' : toks[i], 181 | 'in1' : toks[i - 1] if i >= 1 else None, 182 | 'in2' : toks[i - 2] if i >= 2 else None, 183 | 'in3' : toks[i - 3] if i >= 3 else None 184 | }[s] 185 | 186 | if label is not None: 187 | start_logits = start_probe(start_states[i]).squeeze().detach().cpu() 188 | end_logits = end_probe(end_states[i]).squeeze().detach().cpu() 189 | 190 | corrstart[s] = start_logits.argmax() == label 191 | corrend[s] = end_logits.argmax() == label 192 | probdelta[s] = end_logits.softmax(dim=-1)[label].item() - start_logits.softmax(dim=-1)[label].item() 193 | 194 | del start_logits, end_logits 195 | 196 | return corrstart, corrend, probdelta 197 | 198 | 199 | ''' 200 | given a bunch of tokens and the states for the tokens at a layer, create a dataframe 201 | with every possible probdelta (for different target indices) for each token. 202 | ''' 203 | def get_doc_info(tokens, model, layer_start, layer_end, start_probes, end_probes, tokenizer): 204 | tgt_idxs = ['i0', 'in1', 'in2', 'in3'] 205 | 206 | # get hidden states for tokens 207 | with torch.no_grad(): 208 | with model.trace(tokens): 209 | ss = model.model.layers[layer_start].output[0].squeeze().save() 210 | es = model.model.layers[layer_end].output[0].squeeze().save() 211 | 212 | start_states = ss.detach() 213 | end_states = es.detach() 214 | del ss, es 215 | 216 | # per token: tok-1, decoded, tok-1_i0_iscorr, tok-1_i0_rankdelta, tok-1_i0_logitdelta, tok-1_i0_probdelta 217 | rows = [] 218 | for i, ug_tok in enumerate(tokens): 219 | row = {'decoded' : tokenizer.decode(ug_tok), 'n' : 1} 220 | 221 | # for each token in the ng run all the relevant probes 222 | corrstart, corrend, probdelta = \ 223 | get_tok_metrics(tokens, i, start_probes, end_probes, start_states, end_states, tgt_idxs) 224 | 225 | # save this token 226 | row[f'tok-1'] = ug_tok 227 | 228 | # save i0, in1, in2, in3 for token-1 in the unigram 229 | for s in tgt_idxs: 230 | if s in corrstart.keys(): 231 | row[f'tok-1_{s}_corrstart'] = corrstart[s].item() 232 | row[f'tok-1_{s}_corrend'] = corrend[s].item() 233 | row[f'tok-1_{s}_probdelta'] = probdelta[s] 234 | 235 | rows.append(row) 236 | 237 | del start_states, end_states 238 | torch.cuda.empty_cache() 239 | 240 | return pd.DataFrame(rows) 241 | 242 | def main(args): 243 | model = LanguageModel(args.model, device_map='cuda') 244 | tokenizer = AutoTokenizer.from_pretrained(args.model) 245 | 246 | MODEL_NAME = args.model.split('/')[-1] 247 | WINDOW_SIZE = 256 248 | 249 | # dataset we want to index 250 | dataset = pd.read_csv(args.dataset) 251 | 252 | dump_dir = f"../logs/{MODEL_NAME}/readout/{datasetname(args.dataset)}_layer{args.layer_start}_{args.layer_end}/" 253 | os.makedirs(dump_dir, exist_ok=True) 254 | 255 | hf_string = { 256 | 'Llama-2-7b-hf' : 'llama-2-7b', 257 | 'Meta-Llama-3-8B' : 'llama-3-8b' 258 | }[MODEL_NAME] 259 | 260 | # load in the probes at layer_start and layer_end 261 | start_probes, end_probes = [], [] 262 | start_probes.append(get_probe(args.layer_start, 0, hf_string)) 263 | end_probes.append(get_probe(args.layer_end, 0, hf_string)) 264 | 265 | start_probes.append(get_probe(args.layer_start, -1, hf_string)) 266 | end_probes.append(get_probe(args.layer_end, -1, hf_string)) 267 | 268 | start_probes.append(get_probe(args.layer_start, -2, hf_string)) 269 | end_probes.append(get_probe(args.layer_end, -2, hf_string)) 270 | 271 | ctr = 0 272 | all_ctr = Counter() 273 | sum_scores = {} 274 | tik = time.time() 275 | for doc_idx, doc in enumerate(dataset['text']): 276 | tokens = tokenizer(doc)['input_ids'][:WINDOW_SIZE] 277 | 278 | # get probe probability information for this doc_idx 279 | fname = f"docinfo_{doc_idx}.csv" 280 | try: 281 | doc_df = pd.read_csv(dump_dir + fname) 282 | print(f"loaded {dump_dir + fname}") 283 | except FileNotFoundError: 284 | doc_df = get_doc_info(tokens, model, args.layer_start, args.layer_end, start_probes, end_probes, tokenizer) 285 | doc_df.to_csv(dump_dir + fname, quoting=csv.QUOTE_ALL) 286 | print(f"saved {dump_dir + fname}, {len(tokens)} tokens in {datasetname(args.dataset)}") 287 | 288 | # segment doc with partition_doc 289 | picklename = f"segments_{doc_idx}.pkl" 290 | try: 291 | with open(dump_dir + picklename, 'rb') as f: 292 | segments = pickle.load(f) 293 | print(f"loaded segments from {dump_dir + picklename}") 294 | 295 | except FileNotFoundError: 296 | print(f"partitioning doc {doc_idx}...") 297 | segments = partition_doc(doc_df) 298 | 299 | with open(dump_dir + picklename, 'wb') as f: 300 | pickle.dump(segments, f) 301 | print(f"saved segments to {dump_dir + picklename}") 302 | 303 | # filter out the unigrams when you're "reading out" the vocabulary 304 | entries, vals = read_out_entries(segments, tokens, filter_unis=True) 305 | decoded_entries = [tokenizer.decode(e) for e in entries] 306 | 307 | # add to running totals 308 | all_ctr += Counter(decoded_entries) 309 | for de, v in zip(decoded_entries, vals): 310 | try: 311 | sum_scores[de] += v 312 | except KeyError: 313 | sum_scores[de] = v 314 | 315 | ctr += 1 316 | if args.n_examples > 0: 317 | if ctr >= args.n_examples: 318 | break # cut off 319 | 320 | tok = time.time() 321 | print("minutes taken:", (tok-tik) / 60) 322 | 323 | assert all_ctr.keys() == sum_scores.keys() 324 | 325 | # save counts of all vocabulary items 326 | cts_fname = f"cts_0thru{doc_idx}.pkl" 327 | with open(dump_dir + cts_fname, 'wb') as f: 328 | pickle.dump(all_ctr, f) 329 | print(f"saved counts at {dump_dir + cts_fname}") 330 | 331 | # calculate average scores for each vocab item 332 | avg_scores = {} 333 | for k, v in sum_scores.items(): 334 | avg_scores[k] = v / all_ctr[k] 335 | 336 | # save average scores for each vocab item 337 | avgs_fname = f"avgs_0thru{doc_idx}.pkl" 338 | with open(dump_dir + avgs_fname, 'wb') as f: 339 | pickle.dump(avg_scores, f) 340 | print(f"saved averages at {dump_dir + avgs_fname}") 341 | 342 | ctr = 0 343 | print("\nTop 50 Vocabulary Entries") 344 | for k, v in avg_scores.items(): 345 | print(repr(k), '\t', v) 346 | ctr += 1 347 | if ctr > 50: 348 | break 349 | 350 | if __name__ == "__main__": 351 | parser = argparse.ArgumentParser() 352 | 353 | parser.add_argument("--dataset", default="../data/test_tiny_500.csv") 354 | parser.add_argument('--layer_start', type=int, default=1) 355 | parser.add_argument('--layer_end', type=int, default=9) 356 | parser.add_argument('--n_examples', type=int, default=-1, help="-1 to use the whole dataset") 357 | 358 | parser.add_argument('--model', default='meta-llama/Llama-2-7b-hf', 359 | choices=['meta-llama/Meta-Llama-3-8B', 'meta-llama/Llama-2-7b-hf']) 360 | 361 | args = parser.parse_args() 362 | 363 | main(args) -------------------------------------------------------------------------------- /scripts/training.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Functions used in train_probe.py and test_probe.py for training and data loading. 3 | ''' 4 | import torch 5 | import wandb 6 | import pandas as pd 7 | 8 | from tqdm import tqdm 9 | import torch.nn as nn 10 | from torch.utils.data import Dataset 11 | from utils import acc, _topktoks, _topkprobs, ngram_size, is_entity_last, is_entity_ngram 12 | 13 | class LinearModel(nn.Module): 14 | def __init__(self, input_size, output_size, bias=False): 15 | super(LinearModel, self).__init__() 16 | self.fc = nn.Linear(input_size, output_size, bias=bias) 17 | 18 | def forward(self, x): 19 | output = self.fc(x) 20 | return output 21 | 22 | ''' 23 | Trains a probe for a single epoch and logs loss/accuracy. 24 | 25 | Parameters: 26 | epoch: index of the current epoch 27 | probe: linear model to be trained 28 | train_loader: DataLoader of training data 29 | criterion: what loss function to use 30 | optimizer: torch optimizer for linear model 31 | warmup_scheduler: scheduler for learning rate 32 | accumulate: how many batches to wait until updating / clearing grads 33 | clip_threshold: threshold for gradient clipping. 34 | batches_seen: no. batches seen before this epoch 35 | device: device of Llama model 36 | 37 | Returns: 38 | None 39 | ''' 40 | def train_epoch(epoch, probe, train_loader, criterion, optimizer, warmup_scheduler, accumulate, clip_threshold, batches_seen, device): 41 | probe.train() 42 | 43 | for batch_idx, (hidden_states, target_toks, _, _, _) in enumerate(train_loader): 44 | hidden_states, target_toks = hidden_states.to(device), target_toks.to(device) 45 | assert(not torch.isnan(hidden_states).any() and not torch.isinf(hidden_states).any()) 46 | 47 | # get probe predictions and convert to toks if needed 48 | output = probe(hidden_states.float()).to(device) 49 | 50 | # then calculate loss with the target tokens 51 | loss = criterion(output, target_toks, reduction="mean") 52 | loss.backward() 53 | 54 | if batch_idx % accumulate == 0 and batch_idx > 0: 55 | torch.nn.utils.clip_grad_norm_(probe.parameters(), clip_threshold) 56 | optimizer.step() 57 | optimizer.zero_grad() 58 | 59 | loss = loss.detach().item() 60 | 61 | # learning rate warmup for AdamW. 62 | if warmup_scheduler is not None: 63 | if batch_idx < len(train_loader)-1: 64 | with warmup_scheduler.dampening(): 65 | pass 66 | 67 | # print training accuracy/loss every 10 epochs, and on the last epoch 68 | if batch_idx % max(accumulate, 10) == 0 or batch_idx == len(train_loader) - 1: 69 | train_acc = acc(output.cpu(), target_toks.cpu()) 70 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tTraining Acc:{:3.3f}%\tBatch Loss: {:.6f} ({} tokens)'.format( 71 | epoch, batch_idx, len(train_loader), 100. * batch_idx / len(train_loader), 72 | train_acc.item(), loss, hidden_states.size()[0])) 73 | 74 | wandb.log({"train_loss": loss, "train_acc": train_acc, 75 | "epoch": epoch, "batches_seen": 1 + batch_idx + batches_seen}) 76 | 77 | return 1 + batch_idx + batches_seen 78 | 79 | 80 | ''' 81 | Given a trained LinearModel and a DataLoader of test examples, test the model on a given test_loader and return csv of results (optionally) 82 | ''' 83 | def test(probe, test_loader, criterion, device, tokenizer, target_idx, return_results=False): 84 | probe.eval() 85 | total_loss = 0.0 86 | total_toks = 0 87 | correct = 0 88 | topk_correct = 0 89 | results = [] 90 | 91 | n_entity_ngrams = 0 92 | n_entity_ngrams_correct = 0 93 | n_other = 0 94 | n_other_correct = 0 95 | with torch.no_grad(): 96 | for (data, target_toks, curr_toks, doc_idxs, entity_mask) in tqdm(test_loader): 97 | if data is None: 98 | continue 99 | 100 | output = probe(data.to(device).float()).to(device) 101 | 102 | loss = criterion(output, target_toks.to(device), reduction="mean") 103 | total_loss += loss.detach().item() 104 | 105 | for i, v in enumerate(output.cpu()): 106 | doc_id = doc_idxs[i] 107 | current_tok = _topktoks(curr_toks[i]) 108 | actual_tok = target_toks[i] 109 | predicted_tok = _topktoks(v) 110 | this_loss = criterion(v, target_toks[i]) 111 | is_correct = predicted_tok == actual_tok 112 | 113 | total_toks += 1 114 | if is_correct: 115 | correct += 1 116 | if actual_tok in _topktoks(v, k=5): # less interpretable NOW THAT PINV 117 | topk_correct += 1 118 | 119 | kl_divergence = -1 120 | q_target_log_prob = torch.inf 121 | p_target_log_prob = torch.inf 122 | 123 | 124 | if entity_mask is not None: 125 | this_is_entity_ngram = bool(is_entity_ngram(i, entity_mask, target_idx=target_idx)) 126 | if this_is_entity_ngram: 127 | n_entity_ngrams += 1 128 | n_entity_ngrams_correct += int(is_correct) 129 | else: # count up coarse "other" values. 130 | n_other += 1 131 | n_other_correct += int(is_correct) 132 | 133 | n = 0 134 | this_is_entity_last = is_entity_last(i, entity_mask) 135 | if this_is_entity_last: # if entity save how big the entity is 136 | n = ngram_size(i, entity_mask) 137 | 138 | if return_results: 139 | # BOS token becomes encoded as NaN in pandas here 140 | curr_result = { 141 | "doc_id" : doc_id.item(), 142 | "current_tok_id" : current_tok.item(), 143 | "actual_tok_id" : actual_tok.item(), 144 | "predicted_tok_id" : predicted_tok.item(), 145 | "current_tok" : tokenizer.decode(current_tok.tolist()), 146 | "actual_tok" : tokenizer.decode(actual_tok.tolist()), 147 | "predicted_tok" : tokenizer.decode(predicted_tok.tolist()), 148 | "loss" : this_loss.item(), 149 | 150 | # from this information you can split into entity_last, entity_notlast, and nonentity. 151 | # as well as the case "entity_ngram": whether this probe is predicting the first tok FROM the last tok of an entity. 152 | "is_entity" : entity_mask[i].item() if entity_mask is not None else -1, 153 | "is_entity_last" : this_is_entity_last if entity_mask is not None else -1, 154 | "is_entity_ngram" : this_is_entity_ngram if entity_mask is not None else -1, 155 | "n" : n, 156 | 157 | "kl_divergence" : kl_divergence, 158 | "q_target_log_prob" : q_target_log_prob, 159 | "p_target_log_prob" : p_target_log_prob, 160 | **_topkprobs(v, tokenizer) 161 | } 162 | 163 | results.append(curr_result) 164 | 165 | 166 | test_loss = total_loss / len(test_loader) # divide total average loss by no. batches 167 | test_acc = 100 * correct / total_toks 168 | topk_acc = 100 * topk_correct / total_toks 169 | 170 | if n_entity_ngrams > 0: 171 | entity_ngram_acc = 100 * n_entity_ngrams_correct / n_entity_ngrams 172 | else: 173 | entity_ngram_acc = -1 174 | 175 | if n_other > 0: 176 | other_acc = 100 * n_other_correct / n_other 177 | else: 178 | other_acc = -1 179 | 180 | if return_results: 181 | return test_loss, test_acc, topk_acc, entity_ngram_acc, other_acc, pd.DataFrame(results) 182 | else: 183 | return test_loss, test_acc, topk_acc, entity_ngram_acc, other_acc 184 | 185 | ''' 186 | Dataset for retrieving tokenized documents from csv along with masks that mark 187 | which tokens correspond to "entities" 188 | ''' 189 | class DocDataset(Dataset): 190 | def __init__(self, model, tokenizer, layer_name, target_idx, dataset_csv, window_size, vocab_size, device, entities=None): 191 | self.model = model 192 | self.tokenizer = tokenizer 193 | self.layer_name = layer_name # -1 is embedding, 0-31 for layers, 32 for logits right at the end 194 | self.target_idx = target_idx # -1 is previous token, 0 is current, etc. 195 | self.dataset_csv = dataset_csv 196 | self.window_size = window_size 197 | self.vocab_size = vocab_size 198 | self.device = device 199 | self.entities = entities # list of strings that are the entities we want to mask out 200 | 201 | if self.entities is not None: 202 | self.entities = [self.tokenize(e, bos=False) for e in self.entities] 203 | 204 | # llama tokenizer already adds BOS token 205 | def tokenize(self, text, bos=True): 206 | if bos: 207 | t = self.tokenizer(text)['input_ids'] 208 | else: 209 | t = self.tokenizer(text)['input_ids'][1:] 210 | 211 | # this makes sure that entity mask is also truncated to window size 212 | if len(t) > self.window_size: 213 | return t[:self.window_size] 214 | else: 215 | return t 216 | 217 | # iterate through sequence and mark subseq occurs in sequence 218 | def mask_iterator(self, sequence, subseq, mask): 219 | sequence = list(sequence.cpu()) 220 | if len(subseq) <= len(sequence): 221 | for i in range(len(sequence)-len(subseq)+1): 222 | assert len(sequence[i:i+len(subseq)]) == len(subseq) 223 | if (sequence[i:i+len(subseq)] == subseq): 224 | mask[i:i+len(subseq)] = 1 225 | return torch.Tensor(mask) 226 | 227 | # returns number of documents, not tokens 228 | def __len__(self): 229 | if self.dataset_csv is not None: 230 | return len(self.dataset_csv) 231 | 232 | # get document tokens and mask 233 | def __getitem__(self, index): 234 | doc = self.dataset_csv.iloc[index] 235 | doc_string = str(doc['text']) 236 | 237 | # need this for entity mask calculations 238 | tokens = torch.tensor(self.tokenize(doc_string)) 239 | 240 | entity_mask = torch.zeros_like(tokens) 241 | if self.entities is not None: 242 | for e in self.entities: 243 | entity_mask = self.mask_iterator(tokens, e, entity_mask) 244 | 245 | return torch.tensor(index), doc_string, tokens, entity_mask 246 | 247 | ''' 248 | Bloated collate function that takes sequences and retrieves hidden states for the given model using nnsight 249 | ''' 250 | class DocCollate(object): 251 | def __init__(self, layer, target_idx, tokenizer, model, window_size, device): 252 | self.layer = layer 253 | self.target_idx = target_idx 254 | self.tokenizer = tokenizer 255 | self.model = model 256 | self.window_size = window_size 257 | self.device = device 258 | 259 | def __call__(self, batch): 260 | # pad all the strings and save attention mask 261 | strings = [s for (_, s, _, _) in batch] 262 | tokenized = self.tokenizer(strings, return_tensors='pt', padding=True, truncation=True, max_length=self.window_size) 263 | attention_mask = tokenized['attention_mask'] 264 | 265 | with self.model.trace(tokenized): 266 | if 'llama' in self.model.config._name_or_path.lower(): 267 | if self.layer == -1: 268 | states = self.model.model.embed_tokens.output.save() 269 | elif self.layer == 32: 270 | states = self.model.model.norm.output.save() 271 | else: 272 | states = self.model.model.layers[self.layer].output[0].save() 273 | else: # pythia 274 | if self.layer == -1: 275 | states = self.model.gpt_neox.embed_in.output.save() 276 | elif self.layer == 32: 277 | states = self.model.gpt_neox.final_layer_norm.output.save() 278 | else: 279 | states = self.model.gpt_neox.layers[self.layer].output[0].save() 280 | 281 | # then loop through the entire thing to keep same logic for embs, tokens and doc_idxs 282 | source_hss = [] 283 | target_toks = [] 284 | current_toks = [] 285 | doc_idxs = [] 286 | entity_masks = [] 287 | for i, doc in enumerate(batch): 288 | # batch looks like [doc0:(0, text, tokens, mask), doc1:(1, text, tokens, mask)...] 289 | doc_idx, tokens, entity_mask = (a.cpu() for a in doc if type(a)!=str) 290 | 291 | # get the hidden states we just calculated, and trim off the PAD tokens 292 | # for llama 2 7b padding is always at the beginning 293 | hidden_states = states[i][-sum(attention_mask[i]):] 294 | assert (len(hidden_states) == len(tokens)) 295 | 296 | # make sure that hidden_states has enough tokens to deal with the given target_idx. 297 | # if the target_idx is gonna be outside the bounds of hidden_states, we want to skip doc. 298 | if abs(self.target_idx) >= len(hidden_states): 299 | continue 300 | 301 | # target_idx == -1: 302 | # source_hss: BOS [this is an example sentence] 303 | # target_toks: [BOS this is an example] sentence 304 | # target_idx == -2: 305 | # (BOS this [is an) example sentence] 306 | # target_idx == -3: 307 | # (BOS this is) [an example sentence] 308 | if self.target_idx < 0: 309 | pos = abs(self.target_idx) 310 | source_hss.append(hidden_states[pos:]) 311 | target_toks.append(tokens[:-pos]) 312 | current_toks.append(tokens[pos:]) 313 | doc_idxs.append(torch.tensor([doc_idx for _ in range(len(hidden_states[pos:]))], device='cpu')) 314 | entity_masks.append(entity_mask[pos:]) 315 | 316 | # target_idx == 1: 317 | # source_hss: [BOS this is an example] sentence 318 | # target_toks: BOS [this is an example sentence] 319 | # target_idx == 2: 320 | # [BOS this (is an] example sentence) 321 | # target_idx == 3: 322 | # [BOS this is] (an example sentence) 323 | elif self.target_idx > 0: 324 | pos = abs(self.target_idx) 325 | source_hss.append(hidden_states[:-pos]) 326 | target_toks.append(tokens[pos:]) 327 | current_toks.append(tokens[:-pos]) 328 | doc_idxs.append(torch.tensor([doc_idx for _ in range(len(hidden_states[:-pos]))], device='cpu')) 329 | entity_masks.append(entity_mask[:-pos]) 330 | 331 | # exclude predicting bos_embedding -> BOS 332 | elif self.target_idx == 0: 333 | source_hss.append(hidden_states[1:]) 334 | target_toks.append(tokens[1:]) 335 | current_toks.append(tokens[1:]) 336 | doc_idxs.append(torch.tensor([doc_idx for _ in range(len(hidden_states[1:]))], device='cpu')) 337 | entity_masks.append(entity_mask[1:]) 338 | 339 | # sometimes docs are too small 340 | if len(source_hss) > 0: 341 | source_hss = torch.cat(source_hss) 342 | target_toks = torch.cat(target_toks) 343 | current_toks = torch.cat(current_toks) 344 | doc_idxs = torch.cat(doc_idxs) 345 | entity_masks = torch.cat(entity_masks) 346 | return (source_hss, target_toks, current_toks, doc_idxs, entity_masks) 347 | else: 348 | return None, None, None, None, None --------------------------------------------------------------------------------