├── LICENSE ├── README.md ├── algs ├── efk.py ├── enn.py ├── ft.py └── mend.py ├── config ├── alg │ ├── efk.yaml │ ├── enn.yaml │ ├── ft.yaml │ └── mend.yaml ├── config.yaml ├── experiment │ ├── fc.yaml │ ├── gen.yaml │ └── qa.yaml └── model │ ├── bart-base.yaml │ ├── bert-base.yaml │ ├── distilgpt2.yaml │ ├── gpt2.yaml │ ├── gpt2large.yaml │ ├── gpt2medium.yaml │ ├── gpt2xl.yaml │ ├── gptj.yaml │ ├── gptneo27.yaml │ ├── t5large.yaml │ ├── t5small.yaml │ ├── t5xl.yaml │ └── t5xxl.yaml ├── data_classes ├── fever.py ├── nq.py ├── wiki.py └── zsre.py ├── editable_model.py ├── hooks.py ├── losses.py ├── models.py ├── nn.py ├── oracle.py ├── requirements.txt ├── run.py ├── trainer.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Eric Anthony Mitchell 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MEND: Model Editing Networks using Gradient Decomposition 2 | 3 | If you run into any issues with the code, you can open an issue and/or email me at `eric.mitchell@cs.stanford.edu` 4 | 5 | ## Setup 6 | 7 | ### Environment 8 | 9 | This codebase uses Python 3.7.9. Other versions may work as well. 10 | 11 | Create a virtualenv ([pyenv](https://github.com/pyenv/pyenv) can help with this) 12 | and install the dependencies: 13 | 14 | $ python -m venv env 15 | $ source env/bin/activate 16 | (env) $ pip install -r requirements.txt 17 | 18 | ### Data 19 | 20 | You can download the data needed for this project from 21 | [this Google Drive link](https://drive.google.com/drive/folders/1jAqBE45jEKR-5pMkwxlVQ0V8eKxqWbxA?usp=sharing). 22 | Unzip each sub-directory into `mend/data` and you should be good to go. 23 | 24 | ## Running the code 25 | 26 | Run MEND training/evaluation for distilGPT-2 on the wikitext editing problem with: 27 | 28 | (env) $ python -m run +alg=mend +experiment=gen +model=distilgpt2 data.wiki_webtext=False 29 | 30 | Other valid algs include `efk` ([KnowledgeEditor](https://arxiv.org/abs/2104.08164)) 31 | and `enn` ([Editable Neural Networks](https://arxiv.org/abs/2004.00345)). Other valid experiments 32 | include `fc` (FEVER fact checking) and `qa` (zsRE question-answering). Splits, rephrases, and pre-trained 33 | BERT and BART models **required** for running `fc` and `qa`, respectively, come from 34 | [De Cao et. al](https://arxiv.org/abs/2104.08164) (see repo [here](https://github.com/nicola-decao/KnowledgeEditor)). 35 | Check `config/model` for options for editable models (note that all models don't work for all experiments; GPT-style 36 | models only work with `gen`, seq2seq models only work with `qa`, and BERT only works with `fc`). 37 | 38 | Also note that in the paper, we sample locality data from different datasets depending on the model. 39 | By default, training will use [Natural Questions](https://ai.google.com/research/NaturalQuestions) 40 | data (not zsRE data) for computing drawdown in the `qa` experiment and 41 | [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/). For models such as the `distilgpt2` 42 | model we use (which was fine-tuned on wikitext) or the BART-base model, this behavior should be 43 | disabled with `data.wiki_webtext=False` or `data.zsre_nq=False`, respectively. 44 | 45 | ### Multi-edit experiments 46 | 47 | For multi-edit experiments, it's important to configure batch sizing correctly. In order to run training & 48 | evaluation with `5` edits, for example, we pass the arguments `data.n_edits=5 batch_size=6 val_batch_size=6`. 49 | 50 | This convention is interpreted as using batches of size 6 during training and validation, with 5 of those 51 | batch elements being used to apply edits to the model and the remaining (1) example used to compute drawdown. 52 | 53 | ## Citing the paper 54 | 55 | If this code or paper was useful, please consider using the following citation: 56 | 57 | @inproceedings{mitchell2022fast, 58 | title={Fast Model Editing at Scale}, 59 | author={Eric Mitchell and Charles Lin and Antoine Bosselut and Chelsea Finn and Christopher D Manning}, 60 | booktitle={International Conference on Learning Representations}, 61 | year={2022}, 62 | url={https://openreview.net/pdf?id=0DcZxeWfOPt} 63 | } 64 | -------------------------------------------------------------------------------- /algs/efk.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/nicola-decao/KnowledgeEditor/blob/main/src/models/one_shot_learner.py 2 | """ 3 | @inproceedings{decao2020editing, 4 | title={Editing Factual Knowledge in Language Models}, 5 | author={Nicola De Cao and Wilker Aziz and Ivan Titov}, 6 | booktitle={arXiv pre-print 2104.08164}, 7 | url={https://arxiv.org/abs/2104.08164}, 8 | year={2021}, 9 | } 10 | """ 11 | 12 | import torch 13 | import copy 14 | import higher 15 | from higher.patch import monkeypatch as make_functional 16 | from allennlp.modules.feedforward import FeedForward 17 | from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper 18 | import logging 19 | 20 | from editable_model import EditableModel 21 | from utils import _logits, _inner_params 22 | from models import BertClassifier 23 | from transformers import BartForConditionalGeneration, T5ForConditionalGeneration 24 | 25 | 26 | LOG = logging.getLogger(__name__) 27 | 28 | 29 | class EFK(EditableModel): 30 | def __init__(self, model, config, model_constructor, editor=None): 31 | super().__init__(model, config, model_constructor) 32 | 33 | if editor is None: 34 | if isinstance(model, BertClassifier): 35 | embedding = model.model.embeddings.word_embeddings.weight.data 36 | elif isinstance(model, BartForConditionalGeneration): 37 | embedding = model.model.shared.weight.data 38 | elif isinstance(model, T5ForConditionalGeneration): 39 | embedding = model.shared.weight.data 40 | else: 41 | embedding = model.transformer.wte.weight.data 42 | 43 | editor = OneShotLearner(model, vocab_dim=model.config.vocab_size, 44 | include_set=config.model.inner_params, 45 | embedding_dim=embedding.shape[-1], 46 | embedding_init=embedding.clone().to(torch.float32), 47 | max_scale=1) 48 | self.editor = editor 49 | 50 | def outer_parameters(self): 51 | return self.editor.parameters() 52 | 53 | def state_dict(self, destination=None, prefix="", keep_vars=False): 54 | state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict 55 | model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params 56 | for k in model_keys: 57 | del state_dict[f"model.{k}"] 58 | state_dict["model_config"] = self.model.config # Include model config 59 | return state_dict 60 | 61 | def load_state_dict(self, state_dict, strict: bool = True): 62 | config = state_dict["model_config"] 63 | del state_dict["model_config"] 64 | if config != self.model.config: 65 | LOG.info("Loaded model config doesn't match current model config.") 66 | LOG.info(f"Loaded: {config}") 67 | LOG.info(f"Current: {self.model.config}") 68 | 69 | res = super().load_state_dict(state_dict, False) 70 | # We should only have missing keys for the model, and no unexpected keys 71 | assert len([k for k in res.missing_keys if not k.startswith("model.")]) == 0, "Should only have missing keys for model." 72 | assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys" 73 | return res 74 | 75 | def edit(self, batch, condition, detach_history=False): 76 | outputs = _logits(self.model(**batch)) 77 | loss = self.edit_loss_fn(outputs, batch["labels"])["nll"] 78 | 79 | names = set([n for n, p in self.model.named_parameters()]) 80 | pset = set(self.config.model.inner_params) 81 | for p in pset: 82 | assert p in names, f"inner param {p} not in model" 83 | 84 | grads = torch.autograd.grad( 85 | loss, 86 | [p for (n, p) in _inner_params(self.model.named_parameters(), self.config.model.inner_params)] 87 | ) 88 | 89 | params_dict = self.editor( 90 | condition["input_ids"] if condition is not None else batch["input_ids"], 91 | condition["attention_mask"] if condition is not None else batch["attention_mask"], 92 | {n: g.to(torch.float32) for (n, g) in zip(self.config.model.inner_params, grads)}, 93 | ) 94 | 95 | edited_model = self.model 96 | if not isinstance(edited_model, higher.patch._MonkeyPatchBase): 97 | edited_model = make_functional(edited_model, in_place=True) 98 | 99 | def new_param(n, p): 100 | if n not in params_dict: 101 | return p 102 | 103 | if p.shape[0] == params_dict[n].shape[0]: 104 | return p + params_dict[n] 105 | else: 106 | return p + params_dict[n].T 107 | 108 | edited_model.update_params( 109 | [new_param(n, p) for (n, p) in edited_model.named_parameters()] 110 | ) 111 | 112 | if detach_history: 113 | new_model = self.model_constructor() 114 | new_model.load_state_dict(edited_model.state_dict()) 115 | edited_model = new_model 116 | 117 | return EFK(edited_model, self.config, self.model_constructor, editor=self.editor), {} 118 | 119 | 120 | class ConditionedParameter(torch.nn.Module): 121 | def __init__(self, parameter, condition_dim=1024, hidden_dim=128, max_scale=1): 122 | super().__init__() 123 | self.parameter_shape = parameter.shape 124 | 125 | if len(self.parameter_shape) == 2: 126 | self.conditioners = torch.nn.Sequential( 127 | torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)), 128 | torch.nn.Tanh(), 129 | torch.nn.utils.weight_norm( 130 | torch.nn.Linear( 131 | hidden_dim, 2 * (parameter.shape[0] + parameter.shape[1]) + 1 132 | ) 133 | ), 134 | ) 135 | elif len(self.parameter_shape) == 1: 136 | self.conditioners = torch.nn.Sequential( 137 | torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)), 138 | torch.nn.Tanh(), 139 | torch.nn.utils.weight_norm( 140 | torch.nn.Linear(hidden_dim, 2 * parameter.shape[0] + 1) 141 | ), 142 | ) 143 | else: 144 | raise RuntimeError() 145 | 146 | self.max_scale = max_scale 147 | 148 | def forward(self, inputs, grad): 149 | if inputs.shape[0] > 1: 150 | raise RuntimeError("Can only condition on batches of size 1") 151 | 152 | if len(self.parameter_shape) == 2: 153 | ( 154 | conditioner_cola, 155 | conditioner_rowa, 156 | conditioner_colb, 157 | conditioner_rowb, 158 | conditioner_norm, 159 | ) = self.conditioners(inputs).split( 160 | [ 161 | self.parameter_shape[1], 162 | self.parameter_shape[0], 163 | self.parameter_shape[1], 164 | self.parameter_shape[0], 165 | 1, 166 | ], 167 | dim=-1, 168 | ) 169 | 170 | a = conditioner_rowa.softmax(-1).T @ conditioner_cola 171 | b = conditioner_rowb.softmax(-1).T @ conditioner_colb 172 | 173 | elif len(self.parameter_shape) == 1: 174 | a, b, conditioner_norm = self.conditioners(inputs).split( 175 | [self.parameter_shape[0], self.parameter_shape[0], 1], dim=-1 176 | ) 177 | else: 178 | raise RuntimeError() 179 | 180 | if a.squeeze().shape[0] != grad.shape[0]: 181 | return self.max_scale * conditioner_norm.sigmoid().squeeze() * (grad * a.squeeze().T + b.squeeze().T) 182 | else: 183 | return self.max_scale * conditioner_norm.sigmoid().squeeze() * (grad * a.squeeze() + b.squeeze()) 184 | 185 | 186 | class LSTMConditioner(torch.nn.Module): 187 | def __init__( 188 | self, 189 | vocab_dim=30522, 190 | embedding_dim=768, 191 | hidden_dim=256, 192 | output_dim=1024, 193 | embedding_init=None, 194 | ): 195 | super().__init__() 196 | self.embedding = torch.nn.Embedding( 197 | num_embeddings=vocab_dim, 198 | embedding_dim=embedding_dim, 199 | padding_idx=0, 200 | _weight=embedding_init, 201 | ) 202 | self.lstm = PytorchSeq2VecWrapper( 203 | torch.nn.LSTM( 204 | input_size=embedding_dim, 205 | hidden_size=hidden_dim, 206 | num_layers=1, 207 | bidirectional=True, 208 | batch_first=True, 209 | ) 210 | ) 211 | self.linear = FeedForward( 212 | input_dim=hidden_dim * 2, 213 | num_layers=1, 214 | hidden_dims=[output_dim], 215 | activations=[torch.nn.Tanh()], 216 | ) 217 | 218 | def forward(self, inputs, masks): 219 | return self.linear(self.lstm(self.embedding(inputs), masks)) 220 | 221 | 222 | class OneShotLearner(torch.nn.Module): 223 | def __init__( 224 | self, 225 | model, 226 | vocab_dim, 227 | embedding_dim=768, 228 | hidden_dim=512, 229 | condition_dim=768, 230 | include_set={}, 231 | max_scale=1e-3, 232 | embedding_init=None, 233 | ): 234 | super().__init__() 235 | 236 | self.param2conditioner_map = { 237 | n: "{}_conditioner".format(n).replace(".", "_") 238 | for n, p in model.named_parameters() 239 | if n in include_set 240 | } 241 | 242 | self.conditioners = torch.nn.ModuleDict( 243 | { 244 | self.param2conditioner_map[n]: ConditionedParameter( 245 | p, 246 | condition_dim, 247 | hidden_dim, 248 | max_scale=max_scale, 249 | ) 250 | for n, p in model.named_parameters() 251 | if n in include_set 252 | } 253 | ) 254 | 255 | self.condition = LSTMConditioner( 256 | vocab_dim, 257 | embedding_dim, 258 | hidden_dim, 259 | condition_dim, 260 | embedding_init=embedding_init, 261 | ) 262 | 263 | def forward(self, inputs, masks, grads=None): 264 | condition = self.condition(inputs, masks) 265 | return { 266 | p: self.conditioners[self.param2conditioner_map[p]]( 267 | condition, 268 | grad=grads[p] if grads else None, 269 | ) 270 | for p, c in self.param2conditioner_map.items() 271 | } 272 | 273 | 274 | if __name__ == '__main__': 275 | import transformers 276 | import types 277 | 278 | model = transformers.GPT2LMHeadModel.from_pretrained("gpt2") 279 | 280 | config = types.SimpleNamespace() 281 | config.model.inner_params = [ 282 | "transformer.h.9.mlp.c_fc.weight", 283 | "transformer.h.9.mlp.c_proj.weight", 284 | "transformer.h.10.mlp.c_fc.weight", 285 | "transformer.h.10.mlp.c_proj.weight", 286 | "transformer.h.11.mlp.c_fc.weight", 287 | "transformer.h.11.mlp.c_proj.weight", 288 | ] 289 | 290 | efk = EFK(model, config, lambda: copy.deepcopy(model)).cuda() 291 | 292 | x = torch.arange(20).view(1, 20).cuda() + 1000 293 | orig_logits = efk(x).logits 294 | edited = efk.edit(x, masks=torch.ones_like(x), labels=x) 295 | post_logits = efk(x).logits 296 | 297 | assert torch.allclose(orig_logits, post_logits) 298 | 299 | orig_param = [p for (n, p) in efk.model.named_parameters() if n == config.model.inner_params[-1]][0] 300 | edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0] 301 | 302 | print((orig_param - edited_param).abs().max()) 303 | edited.eval() 304 | print(efk(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x))["nll"] 305 | edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x) 306 | print(efk(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss) 307 | import pdb; pdb.set_trace() 308 | -------------------------------------------------------------------------------- /algs/enn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import higher 4 | 5 | from editable_model import EditableModel 6 | from utils import _logits 7 | 8 | 9 | def fomaml_callback(all_grads): 10 | return [g.detach() if g is not None else None for g in all_grads] 11 | 12 | 13 | class ENN(EditableModel): 14 | def __init__(self, model, config, model_constructor, edit_lrs=None, edit_loss_fn=None): 15 | super().__init__(model, config, model_constructor) 16 | 17 | if edit_lrs is None: 18 | edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params))) 19 | self.edit_lrs = edit_lrs 20 | 21 | if edit_loss_fn is not None: 22 | self.edit_loss_fn = edit_loss_fn 23 | 24 | self.grad_callback = fomaml_callback if config.enn.first_order else lambda x: x 25 | 26 | def outer_parameters(self): 27 | if self.config.no_grad_layers is None: 28 | return super().outer_parameters() 29 | else: 30 | params = [self.edit_lrs] 31 | for m in self.model.modules(): 32 | if isinstance(m, nn.ModuleList): 33 | params.extend(list(m[self.config.no_grad_layers:].parameters())) 34 | return params 35 | 36 | def get_state_dict(self): 37 | return self.state_dict() 38 | 39 | def edit(self, batch, condition=None, detach_history=False): 40 | opt = torch.optim.SGD([{"params": p, "lr": None} 41 | for (n, p) in self.model.named_parameters() if n in self.config.model.inner_params]) 42 | with torch.enable_grad(), higher.innerloop_ctx( 43 | self.model, 44 | opt, 45 | override={'lr': list(self.edit_lrs)}, 46 | copy_initial_weights=False, 47 | track_higher_grads=self.training, 48 | in_place=True 49 | ) as (fmodel, diffopt): 50 | fmodel.eval() 51 | for edit_step in range(self.config.enn.n_edit_steps): 52 | output = _logits(fmodel(**batch)) 53 | loss = self.edit_loss_fn(output, batch["labels"])["nll"] 54 | diffopt.step(loss, grad_callback=self.grad_callback) 55 | 56 | if not detach_history: 57 | model_edited = fmodel 58 | else: 59 | model_edited = self.model_constructor() 60 | model_edited.load_state_dict(fmodel.state_dict()) 61 | model_edited.train(self.training) 62 | 63 | return ENN(model_edited, self.config, self.model_constructor, edit_lrs=self.edit_lrs, edit_loss_fn=self.edit_loss_fn), {} 64 | 65 | 66 | def test(): 67 | import transformers 68 | import types 69 | import copy 70 | 71 | model = transformers.GPT2LMHeadModel.from_pretrained("gpt2") 72 | 73 | config = types.SimpleNamespace() 74 | config.edit_lr = 0.1 75 | config.model.inner_params = [ 76 | "transformer.h.9.mlp.c_fc.weight", 77 | "transformer.h.9.mlp.c_proj.weight", 78 | "transformer.h.10.mlp.c_fc.weight", 79 | "transformer.h.10.mlp.c_proj.weight", 80 | "transformer.h.11.mlp.c_fc.weight", 81 | "transformer.h.11.mlp.c_proj.weight", 82 | ] 83 | config.enn = { 84 | "n_edit_steps": 2, 85 | "first_order": False 86 | } 87 | 88 | enn = ENN(model, config, lambda: copy.deepcopy(model)).cuda() 89 | 90 | x = torch.arange(100).view(5, 20).cuda() + 1000 91 | 92 | edited = enn.edit(x, masks=torch.ones_like(x), labels=x) 93 | 94 | orig_param = [p for (n, p) in enn.model.named_parameters() if n == config.model.inner_params[-1]][0] 95 | edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0] 96 | 97 | print((orig_param - edited_param).abs().max()) 98 | edited.eval() 99 | print(enn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"]) 100 | edited.edit_loss_fn(edited(x).logits, x).backward() 101 | import pdb; pdb.set_trace() 102 | 103 | 104 | if __name__ == '__main__': 105 | with torch.autograd.set_detect_anomaly(True): 106 | test() 107 | -------------------------------------------------------------------------------- /algs/ft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import higher 4 | from higher.patch import monkeypatch as make_functional 5 | import time 6 | 7 | from editable_model import EditableModel 8 | from utils import _logits, _inner_params 9 | from losses import kl_loc_loss 10 | 11 | 12 | class FT(EditableModel): 13 | """ 14 | Fine-tuning approach. Does not require training. 15 | """ 16 | 17 | def __init__(self, model, config, model_constructor, edit_loss_fn=None): 18 | super().__init__(model, config, model_constructor) 19 | 20 | if edit_loss_fn is not None: 21 | self.edit_loss_fn = edit_loss_fn 22 | 23 | self.locality_loss_fn = kl_loc_loss 24 | self.loc_ids = None 25 | self.loc_masks = None 26 | self.loc_sampler = None 27 | 28 | def _edit_loss(self, model, p0, p_edited, edit_batch): 29 | output = _logits(model(**edit_batch, params=p_edited)) 30 | loss_dict = self.edit_loss_fn(output, edit_batch["labels"]) 31 | l_edit, acc = loss_dict["nll"], loss_dict["acc"] 32 | if self.config.ft.locality.enabled: 33 | if self.config.ft.locality.oracle: 34 | loc_batch = next(self.loc_sampler)["loc"] 35 | else: 36 | raise NotImplementedError 37 | 38 | with torch.no_grad(): 39 | original_base_logits = _logits(model(**loc_batch, params=p0)) 40 | edited_base_logits = _logits(model(**loc_batch, params=p_edited)) 41 | kl_mask = loc_batch.get("decoder_attention_mask", loc_batch["attention_mask"]) 42 | l_loc = self.locality_loss_fn(original_base_logits, edited_base_logits, mask=kl_mask) 43 | loss = l_loc + self.config.ft.locality.cedit * l_edit 44 | else: 45 | l_loc = torch.tensor(float('nan')) 46 | loss = l_edit 47 | return loss, l_edit, l_loc, acc 48 | 49 | def accuracy(self, output, labels): 50 | if output.shape[-1] != 1: 51 | shifted_output = output.argmax(-1)[:, :-1] 52 | shifted_labels = labels[:, 1:] 53 | to_predict = (shifted_labels != -100).sum() 54 | correct = (shifted_output == shifted_labels).sum() 55 | acc = correct.float() / to_predict.float() 56 | else: 57 | acc = ((output > 0) == labels.bool()).sum().float() 58 | return acc 59 | 60 | def _edit_status(self, step, loss, l_edit, l_loc, acc, res_p): 61 | return ( 62 | f"step: {step}".ljust(14) + 63 | f"loss: {loss.item():.5f}".ljust(18) + 64 | f"l_edit: {l_edit.item():.5f}".ljust(18) + 65 | f"l_loc: {l_loc.item():.5f}".ljust(18) + 66 | f"acc: {acc.item():.2f}".ljust(14) + 67 | f"norm: {res_p.view(-1).norm().item():.5f}" 68 | ) 69 | 70 | def edit(self, batch, condition=None, detach_history=False): 71 | edit_model = self.model.eval() 72 | p0 = list(edit_model.named_parameters()) 73 | 74 | if not isinstance(edit_model, higher.patch._MonkeyPatchBase): 75 | edit_model = make_functional(self.model, track_higher_grads=False, in_place=True) 76 | 77 | packed_residuals = {} 78 | opt_params = [] 79 | for n, p in _inner_params(edit_model.named_parameters(), self.config.model.inner_params): 80 | if self.config.ft.rank is not None: 81 | u = nn.Parameter(torch.randn(p.shape[0], self.config.ft.rank, device=p.device) * self.config.ft.init_std) 82 | v = nn.Parameter(torch.zeros(self.config.ft.rank, p.shape[1], device=p.device)) 83 | res = [u, v] 84 | else: 85 | res = [nn.Parameter(torch.zeros_like(p, device=p.device))] 86 | 87 | packed_residuals[n] = res 88 | opt_params.extend(res) 89 | 90 | assert len(opt_params) == len(self.config.model.inner_params) 91 | OptClass = getattr(torch.optim, self.config.ft.opt) 92 | opt = OptClass(opt_params, lr=self.config.edit_lr) 93 | 94 | start_time = time.time() 95 | for edit_step in range(self.config.ft.max_edit_steps): 96 | if self.config.ft.time_limit is not None and (time.time() - start_time > self.config.ft.time_limit): 97 | break 98 | residuals = {k: v[0] @ v[1] if len(v) == 2 else v[0] for k, v in packed_residuals.items()} 99 | edited_params = [p if n not in residuals else p.detach() + residuals[n] for n, p in p0] 100 | loss, l_edit, l_loc, acc = self._edit_loss(edit_model, [p for n, p in p0], edited_params, batch) 101 | 102 | if self.config.ft.verbose: 103 | residual = list(residuals.values())[-1] 104 | print(self._edit_status(edit_step, loss, l_edit, l_loc, acc, residual), end="\r") 105 | 106 | if acc == 1.0: 107 | break 108 | 109 | for p, g in zip(opt_params, torch.autograd.grad(loss, opt_params)): 110 | p.grad = g 111 | torch.nn.utils.clip_grad_norm_(opt_params, self.config.grad_clip) 112 | opt.step() 113 | opt.zero_grad() 114 | 115 | if detach_history: 116 | new_model = self.model_constructor() 117 | new_model.load_state_dict(edit_model.state_dict()) 118 | edit_model = new_model 119 | edit_model.train(self.training) 120 | 121 | return FT(edit_model, self.config, self.model_constructor, self.edit_loss_fn), {} 122 | -------------------------------------------------------------------------------- /algs/mend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import copy 5 | import transformers 6 | import higher 7 | import logging 8 | from higher.patch import monkeypatch as make_functional 9 | from collections import defaultdict 10 | 11 | from editable_model import EditableModel 12 | from hooks import hook_model 13 | import nn as local_nn 14 | from utils import _logits, _inner_params 15 | 16 | LOG = logging.getLogger(__name__) 17 | 18 | 19 | def update_counter(x, m, s, k): 20 | new_m = m + (x - m) / k 21 | new_s = s + (x - m) * (x - new_m) 22 | 23 | return new_m, new_s 24 | 25 | 26 | class GradientTransform(nn.Module): 27 | def __init__(self, x_dim: int, delta_dim: int, cfg, n_modes = None): 28 | super().__init__() 29 | 30 | self.x_dim = x_dim 31 | self.delta_dim = delta_dim 32 | self.cfg = cfg 33 | if cfg.combine and (cfg.one_sided or cfg.x_only or cfg.delta_only): 34 | raise ValueError("cfg.combine cannot be used with one-sided MEND variants") 35 | 36 | self.norm_init = False 37 | self.register_buffer("u_mean", torch.full((x_dim,), float("nan"))) 38 | self.register_buffer("v_mean", torch.full((delta_dim,), float("nan"))) 39 | self.register_buffer("u_std", torch.full((x_dim,), float("nan"))) 40 | self.register_buffer("v_std", torch.full((delta_dim,), float("nan"))) 41 | self.register_buffer("u_s", torch.full((x_dim,), float("nan"))) 42 | self.register_buffer("v_s", torch.full((delta_dim,), float("nan"))) 43 | self.register_buffer("k", torch.full((1,), float("nan"))) 44 | 45 | MlpClass = getattr(local_nn, cfg.mlp_class) 46 | LOG.info(f"Building Gradient Transform with MLP class {MlpClass}") 47 | 48 | def delta_net(): 49 | return MlpClass(delta_dim, delta_dim, delta_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes) 50 | 51 | def x_net(): 52 | return MlpClass(x_dim, x_dim, x_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes) 53 | 54 | def combined_net(): 55 | return MlpClass(delta_dim + x_dim, delta_dim + x_dim, (delta_dim + x_dim) * 2, 56 | cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes) 57 | 58 | def ID(): 59 | return lambda x, mode=None: x 60 | 61 | if cfg.combine: 62 | self.mlp = combined_net() 63 | elif cfg.one_sided: 64 | if x_dim > delta_dim: 65 | self.mlp1, self.mlp2 = ID(), delta_net() 66 | else: 67 | self.mlp1, self.mlp2 = x_net(), ID() 68 | elif cfg.x_only: 69 | self.mlp1, self.mlp2 = x_net(), ID() 70 | elif cfg.delta_only: 71 | self.mlp1, self.mlp2 = ID(), delta_net() 72 | else: 73 | self.mlp1, self.mlp2 = x_net(), delta_net() 74 | 75 | def forward(self, u, v, param_idx=None): 76 | u, v = u.to(torch.float32), v.to(torch.float32) 77 | 78 | u_ = u.view(-1, u.shape[-1]) 79 | v_ = v.view(-1, v.shape[-1]) 80 | 81 | nz_mask = (u_ != 0).any(-1) * (v_ != 0).any(-1) # Skip batch elements with zero grad 82 | u_ = u_[nz_mask] 83 | v_ = v_[nz_mask] 84 | 85 | if self.training: 86 | for idx in range(u_.shape[0]): 87 | if not self.norm_init: 88 | self.u_mean = u_[idx].clone().detach() 89 | self.v_mean = v_[idx].clone().detach() 90 | self.u_s.zero_() 91 | self.v_s.zero_() 92 | self.k[:] = 1 93 | self.norm_init = True 94 | else: 95 | self.k += 1 96 | self.u_mean, self.u_s = update_counter(u_[idx], self.u_mean, self.u_s, self.k) 97 | self.v_mean, self.v_s = update_counter(v_[idx], self.v_mean, self.v_s, self.k) 98 | 99 | if self.k < 2: 100 | raise RuntimeError(f"Can't perform normalization with only {self.k} samples so far") 101 | self.u_std = (self.u_s / (self.k - 1)) ** 0.5 102 | self.v_std = (self.v_s / (self.k - 1)) ** 0.5 103 | 104 | if self.cfg.norm: 105 | u_input = (u_ - self.u_mean) / (self.u_std + 1e-7) 106 | v_input = (v_ - self.v_mean) / (self.v_std + 1e-7) 107 | else: 108 | u_input = u_ 109 | v_input = v_ 110 | 111 | if self.cfg.combine: 112 | output = self.mlp(torch.cat((u_input, v_input), -1), mode=param_idx) 113 | out1, out2 = output.split([u.shape[-1], v.shape[-1]], -1) 114 | return out1, out2 115 | else: 116 | return self.mlp1(u_input, mode=param_idx), self.mlp2(v_input, mode=param_idx) 117 | 118 | 119 | class MEND(EditableModel): 120 | def get_shape(self, p): 121 | # We need to flip the shapes since OpenAI gpt2 uses convs instead of linear 122 | return p.shape if isinstance(self.model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0]) 123 | 124 | def __init__(self, model, config, model_constructor, mend=None, edit_lrs=None): 125 | super().__init__(model, config, model_constructor) 126 | 127 | if edit_lrs is None: 128 | edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params))) 129 | self.edit_lrs = edit_lrs 130 | 131 | if not hasattr(self.model, "handles"): 132 | hook_model(self.model, self.config.model.inner_params) 133 | LOG.info(f"Hooked {len(self.model.handles)//2} modules") 134 | 135 | if config.mend.shared: 136 | shape_dict = defaultdict(list) 137 | for n, p in _inner_params(model.named_parameters(), self.config.model.inner_params): 138 | shape_dict[self.get_shape(p)].append(n) 139 | self.shape_dict = shape_dict 140 | 141 | if mend is None: 142 | if not config.mend.shared: 143 | self.mend = nn.ModuleDict({ 144 | n.replace(".", "#"): GradientTransform(*self.get_shape(p), config.mend) 145 | for (n, p) in _inner_params(model.named_parameters(), self.config.model.inner_params) 146 | }) 147 | else: 148 | self.mend = nn.ModuleDict({ 149 | str(tuple(s)): GradientTransform(*s, config.mend, len(shape_dict[s])) 150 | for s in shape_dict.keys() 151 | }) 152 | else: 153 | self.mend = mend 154 | 155 | def state_dict(self, destination=None, prefix="", keep_vars=False): 156 | state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict 157 | model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params 158 | for k in model_keys: 159 | del state_dict[f"model.{k}"] 160 | state_dict["model_config"] = self.model.config # Include model config 161 | return state_dict 162 | 163 | def load_state_dict(self, state_dict, strict: bool = True): 164 | config = state_dict["model_config"] 165 | del state_dict["model_config"] 166 | if config != self.model.config: 167 | LOG.info("Loaded model config doesn't match current model config.") 168 | LOG.info(f"Loaded: {config}") 169 | LOG.info(f"Current: {self.model.config}") 170 | 171 | res = super().load_state_dict(state_dict, False) 172 | # We should only have missing keys for the model, and no unexpected keys 173 | assert len([k for k in res.missing_keys if not k.startswith("model.")]) == 0, "Should only have missing keys for model." 174 | assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys" 175 | return res 176 | 177 | def outer_parameters(self): 178 | return list(self.mend.parameters()) + [self.edit_lrs] 179 | 180 | def edit(self, batch, condition=None, detach_history=False): 181 | outputs = _logits(self.model(**batch)) 182 | loss = self.edit_loss_fn(outputs, batch["labels"])["nll"] 183 | 184 | names = set([n for n, p in self.model.named_parameters()]) 185 | pset = set(self.config.model.inner_params) 186 | for p in pset: 187 | assert p in names, f"inner param {p} not in model" 188 | 189 | loss.backward() 190 | 191 | if self.config.mend.shared: 192 | param_idx = lambda n, p: self.shape_dict[self.get_shape(p)].index(n) if self.config.mend.shared else None # noqa: E731 193 | transformed_factors = { 194 | n: self.mend[str(tuple(self.get_shape(p)))](p.__x__, p.__delta__, param_idx(n, p)) 195 | for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params) 196 | } 197 | else: 198 | transformed_factors = { 199 | n: self.mend[n.replace(".", "#")](p.__x__, p.__delta__) 200 | for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params) 201 | } 202 | 203 | # Should be bi,bj->ji for nn.Linear, but GPT2 uses Conv1d instead... 204 | if isinstance(self.model, transformers.GPT2LMHeadModel): 205 | targ = "ij" 206 | else: 207 | targ = "ji" 208 | mean_grads = { 209 | n: torch.einsum(f"bi,bj->{targ}", x, delta) 210 | for n, (x, delta) in transformed_factors.items() 211 | } 212 | 213 | info_dict = {} 214 | idx = 0 215 | for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params): 216 | info_dict[f"grad/true_mag{idx}"] = p.grad.norm(2).item() 217 | info_dict[f"grad/pseudo_mag{idx}"] = mean_grads[n].norm(2).item() 218 | info_dict[f"grad/true_std{idx}"] = p.grad.std().item() 219 | info_dict[f"grad/pseudo_std{idx}"] = mean_grads[n].std().item() 220 | info_dict[f"grad/diff{idx}"] = (p.grad - mean_grads[n]).norm(2).item() 221 | info_dict[f"grad/cos{idx}"] = F.cosine_similarity(p.grad.reshape(-1), mean_grads[n].reshape(-1), dim=0).item() 222 | idx += 1 223 | 224 | self.model.zero_grad() 225 | 226 | assert len(self.edit_lrs) == len(list(mean_grads.items())) 227 | updates = {n: lr * g for lr, (n, g) in zip(self.edit_lrs, mean_grads.items())} 228 | 229 | edited_model = self.model 230 | if not isinstance(edited_model, higher.patch._MonkeyPatchBase): 231 | edited_model = make_functional(edited_model, in_place=True) 232 | 233 | new_params = [] 234 | for n, p in edited_model.named_parameters(): 235 | if n in pset: 236 | new_params.append(p + updates[n]) 237 | else: 238 | new_params.append(p) 239 | 240 | edited_model.update_params(new_params) 241 | 242 | if detach_history: 243 | new_model = self.model_constructor() 244 | new_model.load_state_dict(edited_model.state_dict()) 245 | edited_model = new_model 246 | 247 | return MEND(edited_model, self.config, self.model_constructor, self.mend, edit_lrs=self.edit_lrs), info_dict 248 | 249 | 250 | if __name__ == '__main__': 251 | import types 252 | 253 | model = transformers.GPT2LMHeadModel.from_pretrained("gpt2") 254 | 255 | config = types.SimpleNamespace() 256 | config.model.inner_params = [ 257 | "transformer.h.9.mlp.c_fc.weight", 258 | "transformer.h.9.mlp.c_proj.weight", 259 | "transformer.h.10.mlp.c_fc.weight", 260 | "transformer.h.10.mlp.c_proj.weight", 261 | "transformer.h.11.mlp.c_fc.weight", 262 | "transformer.h.11.mlp.c_proj.weight", 263 | ] 264 | config.edit_lr = 0.0001 265 | 266 | config.mend = types.SimpleNamespace() 267 | config.mend.n_hidden = 1 268 | config.mend = config.mend.__dict__ 269 | 270 | mend = MEND(model, config, lambda: copy.deepcopy(model)).cuda() 271 | import pdb; pdb.set_trace() 272 | mend.load_state_dict(torch.load("test_state.pt")) 273 | x = torch.arange(20).view(1, 20).cuda() + 1000 274 | orig_logits = mend(x) 275 | edited = mend.edit(x, masks=torch.ones_like(x), labels=x) 276 | post_logits = mend(x) 277 | 278 | assert torch.allclose(orig_logits, post_logits) 279 | 280 | orig_param = [p for (n, p) in mend.model.named_parameters() if n == config.model.inner_params[-1]][0] 281 | edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0] 282 | 283 | LOG.info((orig_param - edited_param).abs().max()) 284 | edited.eval() 285 | LOG.info(mend(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"]) 286 | edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x) 287 | LOG.info(mend(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss) 288 | -------------------------------------------------------------------------------- /config/alg/efk.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | alg: efk 4 | train_base: False 5 | lr: 1e-5 6 | -------------------------------------------------------------------------------- /config/alg/enn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | alg: enn 4 | train_base: True 5 | enn: 6 | first_order: False 7 | n_edit_steps: 1 8 | -------------------------------------------------------------------------------- /config/alg/ft.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | train_base: False 4 | alg: ft 5 | edit_lr: 5e-6 6 | ft: 7 | verbose: false 8 | max_edit_steps: 100 9 | time_limit: null 10 | locality: 11 | enabled: false 12 | oracle: true 13 | cedit: 1e-2 14 | batch_size: 1 15 | rank: null 16 | opt: RMSprop 17 | init_std: 0.01 18 | -------------------------------------------------------------------------------- /config/alg/mend.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | alg: mend 4 | lr: 1e-6 5 | train_base: False 6 | edit_lr: 1e-4 7 | lr_lr: 1e-4 8 | mend: 9 | one_sided: False 10 | n_hidden: 1 11 | hidden_dim: null 12 | init: id 13 | norm: True 14 | combine: True 15 | x_only: False 16 | delta_only: False 17 | act: relu 18 | rank: 1920 19 | mlp_class: IDMLP 20 | shared: True 21 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | alg: enn 2 | lr: 1e-5 3 | edit_lr: 1e-2 4 | seed: 0 5 | debug: False 6 | model_save_pt: 5000 7 | edit_bs: 1 8 | silent: False 9 | max_iters: 1000000 10 | log_interval: 100 11 | val_interval: 5000 12 | lr_lr: 1e-3 13 | batch_size: 2 14 | val_batch_size: 5 15 | accumulate_bs: 10 16 | cedit: 0.1 17 | cloc: 1.0 18 | cbase: 1.0 19 | val_steps: 500 20 | device: cuda 21 | base_loss: distill 22 | oracle: False 23 | train: True 24 | train_base: True 25 | opt: Adam 26 | single_batch: False 27 | archive: null 28 | grad_clip: 100. 29 | ref: null 30 | early_stop_patience: 20000 31 | early_stop_key: "loss/total_edit_val" 32 | dropout: 0.0 33 | tokenizer: null 34 | results_dir: null 35 | no_grad_layers: null 36 | eval_only: False 37 | half: False 38 | save: False 39 | 40 | model: 41 | pt: null 42 | 43 | data: 44 | path: null 45 | rephrase: true 46 | zsre_nq: true 47 | nq_path: ${hydra:runtime.cwd}/data/nq 48 | wiki_webtext: true 49 | n_edits: 1 50 | 51 | eval: 52 | verbose: True 53 | log_interval: 100 54 | final_eval: True 55 | 56 | hydra: 57 | run: 58 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f${uuid:}} 59 | sweep: 60 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f} 61 | subdir: ${hydra.job.num} -------------------------------------------------------------------------------- /config/experiment/fc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: fc 4 | dataset: fever 5 | cbase: 1.0 6 | -------------------------------------------------------------------------------- /config/experiment/gen.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: gen 4 | dataset: wikitext-103 5 | cbase: 10.0 6 | data: 7 | path: ${hydra:runtime.cwd}/data/10token/data/self_sample/ 8 | -------------------------------------------------------------------------------- /config/experiment/qa.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: qa 4 | dataset: zsre 5 | cbase: 1.0 6 | -------------------------------------------------------------------------------- /config/model/bart-base.yaml: -------------------------------------------------------------------------------- 1 | name: facebook/bart-base 2 | class_name: BartForConditionalGeneration 3 | tokenizer_class: BartTokenizerFast 4 | tokenizer_name: facebook/bart-base 5 | inner_params: 6 | - model.encoder.layers.4.fc1.weight 7 | - model.encoder.layers.4.fc2.weight 8 | - model.encoder.layers.5.fc1.weight 9 | - model.encoder.layers.5.fc2.weight 10 | - model.decoder.layers.4.fc1.weight 11 | - model.decoder.layers.4.fc2.weight 12 | - model.decoder.layers.5.fc1.weight 13 | - model.decoder.layers.5.fc2.weight 14 | 15 | pt: ${hydra:runtime.cwd}/data/zsre/QA_model.ckpt -------------------------------------------------------------------------------- /config/model/bert-base.yaml: -------------------------------------------------------------------------------- 1 | name: bert-base-uncased 2 | class_name: BertClassifier 3 | tokenizer_class: BertTokenizerFast 4 | tokenizer_name: bert-base-uncased 5 | inner_params: 6 | - model.encoder.layer.9.intermediate.dense.weight 7 | - model.encoder.layer.9.output.dense.weight 8 | - model.encoder.layer.10.intermediate.dense.weight 9 | - model.encoder.layer.10.output.dense.weight 10 | - model.encoder.layer.11.intermediate.dense.weight 11 | - model.encoder.layer.11.output.dense.weight 12 | 13 | pt: ${hydra:runtime.cwd}/data/fever/FC_model.ckpt -------------------------------------------------------------------------------- /config/model/distilgpt2.yaml: -------------------------------------------------------------------------------- 1 | name: MYX4567/distilgpt2-finetuned-wikitext2 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: distilgpt2 5 | inner_params: 6 | - transformer.h.3.mlp.c_fc.weight 7 | - transformer.h.3.mlp.c_proj.weight 8 | - transformer.h.4.mlp.c_fc.weight 9 | - transformer.h.4.mlp.c_proj.weight 10 | - transformer.h.5.mlp.c_fc.weight 11 | - transformer.h.5.mlp.c_proj.weight 12 | 13 | -------------------------------------------------------------------------------- /config/model/gpt2.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: gpt2 5 | inner_params: 6 | - transformer.h.9.mlp.c_proj.weight 7 | - transformer.h.9.mlp.c_fc.weight 8 | - transformer.h.10.mlp.c_proj.weight 9 | - transformer.h.10.mlp.c_fc.weight 10 | - transformer.h.11.mlp.c_proj.weight 11 | - transformer.h.11.mlp.c_fc.weight -------------------------------------------------------------------------------- /config/model/gpt2large.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2-large 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: gpt2-large 5 | inner_params: 6 | - transformer.h.33.mlp.c_proj.weight 7 | - transformer.h.33.mlp.c_fc.weight 8 | - transformer.h.34.mlp.c_proj.weight 9 | - transformer.h.34.mlp.c_fc.weight 10 | - transformer.h.35.mlp.c_proj.weight 11 | - transformer.h.35.mlp.c_fc.weight 12 | -------------------------------------------------------------------------------- /config/model/gpt2medium.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2-medium 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: gpt2-medium 5 | inner_params: 6 | - transformer.h.21.mlp.c_proj.weight 7 | - transformer.h.21.mlp.c_fc.weight 8 | - transformer.h.22.mlp.c_proj.weight 9 | - transformer.h.22.mlp.c_fc.weight 10 | - transformer.h.23.mlp.c_proj.weight 11 | - transformer.h.23.mlp.c_fc.weight -------------------------------------------------------------------------------- /config/model/gpt2xl.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2-xl 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: gpt2-xl 5 | inner_params: 6 | - transformer.h.45.mlp.c_proj.weight 7 | - transformer.h.45.mlp.c_fc.weight 8 | - transformer.h.46.mlp.c_proj.weight 9 | - transformer.h.46.mlp.c_fc.weight 10 | - transformer.h.47.mlp.c_proj.weight 11 | - transformer.h.47.mlp.c_fc.weight 12 | -------------------------------------------------------------------------------- /config/model/gptj.yaml: -------------------------------------------------------------------------------- 1 | name: EleutherAI/gpt-j-6B 2 | class_name: GPTJForCausalLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: EleutherAI/gpt-j-6B 5 | inner_params: 6 | - transformer.h.25.mlp.fc_in.weight 7 | - transformer.h.25.mlp.fc_out.weight 8 | - transformer.h.26.mlp.fc_in.weight 9 | - transformer.h.26.mlp.fc_out.weight 10 | - transformer.h.27.mlp.fc_in.weight 11 | - transformer.h.27.mlp.fc_out.weight 12 | -------------------------------------------------------------------------------- /config/model/gptneo27.yaml: -------------------------------------------------------------------------------- 1 | name: EleutherAI/gpt-neo-2.7B 2 | class_name: GPTNeoForCausalLM 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: EleutherAI/gpt-neo-2.7B 5 | inner_params: 6 | - transformer.h.29.mlp.c_fc.weight 7 | - transformer.h.29.mlp.c_proj.weight 8 | - transformer.h.30.mlp.c_fc.weight 9 | - transformer.h.30.mlp.c_proj.weight 10 | - transformer.h.31.mlp.c_fc.weight 11 | - transformer.h.31.mlp.c_proj.weight 12 | -------------------------------------------------------------------------------- /config/model/t5large.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-large-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-large-ssm-nq 5 | inner_params: 6 | - encoder.block.22.layer.1.DenseReluDense.wi.weight 7 | - encoder.block.22.layer.1.DenseReluDense.wo.weight 8 | - encoder.block.23.layer.1.DenseReluDense.wi.weight 9 | - encoder.block.23.layer.1.DenseReluDense.wo.weight 10 | - decoder.block.22.layer.2.DenseReluDense.wi.weight 11 | - decoder.block.22.layer.2.DenseReluDense.wo.weight 12 | - decoder.block.23.layer.2.DenseReluDense.wi.weight 13 | - decoder.block.23.layer.2.DenseReluDense.wo.weight 14 | 15 | pt: null -------------------------------------------------------------------------------- /config/model/t5small.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-small-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-small-ssm-nq 5 | inner_params: 6 | - encoder.block.6.layer.1.DenseReluDense.wi_0.weight 7 | - encoder.block.6.layer.1.DenseReluDense.wo.weight 8 | - encoder.block.7.layer.1.DenseReluDense.wi_0.weight 9 | - encoder.block.7.layer.1.DenseReluDense.wo.weight 10 | - decoder.block.6.layer.2.DenseReluDense.wi_0.weight 11 | - decoder.block.6.layer.2.DenseReluDense.wo.weight 12 | - decoder.block.7.layer.2.DenseReluDense.wi_0.weight 13 | - decoder.block.7.layer.2.DenseReluDense.wo.weight 14 | 15 | pt: null -------------------------------------------------------------------------------- /config/model/t5xl.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-xl-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-xl-ssm-nq 5 | inner_params: 6 | - encoder.block.22.layer.1.DenseReluDense.wi_0.weight 7 | - encoder.block.22.layer.1.DenseReluDense.wo.weight 8 | - encoder.block.23.layer.1.DenseReluDense.wi_0.weight 9 | - encoder.block.23.layer.1.DenseReluDense.wo.weight 10 | - decoder.block.22.layer.2.DenseReluDense.wi_0.weight 11 | - decoder.block.22.layer.2.DenseReluDense.wo.weight 12 | - decoder.block.23.layer.2.DenseReluDense.wi_0.weight 13 | - decoder.block.23.layer.2.DenseReluDense.wo.weight 14 | 15 | pt: null -------------------------------------------------------------------------------- /config/model/t5xxl.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-xxl-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-xxl-ssm-nq 5 | inner_params: 6 | - encoder.block.22.layer.1.DenseReluDense.wi_0.weight 7 | - encoder.block.22.layer.1.DenseReluDense.wo.weight 8 | - encoder.block.23.layer.1.DenseReluDense.wi_0.weight 9 | - encoder.block.23.layer.1.DenseReluDense.wo.weight 10 | - decoder.block.22.layer.2.DenseReluDense.wi_0.weight 11 | - decoder.block.22.layer.2.DenseReluDense.wo.weight 12 | - decoder.block.23.layer.2.DenseReluDense.wi_0.weight 13 | - decoder.block.23.layer.2.DenseReluDense.wo.weight 14 | 15 | pt: null -------------------------------------------------------------------------------- /data_classes/fever.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | from utils import EditBatchSampler, dict_to 6 | import random 7 | 8 | 9 | POSITIVE_CLASS = "SUPPORTS" 10 | 11 | 12 | class BinaryAugmentedKILT(Dataset): 13 | def __init__( 14 | self, 15 | tokenizer, 16 | data_path, 17 | config, 18 | max_length=32 19 | ): 20 | super().__init__() 21 | self.tokenizer = tokenizer 22 | self.data = [] 23 | self.config = config 24 | 25 | def extract(d): 26 | extracted = {k: d[k] for k in ["logit", "input", "prediction", "alternatives", "filtered_rephrases"]} 27 | extracted["label"] = d["output"][0]["answer"] 28 | return extracted 29 | 30 | with jsonlines.open(data_path) as f: 31 | for d in f: 32 | if len(d["alternatives"]) > 0 and len(d["filtered_rephrases"]) > 0: 33 | self.data.append(extract(d)) 34 | 35 | self.max_length = max_length 36 | 37 | def __len__(self): 38 | return len(self.data) 39 | 40 | def __getitem__(self, item): 41 | obj = self.data[item] 42 | rephrase = random.choice(self.data[item]["filtered_rephrases"]) 43 | output = { 44 | "label": obj["label"] == POSITIVE_CLASS, 45 | "src": obj["input"], 46 | "rephrase": rephrase, 47 | "pred": obj["prediction"] == POSITIVE_CLASS, 48 | "alt": obj["alternatives"][0] == POSITIVE_CLASS, 49 | "cond_flip": "{} >> {} || {}".format( 50 | obj["prediction"], 51 | obj["alternatives"][0], 52 | obj["input"], 53 | ), 54 | "cond_orig": "{} >> {} || {}".format( 55 | obj["prediction"], 56 | obj["prediction"], 57 | obj["input"], 58 | ), 59 | "logit": obj["logit"], 60 | } 61 | 62 | return output 63 | 64 | def collate_fn(self, batch): 65 | src = [b["src"] for b in batch] 66 | rephrase = [batch[-1]["rephrase"]] 67 | 68 | flip_label = np.random.uniform() > 0.5 69 | predictions = [b["pred"] for b in batch] 70 | labels = [b["label"] for b in batch] 71 | labels[-1] = predictions[-1] # the last element in the batch is special (the edit element) 72 | cond = [batch[-1]["cond_orig"]] 73 | if flip_label: 74 | labels[-1] = batch[-1]["alt"] 75 | cond = [batch[-1]["cond_flip"]] 76 | 77 | batches = {} 78 | for k1, v1 in {"": src, "cond_": cond, "rephrase_": rephrase}.items(): 79 | encoded = self.tokenizer( 80 | v1, 81 | return_tensors="pt", 82 | padding=True, 83 | max_length=self.max_length, 84 | truncation=True, 85 | ) 86 | for k2, v2 in encoded.items(): 87 | batches[f"{k1}{k2}"] = v2 88 | 89 | batches["predictions"] = torch.tensor(predictions).long().view(-1, 1) 90 | batches["labels"] = torch.tensor(labels).long().view(-1, 1) 91 | batches["raw"] = batch 92 | return batches 93 | 94 | def edit_generator(self, batch_size, n=None): 95 | if n is None: 96 | n = len(self) 97 | sampler = EditBatchSampler(n, memorize_mode=self.config.single_batch, seed=self.config.seed) 98 | while True: 99 | edit_idxs, loc_idxs = sampler.sample(batch_size) 100 | assert len(edit_idxs) == 1 101 | idxs = loc_idxs + edit_idxs 102 | toks = self.collate_fn([self[idx] for idx in idxs]) 103 | 104 | pass_keys = ["input_ids", "attention_mask", "labels"] 105 | edit_inner = {k: v[-1:] for k, v in toks.items() if k in pass_keys} 106 | if self.config.data.rephrase: 107 | edit_outer = {} 108 | edit_outer["input_ids"] = toks["rephrase_input_ids"] 109 | edit_outer["attention_mask"] = toks["rephrase_attention_mask"] 110 | edit_outer["labels"] = edit_inner["labels"] 111 | else: 112 | edit_outer = edit_inner 113 | loc = {k: v[:-1] for k, v in toks.items() if k in pass_keys} 114 | cond = {"input_ids": toks["cond_input_ids"], "attention_mask": toks["cond_attention_mask"]} 115 | 116 | batch = { 117 | "edit_inner": edit_inner, 118 | "edit_outer": edit_outer, 119 | "loc": loc, 120 | "cond": cond 121 | } 122 | yield dict_to(batch, self.config.device) 123 | -------------------------------------------------------------------------------- /data_classes/nq.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class NQDataset: 5 | def __init__(self, path: str, tokenizer, config): 6 | with open(path, "r") as f: 7 | self.data = json.load(f) 8 | 9 | self.questions = self.data["questions"] 10 | self.answers = self.data["answers"] 11 | self.tokenizer = tokenizer 12 | self.config = config 13 | 14 | def __getitem__(self, idx): 15 | idx = idx % len(self.questions) 16 | return self.questions[idx], self.answers[idx] 17 | 18 | @staticmethod 19 | def generate(out_path: str, prompt: bool = False, capitalize: bool = True, question_mark: bool = True): 20 | import datasets 21 | import os 22 | 23 | def process(text): 24 | if capitalize: 25 | text = text[0].capitalize() + text[1:] 26 | if question_mark: 27 | text = text + "?" 28 | if prompt: 29 | text = "nq question: " + text 30 | return text 31 | 32 | def extract(d): 33 | questions = [process(q["text"]) for q in d["question"]] 34 | answers = [[a['text'][0] for a in ann['short_answers'] if len(a['text'])] for ann in d['annotations']] 35 | questions = [q for q, a in zip(questions, answers) if len(a)] 36 | answers = [min(a, key=len) for a in answers if len(a)] 37 | return questions, answers 38 | 39 | train = datasets.load_dataset("natural_questions", split="train") 40 | tq, ta = extract(train) 41 | val = datasets.load_dataset("natural_questions", split="validation") 42 | vq, va = extract(val) 43 | 44 | if not os.path.exists(out_path): 45 | os.makedirs(out_path) 46 | with open(f"{out_path}/train.json", "w") as f: 47 | json.dump({"questions": tq, "answers": ta}, f) 48 | with open(f"{out_path}/validation.json", "w") as f: 49 | json.dump({"questions": vq, "answers": va}, f) 50 | 51 | 52 | if __name__ == "__main__": 53 | import argparse 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--out_path", type=str, default="data/nq") 56 | args = parser.parse_args() 57 | NQDataset.generate(args.out_path) 58 | -------------------------------------------------------------------------------- /data_classes/wiki.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from datasets import load_dataset 3 | import json 4 | from utils import EditBatchSampler, dict_to, scr 5 | import logging 6 | import random 7 | import copy 8 | 9 | LOG = logging.getLogger(__name__) 10 | 11 | 12 | def is_ascii(s): 13 | return all(ord(c) < 128 for c in s) 14 | 15 | 16 | def filter_text(iterator): 17 | valid = [] 18 | for text in iterator: 19 | if len(text.split(' ')) < 50: 20 | continue 21 | if not is_ascii(text): 22 | continue 23 | valid.append(text) 24 | 25 | return valid 26 | 27 | 28 | class GenDataset(Dataset): 29 | def __init__(self, split: str, tokenizer, config, edit_path: str, 30 | pct: int = 10, max_length: int = 200): 31 | version = 'wikitext-103-raw-v1' 32 | split_str = f'{split}[:{pct}%]' if split == "train" else split 33 | LOG.info(f"Loading wikitext version {version}, split {split_str}") 34 | base_samples = load_dataset( 35 | 'wikitext', 36 | version, 37 | cache_dir=scr(), 38 | split=split_str 39 | )["text"] 40 | self.base_samples = filter_text(base_samples) 41 | with open(edit_path + split[:5] + ".json", "r") as f: 42 | self.edit_samples = json.load(f) 43 | 44 | self.tok = tokenizer 45 | self.config = config 46 | self.max_length = max_length 47 | self.n_tokens = self.edit_samples["n_tokens"] 48 | 49 | len_base = len(self.base_samples) 50 | len_edit = len(self.edit_samples['original']) 51 | LOG.info(f"Loaded {len_base} wiki-103 samples and {len_edit} edit samples") 52 | 53 | if config.data.wiki_webtext: 54 | self.use_wiki = True 55 | LOG.info("** Using webtext for wiki base samples **") 56 | webtext = load_dataset('stas/openwebtext-10k', split="train", cache_dir=scr())["text"] 57 | n_train = int(len(webtext) * 0.9) 58 | if split == "train": 59 | self.base_samples = webtext[:n_train] 60 | else: 61 | self.base_samples = webtext[n_train:] 62 | else: 63 | self.use_wiki = False 64 | 65 | def edit_generator(self, batch_size, n=None): 66 | if n is None: 67 | n = len(self) 68 | sampler = EditBatchSampler(n, memorize_mode=self.config.single_batch, loc_disjoint=not self.use_wiki, 69 | seed=self.config.seed) 70 | while True: 71 | edit_idxs, loc_idxs = sampler.sample(batch_size) 72 | 73 | edit_batch = [self.edit_samples["completions"][idx] for idx in edit_idxs] 74 | loc_batch = [self.base_samples[idx % len(self.base_samples)] for idx in loc_idxs] 75 | 76 | edit_toks = self.tok(edit_batch, padding=True, return_tensors="pt") 77 | loc_toks = self.tok(loc_batch, padding=True, return_tensors="pt", 78 | truncation=self.config.data.wiki_webtext, max_length=self.max_length) 79 | 80 | edit_inner = {**edit_toks} 81 | edit_inner["labels"] = self.get_edit_labels(edit_toks["input_ids"]) 82 | 83 | edit_outer = copy.deepcopy(edit_inner) 84 | if self.config.data.rephrase: 85 | lens = (edit_outer["input_ids"] != -100).sum(-1) 86 | remove = random.randint(0, (min(lens) - self.n_tokens) // 2) 87 | for k, v in edit_outer.items(): 88 | edit_outer[k] = v[:, remove:] 89 | 90 | loc = {**loc_toks} 91 | loc["labels"] = self.get_labels(loc_toks["input_ids"]) 92 | cond = {**edit_toks} 93 | 94 | batch = { 95 | "edit_inner": edit_inner, 96 | "edit_outer": edit_outer, 97 | "loc": loc, 98 | "cond": cond 99 | } 100 | 101 | yield dict_to(batch, self.config.device) 102 | 103 | def __len__(self): 104 | return len(self.edit_samples["original"]) 105 | 106 | def _check_padding(self, ids): 107 | if (ids[:, 0] == self.tok.pad_token_id).any(): 108 | raise ValueError("Left-padding not supported for GPT2") 109 | 110 | def get_edit_labels(self, ids): 111 | self._check_padding(ids) 112 | 113 | labels = ids.clone() 114 | end_idxs = (labels != self.tok.pad_token_id).sum(-1) 115 | for batch_idx, end_idx in enumerate(end_idxs): 116 | labels[batch_idx, :end_idx - self.n_tokens] = -100 117 | labels[labels == self.tok.pad_token_id] = -100 118 | return labels 119 | 120 | def get_labels(self, ids): 121 | self._check_padding(ids) 122 | 123 | return ids.masked_fill(ids == self.tok.pad_token_id, -100) 124 | 125 | def __getitem__(self, idx): 126 | return self.base_samples[idx] 127 | -------------------------------------------------------------------------------- /data_classes/zsre.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | from torch.utils.data import Dataset 3 | import random 4 | from utils import EditBatchSampler, dict_to 5 | import torch 6 | from transformers import BartTokenizerFast, BartTokenizer 7 | import logging 8 | 9 | LOG = logging.getLogger(__name__) 10 | 11 | 12 | class Seq2SeqAugmentedKILT(Dataset): 13 | def __init__( 14 | self, 15 | tokenizer, 16 | data_path, 17 | config, 18 | max_length=32, 19 | return_view=False, 20 | all_views=False, 21 | ): 22 | super().__init__() 23 | self.tok = tokenizer 24 | self.data = [] 25 | self.config = config 26 | 27 | def extract(d): 28 | ex = {k: d[k] for k in ["input", "prediction", "alternatives", "filtered_rephrases", "output"]} 29 | if ex["input"] in ex["filtered_rephrases"]: 30 | ex["filtered_rephrases"].remove(ex["input"]) 31 | return ex 32 | 33 | with jsonlines.open(data_path) as f: 34 | for d in f: 35 | extracted = extract(d) 36 | if len(extracted["alternatives"]) > 0 and len(extracted["filtered_rephrases"]) > 0: 37 | self.data.append(extracted) 38 | 39 | self.max_length = max_length 40 | self.all_views = all_views 41 | self.return_view = return_view 42 | if self.config.data.zsre_nq and "train" not in data_path: 43 | self.use_nq = True 44 | LOG.info("** Using natural questions for zsre base samples **") 45 | from data_classes.nq import NQDataset 46 | self.nq = NQDataset(self.config.data.nq_path + ("/train.json" if "train" in data_path else "/validation.json"), 47 | tokenizer, config) 48 | else: 49 | self.use_nq = False 50 | 51 | def is_bart(self): 52 | return isinstance(self.tok, BartTokenizer) or isinstance(self.tok, BartTokenizerFast) 53 | 54 | def __len__(self): 55 | return len(self.data) 56 | 57 | def __getitem__(self, item, seed=None): 58 | new_label = random.choice(self.data[item]["alternatives"]) 59 | rephrase = random.choice(self.data[item]["filtered_rephrases"]) 60 | output = { 61 | "src": self.data[item]["input"], 62 | "pred": self.data[item]["prediction"], 63 | "rephrase": rephrase, 64 | "alt": new_label, 65 | "answers": [x["answer"] for x in self.data[item]["output"]], 66 | "cond": "{} >> {} || {}".format( 67 | self.data[item]["prediction"], 68 | new_label, 69 | self.data[item]["input"], 70 | ), 71 | } 72 | 73 | return output 74 | 75 | def collate_fn(self, batch): 76 | src = [b["src"] for b in batch] 77 | ne = self.config.data.n_edits 78 | trg = ( 79 | [b["answers"][0] for b in batch[:-ne]] + 80 | [b["alt"] for b in batch[-ne:]] 81 | ) 82 | 83 | batches = { 84 | f"{k1}_{k2}": v2 85 | for k1, v1 in { 86 | "src": src, 87 | "trg": trg, 88 | "cond": [b["cond"] for b in batch[-ne:]], 89 | "rephrase": [b["rephrase"] for b in batch[-ne:]], 90 | }.items() 91 | for k2, v2 in self.tok( 92 | v1, 93 | return_tensors="pt", 94 | padding=True, 95 | max_length=self.max_length, 96 | truncation=True, 97 | ).items() 98 | } 99 | 100 | if self.is_bart(): # For consistency with Cao et al 101 | batches["trg_input_ids"][:, 0] = self.tok.eos_token_id 102 | batches["raw"] = batch 103 | return batches 104 | 105 | def _check_padding(self, ids): 106 | if (ids[:, 0] == self.tok.pad_token_id).any(): 107 | raise ValueError("Left-padding not supported") 108 | 109 | def get_edit_labels(self, labels): 110 | return labels.masked_fill(labels == self.tok.pad_token_id, -100) 111 | 112 | def edit_generator(self, batch_size, n=None): 113 | if n is None: 114 | n = len(self) 115 | sampler = EditBatchSampler(n, memorize_mode=self.config.single_batch, loc_disjoint=not self.use_nq, seed=self.config.seed) 116 | 117 | while True: 118 | edit_idxs, loc_idxs = sampler.sample(batch_size) 119 | assert len(edit_idxs) == 1 120 | idxs = loc_idxs + edit_idxs 121 | toks = self.collate_fn([self[idx] for idx in idxs]) 122 | 123 | ne = self.config.data.n_edits 124 | edit_decoder_inputs = toks["trg_input_ids"][-ne:] 125 | edit_labels = self.get_edit_labels(edit_decoder_inputs) 126 | edit_attention_mask = toks["trg_attention_mask"][-ne:] 127 | 128 | edit_inner = {} 129 | edit_inner["input_ids"] = toks["src_input_ids"][-ne:] 130 | edit_inner["attention_mask"] = toks["src_attention_mask"][-ne:] 131 | if self.is_bart(): 132 | edit_inner["decoder_input_ids"] = edit_decoder_inputs 133 | edit_inner["decoder_attention_mask"] = edit_attention_mask 134 | edit_inner["labels"] = edit_labels 135 | 136 | if self.config.data.rephrase: 137 | edit_outer = {} 138 | edit_outer["input_ids"] = toks["rephrase_input_ids"] 139 | edit_outer["attention_mask"] = toks["rephrase_attention_mask"] 140 | if self.is_bart(): 141 | edit_outer["decoder_input_ids"] = edit_decoder_inputs 142 | edit_outer["decoder_attention_mask"] = edit_attention_mask 143 | edit_outer["labels"] = edit_labels 144 | else: 145 | edit_outer = edit_inner 146 | 147 | loc = {} 148 | if self.use_nq: 149 | batch = [self.nq[idx] for idx in loc_idxs] 150 | questions = [b[0] for b in batch] 151 | answers = [b[1] for b in batch] 152 | loc = dict(self.tok(questions, return_tensors="pt", padding=True, max_length=self.max_length, truncation=True)) 153 | trg_toks = dict(self.tok(answers, return_tensors="pt", padding=True, max_length=self.max_length, truncation=True)) 154 | if self.is_bart(): 155 | trg_toks["input_ids"][:, 0] = self.tok.eos_token_id 156 | loc["decoder_input_ids"] = trg_toks["input_ids"] 157 | loc["decoder_attention_mask"] = trg_toks["attention_mask"] 158 | loc["labels"] = self.get_edit_labels(trg_toks["input_ids"]) 159 | else: 160 | loc["input_ids"] = toks["src_input_ids"][:-ne] 161 | loc["attention_mask"] = toks["src_attention_mask"][:-ne] 162 | if self.is_bart(): 163 | loc["decoder_input_ids"] = toks["trg_input_ids"][:-ne] 164 | loc["decoder_attention_mask"] = toks["trg_attention_mask"][:-ne] 165 | loc["labels"] = self.get_edit_labels(toks["trg_input_ids"][:-ne]) 166 | 167 | cond = {k[5:]: v for k, v in toks.items() if k.startswith("cond")} 168 | 169 | batch = { 170 | "edit_inner": edit_inner, 171 | "edit_outer": edit_outer, 172 | "loc": loc, 173 | "cond": cond, 174 | "raw": toks["raw"] 175 | } 176 | 177 | yield dict_to(batch, self.config.device) 178 | -------------------------------------------------------------------------------- /editable_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from losses import masked_log_probs 4 | from utils import _logits, shift_targets 5 | 6 | 7 | class EditableModel(nn.Module): 8 | def __init__(self, model, config, model_constructor): 9 | super().__init__() 10 | 11 | self.model = model 12 | self.config = config 13 | self.model_constructor = model_constructor 14 | 15 | def _edit_loss_fn(pred, targ): 16 | return masked_log_probs(pred, targ, shift=shift_targets(self.config)) 17 | self.edit_loss_fn = _edit_loss_fn 18 | self.loc_loss_fn = _edit_loss_fn 19 | 20 | def edit(self, batch, condition=None, detach_history=False): 21 | raise NotImplementedError 22 | 23 | def forward(self, *inputs, **kwargs): 24 | return _logits(self.model(*inputs, **kwargs)) 25 | 26 | def outer_parameters(self): 27 | return self.parameters() 28 | 29 | def base_loss(self, input_ids, attention_masks, label_ids): 30 | pass 31 | -------------------------------------------------------------------------------- /hooks.py: -------------------------------------------------------------------------------- 1 | from utils import parent_module 2 | 3 | 4 | def linear_backward_hook(mod, grad_in, grad_out): 5 | if not hasattr(mod, "weight"): 6 | print(f"{mod} has no weight!") 7 | return 8 | 9 | if hasattr(mod.weight, "__x__"): 10 | assert len(grad_out) == 1 11 | # mod.weight.__bgrad__ = grad_out[0].unsqueeze(-1) * mod.__x__[0].unsqueeze(-2) 12 | mod.weight.__delta__ = grad_out[0].detach() 13 | else: 14 | print(f"{mod} has no __x__") 15 | 16 | 17 | def linear_forward_hook(mod, activations, output): 18 | assert len(activations) == 1 19 | mod.weight.__x__ = activations[0].detach() 20 | 21 | 22 | def hook_model(model, pnames): 23 | handles = [] 24 | for m in [parent_module(model, pname) for pname in pnames]: 25 | handles.append(m.register_full_backward_hook(linear_backward_hook)) 26 | handles.append(m.register_forward_hook(linear_forward_hook)) 27 | 28 | model.handles = handles 29 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def kl_loc_loss(pre, post, mask=None): 6 | pre = pre.to(torch.float32) 7 | post = post.to(torch.float32) 8 | 9 | sequence = pre.dim() == 3 10 | pre_ = pre.view(-1, pre.shape[-1]) 11 | post_ = post.view(pre_.shape) 12 | assert pre_.shape[0] == post_.shape[0] 13 | 14 | if not sequence: 15 | if pre_.shape[-1] == 1: # No masking needed for binary classification 16 | return (pre.sigmoid() * (F.logsigmoid(pre) - F.logsigmoid(post))).mean() + ( 17 | (-pre).sigmoid() * (F.logsigmoid(-pre) - F.logsigmoid(-post)) 18 | ).mean() 19 | else: # We have sequences of predictions; masking needed 20 | if pre_.shape[-1] > 1: 21 | assert mask is not None 22 | mask_ = mask.view(pre_.shape[0]) 23 | kl = (pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1))).sum(-1) 24 | return (kl * mask_).sum() / mask_.sum() 25 | 26 | raise NotImplementedError 27 | 28 | 29 | def binary_log_probs(pred, targ): 30 | neg_mask = torch.ones_like(pred) 31 | neg_mask[targ == 0] *= -1 32 | pred = pred * neg_mask 33 | log_probs = F.logsigmoid(pred) 34 | acc = (log_probs.exp() > 0.5).float().mean() 35 | return { 36 | "acc": acc, 37 | "log_prob": log_probs.mean(), 38 | "prob": log_probs.exp().mean(), 39 | "nll": -log_probs.mean(), 40 | "n_tokens": log_probs.shape[0] 41 | } 42 | 43 | 44 | def multiclass_log_probs(pred, targ, shift=True): 45 | NULL_TOKEN = 0 # a placeholder used for masked target locations 46 | 47 | pred = pred.clone() 48 | targ = targ.clone() 49 | if shift and pred.dim() == 3: # Dealing with sequences 50 | pred = pred[:, :-1] # Remove last prediction in sequence 51 | targ = targ[:, 1:] # Shift to align predictions and targets 52 | 53 | mask = targ != -100 54 | targ[~mask] = NULL_TOKEN # Can be any valid token, since we'll throw them out 55 | unmasked_log_probs = pred.log_softmax(-1).gather(-1, targ.unsqueeze(-1)).squeeze(-1) 56 | 57 | pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN) 58 | correct = pred_ids == targ 59 | if pred.dim() == 3: 60 | correct = (pred_ids == targ).all(-1) # We want to get the whole sequence right 61 | acc = correct.float().mean() 62 | 63 | n_tokens = mask.float().sum() 64 | log_prob = (unmasked_log_probs * mask.float()).sum() / n_tokens 65 | prob = (unmasked_log_probs.exp() * mask.float()).sum() / n_tokens 66 | return { 67 | "acc": acc, 68 | "log_prob": log_prob, 69 | "prob": prob, 70 | "n_tokens": n_tokens, 71 | "nll": -log_prob 72 | } 73 | 74 | 75 | def masked_log_probs(pred, targ, shift=True): 76 | pred = pred.to(torch.float32) 77 | 78 | if not (pred.dim() == 2 or pred.dim() == 3): 79 | raise RuntimeError(f"Expected pred to have 2 or 3 dimensions, got {pred.shape}") 80 | 81 | if pred.shape[-1] == 1: 82 | return binary_log_probs(pred, targ) 83 | else: 84 | return multiclass_log_probs(pred, targ, shift=shift) 85 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import torch.nn as nn 4 | import re 5 | import logging 6 | from utils import scr 7 | 8 | 9 | LOG = logging.getLogger(__name__) 10 | 11 | 12 | class CastModule(nn.Module): 13 | def __init__(self, module: nn.Module, in_cast: torch.dtype = torch.float32, out_cast: torch.dtype = None): 14 | super().__init__() 15 | 16 | self.underlying = module 17 | self.in_cast = in_cast 18 | self.out_cast = out_cast 19 | 20 | def cast(self, obj, dtype): 21 | if dtype is None: 22 | return obj 23 | 24 | if isinstance(obj, torch.Tensor): 25 | return obj.to(dtype) 26 | else: 27 | return obj 28 | 29 | def forward(self, *args, **kwargs): 30 | args = tuple(self.cast(a, self.in_cast) for a in args) 31 | kwargs = {k: self.cast(v, self.in_cast) for k, v in kwargs.items()} 32 | outputs = self.underlying(*args, **kwargs) 33 | if isinstance(outputs, torch.Tensor): 34 | outputs = self.cast(outputs, self.out_cast) 35 | elif isinstance(outputs, tuple): 36 | outputs = tuple(self.cast(o, self.out_cast) for o in outputs) 37 | else: 38 | raise RuntimeError(f"Not sure how to cast type {type(outputs)}") 39 | return outputs 40 | 41 | def extra_repr(self): 42 | return f"in_cast: {self.in_cast}\nout_cast: {self.out_cast}" 43 | 44 | 45 | class BertClassifier(torch.nn.Module): 46 | def __init__(self, model_name, hidden_dim=768): 47 | super().__init__() 48 | self.model = transformers.BertModel.from_pretrained(model_name, cache_dir=scr()) 49 | self.classifier = torch.nn.Linear(hidden_dim, 1) 50 | 51 | @property 52 | def config(self): 53 | return self.model.config 54 | 55 | def forward(self, *args, **kwargs): 56 | filtered_kwargs = {k: v for k, v in kwargs.items() if k != "labels"} 57 | return self.classifier(self.model(*args, **filtered_kwargs)[1]) 58 | 59 | 60 | def get_model(config): 61 | if config.model.class_name == "BertClassifier": 62 | model = BertClassifier(config.model.name) 63 | else: 64 | ModelClass = getattr(transformers, config.model.class_name) 65 | LOG.info(f"Loading model class {ModelClass} with name {config.model.name} from cache dir {scr()}") 66 | model = ModelClass.from_pretrained(config.model.name, cache_dir=scr()) 67 | 68 | if config.model.pt is not None: 69 | LOG.info(f"Loading model initialization from {config.model.pt}") 70 | state_dict = torch.load(config.model.pt, map_location="cpu") 71 | 72 | try: 73 | model.load_state_dict(state_dict) 74 | except RuntimeError: 75 | LOG.info("Default load failed; stripping prefix and trying again.") 76 | state_dict = {re.sub("^model.", "", k): v for k, v in state_dict.items()} 77 | 78 | model.load_state_dict(state_dict) 79 | 80 | LOG.info("Loaded model initialization") 81 | 82 | if config.dropout is not None: 83 | n_reset = 0 84 | for m in model.modules(): 85 | if isinstance(m, nn.Dropout): 86 | m.p = config.dropout 87 | n_reset += 1 88 | 89 | if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout 90 | if isinstance(m.dropout, float): 91 | m.dropout = config.dropout 92 | n_reset += 1 93 | 94 | if hasattr(m, "activation_dropout"): # Requires for BART, which uses F.dropout 95 | if isinstance(m.activation_dropout, float): 96 | m.activation_dropout = config.dropout 97 | n_reset += 1 98 | 99 | LOG.info(f"Set {n_reset} dropout modules to p={config.dropout}") 100 | 101 | param_names = [n for n, _ in model.named_parameters()] 102 | bad_inner_params = [p for p in config.model.inner_params if p not in param_names] 103 | if len(bad_inner_params) != 0: 104 | raise ValueError(f"Params {bad_inner_params} do not exist in model of type {type(model)}.") 105 | 106 | if config.no_grad_layers is not None: 107 | if config.half: 108 | model.bfloat16() 109 | 110 | def upcast(mod): 111 | modlist = None 112 | for child in mod.children(): 113 | if isinstance(child, nn.ModuleList): 114 | assert modlist is None, f"Found multiple modlists for {mod}" 115 | modlist = child 116 | if modlist is None: 117 | raise RuntimeError("Couldn't find a ModuleList child") 118 | 119 | LOG.info(f"Setting {len(modlist) - config.no_grad_layers} modules to full precision, with autocasting") 120 | modlist[config.no_grad_layers:].to(torch.float32) 121 | modlist[config.no_grad_layers] = CastModule(modlist[config.no_grad_layers]) 122 | modlist[-1] = CastModule(modlist[-1], in_cast=torch.float32, out_cast=torch.bfloat16) 123 | 124 | parents = [] 125 | if hasattr(model, "transformer"): 126 | parents.append(model.transformer) 127 | if hasattr(model, "encoder"): 128 | parents.append(model.encoder) 129 | if hasattr(model, "decoder"): 130 | parents.append(model.decoder) 131 | if hasattr(model, "model"): 132 | parents.extend([model.model.encoder, model.model.decoder]) 133 | 134 | for t in parents: 135 | t.no_grad_layers = config.no_grad_layers 136 | if config.half: 137 | upcast(t) 138 | 139 | if config.half: 140 | idxs = [] 141 | for p in config.model.inner_params: 142 | for comp in p.split('.'): 143 | if comp.isdigit(): 144 | idxs.append(int(comp)) 145 | max_idx, min_idx = str(max(idxs)), str(config.no_grad_layers) 146 | for pidx, p in enumerate(config.model.inner_params): 147 | comps = p.split('.') 148 | if max_idx in comps or min_idx in comps: 149 | index = comps.index(max_idx) if max_idx in comps else comps.index(min_idx) 150 | comps.insert(index + 1, 'underlying') 151 | new_p = '.'.join(comps) 152 | LOG.info(f"Replacing config.model.inner_params[{pidx}] '{p}' -> '{new_p}'") 153 | config.model.inner_params[pidx] = new_p 154 | 155 | return model 156 | 157 | 158 | def get_tokenizer(config): 159 | tok_name = config.model.tokenizer_name if config.model.tokenizer_name is not None else config.model.name 160 | return getattr(transformers, config.model.tokenizer_class).from_pretrained(tok_name, cache_dir=scr()) 161 | 162 | 163 | if __name__ == '__main__': 164 | m = BertClassifier("bert-base-uncased") 165 | m(torch.arange(5)[None, :]) 166 | import pdb; pdb.set_trace() 167 | -------------------------------------------------------------------------------- /nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import logging 5 | 6 | LOG = logging.getLogger(__name__) 7 | 8 | 9 | class IDMLP(nn.Module): 10 | def __init__( 11 | self, 12 | indim: int, 13 | outdim: int, 14 | hidden_dim: int, 15 | n_hidden: int, 16 | init: str = None, 17 | act: str = None, 18 | rank: int = None, 19 | n_modes: int = None 20 | ): 21 | super().__init__() 22 | LOG.info(f"Building IDMLP ({init}) {[indim] * (n_hidden + 2)}") 23 | self.layers = nn.ModuleList( 24 | [ 25 | LRLinear(indim, indim, rank=rank, relu=idx < n_hidden, init=init, n_modes=n_modes) 26 | for idx in range(n_hidden + 1) 27 | ] 28 | ) 29 | 30 | def forward(self, x, mode=None): 31 | for layer in self.layers: 32 | x = layer(x, mode=mode) 33 | 34 | return x 35 | 36 | 37 | class LRLinear(nn.Module): 38 | def __init__(self, inf, outf, rank: int = None, relu=False, init="id", n_modes=None): 39 | super().__init__() 40 | 41 | mid_dim = min(rank, inf) 42 | if init == "id": 43 | self.u = nn.Parameter(torch.zeros(outf, mid_dim)) 44 | self.v = nn.Parameter(torch.randn(mid_dim, inf)) 45 | elif init == "xavier": 46 | self.u = nn.Parameter(torch.empty(outf, mid_dim)) 47 | self.v = nn.Parameter(torch.empty(mid_dim, inf)) 48 | nn.init.xavier_uniform_(self.u.data, gain=nn.init.calculate_gain("relu")) 49 | nn.init.xavier_uniform_(self.v.data, gain=1.0) 50 | else: 51 | raise ValueError(f"Unrecognized initialization {init}") 52 | 53 | if n_modes is not None: 54 | self.mode_shift = nn.Embedding(n_modes, outf) 55 | self.mode_shift.weight.data.zero_() 56 | self.mode_scale = nn.Embedding(n_modes, outf) 57 | self.mode_scale.weight.data.fill_(1) 58 | 59 | self.n_modes = n_modes 60 | self.bias = nn.Parameter(torch.zeros(outf)) 61 | self.inf = inf 62 | self.init = init 63 | 64 | def forward(self, x, mode=None): 65 | if mode is not None: 66 | assert self.n_modes is not None, "Linear got a mode but wasn't initialized for it" 67 | assert mode < self.n_modes, f"Input mode {mode} outside of range {self.n_modes}" 68 | assert x.shape[-1] == self.inf, f"Input wrong dim ({x.shape}, {self.inf})" 69 | 70 | pre_act = (self.u @ (self.v @ x.T)).T 71 | if self.bias is not None: 72 | pre_act += self.bias 73 | 74 | if mode is not None: 75 | if not isinstance(mode, torch.Tensor): 76 | mode = torch.tensor(mode).to(x.device) 77 | scale, shift = self.mode_scale(mode), self.mode_shift(mode) 78 | pre_act = pre_act * scale + shift 79 | 80 | # need clamp instead of relu so gradient at 0 isn't 0 81 | acts = pre_act.clamp(min=0) 82 | if self.init == "id": 83 | return acts + x 84 | else: 85 | return acts 86 | 87 | 88 | class MLP(nn.Module): 89 | def __init__( 90 | self, 91 | indim: int, 92 | outdim: int, 93 | hidden_dim: int, 94 | n_hidden: int, 95 | init: str = "xavier_uniform", 96 | act: str = "relu", 97 | rank: int = None, 98 | ): 99 | super().__init__() 100 | 101 | self.init = init 102 | 103 | if act == "relu": 104 | self.act = nn.ReLU() 105 | elif act == "learned": 106 | self.act = ActMLP(10, 1) 107 | else: 108 | raise ValueError(f"Unrecognized activation function '{act}'") 109 | 110 | if hidden_dim is None: 111 | hidden_dim = outdim * 2 112 | 113 | if init.startswith("id") and outdim != indim: 114 | LOG.info(f"Overwriting outdim ({outdim}) to be indim ({indim})") 115 | outdim = indim 116 | 117 | if init == "id": 118 | old_hidden_dim = hidden_dim 119 | if hidden_dim < indim * 2: 120 | hidden_dim = indim * 2 121 | 122 | if hidden_dim % indim != 0: 123 | hidden_dim += hidden_dim % indim 124 | 125 | if old_hidden_dim != hidden_dim: 126 | LOG.info( 127 | f"Overwriting hidden dim ({old_hidden_dim}) to be {hidden_dim}" 128 | ) 129 | 130 | if init == "id_alpha": 131 | self.alpha = nn.Parameter(torch.zeros(1, outdim)) 132 | 133 | dims = [indim] + [hidden_dim] * n_hidden + [outdim] 134 | LOG.info(f"Building ({init}) MLP: {dims} (rank {rank})") 135 | 136 | layers = [] 137 | for idx, (ind, outd) in enumerate(zip(dims[:-1], dims[1:])): 138 | if rank is None: 139 | layers.append(nn.Linear(ind, outd)) 140 | else: 141 | layers.append(LRLinear(ind, outd, rank=rank)) 142 | if idx < n_hidden: 143 | layers.append(self.act) 144 | 145 | if rank is None: 146 | if init == "id": 147 | if n_hidden > 0: 148 | layers[0].weight.data = torch.eye(indim).repeat( 149 | hidden_dim // indim, 1 150 | ) 151 | layers[0].weight.data[hidden_dim // 2:] *= -1 152 | layers[-1].weight.data = torch.eye(outdim).repeat( 153 | 1, hidden_dim // outdim 154 | ) 155 | layers[-1].weight.data[:, hidden_dim // 2:] *= -1 156 | layers[-1].weight.data /= (hidden_dim // indim) / 2.0 157 | 158 | for layer in layers: 159 | if isinstance(layer, nn.Linear): 160 | if init == "ortho": 161 | nn.init.orthogonal_(layer.weight) 162 | elif init == "id": 163 | if layer.weight.shape[0] == layer.weight.shape[1]: 164 | layer.weight.data = torch.eye(hidden_dim) 165 | else: 166 | gain = 3 ** 0.5 if (layer is layers[-1]) else 1.0 167 | nn.init.xavier_uniform_(layer.weight, gain=gain) 168 | 169 | layer.bias.data[:] = 0 170 | 171 | layers[-1].bias = None 172 | self.mlp = nn.Sequential(*layers) 173 | 174 | def forward(self, x): 175 | if self.init == "id_alpha": 176 | return x + self.alpha * self.mlp(x) 177 | else: 178 | return self.mlp(x) 179 | 180 | 181 | if __name__ == "__main__": 182 | logging.basicConfig( 183 | format="%(asctime)s - %(levelname)s [%(filename)s:%(lineno)d] %(message)s", 184 | level=logging.INFO, 185 | ) 186 | m0 = MLP(1000, 1000, 1500, 3) 187 | m1 = MLP(1000, 1000, 1500, 3, init="id") 188 | m2 = MLP(1000, 1000, 1500, 3, init="id_alpha") 189 | m3 = MLP(1000, 1000, 1500, 3, init="ortho", act="learned") 190 | 191 | x = 0.01 * torch.randn(999, 1000) 192 | 193 | y0 = m0(x) 194 | y1 = m1(x) 195 | y2 = m2(x) 196 | y3 = m3(x) 197 | 198 | print("y0", (y0 - x).abs().max()) 199 | print("y1", (y1 - x).abs().max()) 200 | print("y2", (y2 - x).abs().max()) 201 | print("y3", (y3 - x).abs().max()) 202 | 203 | assert not torch.allclose(y0, x) 204 | assert torch.allclose(y1, x) 205 | assert torch.allclose(y2, x) 206 | assert not torch.allclose(y3, x) 207 | import pdb; pdb.set_trace() # fmt: skip 208 | -------------------------------------------------------------------------------- /oracle.py: -------------------------------------------------------------------------------- 1 | from higher.patch import monkeypatch as make_functional 2 | from copy import deepcopy 3 | import torch 4 | import torch.nn as nn 5 | from losses import kl_loc_loss, masked_log_probs 6 | 7 | 8 | def test_rank1(model, dataset, config): 9 | model.eval() 10 | generator = dataset.edit_generator(21) 11 | 12 | history = [] 13 | for example in generator: 14 | edit_model = make_functional(model, track_higher_grads=False) 15 | residuals = {} 16 | opt_list = [] 17 | print(config.model.inner_params) 18 | for n, p in edit_model.named_parameters(): 19 | if n in config.model.inner_params: 20 | std = 0.01 21 | u = nn.Parameter(torch.randn(p.shape[0], 1, device=p.device) * std) 22 | v = nn.Parameter(torch.randn(1, p.shape[1], device=p.device) * std) 23 | assert (u@v).shape == p.shape, f"got {(u@v).shape}, expected {p.shape}" 24 | 25 | residuals[n] = (u,v) 26 | opt_list.extend([u,v]) 27 | 28 | res_opt = torch.optim.SGD(opt_list, lr=100) 29 | 30 | acc = 0 31 | it = 0 32 | ids_train = example["loc_ids"][:10] 33 | ids_val = example["loc_ids"][10:] 34 | with torch.inference_mode(): 35 | original_logits_train = model(ids_train) 36 | original_logits_val = model(ids_val) 37 | if hasattr(original_logits_train, "logits"): 38 | original_logits_train = original_logits_train.logits 39 | original_logits_val = original_logits_val.logits 40 | 41 | while acc < 1 and it < 1000: 42 | fast_params = [] 43 | for n, p in edit_model.named_parameters(): 44 | if n in residuals: 45 | u,v = residuals[n] 46 | fast_params.append(p.detach() + (u @ v)) 47 | else: 48 | fast_params.append(p.detach()) 49 | 50 | loc_pred = edit_model(ids_train, params=fast_params) 51 | if hasattr(loc_pred, "logits"): 52 | loc_pred = loc_pred.logits 53 | 54 | loc_loss = kl_loc_loss(original_logits_train, loc_pred) 55 | 56 | pred_log = edit_model(example["edit_inner_ids"], params=fast_params) 57 | if hasattr(pred_log, "logits"): 58 | pred_log = pred_log.logits 59 | prob_dict = masked_log_probs(pred_log, example["edit_inner_labels"]) 60 | edit_loss = prob_dict["nll"] 61 | acc = prob_dict["acc"] 62 | 63 | loss = loc_loss + 0.0002 * edit_loss 64 | with torch.inference_mode(): 65 | loc_pred_val = edit_model(ids_val, params=fast_params) 66 | if hasattr(loc_pred_val, "logits"): 67 | loc_pred_val = loc_pred_val.logits 68 | 69 | if pred_log.dim() == 3: 70 | facc = (pred_log.argmax(-1)[0,-10:-1] == example["edit_inner_labels"][0,-9:]).float().mean() 71 | ret = (original_logits_val.argmax(-1) == loc_pred_val.argmax(-1)).float().mean() 72 | else: 73 | facc = (pred_log > 0) == example["edit_inner_labels"] 74 | ret = ((original_logits_val > 0) == (loc_pred_val > 0)).float().mean() 75 | 76 | print(f"{it}, ({loss.item():.6f}, {loc_loss.item():.4f}, {edit_loss.item():.4f}), {facc.item():.2f}, {ret.item():.4f} {(u@v).view(-1).norm().item():.5f}", end="\r") 77 | 78 | for p, g in zip(opt_list, torch.autograd.grad(loss, opt_list)): 79 | p.grad = g 80 | res_opt.step() 81 | res_opt.zero_grad() 82 | 83 | it += 1 84 | 85 | if acc == 1: 86 | history.append(1) 87 | else: 88 | history.append(0) 89 | 90 | print() 91 | print(len(history), sum(history)/len(history), ret.item()) 92 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core 2 | numpy 3 | torch 4 | click==7.1.2 # Spacy breaks for click>=8.0 5 | spacy 6 | allennlp 7 | git+https://github.com/eric-mitchell/higher@master # For in-place functional models 8 | git+https://github.com/eric-mitchell/transformers@master # To enable gradient disabling for some models 9 | datasets 10 | jsonlines 11 | wandb 12 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import importlib 4 | import logging 5 | 6 | import hydra 7 | from omegaconf import OmegaConf 8 | import numpy as np 9 | import torch 10 | import utils 11 | 12 | 13 | from trainer import EditTrainer 14 | import models 15 | 16 | 17 | OmegaConf.register_new_resolver("uuid", lambda: utils.uuid()) 18 | 19 | 20 | logging.basicConfig(format='%(asctime)s - %(levelname)s [%(filename)s:%(lineno)d] %(message)s', 21 | level=logging.INFO) 22 | LOG = logging.getLogger(__name__) 23 | 24 | 25 | def add_padding(tokenizer, model): 26 | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 27 | model.resize_token_embeddings(len(tokenizer)) 28 | model.transformer.wte.weight.data[-1] = model.transformer.wte.weight.data.mean(0) 29 | 30 | 31 | @hydra.main(config_path='config', config_name='config') 32 | def run(config): 33 | LOG.info(f"\n\n{OmegaConf.to_yaml(config)}\n") 34 | base_dir = hydra.utils.get_original_cwd() 35 | LOG.info(f"Project base directory: {base_dir}") 36 | 37 | random.seed(config.seed) 38 | np.random.seed(config.seed) 39 | torch.manual_seed(config.seed) 40 | 41 | model = models.get_model(config) 42 | tokenizer = models.get_tokenizer(config) 43 | 44 | if config.task == "gen" or config.task == "wiki": 45 | add_padding(tokenizer, model) 46 | from data_classes.wiki import GenDataset 47 | 48 | train_set = GenDataset("train", tokenizer, config, config.data.path, pct=10) 49 | val_set = GenDataset("validation", tokenizer, config, config.data.path, pct=10) 50 | elif config.task == "fc" or config.task == "fever": 51 | from data_classes.fever import BinaryAugmentedKILT 52 | 53 | train_set = BinaryAugmentedKILT(tokenizer, f"{base_dir}/data/fever/fever-train-kilt.jsonl", config) 54 | val_set = BinaryAugmentedKILT(tokenizer, f"{base_dir}/data/fever/fever-dev-kilt.jsonl", config) 55 | elif config.task == "qa" or config.task == "zsre": 56 | from data_classes.zsre import Seq2SeqAugmentedKILT 57 | 58 | train_set = Seq2SeqAugmentedKILT(tokenizer, f"{base_dir}/data/zsre/structured_zeroshot-train-new_annotated_final.jsonl", 59 | config) 60 | val_set = Seq2SeqAugmentedKILT(tokenizer, f"{base_dir}/data/zsre/structured_zeroshot-dev-new_annotated_final.jsonl", 61 | config) 62 | else: 63 | raise ValueError(f"Unrecognized task {config.task}") 64 | 65 | alg_module = importlib.import_module(f"algs.{config.alg}") 66 | LOG.info(f"Loading class {config.alg.upper()} from module {alg_module}") 67 | AlgClass = getattr(alg_module, config.alg.upper()) 68 | alg = AlgClass(model, config, lambda: copy.deepcopy(model)) 69 | 70 | if config.alg == "ft" and config.ft.locality.enabled: 71 | if config.ft.locality.oracle: 72 | alg.loc_sampler = train_set.edit_generator(config.ft.locality.batch_size + 1) 73 | else: 74 | state = np.random.get_state() 75 | np.random.seed(0) 76 | loc_batch = next(train_set.edit_generator(config.ft.locality.batch_size + 1))["loc"] 77 | np.random.set_state(state) 78 | alg.loc_ids = loc_batch["input_ids"] 79 | alg.loc_masks = loc_batch["attention_mask"] 80 | 81 | trainer = EditTrainer(alg, config, train_set, val_set) 82 | trainer.run() 83 | 84 | 85 | if __name__ == "__main__": 86 | run() 87 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import tempfile 5 | import time 6 | import json 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | from omegaconf import OmegaConf 11 | 12 | import wandb 13 | 14 | from losses import kl_loc_loss 15 | import utils 16 | from utils import _logits, safe_backward, RunningStatAverager, EarlyStopper, formatted_timestamp, time_delta_seconds 17 | 18 | 19 | LOG = logging.getLogger(__name__) 20 | 21 | 22 | class BaseTrainer: 23 | def __init__(self, model, config, train_set: Dataset, val_set: Dataset): 24 | self.model = model 25 | self.config = config 26 | 27 | if config.train_base: 28 | self.original_model = self.model.model_constructor() 29 | self.original_model.load_state_dict(self.model.model.state_dict()) 30 | self.original_model.to(self.config.device) 31 | else: 32 | self.original_model = self.model.model 33 | 34 | self.model.to(self.config.device) 35 | 36 | self.train_set = train_set 37 | self.val_set = val_set 38 | 39 | if self.config.eval_only: 40 | # Eval once and quit 41 | self.config.max_iters = 0 42 | 43 | if not self.config.eval_only: 44 | self.OptimizerClass = getattr(torch.optim, config.opt) 45 | LOG.info(f"Building optimizer {self.OptimizerClass} with lr {config.lr}") 46 | self.opt = self.OptimizerClass(self.model.outer_parameters(), lr=config.lr) 47 | 48 | if config.archive is not None: 49 | archive, config.archive = utils.load_archive(str(config.archive)) 50 | self.model.load_state_dict(archive["model"]) 51 | del archive["model"] 52 | if not self.config.eval_only: 53 | self.opt.load_state_dict(archive["opt"]) 54 | del archive["opt"] 55 | 56 | self.archive = archive # Save for later to load e.g. lr_opt params if they exist 57 | else: 58 | self.archive = None 59 | 60 | # outfiles 61 | with open(os.getcwd() + "/config.json", "w") as f: 62 | json.dump(OmegaConf.to_container(config), f) 63 | 64 | model_dir = os.path.join(os.getcwd(), 'models') 65 | if not (self.config.debug and not self.config.save): 66 | os.makedirs(model_dir) 67 | run_date = os.getcwd().split('/')[-1] 68 | self.run_date = run_date 69 | safe_model_name = self.config.model.name.split("/")[-1] # Make sure no slashes 70 | self.save_path = f"{model_dir}/{safe_model_name}.{run_date}" 71 | 72 | if not (self.config.debug or self.config.eval_only): 73 | wandb_dir = tempfile.mkdtemp() 74 | wandb_name = f"{self.config.dataset} - {self.config.alg} - {safe_model_name} - {run_date}" 75 | if self.config.ref is not None: 76 | wandb_name += f" - {self.config.ref}" 77 | LOG.info(f"Writing wandb run \"{wandb_name}\" to {wandb_dir}") 78 | wandb.init( 79 | project="efk", 80 | entity="patchable-lm", 81 | config=utils.flatten_dict(self.config), 82 | name=wandb_name, 83 | dir=wandb_dir, 84 | tags=[self.config.ref] if self.config.ref is not None else None 85 | ) 86 | 87 | self.start_time = formatted_timestamp() 88 | 89 | def save_state(self, stats): 90 | if (self.config.debug and not self.config.save) or self.config.eval_only: 91 | return 92 | 93 | obj = { 94 | "model": self.model.state_dict(), 95 | "opt": self.opt.state_dict(), 96 | "lr_opt": self.lr_opt.state_dict() if self.lr_opt is not None else None, 97 | "val_stats": stats, 98 | "start_time": self.start_time, 99 | "elapsed_time": time_delta_seconds(self.start_time), 100 | "step": self.global_iter 101 | } 102 | LOG.info(f"Saving model to {self.save_path}") 103 | 104 | if os.path.exists(self.save_path): 105 | bk_path = f"{self.save_path}.bk" 106 | LOG.info(f"Moving old archive to {bk_path}") 107 | os.rename(self.save_path, bk_path) 108 | 109 | torch.save(obj, self.save_path) 110 | LOG.info("Write complete.") 111 | 112 | def echo(self, train_step, info_dict, pretty=False): 113 | if not self.config.silent: 114 | sep = "\n" if pretty else "; " 115 | 116 | def key_format(k): 117 | return k.ljust(20) if pretty else k 118 | LOG.info(f"Step {train_step}:") 119 | LOG.info(sep.join([f"{key_format(k)}: {v: 0.5f}" for k, v in info_dict.items()])) 120 | 121 | def wandb_log(self, step, info_dict): 122 | if not (self.config.debug or self.config.eval_only): 123 | wandb.log(info_dict, step=step) 124 | 125 | def run(self): 126 | averager = RunningStatAverager("train") 127 | stopper = EarlyStopper(self.config.early_stop_patience, self.config.early_stop_key) 128 | self.global_iter = 0 129 | for global_iter in range(0, self.config.max_iters): 130 | self.global_iter = global_iter 131 | 132 | if not self.config.eval_only: 133 | train_info = self.train_step() 134 | averager.add(train_info) 135 | 136 | if global_iter % self.config.log_interval == 0: 137 | avg_info = averager.average() 138 | averager.reset() 139 | self.echo(global_iter, avg_info) 140 | self.wandb_log(global_iter, avg_info) 141 | 142 | if global_iter % self.config.val_interval == 0: 143 | val_info = self.validate(steps=self.config.val_steps) 144 | self.echo(global_iter, val_info) 145 | self.wandb_log(global_iter, val_info) 146 | 147 | if stopper.update(self.global_iter, val_info): 148 | self.save_state(val_info) # New best 149 | 150 | if stopper.should_stop(): 151 | LOG.info(f"No decrease in {self.config.early_stop_key} for {self.config.early_stop_patience} steps") 152 | break 153 | 154 | if not self.config.eval_only: 155 | LOG.info(f"Training complete after {self.global_iter+1} steps.") 156 | 157 | if not self.config.eval.final_eval: 158 | return 159 | 160 | if not self.config.eval_only: 161 | if (not self.config.debug) or self.config.save: 162 | archive = torch.load(self.save_path, map_location="cpu") 163 | LOG.info(f"Loading best model from step {archive['step']}, elapsed time {archive['elapsed_time']}") 164 | self.model.to("cpu") 165 | self.model.load_state_dict(archive["model"]) 166 | self.model.to(self.config.device) 167 | 168 | val_steps = 200 if self.config.debug else None 169 | val_info = self.validate(log=True, steps=val_steps) 170 | self.echo(self.global_iter, val_info, pretty=True) 171 | self.wandb_log(self.global_iter + self.config.val_interval, val_info) 172 | 173 | if self.config.results_dir is not None: 174 | results_path = f"{self.config.results_dir}/results_{self.run_date}.json" 175 | latest_path = f"{self.config.results_dir}/results_latest.json" 176 | else: 177 | results_path = f"{os.getcwd()}/results.json" 178 | latest_path = f"{os.getcwd()}/results_latest.json" 179 | 180 | with open(results_path, "w") as f: 181 | json.dump({"results": val_info, "config": OmegaConf.to_container(self.config)}, f) 182 | LOG.info("Wrote results to:") 183 | LOG.info(results_path) 184 | 185 | shutil.copy(results_path, latest_path) 186 | LOG.info("Copied to:") 187 | LOG.info(latest_path) 188 | 189 | 190 | class EditTrainer(BaseTrainer): 191 | def __init__(self, model, config, train_set: Dataset, val_set: Dataset): 192 | super().__init__(model, config, train_set, val_set) 193 | 194 | self.edit_gen = self.train_set.edit_generator(batch_size=config.batch_size) 195 | if hasattr(model, "edit_lrs") and not self.config.eval_only: 196 | self.lr_opt = self.OptimizerClass([model.edit_lrs], config.lr_lr) 197 | if self.archive is not None: 198 | self.lr_opt.load_state_dict(self.archive["lr_opt"]) 199 | else: 200 | self.lr_opt = None 201 | 202 | if hasattr(self.config, "ft"): 203 | if getattr(self.config.ft, "use_locality", False): 204 | batch = next(self.edit_gen) 205 | self.model.loc_ids = batch["loc"]["input_ids"] 206 | self.model.loc_masks = batch["loc"]["attention_mask"] 207 | 208 | def edit_step(self, batch, training: bool): 209 | self.model.train(training) 210 | self.original_model.train(training) 211 | 212 | with torch.no_grad(): 213 | base_logits = self.model(**batch["loc"]) 214 | 215 | # Do the edit 216 | start = time.time() 217 | edited_model, model_info = self.model.edit(batch["edit_inner"], batch["cond"]) 218 | edit_time = time.time() - start 219 | 220 | with torch.set_grad_enabled(training): 221 | # Editing loss 222 | post_edit_logits = edited_model(**batch["edit_outer"]) 223 | l_edit = self.model.edit_loss_fn(post_edit_logits, batch["edit_outer"]["labels"])["nll"] 224 | 225 | # Locality loss 226 | post_base_logits = edited_model(**batch["loc"]) 227 | kl_mask = batch["loc"].get("decoder_attention_mask", batch["loc"]["attention_mask"]) 228 | l_loc = kl_loc_loss(base_logits.detach(), post_base_logits, mask=kl_mask) 229 | 230 | l_total_edit = self.config.cedit * l_edit + self.config.cloc * l_loc 231 | 232 | if training: 233 | safe_backward(l_total_edit, self.model.outer_parameters(), self.config.accumulate_bs) 234 | 235 | # Collect some useful metrics 236 | with torch.no_grad(): 237 | post_edit_dict = self.model.edit_loss_fn(post_edit_logits, batch["edit_outer"]["labels"]) 238 | post_loc_dict = self.model.loc_loss_fn(post_base_logits, batch["loc"]["labels"]) 239 | pre_loc_dict = self.model.loc_loss_fn(base_logits, batch["loc"]["labels"]) 240 | 241 | info_dict = {} 242 | info_dict['loss/edit'] = l_edit.item() 243 | info_dict['loss/loc'] = l_loc.item() 244 | info_dict['edit/acc'] = post_edit_dict["acc"].item() 245 | info_dict['edit/log_prob'] = post_edit_dict["log_prob"].item() 246 | info_dict['edit/prob'] = post_edit_dict["prob"].item() 247 | info_dict["acc/pre"] = pre_loc_dict["acc"].item() 248 | info_dict["acc/post"] = post_loc_dict["acc"].item() 249 | info_dict["nll/pre"] = pre_loc_dict["nll"].item() 250 | info_dict["nll/post"] = post_loc_dict["nll"].item() 251 | info_dict["n_tokens/pre"] = post_loc_dict["n_tokens"] 252 | info_dict["n_tokens/post"] = post_loc_dict["n_tokens"] 253 | info_dict["time/edit"] = edit_time 254 | 255 | # Base loss 256 | if self.config.train_base: 257 | with torch.no_grad(): 258 | original_logits = _logits(self.original_model(**batch["loc"])) 259 | original_loc_dict = self.model.loc_loss_fn(original_logits, batch["loc"]["labels"]) 260 | 261 | base_logits = self.model(**batch["loc"]) 262 | l_base = kl_loc_loss(original_logits.detach(), base_logits, mask=kl_mask.detach()) 263 | 264 | if training: 265 | safe_backward(l_base, self.model.outer_parameters(), self.config.accumulate_bs, allow_unused=True) 266 | 267 | info_dict['loss/base'] = l_base.item() 268 | info_dict['nll/original'] = original_loc_dict["nll"].item() 269 | info_dict['acc/original'] = original_loc_dict["acc"].item() 270 | info_dict["n_tokens/original"] = original_loc_dict["n_tokens"] 271 | else: 272 | l_base = torch.tensor(0.) 273 | 274 | l_total = l_total_edit + self.config.cbase * l_base 275 | 276 | info_dict["loss/total"] = l_total.item() 277 | info_dict["loss/total_edit"] = l_total_edit.item() 278 | info_dict["memory/alloc_max"] = torch.cuda.max_memory_allocated() 279 | info_dict["memory/res_max"] = torch.cuda.max_memory_reserved() 280 | info_dict = {**info_dict, **model_info} 281 | 282 | return l_total, l_edit, l_loc, l_base, info_dict 283 | 284 | def train_step(self): 285 | l_total, l_edit, l_loc, l_base, info_dict = self.edit_step(next(self.edit_gen), training=True) 286 | 287 | if self.global_iter > 0 and self.global_iter % self.config.accumulate_bs == 0: 288 | grad = torch.nn.utils.clip_grad_norm_(self.model.outer_parameters(), self.config.grad_clip, 289 | error_if_nonfinite=True) 290 | info_dict['grad'] = grad.item() 291 | 292 | self.opt.step() 293 | self.opt.zero_grad() 294 | 295 | if self.lr_opt is not None: 296 | self.lr_opt.step() 297 | self.lr_opt.zero_grad() 298 | 299 | for lr_idx, lr in enumerate(self.model.edit_lrs): 300 | info_dict[f'lr/lr{lr_idx}'] = lr.item() 301 | 302 | return info_dict 303 | 304 | def _inline_validation_log(self, step, stats, start_time, steps): 305 | elapsed = (time.time() - start_time) / (step + 1) 306 | prog = f"{step+1}/{steps}".ljust(20) 307 | acc = f"{stats['edit/acc_val']:<12.5f}" 308 | if self.config.task in ["fc", "qa"]: 309 | draw_pre = f"{stats['acc/pre_val']:<12.5f}" 310 | draw_post = f"{stats['acc/post_val']:<12.5f}" 311 | draw_diff = f"{stats['acc/pre_val']-stats['acc/post_val']:<12.5f}" 312 | dn = "acc" # drawdown name 313 | elif self.config.task in ["gen"]: 314 | draw_pre = f"{stats['perplexity/pre_val']:<12.5f}" 315 | draw_post = f"{stats['perplexity/post_val']:<12.5f}" 316 | draw_diff = f"{stats['perplexity/post_val']-stats['perplexity/pre_val']:<12.5f}" 317 | dn = "ppl" # drawdown name 318 | else: 319 | raise RuntimeError(f"Didn't recognize task {self.config.task}") 320 | 321 | LOG.info(f"Step {prog} edit: {acc} {dn}_pre: {draw_pre} {dn}_post: {draw_post} {dn}_delta: {draw_diff} it_time: {elapsed:.4f}") 322 | 323 | def validate(self, steps=None, log: bool = False): 324 | if steps is None or steps > len(self.val_set): 325 | steps = len(self.val_set) 326 | 327 | if log: 328 | LOG.info(f"Beginning evaluation for {steps} steps...") 329 | averager = RunningStatAverager("val") 330 | val_edit_gen = self.val_set.edit_generator(batch_size=self.config.val_batch_size, n=steps) 331 | 332 | start_time = time.time() 333 | for val_step in range(steps): 334 | _, _, _, _, info_dict = self.edit_step(next(val_edit_gen), training=False) 335 | averager.add(info_dict) 336 | 337 | if log and self.config.eval.verbose and (val_step + 1) % self.config.eval.log_interval == 0: 338 | self._inline_validation_log(val_step, averager.average(), start_time, steps) 339 | 340 | if log and self.config.eval.verbose: 341 | self._inline_validation_log(val_step, averager.average(), start_time, steps) 342 | elapsed = time.time() - start_time 343 | stats = averager.average() 344 | stats["eval_time/elapsed"] = elapsed 345 | stats["eval_time/average"] = elapsed / steps 346 | 347 | return stats 348 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import typing 3 | import numpy as np 4 | import struct 5 | import os 6 | import getpass 7 | import hydra 8 | import logging 9 | import torch 10 | from collections import defaultdict 11 | import math 12 | 13 | 14 | LOG = logging.getLogger(__name__) 15 | 16 | 17 | def _inner_params(named_parameters, inner_names): 18 | param_dict = dict(named_parameters) 19 | return [(n, param_dict[n]) for n in inner_names] 20 | 21 | 22 | def shift_targets(config): 23 | return "t5" not in config.model.name.lower() 24 | 25 | 26 | def scr(): 27 | if os.path.exists("/scr-ssd"): 28 | scr_dir = "/scr-ssd/" + getpass.getuser() 29 | else: 30 | scr_dir = "/scr/" + getpass.getuser() 31 | 32 | if not os.path.exists(scr_dir): 33 | os.makedirs(scr_dir) 34 | 35 | return scr_dir 36 | 37 | 38 | def uuid(digits=4): 39 | if not hasattr(uuid, "uuid_value"): 40 | uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10**digits) 41 | 42 | return uuid.uuid_value 43 | 44 | 45 | def formatted_timestamp(time=None): 46 | if time is None: 47 | time = datetime.datetime.now() 48 | return time.strftime("%d/%m/%Y-%H:%M:%S/%f") 49 | 50 | 51 | def time_delta_seconds(start, finish=None): 52 | assert type(start) == str 53 | 54 | t1 = datetime.datetime.strptime(start, "%d/%m/%Y-%H:%M:%S/%f") 55 | if finish is not None: 56 | assert type(finish) == str 57 | t2 = datetime.datetime.strptime(finish, "%d/%m/%Y-%H:%M:%S/%f") 58 | else: 59 | t2 = datetime.datetime.now() 60 | 61 | return (t2 - t1).total_seconds() 62 | 63 | 64 | def dict_to(d, device): 65 | new_dict = {} 66 | for k, v in d.items(): 67 | if isinstance(v, torch.Tensor): 68 | new_dict[k] = v.to(device) 69 | elif isinstance(v, dict): 70 | new_dict[k] = dict_to(v, device) 71 | else: 72 | new_dict[k] = v 73 | 74 | return new_dict 75 | 76 | 77 | def safe_backward(loss, parameters, accumulate=1, allow_unused=False): 78 | parameters = list(parameters) # Capture the generator output 79 | grads = torch.autograd.grad(loss, parameters, allow_unused=allow_unused) 80 | nan, inf = False, False 81 | for g in grads: 82 | if g is not None: 83 | nan |= g.isnan().any().item() 84 | inf |= g.isinf().any().item() 85 | 86 | if not (nan or inf): 87 | for p, g in zip(parameters, grads): 88 | if g is None: 89 | continue 90 | 91 | if p.grad is None: 92 | p.grad = g / accumulate 93 | else: 94 | p.grad += g / accumulate 95 | else: 96 | LOG.info(f"Skipping grad accumulation because inf: {inf} nan: {nan}") 97 | 98 | 99 | def _logits(x): 100 | return x if not hasattr(x, "logits") else x.logits 101 | 102 | 103 | def load_archive(path): 104 | import torch 105 | 106 | if not os.path.exists(path): 107 | # We've not passed an explicit path, but a part of the filename 108 | wd = hydra.utils.get_original_cwd() 109 | directories = ["outputs", "multirun"] 110 | matches = [] 111 | for d in directories: 112 | search = os.path.join(wd, d) 113 | for run_dir in os.listdir(search): 114 | if path in run_dir: 115 | matches.append(os.path.join(search, run_dir)) 116 | assert len(matches) == 1, f">1 matches for search {path}; specify exact path" 117 | 118 | full_run_dir = matches[0] 119 | if "0" in os.listdir(full_run_dir): 120 | full_run_dir = os.path.join(full_run_dir, "0") 121 | models_dir = os.path.join(full_run_dir, "models") 122 | models = os.listdir(models_dir) 123 | non_bk = [m for m in models if not m.endswith(".bk")] 124 | assert ( 125 | len(non_bk) == 1 126 | ), f"Expected a single model in {models_dir}, got {len(non_bk)}" 127 | path = os.path.join(models_dir, non_bk[0]) 128 | 129 | LOG.info(f"Loading checkpoint from {path}") 130 | archive = torch.load(path, map_location="cpu") 131 | LOG.info("Load complete.") 132 | 133 | return archive, path 134 | 135 | 136 | def flatten_dict(d): 137 | to_process = list(d.items()) 138 | output = {} 139 | while len(to_process): 140 | k, v = to_process.pop() 141 | if isinstance(v, typing.MutableMapping): 142 | to_process.extend([(f"{k}.{k_}", v_) for (k_, v_) in v.items()]) 143 | else: 144 | assert k not in output.keys(), "Somehow ended up with duplicate keys" 145 | output[k] = v 146 | 147 | return output 148 | 149 | 150 | class EarlyStopper: 151 | def __init__(self, patience: int, key: str): 152 | self.best_value = 1e9 153 | self.best_iter = 0 154 | self.current_iter = 0 155 | self.key = key 156 | self.patience = patience 157 | self._stop = False 158 | 159 | def update(self, idx, stats): 160 | assert self.key in stats, f"'{self.key}' not in stats dict" 161 | value = stats[self.key] 162 | new_best = value < self.best_value 163 | if new_best: 164 | self.best_value = value 165 | self.best_iter = idx 166 | 167 | self.current_iter = idx 168 | return new_best 169 | 170 | def should_stop(self): 171 | self._stop |= self.current_iter - self.best_iter >= self.patience 172 | return self._stop 173 | 174 | 175 | class RunningStatAverager: 176 | def __init__(self, suffix="", exclude=["grad/"], compute_ppl: bool = True): 177 | self.underlying = None 178 | self.suffix = suffix 179 | self.exclude = exclude 180 | self.compute_ppl = compute_ppl 181 | 182 | self.reset() 183 | 184 | def add(self, d: dict): 185 | for k, v in d.items(): 186 | if not any([k.startswith(prefix) for prefix in self.exclude]): 187 | if len(self.suffix): 188 | self.underlying[f"{k}_{self.suffix}"].append(v) 189 | else: 190 | self.underlying[k].append(v) 191 | 192 | def average(self): 193 | average = {} 194 | for k, v in self.underlying.items(): 195 | if not k.startswith("nll/"): 196 | average[k] = sum(v) / len(v) 197 | else: 198 | assert len(k.split("/")) == 2, f"Invalid key {k}" 199 | name = k.split("/")[1] 200 | token_counts = self.underlying[f"n_tokens/{name}"] 201 | total_nll = sum([nll * c for nll, c in zip(v, token_counts)]) 202 | average[k] = total_nll / sum(token_counts) 203 | if self.compute_ppl: 204 | average[f"perplexity/{name}"] = math.e ** average[k] 205 | 206 | return {k: v if not isinstance(v, torch.Tensor) else v.item() for k, v in average.items()} 207 | 208 | def reset(self): 209 | self.underlying = defaultdict(list) 210 | 211 | 212 | class EditBatchSampler: 213 | def __init__(self, n, n_edits=1, memorize_mode=False, loc_disjoint=True, seed=0): 214 | self.memorize_mode = memorize_mode 215 | self.n = n 216 | self.n_edits = n_edits 217 | self.loc_disjoint = loc_disjoint 218 | self.rng = np.random.default_rng(seed) 219 | self._init() 220 | 221 | def _init(self): 222 | self.perm = self.rng.permutation(self.n) 223 | self.edit_position = 0 224 | 225 | def sample(self, batch_size): 226 | assert ( 227 | batch_size > self.n_edits 228 | ), "Batch size is interpreted such that batch_size = n_edits + n_loc" 229 | 230 | if self.memorize_mode: 231 | return list(range(self.n_edits)), list(range(batch_size - self.n_edits)) 232 | 233 | if self.edit_position >= self.n: 234 | self._init() 235 | 236 | edit_idxs = self.perm[self.edit_position: self.edit_position + self.n_edits] 237 | self.edit_position += self.n_edits 238 | 239 | loc_idxs = self.rng.choice(self.n, batch_size - self.n_edits) 240 | if self.loc_disjoint: 241 | while len(np.intersect1d(edit_idxs, loc_idxs)) > 0: 242 | loc_idxs = self.rng.choice(self.n, batch_size - self.n_edits) 243 | 244 | return edit_idxs.tolist(), loc_idxs.tolist() 245 | 246 | 247 | def parent_module(model, pname): 248 | comps = pname.split('.') 249 | parent = model 250 | for comp in comps[:-1]: 251 | if hasattr(parent, comp): 252 | parent = getattr(parent, comp) 253 | elif comp.isdigit(): 254 | parent = parent[int(comp)] 255 | else: 256 | raise RuntimeError(f"Couldn't find child module {comp}") 257 | assert hasattr(parent, comps[-1]) 258 | return parent 259 | 260 | 261 | if __name__ == '__main__': 262 | import random 263 | stopper = EarlyStopper(1000, "loss/edit") 264 | 265 | data = [(100 * idx, {"loss/edit": 2 ** (1 - idx / 10) + random.random()}) for idx in range(100)] 266 | 267 | for d in data: 268 | stopper.update(*d) 269 | print(stopper.current_iter, stopper.should_stop(), stopper.best_iter, d[1]["loss/edit"]) 270 | --------------------------------------------------------------------------------