├── .gitignore
├── data
├── monk.txt
└── solitude.txt
├── LICENSE
├── requirements.txt
├── README.md
└── scripts
├── segment.py
├── utils.py
├── train_probe.py
├── test_probe.py
├── readout.py
└── training.py
/.gitignore:
--------------------------------------------------------------------------------
1 | env
2 | scripts/__pycache__
3 | scripts/wandb
4 | checkpoints
5 | logs
6 |
--------------------------------------------------------------------------------
/data/monk.txt:
--------------------------------------------------------------------------------
1 | Monk's compositions and improvisations feature dissonances and angular melodic twists, often using flat ninths, flat fifths, unexpected chromatic notes together, low bass notes and stride, and fast whole tone runs, combining a highly percussive attack with abrupt, dramatic use of switched key releases, silences, and hesitations.
2 |
--------------------------------------------------------------------------------
/data/solitude.txt:
--------------------------------------------------------------------------------
1 | Many years later, as he faced the firing squad, Colonel Aureliano Buendía was to remember that distant afternoon when his father took him to discover ice. At that time Macondo was a village of twenty adobe houses, built on the bank of a river of clear water that ran along a bed of polished stones, which were white and enormous, like prehistoric eggs. The world was so recent that many things lacked names, and in order to indicate them it was necessary to point.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Sheridan Feucht
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.31.0
2 | annotated-types==0.7.0
3 | bidict==0.23.1
4 | blis==0.7.11
5 | catalogue==2.0.10
6 | certifi==2024.6.2
7 | charset-normalizer==3.3.2
8 | click==8.1.7
9 | cloudpathlib==0.18.1
10 | confection==0.1.5
11 | cymem==2.0.8
12 | diffusers==0.29.1
13 | docker-pycreds==0.4.0
14 | einops==0.8.0
15 | filelock==3.15.4
16 | fsspec==2024.6.0
17 | gitdb==4.0.11
18 | GitPython==3.1.43
19 | h11==0.14.0
20 | huggingface-hub==0.23.4
21 | idna==3.7
22 | importlib-metadata==7.2.1
23 | jinja2==3.1.4
24 | langcodes==3.4.0
25 | language-data==1.2.0
26 | marisa-trie==1.2.0
27 | markdown-it-py==3.0.0
28 | MarkupSafe==2.1.5
29 | mdurl==0.1.2
30 | mpmath==1.3.0
31 | murmurhash==1.0.10
32 | networkx==3.1
33 | nnsight==0.2.19
34 | numpy==1.24.4
35 | nvidia-cublas-cu12==12.1.3.1
36 | nvidia-cuda-cupti-cu12==12.1.105
37 | nvidia-cuda-nvrtc-cu12==12.1.105
38 | nvidia-cuda-runtime-cu12==12.1.105
39 | nvidia-cudnn-cu12==8.9.2.26
40 | nvidia-cufft-cu12==11.0.2.54
41 | nvidia-curand-cu12==10.3.2.106
42 | nvidia-cusolver-cu12==11.4.5.107
43 | nvidia-cusparse-cu12==12.1.0.106
44 | nvidia-nccl-cu12==2.20.5
45 | nvidia-nvjitlink-cu12==12.5.40
46 | nvidia-nvtx-cu12==12.1.105
47 | packaging==24.1
48 | pandas==2.0.3
49 | pillow==10.3.0
50 | platformdirs==4.2.2
51 | preshed==3.0.9
52 | protobuf==5.27.1
53 | psutil==6.0.0
54 | pydantic==2.7.4
55 | pydantic-core==2.18.4
56 | pygments==2.18.0
57 | python-dateutil==2.9.0.post0
58 | python-engineio==4.9.1
59 | python-socketio==5.11.3
60 | pytorch-warmup==0.1.1
61 | pytz==2024.1
62 | PyYAML==6.0.1
63 | regex==2024.5.15
64 | requests==2.32.3
65 | rich==13.7.1
66 | safetensors==0.4.3
67 | sentencepiece==0.2.0
68 | sentry-sdk==2.6.0
69 | setproctitle==1.3.3
70 | shellingham==1.5.4
71 | simple-websocket==1.0.0
72 | six==1.16.0
73 | smart-open==7.0.4
74 | smmap==5.0.1
75 | spacy==3.7.5
76 | spacy-legacy==3.0.12
77 | spacy-loggers==1.0.5
78 | srsly==2.4.8
79 | sympy==1.12.1
80 | thinc==8.2.5
81 | tokenizers==0.19.1
82 | torch==2.3.1
83 | torchvision==0.18.1
84 | tqdm==4.66.4
85 | transformers==4.41.2
86 | triton==2.3.1
87 | typer==0.12.3
88 | typing-extensions==4.12.2
89 | tzdata==2024.1
90 | urllib3==2.2.2
91 | wandb==0.17.3
92 | wasabi==1.1.3
93 | weasel==0.4.1
94 | websocket-client==1.8.0
95 | wrapt==1.16.0
96 | wsproto==1.2.0
97 | zipp==3.19.2
98 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Token Erasure as a Footprint of Implicit Vocabulary Items in LLMs
2 | How do LLMs process multi-token words, common phrases, and named entities? We discover a pattern of token erasure that we hypothesize to be a 'footprint' of how LLMs process unnatural tokenization.
3 |
4 | Read more about our paper here:
5 | 🌐 https://footprints.baulab.info
6 | 📄 https://arxiv.org/abs/2406.20086
7 |
8 |
9 |
10 | ## Setup
11 | To run our code, clone this repository and create a new virtual environment using Python 3.8.10:
12 | ```
13 | python3 -m venv env
14 | source env/bin/activate
15 | pip install -r requirements.txt
16 | ```
17 |
18 | ## Segmenting a Document
19 | An implementation of Algorithm 1 in our paper is provided in `segment.py`. This script can be run like so:
20 | ```
21 | python segment.py --document my_doc.txt --model meta-llama/Llama-2-7b-hf
22 | ```
23 | allowing you to segment any paragraph of text into high-scoring token sequences.
24 | ```
25 | segments from highest to lowest score:
26 | 'dramatic' 0.5815845847818102
27 | 'twists' 0.5553912909024803
28 | 'low bass' 0.41476866118921824
29 | 'cuss' 0.3979072428604316
30 | 'fifth' 0.3842911866668146
31 | 'using' 0.3568337553491195
32 | ...
33 | 'ive' -0.07994025301498671
34 | 's' -0.14006704260206485
35 | 'ations' -0.2306471753618856
36 | 'itions' -0.3348596893891435
37 | ```
38 | Adding the `--output_html` flag will also save an HTML file in the style of the below example to the folder `./logs/html`, bolding all multi-token sequences and coloring them blue if they have a higher erasure score.
39 |
40 |
41 |
42 | ## "Reading Out" Vocabulary Examples
43 | To apply this segmentation algorithm to an entire dataset (as seen in Tables 3 through 6), run
44 | ```
45 | python readout.py --model meta-llama/Meta-Llama-3-8B --dataset ../data/wikipedia_test_500.csv
46 | ```
47 | which specifically replicates Appendix Table 6. You can use your own dataset csv, as long as it contains a 'text' column with the documents you want to analyze.
48 |
49 | ## Loading Our Probes
50 | Checkpoints for each of the linear probes used in our paper are available at https://huggingface.co/sfeucht/footprints. To load a linear probe used in this paper, run the following code snippet:
51 |
52 | ```python
53 | import torch
54 | import torch.nn as nn
55 | from huggingface_hub import hf_hub_download
56 |
57 | class LinearModel(nn.Module):
58 | def __init__(self, input_size, output_size, bias=False):
59 | super(LinearModel, self).__init__()
60 | self.fc = nn.Linear(input_size, output_size, bias=bias)
61 | def forward(self, x):
62 | output = self.fc(x)
63 | return output
64 |
65 | # example: llama-2-7b probe at layer 0, predicting 3 tokens ago
66 | # predicting the next token would be `layer0_tgtidx1.ckpt`
67 | checkpoint_path = hf_hub_download(
68 | repo_id="sfeucht/footprints",
69 | filename="llama-2-7b/layer0_tgtidx-3.ckpt"
70 | )
71 |
72 | # model_size is 4096 for both models.
73 | # vocab_size is 32000 for Llama-2-7b and 128256 for Llama-3-8b
74 | probe = LinearModel(4096, 32000).cuda()
75 | probe.load_state_dict(torch.load(checkpoint_path))
76 | ```
77 |
78 | ## Training Your Own Probes
79 | We have provided the probes used for the paper above. However, if you would still like to train your own linear probes, we provide code for training and testing linear probes on Llama hidden states in `./scripts`. To train a probe on e.g. layer 12 to predict two tokens ago, run
80 | ```
81 | python train_probe.py --layer 12 --target_idx -2
82 | ```
83 | and a linear model will be trained on Llama-2-7b by default and stored as a checkpoint in `./checkpoints`. These checkpoints can then be read by `./scripts/test_probe.py` and tested on either CounterFact tokens, Wikipedia tokens (multi-token words or spaCy entities), or plain Pile tokens. Test results are stored in `./logs`.
84 | ```
85 | python test_probe.py --checkpoint ../checkpoints/Llama-2-7b-hf/.../final.ckpt --test_data counterfact_expanded.csv
86 | ```
87 |
88 | ## Datasets Used
89 | We use three datasets in this paper, which can all be found in `./data`.
90 |
91 | - CounterFact [(Meng et al., 2022)](https://rome.baulab.info/)
92 | - `counterfact_expanded.csv` was used for all of the CounterFact tests in the paper, and includes rows in addition to the original CounterFact dataset.
93 | - Pile [(Gao et al., 2020)](https://pile.eleuther.ai/)
94 | - `train_tiny_1000.csv` was used to train all of the probes.
95 | - `val_tiny_500.csv` was used to validate probe hyperparameters.
96 | - `test_tiny_500.csv` was used for overall Pile test results.
97 | - Wikipedia [(Wikimedia Foundation, 2022)](https://huggingface.co/datasets/legacy-datasets/wikipedia)
98 | - `wikipedia_test_500.csv` was used for overall Wikipedia test results.
99 | - `wikipedia_val_500.csv` and `wikipedia_train_1000.csv` were not used in this work, but are included for completeness.
100 |
--------------------------------------------------------------------------------
/scripts/segment.py:
--------------------------------------------------------------------------------
1 | '''
2 | Segment one document into highest-scoring token subsequences.
3 | '''
4 | import os
5 | import argparse
6 | import math
7 | from nnsight import LanguageModel
8 | from readout import get_doc_info, partition_doc, get_probe
9 |
10 | ## HTML Formatting
11 | def col(s):
12 | # make sure between -1 and 1
13 | s = max(-1, min(1, s))
14 |
15 | # blue (high) 0, 98, 255
16 | blue_r = 0
17 | blue_g = 98
18 | blue_b = 255
19 |
20 | if (s > 0):
21 | r = math.floor(255 + s * (blue_r - 255))
22 | g = math.floor(255 + s * (blue_g - 255))
23 | b = math.floor(255 + s * (blue_b - 255))
24 |
25 | return f"style='background-color:rgb({r}, {g}, {b})'"
26 | else:
27 | return ""
28 |
29 | def html_span(segments, doc_unis, tokenizer):
30 | # all the tokens in original doc
31 | doc_tokens = [t for t in doc_unis['tok-1']]
32 |
33 | boxs_col = lambda x: tokenizer(f"")['input_ids'][1:]
34 | boxs = lambda x: tokenizer(f"")['input_ids'][1:]
35 | boxe = tokenizer("")['input_ids'][1:]
36 |
37 | blds = tokenizer("")['input_ids'][1:]
38 | blde = tokenizer("")['input_ids'][1:]
39 |
40 | bos_tok = tokenizer("<s>")['input_ids'][1:]
41 |
42 | out = []
43 | for (x, y), val in segments:
44 | tokens = doc_tokens[x:y+1]
45 |
46 | # replace with s
47 | if tokens[0] == 1:
48 | tokens = bos_tok + tokens[1:] # 's'
49 |
50 | if len(tokens) > 1:
51 | to_appnd = [*boxs_col(val), *blds, *tokens, *blde, *boxe]
52 | else:
53 | to_appnd = [*boxs(val), *tokens, *boxe]
54 |
55 | out += to_appnd
56 |
57 | return tokenizer.decode(out)
58 |
59 | def html_view(segments, doc_unis, tokenizer):
60 | # sort segments to be in document order
61 | segments = sorted(segments, key=lambda t: t[0][0])
62 | inject = html_span(segments, doc_unis, tokenizer)
63 |
64 | # white out non-mtw segment values
65 | whiteout = []
66 | for (x,y), val in segments:
67 | if x == y:
68 | whiteout.append(((x,y), 0))
69 | else:
70 | whiteout.append(((x,y), val))
71 |
72 | template = f"""
73 |
74 |
75 |
76 |
84 |
85 |
86 |
87 |
88 | {inject}
89 |
90 |
91 |
92 |
93 | """
94 |
95 | return template
96 |
97 |
98 | def load_model_and_probes(path):
99 | model = LanguageModel(path, device_map='cuda')
100 | tokenizer = model.tokenizer
101 |
102 | hf_name = {
103 | 'meta-llama/Llama-2-7b-hf' : 'llama-2-7b',
104 | 'meta-llama/Meta-Llama-3-8B' : 'llama-3-8b'
105 | }[path]
106 | layer_start = 1
107 | layer_end = 9
108 |
109 | start_probes, end_probes = [], []
110 | start_probes.append(get_probe(layer_start, 0, hf_name))
111 | end_probes.append(get_probe(layer_end, 0, hf_name))
112 |
113 | start_probes.append(get_probe(layer_start, -1, hf_name))
114 | end_probes.append(get_probe(layer_end, -1, hf_name))
115 |
116 | start_probes.append(get_probe(layer_start, -2, hf_name))
117 | end_probes.append(get_probe(layer_end, -2, hf_name))
118 |
119 | return model, tokenizer, start_probes, end_probes
120 |
121 |
122 | def main(args):
123 | # load in model and probes
124 | model, tokenizer, start_probes, end_probes = load_model_and_probes(args.model)
125 |
126 | # read in given txt file
127 | with open(args.document, 'r') as f:
128 | input_text = f.read().strip()
129 |
130 | # tokenize
131 | tokens = tokenizer(input_text)['input_ids'][:args.max_length]
132 |
133 | # get probe info and partition document
134 | doc_info = get_doc_info(tokens, model, args.layer_start, args.layer_end, start_probes, end_probes, tokenizer)
135 | segments = partition_doc(doc_info)
136 |
137 | # save html output if desired
138 | if args.output_html:
139 | html_output = html_view(segments, doc_info, tokenizer)
140 |
141 | write_dir = "../logs/html/"
142 | fname = f"{args.document.split('/')[-1][:-4]}.html"
143 | os.makedirs(write_dir, exist_ok=True)
144 |
145 | print("saving html output as " + write_dir + fname)
146 | with open(f"../logs/html/{args.document.split('/')[-1][:-4]}.html", 'w') as f:
147 | f.write(html_output)
148 |
149 | # print document segments
150 | print("\nsegments from highest to lowest score:")
151 | for (p, q), val in segments:
152 | text = tokenizer.decode(tokens[p : q+1])
153 | print(repr(text), '\t', val)
154 |
155 | if __name__ == "__main__":
156 | parser = argparse.ArgumentParser()
157 |
158 | parser.add_argument("--document", default="../data/monk.txt")
159 | parser.add_argument('--layer_start', type=int, default=1)
160 | parser.add_argument('--layer_end', type=int, default=9)
161 | parser.add_argument('--max_length', type=int, default=256)
162 |
163 | parser.add_argument('--output_html', action='store_true')
164 | parser.set_defaults(output_html=False)
165 |
166 | parser.add_argument('--model', default='meta-llama/Llama-2-7b-hf',
167 | choices=['meta-llama/Meta-Llama-3-8B', 'meta-llama/Llama-2-7b-hf'])
168 |
169 | args = parser.parse_args()
170 |
171 | main(args)
172 |
--------------------------------------------------------------------------------
/scripts/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Small helper functions used for training and testing.
3 | '''
4 | import torch
5 |
6 | def acc(pred_dist, target_toks):
7 | return 100 * (sum(pred_dist.argmax(dim=-1) == target_toks) / len(target_toks))
8 |
9 | def _topktoks(logits, k=1):
10 | _, top_tokens = logits.topk(k=k, dim=-1)
11 | return top_tokens
12 |
13 | def _topkprobs(logits, tokenizer, k=5):
14 | top_probs, top_tokens = torch.softmax(logits, dim=0).topk(k=k, dim=-1)
15 | out = {}
16 | for i in range(k):
17 | out[f"top_{i+1}_prob"] = top_probs[i].item()
18 | out[f"top_{i+1}_tok_id"] = top_tokens[i].item()
19 | out[f"top_{i+1}_tok"] = tokenizer.decode(top_tokens[i].tolist())
20 | return out
21 |
22 | '''
23 | The next three functions handle entity_mask lists for the function test() in training.py
24 | '''
25 | # assume you are given i as the last position in ngram mask
26 | def ngram_size(i, entity_mask):
27 | assert i >= 0
28 | assert entity_mask[i] == 1
29 | assert i + 1 == len(entity_mask) or entity_mask[i+1] == 0
30 |
31 | if i == 0:
32 | return 1 # unigram [1]
33 | size = 1
34 | j = i - 1
35 |
36 | while entity_mask[j] == 1:
37 | size += 1
38 | j -= 1
39 | if j < 0:
40 | break
41 | return size
42 |
43 | # is position i the last position in an entity in this mask?
44 | def is_entity_last(i, entity_mask):
45 | if entity_mask[i] == 0:
46 | return False
47 | else:
48 | if i + 1 == len(entity_mask):
49 | return True # end of sequence, i is last.
50 |
51 | # otherwise i is last if the next token is 0.
52 | return entity_mask[i+1] == 0
53 |
54 | # this is about to get thorny.
55 | # returns whether there is an entity ngram of our chosen shape at i based on target index:
56 | # 0 counts only unigram entities, e.g. [Iran]
57 | # -1 counts only bigram entities, e.g. [New, York]
58 | # -2 counts only trigram entities, e.g. [Empire, State, Building]
59 | # -3 counts only 4-gram entities, e.g. [Co, ca, Co, la]
60 | # 1 counts only bigram entities but the other way.
61 | def is_entity_ngram(i, entity_mask, target_idx=-1):
62 | if entity_mask[i] == 0:
63 | return False
64 |
65 | # otherwise, the current guy is an entity, so let's check if the previous/future one
66 | # (to make an ngram) and the ones in between are also an entity as well
67 | else:
68 | if target_idx == 0:
69 | if i-1<0 and i+1 >= len(entity_mask): # [1]
70 | return bool(entity_mask[i])
71 | elif i-1<0 and i+1 < len(entity_mask): # [1,0...]
72 | return bool(entity_mask[i]) and not entity_mask[i+1]
73 | elif i-1>=0 and i+1 >= len(entity_mask): # [...0,1]
74 | return bool(entity_mask[i]) and not entity_mask[i-1]
75 | else: # i-1>=0 and i+1 < len(entity_mask), they're both within bounds
76 | return bool(entity_mask[i]) and not entity_mask[i+1] and not entity_mask[i-1]
77 |
78 | # backwards bigram. [1,1,0,0,0]
79 | elif target_idx == -1:
80 | if i-1<0:
81 | return False
82 | if i-2<0:
83 | if i+1 < len(entity_mask):
84 | return not entity_mask[i+1] and bool(entity_mask[i] and entity_mask[i-1])
85 | else:
86 | return bool(entity_mask[i] and entity_mask[i-1])
87 | else:
88 | if i+1 < len(entity_mask):
89 | return not entity_mask[i+1] and not entity_mask[i-2] and bool(entity_mask[i] and entity_mask[i-1])
90 | else:
91 | return not entity_mask[i-2] and bool(entity_mask[i] and entity_mask[i-1])
92 |
93 | # backwards trigram [0,1,1,1,0]
94 | elif target_idx == -2:
95 | if i-2<0:
96 | return False
97 | if i-3<0:
98 | if i+1 < len(entity_mask):
99 | return not entity_mask[i+1] and sum(entity_mask[i-2:i+1]) == 3
100 | else:
101 | return sum(entity_mask[i-2:i+1]) == 3
102 | else:
103 | if i+1 < len(entity_mask):
104 | return not entity_mask[i+1] and not entity_mask[i-3] and (sum(entity_mask[i-2:i+1]) == 3)
105 | else:
106 | return not entity_mask[i-3] and (sum(entity_mask[i-2:i+1]) == 3)
107 |
108 | # backwards 4-gram [0,1,1,1,1]
109 | elif target_idx == -3:
110 | if i-3<0:
111 | return False
112 | if i-4<0:
113 | if i+1 < len(entity_mask): # [1,1,1,1,0...]
114 | return not entity_mask[i+1] and sum(entity_mask[i-3:i+1]) == 4
115 | else: # [1,1,1,1]
116 | return sum(entity_mask[i-3:i+1]) == 4
117 | else:
118 | if i+1 < len(entity_mask): # [0,1,1,1,1,0...]
119 | return not entity_mask[i+1] and not entity_mask[i-4] and (sum(entity_mask[i-3:i+1]) == 4)
120 | else: # [...0,1,1,1,1]
121 | return not entity_mask[i-4] and (sum(entity_mask[i-3:i+1]) == 4)
122 |
123 | # forwards bigram [0,1,1,0]
124 | elif target_idx == 1:
125 | if i+1 >= len(entity_mask): # [...,1]
126 | return False
127 | if i+2 >= len(entity_mask): # [...,1,0]
128 | if i-1 < 0: # [1,0]
129 | return bool(entity_mask[i] and entity_mask[i+1])
130 | else: # [...0,1,0]
131 | return not entity_mask[i-1] and bool(entity_mask[i] and entity_mask[i+1])
132 | else:
133 | if i-1 < 0: # [1,0,...]
134 | return not entity_mask[i+2] and bool(entity_mask[i] and entity_mask[i+1])
135 | else: # [...0,1,0,...]
136 | return not entity_mask[i-1] and not entity_mask[i+2] and bool(entity_mask[i] and entity_mask[i+1])
--------------------------------------------------------------------------------
/scripts/train_probe.py:
--------------------------------------------------------------------------------
1 | '''
2 | Train a new linear probe to predict a token offset (target_idx) at a specific
3 | layer. For example, you can train a probe with target_idx=-2 and layer=12 to
4 | take hidden states at layer 12 (the 13th layer) and predict what was two tokens
5 | before that hidden state.
6 | '''
7 | import os
8 | import csv
9 | import wandb
10 | import torch
11 | import argparse
12 |
13 | import pandas as pd
14 | import torch.nn.functional as F
15 | import pytorch_warmup as warmup
16 |
17 | from torch.utils.data import DataLoader
18 | from torch.optim.lr_scheduler import ReduceLROnPlateau
19 | from nnsight import LanguageModel
20 |
21 | from training import LinearModel, train_epoch, test, DocDataset, DocCollate
22 |
23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24 | wandb.login()
25 |
26 | def datasetname(input):
27 | return input.split('/')[-1][:-4]
28 |
29 | def add_args(s, args):
30 | for k, v in vars(args).items():
31 | if k in ['probe_bsz', 'probe_epochs']:
32 | s += f"-{k[6:]}{v}"
33 | elif k in ['probe_lr']:
34 | s += f"-{k[6:]}" + "{:1.5f}".format(v)
35 | return s
36 |
37 | def main(args):
38 | model = LanguageModel(args.model, device_map=device)
39 | tokenizer = model.tokenizer
40 | tokenizer.add_special_tokens({'pad_token':''})
41 |
42 | VOCAB_SIZE = model.config.vocab_size
43 | MODEL_SIZE = model.config.hidden_size
44 | MODEL_NAME = args.model.split('/')[-1]
45 |
46 | # max_seq_len from below sources is 2048, but changing to 512 for memory/speed
47 | # https://github.com/meta-llama/llama/blob/main/llama/model.py#L31
48 | # https://github.com/meta-llama/llama3/blob/bf8d18cd087a4a0b3f61075b7de0b86cf6c70697/llama/model.py#L32
49 | WINDOW_SIZE = args.window_size
50 |
51 | for param in model.parameters():
52 | param.requires_grad = False
53 |
54 | run_name = add_args(f"{MODEL_NAME}LAYER{args.layer}-TGTIDX{args.target_idx}-{datasetname(args.train_data)}", args)
55 | wandb.init(project = args.wandb_proj, name = run_name, config = args, settings=wandb.Settings(start_method="fork"))
56 |
57 | run_name += f'-{wandb.run.id}'
58 | wandb.run.name = run_name
59 |
60 | if args.probe_bsz > 10:
61 | print(f"Warning: batch size represents number of documents (each doc contains a few hundred tokens). {args.probe_bsz}>10, you may want to use a smaller batch size.")
62 |
63 | # make dirs that include the wandb id
64 | checkpoint_dir = f"../checkpoints/{MODEL_NAME}/{run_name}"
65 | log_dir = f"../logs/{MODEL_NAME}/{run_name}"
66 | os.makedirs(checkpoint_dir, exist_ok=True)
67 | os.makedirs(log_dir, exist_ok=True)
68 |
69 | # load data csvs
70 | train_data = pd.read_csv(args.train_data)
71 | val_data = pd.read_csv(args.val_data)
72 | test_data = pd.read_csv(args.test_data)
73 |
74 | # pass in subjects from counterfact dataset as "entities" to split during testing
75 | which_entity = "subject"
76 | entities = None
77 | if test_data is not None:
78 | if which_entity in test_data.columns:
79 | entities = list(test_data[which_entity])
80 |
81 | train_dataset = DocDataset(model, tokenizer, args.layer, args.target_idx, train_data, WINDOW_SIZE, VOCAB_SIZE, device)
82 | val_dataset = DocDataset(model, tokenizer, args.layer, args.target_idx, val_data, WINDOW_SIZE, VOCAB_SIZE, device)
83 | test_dataset = DocDataset(model, tokenizer, args.layer, args.target_idx, test_data, WINDOW_SIZE, VOCAB_SIZE, device, entities=entities)
84 |
85 | linear_probe = LinearModel(MODEL_SIZE, VOCAB_SIZE).to(device)
86 | wandb.watch(linear_probe)
87 |
88 | collate_fn = DocCollate(args.layer, args.target_idx, tokenizer, model, WINDOW_SIZE, device)
89 |
90 | train_loader = DataLoader(dataset=train_dataset, batch_size=args.probe_bsz, collate_fn=collate_fn,
91 | drop_last=True, pin_memory=False, shuffle=True)
92 | val_loader = DataLoader(dataset=val_dataset, batch_size=args.probe_bsz, collate_fn=collate_fn,
93 | drop_last=True, pin_memory=False)
94 | test_loader = DataLoader(dataset=test_dataset, batch_size=args.probe_bsz, collate_fn=collate_fn,
95 | drop_last=True, pin_memory=False)
96 |
97 | optimizer = torch.optim.AdamW(linear_probe.parameters(), lr=args.probe_lr, weight_decay=args.probe_wd) # no momentum
98 | warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)
99 |
100 | criterion = {
101 | "ce" : F.cross_entropy,
102 | "mse" : F.mse_loss
103 | }[args.criterion]
104 | scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3)
105 |
106 | print('training linear probe...')
107 | batches_seen = 0
108 | for epoch in range(args.probe_epochs):
109 | print('# Epoch {} #'.format(epoch))
110 | batches_seen = train_epoch(epoch, linear_probe, train_loader, criterion, optimizer, warmup_scheduler, args.accumulate, args.clip_threshold, batches_seen, device)
111 |
112 | # log validation loss at the end of each epoch
113 | val_loss, val_acc, val_topk_acc, val_entity_ng_acc, val_other_acc = test(linear_probe, val_loader, criterion, device, tokenizer, args.target_idx, return_results=False)
114 | wandb.log({"val_loss": val_loss, "val_acc": val_acc, "val_topk_acc": val_topk_acc, "val_entity_ng_acc": val_entity_ng_acc, "val_other_acc":val_other_acc})
115 |
116 | if warmup_scheduler is not None:
117 | with warmup_scheduler.dampening():
118 | scheduler.step(val_loss)
119 | else:
120 | scheduler.step(val_loss)
121 |
122 | # Get final testing accuracy and prediction results
123 | torch.save(linear_probe.state_dict(), f"{checkpoint_dir}/final.ckpt")
124 | test_loss, test_acc, test_topk_acc, test_entity_ng_acc, test_other_acc, test_results = test(linear_probe, test_loader, criterion, device, tokenizer, args.target_idx, return_results=True)
125 | test_results.to_csv(log_dir + f"/{datasetname(args.test_data)}_results.csv", quoting=csv.QUOTE_ALL, encoding='utf-8')
126 |
127 | print('Test Loss: {:10.4f} Accuracy: {:3.4f}%\n'.format(test_loss, test_acc))
128 | wandb.log({"test_loss": test_loss, "test_acc": test_acc, "test_topk_acc": test_topk_acc, "test_entity_ng_acc": test_entity_ng_acc, "test_other_acc": test_other_acc})
129 | wandb.finish()
130 |
131 |
132 | if __name__ == '__main__':
133 | parser = argparse.ArgumentParser()
134 |
135 | # training info for linear probe
136 | parser.add_argument('--probe_bsz', type=int, default=1)
137 | parser.add_argument('--probe_lr', type=float, default=0.1)
138 | parser.add_argument('--probe_wd', type=float, default=0.001)
139 | parser.add_argument('--probe_epochs', type=int, default=8)
140 |
141 | parser.add_argument('--wandb_proj', type=str, default='footprints')
142 | parser.add_argument('--accumulate', type=int, default=30)
143 | parser.add_argument('--clip_threshold', type=float, default=0.1)
144 |
145 | parser.add_argument('--window_size', type=int, default=512)
146 | parser.add_argument('--num_workers', type=int, default=12)
147 | parser.add_argument('--criterion', type=str, choices=['mse', 'ce'], default='ce')
148 |
149 | # document data locations
150 | parser.add_argument('--train_data', type=str, default='../data/train_tiny_1000.csv')
151 | parser.add_argument('--val_data', type=str, default='../data/val_tiny_500.csv')
152 | parser.add_argument('--test_data', type=str, default='../data/test_tiny_500.csv')
153 |
154 | # required specifications for where probe is trained
155 | parser.add_argument('--layer', type=int, required=True,
156 | help='which layer to train the probe at, from -1...32 where -1 is embedding layer and 32 is output pre-softmax.')
157 | parser.add_argument('--target_idx', type=int, required=True,
158 | help='which token the probe should predict from current hidden state (e.g. 0 for current token, -1 for prev)')
159 | parser.add_argument('--model', type=str, choices=['meta-llama/Llama-2-7b-hf', 'meta-llama/Meta-Llama-3-8B', 'EleutherAI/pythia-6.9b'], default='meta-llama/Llama-2-7b-hf')
160 |
161 |
162 | args = parser.parse_args()
163 | main(args)
164 |
165 |
166 |
--------------------------------------------------------------------------------
/scripts/test_probe.py:
--------------------------------------------------------------------------------
1 | '''
2 | Load a probe checkpoint from `train_probe.py and get test results on it for another dataset.
3 | If you're loading a llama 3 checkpoint, make sure to specify --model meta-llama/Meta-Llama-3-8B
4 | '''
5 | import os
6 | import csv
7 | import torch
8 | import argparse
9 | import pandas as pd
10 | import numpy as np
11 | import regex as re
12 | import spacy
13 |
14 | import torch.nn.functional as F
15 | from collections import Counter
16 | from torch.utils.data import DataLoader
17 | from nnsight import LanguageModel
18 | from training import LinearModel, test, DocDataset, DocCollate
19 |
20 | torch.manual_seed(0)
21 |
22 | idx_to_n = {
23 | 1 : 2,
24 | 0 : 1,
25 | -1 : 2,
26 | -2 : 3,
27 | -3 : 4
28 | }
29 |
30 | def idx_to_zip(idx, toks):
31 | if idx <= 0:
32 | skip = {
33 | 0 : 0,
34 | -1 : 1,
35 | -2 : 2,
36 | -3 : 3
37 | }[idx]
38 | return zip(toks[:-skip], toks[skip:])
39 | elif idx == 1:
40 | return zip(toks[1:], toks[:-1])
41 |
42 | def datasetname(input):
43 | return input.split('/')[-1][:-4]
44 |
45 | # get nice string when saving wikipedia results
46 | def wikiextra(datasetname, spacy):
47 | if "wikipedia" in datasetname:
48 | if spacy:
49 | return "_mte"
50 | else:
51 | return "_mtw"
52 | else:
53 | return ""
54 |
55 | # take in results dataframe and add column `train_freq` indicating frequency of that ngram in the training dataset.
56 | # note: this only does the SKIP ngrams. so "the big tower" and "the small tower" are the same for tgtidx=-2.
57 | def train_ngram_info(results, train_df, tokenizer, model_name, target_idx):
58 | # get information about train ngram frequencies
59 | gt_ctr = Counter()
60 | all_gt_ct = 0
61 | for d in list(train_df['text']):
62 | if "llama" in model_name:
63 | toks = tokenizer(d)['input_ids']
64 | else:
65 | bos = tokenizer.bos_token
66 | toks = tokenizer(bos+d)['input_ids']
67 | gt_ctr += Counter(idx_to_zip(target_idx, toks))
68 | all_gt_ct += len(toks)
69 |
70 | # then for each ngram in results, save that ngram's train count
71 | results['all_train_ct'] = [all_gt_ct for _ in range(len(results))]
72 | new = []
73 | for i, row in results.iterrows():
74 | # want to be able to change this based on target_idx
75 | try:
76 | ct = gt_ctr[(row['actual_tok_id'], row['current_tok_id'])]
77 | except KeyError:
78 | ct = 0
79 |
80 | row['ngram_train_ct'] = ct
81 |
82 | if row['ngram_train_ct'] > 0:
83 | row['log_ngram_train_freq'] = np.log(row['ngram_train_ct'] / row['all_train_ct'])
84 | else:
85 | row['log_ngram_train_freq'] = 0
86 | new.append(row)
87 |
88 | return pd.DataFrame(new)
89 |
90 | # returns a |V|x|V| matrix where each row is conditional probability of all preceding tokens
91 | # given that token as the succeeding token. e.g. row for "York" has a high probability on "New"
92 | # https://jofrhwld.github.io/teaching/courses/2022_lin517/python_sessions/04_session5.html#conditional-probability
93 | def train_conditional_probs(train_df, train_name, tokenizer, model_name, vocab_size, target_idx=-1):
94 | q_name = train_name + f"_qmodel{target_idx}.ckpt"
95 | if q_name in os.listdir("../data/qmodels"):
96 | # load the previous matrix for this dataset and target_idx=-1
97 | return torch.load(f"../data/qmodels/{q_name}")
98 |
99 | # edit this to account for diff target_idxs
100 | ng_ctr = Counter()
101 | ug_ctr = Counter()
102 | out = torch.zeros(size=(vocab_size, vocab_size))
103 | for doc in list(train_df['text']):
104 | if "llama" in model_name:
105 | toks = tokenizer(doc)['input_ids']
106 | else:
107 | bos = tokenizer.bos_token
108 | toks = tokenizer(bos+doc)['input_ids']
109 | ng_ctr += Counter(idx_to_zip(target_idx, toks)) # ngrams(toks, idx_to_n[target_idx]))
110 | ug_ctr += Counter(toks)
111 | assert(len(ug_ctr) <= vocab_size)
112 |
113 | # each ROW correspnds to the second token
114 | # right now we fill with joint probabilities
115 | for tup, ct in ng_ctr.items():
116 | out[tup] = ct / sum(ng_ctr.values())
117 |
118 | # then divide each row by p(x1) out of all unigrams
119 | unigram_probs = torch.zeros(size=(vocab_size,))
120 | for i in range(vocab_size):
121 | unigram_probs[i] = ug_ctr[i] / sum(ug_ctr.values())
122 | assert(torch.sum(unigram_probs) == 1)
123 |
124 | # avoid divide by 0 errors by just dividing by 1 (since numerator will also be 0 in those cases)
125 | denom = torch.where(unigram_probs == 0, torch.ones_like(unigram_probs), unigram_probs)
126 | out = (out.T / denom).T
127 | print(out)
128 |
129 | torch.save(out, f"../data/{q_name}.ckpt")
130 | return out
131 |
132 |
133 | def main(args):
134 | device = torch.device(f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu")
135 |
136 | model = LanguageModel(args.model, device_map=device)
137 | tokenizer = model.tokenizer
138 | tokenizer.add_special_tokens({'pad_token':''})
139 |
140 | VOCAB_SIZE = model.vocab_size
141 | MODEL_SIZE = model.config.hidden_size
142 | MODEL_NAME = args.model.split('/')[-1]
143 |
144 | # window size is actually 2048 but I choose 512 for brevity
145 | WINDOW_SIZE = args.window_size
146 |
147 | for p in model.parameters():
148 | p.requires_grad = False
149 |
150 | probe = LinearModel(MODEL_SIZE, VOCAB_SIZE).to(device)
151 | probe.load_state_dict(torch.load(args.checkpoint))
152 |
153 | # intuit target_idx from the filename
154 | if "TGTIDX" in args.checkpoint:
155 | s = re.search(r'TGTIDX-\d+|TGTIDX\d+', args.checkpoint).group()
156 | target_idx = int(s[6:])
157 | else:
158 | raise Exception("Can't infer target index from checkpoint: " + args.checkpoint)
159 |
160 | # intuit layer from the filename
161 | if "LAYER" in args.checkpoint:
162 | s = re.search(r'LAYER-\d+|LAYER\d+', args.checkpoint).group()
163 | layer = int(s[5:])
164 | else:
165 | raise Exception("Can't infer layer from checkpoint: " + args.checkpoint)
166 |
167 | collate_fn = DocCollate(layer, target_idx, tokenizer, model, WINDOW_SIZE, device)
168 |
169 | test_data = pd.read_csv(args.test_data)
170 |
171 | if args.test_data == "../data/counterfact_expandeds.csv":
172 | corr_str = {
173 | 'Llama-2-7b-hf' : 'llama-2-7b',
174 | 'Meta-Llama-3-8B' : 'llama-3-8b'
175 | }[MODEL_NAME]
176 |
177 | test_data = test_data.loc[test_data[f'{corr_str}_correct']]
178 | print(f"pruned down to only correct CounterFact answers, {len(test_data)}")
179 |
180 | # pass in subjects from counterfact dataset as entities
181 | which_entity = "subject"
182 | entities = None
183 | if test_data is not None:
184 | if which_entity in test_data.columns:
185 | entities = list(test_data[which_entity])
186 |
187 | if 'wikipedia' in args.test_data:
188 | if args.wiki_spacy:
189 | nlp = spacy.load("en_core_web_sm")
190 | entities = []
191 | for d in test_data['text']:
192 | doc = nlp(d)
193 | # https://stackoverflow.com/questions/70185150/return-all-possible-entity-types-from-spacy-model
194 | # ['ORG', 'CARDINAL', 'DATE', 'GPE', 'PERSON', 'MONEY', 'PRODUCT', 'TIME', 'PERCENT', 'WORK_OF_ART', 'QUANTITY', 'NORP', 'LOC', 'EVENT', 'ORDINAL', 'FAC', 'LAW', 'LANGUAGE']
195 | # we want non-number ones. no dates, money, cardinals, time etc.
196 | desired_types = ['ORG', 'GPE', 'PERSON', 'PRODUCT', 'WORK_OF_ART', 'NORP', 'LOC', 'EVENT', 'FAC', 'LAW', 'LANGUAGE']
197 | entities += [e.text for e in doc.ents if e.label_ in desired_types]
198 | entities = list(set(entities))
199 | print(entities[:10])
200 | else:
201 | multi_tok = []
202 | for txt in test_data['text']:
203 | txt = re.sub(r'[^\w\s]', '', txt)
204 | txt = re.sub(r'[0-9]', '', txt)
205 | for word in Counter(txt.split()):
206 | if len(tokenizer(word)['input_ids'][1:]) > 1:
207 | multi_tok.append(word)
208 | print(multi_tok[:10])
209 | entities = list(set(multi_tok))
210 |
211 | test_dataset = DocDataset(model, tokenizer, layer, target_idx, test_data, WINDOW_SIZE, VOCAB_SIZE, device, entities=entities)
212 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, collate_fn=collate_fn)
213 |
214 | criterion = {
215 | "ce" : F.cross_entropy,
216 | "mse" : F.mse_loss
217 | }[args.criterion]
218 |
219 | test_loss, test_acc, test_topk_acc, test_entity_ng_acc, test_other_acc, test_results = test(probe, test_loader,
220 | criterion, device, tokenizer, target_idx, return_results=True)
221 |
222 | model_folder = args.checkpoint.split('/')[2] # hf-llama-2
223 | run_name = args.checkpoint.split('/')[-2]
224 | log_dir = f"../logs/{model_folder}/{run_name}/"
225 | out_csv = f"{datasetname(args.test_data)}_results{wikiextra(args.test_data, args.wiki_spacy)}.csv"
226 |
227 | os.makedirs(log_dir, exist_ok=True)
228 | print(log_dir + out_csv)
229 | test_results.to_csv(log_dir + out_csv, quoting=csv.QUOTE_ALL, encoding='utf-8')
230 | print('Test Loss: {:10.4f} Accuracy: {:3.4f}%\n'.format(test_loss, test_acc))
231 | print('Test Top-5 Accuracy: {:3.4f}%\n'.format(test_topk_acc))
232 | print('Test Accuracy for Entity Ngrams: {:3.4f}% (Other: {:3.4f})\n'.format(test_entity_ng_acc, test_other_acc))
233 |
234 | return test_loss, test_acc, test_entity_ng_acc, test_other_acc
235 |
236 | if __name__ == '__main__':
237 | parser = argparse.ArgumentParser()
238 |
239 | # defaults
240 | parser.add_argument('--criterion', type=str, choices=['mse', 'ce'], default='ce')
241 | parser.add_argument('--cuda', type=int, default=0)
242 | parser.add_argument('--num_workers', type=int, default=12)
243 | parser.add_argument('--window_size', type=int, default=512)
244 |
245 | # what dataset to test on
246 | parser.add_argument('--test_data', type=str,
247 | choices=[
248 | '../data/counterfact_expanded.csv',
249 | '../data/test_tiny_500.csv',
250 | '../data/wikipedia_test_500.csv'
251 | ], default="../data/test_tiny_500.csv")
252 |
253 | # for wikipedia dataset, do MTE if True, otherwise do MTW
254 | parser.add_argument('--wiki_spacy', action='store_true')
255 | parser.set_defaults(wiki_spacy=False)
256 |
257 | # specify probe checkpoint and model. tests the same layer and target_idx.
258 | parser.add_argument('--model', type=str, choices=['meta-llama/Llama-2-7b-hf', 'meta-llama/Meta-Llama-3-8B'], default='meta-llama/Llama-2-7b-hf')
259 | parser.add_argument('--checkpoint', type=str, required=True, help="e.g. ../checkpoints/Llama-2-7b-hf/llamaLAYER12-TGTIDX-2.../final.ckpt")
260 |
261 | args = parser.parse_args()
262 | main(args)
263 |
--------------------------------------------------------------------------------
/scripts/readout.py:
--------------------------------------------------------------------------------
1 | '''
2 | Given a dataset csv with the column 'text' and a choice between Llama-2-7b and
3 | Llama-3-8b, "read out" the vocabulary of that model based on the given datset.
4 | '''
5 | import os
6 | import re
7 | import csv
8 | import time
9 | import torch
10 | import pickle
11 | import argparse
12 | import numpy as np
13 | import pandas as pd
14 |
15 | from collections import Counter
16 | from transformers import AutoTokenizer
17 | from training import LinearModel
18 | from nnsight import LanguageModel
19 | from huggingface_hub import hf_hub_download
20 |
21 | tgt_idxs = ['i0', 'in1', 'in2', 'in3']
22 |
23 | def datasetname(s):
24 | if s == "wikipedia":
25 | return s
26 | elif s == "bigbio/med_qa":
27 | return s.split('/')[-1]
28 | elif s == "bigbio/blurb":
29 | return "blurb"
30 | elif s == "wmt/wmt14":
31 | return "wmt14"
32 | else:
33 | return s.split('/')[-1][:-4]
34 |
35 | # load in the outputs of `algorithm1.py` to double check and use as basis for this.
36 | def load_scores(model, dataset):
37 | dir = f"../logs/{model}/candidates/{dataset}_layer1_9/"
38 | dfs = []
39 | for f in os.listdir(dir):
40 | if "scores" in f:
41 | df = pd.read_csv(dir + f)
42 | df['doc_idx'] = int(re.search(r'\d+', f).group(0))
43 | dfs.append(df)
44 |
45 | return pd.concat(dfs).drop(columns=["Unnamed: 0"])
46 |
47 | def get_probe(layer, target_idx, model):
48 | # example: llama-2-7b probe at layer 0, predicting 3 tokens ago
49 | # predicting the next token would be `layer0_tgtidx1.ckpt`
50 | checkpoint_path = hf_hub_download(
51 | repo_id="sfeucht/footprints",
52 | filename=f"{model}/layer{layer}_tgtidx{target_idx}.ckpt"
53 | )
54 |
55 | # model_size is 4096 for both models.
56 | # vocab_size is 32000 for Llama-2-7b and 128256 for Llama-3-8b
57 | model_size = 4096
58 | if model == 'llama-2-7b':
59 | vocab_size = 32000
60 | elif model == 'llama-3-8b':
61 | vocab_size = 128256
62 |
63 | probe = LinearModel(model_size, vocab_size).cuda()
64 | probe.load_state_dict(torch.load(checkpoint_path))
65 |
66 | return probe
67 |
68 |
69 | '''
70 | Implementation of "Erasure Score"
71 | '''
72 | def psi(doc_info, i, j):
73 | sn = doc_info.iloc[i:j+1]
74 |
75 | idx_to_key = {
76 | 0 : 'tok-1_i0_probdelta',
77 | -1 : 'tok-1_in1_probdelta',
78 | -2 : 'tok-1_in2_probdelta',
79 | -3 : 'tok-1_in3_probdelta'
80 | }
81 |
82 | # we're doing 0 indexing for t, so this is different from the paper
83 | # in the paper we did start-end, so we have to flip all these to end-start and add - in front of probdelta
84 | # also, we have to do (t + idx < i) since we're using absolute idxs
85 |
86 | ideal = 1
87 | score = -sn.iloc[-1]['tok-1_i0_probdelta']
88 | for t, row in sn.iterrows():
89 | # for idx in range(-3, 0): # OPTION include -3 information
90 | for idx in range(-2, 0):
91 | # 0 - 3 means it's outside bounds
92 | # 1 - 1 is inside bounds (predicting t=0)
93 | # 3 - 2 is inside bounds (predicting t=1 from t=3)
94 | sign = -1 if (t + idx < i) else 1
95 |
96 | try:
97 | sc = -row[idx_to_key[idx]]
98 | if not np.isnan(sc):
99 | score += sign * sc
100 | ideal += 1
101 | except KeyError:
102 | pass
103 |
104 | return score / ideal
105 |
106 | '''
107 | Given doc_info, apply Algorithm 1 to segment this particular document into non-
108 | overlapping high-scoring chunks. Allow for unigram segments to fill in the gaps.
109 | '''
110 | def partition_doc(doc_info):
111 | # create and initialize the matrix
112 | n = len(doc_info)
113 |
114 | # implement as np array
115 | dp = np.ones((n, n))
116 | dp = dp * -1
117 |
118 | # fill out the matrix
119 | for i in range(n):
120 | for j in range(n):
121 | if i <= j:
122 | if True: # j - i < 6:
123 | dp[i][j] = psi(doc_info, i, j)
124 |
125 | # get the top scores in order
126 | x, y = np.unravel_index(np.argsort(-dp.flatten()), dp.shape)
127 | coords = np.array(list(zip(x, y)))
128 |
129 | # go through all the top ngrams and add them to list, marking which ones become invalid as we go.
130 | segments = []
131 | for p, q in coords:
132 | if p <= q:
133 | val = dp[p, q]
134 |
135 | valid = True
136 | for (x, y), _ in segments:
137 | if x > q or y < p:
138 | pass
139 | else:
140 | valid = False
141 | break
142 |
143 | if valid:
144 | segments.append(((p,q), val))
145 |
146 | # validate that the segments fully cover doc
147 | all_ranges = []
148 | for (x, y), val in segments:
149 | r = range(x, y+1)
150 | all_ranges += list(r)
151 | if set(all_ranges) == set(range(len(dp))):
152 | print("segments have full coverage")
153 | else:
154 | print("WARNING: segments did not fully cover doc")
155 |
156 | return segments
157 |
158 | # take the segmentation of a document and read out the multi-token words
159 | # that were identified via this approach.
160 | def read_out_entries(segments, tokens, filter_unis=True):
161 | entries = []
162 | vals = []
163 | for (x, y), val in segments:
164 | r = range(x, y+1)
165 | if not (x==y and filter_unis):
166 | entries.append([tokens[idx] for idx in r])
167 | vals.append(val)
168 |
169 | return entries, vals
170 |
171 | '''
172 | Run all the possible probes for a specific token
173 | '''
174 | def get_tok_metrics(toks, i, start_probes, end_probes, start_states, end_states, tgt_idxs):
175 | corrstart, corrend, probdelta = {}, {}, {}
176 |
177 | # run each pair of probes on hidden states
178 | for start_probe, end_probe, s in zip(start_probes, end_probes, tgt_idxs):
179 | label = {
180 | 'i0' : toks[i],
181 | 'in1' : toks[i - 1] if i >= 1 else None,
182 | 'in2' : toks[i - 2] if i >= 2 else None,
183 | 'in3' : toks[i - 3] if i >= 3 else None
184 | }[s]
185 |
186 | if label is not None:
187 | start_logits = start_probe(start_states[i]).squeeze().detach().cpu()
188 | end_logits = end_probe(end_states[i]).squeeze().detach().cpu()
189 |
190 | corrstart[s] = start_logits.argmax() == label
191 | corrend[s] = end_logits.argmax() == label
192 | probdelta[s] = end_logits.softmax(dim=-1)[label].item() - start_logits.softmax(dim=-1)[label].item()
193 |
194 | del start_logits, end_logits
195 |
196 | return corrstart, corrend, probdelta
197 |
198 |
199 | '''
200 | given a bunch of tokens and the states for the tokens at a layer, create a dataframe
201 | with every possible probdelta (for different target indices) for each token.
202 | '''
203 | def get_doc_info(tokens, model, layer_start, layer_end, start_probes, end_probes, tokenizer):
204 | tgt_idxs = ['i0', 'in1', 'in2', 'in3']
205 |
206 | # get hidden states for tokens
207 | with torch.no_grad():
208 | with model.trace(tokens):
209 | ss = model.model.layers[layer_start].output[0].squeeze().save()
210 | es = model.model.layers[layer_end].output[0].squeeze().save()
211 |
212 | start_states = ss.detach()
213 | end_states = es.detach()
214 | del ss, es
215 |
216 | # per token: tok-1, decoded, tok-1_i0_iscorr, tok-1_i0_rankdelta, tok-1_i0_logitdelta, tok-1_i0_probdelta
217 | rows = []
218 | for i, ug_tok in enumerate(tokens):
219 | row = {'decoded' : tokenizer.decode(ug_tok), 'n' : 1}
220 |
221 | # for each token in the ng run all the relevant probes
222 | corrstart, corrend, probdelta = \
223 | get_tok_metrics(tokens, i, start_probes, end_probes, start_states, end_states, tgt_idxs)
224 |
225 | # save this token
226 | row[f'tok-1'] = ug_tok
227 |
228 | # save i0, in1, in2, in3 for token-1 in the unigram
229 | for s in tgt_idxs:
230 | if s in corrstart.keys():
231 | row[f'tok-1_{s}_corrstart'] = corrstart[s].item()
232 | row[f'tok-1_{s}_corrend'] = corrend[s].item()
233 | row[f'tok-1_{s}_probdelta'] = probdelta[s]
234 |
235 | rows.append(row)
236 |
237 | del start_states, end_states
238 | torch.cuda.empty_cache()
239 |
240 | return pd.DataFrame(rows)
241 |
242 | def main(args):
243 | model = LanguageModel(args.model, device_map='cuda')
244 | tokenizer = AutoTokenizer.from_pretrained(args.model)
245 |
246 | MODEL_NAME = args.model.split('/')[-1]
247 | WINDOW_SIZE = 256
248 |
249 | # dataset we want to index
250 | dataset = pd.read_csv(args.dataset)
251 |
252 | dump_dir = f"../logs/{MODEL_NAME}/readout/{datasetname(args.dataset)}_layer{args.layer_start}_{args.layer_end}/"
253 | os.makedirs(dump_dir, exist_ok=True)
254 |
255 | hf_string = {
256 | 'Llama-2-7b-hf' : 'llama-2-7b',
257 | 'Meta-Llama-3-8B' : 'llama-3-8b'
258 | }[MODEL_NAME]
259 |
260 | # load in the probes at layer_start and layer_end
261 | start_probes, end_probes = [], []
262 | start_probes.append(get_probe(args.layer_start, 0, hf_string))
263 | end_probes.append(get_probe(args.layer_end, 0, hf_string))
264 |
265 | start_probes.append(get_probe(args.layer_start, -1, hf_string))
266 | end_probes.append(get_probe(args.layer_end, -1, hf_string))
267 |
268 | start_probes.append(get_probe(args.layer_start, -2, hf_string))
269 | end_probes.append(get_probe(args.layer_end, -2, hf_string))
270 |
271 | ctr = 0
272 | all_ctr = Counter()
273 | sum_scores = {}
274 | tik = time.time()
275 | for doc_idx, doc in enumerate(dataset['text']):
276 | tokens = tokenizer(doc)['input_ids'][:WINDOW_SIZE]
277 |
278 | # get probe probability information for this doc_idx
279 | fname = f"docinfo_{doc_idx}.csv"
280 | try:
281 | doc_df = pd.read_csv(dump_dir + fname)
282 | print(f"loaded {dump_dir + fname}")
283 | except FileNotFoundError:
284 | doc_df = get_doc_info(tokens, model, args.layer_start, args.layer_end, start_probes, end_probes, tokenizer)
285 | doc_df.to_csv(dump_dir + fname, quoting=csv.QUOTE_ALL)
286 | print(f"saved {dump_dir + fname}, {len(tokens)} tokens in {datasetname(args.dataset)}")
287 |
288 | # segment doc with partition_doc
289 | picklename = f"segments_{doc_idx}.pkl"
290 | try:
291 | with open(dump_dir + picklename, 'rb') as f:
292 | segments = pickle.load(f)
293 | print(f"loaded segments from {dump_dir + picklename}")
294 |
295 | except FileNotFoundError:
296 | print(f"partitioning doc {doc_idx}...")
297 | segments = partition_doc(doc_df)
298 |
299 | with open(dump_dir + picklename, 'wb') as f:
300 | pickle.dump(segments, f)
301 | print(f"saved segments to {dump_dir + picklename}")
302 |
303 | # filter out the unigrams when you're "reading out" the vocabulary
304 | entries, vals = read_out_entries(segments, tokens, filter_unis=True)
305 | decoded_entries = [tokenizer.decode(e) for e in entries]
306 |
307 | # add to running totals
308 | all_ctr += Counter(decoded_entries)
309 | for de, v in zip(decoded_entries, vals):
310 | try:
311 | sum_scores[de] += v
312 | except KeyError:
313 | sum_scores[de] = v
314 |
315 | ctr += 1
316 | if args.n_examples > 0:
317 | if ctr >= args.n_examples:
318 | break # cut off
319 |
320 | tok = time.time()
321 | print("minutes taken:", (tok-tik) / 60)
322 |
323 | assert all_ctr.keys() == sum_scores.keys()
324 |
325 | # save counts of all vocabulary items
326 | cts_fname = f"cts_0thru{doc_idx}.pkl"
327 | with open(dump_dir + cts_fname, 'wb') as f:
328 | pickle.dump(all_ctr, f)
329 | print(f"saved counts at {dump_dir + cts_fname}")
330 |
331 | # calculate average scores for each vocab item
332 | avg_scores = {}
333 | for k, v in sum_scores.items():
334 | avg_scores[k] = v / all_ctr[k]
335 |
336 | # save average scores for each vocab item
337 | avgs_fname = f"avgs_0thru{doc_idx}.pkl"
338 | with open(dump_dir + avgs_fname, 'wb') as f:
339 | pickle.dump(avg_scores, f)
340 | print(f"saved averages at {dump_dir + avgs_fname}")
341 |
342 | ctr = 0
343 | print("\nTop 50 Vocabulary Entries")
344 | for k, v in avg_scores.items():
345 | print(repr(k), '\t', v)
346 | ctr += 1
347 | if ctr > 50:
348 | break
349 |
350 | if __name__ == "__main__":
351 | parser = argparse.ArgumentParser()
352 |
353 | parser.add_argument("--dataset", default="../data/test_tiny_500.csv")
354 | parser.add_argument('--layer_start', type=int, default=1)
355 | parser.add_argument('--layer_end', type=int, default=9)
356 | parser.add_argument('--n_examples', type=int, default=-1, help="-1 to use the whole dataset")
357 |
358 | parser.add_argument('--model', default='meta-llama/Llama-2-7b-hf',
359 | choices=['meta-llama/Meta-Llama-3-8B', 'meta-llama/Llama-2-7b-hf'])
360 |
361 | args = parser.parse_args()
362 |
363 | main(args)
--------------------------------------------------------------------------------
/scripts/training.py:
--------------------------------------------------------------------------------
1 | '''
2 | Functions used in train_probe.py and test_probe.py for training and data loading.
3 | '''
4 | import torch
5 | import wandb
6 | import pandas as pd
7 |
8 | from tqdm import tqdm
9 | import torch.nn as nn
10 | from torch.utils.data import Dataset
11 | from utils import acc, _topktoks, _topkprobs, ngram_size, is_entity_last, is_entity_ngram
12 |
13 | class LinearModel(nn.Module):
14 | def __init__(self, input_size, output_size, bias=False):
15 | super(LinearModel, self).__init__()
16 | self.fc = nn.Linear(input_size, output_size, bias=bias)
17 |
18 | def forward(self, x):
19 | output = self.fc(x)
20 | return output
21 |
22 | '''
23 | Trains a probe for a single epoch and logs loss/accuracy.
24 |
25 | Parameters:
26 | epoch: index of the current epoch
27 | probe: linear model to be trained
28 | train_loader: DataLoader of training data
29 | criterion: what loss function to use
30 | optimizer: torch optimizer for linear model
31 | warmup_scheduler: scheduler for learning rate
32 | accumulate: how many batches to wait until updating / clearing grads
33 | clip_threshold: threshold for gradient clipping.
34 | batches_seen: no. batches seen before this epoch
35 | device: device of Llama model
36 |
37 | Returns:
38 | None
39 | '''
40 | def train_epoch(epoch, probe, train_loader, criterion, optimizer, warmup_scheduler, accumulate, clip_threshold, batches_seen, device):
41 | probe.train()
42 |
43 | for batch_idx, (hidden_states, target_toks, _, _, _) in enumerate(train_loader):
44 | hidden_states, target_toks = hidden_states.to(device), target_toks.to(device)
45 | assert(not torch.isnan(hidden_states).any() and not torch.isinf(hidden_states).any())
46 |
47 | # get probe predictions and convert to toks if needed
48 | output = probe(hidden_states.float()).to(device)
49 |
50 | # then calculate loss with the target tokens
51 | loss = criterion(output, target_toks, reduction="mean")
52 | loss.backward()
53 |
54 | if batch_idx % accumulate == 0 and batch_idx > 0:
55 | torch.nn.utils.clip_grad_norm_(probe.parameters(), clip_threshold)
56 | optimizer.step()
57 | optimizer.zero_grad()
58 |
59 | loss = loss.detach().item()
60 |
61 | # learning rate warmup for AdamW.
62 | if warmup_scheduler is not None:
63 | if batch_idx < len(train_loader)-1:
64 | with warmup_scheduler.dampening():
65 | pass
66 |
67 | # print training accuracy/loss every 10 epochs, and on the last epoch
68 | if batch_idx % max(accumulate, 10) == 0 or batch_idx == len(train_loader) - 1:
69 | train_acc = acc(output.cpu(), target_toks.cpu())
70 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tTraining Acc:{:3.3f}%\tBatch Loss: {:.6f} ({} tokens)'.format(
71 | epoch, batch_idx, len(train_loader), 100. * batch_idx / len(train_loader),
72 | train_acc.item(), loss, hidden_states.size()[0]))
73 |
74 | wandb.log({"train_loss": loss, "train_acc": train_acc,
75 | "epoch": epoch, "batches_seen": 1 + batch_idx + batches_seen})
76 |
77 | return 1 + batch_idx + batches_seen
78 |
79 |
80 | '''
81 | Given a trained LinearModel and a DataLoader of test examples, test the model on a given test_loader and return csv of results (optionally)
82 | '''
83 | def test(probe, test_loader, criterion, device, tokenizer, target_idx, return_results=False):
84 | probe.eval()
85 | total_loss = 0.0
86 | total_toks = 0
87 | correct = 0
88 | topk_correct = 0
89 | results = []
90 |
91 | n_entity_ngrams = 0
92 | n_entity_ngrams_correct = 0
93 | n_other = 0
94 | n_other_correct = 0
95 | with torch.no_grad():
96 | for (data, target_toks, curr_toks, doc_idxs, entity_mask) in tqdm(test_loader):
97 | if data is None:
98 | continue
99 |
100 | output = probe(data.to(device).float()).to(device)
101 |
102 | loss = criterion(output, target_toks.to(device), reduction="mean")
103 | total_loss += loss.detach().item()
104 |
105 | for i, v in enumerate(output.cpu()):
106 | doc_id = doc_idxs[i]
107 | current_tok = _topktoks(curr_toks[i])
108 | actual_tok = target_toks[i]
109 | predicted_tok = _topktoks(v)
110 | this_loss = criterion(v, target_toks[i])
111 | is_correct = predicted_tok == actual_tok
112 |
113 | total_toks += 1
114 | if is_correct:
115 | correct += 1
116 | if actual_tok in _topktoks(v, k=5): # less interpretable NOW THAT PINV
117 | topk_correct += 1
118 |
119 | kl_divergence = -1
120 | q_target_log_prob = torch.inf
121 | p_target_log_prob = torch.inf
122 |
123 |
124 | if entity_mask is not None:
125 | this_is_entity_ngram = bool(is_entity_ngram(i, entity_mask, target_idx=target_idx))
126 | if this_is_entity_ngram:
127 | n_entity_ngrams += 1
128 | n_entity_ngrams_correct += int(is_correct)
129 | else: # count up coarse "other" values.
130 | n_other += 1
131 | n_other_correct += int(is_correct)
132 |
133 | n = 0
134 | this_is_entity_last = is_entity_last(i, entity_mask)
135 | if this_is_entity_last: # if entity save how big the entity is
136 | n = ngram_size(i, entity_mask)
137 |
138 | if return_results:
139 | # BOS token becomes encoded as NaN in pandas here
140 | curr_result = {
141 | "doc_id" : doc_id.item(),
142 | "current_tok_id" : current_tok.item(),
143 | "actual_tok_id" : actual_tok.item(),
144 | "predicted_tok_id" : predicted_tok.item(),
145 | "current_tok" : tokenizer.decode(current_tok.tolist()),
146 | "actual_tok" : tokenizer.decode(actual_tok.tolist()),
147 | "predicted_tok" : tokenizer.decode(predicted_tok.tolist()),
148 | "loss" : this_loss.item(),
149 |
150 | # from this information you can split into entity_last, entity_notlast, and nonentity.
151 | # as well as the case "entity_ngram": whether this probe is predicting the first tok FROM the last tok of an entity.
152 | "is_entity" : entity_mask[i].item() if entity_mask is not None else -1,
153 | "is_entity_last" : this_is_entity_last if entity_mask is not None else -1,
154 | "is_entity_ngram" : this_is_entity_ngram if entity_mask is not None else -1,
155 | "n" : n,
156 |
157 | "kl_divergence" : kl_divergence,
158 | "q_target_log_prob" : q_target_log_prob,
159 | "p_target_log_prob" : p_target_log_prob,
160 | **_topkprobs(v, tokenizer)
161 | }
162 |
163 | results.append(curr_result)
164 |
165 |
166 | test_loss = total_loss / len(test_loader) # divide total average loss by no. batches
167 | test_acc = 100 * correct / total_toks
168 | topk_acc = 100 * topk_correct / total_toks
169 |
170 | if n_entity_ngrams > 0:
171 | entity_ngram_acc = 100 * n_entity_ngrams_correct / n_entity_ngrams
172 | else:
173 | entity_ngram_acc = -1
174 |
175 | if n_other > 0:
176 | other_acc = 100 * n_other_correct / n_other
177 | else:
178 | other_acc = -1
179 |
180 | if return_results:
181 | return test_loss, test_acc, topk_acc, entity_ngram_acc, other_acc, pd.DataFrame(results)
182 | else:
183 | return test_loss, test_acc, topk_acc, entity_ngram_acc, other_acc
184 |
185 | '''
186 | Dataset for retrieving tokenized documents from csv along with masks that mark
187 | which tokens correspond to "entities"
188 | '''
189 | class DocDataset(Dataset):
190 | def __init__(self, model, tokenizer, layer_name, target_idx, dataset_csv, window_size, vocab_size, device, entities=None):
191 | self.model = model
192 | self.tokenizer = tokenizer
193 | self.layer_name = layer_name # -1 is embedding, 0-31 for layers, 32 for logits right at the end
194 | self.target_idx = target_idx # -1 is previous token, 0 is current, etc.
195 | self.dataset_csv = dataset_csv
196 | self.window_size = window_size
197 | self.vocab_size = vocab_size
198 | self.device = device
199 | self.entities = entities # list of strings that are the entities we want to mask out
200 |
201 | if self.entities is not None:
202 | self.entities = [self.tokenize(e, bos=False) for e in self.entities]
203 |
204 | # llama tokenizer already adds BOS token
205 | def tokenize(self, text, bos=True):
206 | if bos:
207 | t = self.tokenizer(text)['input_ids']
208 | else:
209 | t = self.tokenizer(text)['input_ids'][1:]
210 |
211 | # this makes sure that entity mask is also truncated to window size
212 | if len(t) > self.window_size:
213 | return t[:self.window_size]
214 | else:
215 | return t
216 |
217 | # iterate through sequence and mark subseq occurs in sequence
218 | def mask_iterator(self, sequence, subseq, mask):
219 | sequence = list(sequence.cpu())
220 | if len(subseq) <= len(sequence):
221 | for i in range(len(sequence)-len(subseq)+1):
222 | assert len(sequence[i:i+len(subseq)]) == len(subseq)
223 | if (sequence[i:i+len(subseq)] == subseq):
224 | mask[i:i+len(subseq)] = 1
225 | return torch.Tensor(mask)
226 |
227 | # returns number of documents, not tokens
228 | def __len__(self):
229 | if self.dataset_csv is not None:
230 | return len(self.dataset_csv)
231 |
232 | # get document tokens and mask
233 | def __getitem__(self, index):
234 | doc = self.dataset_csv.iloc[index]
235 | doc_string = str(doc['text'])
236 |
237 | # need this for entity mask calculations
238 | tokens = torch.tensor(self.tokenize(doc_string))
239 |
240 | entity_mask = torch.zeros_like(tokens)
241 | if self.entities is not None:
242 | for e in self.entities:
243 | entity_mask = self.mask_iterator(tokens, e, entity_mask)
244 |
245 | return torch.tensor(index), doc_string, tokens, entity_mask
246 |
247 | '''
248 | Bloated collate function that takes sequences and retrieves hidden states for the given model using nnsight
249 | '''
250 | class DocCollate(object):
251 | def __init__(self, layer, target_idx, tokenizer, model, window_size, device):
252 | self.layer = layer
253 | self.target_idx = target_idx
254 | self.tokenizer = tokenizer
255 | self.model = model
256 | self.window_size = window_size
257 | self.device = device
258 |
259 | def __call__(self, batch):
260 | # pad all the strings and save attention mask
261 | strings = [s for (_, s, _, _) in batch]
262 | tokenized = self.tokenizer(strings, return_tensors='pt', padding=True, truncation=True, max_length=self.window_size)
263 | attention_mask = tokenized['attention_mask']
264 |
265 | with self.model.trace(tokenized):
266 | if 'llama' in self.model.config._name_or_path.lower():
267 | if self.layer == -1:
268 | states = self.model.model.embed_tokens.output.save()
269 | elif self.layer == 32:
270 | states = self.model.model.norm.output.save()
271 | else:
272 | states = self.model.model.layers[self.layer].output[0].save()
273 | else: # pythia
274 | if self.layer == -1:
275 | states = self.model.gpt_neox.embed_in.output.save()
276 | elif self.layer == 32:
277 | states = self.model.gpt_neox.final_layer_norm.output.save()
278 | else:
279 | states = self.model.gpt_neox.layers[self.layer].output[0].save()
280 |
281 | # then loop through the entire thing to keep same logic for embs, tokens and doc_idxs
282 | source_hss = []
283 | target_toks = []
284 | current_toks = []
285 | doc_idxs = []
286 | entity_masks = []
287 | for i, doc in enumerate(batch):
288 | # batch looks like [doc0:(0, text, tokens, mask), doc1:(1, text, tokens, mask)...]
289 | doc_idx, tokens, entity_mask = (a.cpu() for a in doc if type(a)!=str)
290 |
291 | # get the hidden states we just calculated, and trim off the PAD tokens
292 | # for llama 2 7b padding is always at the beginning
293 | hidden_states = states[i][-sum(attention_mask[i]):]
294 | assert (len(hidden_states) == len(tokens))
295 |
296 | # make sure that hidden_states has enough tokens to deal with the given target_idx.
297 | # if the target_idx is gonna be outside the bounds of hidden_states, we want to skip doc.
298 | if abs(self.target_idx) >= len(hidden_states):
299 | continue
300 |
301 | # target_idx == -1:
302 | # source_hss: BOS [this is an example sentence]
303 | # target_toks: [BOS this is an example] sentence
304 | # target_idx == -2:
305 | # (BOS this [is an) example sentence]
306 | # target_idx == -3:
307 | # (BOS this is) [an example sentence]
308 | if self.target_idx < 0:
309 | pos = abs(self.target_idx)
310 | source_hss.append(hidden_states[pos:])
311 | target_toks.append(tokens[:-pos])
312 | current_toks.append(tokens[pos:])
313 | doc_idxs.append(torch.tensor([doc_idx for _ in range(len(hidden_states[pos:]))], device='cpu'))
314 | entity_masks.append(entity_mask[pos:])
315 |
316 | # target_idx == 1:
317 | # source_hss: [BOS this is an example] sentence
318 | # target_toks: BOS [this is an example sentence]
319 | # target_idx == 2:
320 | # [BOS this (is an] example sentence)
321 | # target_idx == 3:
322 | # [BOS this is] (an example sentence)
323 | elif self.target_idx > 0:
324 | pos = abs(self.target_idx)
325 | source_hss.append(hidden_states[:-pos])
326 | target_toks.append(tokens[pos:])
327 | current_toks.append(tokens[:-pos])
328 | doc_idxs.append(torch.tensor([doc_idx for _ in range(len(hidden_states[:-pos]))], device='cpu'))
329 | entity_masks.append(entity_mask[:-pos])
330 |
331 | # exclude predicting bos_embedding -> BOS
332 | elif self.target_idx == 0:
333 | source_hss.append(hidden_states[1:])
334 | target_toks.append(tokens[1:])
335 | current_toks.append(tokens[1:])
336 | doc_idxs.append(torch.tensor([doc_idx for _ in range(len(hidden_states[1:]))], device='cpu'))
337 | entity_masks.append(entity_mask[1:])
338 |
339 | # sometimes docs are too small
340 | if len(source_hss) > 0:
341 | source_hss = torch.cat(source_hss)
342 | target_toks = torch.cat(target_toks)
343 | current_toks = torch.cat(current_toks)
344 | doc_idxs = torch.cat(doc_idxs)
345 | entity_masks = torch.cat(entity_masks)
346 | return (source_hss, target_toks, current_toks, doc_idxs, entity_masks)
347 | else:
348 | return None, None, None, None, None
--------------------------------------------------------------------------------