├── ConceptVectors_data ├── llama2-7b_concepts.json ├── llama2-7b_concepts_dev.json ├── llama2-7b_concepts_test.json ├── olmo-7b_concepts.json ├── olmo-7b_concepts_dev.json ├── olmo-7b_concepts_test.json └── relation_for_KE │ ├── .idea │ ├── .gitignore │ ├── deployment.xml │ ├── inspectionProfiles │ │ ├── Project_Default.xml │ │ └── profiles_settings.xml │ ├── misc.xml │ ├── modules.xml │ ├── relation_for_KE.iml │ └── vcs.xml │ ├── label_prop_dict.json │ ├── llama_relation_object.json │ ├── llama_relation_object_dev.json │ ├── llama_relation_object_test.json │ ├── olmo_relation_object.json │ ├── olmo_relation_object_dev.json │ ├── olmo_relation_object_test.json │ └── relation_to_template.json ├── Concept_Validation_Experiments ├── .ipynb_checkpoints │ ├── Concept_Validation_Experiment-checkpoint.ipynb │ └── olmo_Concept_Validation_Experiment-checkpoint.ipynb └── Concept_Validation_Experiments.ipynb ├── Jailbreak ├── __pycache__ │ └── evaluate_util.cpython-38.pyc ├── evaluate_util.py ├── jailbreak.ipynb ├── llama_jailbreak_German_qa.json └── olmo_jailbreak_German_qa.json ├── LICENSE ├── README.md ├── all_forget_llama.sh ├── all_forget_olmo.sh ├── config ├── ds_config.json ├── forget.yaml └── model_config.yaml ├── data_module.py ├── dataloader.py ├── evaluate_llama.py ├── evaluate_olmo.py ├── evaluate_util.py ├── forget.py ├── memit ├── .gitattributes ├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── baselines │ ├── README.md │ ├── ft │ │ ├── __init__.py │ │ ├── ft_hparams.py │ │ └── ft_main.py │ └── mend │ │ ├── README.md │ │ ├── __init__.py │ │ ├── algs │ │ ├── enn.py │ │ ├── ft.py │ │ └── mend.py │ │ ├── config │ │ ├── alg │ │ │ ├── efk.yaml │ │ │ ├── enn.yaml │ │ │ ├── ft.yaml │ │ │ └── mend.yaml │ │ ├── config.yaml │ │ ├── experiment │ │ │ ├── fc.yaml │ │ │ ├── gen.yaml │ │ │ └── qa.yaml │ │ └── model │ │ │ ├── bart-base.yaml │ │ │ ├── bert-base.yaml │ │ │ ├── distilgpt2.yaml │ │ │ ├── gpt2.yaml │ │ │ ├── gpt2large.yaml │ │ │ ├── gpt2medium.yaml │ │ │ ├── gpt2xl.yaml │ │ │ ├── gptj.yaml │ │ │ ├── gptneo27.yaml │ │ │ ├── t5large.yaml │ │ │ ├── t5small.yaml │ │ │ ├── t5xl.yaml │ │ │ └── t5xxl.yaml │ │ ├── data_classes │ │ ├── fever.py │ │ ├── nq.py │ │ ├── wiki.py │ │ └── zsre.py │ │ ├── editable_model.py │ │ ├── hooks.py │ │ ├── losses.py │ │ ├── mend_hparams.py │ │ ├── mend_main.py │ │ ├── models.py │ │ ├── nn.py │ │ ├── oracle.py │ │ ├── requirements.txt │ │ ├── run.py │ │ ├── trainer.py │ │ └── utils.py ├── dsets │ ├── __init__.py │ ├── attr_snippets.py │ ├── counterfact.py │ ├── knowns.py │ ├── tfidf_stats.py │ └── zsre.py ├── experiments │ ├── __init__.py │ ├── causal_trace.py │ ├── evaluate.py │ ├── evaluate_olmo.py │ ├── memit_jailbreak_evaluate.py │ ├── memit_jailbreak_evaluate_olmo.py │ ├── py │ │ ├── demo.py │ │ ├── eval_utils_counterfact.py │ │ └── eval_utils_zsre.py │ ├── summarize.py │ └── sweep.py ├── forget_memit.sh ├── forget_memit_olmo.sh ├── globals.yml ├── hparams │ ├── FT │ │ ├── EleutherAI_gpt-j-6B_constr.json │ │ ├── EleutherAI_gpt-j-6B_unconstr.json │ │ ├── EleutherAI_gpt-j-6B_wd.json │ │ ├── gpt2-large_constr.json │ │ ├── gpt2-medium_constr.json │ │ ├── gpt2-xl_attn.json │ │ ├── gpt2-xl_constr.json │ │ └── gpt2-xl_unconstr.json │ ├── MEMIT │ │ ├── EleutherAI_gpt-j-6B.json │ │ ├── gpt2-xl.json │ │ ├── llama2-7b.json │ │ └── olmo-7b.json │ ├── MEND │ │ ├── EleutherAI_gpt-j-6B.json │ │ ├── EleutherAI_gpt-j-6B_CF.json │ │ ├── gpt2-xl.json │ │ ├── gpt2-xl_CF.json │ │ └── gpt2-xl_zsRE.json │ └── ROME │ │ ├── EleutherAI_gpt-j-6B.json │ │ ├── EleutherAI_gpt-neox-20b.json │ │ ├── gpt2-large.json │ │ ├── gpt2-medium.json │ │ └── gpt2-xl.json ├── memit │ ├── __init__.py │ ├── compute_ks.py │ ├── compute_z.py │ ├── memit_hparams.py │ └── memit_main.py ├── notebooks │ ├── average_causal_effects.ipynb │ ├── causal_trace.ipynb │ ├── causal_trace_frozen_mlp_attn.ipynb │ ├── memit.ipynb │ └── vis │ │ ├── table_population.ipynb │ │ ├── table_population_zsre.ipynb │ │ ├── visualize_multi_results.ipynb │ │ └── visualize_sweep_results.ipynb ├── rome │ ├── README.md │ ├── __init__.py │ ├── compute_u.py │ ├── compute_v.py │ ├── layer_stats.py │ ├── repr_tools.py │ ├── rome_hparams.py │ ├── rome_main.py │ └── tok_dataset.py ├── scaling_curves.sh ├── scripts │ ├── causal_trace.sh │ ├── colab_reqs │ │ ├── additional.txt │ │ └── rome.txt │ ├── collect_layer_stats.sh │ ├── ipynb_drop_output.py │ ├── memit.yml │ ├── setup_clean_ipynb.sh │ └── setup_conda.sh ├── transformer_utils │ ├── README.md │ ├── __init__.py │ ├── requirements.txt │ ├── setup.py │ └── src │ │ └── transformer_utils │ │ ├── __init__.py │ │ ├── logit_lens │ │ ├── __init__.py │ │ ├── hooks.py │ │ ├── layer_names.py │ │ └── plotting.py │ │ ├── low_memory │ │ ├── __init__.py │ │ ├── enable.py │ │ ├── load.py │ │ └── load_context.py │ │ ├── partial_forward │ │ └── __init__.py │ │ └── util │ │ ├── __init__.py │ │ ├── module_utils.py │ │ ├── python_utils.py │ │ └── tfm_utils.py ├── util │ ├── __init__.py │ ├── generate.py │ ├── globals.py │ ├── hparams.py │ ├── logit_lens.py │ ├── nethook.py │ ├── perplexity.py │ └── runningstats.py └── zsre_evals.sh ├── requirements.txt └── utils.py /ConceptVectors_data/relation_for_KE/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /ConceptVectors_data/relation_for_KE/.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /ConceptVectors_data/relation_for_KE/.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 7 | -------------------------------------------------------------------------------- /ConceptVectors_data/relation_for_KE/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /ConceptVectors_data/relation_for_KE/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /ConceptVectors_data/relation_for_KE/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /ConceptVectors_data/relation_for_KE/.idea/relation_for_KE.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /ConceptVectors_data/relation_for_KE/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Jailbreak/__pycache__/evaluate_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuaihong/ConceptVectors/607591b415043f7692bc17a9748de3d8ff3fc0c7/Jailbreak/__pycache__/evaluate_util.cpython-38.pyc -------------------------------------------------------------------------------- /all_forget_llama.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 循环调用程序,传递不同的次序参数 4 | 5 | #Testing on ConceptVectors Test set of LLaMA 6 | for i in {0..94} #18 26 27 #{0..9} 7 | do 8 | python forget.py order=$i batch_size=4 gradient_accumulation_steps=8 num_epochs=1 gradient_checkpointing=True lr=2e-1 forget_loss=grad_ascent ft_type=Needle set=test 9 | done 10 | 11 | -------------------------------------------------------------------------------- /all_forget_olmo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #Testing on ConceptVectors Test set of OLMo 4 | for i in {0..161} 5 | do 6 | python forget.py order=$i batch_size=4 gradient_accumulation_steps=8 num_epochs=1 gradient_checkpointing=False lr=2e-1 forget_loss=grad_ascent ft_type=Needle set=test 7 | done 8 | 9 | 10 | #olmo jailbreak: 4 37 40 44 59 77 90 105 141 147 -------------------------------------------------------------------------------- /config/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "offload_optimizer": { 5 | "device": "none", 6 | "pin_memory": true 7 | }, 8 | "offload_param": { 9 | "device": "none", 10 | "pin_memory": true 11 | }, 12 | "overlap_comm": true, 13 | "contiguous_gradients": true, 14 | "sub_group_size": 1e9, 15 | "reduce_bucket_size": "auto", 16 | "stage3_prefetch_bucket_size": "auto", 17 | "stage3_param_persistence_threshold": "auto", 18 | "stage3_max_live_parameters": 1e9, 19 | "stage3_max_reuse_distance": 1e9, 20 | "stage3_gather_16bit_weights_on_model_save": true 21 | }, 22 | "train_batch_size": "auto", 23 | "train_micro_batch_size_per_gpu": "auto", 24 | "gradient_accumulation_steps": "auto", 25 | "bf16": { 26 | "enabled": true 27 | } 28 | } -------------------------------------------------------------------------------- /config/forget.yaml: -------------------------------------------------------------------------------- 1 | model_family: llama2-7b #olmo-7b 2 | model_path: /root/autodl-tmp/transformers/llama2-7b-chat-hf 3 | data_path: /root/Unlearn_Harry_Potter/Baselines/ConceptMap/ConceptMap_data 4 | results_save_path: /root/autodl-tmp/unlearn_results/${model_family}/${forget_loss} 5 | save_dir: ${model_path}/1GPU_${forget_loss}_${lr}_${split}_epoch${num_epochs}_batch${batch_size}_accum${gradient_accumulation_steps}_beta${beta}_ref${ref_policy}_eval${eval_steps}_seed${seed}_${run_index} 6 | 7 | set: test #dev or test 8 | order: 10 9 | LoRA: 10 | r: 0 11 | alpha: 32 12 | dropout: 0.05 13 | 14 | lr: 5e-5 #5e-4 lr should be bigger on Niddle 15 | split: wikipedia #wikipedia, pretraining_data 16 | batch_size: 1 17 | gradient_accumulation_steps: 16 18 | gradient_checkpointing: False #gradient_checkpointing is banned for olmo-7b 19 | num_epochs: 10 20 | forget_loss: npo_KL # type: grad_ascent, grad_diff, npo, npo_grad_diff, npo_KL, dpo 21 | ft_type: Full #Full, Sparse, MEMIT, all_value_vectors, Niddle, 22 | loss_threshold: -100 23 | 24 | npo_coeff: 1.0 25 | grad_diff_coeff: 1.0 26 | KL_coeff: 1.0 27 | ref_policy: fine_tuned 28 | beta: 0.1 29 | weight_decay: 0.01 30 | 31 | seed: 999 32 | run_index: 1 33 | overwrite_dir: True 34 | eval_steps: 1000000000 35 | warmup_steps: steps_per_epoch 36 | 37 | -------------------------------------------------------------------------------- /config/model_config.yaml: -------------------------------------------------------------------------------- 1 | llama2-7b: 2 | hf_key: "/root/autodl-tmp/transformers/llama2-7b-chat-hf" 3 | question_start_tag: "[INST] " 4 | question_end_tag: " [/INST]" 5 | answer_tag: "" 6 | flash_attention2: "false" 7 | gradient_checkpointing: "true" 8 | ft_model_path: "/root/autodl-tmp/transformers/final_ft_noLORA_5_epochs_inst_lr1e-05_llama2-7b_full/checkpoint-625" #this model will be used for unlearning by defauly 9 | olmo-7b: 10 | hf_key: "/root/autodl-tmp/transformers/OLMo-7B" 11 | question_start_tag: "Question: " 12 | question_end_tag: "\n" 13 | answer_tag: "Answer: " 14 | flash_attention2: "false" 15 | gradient_checkpointing: "false" 16 | ft_model_path: "/root/autodl-tmp/transformers/final_ft_noLORA_5_epochs_inst_lr1e-05_olmo-7b_full/checkpoint-625" 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /memit/.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb filter=clean_ipynb 2 | -------------------------------------------------------------------------------- /memit/.gitignore: -------------------------------------------------------------------------------- 1 | # Pipeline dumps, data directory 2 | results 3 | data 4 | !notebooks/data 5 | *_tmp_*_.json 6 | *_kmeng01g1gn* 7 | 8 | # Pre-trained hypernetworks 9 | baselines/*/weights 10 | 11 | # Mac specific 12 | .idea 13 | .vscode 14 | .DS_Store 15 | 16 | # Latex 17 | *.aux 18 | *.dvi 19 | *.fdb_latexmk 20 | *.fls 21 | *.log 22 | *.pdf 23 | *.synctex.gz 24 | *.out 25 | *.toc 26 | *.nps 27 | 28 | # Byte-compiled / optimized / DLL files 29 | __pycache__/ 30 | *.py[cod] 31 | *$py.class 32 | *.py.swp 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | share/python-wheels/ 52 | *.egg-info/ 53 | .installed.cfg 54 | *.egg 55 | MANIFEST 56 | 57 | # PyInstaller 58 | # Usually these files are written by a python script from a template 59 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 60 | *.manifest 61 | *.spec 62 | 63 | # Installer logs 64 | pip-log.txt 65 | pip-delete-this-directory.txt 66 | 67 | # Unit test / coverage reports 68 | htmlcov/ 69 | .tox/ 70 | .nox/ 71 | .coverage 72 | .coverage.* 73 | .cache 74 | nosetests.xml 75 | coverage.xml 76 | *.cover 77 | *.py,cover 78 | .hypothesis/ 79 | .pytest_cache/ 80 | cover/ 81 | 82 | # Translations 83 | *.mo 84 | *.pot 85 | 86 | # Django stuff: 87 | *.log 88 | local_settings.py 89 | db.sqlite3 90 | db.sqlite3-journal 91 | 92 | # Flask stuff: 93 | instance/ 94 | .webassets-cache 95 | 96 | # Scrapy stuff: 97 | .scrapy 98 | 99 | # Sphinx documentation 100 | docs/_build/ 101 | 102 | # PyBuilder 103 | .pybuilder/ 104 | target/ 105 | 106 | # Jupyter Notebook 107 | .ipynb_checkpoints 108 | 109 | # IPython 110 | profile_default/ 111 | ipython_config.py 112 | 113 | # pyenv 114 | # For a library or package, you might want to ignore these files since the code is 115 | # intended to run in multiple environments; otherwise, check them in: 116 | # .python-version 117 | 118 | # pipenv 119 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 120 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 121 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 122 | # install all needed dependencies. 123 | #Pipfile.lock 124 | 125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 126 | __pypackages__/ 127 | 128 | # Celery stuff 129 | celerybeat-schedule 130 | celerybeat.pid 131 | 132 | # SageMath parsed files 133 | *.sage.py 134 | 135 | # Environments 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | -------------------------------------------------------------------------------- /memit/CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | preferred-citation: 4 | type: article 5 | authors: 6 | - family-names: "Meng" 7 | given-names: "Kevin" 8 | - family-names: "Sen Sharma" 9 | given-names: "Arnab" 10 | - family-names: "Andonian" 11 | given-names: "Alex" 12 | - family-names: "Belinkov" 13 | given-names: "Yonatan" 14 | - family-names: "Bau" 15 | given-names: "David" 16 | journal: "arXiv preprint arXiv:2210.07229" 17 | title: "Mass-Editing Memory in a Transformer" 18 | year: 2022 -------------------------------------------------------------------------------- /memit/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Kevin Meng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /memit/README.md: -------------------------------------------------------------------------------- 1 | # MEMIT: Mass-Editing Memory in a Transformer 2 | 3 | Editing thousands of facts into a transformer memory at once. 4 | 5 | 6 | 7 | ## Table of Contents 8 | 9 | - [Installation](#installation) 10 | - [MEMIT Algorithm Demo](#memit-algorithm-demo) 11 | - [Running the Full Evaluation Suite](#running-the-full-evaluation-suite) 12 | - [Generating Scaling Curves](#generating-scaling-curves) 13 | - [How to Cite](#how-to-cite) 14 | 15 | ## Installation 16 | 17 | We recommend `conda` for managing Python, CUDA, and PyTorch; `pip` is for everything else. To get started, simply install `conda` and run: 18 | ```bash 19 | CONDA_HOME=$CONDA_HOME ./scripts/setup_conda.sh 20 | ``` 21 | 22 | `$CONDA_HOME` should be the path to your `conda` installation, e.g., `~/miniconda3`. 23 | 24 | ## MEMIT Algorithm Demo 25 | 26 | [`notebooks/memit.ipynb`](notebooks/memit.ipynb) demonstrates MEMIT. The API is simple; simply specify a *requested rewrite* of the following form: 27 | 28 | ```python 29 | request = [ 30 | { 31 | "prompt": "{} plays the sport of", 32 | "subject": "LeBron James", 33 | "target_new": { 34 | "str": "football" 35 | } 36 | }, 37 | { 38 | "prompt": "{} plays the sport of", 39 | "subject": "Michael Jordan", 40 | "target_new": { 41 | "str": "baseball" 42 | } 43 | }, 44 | ] 45 | ``` 46 | 47 | Other similar example(s) are included in the notebook. 48 | 49 | ## Running the Full Evaluation Suite 50 | 51 | [`experiments/evaluate.py`](experiments/evaluate.py) can be used to evaluate any method in [`baselines/`](baselines/). 52 | 53 | For example: 54 | ``` 55 | python3 -m experiments.evaluate \ 56 | --alg_name=MEMIT \ 57 | --model_name=EleutherAI/gpt-j-6B \ 58 | --hparams_fname=EleutherAI_gpt-j-6B.json \ 59 | --num_edits=10000 \ 60 | --use_cache 61 | ``` 62 | Results from each run are stored at `results//run_` in a specific format: 63 | ```bash 64 | results/ 65 | |__ MEMIT/ 66 | |__ run_/ 67 | |__ params.json 68 | |__ case_0.json 69 | |__ case_1.json 70 | |__ ... 71 | |__ case_10000.json 72 | ``` 73 | 74 | To summarize the results, you can use [`experiments/summarize.py`](experiments/summarize.py): 75 | ```bash 76 | python3 -m experiments.summarize --dir_name=MEMIT --runs=run_,run_ 77 | ``` 78 | 79 | Running `python3 -m experiments.evaluate -h` or `python3 -m experiments.summarize -h` provides details about command-line flags. 80 | 81 | ## How to Cite 82 | 83 | ```bibtex 84 | @article{meng2022memit, 85 | title={Mass Editing Memory in a Transformer}, 86 | author={Kevin Meng and Sen Sharma, Arnab and Alex Andonian and Yonatan Belinkov and David Bau}, 87 | journal={arXiv preprint arXiv:2210.07229}, 88 | year={2022} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /memit/baselines/README.md: -------------------------------------------------------------------------------- 1 | We compare ROME against several open sourced state-of-the-art model editors. All are implemented in their respective folders. Implementations other than FT/FT+L are adapted from third parties. 2 | - Fine-Tuning (`ft`): Direct fine-tuning. 3 | - Constrained Fine-Tuning (`ft`): FT with $L_\infty$ norm constraint. Inspired by Zhu et al. [[Paper]](https://arxiv.org/abs/2012.00363) 4 | - Knowledge Neurons (`kn`): Dai et al. [[Code]](https://github.com/EleutherAI/knowledge-neurons) [[Paper]](https://arxiv.org/abs/2104.08696) 5 | - Knowledge Editor (`efk`): De Cao et al. [[Code]](https://github.com/eric-mitchell/mend) [[Paper]](https://arxiv.org/abs/2104.08164) 6 | - Model Editor Networks with Gradient Decomposition (`mend`): Mitchell et al. [[Code]](https://github.com/eric-mitchell/mend) [[Paper]](https://arxiv.org/abs/2110.11309) -------------------------------------------------------------------------------- /memit/baselines/ft/__init__.py: -------------------------------------------------------------------------------- 1 | from .ft_main import FTHyperParams, apply_ft_to_model, execute_ft 2 | -------------------------------------------------------------------------------- /memit/baselines/ft/ft_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | from util.hparams import HyperParams 5 | 6 | 7 | @dataclass 8 | class FTHyperParams(HyperParams): 9 | # Method 10 | layers: List[int] 11 | num_steps: int 12 | lr: float 13 | weight_decay: float 14 | kl_factor: float 15 | norm_constraint: float 16 | 17 | # Module templates 18 | rewrite_module_tmp: str 19 | layer_module_tmp: str 20 | mlp_module_tmp: str 21 | attn_module_tmp: str 22 | ln_f_module: str 23 | lm_head_module: str 24 | 25 | # Defaults 26 | batch_size: int = 64 27 | wd_power_law: tuple = None # Scale weight decay by number of edits 28 | -------------------------------------------------------------------------------- /memit/baselines/ft/ft_main.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, Dict, List, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from util import nethook 8 | 9 | from .ft_hparams import FTHyperParams 10 | 11 | 12 | def apply_ft_to_model( 13 | model: AutoModelForCausalLM, 14 | tok: AutoTokenizer, 15 | requests: List[Dict], 16 | hparams: FTHyperParams, 17 | copy=False, 18 | return_orig_weights=False, 19 | **kwargs: Any, 20 | ) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]: 21 | """ 22 | Returns a model with the desired changes. 23 | :param copy: If true, will preserve the original model while creating a new one to edit. 24 | Note that you are responsible for deallocating the new model's memory to avoid leaks. 25 | :return: (1) the updated model, (2) the weights that changed 26 | """ 27 | 28 | weights_copy = {} 29 | if copy: 30 | model = deepcopy(model) 31 | 32 | deltas = execute_ft(model, tok, requests, hparams) 33 | 34 | with torch.no_grad(): 35 | for w_name, upd_matrix in deltas.items(): 36 | w = nethook.get_parameter(model, w_name) 37 | if return_orig_weights and w_name not in weights_copy: 38 | weights_copy[w_name] = w.detach().clone() 39 | 40 | w[...] += upd_matrix 41 | 42 | print(f"New weights successfully inserted into {list(deltas.keys())}") 43 | 44 | return model, weights_copy 45 | 46 | 47 | def execute_ft( 48 | model: AutoModelForCausalLM, 49 | tok: AutoTokenizer, 50 | requests: List[Dict], 51 | hparams: FTHyperParams, 52 | **kwargs: Any, 53 | ) -> Dict[str, Tuple[torch.Tensor]]: 54 | """ 55 | Executes the FT update algorithm for the specified update at the specified layer 56 | Invariant: model at beginning of function == model at end of function 57 | """ 58 | 59 | # Update target and print info 60 | requests = deepcopy(requests) 61 | for request in requests: 62 | if request["target_new"]["str"][0] != " ": 63 | # Space required for correct tokenization 64 | request["target_new"]["str"] = " " + request["target_new"]["str"] 65 | print( 66 | f"Executing FT algo for: " 67 | f"[{request['prompt'].format(request['subject'])}] -> [{request['target_new']['str']}]" 68 | ) 69 | 70 | # Retrieve weights that user desires to change 71 | weights = { 72 | n: p 73 | for n, p in model.named_parameters() 74 | for layer in hparams.layers 75 | if hparams.rewrite_module_tmp.format(layer) in n 76 | } 77 | # Save old weights for future restoration 78 | weights_copy = {k: v.detach().clone() for k, v in weights.items()} 79 | print(f"Weights to be updated: {list(weights.keys())}") 80 | 81 | # Define inputs 82 | texts = [r["prompt"].format(r["subject"]) for r in requests] 83 | targets = [r["target_new"]["str"] for r in requests] 84 | 85 | # Configure optimizer / gradients 86 | wd = ( 87 | hparams.weight_decay 88 | if not isinstance(hparams.wd_power_law, tuple) 89 | else (len(requests) ** hparams.wd_power_law[0]) 90 | * np.exp(hparams.wd_power_law[1]) 91 | ) 92 | print(f"Using weight decay of {wd} for {len(requests)} edits") 93 | opt = torch.optim.Adam( 94 | [v for _, v in weights.items()], 95 | lr=hparams.lr, 96 | weight_decay=wd, 97 | ) 98 | for name, w in model.named_parameters(): 99 | w.requires_grad = name in weights 100 | 101 | # Update loop: intervene at layers simultaneously 102 | loss_meter = AverageMeter() 103 | for it in range(hparams.num_steps): 104 | print(20 * "=") 105 | print(f"Epoch: {it}") 106 | print(20 * "=") 107 | loss_meter.reset() 108 | 109 | for txt, tgt in zip( 110 | chunks(texts, hparams.batch_size), chunks(targets, hparams.batch_size) 111 | ): 112 | inputs = tok(txt, return_tensors="pt", padding=True).to("cuda") 113 | target_ids = tok(tgt, return_tensors="pt", padding=True)["input_ids"].to( 114 | "cuda" 115 | ) 116 | last_token_inds = inputs["attention_mask"].sum(dim=1) - 1 117 | loss_mask = target_ids != tok.unk_token_id 118 | 119 | opt.zero_grad() 120 | bs = inputs["input_ids"].shape[0] 121 | probs = torch.nn.functional.log_softmax( 122 | model(**inputs).logits[torch.arange(bs), last_token_inds], dim=-1 123 | ) 124 | loss = -(torch.gather(probs, 1, target_ids) * loss_mask).sum( 125 | 1 126 | ) / loss_mask.sum(1) 127 | loss = loss.mean() 128 | print(f"Batch loss {loss.item()}") 129 | loss_meter.update(loss.item(), n=bs) 130 | 131 | if loss.item() >= 1e-2: 132 | loss.backward() 133 | opt.step() 134 | 135 | if type(hparams.norm_constraint) is float: 136 | eps = hparams.norm_constraint 137 | with torch.no_grad(): 138 | for k, v in weights.items(): 139 | v[...] = torch.clamp( 140 | v, min=weights_copy[k] - eps, max=weights_copy[k] + eps 141 | ) 142 | 143 | print(f"Total loss {loss_meter.avg}") 144 | 145 | if loss_meter.avg < 1e-2: 146 | break 147 | 148 | deltas = {k: (weights[k] - weights_copy[k]).detach() for k in weights} 149 | 150 | # Restore state of original model 151 | with torch.no_grad(): 152 | for k, v in weights.items(): 153 | v[...] = weights_copy[k] 154 | 155 | print(f"Deltas successfully computed for {list(weights.keys())}") 156 | 157 | return deltas 158 | 159 | 160 | def chunks(arr, n): 161 | """Yield successive n-sized chunks from arr.""" 162 | chunk = [] 163 | for a in arr: 164 | chunk.append(a) 165 | if len(chunk) == n: 166 | yield chunk 167 | chunk = [] 168 | if len(chunk) > 0: 169 | yield chunk 170 | 171 | 172 | class AverageMeter: 173 | """Computes and stores the average and current value""" 174 | 175 | def __init__(self): 176 | self.reset() 177 | 178 | def reset(self): 179 | self.val = 0 180 | self.avg = 0 181 | self.sum = 0 182 | self.count = 0 183 | 184 | def update(self, val, n=1): 185 | self.val = val 186 | self.sum += val * n 187 | self.count += n 188 | self.avg = self.sum / self.count 189 | -------------------------------------------------------------------------------- /memit/baselines/mend/README.md: -------------------------------------------------------------------------------- 1 | # MEND: Model Editing Networks using Gradient Decomposition 2 | 3 | If you run into any issues with the code, you can open an issue and/or email me at `eric.mitchell@cs.stanford.edu` 4 | 5 | ## Setup 6 | 7 | ### Environment 8 | 9 | This codebase uses Python 3.7.9. Other versions may work as well. 10 | 11 | Create a virtualenv ([pyenv](https://github.com/pyenv/pyenv) can help with this) 12 | and install the dependencies: 13 | 14 | $ python -m venv env 15 | $ source env/bin/activate 16 | (env) $ pip install -r requirements.txt 17 | 18 | ### Data 19 | 20 | You can download the data needed for this project from 21 | [this Google Drive link](https://drive.google.com/drive/folders/1jAqBE45jEKR-5pMkwxlVQ0V8eKxqWbxA?usp=sharing). 22 | Unzip each sub-directory into `mend/data` and you should be good to go. 23 | 24 | ## Running the code 25 | 26 | Run MEND training/evaluation for distilGPT-2 on the wikitext editing problem with: 27 | 28 | (env) $ python -m run +alg=mend +experiment=gen +model=distilgpt2 data.wiki_webtext=False 29 | 30 | Other valid algs include `efk` ([KnowledgeEditor](https://arxiv.org/abs/2104.08164)) 31 | and `enn` ([Editable Neural Networks](https://arxiv.org/abs/2004.00345)). Valid experiments 32 | include `fc` (FEVER fact checking) and `qa` (zsRE question-answering). Splits and rephrases 33 | for both come from [De Cao et. al](https://arxiv.org/abs/2104.08164). Check `config/model` 34 | for options for editable models (note that all models don't work for all experiments; GPT-style 35 | models only work with `gen`, seq2seq models only work with `qa`, and BERT only works with `fc`). 36 | 37 | Also note that in the paper, we sample locality data from different datasets depending on the model. 38 | By default, training will use [Natural Questions](https://ai.google.com/research/NaturalQuestions) 39 | data (not zsRE data) for computing drawdown in the `qa` experiment and 40 | [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/). For models such as the `distilgpt2` 41 | model we use (which was fine-tuned on wikitext) or the BART-base model, this behavior should be 42 | disabled with `data.wiki_webtext=False` or `data.zsre_nq=False`, respectively. 43 | 44 | ## Citing the paper 45 | 46 | If this code or paper was useful, please consider using the following citation: 47 | 48 | @article{mitchell2021fast, 49 | title={Fast Model Editing at Scale}, 50 | author={Mitchell, Eric and Lin, Charles and Bosselut, Antoine and Finn, Chelsea and Manning, Christopher D.}, 51 | year={2021}, 52 | journal={CoRR}, 53 | url={https://arxiv.org/pdf/2110.11309.pdf} 54 | } 55 | -------------------------------------------------------------------------------- /memit/baselines/mend/__init__.py: -------------------------------------------------------------------------------- 1 | from .mend_hparams import MENDHyperParams 2 | from .mend_main import MendRewriteExecutor 3 | -------------------------------------------------------------------------------- /memit/baselines/mend/algs/enn.py: -------------------------------------------------------------------------------- 1 | import higher 2 | import torch 3 | import torch.nn as nn 4 | from editable_model import EditableModel 5 | from utils import _logits 6 | 7 | 8 | def fomaml_callback(all_grads): 9 | return [g.detach() if g is not None else None for g in all_grads] 10 | 11 | 12 | class ENN(EditableModel): 13 | def __init__( 14 | self, model, config, model_constructor, edit_lrs=None, edit_loss_fn=None 15 | ): 16 | super().__init__(model, config, model_constructor) 17 | 18 | if edit_lrs is None: 19 | edit_lrs = nn.Parameter( 20 | torch.tensor([config.edit_lr] * len(self.config.model.inner_params)) 21 | ) 22 | self.edit_lrs = edit_lrs 23 | 24 | if edit_loss_fn is not None: 25 | self.edit_loss_fn = edit_loss_fn 26 | 27 | self.grad_callback = fomaml_callback if config.enn.first_order else lambda x: x 28 | 29 | def outer_parameters(self): 30 | if self.config.no_grad_layers is None: 31 | return super().outer_parameters() 32 | else: 33 | params = [self.edit_lrs] 34 | for m in self.model.modules(): 35 | if isinstance(m, nn.ModuleList): 36 | params.extend(list(m[self.config.no_grad_layers :].parameters())) 37 | return params 38 | 39 | def get_state_dict(self): 40 | return self.state_dict() 41 | 42 | def edit(self, batch, condition=None, detach_history=False): 43 | opt = torch.optim.SGD( 44 | [ 45 | {"params": p, "lr": None} 46 | for (n, p) in self.model.named_parameters() 47 | if n in self.config.model.inner_params 48 | ] 49 | ) 50 | with torch.enable_grad(), higher.innerloop_ctx( 51 | self.model, 52 | opt, 53 | override={"lr": list(self.edit_lrs)}, 54 | copy_initial_weights=False, 55 | track_higher_grads=self.training, 56 | in_place=True, 57 | ) as (fmodel, diffopt): 58 | fmodel.eval() 59 | for edit_step in range(self.config.enn.n_edit_steps): 60 | output = _logits(fmodel(**batch)) 61 | loss = self.edit_loss_fn(output, batch["labels"])["nll"] 62 | diffopt.step(loss, grad_callback=self.grad_callback) 63 | 64 | if not detach_history: 65 | model_edited = fmodel 66 | else: 67 | model_edited = self.model_constructor() 68 | model_edited.load_state_dict(fmodel.state_dict()) 69 | model_edited.train(self.training) 70 | 71 | return ( 72 | ENN( 73 | model_edited, 74 | self.config, 75 | self.model_constructor, 76 | edit_lrs=self.edit_lrs, 77 | edit_loss_fn=self.edit_loss_fn, 78 | ), 79 | {}, 80 | ) 81 | 82 | 83 | def test(): 84 | import copy 85 | import types 86 | 87 | import transformers 88 | 89 | model = transformers.GPT2LMHeadModel.from_pretrained("gpt2") 90 | 91 | config = types.SimpleNamespace() 92 | config.edit_lr = 0.1 93 | config.model.inner_params = [ 94 | "transformer.h.9.mlp.c_fc.weight", 95 | "transformer.h.9.mlp.c_proj.weight", 96 | "transformer.h.10.mlp.c_fc.weight", 97 | "transformer.h.10.mlp.c_proj.weight", 98 | "transformer.h.11.mlp.c_fc.weight", 99 | "transformer.h.11.mlp.c_proj.weight", 100 | ] 101 | config.enn = {"n_edit_steps": 2, "first_order": False} 102 | 103 | enn = ENN(model, config, lambda: copy.deepcopy(model)).cuda() 104 | 105 | x = torch.arange(100).view(5, 20).cuda() + 1000 106 | 107 | edited = enn.edit(x, masks=torch.ones_like(x), labels=x) 108 | 109 | orig_param = [ 110 | p 111 | for (n, p) in enn.model.named_parameters() 112 | if n == config.model.inner_params[-1] 113 | ][0] 114 | edited_param = [ 115 | p 116 | for (n, p) in edited.model.named_parameters() 117 | if n == config.model.inner_params[-1] 118 | ][0] 119 | 120 | print((orig_param - edited_param).abs().max()) 121 | edited.eval() 122 | print( 123 | enn(x, labels=x).loss, 124 | edited(x, labels=x).loss, 125 | edited.edit_loss_fn(edited(x).logits, x)["nll"], 126 | ) 127 | edited.edit_loss_fn(edited(x).logits, x).backward() 128 | import pdb 129 | 130 | pdb.set_trace() 131 | 132 | 133 | if __name__ == "__main__": 134 | with torch.autograd.set_detect_anomaly(True): 135 | test() 136 | -------------------------------------------------------------------------------- /memit/baselines/mend/algs/ft.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import higher 4 | import torch 5 | import torch.nn as nn 6 | from editable_model import EditableModel 7 | from higher.patch import monkeypatch as make_functional 8 | from losses import kl_loc_loss 9 | from utils import _inner_params, _logits 10 | 11 | 12 | class FT(EditableModel): 13 | """ 14 | Fine-tuning approach. Does not require training. 15 | """ 16 | 17 | def __init__(self, model, config, model_constructor, edit_loss_fn=None): 18 | super().__init__(model, config, model_constructor) 19 | 20 | if edit_loss_fn is not None: 21 | self.edit_loss_fn = edit_loss_fn 22 | 23 | self.locality_loss_fn = kl_loc_loss 24 | self.loc_ids = None 25 | self.loc_masks = None 26 | self.loc_sampler = None 27 | 28 | def _edit_loss(self, model, p0, p_edited, edit_batch): 29 | output = _logits(model(**edit_batch, params=p_edited)) 30 | loss_dict = self.edit_loss_fn(output, edit_batch["labels"]) 31 | l_edit, acc = loss_dict["nll"], loss_dict["acc"] 32 | if self.config.ft.locality.enabled: 33 | if self.config.ft.locality.oracle: 34 | loc_batch = next(self.loc_sampler)["loc"] 35 | else: 36 | raise NotImplementedError 37 | 38 | with torch.no_grad(): 39 | original_base_logits = _logits(model(**loc_batch, params=p0)) 40 | edited_base_logits = _logits(model(**loc_batch, params=p_edited)) 41 | kl_mask = loc_batch.get( 42 | "decoder_attention_mask", loc_batch["attention_mask"] 43 | ) 44 | l_loc = self.locality_loss_fn( 45 | original_base_logits, edited_base_logits, mask=kl_mask 46 | ) 47 | loss = l_loc + self.config.ft.locality.cedit * l_edit 48 | else: 49 | l_loc = torch.tensor(float("nan")) 50 | loss = l_edit 51 | return loss, l_edit, l_loc, acc 52 | 53 | def accuracy(self, output, labels): 54 | if output.shape[-1] != 1: 55 | shifted_output = output.argmax(-1)[:, :-1] 56 | shifted_labels = labels[:, 1:] 57 | to_predict = (shifted_labels != -100).sum() 58 | correct = (shifted_output == shifted_labels).sum() 59 | acc = correct.float() / to_predict.float() 60 | else: 61 | acc = ((output > 0) == labels.bool()).sum().float() 62 | return acc 63 | 64 | def _edit_status(self, step, loss, l_edit, l_loc, acc, res_p): 65 | return ( 66 | f"step: {step}".ljust(14) 67 | + f"loss: {loss.item():.5f}".ljust(18) 68 | + f"l_edit: {l_edit.item():.5f}".ljust(18) 69 | + f"l_loc: {l_loc.item():.5f}".ljust(18) 70 | + f"acc: {acc.item():.2f}".ljust(14) 71 | + f"norm: {res_p.view(-1).norm().item():.5f}" 72 | ) 73 | 74 | def edit(self, batch, condition=None, detach_history=False): 75 | edit_model = self.model.eval() 76 | p0 = list(edit_model.named_parameters()) 77 | 78 | if not isinstance(edit_model, higher.patch._MonkeyPatchBase): 79 | edit_model = make_functional( 80 | self.model, track_higher_grads=False, in_place=True 81 | ) 82 | 83 | packed_residuals = {} 84 | opt_params = [] 85 | for n, p in _inner_params( 86 | edit_model.named_parameters(), self.config.model.inner_params 87 | ): 88 | if self.config.ft.rank is not None: 89 | u = nn.Parameter( 90 | torch.randn(p.shape[0], self.config.ft.rank, device=p.device) 91 | * self.config.ft.init_std 92 | ) 93 | v = nn.Parameter( 94 | torch.zeros(self.config.ft.rank, p.shape[1], device=p.device) 95 | ) 96 | res = [u, v] 97 | else: 98 | res = [nn.Parameter(torch.zeros_like(p, device=p.device))] 99 | 100 | packed_residuals[n] = res 101 | opt_params.extend(res) 102 | 103 | assert len(opt_params) == len(self.config.model.inner_params) 104 | OptClass = getattr(torch.optim, self.config.ft.opt) 105 | opt = OptClass(opt_params, lr=self.config.edit_lr) 106 | 107 | start_time = time.time() 108 | for edit_step in range(self.config.ft.max_edit_steps): 109 | if self.config.ft.time_limit is not None and ( 110 | time.time() - start_time > self.config.ft.time_limit 111 | ): 112 | break 113 | residuals = { 114 | k: v[0] @ v[1] if len(v) == 2 else v[0] 115 | for k, v in packed_residuals.items() 116 | } 117 | edited_params = [ 118 | p if n not in residuals else p.detach() + residuals[n] for n, p in p0 119 | ] 120 | loss, l_edit, l_loc, acc = self._edit_loss( 121 | edit_model, [p for n, p in p0], edited_params, batch 122 | ) 123 | 124 | if self.config.ft.verbose: 125 | residual = list(residuals.values())[-1] 126 | print( 127 | self._edit_status(edit_step, loss, l_edit, l_loc, acc, residual), 128 | end="\r", 129 | ) 130 | 131 | if acc == 1.0: 132 | break 133 | 134 | for p, g in zip(opt_params, torch.autograd.grad(loss, opt_params)): 135 | p.grad = g 136 | torch.nn.utils.clip_grad_norm_(opt_params, self.config.grad_clip) 137 | opt.step() 138 | opt.zero_grad() 139 | 140 | if detach_history: 141 | new_model = self.model_constructor() 142 | new_model.load_state_dict(edit_model.state_dict()) 143 | edit_model = new_model 144 | edit_model.train(self.training) 145 | 146 | return ( 147 | FT(edit_model, self.config, self.model_constructor, self.edit_loss_fn), 148 | {}, 149 | ) 150 | -------------------------------------------------------------------------------- /memit/baselines/mend/config/alg/efk.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | alg: efk 4 | train_base: False 5 | lr: 1e-5 6 | -------------------------------------------------------------------------------- /memit/baselines/mend/config/alg/enn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | alg: enn 4 | train_base: True 5 | enn: 6 | first_order: False 7 | n_edit_steps: 1 8 | -------------------------------------------------------------------------------- /memit/baselines/mend/config/alg/ft.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | train_base: False 4 | alg: ft 5 | edit_lr: 5e-6 6 | ft: 7 | verbose: false 8 | max_edit_steps: 100 9 | time_limit: null 10 | locality: 11 | enabled: false 12 | oracle: true 13 | cedit: 1e-2 14 | batch_size: 1 15 | rank: null 16 | opt: RMSprop 17 | init_std: 0.01 18 | -------------------------------------------------------------------------------- /memit/baselines/mend/config/alg/mend.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | alg: mend 4 | lr: 1e-6 5 | train_base: False 6 | edit_lr: 1e-4 7 | lr_lr: 1e-4 8 | mend: 9 | one_sided: False 10 | n_hidden: 1 11 | hidden_dim: null 12 | init: id 13 | norm: True 14 | combine: True 15 | x_only: False 16 | delta_only: False 17 | act: relu 18 | rank: 1920 19 | mlp_class: IDMLP 20 | shared: True 21 | -------------------------------------------------------------------------------- /memit/baselines/mend/config/config.yaml: -------------------------------------------------------------------------------- 1 | alg: enn 2 | lr: 1e-5 3 | edit_lr: 1e-2 4 | seed: 0 5 | debug: False 6 | model_save_pt: 5000 7 | edit_bs: 1 8 | silent: False 9 | max_iters: 1000000 10 | log_interval: 100 11 | val_interval: 5000 12 | lr_lr: 1e-3 13 | batch_size: 2 14 | val_batch_size: 5 15 | accumulate_bs: 10 16 | cedit: 0.1 17 | cloc: 1.0 18 | cbase: 1.0 19 | val_steps: 500 20 | device: cuda 21 | base_loss: distill 22 | oracle: False 23 | train: True 24 | train_base: True 25 | opt: Adam 26 | single_batch: False 27 | archive: null 28 | grad_clip: 100. 29 | ref: null 30 | early_stop_patience: 20000 31 | early_stop_key: "loss/total_edit_val" 32 | dropout: 0.0 33 | tokenizer: null 34 | results_dir: null 35 | no_grad_layers: null 36 | eval_only: False 37 | half: False 38 | save: False 39 | 40 | model: 41 | pt: null 42 | 43 | data: 44 | path: null 45 | rephrase: true 46 | zsre_nq: true 47 | nq_path: ${hydra:runtime.cwd}/data/nq 48 | wiki_webtext: true 49 | n_edits: 1 50 | 51 | eval: 52 | verbose: True 53 | log_interval: 100 54 | final_eval: True 55 | 56 | hydra: 57 | run: 58 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f${uuid:}} 59 | sweep: 60 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f} 61 | subdir: ${hydra.job.num} -------------------------------------------------------------------------------- /memit/baselines/mend/config/experiment/fc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: fc 4 | dataset: fever 5 | cbase: 1.0 6 | -------------------------------------------------------------------------------- /memit/baselines/mend/config/experiment/gen.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: gen 4 | dataset: wikitext-103 5 | cbase: 10.0 6 | data: 7 | path: ${hydra:runtime.cwd}/data/10token/data/self_sample/ 8 | -------------------------------------------------------------------------------- /memit/baselines/mend/config/experiment/qa.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: qa 4 | dataset: zsre 5 | cbase: 1.0 6 | -------------------------------------------------------------------------------- /memit/baselines/mend/config/model/bart-base.yaml: -------------------------------------------------------------------------------- 1 | name: facebook/bart-base 2 | class_name: BartForConditionalGeneration 3 | tokenizer_class: BartTokenizerFast 4 | tokenizer_name: facebook/bart-base 5 | inner_params: 6 | - model.encoder.layers.4.fc1.weight 7 | - model.encoder.layers.4.fc2.weight 8 | - model.encoder.layers.5.fc1.weight 9 | - model.encoder.layers.5.fc2.weight 10 | - model.decoder.layers.4.fc1.weight 11 | - model.decoder.layers.4.fc2.weight 12 | - model.decoder.layers.5.fc1.weight 13 | - model.decoder.layers.5.fc2.weight 14 | 15 | pt: ${hydra:runtime.cwd}/data/zsre/QA_model.ckpt -------------------------------------------------------------------------------- /memit/baselines/mend/config/model/bert-base.yaml: -------------------------------------------------------------------------------- 1 | name: bert-base-uncased 2 | class_name: BertClassifier 3 | tokenizer_class: BertTokenizerFast 4 | tokenizer_name: bert-base-uncased 5 | inner_params: 6 | - model.encoder.layer.9.intermediate.dense.weight 7 | - model.encoder.layer.9.output.dense.weight 8 | - model.encoder.layer.10.intermediate.dense.weight 9 | - model.encoder.layer.10.output.dense.weight 10 | - model.encoder.layer.11.intermediate.dense.weight 11 | - model.encoder.layer.11.output.dense.weight 12 | 13 | pt: ${hydra:runtime.cwd}/data/fever/FC_model.ckpt -------------------------------------------------------------------------------- /memit/baselines/mend/config/model/distilgpt2.yaml: -------------------------------------------------------------------------------- 1 | name: MYX4567/distilgpt2-finetuned-wikitext2 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: distilgpt2 5 | inner_params: 6 | - transformer.h.3.mlp.c_fc.weight 7 | - transformer.h.3.mlp.c_proj.weight 8 | - transformer.h.4.mlp.c_fc.weight 9 | - transformer.h.4.mlp.c_proj.weight 10 | - transformer.h.5.mlp.c_fc.weight 11 | - transformer.h.5.mlp.c_proj.weight 12 | 13 | -------------------------------------------------------------------------------- /memit/baselines/mend/config/model/gpt2.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: gpt2 5 | inner_params: 6 | - transformer.h.9.mlp.c_proj.weight 7 | - transformer.h.9.mlp.c_fc.weight 8 | - transformer.h.10.mlp.c_proj.weight 9 | - transformer.h.10.mlp.c_fc.weight 10 | - transformer.h.11.mlp.c_proj.weight 11 | - transformer.h.11.mlp.c_fc.weight -------------------------------------------------------------------------------- /memit/baselines/mend/config/model/gpt2large.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2-large 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: gpt2-large 5 | inner_params: 6 | - transformer.h.33.mlp.c_proj.weight 7 | - transformer.h.33.mlp.c_fc.weight 8 | - transformer.h.34.mlp.c_proj.weight 9 | - transformer.h.34.mlp.c_fc.weight 10 | - transformer.h.35.mlp.c_proj.weight 11 | - transformer.h.35.mlp.c_fc.weight 12 | -------------------------------------------------------------------------------- /memit/baselines/mend/config/model/gpt2medium.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2-medium 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: gpt2-medium 5 | inner_params: 6 | - transformer.h.21.mlp.c_proj.weight 7 | - transformer.h.21.mlp.c_fc.weight 8 | - transformer.h.22.mlp.c_proj.weight 9 | - transformer.h.22.mlp.c_fc.weight 10 | - transformer.h.23.mlp.c_proj.weight 11 | - transformer.h.23.mlp.c_fc.weight -------------------------------------------------------------------------------- /memit/baselines/mend/config/model/gpt2xl.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2-xl 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: gpt2-xl 5 | inner_params: 6 | - transformer.h.45.mlp.c_proj.weight 7 | - transformer.h.45.mlp.c_fc.weight 8 | - transformer.h.46.mlp.c_proj.weight 9 | - transformer.h.46.mlp.c_fc.weight 10 | - transformer.h.47.mlp.c_proj.weight 11 | - transformer.h.47.mlp.c_fc.weight 12 | -------------------------------------------------------------------------------- /memit/baselines/mend/config/model/gptj.yaml: -------------------------------------------------------------------------------- 1 | name: EleutherAI/gpt-j-6B 2 | class_name: GPTJForCausalLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: EleutherAI/gpt-j-6B 5 | inner_params: 6 | - transformer.h.25.mlp.fc_in.weight 7 | - transformer.h.25.mlp.fc_out.weight 8 | - transformer.h.26.mlp.fc_in.weight 9 | - transformer.h.26.mlp.fc_out.weight 10 | - transformer.h.27.mlp.fc_in.weight 11 | - transformer.h.27.mlp.fc_out.weight 12 | -------------------------------------------------------------------------------- /memit/baselines/mend/config/model/gptneo27.yaml: -------------------------------------------------------------------------------- 1 | name: EleutherAI/gpt-neo-2.7B 2 | class_name: GPTNeoForCausalLM 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: EleutherAI/gpt-neo-2.7B 5 | inner_params: 6 | - transformer.h.29.mlp.c_fc.weight 7 | - transformer.h.29.mlp.c_proj.weight 8 | - transformer.h.30.mlp.c_fc.weight 9 | - transformer.h.30.mlp.c_proj.weight 10 | - transformer.h.31.mlp.c_fc.weight 11 | - transformer.h.31.mlp.c_proj.weight 12 | -------------------------------------------------------------------------------- /memit/baselines/mend/config/model/t5large.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-large-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-large-ssm-nq 5 | inner_params: 6 | - encoder.block.22.layer.1.DenseReluDense.wi.weight 7 | - encoder.block.22.layer.1.DenseReluDense.wo.weight 8 | - encoder.block.23.layer.1.DenseReluDense.wi.weight 9 | - encoder.block.23.layer.1.DenseReluDense.wo.weight 10 | - decoder.block.22.layer.2.DenseReluDense.wi.weight 11 | - decoder.block.22.layer.2.DenseReluDense.wo.weight 12 | - decoder.block.23.layer.2.DenseReluDense.wi.weight 13 | - decoder.block.23.layer.2.DenseReluDense.wo.weight 14 | 15 | pt: null -------------------------------------------------------------------------------- /memit/baselines/mend/config/model/t5small.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-small-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-small-ssm-nq 5 | inner_params: 6 | - encoder.block.6.layer.1.DenseReluDense.wi_0.weight 7 | - encoder.block.6.layer.1.DenseReluDense.wo.weight 8 | - encoder.block.7.layer.1.DenseReluDense.wi_0.weight 9 | - encoder.block.7.layer.1.DenseReluDense.wo.weight 10 | - decoder.block.6.layer.2.DenseReluDense.wi_0.weight 11 | - decoder.block.6.layer.2.DenseReluDense.wo.weight 12 | - decoder.block.7.layer.2.DenseReluDense.wi_0.weight 13 | - decoder.block.7.layer.2.DenseReluDense.wo.weight 14 | 15 | pt: null -------------------------------------------------------------------------------- /memit/baselines/mend/config/model/t5xl.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-xl-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-xl-ssm-nq 5 | inner_params: 6 | - encoder.block.22.layer.1.DenseReluDense.wi_0.weight 7 | - encoder.block.22.layer.1.DenseReluDense.wo.weight 8 | - encoder.block.23.layer.1.DenseReluDense.wi_0.weight 9 | - encoder.block.23.layer.1.DenseReluDense.wo.weight 10 | - decoder.block.22.layer.2.DenseReluDense.wi_0.weight 11 | - decoder.block.22.layer.2.DenseReluDense.wo.weight 12 | - decoder.block.23.layer.2.DenseReluDense.wi_0.weight 13 | - decoder.block.23.layer.2.DenseReluDense.wo.weight 14 | 15 | pt: null -------------------------------------------------------------------------------- /memit/baselines/mend/config/model/t5xxl.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-xxl-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-xxl-ssm-nq 5 | inner_params: 6 | - encoder.block.22.layer.1.DenseReluDense.wi_0.weight 7 | - encoder.block.22.layer.1.DenseReluDense.wo.weight 8 | - encoder.block.23.layer.1.DenseReluDense.wi_0.weight 9 | - encoder.block.23.layer.1.DenseReluDense.wo.weight 10 | - decoder.block.22.layer.2.DenseReluDense.wi_0.weight 11 | - decoder.block.22.layer.2.DenseReluDense.wo.weight 12 | - decoder.block.23.layer.2.DenseReluDense.wi_0.weight 13 | - decoder.block.23.layer.2.DenseReluDense.wo.weight 14 | 15 | pt: null -------------------------------------------------------------------------------- /memit/baselines/mend/data_classes/fever.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import jsonlines 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from utils import EditBatchSampler, dict_to 8 | 9 | POSITIVE_CLASS = "SUPPORTS" 10 | 11 | 12 | class BinaryAugmentedKILT(Dataset): 13 | def __init__(self, tokenizer, data_path, config, max_length=32): 14 | super().__init__() 15 | self.tokenizer = tokenizer 16 | self.data = [] 17 | self.config = config 18 | 19 | def extract(d): 20 | extracted = { 21 | k: d[k] 22 | for k in [ 23 | "logit", 24 | "input", 25 | "prediction", 26 | "alternatives", 27 | "filtered_rephrases", 28 | ] 29 | } 30 | extracted["label"] = d["output"][0]["answer"] 31 | return extracted 32 | 33 | with jsonlines.open(data_path) as f: 34 | for d in f: 35 | if len(d["alternatives"]) > 0 and len(d["filtered_rephrases"]) > 0: 36 | self.data.append(extract(d)) 37 | 38 | self.max_length = max_length 39 | 40 | def __len__(self): 41 | return len(self.data) 42 | 43 | def __getitem__(self, item): 44 | obj = self.data[item] 45 | rephrase = random.choice(self.data[item]["filtered_rephrases"]) 46 | output = { 47 | "label": obj["label"] == POSITIVE_CLASS, 48 | "src": obj["input"], 49 | "rephrase": rephrase, 50 | "pred": obj["prediction"] == POSITIVE_CLASS, 51 | "alt": obj["alternatives"][0] == POSITIVE_CLASS, 52 | "cond_flip": "{} >> {} || {}".format( 53 | obj["prediction"], 54 | obj["alternatives"][0], 55 | obj["input"], 56 | ), 57 | "cond_orig": "{} >> {} || {}".format( 58 | obj["prediction"], 59 | obj["prediction"], 60 | obj["input"], 61 | ), 62 | "logit": obj["logit"], 63 | } 64 | 65 | return output 66 | 67 | def collate_fn(self, batch): 68 | src = [b["src"] for b in batch] 69 | rephrase = [batch[-1]["rephrase"]] 70 | 71 | flip_label = np.random.uniform() > 0.5 72 | predictions = [b["pred"] for b in batch] 73 | labels = [b["label"] for b in batch] 74 | labels[-1] = predictions[ 75 | -1 76 | ] # the last element in the batch is special (the edit element) 77 | cond = [batch[-1]["cond_orig"]] 78 | if flip_label: 79 | labels[-1] = batch[-1]["alt"] 80 | cond = [batch[-1]["cond_flip"]] 81 | 82 | batches = {} 83 | for k1, v1 in {"": src, "cond_": cond, "rephrase_": rephrase}.items(): 84 | encoded = self.tokenizer( 85 | v1, 86 | return_tensors="pt", 87 | padding=True, 88 | max_length=self.max_length, 89 | truncation=True, 90 | ) 91 | for k2, v2 in encoded.items(): 92 | batches[f"{k1}{k2}"] = v2 93 | 94 | batches["predictions"] = torch.tensor(predictions).long().view(-1, 1) 95 | batches["labels"] = torch.tensor(labels).long().view(-1, 1) 96 | batches["raw"] = batch 97 | return batches 98 | 99 | def edit_generator(self, batch_size, n=None): 100 | if n is None: 101 | n = len(self) 102 | sampler = EditBatchSampler( 103 | n, memorize_mode=self.config.single_batch, seed=self.config.seed 104 | ) 105 | while True: 106 | edit_idxs, loc_idxs = sampler.sample(batch_size) 107 | assert len(edit_idxs) == 1 108 | idxs = loc_idxs + edit_idxs 109 | toks = self.collate_fn([self[idx] for idx in idxs]) 110 | 111 | pass_keys = ["input_ids", "attention_mask", "labels"] 112 | edit_inner = {k: v[-1:] for k, v in toks.items() if k in pass_keys} 113 | if self.config.data.rephrase: 114 | edit_outer = {} 115 | edit_outer["input_ids"] = toks["rephrase_input_ids"] 116 | edit_outer["attention_mask"] = toks["rephrase_attention_mask"] 117 | edit_outer["labels"] = edit_inner["labels"] 118 | else: 119 | edit_outer = edit_inner 120 | loc = {k: v[:-1] for k, v in toks.items() if k in pass_keys} 121 | cond = { 122 | "input_ids": toks["cond_input_ids"], 123 | "attention_mask": toks["cond_attention_mask"], 124 | } 125 | 126 | batch = { 127 | "edit_inner": edit_inner, 128 | "edit_outer": edit_outer, 129 | "loc": loc, 130 | "cond": cond, 131 | } 132 | yield dict_to(batch, self.config.device) 133 | -------------------------------------------------------------------------------- /memit/baselines/mend/data_classes/nq.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class NQDataset: 5 | def __init__(self, path: str, tokenizer, config): 6 | with open(path, "r") as f: 7 | self.data = json.load(f) 8 | 9 | self.questions = self.data["questions"] 10 | self.answers = self.data["answers"] 11 | self.tokenizer = tokenizer 12 | self.config = config 13 | 14 | def __getitem__(self, idx): 15 | idx = idx % len(self.questions) 16 | return self.questions[idx], self.answers[idx] 17 | 18 | @staticmethod 19 | def generate( 20 | out_path: str, 21 | prompt: bool = False, 22 | capitalize: bool = True, 23 | question_mark: bool = True, 24 | ): 25 | import os 26 | 27 | import datasets 28 | 29 | def process(text): 30 | if capitalize: 31 | text = text[0].capitalize() + text[1:] 32 | if question_mark: 33 | text = text + "?" 34 | if prompt: 35 | text = "nq question: " + text 36 | return text 37 | 38 | def extract(d): 39 | questions = [process(q["text"]) for q in d["question"]] 40 | answers = [ 41 | [a["text"][0] for a in ann["short_answers"] if len(a["text"])] 42 | for ann in d["annotations"] 43 | ] 44 | questions = [q for q, a in zip(questions, answers) if len(a)] 45 | answers = [min(a, key=len) for a in answers if len(a)] 46 | return questions, answers 47 | 48 | train = datasets.load_dataset("natural_questions", split="train") 49 | tq, ta = extract(train) 50 | val = datasets.load_dataset("natural_questions", split="validation") 51 | vq, va = extract(val) 52 | 53 | if not os.path.exists(out_path): 54 | os.makedirs(out_path) 55 | with open(f"{out_path}/train.json", "w") as f: 56 | json.dump({"questions": tq, "answers": ta}, f) 57 | with open(f"{out_path}/validation.json", "w") as f: 58 | json.dump({"questions": vq, "answers": va}, f) 59 | 60 | 61 | if __name__ == "__main__": 62 | import argparse 63 | 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument("--out_path", type=str, default="data/nq") 66 | args = parser.parse_args() 67 | NQDataset.generate(args.out_path) 68 | -------------------------------------------------------------------------------- /memit/baselines/mend/data_classes/wiki.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import logging 4 | import random 5 | 6 | from datasets import load_dataset 7 | from torch.utils.data import Dataset 8 | from utils import EditBatchSampler, dict_to, scr 9 | 10 | LOG = logging.getLogger(__name__) 11 | 12 | 13 | def is_ascii(s): 14 | return all(ord(c) < 128 for c in s) 15 | 16 | 17 | def filter_text(iterator): 18 | valid = [] 19 | for text in iterator: 20 | if len(text.split(" ")) < 50: 21 | continue 22 | if not is_ascii(text): 23 | continue 24 | valid.append(text) 25 | 26 | return valid 27 | 28 | 29 | class GenDataset(Dataset): 30 | def __init__( 31 | self, 32 | split: str, 33 | tokenizer, 34 | config, 35 | edit_path: str, 36 | pct: int = 10, 37 | max_length: int = 200, 38 | ): 39 | version = "wikitext-103-raw-v1" 40 | split_str = f"{split}[:{pct}%]" if split == "train" else split 41 | LOG.info(f"Loading wikitext version {version}, split {split_str}") 42 | base_samples = load_dataset( 43 | "wikitext", version, cache_dir=scr(), split=split_str 44 | )["text"] 45 | self.base_samples = filter_text(base_samples) 46 | with open(edit_path + split[:5] + ".json", "r") as f: 47 | self.edit_samples = json.load(f) 48 | 49 | self.tok = tokenizer 50 | self.config = config 51 | self.max_length = max_length 52 | self.n_tokens = self.edit_samples["n_tokens"] 53 | 54 | len_base = len(self.base_samples) 55 | len_edit = len(self.edit_samples["original"]) 56 | LOG.info(f"Loaded {len_base} wiki-103 samples and {len_edit} edit samples") 57 | 58 | if config.data.wiki_webtext: 59 | self.use_wiki = True 60 | LOG.info("** Using webtext for wiki base samples **") 61 | webtext = load_dataset( 62 | "stas/openwebtext-10k", split="train", cache_dir=scr() 63 | )["text"] 64 | n_train = int(len(webtext) * 0.9) 65 | if split == "train": 66 | self.base_samples = webtext[:n_train] 67 | else: 68 | self.base_samples = webtext[n_train:] 69 | else: 70 | self.use_wiki = False 71 | 72 | def edit_generator(self, batch_size, n=None): 73 | if n is None: 74 | n = len(self) 75 | sampler = EditBatchSampler( 76 | n, 77 | memorize_mode=self.config.single_batch, 78 | loc_disjoint=not self.use_wiki, 79 | seed=self.config.seed, 80 | ) 81 | while True: 82 | edit_idxs, loc_idxs = sampler.sample(batch_size) 83 | 84 | edit_batch = [self.edit_samples["completions"][idx] for idx in edit_idxs] 85 | loc_batch = [ 86 | self.base_samples[idx % len(self.base_samples)] for idx in loc_idxs 87 | ] 88 | 89 | edit_toks = self.tok(edit_batch, padding=True, return_tensors="pt") 90 | loc_toks = self.tok( 91 | loc_batch, 92 | padding=True, 93 | return_tensors="pt", 94 | truncation=self.config.data.wiki_webtext, 95 | max_length=self.max_length, 96 | ) 97 | 98 | edit_inner = {**edit_toks} 99 | edit_inner["labels"] = self.get_edit_labels(edit_toks["input_ids"]) 100 | 101 | edit_outer = copy.deepcopy(edit_inner) 102 | if self.config.data.rephrase: 103 | lens = (edit_outer["input_ids"] != -100).sum(-1) 104 | remove = random.randint(0, (min(lens) - self.n_tokens) // 2) 105 | for k, v in edit_outer.items(): 106 | edit_outer[k] = v[:, remove:] 107 | 108 | loc = {**loc_toks} 109 | loc["labels"] = self.get_labels(loc_toks["input_ids"]) 110 | cond = {**edit_toks} 111 | 112 | batch = { 113 | "edit_inner": edit_inner, 114 | "edit_outer": edit_outer, 115 | "loc": loc, 116 | "cond": cond, 117 | } 118 | 119 | yield dict_to(batch, self.config.device) 120 | 121 | def __len__(self): 122 | return len(self.edit_samples["original"]) 123 | 124 | def _check_padding(self, ids): 125 | if (ids[:, 0] == self.tok.pad_token_id).any(): 126 | raise ValueError("Left-padding not supported for GPT2") 127 | 128 | def get_edit_labels(self, ids): 129 | self._check_padding(ids) 130 | 131 | labels = ids.clone() 132 | end_idxs = (labels != self.tok.pad_token_id).sum(-1) 133 | for batch_idx, end_idx in enumerate(end_idxs): 134 | labels[batch_idx, : end_idx - self.n_tokens] = -100 135 | labels[labels == self.tok.pad_token_id] = -100 136 | return labels 137 | 138 | def get_labels(self, ids): 139 | self._check_padding(ids) 140 | 141 | return ids.masked_fill(ids == self.tok.pad_token_id, -100) 142 | 143 | def __getitem__(self, idx): 144 | return self.base_samples[idx] 145 | -------------------------------------------------------------------------------- /memit/baselines/mend/editable_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .losses import masked_log_probs 4 | from .utils import _logits, shift_targets 5 | 6 | 7 | class EditableModel(nn.Module): 8 | def __init__(self, model, config, model_constructor): 9 | super().__init__() 10 | 11 | self.model = model 12 | self.config = config 13 | self.model_constructor = model_constructor 14 | 15 | def _edit_loss_fn(pred, targ): 16 | return masked_log_probs(pred, targ, shift=shift_targets(self.config)) 17 | 18 | self.edit_loss_fn = _edit_loss_fn 19 | self.loc_loss_fn = _edit_loss_fn 20 | 21 | def edit(self, batch, condition=None, detach_history=False): 22 | raise NotImplementedError 23 | 24 | def forward(self, *inputs, **kwargs): 25 | return _logits(self.model(*inputs, **kwargs)) 26 | 27 | def outer_parameters(self): 28 | return self.parameters() 29 | 30 | def base_loss(self, input_ids, attention_masks, label_ids): 31 | pass 32 | -------------------------------------------------------------------------------- /memit/baselines/mend/hooks.py: -------------------------------------------------------------------------------- 1 | from .utils import parent_module 2 | 3 | 4 | def linear_backward_hook(mod, grad_in, grad_out): 5 | if not hasattr(mod, "weight"): 6 | print(f"{mod} has no weight!") 7 | return 8 | 9 | if hasattr(mod.weight, "__x__"): 10 | assert len(grad_out) == 1 11 | # mod.weight.__bgrad__ = grad_out[0].unsqueeze(-1) * mod.__x__[0].unsqueeze(-2) 12 | mod.weight.__delta__ = grad_out[0].detach() 13 | else: 14 | print(f"{mod} has no __x__") 15 | 16 | 17 | def linear_forward_hook(mod, activations, output): 18 | assert len(activations) == 1 19 | mod.weight.__x__ = activations[0].detach() 20 | 21 | 22 | def hook_model(model, pnames): 23 | handles = [] 24 | for m in [parent_module(model, pname) for pname in pnames]: 25 | handles.append(m.register_full_backward_hook(linear_backward_hook)) 26 | handles.append(m.register_forward_hook(linear_forward_hook)) 27 | 28 | model.handles = handles 29 | -------------------------------------------------------------------------------- /memit/baselines/mend/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def kl_loc_loss(pre, post, mask=None): 6 | pre = pre.to(torch.float32) 7 | post = post.to(torch.float32) 8 | 9 | sequence = pre.dim() == 3 10 | pre_ = pre.view(-1, pre.shape[-1]) 11 | post_ = post.view(pre_.shape) 12 | assert pre_.shape[0] == post_.shape[0] 13 | 14 | if not sequence: 15 | if pre_.shape[-1] == 1: # No masking needed for binary classification 16 | return (pre.sigmoid() * (F.logsigmoid(pre) - F.logsigmoid(post))).mean() + ( 17 | (-pre).sigmoid() * (F.logsigmoid(-pre) - F.logsigmoid(-post)) 18 | ).mean() 19 | else: # We have sequences of predictions; masking needed 20 | if pre_.shape[-1] > 1: 21 | assert mask is not None 22 | mask_ = mask.view(pre_.shape[0]) 23 | kl = ( 24 | pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1)) 25 | ).sum(-1) 26 | return (kl * mask_).sum() / mask_.sum() 27 | 28 | raise NotImplementedError 29 | 30 | 31 | def binary_log_probs(pred, targ): 32 | neg_mask = torch.ones_like(pred) 33 | neg_mask[targ == 0] *= -1 34 | pred = pred * neg_mask 35 | log_probs = F.logsigmoid(pred) 36 | acc = (log_probs.exp() > 0.5).float().mean() 37 | return { 38 | "acc": acc, 39 | "log_prob": log_probs.mean(), 40 | "prob": log_probs.exp().mean(), 41 | "nll": -log_probs.mean(), 42 | "n_tokens": log_probs.shape[0], 43 | } 44 | 45 | 46 | def multiclass_log_probs(pred, targ, shift=True): 47 | NULL_TOKEN = 0 # a placeholder used for masked target locations 48 | 49 | pred = pred.clone() 50 | targ = targ.clone() 51 | if shift and pred.dim() == 3: # Dealing with sequences 52 | pred = pred[:, :-1] # Remove last prediction in sequence 53 | targ = targ[:, 1:] # Shift to align predictions and targets 54 | 55 | mask = targ != -100 56 | targ[~mask] = NULL_TOKEN # Can be any valid token, since we'll throw them out 57 | unmasked_log_probs = pred.log_softmax(-1).gather(-1, targ.unsqueeze(-1)).squeeze(-1) 58 | 59 | pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN) 60 | correct = pred_ids == targ 61 | if pred.dim() == 3: 62 | correct = (pred_ids == targ).all(-1) # We want to get the whole sequence right 63 | acc = correct.float().mean() 64 | 65 | n_tokens = mask.float().sum() 66 | log_prob = (unmasked_log_probs * mask.float()).sum() / n_tokens 67 | prob = (unmasked_log_probs.exp() * mask.float()).sum() / n_tokens 68 | return { 69 | "acc": acc, 70 | "log_prob": log_prob, 71 | "prob": prob, 72 | "n_tokens": n_tokens, 73 | "nll": -log_prob, 74 | } 75 | 76 | 77 | def masked_log_probs(pred, targ, shift=True): 78 | pred = pred.to(torch.float32) 79 | 80 | if not (pred.dim() == 2 or pred.dim() == 3): 81 | raise RuntimeError(f"Expected pred to have 2 or 3 dimensions, got {pred.shape}") 82 | 83 | if pred.shape[-1] == 1: 84 | return binary_log_probs(pred, targ) 85 | else: 86 | return multiclass_log_probs(pred, targ, shift=shift) 87 | -------------------------------------------------------------------------------- /memit/baselines/mend/mend_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from util.hparams import HyperParams 4 | 5 | 6 | @dataclass 7 | class MENDHyperParams(HyperParams): 8 | lr_scale: float 9 | n_toks: int 10 | model_name: str 11 | counterfact: bool 12 | mini: bool 13 | zsre: bool 14 | -------------------------------------------------------------------------------- /memit/baselines/mend/mend_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from typing import Dict, List 4 | 5 | import hydra 6 | import torch 7 | from transformers import AutoModelForCausalLM, AutoTokenizer 8 | 9 | from util.globals import * 10 | 11 | from .algs.mend import MEND 12 | from .mend_hparams import MENDHyperParams 13 | 14 | 15 | class MendRewriteExecutor: 16 | method_name = "MEND" 17 | 18 | def __init__(self): 19 | self.is_init = False 20 | 21 | def init_model(self, model, tok, params): 22 | train_ds = ( 23 | "counterfact-" if params.counterfact else ("zsre-" if params.zsre else "") 24 | ) 25 | mini_string = "mini-" if params.mini else "" 26 | 27 | model_name = "gpt2-xl" if params.model_name == "gpt2-xl" else "gpt-j-6b" 28 | modelcode = "gpt2xl" if params.model_name == "gpt2-xl" else "gptj" 29 | model_filename = ( 30 | f"mend-{mini_string}{params.n_toks}tok-{train_ds}{model_name}.pt" 31 | ) 32 | model_dir = "baselines/mend/weights" 33 | 34 | os.makedirs(model_dir, exist_ok=True) 35 | if not os.path.isfile(f"{model_dir}/{model_filename}"): 36 | remote_url = f"{REMOTE_ROOT_URL}/data/weights/{model_filename}" 37 | print(f"Attemping to download from {remote_url}") 38 | torch.hub.download_url_to_file(remote_url, f"{model_dir}/{model_filename}") 39 | with hydra.initialize(config_path="config", job_name="run"): 40 | config = hydra.compose( 41 | config_name="config", 42 | overrides=[ 43 | "+alg=mend", 44 | "+experiment=gen", 45 | f"+model={modelcode}", 46 | f"data.path=data/{params.n_toks}token/data/self_sample/", 47 | ], 48 | ) 49 | 50 | def add_padding(tokenizer, model): 51 | tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 52 | model.resize_token_embeddings(len(tokenizer)) 53 | model.transformer.wte.weight.data[ 54 | -1 55 | ] = model.transformer.wte.weight.data.mean(0) 56 | 57 | # Customize the gpt2xl and tokenizer 58 | self.model = model 59 | self.tokenizer = tok 60 | add_padding(self.tokenizer, self.model) 61 | 62 | # Load the trained MEND model 63 | self.alg = MEND(self.model, config, lambda: deepcopy(self.model)) 64 | d = torch.load(f"{model_dir}/{model_filename}") 65 | self.alg.load_state_dict( 66 | {k.replace("gtn.", "mend."): v for k, v in d["model"].items()} 67 | ) 68 | self.alg.cuda() 69 | 70 | # Disable unneeded gradients 71 | for n, p in self.model.named_parameters(): 72 | if n not in config.model.inner_params: 73 | p.requires_grad = False 74 | self.is_init = True 75 | 76 | def reset_model(self): 77 | self.is_init = False 78 | del self.model, self.tokenizer, self.alg 79 | 80 | def apply_to_model( 81 | self, 82 | model: AutoModelForCausalLM, 83 | tok: AutoTokenizer, 84 | requests: List[Dict], 85 | hparams: MENDHyperParams, 86 | copy=False, 87 | return_orig_weights=False, 88 | ): 89 | """ 90 | Given a request, for example 91 | {'prompt': '{} has the position of', 92 | 'subject': 'Charles Herman Helmsing', 93 | 'relation_id': 'P39', 94 | 'target_new': {'str': 'President', 'id': 'Q11696'}, 95 | 'target_true': {'str': 'bishop', 'id': 'Q29182'}} 96 | Returns a dictionary of numpy arrays that specifies 97 | how mend will change the weights of the model. 98 | """ 99 | 100 | if not self.is_init: 101 | self.init_model(model, tok, hparams) 102 | 103 | weights_copy = {} 104 | model = deepcopy(self.model) if copy else self.model 105 | 106 | # Define i/o 107 | targets = [ 108 | (" " if request["target_new"]["str"][0] != " " else "") 109 | + request["target_new"]["str"] 110 | for request in requests 111 | ] 112 | sentences = [ 113 | request["prompt"].format(request["subject"]) + targets[i] 114 | for i, request in enumerate(requests) 115 | ] 116 | 117 | # Tokenize 118 | sent_tok = self.tokenizer(sentences, padding=True, return_tensors="pt").to( 119 | "cuda" 120 | ) 121 | target_tok = self.tokenizer(targets, padding=True, return_tensors="pt").to( 122 | "cuda" 123 | ) 124 | 125 | # Define labels 126 | label_tok = deepcopy(sent_tok["input_ids"]) 127 | for i in range(label_tok.size(0)): 128 | target_len = target_tok["attention_mask"][i].sum() 129 | padding_len = ( 130 | sent_tok["input_ids"].size(1) - sent_tok["attention_mask"][i].sum() 131 | ) 132 | label_tok[i][: -target_len - padding_len] = -100 133 | label_tok[i][label_tok[i] == self.tokenizer.pad_token_id] = -100 134 | 135 | # Run MEND 136 | edit_inner = dict( 137 | input_ids=sent_tok["input_ids"], 138 | attention_mask=sent_tok["attention_mask"], 139 | labels=label_tok, 140 | ) 141 | cond = {k: sent_tok[k] for k in ["input_ids", "attention_mask"]} 142 | _, model_info = self.alg.edit(edit_inner, cond, return_factors=True) 143 | factors = { 144 | k + "." + n: v.detach().cpu().numpy() 145 | for k, pair in model_info["factors"].items() 146 | for n, v in zip("uv", pair) 147 | } 148 | # Also keep these learned LRs. 149 | factors["edit_lrs"] = self.alg.edit_lrs.detach().cpu().numpy() 150 | 151 | # Edit! 152 | d = factors 153 | torch_factors = {k: torch.tensor(v) for k, v in d.items()} 154 | eli = 0 155 | edit_lrs = torch_factors["edit_lrs"] 156 | 157 | with torch.no_grad(): 158 | for n, p in model.named_parameters(): 159 | uname, vname = f"{n}.u", f"{n}.v" 160 | if uname in torch_factors: 161 | if return_orig_weights and n not in weights_copy: 162 | weights_copy[n] = p.detach().clone() 163 | 164 | if "gpt2" in hparams.model_name: 165 | delta = torch_factors[uname].t() @ torch_factors[vname] 166 | elif "gpt-j-6B" in hparams.model_name: 167 | delta = torch_factors[vname].t() @ torch_factors[uname] 168 | else: 169 | raise ValueError("Unknown model") 170 | p.add_((delta * edit_lrs[eli] * hparams.lr_scale).to(p.device)) 171 | eli += 1 172 | 173 | return model, weights_copy 174 | -------------------------------------------------------------------------------- /memit/baselines/mend/models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | 8 | from .utils import scr 9 | 10 | LOG = logging.getLogger(__name__) 11 | 12 | 13 | class CastModule(nn.Module): 14 | def __init__( 15 | self, 16 | module: nn.Module, 17 | in_cast: torch.dtype = torch.float32, 18 | out_cast: torch.dtype = None, 19 | ): 20 | super().__init__() 21 | 22 | self.underlying = module 23 | self.in_cast = in_cast 24 | self.out_cast = out_cast 25 | 26 | def cast(self, obj, dtype): 27 | if dtype is None: 28 | return obj 29 | 30 | if isinstance(obj, torch.Tensor): 31 | return obj.to(dtype) 32 | else: 33 | return obj 34 | 35 | def forward(self, *args, **kwargs): 36 | args = tuple(self.cast(a, self.in_cast) for a in args) 37 | kwargs = {k: self.cast(v, self.in_cast) for k, v in kwargs.items()} 38 | outputs = self.underlying(*args, **kwargs) 39 | if isinstance(outputs, torch.Tensor): 40 | outputs = self.cast(outputs, self.out_cast) 41 | elif isinstance(outputs, tuple): 42 | outputs = tuple(self.cast(o, self.out_cast) for o in outputs) 43 | else: 44 | raise RuntimeError(f"Not sure how to cast type {type(outputs)}") 45 | return outputs 46 | 47 | def extra_repr(self): 48 | return f"in_cast: {self.in_cast}\nout_cast: {self.out_cast}" 49 | 50 | 51 | class BertClassifier(torch.nn.Module): 52 | def __init__(self, model_name, hidden_dim=768): 53 | super().__init__() 54 | self.model = transformers.BertModel.from_pretrained(model_name, cache_dir=scr()) 55 | self.classifier = torch.nn.Linear(hidden_dim, 1) 56 | 57 | @property 58 | def config(self): 59 | return self.model.config 60 | 61 | def forward(self, *args, **kwargs): 62 | filtered_kwargs = {k: v for k, v in kwargs.items() if k != "labels"} 63 | return self.classifier(self.model(*args, **filtered_kwargs)[1]) 64 | 65 | 66 | def get_model(config): 67 | if config.model.class_name == "BertClassifier": 68 | model = BertClassifier(config.model.name) 69 | else: 70 | ModelClass = getattr(transformers, config.model.class_name) 71 | LOG.info( 72 | f"Loading model class {ModelClass} with name {config.model.name} from cache dir {scr()}" 73 | ) 74 | model = ModelClass.from_pretrained(config.model.name, cache_dir=scr()) 75 | 76 | if config.model.pt is not None: 77 | LOG.info(f"Loading model initialization from {config.model.pt}") 78 | state_dict = torch.load(config.model.pt, map_location="cpu") 79 | 80 | try: 81 | model.load_state_dict(state_dict) 82 | except RuntimeError: 83 | LOG.info("Default load failed; stripping prefix and trying again.") 84 | state_dict = {re.sub("^model.", "", k): v for k, v in state_dict.items()} 85 | 86 | model.load_state_dict(state_dict) 87 | 88 | LOG.info("Loaded model initialization") 89 | 90 | if config.dropout is not None: 91 | n_reset = 0 92 | for m in model.modules(): 93 | if isinstance(m, nn.Dropout): 94 | m.p = config.dropout 95 | n_reset += 1 96 | 97 | if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout 98 | if isinstance(m.dropout, float): 99 | m.dropout = config.dropout 100 | n_reset += 1 101 | 102 | if hasattr( 103 | m, "activation_dropout" 104 | ): # Requires for BART, which uses F.dropout 105 | if isinstance(m.activation_dropout, float): 106 | m.activation_dropout = config.dropout 107 | n_reset += 1 108 | 109 | LOG.info(f"Set {n_reset} dropout modules to p={config.dropout}") 110 | 111 | param_names = [n for n, _ in model.named_parameters()] 112 | bad_inner_params = [p for p in config.model.inner_params if p not in param_names] 113 | if len(bad_inner_params) != 0: 114 | raise ValueError( 115 | f"Params {bad_inner_params} do not exist in model of type {type(model)}." 116 | ) 117 | 118 | if config.no_grad_layers is not None: 119 | if config.half: 120 | model.bfloat16() 121 | 122 | def upcast(mod): 123 | modlist = None 124 | for child in mod.children(): 125 | if isinstance(child, nn.ModuleList): 126 | assert modlist is None, f"Found multiple modlists for {mod}" 127 | modlist = child 128 | if modlist is None: 129 | raise RuntimeError("Couldn't find a ModuleList child") 130 | 131 | LOG.info( 132 | f"Setting {len(modlist) - config.no_grad_layers} modules to full precision, with autocasting" 133 | ) 134 | modlist[config.no_grad_layers :].to(torch.float32) 135 | modlist[config.no_grad_layers] = CastModule(modlist[config.no_grad_layers]) 136 | modlist[-1] = CastModule( 137 | modlist[-1], in_cast=torch.float32, out_cast=torch.bfloat16 138 | ) 139 | 140 | parents = [] 141 | if hasattr(model, "transformer"): 142 | parents.append(model.transformer) 143 | if hasattr(model, "encoder"): 144 | parents.append(model.encoder) 145 | if hasattr(model, "decoder"): 146 | parents.append(model.decoder) 147 | if hasattr(model, "model"): 148 | parents.extend([model.model.encoder, model.model.decoder]) 149 | 150 | for t in parents: 151 | t.no_grad_layers = config.no_grad_layers 152 | if config.half: 153 | upcast(t) 154 | 155 | if config.half: 156 | idxs = [] 157 | for p in config.model.inner_params: 158 | for comp in p.split("."): 159 | if comp.isdigit(): 160 | idxs.append(int(comp)) 161 | max_idx, min_idx = str(max(idxs)), str(config.no_grad_layers) 162 | for pidx, p in enumerate(config.model.inner_params): 163 | comps = p.split(".") 164 | if max_idx in comps or min_idx in comps: 165 | index = ( 166 | comps.index(max_idx) 167 | if max_idx in comps 168 | else comps.index(min_idx) 169 | ) 170 | comps.insert(index + 1, "underlying") 171 | new_p = ".".join(comps) 172 | LOG.info( 173 | f"Replacing config.model.inner_params[{pidx}] '{p}' -> '{new_p}'" 174 | ) 175 | config.model.inner_params[pidx] = new_p 176 | 177 | return model 178 | 179 | 180 | def get_tokenizer(config): 181 | tok_name = ( 182 | config.model.tokenizer_name 183 | if config.model.tokenizer_name is not None 184 | else config.model.name 185 | ) 186 | return getattr(transformers, config.model.tokenizer_class).from_pretrained( 187 | tok_name, cache_dir=scr() 188 | ) 189 | 190 | 191 | if __name__ == "__main__": 192 | m = BertClassifier("bert-base-uncased") 193 | m(torch.arange(5)[None, :]) 194 | import pdb 195 | 196 | pdb.set_trace() 197 | -------------------------------------------------------------------------------- /memit/baselines/mend/oracle.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | from higher.patch import monkeypatch as make_functional 6 | from losses import kl_loc_loss, masked_log_probs 7 | 8 | 9 | def test_rank1(model, dataset, config): 10 | model.eval() 11 | generator = dataset.edit_generator(21) 12 | 13 | history = [] 14 | for example in generator: 15 | edit_model = make_functional(model, track_higher_grads=False) 16 | residuals = {} 17 | opt_list = [] 18 | print(config.model.inner_params) 19 | for n, p in edit_model.named_parameters(): 20 | if n in config.model.inner_params: 21 | std = 0.01 22 | u = nn.Parameter(torch.randn(p.shape[0], 1, device=p.device) * std) 23 | v = nn.Parameter(torch.randn(1, p.shape[1], device=p.device) * std) 24 | assert ( 25 | u @ v 26 | ).shape == p.shape, f"got {(u@v).shape}, expected {p.shape}" 27 | 28 | residuals[n] = (u, v) 29 | opt_list.extend([u, v]) 30 | 31 | res_opt = torch.optim.SGD(opt_list, lr=100) 32 | 33 | acc = 0 34 | it = 0 35 | ids_train = example["loc_ids"][:10] 36 | ids_val = example["loc_ids"][10:] 37 | with torch.inference_mode(): 38 | original_logits_train = model(ids_train) 39 | original_logits_val = model(ids_val) 40 | if hasattr(original_logits_train, "logits"): 41 | original_logits_train = original_logits_train.logits 42 | original_logits_val = original_logits_val.logits 43 | 44 | while acc < 1 and it < 1000: 45 | fast_params = [] 46 | for n, p in edit_model.named_parameters(): 47 | if n in residuals: 48 | u, v = residuals[n] 49 | fast_params.append(p.detach() + (u @ v)) 50 | else: 51 | fast_params.append(p.detach()) 52 | 53 | loc_pred = edit_model(ids_train, params=fast_params) 54 | if hasattr(loc_pred, "logits"): 55 | loc_pred = loc_pred.logits 56 | 57 | loc_loss = kl_loc_loss(original_logits_train, loc_pred) 58 | 59 | pred_log = edit_model(example["edit_inner_ids"], params=fast_params) 60 | if hasattr(pred_log, "logits"): 61 | pred_log = pred_log.logits 62 | prob_dict = masked_log_probs(pred_log, example["edit_inner_labels"]) 63 | edit_loss = prob_dict["nll"] 64 | acc = prob_dict["acc"] 65 | 66 | loss = loc_loss + 0.0002 * edit_loss 67 | with torch.inference_mode(): 68 | loc_pred_val = edit_model(ids_val, params=fast_params) 69 | if hasattr(loc_pred_val, "logits"): 70 | loc_pred_val = loc_pred_val.logits 71 | 72 | if pred_log.dim() == 3: 73 | facc = ( 74 | ( 75 | pred_log.argmax(-1)[0, -10:-1] 76 | == example["edit_inner_labels"][0, -9:] 77 | ) 78 | .float() 79 | .mean() 80 | ) 81 | ret = ( 82 | (original_logits_val.argmax(-1) == loc_pred_val.argmax(-1)) 83 | .float() 84 | .mean() 85 | ) 86 | else: 87 | facc = (pred_log > 0) == example["edit_inner_labels"] 88 | ret = ( 89 | ((original_logits_val > 0) == (loc_pred_val > 0)).float().mean() 90 | ) 91 | 92 | print( 93 | f"{it}, ({loss.item():.6f}, {loc_loss.item():.4f}, {edit_loss.item():.4f}), {facc.item():.2f}, {ret.item():.4f} {(u@v).view(-1).norm().item():.5f}", 94 | end="\r", 95 | ) 96 | 97 | for p, g in zip(opt_list, torch.autograd.grad(loss, opt_list)): 98 | p.grad = g 99 | res_opt.step() 100 | res_opt.zero_grad() 101 | 102 | it += 1 103 | 104 | if acc == 1: 105 | history.append(1) 106 | else: 107 | history.append(0) 108 | 109 | print() 110 | print(len(history), sum(history) / len(history), ret.item()) 111 | -------------------------------------------------------------------------------- /memit/baselines/mend/requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core 2 | numpy 3 | torch 4 | click==7.1.2 # Spacy breaks for click>=8.0 5 | spacy 6 | allennlp 7 | git+git://github.com/eric-mitchell/higher@master # For in-place functional models 8 | git+git://github.com/eric-mitchell/transformers@master # To enable gradient disabling for some models 9 | datasets 10 | jsonlines 11 | wandb 12 | -------------------------------------------------------------------------------- /memit/baselines/mend/run.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import importlib 3 | import logging 4 | import random 5 | 6 | import hydra 7 | import models 8 | import numpy as np 9 | import torch 10 | import utils 11 | from omegaconf import OmegaConf 12 | from trainer import EditTrainer 13 | 14 | OmegaConf.register_new_resolver("uuid", lambda: utils.uuid()) 15 | 16 | 17 | logging.basicConfig( 18 | format="%(asctime)s - %(levelname)s [%(filename)s:%(lineno)d] %(message)s", 19 | level=logging.INFO, 20 | ) 21 | LOG = logging.getLogger(__name__) 22 | 23 | 24 | def add_padding(tokenizer, model): 25 | tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 26 | model.resize_token_embeddings(len(tokenizer)) 27 | model.transformer.wte.weight.data[-1] = model.transformer.wte.weight.data.mean(0) 28 | 29 | 30 | @hydra.main(config_path="config", config_name="config") 31 | def run(config): 32 | LOG.info(f"\n\n{OmegaConf.to_yaml(config)}\n") 33 | base_dir = hydra.utils.get_original_cwd() 34 | LOG.info(f"Project base directory: {base_dir}") 35 | 36 | random.seed(config.seed) 37 | np.random.seed(config.seed) 38 | torch.manual_seed(config.seed) 39 | 40 | model = models.get_model(config) 41 | tokenizer = models.get_tokenizer(config) 42 | 43 | if config.task == "gen" or config.task == "wiki": 44 | add_padding(tokenizer, model) 45 | from data_classes.wiki import GenDataset 46 | 47 | train_set = GenDataset("train", tokenizer, config, config.data.path, pct=10) 48 | val_set = GenDataset("validation", tokenizer, config, config.data.path, pct=10) 49 | elif config.task == "fc" or config.task == "fever": 50 | from data_classes.fever import BinaryAugmentedKILT 51 | 52 | train_set = BinaryAugmentedKILT( 53 | tokenizer, f"{base_dir}/data/fever/fever-train-kilt.jsonl", config 54 | ) 55 | val_set = BinaryAugmentedKILT( 56 | tokenizer, f"{base_dir}/data/fever/fever-dev-kilt.jsonl", config 57 | ) 58 | elif config.task == "qa" or config.task == "zsre": 59 | from data_classes.zsre import Seq2SeqAugmentedKILT 60 | 61 | train_set = Seq2SeqAugmentedKILT( 62 | tokenizer, 63 | f"{base_dir}/data/zsre/structured_zeroshot-train-new_annotated_final.jsonl", 64 | config, 65 | ) 66 | val_set = Seq2SeqAugmentedKILT( 67 | tokenizer, 68 | f"{base_dir}/data/zsre/structured_zeroshot-dev-new_annotated_final.jsonl", 69 | config, 70 | ) 71 | else: 72 | raise ValueError(f"Unrecognized task {config.task}") 73 | 74 | alg_module = importlib.import_module(f"algs.{config.alg}") 75 | LOG.info(f"Loading class {config.alg.upper()} from module {alg_module}") 76 | AlgClass = getattr(alg_module, config.alg.upper()) 77 | alg = AlgClass(model, config, lambda: copy.deepcopy(model)) 78 | 79 | if config.alg == "ft" and config.ft.locality.enabled: 80 | if config.ft.locality.oracle: 81 | alg.loc_sampler = train_set.edit_generator( 82 | config.ft.locality.batch_size + 1 83 | ) 84 | else: 85 | state = np.random.get_state() 86 | np.random.seed(0) 87 | loc_batch = next( 88 | train_set.edit_generator(config.ft.locality.batch_size + 1) 89 | )["loc"] 90 | np.random.set_state(state) 91 | alg.loc_ids = loc_batch["input_ids"] 92 | alg.loc_masks = loc_batch["attention_mask"] 93 | 94 | trainer = EditTrainer(alg, config, train_set, val_set) 95 | trainer.run() 96 | 97 | 98 | if __name__ == "__main__": 99 | run() 100 | -------------------------------------------------------------------------------- /memit/dsets/__init__.py: -------------------------------------------------------------------------------- 1 | from .attr_snippets import AttributeSnippets 2 | from .counterfact import CounterFactDataset, MultiCounterFactDataset 3 | from .knowns import KnownsDataset 4 | from .tfidf_stats import get_tfidf_vectorizer 5 | from .zsre import MENDQADataset 6 | -------------------------------------------------------------------------------- /memit/dsets/attr_snippets.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | from pathlib import Path 4 | 5 | import torch 6 | 7 | from util.globals import * 8 | 9 | REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/attribute_snippets.json" 10 | 11 | 12 | class AttributeSnippets: 13 | """ 14 | Contains wikipedia snippets discussing entities that have some property. 15 | 16 | More formally, given a tuple t = (s, r, o): 17 | - Let snips = AttributeSnippets(DATA_DIR) 18 | - snips[r][o] is a list of wikipedia articles for all s' such that t' = (s', r, o) is valid. 19 | """ 20 | 21 | def __init__(self, data_dir: str): 22 | data_dir = Path(data_dir) 23 | snips_loc = data_dir / "attribute_snippets.json" 24 | if not snips_loc.exists(): 25 | print(f"{snips_loc} does not exist. Downloading from {REMOTE_URL}") 26 | data_dir.mkdir(exist_ok=True, parents=True) 27 | torch.hub.download_url_to_file(REMOTE_URL, snips_loc) 28 | 29 | with open(snips_loc, "r") as f: 30 | snippets_list = json.load(f) 31 | 32 | snips = collections.defaultdict(lambda: collections.defaultdict(list)) 33 | 34 | for el in snippets_list: 35 | rid, tid = el["relation_id"], el["target_id"] 36 | for sample in el["samples"]: 37 | snips[rid][tid].append(sample) 38 | 39 | self._data = snips 40 | self.snippets_list = snippets_list 41 | 42 | def __getitem__(self, item): 43 | return self._data[item] 44 | -------------------------------------------------------------------------------- /memit/dsets/counterfact.py: -------------------------------------------------------------------------------- 1 | import json 2 | import typing 3 | from pathlib import Path 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | from util.globals import * 9 | 10 | REMOTE_ROOT = f"{REMOTE_ROOT_URL}/data/dsets" 11 | 12 | 13 | class ConceptvectorsDataset(Dataset): 14 | def __init__( 15 | self, 16 | data_dir: str, 17 | multi: bool = False, 18 | size: typing.Optional[int] = None, 19 | *args, 20 | **kwargs, 21 | ): 22 | data_dir = Path(data_dir) 23 | cf_loc = data_dir / ( 24 | "counterfact.json" if not multi else "multi_counterfact.json" 25 | ) 26 | if not cf_loc.exists(): 27 | remote_url = f"{REMOTE_ROOT}/{'multi_' if multi else ''}counterfact.json" 28 | print(f"{cf_loc} does not exist. Downloading from {remote_url}") 29 | data_dir.mkdir(exist_ok=True, parents=True) 30 | torch.hub.download_url_to_file(remote_url, cf_loc) 31 | 32 | with open(cf_loc, "r") as f: 33 | self.data = json.load(f) 34 | if size is not None: 35 | self.data = self.data[:size] 36 | 37 | print(f"Loaded dataset with {len(self)} elements") 38 | 39 | def __len__(self): 40 | return len(self.data) 41 | 42 | def __getitem__(self, item): 43 | return self.data[item] 44 | 45 | 46 | class CounterFactDataset(Dataset): 47 | def __init__( 48 | self, 49 | data_dir: str, 50 | multi: bool = False, 51 | size: typing.Optional[int] = None, 52 | *args, 53 | **kwargs, 54 | ): 55 | data_dir = Path(data_dir) 56 | cf_loc = data_dir / ( 57 | "counterfact.json" if not multi else "multi_counterfact.json" 58 | ) 59 | if not cf_loc.exists(): 60 | remote_url = f"{REMOTE_ROOT}/{'multi_' if multi else ''}counterfact.json" 61 | print(f"{cf_loc} does not exist. Downloading from {remote_url}") 62 | data_dir.mkdir(exist_ok=True, parents=True) 63 | torch.hub.download_url_to_file(remote_url, cf_loc) 64 | 65 | with open(cf_loc, "r") as f: 66 | self.data = json.load(f) 67 | if size is not None: 68 | self.data = self.data[:size] 69 | 70 | print(f"Loaded dataset with {len(self)} elements") 71 | 72 | def __len__(self): 73 | return len(self.data) 74 | 75 | def __getitem__(self, item): 76 | return self.data[item] 77 | 78 | 79 | class MultiCounterFactDataset(CounterFactDataset): 80 | def __init__( 81 | self, data_dir: str, size: typing.Optional[int] = None, *args, **kwargs 82 | ): 83 | super().__init__(data_dir, *args, multi=True, size=size, **kwargs) 84 | -------------------------------------------------------------------------------- /memit/dsets/knowns.py: -------------------------------------------------------------------------------- 1 | import json 2 | import typing 3 | from pathlib import Path 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | from util.globals import * 9 | 10 | REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/known_1000.json" 11 | 12 | 13 | class KnownsDataset(Dataset): 14 | def __init__(self, data_dir: str, *args, **kwargs): 15 | data_dir = Path(data_dir) 16 | known_loc = data_dir / "known_1000.json" 17 | if not known_loc.exists(): 18 | print(f"{known_loc} does not exist. Downloading from {REMOTE_URL}") 19 | data_dir.mkdir(exist_ok=True, parents=True) 20 | torch.hub.download_url_to_file(REMOTE_URL, known_loc) 21 | 22 | with open(known_loc, "r") as f: 23 | self.data = json.load(f) 24 | 25 | print(f"Loaded dataset with {len(self)} elements") 26 | 27 | def __len__(self): 28 | return len(self.data) 29 | 30 | def __getitem__(self, item): 31 | return self.data[item] 32 | -------------------------------------------------------------------------------- /memit/dsets/tfidf_stats.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import chain 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import scipy.sparse as sp 7 | import torch 8 | from sklearn.feature_extraction.text import TfidfVectorizer 9 | 10 | from dsets import AttributeSnippets 11 | from util.globals import * 12 | 13 | REMOTE_IDF_URL = f"{REMOTE_ROOT_URL}/data/dsets/idf.npy" 14 | REMOTE_VOCAB_URL = f"{REMOTE_ROOT_URL}/data/dsets/tfidf_vocab.json" 15 | 16 | 17 | def get_tfidf_vectorizer(data_dir: str): 18 | """ 19 | Returns an sklearn TF-IDF vectorizer. See their website for docs. 20 | Loading hack inspired by some online blog post lol. 21 | """ 22 | 23 | data_dir = Path(data_dir) 24 | 25 | idf_loc, vocab_loc = data_dir / "idf.npy", data_dir / "tfidf_vocab.json" 26 | if not (idf_loc.exists() and vocab_loc.exists()): 27 | collect_stats(data_dir) 28 | 29 | idf = np.load(idf_loc) 30 | with open(vocab_loc, "r") as f: 31 | vocab = json.load(f) 32 | 33 | class MyVectorizer(TfidfVectorizer): 34 | TfidfVectorizer.idf_ = idf 35 | 36 | vec = MyVectorizer() 37 | vec.vocabulary_ = vocab 38 | vec._tfidf._idf_diag = sp.spdiags(idf, diags=0, m=len(idf), n=len(idf)) 39 | 40 | return vec 41 | 42 | 43 | def collect_stats(data_dir: str): 44 | """ 45 | Uses wikipedia snippets to collect statistics over a corpus of English text. 46 | Retrieved later when computing TF-IDF vectors. 47 | """ 48 | 49 | data_dir = Path(data_dir) 50 | data_dir.mkdir(exist_ok=True, parents=True) 51 | idf_loc, vocab_loc = data_dir / "idf.npy", data_dir / "tfidf_vocab.json" 52 | 53 | try: 54 | print(f"Downloading IDF cache from {REMOTE_IDF_URL}") 55 | torch.hub.download_url_to_file(REMOTE_IDF_URL, idf_loc) 56 | print(f"Downloading TF-IDF vocab cache from {REMOTE_VOCAB_URL}") 57 | torch.hub.download_url_to_file(REMOTE_VOCAB_URL, vocab_loc) 58 | return 59 | except Exception as e: 60 | print(f"Error downloading file:", e) 61 | print("Recomputing TF-IDF stats...") 62 | 63 | snips_list = AttributeSnippets(data_dir).snippets_list 64 | documents = list(chain(*[[y["text"] for y in x["samples"]] for x in snips_list])) 65 | 66 | vec = TfidfVectorizer() 67 | vec.fit(documents) 68 | 69 | idfs = vec.idf_ 70 | vocab = vec.vocabulary_ 71 | 72 | np.save(data_dir / "idf.npy", idfs) 73 | with open(data_dir / "tfidf_vocab.json", "w") as f: 74 | json.dump(vocab, f, indent=1) 75 | -------------------------------------------------------------------------------- /memit/dsets/zsre.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import torch 5 | from transformers import AutoTokenizer 6 | 7 | from util.globals import * 8 | 9 | REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/zsre_mend_eval.json" 10 | 11 | 12 | class MENDQADataset: 13 | """ 14 | Dataset of factual knowledge based on zsRE. 15 | Specifically selected from the QA validation slice from Mitchell et al. 16 | Project page: http://nlp.cs.washington.edu/zeroshot/ 17 | """ 18 | 19 | def __init__(self, data_dir: str, tok: AutoTokenizer, size=None, *args, **kwargs): 20 | data_dir = Path(data_dir) 21 | zsre_loc = data_dir / "zsre_mend_eval.json" 22 | if not zsre_loc.exists(): 23 | print(f"{zsre_loc} does not exist. Downloading from {REMOTE_URL}") 24 | data_dir.mkdir(exist_ok=True, parents=True) 25 | torch.hub.download_url_to_file(REMOTE_URL, zsre_loc) 26 | 27 | with open(zsre_loc, "r") as f: 28 | raw = json.load(f) 29 | 30 | data = [] 31 | for i, record in enumerate(raw): 32 | assert ( 33 | "nq question: " in record["loc"] 34 | ), f"Neighborhood prompt missing `nq question:`. Check for errors?" 35 | ans_toks = tok(" " + record["loc_ans"])["input_ids"] 36 | data.append( 37 | { 38 | "case_id": i, 39 | "requested_rewrite": { 40 | "prompt": record["src"].replace(record["subject"], "{}"), 41 | "subject": record["subject"], 42 | "target_new": {"str": record["answers"][0]}, 43 | "target_true": {"str": "<|endoftext|>"}, 44 | }, 45 | "paraphrase_prompts": [record["rephrase"]], 46 | "neighborhood_prompts": [ 47 | { 48 | "prompt": record["loc"] + "?" + tok.decode(ans_toks[:i]), 49 | "target": tok.decode(ans_toks[i]), 50 | } 51 | for i in range(len(ans_toks)) 52 | ], 53 | "attribute_prompts": [], 54 | "generation_prompts": [], 55 | } 56 | ) 57 | 58 | self._data = data[:size] 59 | 60 | def __getitem__(self, item): 61 | return self._data[item] 62 | 63 | def __len__(self): 64 | return len(self._data) 65 | -------------------------------------------------------------------------------- /memit/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuaihong/ConceptVectors/607591b415043f7692bc17a9748de3d8ff3fc0c7/memit/experiments/__init__.py -------------------------------------------------------------------------------- /memit/experiments/py/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Dict, List, Tuple 4 | 5 | import torch 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | 8 | from baselines.ft import FTHyperParams, apply_ft_to_model 9 | from memit import MEMITHyperParams, apply_memit_to_model 10 | from rome import ROMEHyperParams, apply_rome_to_model 11 | from util import nethook 12 | from util.generate import generate_fast 13 | from util.globals import * 14 | 15 | 16 | def demo_model_editing( 17 | model: AutoModelForCausalLM, 18 | tok: AutoTokenizer, 19 | requests: List[Dict], 20 | generation_prompts: List[str], 21 | alg_name: str = "ROME", 22 | ) -> Tuple[AutoModelForCausalLM, Dict[str, torch.Tensor]]: 23 | """ 24 | Applies the selected model editing algorithm. Generates text both before and after 25 | for comparison of model behavior. Returns the updated model and the original values of 26 | weights that were changed. 27 | """ 28 | 29 | nethook.set_requires_grad(True, model) 30 | 31 | RewritingParamsClass, apply_method, hparams_prefix, hparams_suffix = load_alg( 32 | alg_name 33 | ) 34 | params_name = ( 35 | HPARAMS_DIR 36 | / hparams_prefix 37 | / f"{model.config._name_or_path.replace('/', '_')}{hparams_suffix}.json" 38 | ) 39 | 40 | print_loud(f"Retrieving {alg_name} hyperparameters") 41 | print("Loading from", params_name) 42 | hparams = RewritingParamsClass.from_json(params_name) 43 | print(hparams) 44 | 45 | print_loud("Generating pre-update text") 46 | pre_update_text = generate_fast(model, tok, generation_prompts, max_out_len=100) 47 | print(pre_update_text) 48 | 49 | print_loud(f"Applying {alg_name} to model") 50 | model_new, orig_weights = apply_method( 51 | model, 52 | tok, 53 | requests, 54 | hparams, 55 | return_orig_weights=True, 56 | ) 57 | 58 | print_loud("Generating post-update text") 59 | post_update_text = generate_fast( 60 | model_new, tok, generation_prompts, max_out_len=100 61 | ) 62 | print(post_update_text) 63 | 64 | print_loud("Summarizing differences") 65 | for i, (prompt, pre, post) in enumerate( 66 | zip(generation_prompts, pre_update_text, post_update_text) 67 | ): 68 | if i > 0: 69 | print("".join(["-" for _ in range(10)])) 70 | 71 | prompt_str = "[Prompt]:" 72 | pre_str = f"[Pre-{alg_name}]:" 73 | post_str = f"[Post-{alg_name}]:" 74 | pad_to = 1 + max(len(prompt_str), len(pre_str), len(post_str)) 75 | 76 | for s, t in zip([prompt_str, post_str, pre_str], [prompt, post, pre]): 77 | print(s.ljust(pad_to), t) 78 | 79 | return model_new, orig_weights 80 | 81 | 82 | def load_alg(alg_name): 83 | """ 84 | Loads dependencies for the desired algorithm. 85 | Implementation is slightly awkward to prevent unnecessary imports on Colab. 86 | 87 | The return value is a tuple of the following: 88 | 1. Class for storing hyperparameters 89 | 2. Method for applying rewrites 90 | 3. Location of parameters 91 | 4. Predefined suffix for the param file 92 | """ 93 | assert alg_name in [ 94 | "FT", 95 | "FT-L", 96 | "FT-AttnEdit", 97 | "MEND", 98 | "MEND-CF", 99 | "MEND-zsRE", 100 | "ROME", 101 | "MEMIT", 102 | ] 103 | 104 | if alg_name == "ROME": 105 | return ROMEHyperParams, apply_rome_to_model, "ROME", "" 106 | elif alg_name == "MEMIT": 107 | return MEMITHyperParams, apply_memit_to_model, "MEMIT", "" 108 | elif "FT" in alg_name: 109 | d = { 110 | "FT": (FTHyperParams, apply_ft_to_model, "FT", "_unconstr"), 111 | "FT-AttnEdit": (FTHyperParams, apply_ft_to_model, "FT", "_attn"), 112 | "FT-L": (FTHyperParams, apply_ft_to_model, "FT", "_constr"), 113 | } 114 | return d[alg_name] 115 | else: 116 | from baselines.mend import MENDHyperParams, MendRewriteExecutor 117 | 118 | d = { 119 | "MEND": (MENDHyperParams, MendRewriteExecutor().apply_to_model, "MEND", ""), 120 | "MEND-CF": ( 121 | MENDHyperParams, 122 | MendRewriteExecutor().apply_to_model, 123 | "MEND", 124 | "_CF", 125 | ), 126 | "MEND-zsRE": ( 127 | MENDHyperParams, 128 | MendRewriteExecutor().apply_to_model, 129 | "MEND", 130 | "_zsRE", 131 | ), 132 | } 133 | return d[alg_name] 134 | 135 | 136 | def print_loud(x, pad=3): 137 | """ 138 | Prints a string with # box for emphasis. 139 | 140 | Example: 141 | ############################ 142 | # # 143 | # Applying ROME to model # 144 | # # 145 | ############################ 146 | """ 147 | 148 | n = len(x) 149 | print() 150 | print("".join(["#" for _ in range(n + 2 * pad)])) 151 | print("#" + "".join([" " for _ in range(n + 2 * (pad - 1))]) + "#") 152 | print( 153 | "#" 154 | + "".join([" " for _ in range(pad - 1)]) 155 | + x 156 | + "".join([" " for _ in range(pad - 1)]) 157 | + "#" 158 | ) 159 | print("#" + "".join([" " for _ in range(n + 2 * (pad - 1))]) + "#") 160 | print("".join(["#" for _ in range(n + 2 * pad)])) 161 | 162 | 163 | class StopExecution(Exception): 164 | def _render_traceback_(self): 165 | pass 166 | 167 | 168 | def stop_execution(): 169 | raise StopExecution 170 | -------------------------------------------------------------------------------- /memit/experiments/py/eval_utils_zsre.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains evaluation utilities for pytorch-based rewriting methods. 3 | To use, simply call `compute_rewrite_quality_zsre` with the 4 | appropriate arguments, which returns a dictionary containing them. 5 | """ 6 | 7 | import typing 8 | from itertools import chain 9 | 10 | import numpy as np 11 | import torch 12 | from sklearn.feature_extraction.text import TfidfVectorizer 13 | from transformers import AutoModelForCausalLM, AutoTokenizer 14 | 15 | from dsets import AttributeSnippets 16 | 17 | 18 | def compute_rewrite_quality_zsre( 19 | model: AutoModelForCausalLM, 20 | tok: AutoTokenizer, 21 | record: typing.Dict, 22 | snips: AttributeSnippets, 23 | vec: TfidfVectorizer, 24 | ) -> typing.Dict: 25 | """ 26 | Given a rewritten model, computes generalization and specificity metrics for 27 | the desired rewrite (passed in via the CounterFact dataset record). Returns a 28 | dictionary containing those metrics. 29 | 30 | :param model: Rewritten model 31 | :param tok: Tokenizer 32 | :param record: CounterFact dataset record 33 | :paran snips: ??? 34 | :param vec: ??? 35 | :return: Dictionary containing rewriting metrics 36 | """ 37 | 38 | # First, unpack rewrite evaluation record. 39 | subject, target_new, target_true = ( 40 | record["requested_rewrite"][x] for x in ["subject", "target_new", "target_true"] 41 | ) 42 | rewrite_prompts = [record["requested_rewrite"]["prompt"].format(subject)] 43 | paraphrase_prompts = record["paraphrase_prompts"] 44 | neighborhood_prompts = record["neighborhood_prompts"] 45 | 46 | # Form a list of lists of prefixes to test. 47 | prob_prompts = [ 48 | rewrite_prompts, 49 | paraphrase_prompts, 50 | ] 51 | # Flatten all the evaluated prefixes into one list. 52 | target_tok = tok(" " + target_new["str"])["input_ids"] 53 | inp_prompts_og = list(chain(*prob_prompts)) 54 | inp_prompts = [ 55 | el + tok.decode(target_tok[:i]) 56 | for el in inp_prompts_og 57 | for i in range(len(target_tok)) 58 | ] 59 | inp_targets = [ 60 | tok.decode(target_tok[i]) 61 | for _ in range(len(inp_prompts_og)) 62 | for i in range(len(target_tok)) 63 | ] 64 | 65 | stuff_probs = test_batch_prediction_acc(model, tok, inp_prompts, inp_targets) 66 | 67 | # Predict for neighborhood prompts (dictionary format). 68 | neighborhood_correct = test_batch_prediction_acc( 69 | model, 70 | tok, 71 | [ 72 | el["prompt"].format(record["requested_rewrite"]) 73 | for el in neighborhood_prompts 74 | ], 75 | [el["target"] for el in neighborhood_prompts], 76 | ) 77 | 78 | probs = stuff_probs + neighborhood_correct 79 | 80 | # Unflatten the results again into a list of lists. 81 | cutoffs = [0] + np.cumsum( 82 | [l * len(target_tok) for l in map(len, prob_prompts)] 83 | ).tolist() 84 | ret_probs = [probs[cutoffs[i - 1] : cutoffs[i]] for i in range(1, len(cutoffs))] 85 | # Structure the restuls as a dictionary. 86 | ret = { 87 | f"{key}_correct": ret_probs[i] 88 | for i, key in enumerate( 89 | [ 90 | "rewrite_prompts", 91 | "paraphrase_prompts", 92 | ] 93 | ) 94 | } 95 | ret["neighborhood_prompts_correct"] = neighborhood_correct 96 | 97 | return ret 98 | 99 | 100 | def test_batch_prediction_acc(model, tok, prompts: typing.List[str], target): 101 | prompt_tok = tok( 102 | prompts, 103 | padding=True, 104 | return_tensors="pt", 105 | ).to("cuda") 106 | 107 | with torch.no_grad(): 108 | logits = model(**prompt_tok).logits 109 | last_non_masked = prompt_tok["attention_mask"].sum(1) - 1 110 | to_gather = last_non_masked.unsqueeze(1).repeat(1, logits.size(-1)).unsqueeze(1) 111 | gathered = torch.gather(logits, 1, to_gather).squeeze(1) 112 | ans = torch.argmax(gathered, dim=1) 113 | 114 | correct_id = tok(target, padding=True, return_tensors="pt").to("cuda")[ 115 | "input_ids" 116 | ] 117 | # Temporary hack to deal with foreign characters. 118 | correct_id = correct_id[:, 0].squeeze() 119 | 120 | return (ans == correct_id).detach().cpu().numpy().tolist() 121 | -------------------------------------------------------------------------------- /memit/experiments/sweep.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from copy import deepcopy 5 | from pathlib import Path 6 | from typing import Dict, List, Tuple 7 | 8 | import torch 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | 11 | from experiments.evaluate import HPARAMS_DIR 12 | from experiments.evaluate import main as eval_main 13 | 14 | TMP_PARAMS_TEMPLATE = "sweep_params_tmp_{}_.json" 15 | 16 | 17 | def exec_sweep( 18 | alg_name: str, 19 | model_tok: Tuple[AutoModelForCausalLM, AutoTokenizer], 20 | hparams_fname: str, 21 | ds_name: str, 22 | sweep_dir: Path, 23 | num_records: int, 24 | generation_test_interval: bool, 25 | num_edits: int, 26 | use_cache: bool, 27 | ): 28 | # Configure hparams 29 | with open(HPARAMS_DIR / alg_name / hparams_fname, "r") as f: 30 | hparams_orig = json.load(f) 31 | with open(Path("results") / sweep_dir / "config.json", "r") as f: 32 | sweep_config = json.load(f) 33 | sweep_keys = list(sweep_config.keys()) 34 | 35 | # Sweep 36 | for s_i, state in enumerate(get_states([], sweep_config, sweep_keys)): 37 | # Set dirs 38 | tmp_params_name = TMP_PARAMS_TEMPLATE.format(time.time_ns()) 39 | tmp_params_path = HPARAMS_DIR / alg_name / tmp_params_name 40 | 41 | # Set new hparams 42 | hparams_new = deepcopy(hparams_orig) 43 | for key_num, state_num in enumerate(state): 44 | k = sweep_keys[key_num] 45 | hparams_new[k] = sweep_config[k][state_num] 46 | print(f"Sweep {s_i}: Setting {k} = {hparams_new[k]}") 47 | 48 | with open(tmp_params_path, "w") as f: 49 | json.dump(hparams_new, f) 50 | 51 | # Execute 52 | eval_main( 53 | alg_name, 54 | model_name=model_tok, 55 | hparams_fname=tmp_params_name, 56 | ds_name=ds_name, 57 | dataset_size_limit=num_records, 58 | continue_from_run="run_000", 59 | skip_generation_tests=(generation_test_interval == -1), 60 | generation_test_interval=generation_test_interval, 61 | conserve_memory=False, 62 | dir_name=sweep_dir / f"{num_edits}_edits_setting_{s_i}", 63 | num_edits=num_edits, 64 | use_cache=use_cache, 65 | ) 66 | 67 | # Clean up 68 | os.remove(tmp_params_path) 69 | 70 | 71 | def get_states( 72 | state: List, 73 | sweep_config: Dict, 74 | sweep_keys: List, 75 | ): 76 | """ 77 | Standard recursive procedure for generating all possible configurations. 78 | """ 79 | 80 | ans = [] 81 | if len(state) < len(sweep_config): 82 | for i in range(len(sweep_config[sweep_keys[len(state)]])): 83 | for s in get_states(state + [i], sweep_config, sweep_keys): 84 | ans.append(s) 85 | else: 86 | ans.append(state) 87 | return ans 88 | 89 | 90 | if __name__ == "__main__": 91 | import argparse 92 | 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument( 95 | "--alg_name", choices=["MEMIT", "FT", "ROME", "MEND"], required=True 96 | ) 97 | parser.add_argument( 98 | "--model_name", choices=["gpt2-xl", "EleutherAI/gpt-j-6B"], required=True 99 | ) 100 | parser.add_argument("--hparams_fname", type=str, required=True) 101 | parser.add_argument( 102 | "--ds_name", 103 | choices=["mcf", "cf", "zsre"], 104 | default="mcf", 105 | help="Dataset to perform evaluations on. Either CounterFact (cf), MultiCounterFact (mcf), or zsRE (zsre).", 106 | ) 107 | parser.add_argument("--min_records", type=int, default=None) 108 | parser.add_argument("--max_records", type=int, default=None) 109 | parser.add_argument( 110 | "--num_edits", 111 | type=str, 112 | default="1", 113 | help="Number of rewrites to perform simultaneously.", 114 | ) 115 | parser.add_argument( 116 | "--generation_test_interval", 117 | type=int, 118 | default=-1, 119 | help="One generation test is performed every [flag_value] iterations. If -1, generation tests are skipped.", 120 | ) 121 | parser.add_argument("--sweep_dir", type=str) 122 | parser.add_argument( 123 | "--use_cache", 124 | dest="use_cache", 125 | action="store_true", 126 | help="Use cached k/v pairs (MEMIT and ROME only)", 127 | ) 128 | 129 | args = parser.parse_args() 130 | assert args.sweep_dir is not None, f"Must specify a sweep_dir." 131 | 132 | model = AutoModelForCausalLM.from_pretrained(args.model_name).to("cuda") 133 | tok = AutoTokenizer.from_pretrained(args.model_name) 134 | tok.pad_token = tok.eos_token 135 | 136 | for cur_num_edits in list(map(int, args.num_edits.split(","))): 137 | torch.cuda.empty_cache() 138 | 139 | num_records = ( 140 | None if args.max_records is None 141 | else min(args.max_records, cur_num_edits) 142 | ) 143 | if args.min_records is not None: 144 | num_records = max(args.min_records, cur_num_edits) 145 | 146 | exec_sweep( 147 | args.alg_name, 148 | (model, tok), 149 | args.hparams_fname, 150 | args.ds_name, 151 | Path(args.sweep_dir), 152 | num_records, 153 | args.generation_test_interval, 154 | cur_num_edits, 155 | args.use_cache, 156 | ) 157 | -------------------------------------------------------------------------------- /memit/forget_memit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 循环调用程序,传递不同的次序参数 4 | #for i in {0..94} 5 | #do 6 | # #python -m experiments.evaluate --order=$i 7 | # python -m experiments.evaluate --order=$i 8 | # 9 | #done 10 | 11 | for i in 16 18 21 26 27 38 42 47 49 54 12 | do 13 | #python -m experiments.evaluate --order=$i 14 | python -m experiments.memit_jailbreak_evaluate --order=$i 15 | 16 | done -------------------------------------------------------------------------------- /memit/forget_memit_olmo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 循环调用程序,传递不同的次序参数 4 | for i in {37..161} 5 | do 6 | python -m experiments.evaluate_olmo --order=$i 7 | 8 | done 9 | 10 | #for i in 4 37 40 44 59 77 90 105 141 147 11 | #do 12 | # python -m experiments.memit_jailbreak_evaluate_olmo --order=$i 13 | # 14 | #done 15 | -------------------------------------------------------------------------------- /memit/globals.yml: -------------------------------------------------------------------------------- 1 | --- 2 | # Result files 3 | RESULTS_DIR: "results" 4 | 5 | # Data files 6 | DATA_DIR: "data" 7 | STATS_DIR: "data/stats" 8 | KV_DIR: "/share/projects/rewriting-knowledge/kvs" 9 | 10 | # Hyperparameters 11 | HPARAMS_DIR: "hparams" 12 | 13 | # Remote URLs 14 | REMOTE_ROOT_URL: "https://memit.baulab.info" 15 | -------------------------------------------------------------------------------- /memit/hparams/FT/EleutherAI_gpt-j-6B_constr.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 0 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": 5e-5, 10 | "rewrite_module_tmp": "transformer.h.{}.mlp.fc_out", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "lm_head" 16 | } -------------------------------------------------------------------------------- /memit/hparams/FT/EleutherAI_gpt-j-6B_unconstr.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 21 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": false, 10 | "rewrite_module_tmp": "transformer.h.{}.mlp.fc_out", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "lm_head" 16 | } -------------------------------------------------------------------------------- /memit/hparams/FT/EleutherAI_gpt-j-6B_wd.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 21 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": false, 8 | "wd_power_law": [-0.87028798, 0.15589562], 9 | "kl_factor": 0, 10 | "norm_constraint": false, 11 | "rewrite_module_tmp": "transformer.h.{}", 12 | "layer_module_tmp": "transformer.h.{}", 13 | "mlp_module_tmp": "transformer.h.{}.mlp", 14 | "attn_module_tmp": "transformer.h.{}.attn", 15 | "ln_f_module": "transformer.ln_f", 16 | "lm_head_module": "lm_head" 17 | } -------------------------------------------------------------------------------- /memit/hparams/FT/gpt2-large_constr.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 0 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": 1e-3, 10 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "transformer.wte" 16 | } -------------------------------------------------------------------------------- /memit/hparams/FT/gpt2-medium_constr.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 0 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": 2e-3, 10 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "transformer.wte" 16 | } -------------------------------------------------------------------------------- /memit/hparams/FT/gpt2-xl_attn.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 33 4 | ], 5 | "num_steps": 25, 6 | "lr": 1e-3, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": 1e-3, 10 | "rewrite_module_tmp": "transformer.h.{}.attn.c_attn", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "transformer.wte" 16 | } -------------------------------------------------------------------------------- /memit/hparams/FT/gpt2-xl_constr.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 0 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": 5e-4, 10 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "transformer.wte" 16 | } -------------------------------------------------------------------------------- /memit/hparams/FT/gpt2-xl_unconstr.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 1 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": false, 10 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "transformer.wte" 16 | } -------------------------------------------------------------------------------- /memit/hparams/MEMIT/EleutherAI_gpt-j-6B.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 3, 4, 5, 6, 7, 8 4 | ], 5 | "clamp_norm_factor": 0.75, 6 | "layer_selection": "all", 7 | "fact_token": "subject_last", 8 | "v_num_grad_steps": 25, 9 | "v_lr": 5e-1, 10 | "v_loss_layer": 27, 11 | "v_weight_decay": 0.5, 12 | "kl_factor": 0.0625, 13 | "mom2_adjustment": true, 14 | "mom2_update_weight": 15000, 15 | "rewrite_module_tmp": "transformer.h.{}.mlp.fc_out", 16 | "layer_module_tmp": "transformer.h.{}", 17 | "mlp_module_tmp": "transformer.h.{}.mlp", 18 | "attn_module_tmp": "transformer.h.{}.attn", 19 | "ln_f_module": "transformer.ln_f", 20 | "lm_head_module": "lm_head", 21 | "mom2_dataset": "wikipedia", 22 | "mom2_n_samples": 100000, 23 | "mom2_dtype": "float32" 24 | } 25 | -------------------------------------------------------------------------------- /memit/hparams/MEMIT/gpt2-xl.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [13, 14, 15, 16, 17], 3 | "clamp_norm_factor": 0.75, 4 | "layer_selection": "all", 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 20, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 47, 9 | "v_weight_decay": 0.5, 10 | "kl_factor": 0.0625, 11 | "mom2_adjustment": true, 12 | "mom2_update_weight": 20000, 13 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 14 | "layer_module_tmp": "transformer.h.{}", 15 | "mlp_module_tmp": "transformer.h.{}.mlp", 16 | "attn_module_tmp": "transformer.h.{}.attn", 17 | "ln_f_module": "transformer.ln_f", 18 | "lm_head_module": "transformer.wte", 19 | "mom2_dataset": "wikipedia", 20 | "mom2_n_samples": 100000, 21 | "mom2_dtype": "float32" 22 | } -------------------------------------------------------------------------------- /memit/hparams/MEMIT/llama2-7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [23], 3 | "clamp_norm_factor": 4, 4 | "layer_selection": "all", 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 25, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 31, 9 | "v_weight_decay": 1e-3, 10 | "kl_factor": 0.0625, 11 | "mom2_adjustment": true, 12 | "mom2_update_weight": 15000, 13 | "rewrite_module_tmp": "model.layers.{}.mlp.down_proj", 14 | "layer_module_tmp": "model.layers.{}", 15 | "mlp_module_tmp": "model.layers.{}.mlp", 16 | "attn_module_tmp": "model.layers.{}.self_attn", 17 | "ln_f_module": "model.norm", 18 | "lm_head_module": "lm_head", 19 | "mom2_dataset": "wikipedia", 20 | "mom2_n_samples": 100000, 21 | "mom2_dtype": "float32" 22 | } -------------------------------------------------------------------------------- /memit/hparams/MEMIT/olmo-7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [23], 3 | "clamp_norm_factor": 4, 4 | "layer_selection": "all", 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 25, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 31, 9 | "v_weight_decay": 1e-3, 10 | "kl_factor": 0.0625, 11 | "mom2_adjustment": true, 12 | "mom2_update_weight": 15000, 13 | "rewrite_module_tmp": "model.transformer.blocks.{}.ff_out", 14 | "layer_module_tmp": "model.transformer.blocks.{}", 15 | "mlp_module_tmp": "model.transformer.blocks.{}.mlp", 16 | "attn_module_tmp": "model.transformer.blocks.{}.self_attn", 17 | "ln_f_module": "model.transformer.ln_f", 18 | "lm_head_module": "model.transformer.ff_out", 19 | "mom2_dataset": "wikipedia", 20 | "mom2_n_samples": 100000, 21 | "mom2_dtype": "float32" 22 | } 23 | 24 | -------------------------------------------------------------------------------- /memit/hparams/MEND/EleutherAI_gpt-j-6B.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr_scale": 1.0, 3 | "n_toks": 10, 4 | "model_name": "EleutherAI/gpt-j-6B", 5 | "counterfact": false, 6 | "zsre": false, 7 | "mini": false 8 | } -------------------------------------------------------------------------------- /memit/hparams/MEND/EleutherAI_gpt-j-6B_CF.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr_scale": 1.0, 3 | "n_toks": 10, 4 | "model_name": "EleutherAI/gpt-j-6B", 5 | "counterfact": true, 6 | "zsre": false, 7 | "mini": false 8 | } -------------------------------------------------------------------------------- /memit/hparams/MEND/gpt2-xl.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr_scale": 1.0, 3 | "n_toks": 10, 4 | "model_name": "gpt2-xl", 5 | "counterfact": false, 6 | "zsre": false, 7 | "mini": false 8 | } -------------------------------------------------------------------------------- /memit/hparams/MEND/gpt2-xl_CF.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr_scale": 1.0, 3 | "n_toks": 1, 4 | "model_name": "gpt2-xl", 5 | "counterfact": true, 6 | "zsre": false, 7 | "mini": false 8 | } -------------------------------------------------------------------------------- /memit/hparams/MEND/gpt2-xl_zsRE.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr_scale": 1.0, 3 | "n_toks": 1, 4 | "model_name": "gpt2-xl", 5 | "counterfact": false, 6 | "zsre": true, 7 | "mini": false 8 | } -------------------------------------------------------------------------------- /memit/hparams/ROME/EleutherAI_gpt-j-6B.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 5 4 | ], 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 20, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 27, 9 | "v_weight_decay": 0.5, 10 | "clamp_norm_factor": 4, 11 | "kl_factor": 0.0625, 12 | "mom2_adjustment": true, 13 | "context_template_length_params": [[5, 10], [10, 10]], 14 | "rewrite_module_tmp": "transformer.h.{}.mlp.fc_out", 15 | "layer_module_tmp": "transformer.h.{}", 16 | "mlp_module_tmp": "transformer.h.{}.mlp", 17 | "attn_module_tmp": "transformer.h.{}.attn", 18 | "ln_f_module": "transformer.ln_f", 19 | "lm_head_module": "lm_head", 20 | "mom2_dataset": "wikipedia", 21 | "mom2_n_samples": 100000, 22 | "mom2_dtype": "float32" 23 | } -------------------------------------------------------------------------------- /memit/hparams/ROME/EleutherAI_gpt-neox-20b.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 15 4 | ], 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 20, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 43, 9 | "v_weight_decay": 0.5, 10 | "clamp_norm_factor": 4, 11 | "kl_factor": 0.0625, 12 | "mom2_adjustment": true, 13 | "context_template_length_params": [[5, 10], [10, 10]], 14 | "rewrite_module_tmp": "gpt_neox.layers.{}.mlp.dense_4h_to_h", 15 | "layer_module_tmp": "gpt_neox.layers.{}", 16 | "mlp_module_tmp": "gpt_neox.layers.{}.mlp", 17 | "attn_module_tmp": "gpt_neox.layers.{}.attention", 18 | "ln_f_module": "gpt_neox.final_layer_norm", 19 | "lm_head_module": "embed_out", 20 | "mom2_dataset": "wikipedia", 21 | "mom2_n_samples": 100000, 22 | "mom2_dtype": "float32" 23 | } -------------------------------------------------------------------------------- /memit/hparams/ROME/gpt2-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 12 4 | ], 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 20, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 35, 9 | "v_weight_decay": 0.5, 10 | "clamp_norm_factor": 4, 11 | "kl_factor": 0.0625, 12 | "mom2_adjustment": true, 13 | "context_template_length_params": [[5, 10], [10, 10]], 14 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 15 | "layer_module_tmp": "transformer.h.{}", 16 | "mlp_module_tmp": "transformer.h.{}.mlp", 17 | "attn_module_tmp": "transformer.h.{}.attn", 18 | "ln_f_module": "transformer.ln_f", 19 | "lm_head_module": "transformer.wte", 20 | "mom2_dataset": "wikipedia", 21 | "mom2_n_samples": 100000, 22 | "mom2_dtype": "float32" 23 | } -------------------------------------------------------------------------------- /memit/hparams/ROME/gpt2-medium.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 8 4 | ], 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 20, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 23, 9 | "v_weight_decay": 0.5, 10 | "clamp_norm_factor": 3, 11 | "kl_factor": 0.0625, 12 | "mom2_adjustment": true, 13 | "context_template_length_params": [[5, 10], [10, 10]], 14 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 15 | "layer_module_tmp": "transformer.h.{}", 16 | "mlp_module_tmp": "transformer.h.{}.mlp", 17 | "attn_module_tmp": "transformer.h.{}.attn", 18 | "ln_f_module": "transformer.ln_f", 19 | "lm_head_module": "transformer.wte", 20 | "mom2_dataset": "wikipedia", 21 | "mom2_n_samples": 100000, 22 | "mom2_dtype": "float32" 23 | } -------------------------------------------------------------------------------- /memit/hparams/ROME/gpt2-xl.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 17 4 | ], 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 20, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 47, 9 | "v_weight_decay": 0.5, 10 | "clamp_norm_factor": 4, 11 | "kl_factor": 0.0625, 12 | "mom2_adjustment": true, 13 | "context_template_length_params": [[5, 10], [10, 10]], 14 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 15 | "layer_module_tmp": "transformer.h.{}", 16 | "mlp_module_tmp": "transformer.h.{}.mlp", 17 | "attn_module_tmp": "transformer.h.{}.attn", 18 | "ln_f_module": "transformer.ln_f", 19 | "lm_head_module": "transformer.wte", 20 | "mom2_dataset": "wikipedia", 21 | "mom2_n_samples": 100000, 22 | "mom2_dtype": "float32" 23 | } -------------------------------------------------------------------------------- /memit/memit/__init__.py: -------------------------------------------------------------------------------- 1 | from .memit_main import MEMITHyperParams, apply_memit_to_model -------------------------------------------------------------------------------- /memit/memit/compute_ks.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy as np 4 | import torch 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | 7 | from .compute_z import get_module_input_output_at_words 8 | from .memit_hparams import MEMITHyperParams 9 | 10 | 11 | def compute_ks( 12 | model: AutoModelForCausalLM, 13 | tok: AutoTokenizer, 14 | requests: Dict, 15 | hparams: MEMITHyperParams, 16 | layer: int, 17 | context_templates: List[str], 18 | ): 19 | layer_ks = get_module_input_output_at_words( 20 | model, 21 | tok, 22 | layer, 23 | context_templates=[ 24 | context.format(request["prompt"]) 25 | for request in requests 26 | for context_type in context_templates 27 | for context in context_type 28 | ], 29 | words=[ 30 | request["subject"] 31 | for request in requests 32 | for context_type in context_templates 33 | for _ in context_type 34 | ], 35 | module_template=hparams.rewrite_module_tmp, 36 | fact_token_strategy=hparams.fact_token, 37 | )[0] 38 | 39 | context_type_lens = [0] + [len(context_type) for context_type in context_templates] 40 | context_len = sum(context_type_lens) 41 | context_type_csum = np.cumsum(context_type_lens).tolist() 42 | 43 | ans = [] 44 | for i in range(0, layer_ks.size(0), context_len): 45 | tmp = [] 46 | for j in range(len(context_type_csum) - 1): 47 | start, end = context_type_csum[j], context_type_csum[j + 1] 48 | tmp.append(layer_ks[i + start : i + end].mean(0)) 49 | ans.append(torch.stack(tmp, 0).mean(0)) 50 | return torch.stack(ans, dim=0) 51 | -------------------------------------------------------------------------------- /memit/memit/memit_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Literal 3 | 4 | from util.hparams import HyperParams 5 | 6 | 7 | @dataclass 8 | class MEMITHyperParams(HyperParams): 9 | # Method 10 | layers: List[int] 11 | layer_selection: Literal["all", "random"] 12 | fact_token: Literal[ 13 | "last", "subject_first", "subject_last", "subject_first_after_last" 14 | ] 15 | v_num_grad_steps: int 16 | v_lr: float 17 | v_loss_layer: int 18 | v_weight_decay: float 19 | clamp_norm_factor: float 20 | kl_factor: float 21 | mom2_adjustment: bool 22 | mom2_update_weight: float 23 | 24 | # Module templates 25 | rewrite_module_tmp: str 26 | layer_module_tmp: str 27 | mlp_module_tmp: str 28 | attn_module_tmp: str 29 | ln_f_module: str 30 | lm_head_module: str 31 | 32 | # Statistics 33 | mom2_dataset: str 34 | mom2_n_samples: int 35 | mom2_dtype: str 36 | -------------------------------------------------------------------------------- /memit/notebooks/vis/visualize_multi_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d465f696", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "9bdfca4c", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "%matplotlib inline\n", 22 | "%config InlineBackend.figure_format = 'retina'\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "from experiments.summarize import main as summarize_main\n", 25 | "from pathlib import Path\n", 26 | "import math" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "451eb471", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "RESULTS_DIR = Path(\"results/iclr\")\n", 37 | "DATA = {}\n", 38 | "KEYS = None\n", 39 | "for method_dir in RESULTS_DIR.iterdir():\n", 40 | " method_name = str(method_dir).split(\"/\")[-1]\n", 41 | " print(method_name)\n", 42 | " n_edit_folders = list(method_dir.glob(\"*_edits_setting_*\"))\n", 43 | " for n_edit_folder in n_edit_folders:\n", 44 | " n_edits = str(n_edit_folder.name).split(\"/\")[-1].split(\"_\")[0]\n", 45 | " try:\n", 46 | " res = summarize_main(n_edit_folder.relative_to(\"results\"), [\"run_000\"])[0]\n", 47 | "\n", 48 | " DATA[method_name] = DATA.get(method_name, {})\n", 49 | " DATA[method_name][n_edits] = res\n", 50 | " if KEYS is None:\n", 51 | " KEYS = list(res.keys())\n", 52 | " except:\n", 53 | " pass\n", 54 | "\n", 55 | "print({k: list(v.keys()) for k, v in DATA.items()})" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "7b9f0860", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "plt.rcParams[\"figure.dpi\"] = 200\n", 66 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n", 67 | "\n", 68 | "SMALL_SIZE = 14\n", 69 | "MEDIUM_SIZE = 15\n", 70 | "BIGGER_SIZE = 16\n", 71 | "\n", 72 | "plt.rc(\"font\", size=SMALL_SIZE) # controls default text sizes\n", 73 | "plt.rc(\"axes\", titlesize=BIGGER_SIZE) # fontsize of the axes title\n", 74 | "plt.rc(\"axes\", labelsize=MEDIUM_SIZE) # fontsize of the x and y labels\n", 75 | "plt.rc(\"xtick\", labelsize=SMALL_SIZE) # fontsize of the tick labels\n", 76 | "plt.rc(\"ytick\", labelsize=SMALL_SIZE) # fontsize of the tick labels\n", 77 | "plt.rc(\"legend\", fontsize=SMALL_SIZE) # legend fontsize\n", 78 | "plt.rc(\"figure\", titlesize=BIGGER_SIZE) # fontsize of the figure title" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "d8b41acc", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "TITLES = {\n", 89 | " \"post_score\": \"Score (S)\",\n", 90 | " \"post_rewrite_success\": \"Efficacy Succ. (ES)\",\n", 91 | " \"post_paraphrase_success\": \"Generalization Succ. (PS)\",\n", 92 | " \"post_neighborhood_success\": \"Specificity Succ. (NS)\",\n", 93 | " \"post_rewrite_acc\": \"Efficacy Acc (EA)\",\n", 94 | " \"post_paraphrase_acc\": \"Generalization Acc. (PA)\",\n", 95 | " \"post_neighborhood_acc\": \"Specificity Acc. (NA)\",\n", 96 | " \"post_reference_score\": \"Consistency (RS)\",\n", 97 | "}\n", 98 | "\n", 99 | "SHOW_KEYS = list(TITLES.keys())" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "id": "a1d443f7", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "SHOW_KEYS = KEYS\n", 110 | "SHOW_KEYS.pop(SHOW_KEYS.index(\"run_dir\"))\n", 111 | "TITLES = {k: k for k in SHOW_KEYS}" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "id": "49efeea0", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "w = 4\n", 122 | "h = math.ceil(len(KEYS) / w)\n", 123 | "plt.figure(figsize=(w * 3.5, h * 2.5))\n", 124 | "\n", 125 | "assert all(k in KEYS for k in SHOW_KEYS)\n", 126 | "for i, key in enumerate(SHOW_KEYS):\n", 127 | " plt.subplot(h, w, i + 1)\n", 128 | " for method, results in sorted([(k, v) for k, v in DATA.items() if \"_fix\" not in k]):\n", 129 | " try:\n", 130 | " n_edits = list(map(int, results.keys()))\n", 131 | " values = [\n", 132 | " f[0] if (type(f := results[str(n)][key]) is tuple) else f\n", 133 | " for n in n_edits\n", 134 | " ]\n", 135 | " plt.plot(n_edits, values, marker=\"o\", markersize=4, label=method)\n", 136 | " plt.xlabel(\"# Edits\")\n", 137 | " # plt.ylabel(\"metric value\")\n", 138 | " plt.title(TITLES[key])\n", 139 | " plt.legend()\n", 140 | " except:\n", 141 | " pass\n", 142 | "plt.tight_layout()\n", 143 | "plt.savefig(\"tmp.pdf\")\n", 144 | "plt.show()" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "ae8e7ea4", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [] 154 | } 155 | ], 156 | "metadata": { 157 | "accelerator": "GPU", 158 | "kernelspec": { 159 | "display_name": "Python 3 (ipykernel)", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.9.7" 174 | }, 175 | "vscode": { 176 | "interpreter": { 177 | "hash": "2c3ec9f9cb0aa45979d92499665f4b05f2a3528d3b2ca0efacea2020d32b93f4" 178 | } 179 | } 180 | }, 181 | "nbformat": 4, 182 | "nbformat_minor": 5 183 | } 184 | -------------------------------------------------------------------------------- /memit/notebooks/vis/visualize_sweep_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d465f696", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "9bdfca4c", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "# enable high-resolution figure\n", 22 | "%matplotlib inline\n", 23 | "%config InlineBackend.figure_format = 'retina'\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "from experiments.summarize import main as summarize_main\n", 26 | "from pathlib import Path\n", 27 | "import math" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "451eb471", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "RESULTS_DIR = Path(\"results/sweeps\")\n", 38 | "DATA = {}\n", 39 | "KEYS = None\n", 40 | "for method_dir in RESULTS_DIR.iterdir():\n", 41 | " method_name = str(method_dir).split(\"/\")[-1]\n", 42 | " print(method_name)\n", 43 | " n_edit_folders = list(method_dir.glob(\"*_edits_setting_*\"))\n", 44 | " for n_edit_folder in n_edit_folders:\n", 45 | " n_edits = int(str(n_edit_folder.name).split(\"/\")[-1].split(\"_\")[0])\n", 46 | " setting_id = str(n_edit_folder.name).split(\"/\")[-1].split(\"_\")[-1]\n", 47 | " try:\n", 48 | " res = summarize_main(n_edit_folder.relative_to(\"results\"), [\"run_000\"])[0]\n", 49 | "\n", 50 | " DATA[method_name] = DATA.get(method_name, {})\n", 51 | " DATA[method_name][n_edits] = DATA[method_name].get(n_edits, {})\n", 52 | " DATA[method_name][n_edits][setting_id] = res\n", 53 | "\n", 54 | " if KEYS is None:\n", 55 | " KEYS = list(res.keys())\n", 56 | " except:\n", 57 | " pass\n", 58 | "\n", 59 | "print({k: list(v.keys()) for k, v in DATA.items()})" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "49efeea0", 66 | "metadata": { 67 | "scrolled": false 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "for method, all_n_edits in sorted([(k, v) for k, v in DATA.items()]):\n", 72 | " for n_edits, results in sorted([(k, v) for k, v in all_n_edits.items()]):\n", 73 | " w = 4\n", 74 | " h = math.ceil(len(KEYS) / w)\n", 75 | " plt.figure(figsize=(w * 3.5, h * 2.5))\n", 76 | " if \"run_dir\" in KEYS:\n", 77 | " KEYS.pop(KEYS.index(\"run_dir\"))\n", 78 | " for i, key in enumerate(KEYS):\n", 79 | " plt.subplot(w, h, i + 1)\n", 80 | "\n", 81 | " try:\n", 82 | " setting_ids = list(map(int, results.keys()))\n", 83 | " values = [\n", 84 | " f[0] if (type(f := results[str(n)][key]) is tuple) else f\n", 85 | " for n in setting_ids\n", 86 | " ]\n", 87 | " plt.plot(setting_ids, values, marker=\"o\", markersize=4, label=method)\n", 88 | " plt.xlabel(\"setting_id\")\n", 89 | " plt.ylabel(\"metric value\")\n", 90 | " plt.title(f\"{n_edits} edits: {key}\")\n", 91 | " plt.legend()\n", 92 | " except:\n", 93 | " pass\n", 94 | "\n", 95 | " plt.tight_layout()\n", 96 | " plt.show()" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "ae8e7ea4", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [] 106 | } 107 | ], 108 | "metadata": { 109 | "accelerator": "GPU", 110 | "kernelspec": { 111 | "display_name": "Python 3 (ipykernel)", 112 | "language": "python", 113 | "name": "python3" 114 | }, 115 | "language_info": { 116 | "codemirror_mode": { 117 | "name": "ipython", 118 | "version": 3 119 | }, 120 | "file_extension": ".py", 121 | "mimetype": "text/x-python", 122 | "name": "python", 123 | "nbconvert_exporter": "python", 124 | "pygments_lexer": "ipython3", 125 | "version": "3.9.7" 126 | }, 127 | "vscode": { 128 | "interpreter": { 129 | "hash": "2c3ec9f9cb0aa45979d92499665f4b05f2a3528d3b2ca0efacea2020d32b93f4" 130 | } 131 | } 132 | }, 133 | "nbformat": 4, 134 | "nbformat_minor": 5 135 | } 136 | -------------------------------------------------------------------------------- /memit/rome/README.md: -------------------------------------------------------------------------------- 1 | # ROME 2 | This package provides a self-contained implementation of Rank-One Model Editing (ROME). 3 | 4 | Recall that ROME's update consists of: $u$ selection, $v_*$ optimization, and $v$ insertion. 5 | * [`compute_u.py`](compute_u.py): Chooses a $u$ vector. 6 | * [`compute_v.py`](compute_v.py): Choose a $v_*$ via optimization, then computes $v$. 7 | * [`rome_main.py`](rome_main.py): Instruments main logic. 8 | * [`rome_params.py`](rome_hparams.py): Interface for specifying hyperparameters. Inherits from the base [`params.py`](../util/hparams.py) module. 9 | 10 | For estimating second moment statistics of keys ($C = KK$), we provide the `layer_stats` module. See the [main README](../README.md) for usage instructions. 11 | * [`layer_stats.py`](layer_stats.py): Logic for retrieving and caching key statistics. 12 | * [`tok_dataset.py`](tok_dataset.py): Utilities for creating a dataset of tokens. -------------------------------------------------------------------------------- /memit/rome/__init__.py: -------------------------------------------------------------------------------- 1 | from .rome_main import ROMEHyperParams, apply_rome_to_model, execute_rome 2 | -------------------------------------------------------------------------------- /memit/rome/compute_u.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Dict, List 4 | 5 | import torch 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | 8 | from rome import repr_tools 9 | from util.globals import * 10 | 11 | from .layer_stats import layer_stats 12 | from .rome_hparams import ROMEHyperParams 13 | 14 | # Cache variables 15 | inv_mom2_cache = {} 16 | 17 | 18 | def get_inv_cov( 19 | model: AutoModelForCausalLM, 20 | tok: AutoTokenizer, 21 | layer_name: str, 22 | mom2_dataset: str, 23 | mom2_n_samples: str, 24 | mom2_dtype: str, 25 | ) -> torch.Tensor: 26 | """ 27 | Retrieves covariance statistics, then computes the algebraic inverse. 28 | Caches result for future use. 29 | """ 30 | 31 | global inv_mom2_cache 32 | 33 | model_name = model.config._name_or_path.replace("/", "_") 34 | key = (model_name, layer_name) 35 | 36 | if key not in inv_mom2_cache: 37 | print( 38 | f"Retrieving inverse covariance statistics for {model_name} @ {layer_name}. " 39 | f"The result will be cached to avoid repetitive computation." 40 | ) 41 | stat = layer_stats( 42 | model, 43 | tok, 44 | layer_name, 45 | STATS_DIR, 46 | mom2_dataset, 47 | to_collect=["mom2"], 48 | sample_size=mom2_n_samples, 49 | precision=mom2_dtype, 50 | ) 51 | inv_mom2_cache[key] = torch.inverse( 52 | stat.mom2.moment().to("cuda") 53 | ).float() # Cast back to float32 54 | 55 | return inv_mom2_cache[key] 56 | 57 | 58 | def compute_u( 59 | model: AutoModelForCausalLM, 60 | tok: AutoTokenizer, 61 | request: Dict, 62 | hparams: ROMEHyperParams, 63 | layer: int, 64 | context_templates: List[str], 65 | ) -> torch.Tensor: 66 | """ 67 | Computes the right vector used in constructing the rank-1 update matrix. 68 | """ 69 | 70 | print("Computing left vector (u)...") 71 | 72 | # Compute projection token 73 | word_repr_args = dict( 74 | model=model, 75 | tok=tok, 76 | layer=layer, 77 | module_template=hparams.rewrite_module_tmp, 78 | track="in", 79 | ) 80 | if "subject_" in hparams.fact_token and hparams.fact_token.index("subject_") == 0: 81 | word = request["subject"] 82 | print(f"Selected u projection object {word}") 83 | cur_repr = repr_tools.get_reprs_at_word_tokens( 84 | context_templates=[ 85 | templ.format(request["prompt"]) for templ in context_templates 86 | ], 87 | words=[word for _ in range(len(context_templates))], 88 | subtoken=hparams.fact_token[len("subject_") :], 89 | **word_repr_args, 90 | ).mean(0) 91 | elif hparams.fact_token == "last": 92 | # Heuristic to choose last word. Not a huge deal if there's a minor 93 | # edge case (e.g. multi-token word) because the function below will 94 | # take the last token. 95 | cur_repr = repr_tools.get_reprs_at_idxs( 96 | contexts=[ 97 | templ.format(request["prompt"].format(request["subject"])) 98 | for templ in context_templates 99 | ], 100 | idxs=[[-1] for _ in range(len(context_templates))], 101 | **word_repr_args, 102 | ).mean(0) 103 | print("Selected u projection token with last token") 104 | else: 105 | raise ValueError(f"fact_token={hparams.fact_token} not recognized") 106 | 107 | # Apply inverse second moment adjustment 108 | u = cur_repr 109 | if hparams.mom2_adjustment: 110 | u = get_inv_cov( 111 | model, 112 | tok, 113 | hparams.rewrite_module_tmp.format(layer), 114 | hparams.mom2_dataset, 115 | hparams.mom2_n_samples, 116 | hparams.mom2_dtype, 117 | ) @ u.unsqueeze(1) 118 | u = u.squeeze() 119 | 120 | return u / u.norm() 121 | -------------------------------------------------------------------------------- /memit/rome/layer_stats.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | from datasets import load_dataset 6 | from tqdm.auto import tqdm 7 | from transformers import AutoModelForCausalLM, AutoTokenizer 8 | 9 | from util.globals import * 10 | from util.nethook import Trace, set_requires_grad 11 | from util.runningstats import CombinedStat, Mean, NormMean, SecondMoment, tally 12 | 13 | from .tok_dataset import ( 14 | TokenizedDataset, 15 | dict_to_, 16 | flatten_masked_batch, 17 | length_collation, 18 | ) 19 | 20 | STAT_TYPES = { 21 | "mom2": SecondMoment, 22 | "mean": Mean, 23 | "norm_mean": NormMean, 24 | } 25 | 26 | 27 | def main(): 28 | """ 29 | Command-line utility to precompute cached stats. 30 | """ 31 | import argparse 32 | 33 | parser = argparse.ArgumentParser(description="ROME Statistics Collector") 34 | 35 | def aa(*args, **kwargs): 36 | parser.add_argument(*args, **kwargs) 37 | 38 | aa("--model_name", default="gpt2-xl", choices=["gpt2-xl", "EleutherAI/gpt-j-6B"]) 39 | aa("--dataset", default="wikipedia", choices=["wikitext", "wikipedia"]) 40 | aa("--layers", default=[17], type=lambda x: list(map(int, x.split(",")))) 41 | aa("--to_collect", default=["mom2"], type=lambda x: x.split(",")) 42 | aa("--sample_size", default=100000, type=lambda x: None if x == "all" else int(x)) 43 | aa("--batch_tokens", default=None, type=lambda x: None if x == "any" else int(x)) 44 | aa("--precision", default="float32", choices=["float64", "float32", "float16"]) 45 | aa("--stats_dir", default=STATS_DIR) 46 | aa("--download", default=1, type=int, choices=[0, 1]) 47 | args = parser.parse_args() 48 | 49 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 50 | model = AutoModelForCausalLM.from_pretrained(args.model_name).eval().cuda() 51 | set_requires_grad(False, model) 52 | 53 | for layer_num in args.layers: 54 | print( 55 | f"Computing stats for layer {layer_num} of {args.model_name} " 56 | f'over {args.sample_size or "all"} samples of {args.dataset}. ' 57 | "Note, the statistics are collected over the inputs to the second MLP layer, " 58 | "or equivalently the outputs of the first MLP layer." 59 | ) 60 | proj_layer_name = "c_proj" if "gpt2" in args.model_name else "fc_out" 61 | layer_name = f"transformer.h.{layer_num}.mlp.{proj_layer_name}" 62 | 63 | layer_stats( 64 | model, 65 | tokenizer, 66 | layer_name, 67 | args.stats_dir, 68 | args.dataset, 69 | args.to_collect, 70 | sample_size=args.sample_size, 71 | precision=args.precision, 72 | batch_tokens=args.batch_tokens, 73 | download=args.download, 74 | ) 75 | 76 | 77 | def layer_stats( 78 | model, 79 | tokenizer, 80 | layer_name, 81 | stats_dir, 82 | ds_name, 83 | to_collect, 84 | model_name=None, 85 | sample_size=None, 86 | precision=None, 87 | batch_tokens=None, 88 | download=True, 89 | progress=tqdm, 90 | force_recompute=False, 91 | ): 92 | """ 93 | Function to load or compute cached stats. 94 | """ 95 | 96 | def get_ds(): 97 | raw_ds = load_dataset( 98 | ds_name, 99 | dict(wikitext="wikitext-103-raw-v1", wikipedia="20200501.en")[ds_name], 100 | ) 101 | if 'olmo' in model.name_or_path.lower(): 102 | maxlen = model.config.max_sequence_length 103 | else: 104 | maxlen = model.config.max_position_embeddings 105 | if batch_tokens is not None and batch_tokens < maxlen: 106 | maxlen = batch_tokens 107 | return TokenizedDataset(raw_ds["train"], tokenizer, maxlen=maxlen) 108 | 109 | # Continue with computation of statistics 110 | batch_size = 100 # Examine this many dataset texts at once 111 | if 'olmo' in model.name_or_path.lower(): 112 | npos = model.config.max_sequence_length 113 | else: 114 | npos = model.config.max_position_embeddings 115 | 116 | if batch_tokens is None: 117 | batch_tokens = npos * 3 # Sort and divide into batches with this many tokens 118 | if precision is None: 119 | precision = "float64" 120 | dtype = getattr(torch, precision) 121 | size_suffix = "" if sample_size is None else f"_{sample_size}" 122 | if batch_tokens < npos: 123 | size_suffix = "_t{batch_tokens}" + size_suffix 124 | if model_name is None: 125 | model_name = model.config._name_or_path.replace("/", "_") 126 | 127 | stats_dir = Path(stats_dir) 128 | file_extension = f"{model_name}/{ds_name}_stats/{layer_name}_{precision}_{'-'.join(sorted(to_collect))}{size_suffix}.npz" 129 | filename = stats_dir / file_extension 130 | 131 | if not filename.exists() and download: 132 | remote_url = f"{REMOTE_ROOT_URL}/data/stats/{file_extension}" 133 | try: 134 | print(f"Attempting to download {file_extension} from {remote_url}.") 135 | (stats_dir / "/".join(file_extension.split("/")[:-1])).mkdir( 136 | exist_ok=True, parents=True 137 | ) 138 | torch.hub.download_url_to_file(remote_url, filename) 139 | print("Successfully downloaded.") 140 | except Exception as e: 141 | print(f"Unable to download due to {e}. Computing locally....") 142 | 143 | ds = get_ds() if not filename.exists() else None 144 | 145 | if progress is None: 146 | progress = lambda x: x 147 | 148 | stat = CombinedStat(**{k: STAT_TYPES[k]() for k in to_collect}) 149 | loader = tally( 150 | stat, 151 | ds, 152 | cache=(filename if not force_recompute else None), 153 | sample_size=sample_size, 154 | batch_size=batch_size, 155 | collate_fn=length_collation(batch_tokens), 156 | pin_memory=True, 157 | random_sample=1, 158 | num_workers=2, 159 | ) 160 | batch_count = -(-(sample_size or len(ds)) // batch_size) 161 | with torch.no_grad(): 162 | for batch_group in progress(loader, total=batch_count): 163 | for batch in batch_group: 164 | batch = dict_to_(batch, "cuda") 165 | with Trace( 166 | model, layer_name, retain_input=True, retain_output=False, stop=True 167 | ) as tr: 168 | model(**batch) 169 | feats = flatten_masked_batch(tr.input, batch["attention_mask"]) 170 | # feats = flatten_masked_batch(tr.output, batch["attention_mask"]) 171 | feats = feats.to(dtype=dtype) 172 | stat.add(feats) 173 | return stat 174 | 175 | 176 | if __name__ == "__main__": 177 | main() 178 | -------------------------------------------------------------------------------- /memit/rome/repr_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains utilities for extracting token representations and indices 3 | from string templates. Used in computing the left and right vectors for ROME. 4 | """ 5 | 6 | from copy import deepcopy 7 | from typing import List 8 | 9 | import torch 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | from util import nethook 13 | 14 | 15 | def get_reprs_at_word_tokens( 16 | model: AutoModelForCausalLM, 17 | tok: AutoTokenizer, 18 | context_templates: List[str], 19 | words: List[str], 20 | layer: int, 21 | module_template: str, 22 | subtoken: str, 23 | track: str = "in", 24 | ) -> torch.Tensor: 25 | """ 26 | Retrieves the last token representation of `word` in `context_template` 27 | when `word` is substituted into `context_template`. See `get_last_word_idx_in_template` 28 | for more details. 29 | """ 30 | 31 | idxs = get_words_idxs_in_templates(tok, context_templates, words, subtoken) 32 | return get_reprs_at_idxs( 33 | model, 34 | tok, 35 | [context_templates[i].format(words[i]) for i in range(len(words))], 36 | idxs, 37 | layer, 38 | module_template, 39 | track, 40 | ) 41 | 42 | 43 | def get_words_idxs_in_templates( 44 | tok: AutoTokenizer, context_templates: str, words: str, subtoken: str 45 | ) -> int: 46 | """ 47 | Given list of template strings, each with *one* format specifier 48 | (e.g. "{} plays basketball"), and words to be substituted into the 49 | template, computes the post-tokenization index of their last tokens. 50 | """ 51 | 52 | assert all( 53 | tmp.count("{}") == 1 for tmp in context_templates 54 | ), "We currently do not support multiple fill-ins for context" 55 | 56 | # Compute prefixes and suffixes of the tokenized context 57 | fill_idxs = [tmp.index("{}") for tmp in context_templates] 58 | prefixes, suffixes = [ 59 | tmp[: fill_idxs[i]] for i, tmp in enumerate(context_templates) 60 | ], [tmp[fill_idxs[i] + 2 :] for i, tmp in enumerate(context_templates)] 61 | words = deepcopy(words) 62 | 63 | # Pre-process tokens 64 | for i, prefix in enumerate(prefixes): 65 | if len(prefix) > 0: 66 | assert prefix[-1] == " " 67 | prefix = prefix[:-1] 68 | 69 | prefixes[i] = prefix 70 | words[i] = f" {words[i].strip()}" 71 | 72 | # Tokenize to determine lengths 73 | assert len(prefixes) == len(words) == len(suffixes) 74 | n = len(prefixes) 75 | batch_tok = tok([*prefixes, *words, *suffixes]) 76 | prefixes_tok, words_tok, suffixes_tok = [ 77 | batch_tok[i : i + n] for i in range(0, n * 3, n) 78 | ] 79 | prefixes_len, words_len, suffixes_len = [ 80 | [len(el) for el in tok_list] 81 | for tok_list in [prefixes_tok, words_tok, suffixes_tok] 82 | ] 83 | 84 | # Compute indices of last tokens 85 | if subtoken == "last" or subtoken == "first_after_last": 86 | return [ 87 | [ 88 | prefixes_len[i] 89 | + words_len[i] 90 | - (1 if subtoken == "last" or suffixes_len[i] == 0 else 0) 91 | ] 92 | # If suffix is empty, there is no "first token after the last". 93 | # So, just return the last token of the word. 94 | for i in range(n) 95 | ] 96 | elif subtoken == "first": 97 | return [[prefixes_len[i]] for i in range(n)] 98 | else: 99 | raise ValueError(f"Unknown subtoken type: {subtoken}") 100 | 101 | 102 | def get_reprs_at_idxs( 103 | model: AutoModelForCausalLM, 104 | tok: AutoTokenizer, 105 | contexts: List[str], 106 | idxs: List[List[int]], 107 | layer: int, 108 | module_template: str, 109 | track: str = "in", 110 | ) -> torch.Tensor: 111 | """ 112 | Runs input through model and returns averaged representations of the tokens 113 | at each index in `idxs`. 114 | """ 115 | 116 | def _batch(n): 117 | for i in range(0, len(contexts), n): 118 | yield contexts[i : i + n], idxs[i : i + n] 119 | 120 | assert track in {"in", "out", "both"} 121 | both = track == "both" 122 | tin, tout = ( 123 | (track == "in" or both), 124 | (track == "out" or both), 125 | ) 126 | module_name = module_template.format(layer) 127 | to_return = {"in": [], "out": []} 128 | 129 | def _process(cur_repr, batch_idxs, key): 130 | nonlocal to_return 131 | cur_repr = cur_repr[0] if type(cur_repr) is tuple else cur_repr 132 | for i, idx_list in enumerate(batch_idxs): 133 | to_return[key].append(cur_repr[i][idx_list].mean(0)) 134 | 135 | for batch_contexts, batch_idxs in _batch(n=128): 136 | contexts_tok = tok(batch_contexts, padding=True, return_token_type_ids = False, return_tensors="pt").to( 137 | next(model.parameters()).device 138 | ) 139 | 140 | with torch.no_grad(): 141 | with nethook.Trace( 142 | module=model, 143 | layer=module_name, 144 | retain_input=tin, 145 | retain_output=tout, 146 | ) as tr: 147 | model(**contexts_tok) 148 | 149 | if tin: 150 | _process(tr.input, batch_idxs, "in") 151 | if tout: 152 | _process(tr.output, batch_idxs, "out") 153 | 154 | to_return = {k: torch.stack(v, 0) for k, v in to_return.items() if len(v) > 0} 155 | 156 | if len(to_return) == 1: 157 | return to_return["in"] if tin else to_return["out"] 158 | else: 159 | return to_return["in"], to_return["out"] 160 | -------------------------------------------------------------------------------- /memit/rome/rome_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | from util.hparams import HyperParams 5 | 6 | 7 | @dataclass 8 | class ROMEHyperParams(HyperParams): 9 | # Method 10 | layers: List[int] 11 | fact_token: str 12 | v_num_grad_steps: int 13 | v_lr: float 14 | v_loss_layer: int 15 | v_weight_decay: float 16 | clamp_norm_factor: float 17 | kl_factor: float 18 | mom2_adjustment: bool 19 | context_template_length_params: List[List[int]] 20 | 21 | # Module templates 22 | rewrite_module_tmp: str 23 | layer_module_tmp: str 24 | mlp_module_tmp: str 25 | attn_module_tmp: str 26 | ln_f_module: str 27 | lm_head_module: str 28 | 29 | # Statistics 30 | mom2_dataset: str 31 | mom2_n_samples: int 32 | mom2_dtype: str 33 | -------------------------------------------------------------------------------- /memit/rome/tok_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class TokenizedDataset(Dataset): 7 | """ 8 | Converts a dataset of text samples into a dataset of token sequences, 9 | as converted by a supplied tokenizer. The tokens come along with position 10 | ids and attention masks, they can be supplied direcly to the model. 11 | """ 12 | 13 | def __init__(self, text_dataset, tokenizer=None, maxlen=None, field="text"): 14 | self.text_dataset = text_dataset 15 | self.field = field 16 | self.tokenizer = tokenizer 17 | self.maxlen = maxlen 18 | if hasattr(text_dataset, "info"): 19 | self.info = text_dataset.info 20 | 21 | def __len__(self): 22 | return len(self.text_dataset) 23 | 24 | def __getitem__(self, i): 25 | text = self.text_dataset[i] 26 | if self.field is not None: 27 | text = text[self.field] 28 | token_list = self.tokenizer.encode( 29 | text, truncation=True, max_length=self.maxlen 30 | ) 31 | position_ids = list(range(len(token_list))) 32 | attention_mask = [1] * len(token_list) 33 | return dict( 34 | input_ids=torch.tensor(token_list), 35 | #position_ids=torch.tensor(position_ids), 36 | attention_mask=torch.tensor(attention_mask), 37 | ) 38 | 39 | 40 | def dict_to_(data, device): 41 | """ 42 | Moves a dictionary of tensors to the specified device. 43 | """ 44 | for k in data: 45 | data[k] = data[k].to(device) 46 | return data 47 | 48 | 49 | def length_collation(token_size): 50 | """ 51 | Sorts a batch of sequences and breaks it up into subbatches 52 | of same-sized sequences, padding as needed. Each batch 53 | has no more than token_size total tokens (or a single 54 | sequence, if the sequence happens to be larger). 55 | """ 56 | 57 | def collate_fn(items): 58 | items = sorted(items, key=lambda x: -len(x["input_ids"])) 59 | batches = [] 60 | batch = [] 61 | batch_width = 0 62 | for item in items: 63 | item_width = len(item["input_ids"]) 64 | if item_width == 0: 65 | break 66 | if batch_width * (len(batch) + 1) > token_size: 67 | batches.append(make_padded_batch(batch)) 68 | batch = [] 69 | batch_width = 0 70 | if not batch: 71 | batch_width = item_width 72 | batch.append(item) 73 | if len(batch): 74 | batches.append(make_padded_batch(batch)) 75 | return batches 76 | 77 | return collate_fn 78 | 79 | 80 | def make_padded_batch(items): 81 | """ 82 | Pads sequences in a batch, so they are all the same length as the longest. 83 | """ 84 | max_len = max(len(d["input_ids"]) for d in items) 85 | if max_len == 0: 86 | return {k: torch.zeros((0, 0), dtype=torch.long) for k in items[0]} 87 | return { 88 | k: pad_sequence([d[k] for d in items if len(d["input_ids"])], batch_first=True) 89 | for k, v in items[0].items() 90 | } 91 | 92 | 93 | def flatten_masked_batch(data, mask): 94 | """ 95 | Flattens feature data, ignoring items that are masked out of attention. 96 | """ 97 | flat_data = data.view(-1, data.size(-1)) 98 | attended_tokens = mask.view(-1).nonzero()[:, 0] 99 | return flat_data[attended_tokens] 100 | -------------------------------------------------------------------------------- /memit/scaling_curves.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Constants 5 | DIR="scaling" 6 | MIN_NUM_RECORDS="10000" 7 | GEN_TEST_INTERV="10" 8 | N_EDITS="1,56,100,316,562,1000,1778,3162,5623,10000" 9 | 10 | # Run configurations 11 | MODEL_NAME="EleutherAI/gpt-j-6B" 12 | ALG_NAMES=("FT" "MEND" "ROME" "MEMIT") 13 | HPARAMS_FNAMES=("EleutherAI_gpt-j-6B_wd.json" "EleutherAI_gpt-j-6B.json" "EleutherAI_gpt-j-6B.json" "EleutherAI_gpt-j-6B.json") 14 | 15 | # Execute 16 | for i in ${!ALG_NAMES[@]} 17 | do 18 | alg_name=${ALG_NAMES[$i]} 19 | hparams_fname=${HPARAMS_FNAMES[$i]} 20 | 21 | echo "Running evals for $alg_name..." 22 | sweep_dir="$DIR/$alg_name" 23 | 24 | if [ -d "results/$sweep_dir" ]; then 25 | echo "Note: results/$sweep_dir already exists! Continuing from previous run..." 26 | fi 27 | 28 | echo "Dumping results at results/$sweep_dir" 29 | mkdir -p results/$sweep_dir 30 | echo "{}" > results/$sweep_dir/config.json 31 | 32 | python3 -m experiments.sweep --alg_name=$alg_name --model_name=$MODEL_NAME --hparams_fname=$hparams_fname --sweep_dir=$sweep_dir --min_num_records=$MIN_NUM_RECORDS --num_edits=$N_EDITS --generation_test_interval=$GEN_TEST_INTERV --use_cache 33 | done 34 | 35 | exit 0 36 | -------------------------------------------------------------------------------- /memit/scripts/causal_trace.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Start from parent directory of script 4 | SCRIPT_DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) 5 | cd "$(dirname ${SCRIPT_DIR})" 6 | 7 | python -m experiments.causal_trace --model_name "EleutherAI/gpt-j-6B" --noise_level 0.025 8 | python -m experiments.causal_trace --model_name "gpt2-xl" --noise_level 0.1 9 | python -m experiments.causal_trace --model_name "EleutherAI/gpt-neox-20b" --noise_level 0.03 10 | -------------------------------------------------------------------------------- /memit/scripts/colab_reqs/additional.txt: -------------------------------------------------------------------------------- 1 | allennlp==2.9.0 2 | einops==0.4.0 3 | higher==0.2.1 4 | hydra-core==1.1.1 -------------------------------------------------------------------------------- /memit/scripts/colab_reqs/rome.txt: -------------------------------------------------------------------------------- 1 | datasets==1.18.3 2 | python-dotenv==0.19.2 3 | git+https://github.com/kmeng01/transformers-colab@allennlp-compat 4 | -------------------------------------------------------------------------------- /memit/scripts/collect_layer_stats.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Start from parent directory of script 4 | SCRIPT_DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) 5 | cd "$(dirname ${SCRIPT_DIR})" 6 | 7 | run_gpu_0() { 8 | CUDA_VISIBLE_DEVICES=0 python -m rome.layer_stats --layers=$(seq -s, 0 1 27) --sample_size 100000 --model_name=EleutherAI/gpt-j-6B 9 | } 10 | 11 | run_gpu_1() { 12 | CUDA_VISIBLE_DEVICES=1 python -m rome.layer_stats --layers=$(seq -s, 0 1 27) --sample_size 100000 --model_name=EleutherAI/gpt-j-6B 13 | } 14 | 15 | # run_gpu_0 &>stats0.out& 16 | run_gpu_1 &>stats1.out& 17 | -------------------------------------------------------------------------------- /memit/scripts/ipynb_drop_output.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Suppress output and prompt numbers in git version control. 5 | 6 | This script will tell git to ignore prompt numbers and cell output 7 | when looking at ipynb files UNLESS their metadata contains: 8 | 9 | "git": { 10 | "keep_outputs": true 11 | }, 12 | 13 | The notebooks themselves are not changed. 14 | 15 | See also this blogpost: http://pascalbugnion.net/blog/ipython-notebooks-and-git.html. 16 | 17 | Usage instructions 18 | ================== 19 | 20 | 1. Put this script in a directory that is on the system's path. 21 | For future reference, I will assume you saved it in 22 | `~/scripts/ipynb_drop_output`. 23 | 2. Make sure it is executable by typing the command 24 | `chmod +x ~/scripts/ipynb_drop_output`. 25 | 3. Register a filter for ipython notebooks by 26 | putting the following line in `~/.config/git/attributes`: 27 | `*.ipynb filter=clean_ipynb` 28 | 4. Connect this script to the filter by running the following 29 | git commands: 30 | 31 | git config --global filter.clean_ipynb.clean ipynb_drop_output 32 | git config --global filter.clean_ipynb.smudge cat 33 | 34 | To tell git NOT to ignore the output and prompts for a notebook, 35 | open the notebook's metadata (Edit > Edit Notebook Metadata). A 36 | panel should open containing the lines: 37 | 38 | { 39 | "name" : "", 40 | "signature" : "some very long hash" 41 | } 42 | 43 | Add an extra line so that the metadata now looks like: 44 | 45 | { 46 | "name" : "", 47 | "signature" : "don't change the hash, but add a comma at the end of the line", 48 | "git" : { "keep_outputs" : true } 49 | } 50 | 51 | You may need to "touch" the notebooks for git to actually register a change, if 52 | your notebooks are already under version control. 53 | 54 | Notes 55 | ===== 56 | 57 | 58 | This script is inspired by http://stackoverflow.com/a/20844506/827862, but 59 | lets the user specify whether the ouptut of a notebook should be kept 60 | in the notebook's metadata, and works for IPython v3.0. 61 | """ 62 | 63 | import json 64 | import sys 65 | 66 | nb = sys.stdin.read() 67 | 68 | json_in = json.loads(nb) 69 | nb_metadata = json_in["metadata"] 70 | keep_output = False 71 | if "git" in nb_metadata: 72 | if "keep_outputs" in nb_metadata["git"] and nb_metadata["git"]["keep_outputs"]: 73 | keep_output = True 74 | if keep_output: 75 | sys.stdout.write(nb) 76 | exit() 77 | 78 | 79 | ipy_version = int(json_in["nbformat"]) - 1 # nbformat is 1 more than actual version. 80 | 81 | 82 | def strip_output_from_cell(cell): 83 | if "outputs" in cell: 84 | cell["outputs"] = [] 85 | if "prompt_number" in cell: 86 | del cell["prompt_number"] 87 | if "execution_count" in cell: 88 | cell["execution_count"] = None 89 | 90 | 91 | if ipy_version == 2: 92 | for sheet in json_in["worksheets"]: 93 | for cell in sheet["cells"]: 94 | strip_output_from_cell(cell) 95 | else: 96 | for cell in json_in["cells"]: 97 | strip_output_from_cell(cell) 98 | 99 | json.dump( 100 | json_in, 101 | sys.stdout, 102 | sort_keys=True, 103 | indent=1, 104 | separators=(",", ": "), 105 | ensure_ascii=False, 106 | ) 107 | # https://stackoverflow.com/questions/729692/why-should-text-files-end-with-a-newline 108 | sys.stdout.write("\n") 109 | -------------------------------------------------------------------------------- /memit/scripts/memit.yml: -------------------------------------------------------------------------------- 1 | name: memit 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.9.7 7 | - pip=21.2.4 8 | - cudatoolkit=11.3 9 | - pytorch==1.12.1 10 | - pip: 11 | - einops==0.4.0 12 | - higher==0.2.1 13 | - hydra-core==1.2.0 14 | - transformers==4.23.1 15 | - datasets==1.18.3 16 | - matplotlib==3.6.1 17 | - spacy==3.4.1 18 | - scipy==1.9.2 19 | - scikit-learn==1.0.2 20 | - nltk==3.7 21 | - jupyter==1.0.0 -------------------------------------------------------------------------------- /memit/scripts/setup_clean_ipynb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Start from directory of script 4 | cd "$(dirname "${BASH_SOURCE[0]}")" 5 | 6 | # Set up git config filters so huge output of notebooks is not committed. 7 | git config filter.clean_ipynb.clean "$(pwd)/ipynb_drop_output.py" 8 | git config filter.clean_ipynb.smudge cat 9 | git config filter.clean_ipynb.required true 10 | -------------------------------------------------------------------------------- /memit/scripts/setup_conda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Start from directory of script 4 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 5 | cd $SCRIPT_DIR 6 | 7 | # Detect operating system 8 | unameOut="$(uname -s)" 9 | case "${unameOut}" in 10 | Linux*) machine=Linux;; 11 | Darwin*) machine=Mac;; 12 | CYGWIN*) machine=Cygwin;; 13 | MINGW*) machine=MinGw;; 14 | *) machine="UNKNOWN:${unameOut}" 15 | esac 16 | 17 | if [ $machine != "Linux" ] && [ $machine != "Mac" ] 18 | then 19 | echo "Conda setup script is only available on Linux and Mac." 20 | exit 1 21 | else 22 | echo "Running on $machine..." 23 | fi 24 | 25 | if [[ -z "${CONDA_HOME}" ]]; then 26 | echo "Please specify the CONDA_HOME environment variable (it might look something like ~/miniconda3)." 27 | exit 1 28 | else 29 | echo "Found CONDA_HOME=${CONDA_HOME}." 30 | fi 31 | 32 | RECIPE=${RECIPE:-memit} 33 | ENV_NAME="${ENV_NAME:-${RECIPE}}" 34 | echo "Creating conda environment ${ENV_NAME}..." 35 | 36 | if [[ ! $(type -P conda) ]] 37 | then 38 | echo "conda not in PATH" 39 | echo "read: https://conda.io/docs/user-guide/install/index.html" 40 | exit 1 41 | fi 42 | 43 | if df "${HOME}/.conda" --type=afs > /dev/null 2>&1 44 | then 45 | echo "Not installing: your ~/.conda directory is on AFS." 46 | echo "Use 'ln -s /some/nfs/dir ~/.conda' to avoid using up your AFS quota." 47 | exit 1 48 | fi 49 | 50 | # Build new environment 51 | conda env create --name=${ENV_NAME} -f ${RECIPE}.yml 52 | -------------------------------------------------------------------------------- /memit/transformer_utils/README.md: -------------------------------------------------------------------------------- 1 | ## transformer-utils 2 | 3 | Utilities for the HuggingFace [transformers](https://github.com/huggingface/transformers/) library, focused on loading and using large pretrained autoregressive language models like GPT-2 and GPT-Neo. 4 | 5 | This package is unofficial and not associated with HuggingFace. 6 | 7 | Features: 8 | 9 | - Load large (~2.7B) models in low-resource environments like Google Colab 10 | - Get activations from any part of the model, without running parts you don't need 11 | - Interpret models with the "logit lens" 12 | - For background, see 13 | - ["interpreting GPT: the logit lens"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) by nostalgebraist 14 | - ["Finding the Words to Say: Hidden State Visualizations for Language Models"](https://jalammar.github.io/hidden-states/) by Jay Alammar 15 | 16 | ## Example usage 17 | 18 | ### Load in a low-memory environment 19 | 20 | Loading a 2.7B model: 21 | 22 | ```python 23 | from transformer_utils.low_memory import enable_low_memory_load 24 | 25 | enable_low_memory_load() 26 | 27 | model = transformers.AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-2.7B') 28 | ``` 29 | 30 | This works fine in an ordinary (non-Pro) Google Colab notebook, with ~12 GB RAM and a T5 GPU. 31 | 32 | Inference will work up to the full context window length of 2048 tokens without memory issues. 33 | 34 | ### Logit lens 35 | 36 | ```python 37 | import torch 38 | import transformers 39 | from transformer_utils.low_memory import enable_low_memory_load 40 | 41 | enable_low_memory_load() 42 | 43 | tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2") 44 | model = transformers.AutoModelForCausalLM.from_pretrained('gpt2-xl') 45 | 46 | def text_to_input_ids(text): 47 | toks = tokenizer.encode(text) 48 | return torch.as_tensor(toks).view(1, -1).cuda() 49 | 50 | input_ids = text_to_input_ids("This is an example. You can probably think of a more fun text to use than this one.") 51 | 52 | plot_logit_lens(model, tokenizer, input_ids, start_ix=0, end_ix=45) # logits 53 | 54 | plot_logit_lens(model, tokenizer, input_ids, start_ix=0, end_ix=45, probs=True) # probabilities 55 | 56 | plot_logit_lens(model, tokenizer, input_ids, start_ix=0, end_ix=45, kl=True) # K-L divergence 57 | ``` 58 | 59 | You can do also some other things that aren't in the original blog posts. This will break down the transformer blocks into their attention and MLP parts: 60 | 61 | ```python 62 | plot_logit_lens(model, tokenizer, input_ids, start_ix=0, end_ix=45, include_subblocks=True) 63 | ``` 64 | 65 | You can also change the definition of the "decoder" to include some of the later blocks/subblocks of the model. This helps especially in interpreting GPT-Neo hidden states. 66 | 67 | ```python 68 | # assume we have a 48-layer model 69 | # so 'h.47' is the final layer 70 | 71 | # include last layer in decoder 72 | plot_logit_lens( 73 | model, tokenizer, input_ids, start_ix=0, end_ix=45, 74 | decoder_layer_names=['h.47', 'final_layernorm', 'lm_head'] 75 | ) 76 | 77 | # include just the last MLP subblock in decoder 78 | plot_logit_lens( 79 | model, tokenizer, input_ids, start_ix=0, end_ix=45, 80 | decoder_layer_names=['h.47.mlp', 'final_layernorm', 'lm_head'] 81 | ) 82 | ``` 83 | 84 | ### Get activations from any part of the model 85 | 86 | ###### ...and without running parts you don't need 87 | 88 | ```python 89 | from transformer_utils.partial_forward import partial_forward 90 | 91 | output = partial_forward( 92 | model=model, # your `transformers` model 93 | output_names=[ 94 | 'h.0', # output of the 1st layer 95 | 'h.2.attn.c_attn', # query/key/value matrix from the 3rd layer 96 | 'h.5.mlp.c_proj', # feed-forward activations from the 6th layer 97 | ], 98 | input_ids # the input to run 99 | ) 100 | 101 | # each of these is a tensor 102 | output['h.0'] 103 | output['h.2.attn.c_attn'] 104 | output['h.5.mlp.c_proj'] 105 | ``` 106 | 107 | For efficiency, `partial_forward` doesn't run any part of the model later than the ones you specify in `output_names`. 108 | 109 | For example, suppose `model` above was GPT-2 XL. Then it has 48 layers. But the forward pass in the code above stops running after the 6th layer of 48 -- so the compute and memory cost is far lower than a full `model.forward`. 110 | 111 | This makes it easy to write new "heads" that do further computation on the model's activations. 112 | 113 | Some examples: 114 | 115 | ##### Using the first two layers of a model as features extractors for binary classification 116 | 117 | ```python 118 | output_names=['h.0', 'h.1',] 119 | classifier_hidden_size=768 120 | 121 | feature_vector_size = base_model.config.n_embd * len(output_names) 122 | 123 | classifier = nn.Sequential( 124 | nn.Linear(feature_vector_size, classifier_hidden_size), 125 | nn.ReLU(), 126 | nn.Linear(classifier_hidden_size, 2), 127 | ) 128 | 129 | opt = torch.optim.Adam(classifier.parameters()) 130 | 131 | for input_ids, targets in dataset: # `dataset` is your classification train data 132 | with torch.no_grad(): 133 | hidden_states = partial_forward( 134 | base_model, 135 | output_names, 136 | input_ids, 137 | ) 138 | 139 | # shape (batch, sequence, len(output_names) * model's hidden size) 140 | feature_vector = torch.cat( 141 | [hidden_states[name] for name in output_names], 142 | dim=-1 143 | ) 144 | 145 | # shape (batch, sequence, 2) 146 | classifier_out = classifier(feature_vector) 147 | 148 | # simple avg pool over sequence dim -- in practice find attention works well for this step :) 149 | # shape (batch, 2) 150 | logits = classifier_out.mean(dim=1) 151 | 152 | loss = F.cross_entropy(target=targets, input=logits) 153 | loss.backward() 154 | opt.step() 155 | opt.zero_grad() 156 | ``` 157 | 158 | 159 | ##### Finetuning the first two layers of a model 160 | 161 | This is exactly the same as the above, with just two changes: 162 | 163 | - Remove the `with torch.no_grad()` wrapper around `partial_forward` 164 | - Optimize the base model's params too: 165 | - `opt = torch.optim.Adam(list(classifier.parameters()) + list(base_model.parameters()))` 166 | 167 | If you want to train a model like these ones for real use, I recommend writing a custom `nn.Module`. [See here](https://github.com/nostalgebraist/nostalgebraist-autoresponder/blob/fd96e9482186f5dbeaa27bd6179087c892c577d6/selector_model/selector_nn_neo.py#L263) for an example. 168 | -------------------------------------------------------------------------------- /memit/transformer_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuaihong/ConceptVectors/607591b415043f7692bc17a9748de3d8ff3fc0c7/memit/transformer_utils/__init__.py -------------------------------------------------------------------------------- /memit/transformer_utils/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | seaborn 4 | tqdm 5 | colorcet 6 | -------------------------------------------------------------------------------- /memit/transformer_utils/setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from setuptools import setup, find_packages 3 | 4 | # The directory containing this file 5 | HERE = pathlib.Path(__file__).parent 6 | 7 | # The text of the README file 8 | README = (HERE / "README.md").read_text() 9 | 10 | setup( 11 | name='transformer-utils', 12 | description="Large autoregressive language modeling helpers", 13 | long_description=README, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/nostalgebraist/transformer-utils", 16 | author="nostalgebraist", 17 | author_email="nostalgebraist@gmail.com", 18 | license="MIT", 19 | version='0.1.0', 20 | packages=find_packages('src'), 21 | package_dir={'': 'src'}, 22 | install_requires=[ 23 | 'torch', 24 | 'transformers', 25 | 'seaborn', 26 | 'tqdm', 27 | 'colorcet' 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /memit/transformer_utils/src/transformer_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .low_memory import low_memory_load 2 | from . import util 3 | from . import partial_forward 4 | 5 | __all__ = [ 6 | 'low_memory_load', 7 | 'util', 8 | 'partial_forward', 9 | ] 10 | -------------------------------------------------------------------------------- /memit/transformer_utils/src/transformer_utils/logit_lens/__init__.py: -------------------------------------------------------------------------------- 1 | from .hooks import make_lens_hooks, clear_lens_hooks 2 | from . import plotting 3 | from .plotting import plot_logit_lens 4 | 5 | __all__ = [ 6 | 'make_lens_hooks', 7 | 'clear_lens_hooks', 8 | 'plotting', 9 | 'plot_logit_lens' 10 | ] 11 | -------------------------------------------------------------------------------- /memit/transformer_utils/src/transformer_utils/logit_lens/hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..util.python_utils import make_print_if_verbose 5 | from ..util.module_utils import get_child_module_by_names 6 | 7 | _RESID_SUFFIXES = {".attn", ".mlp"} 8 | 9 | 10 | def blocks_input_locator(model: nn.Module): 11 | """ 12 | HF usually (always?) places a dropout after the input embeddings. 13 | TODO: avoid depending on this 14 | """ 15 | dropouts_on_base_model = [ 16 | mod for mod in model.base_model.children() 17 | if isinstance(mod, nn.Dropout) 18 | ] 19 | if len(dropouts_on_base_model) > 0: 20 | return lambda: dropouts_on_base_model[0] 21 | raise ValueError('could not identify blocks input') 22 | 23 | 24 | def final_layernorm_locator(model: nn.Module): 25 | layernorms_on_base_model = [ 26 | mod for mod in model.base_model.children() 27 | if isinstance(mod, nn.LayerNorm) 28 | ] 29 | if len(layernorms_on_base_model) > 0: 30 | return lambda: layernorms_on_base_model[0] 31 | raise ValueError('could not identify ln_f') 32 | 33 | 34 | def _locate_special_modules(model): 35 | if not hasattr(model, "_blocks_input_getter"): 36 | model._blocks_input_getter = blocks_input_locator(model) 37 | 38 | if not hasattr(model, "_ln_f_getter"): 39 | model._ln_f_getter = final_layernorm_locator(model) 40 | 41 | 42 | def _get_layer(model, name): 43 | if name == "input": 44 | return model._blocks_input_getter() 45 | if name == "final_layernorm": 46 | return model._ln_f_getter() 47 | 48 | model_with_module = model if name == "lm_head" else model.base_model 49 | return get_child_module_by_names(model_with_module, name.split(".")) 50 | 51 | 52 | def _sqz(x): 53 | if isinstance(x, torch.Tensor): 54 | return x 55 | try: 56 | return x[0] 57 | except: 58 | return x 59 | 60 | 61 | def _get_layer_and_compose_with_ln(model, name): 62 | if name.endswith('.attn'): 63 | lname = name[:-len('.attn')] + '.ln_1' 64 | ln = _get_layer(model, lname) 65 | elif name.endswith('.mlp'): 66 | lname = name[:-len('.mlp')] + '.ln_2' 67 | ln = _get_layer(model, lname) 68 | else: 69 | ln = lambda x: x 70 | return lambda x: _get_layer(model, name)(ln(x)) 71 | 72 | 73 | def make_decoder(model, decoder_layer_names=['final_layernorm', 'lm_head']): 74 | _locate_special_modules(model) 75 | 76 | decoder_layers = [_get_layer_and_compose_with_ln(model, name) for name in decoder_layer_names] 77 | 78 | def _decoder(x): 79 | for name, layer in zip(decoder_layer_names, decoder_layers): 80 | layer_out = _sqz(layer(_sqz(x))) 81 | 82 | # TODO: DRY 83 | is_resid = any([name.endswith(s) for s in _RESID_SUFFIXES]) 84 | if is_resid: 85 | x = x + layer_out 86 | else: 87 | x = layer_out 88 | return x 89 | return _decoder 90 | 91 | 92 | def make_lens_hooks( 93 | model, 94 | layer_names: list, 95 | decoder_layer_names: list = ['final_layernorm', 'lm_head'], 96 | verbose=True, 97 | start_ix=None, 98 | end_ix=None, 99 | ): 100 | vprint = make_print_if_verbose(verbose) 101 | 102 | clear_lens_hooks(model) 103 | 104 | def _opt_slice(x, start_ix, end_ix): 105 | if not start_ix: 106 | start_ix = 0 107 | if not end_ix: 108 | end_ix = x.shape[1] 109 | return x[:, start_ix:end_ix, :] 110 | 111 | _locate_special_modules(model) 112 | 113 | for attr in ["_layer_logits", "_layer_logits_handles"]: 114 | if not hasattr(model, attr): 115 | setattr(model, attr, {}) 116 | 117 | # TODO: better naming 118 | model._ordered_layer_names = layer_names 119 | 120 | model._lens_decoder = make_decoder(model, decoder_layer_names) 121 | 122 | def _make_record_logits_hook(name): 123 | model._layer_logits[name] = None 124 | 125 | is_resid = any([name.endswith(s) for s in _RESID_SUFFIXES]) 126 | 127 | def _record_logits_hook(module, input, output) -> None: 128 | del model._layer_logits[name] 129 | ln_f = model._ln_f_getter() 130 | 131 | if is_resid: 132 | decoder_in = model._last_resid + _sqz(output) 133 | else: 134 | decoder_in = _sqz(output) 135 | # print(decoder_in.shape) 136 | decoder_out = model._lens_decoder(decoder_in) 137 | # print(decoder_out.shape) 138 | decoder_out = _opt_slice(decoder_out, start_ix, end_ix) 139 | 140 | model._layer_logits[name] = decoder_out.detach().cpu().numpy() 141 | model._last_resid = decoder_in 142 | 143 | return _record_logits_hook 144 | 145 | def _hook_already_there(name): 146 | handle = model._layer_logits_handles.get(name) 147 | if not handle: 148 | return False 149 | layer = _get_layer(model, name) 150 | return handle.id in layer._forward_hooks 151 | 152 | for name in layer_names: 153 | if _hook_already_there(name): 154 | vprint(f"skipping layer {name}, hook already exists") 155 | continue 156 | layer = _get_layer(model, name) 157 | handle = layer.register_forward_hook(_make_record_logits_hook(name)) 158 | model._layer_logits_handles[name] = handle 159 | 160 | 161 | def clear_lens_hooks(model): 162 | if hasattr(model, "_layer_logits_handles"): 163 | for k, v in model._layer_logits_handles.items(): 164 | v.remove() 165 | 166 | ks = list(model._layer_logits_handles.keys()) 167 | for k in ks: 168 | del model._layer_logits_handles[k] 169 | -------------------------------------------------------------------------------- /memit/transformer_utils/src/transformer_utils/logit_lens/layer_names.py: -------------------------------------------------------------------------------- 1 | from ..util.module_utils import get_child_module_by_names 2 | 3 | 4 | def make_layer_names( 5 | model, 6 | block_step=1, 7 | include_input=True, 8 | force_include_output=True, 9 | include_subblocks=False, 10 | decoder_layer_names: list = ['final_layernorm', 'lm_head'] 11 | ): 12 | h = get_child_module_by_names(model.base_model, ["h"]) 13 | h_names = [f"h.{i}" for i in range(len(h))] 14 | 15 | last_h_name = h_names[-1] 16 | 17 | h_names = h_names[::block_step] 18 | if force_include_output and last_h_name not in h_names: 19 | h_names.append(last_h_name) 20 | 21 | if include_subblocks: 22 | names = [sub_name for name in h_names for sub_name in (f"{name}.attn", name)] 23 | else: 24 | names = h_names 25 | 26 | if include_input: 27 | names = ["input"] + names 28 | 29 | def _subset(a, b): 30 | return a == b or a.startswith(b + ".") 31 | 32 | def _names_overlap(a, b): 33 | return _subset(a, b) or _subset(b, a) 34 | 35 | names = [name for name in names 36 | if not any([_names_overlap(name, dname) for dname in decoder_layer_names]) 37 | ] 38 | 39 | return names 40 | -------------------------------------------------------------------------------- /memit/transformer_utils/src/transformer_utils/low_memory/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_context import LazyLinearAPICompatible, LazyTransformersConv1D, LowMemoryLoadContext 2 | from .load import low_memory_load 3 | from .enable import enable_low_memory_load, disable_low_memory_load 4 | 5 | __all__ = [ 6 | 'LazyLinearAPICompatible', 7 | 'LazyTransformersConv1D', 8 | 'LowMemoryLoadContext', 9 | 'low_memory_load', 10 | 'enable_low_memory_load', 11 | 'disable_low_memory_load' 12 | ] 13 | -------------------------------------------------------------------------------- /memit/transformer_utils/src/transformer_utils/low_memory/enable.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | from .load import low_memory_load 4 | from ..util.tfm_utils import huggingface_model_local_paths 5 | 6 | _TFM_PRETRAINED_MODEL_FROM_PRETRAINED_ORIGINAL = transformers.modeling_utils.PreTrainedModel.from_pretrained 7 | 8 | 9 | def low_memory_from_pretrained(pretrained_model_name_or_path, *args, **kwargs): 10 | config_path, model_path = huggingface_model_local_paths(pretrained_model_name_or_path) 11 | 12 | model = low_memory_load(config_path=config_path, model_path=model_path, verbose=False) 13 | 14 | return model 15 | 16 | 17 | def enable_low_memory_load(): 18 | transformers.modeling_utils.PreTrainedModel.from_pretrained = low_memory_from_pretrained 19 | 20 | 21 | def disable_low_memory_load(): 22 | transformers.modeling_utils.PreTrainedModel.from_pretrained = _TFM_PRETRAINED_MODEL_FROM_PRETRAINED_ORIGINAL 23 | -------------------------------------------------------------------------------- /memit/transformer_utils/src/transformer_utils/low_memory/load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | 4 | from ..util.python_utils import make_print_if_verbose 5 | from ..util.module_utils import get_child_module_by_names 6 | from ..util.tfm_utils import normalize_inconsistent_state_dict_keys 7 | from .load_context import LowMemoryLoadContext 8 | 9 | 10 | DEFAULT_GENERIC_INPUT = torch.as_tensor([[0]]) 11 | 12 | 13 | def modify_weights_after_load(model): 14 | """the part of PreTrainedModel.init_weights that isn't initializing weights""" 15 | # Prune heads if needed 16 | if model.config.pruned_heads: 17 | model.prune_heads(model.config.pruned_heads) 18 | 19 | # Tie weights if needed 20 | model.tie_weights() 21 | 22 | 23 | def low_memory_load( 24 | config_path, 25 | model_path, 26 | config_cls=None, 27 | model_cls=None, 28 | high_memory_device="cuda:1", 29 | generic_input=DEFAULT_GENERIC_INPUT, 30 | verbose=True, 31 | ): 32 | vprint = make_print_if_verbose(verbose) 33 | 34 | if isinstance(high_memory_device, str): 35 | high_memory_device = torch.device(high_memory_device) 36 | 37 | if config_cls is None: 38 | config_cls = transformers.AutoConfig 39 | 40 | vprint("start") 41 | 42 | with LowMemoryLoadContext(): 43 | config = config_cls.from_pretrained(config_path) 44 | 45 | vprint("made config obj") 46 | 47 | state_dict = torch.load( 48 | model_path, 49 | map_location=high_memory_device, 50 | ) 51 | 52 | state_dict = normalize_inconsistent_state_dict_keys(state_dict) 53 | 54 | vprint("loaded state dict") 55 | 56 | # uses lazy init, no memory 57 | if model_cls is None: 58 | model = transformers.AutoModelForCausalLM.from_config(config) 59 | else: 60 | model = model_cls(config=config) 61 | 62 | vprint("made model obj") 63 | 64 | # START gpu --> cpu --> gpu handoff, one leaf module at a time 65 | handled = set() 66 | 67 | for name in dict(model.named_parameters()).keys(): 68 | prefix = name.rpartition(".")[0] 69 | mod = get_child_module_by_names(model, prefix.split(".")) 70 | 71 | if prefix in handled: 72 | continue 73 | 74 | vprint((name, prefix, mod)) 75 | 76 | mk, uk, er = [], [], [] 77 | mod._load_from_state_dict( 78 | state_dict, 79 | prefix=prefix + ".", 80 | local_metadata={}, 81 | strict=True, 82 | missing_keys=mk, 83 | unexpected_keys=uk, 84 | error_msgs=er, 85 | ) 86 | vprint((mk, uk, er)) 87 | mod.to(high_memory_device) 88 | sdks = [k for k in state_dict if k.startswith(prefix)] 89 | for k in sdks: 90 | del state_dict[k] 91 | handled.add(prefix) 92 | 93 | # END gpu --> cpu --> gpu handoff, one leaf module at a time 94 | 95 | vprint("loaded params into memory") 96 | 97 | # does the buffers 98 | model = model.to(high_memory_device) 99 | 100 | vprint("loaded all into memory") 101 | 102 | # does stuff like weight tying, now that the weights actually exist 103 | modify_weights_after_load(model) 104 | 105 | model.eval() 106 | 107 | # ensures we materialize the lazy params (and delete the hooks for doing so), before doing anything else 108 | # 109 | # if you add pre-hooks before doing this step, you get OrderedDict mutation exceptions 110 | with torch.no_grad(): 111 | out = model(generic_input.to(high_memory_device)) 112 | out = None 113 | 114 | return model 115 | -------------------------------------------------------------------------------- /memit/transformer_utils/src/transformer_utils/low_memory/load_context.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | import transformers 3 | from torch.nn.modules.lazy import LazyModuleMixin 4 | from torch.nn.parameter import UninitializedParameter 5 | 6 | _TORCH_NN_ORIGINAL_LINEAR = torch.nn.Linear 7 | 8 | _TFM_PRETRAINED_MODEL_INIT_WEIGHTS_ORIGINAL = ( 9 | transformers.modeling_utils.PreTrainedModel.init_weights 10 | ) 11 | _TFM_CONV1D_ORIGINAL = transformers.modeling_utils.Conv1D 12 | 13 | 14 | def init_weights_without_init(self): 15 | pass 16 | 17 | 18 | class LazyLinearAPICompatible(torch.nn.LazyLinear): 19 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: 20 | super().__init__(out_features=out_features, bias=bias) 21 | 22 | 23 | class LazyTransformersConv1D(LazyModuleMixin, _TFM_CONV1D_ORIGINAL): 24 | cls_to_become = _TFM_CONV1D_ORIGINAL 25 | weight: UninitializedParameter 26 | 27 | def __init__(self, nf, nx): 28 | super().__init__(nf=nf, nx=0) 29 | self.nx = 0 30 | self.weight = UninitializedParameter() 31 | 32 | def reset_parameters(self) -> None: 33 | if not self.has_uninitialized_params() and self.nx != 0: 34 | super().reset_parameters() 35 | 36 | def initialize_parameters(self, input) -> None: 37 | if self.has_uninitialized_params(): 38 | with torch.no_grad(): 39 | self.nx = input.shape[-1] 40 | self.weight.materialize((self.nf, self.nx)) 41 | self.reset_parameters() 42 | 43 | 44 | class LowMemoryLoadContext: 45 | def __enter__(self): 46 | torch.nn.Linear = LazyLinearAPICompatible 47 | transformers.modeling_utils.Conv1D = LazyTransformersConv1D 48 | transformers.PreTrainedModel.init_weights = init_weights_without_init 49 | 50 | def __exit__(self, exc_type, exc_value, exc_traceback): 51 | torch.nn.Linear = _TORCH_NN_ORIGINAL_LINEAR 52 | transformers.modeling_utils.Conv1D = _TFM_CONV1D_ORIGINAL 53 | transformers.PreTrainedModel.init_weights = ( 54 | _TFM_PRETRAINED_MODEL_INIT_WEIGHTS_ORIGINAL 55 | ) 56 | return exc_type is None 57 | -------------------------------------------------------------------------------- /memit/transformer_utils/src/transformer_utils/partial_forward/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import inspect 3 | 4 | from ..util.python_utils import make_print_if_verbose 5 | 6 | 7 | PARTIAL_FORWARD_FORCE_FALSE_KWARGS = { 8 | "use_cache", 9 | "output_attentions", 10 | "output_hidden_states", 11 | "return_dict", 12 | } 13 | 14 | PARTIAL_FORWARD_FORCE_FALSE_KWARGS_MSG = """`partial_forward` was passed the argument {kwarg} but will ignore it. 15 | 16 | `partial_forward` ignores arguments that configure output shape in `transformers`, since its output shape is configured entirely through the `output_names` argument.""" 17 | 18 | VALIDATE_OUTPUT_BASE_MODEL_MSG = """Some `output_names` were not found on the model (a `{model_class_name}`), but exist on its base model (a `{base_model_class_name}`). 19 | 20 | Try either passing `model.base_model` as the model, OR adding the string '{base_model_prefix}.' to the start of each output name. 21 | 22 | Names not found: {names}""" 23 | 24 | VALIDATE_OUTPUT_NOT_FOUND_MSG = """Some `output_names` were not found on the model. 25 | 26 | To see valid output names, try `dict(model.named_modules()).keys()`. 27 | 28 | Names not found: {names}""" 29 | 30 | 31 | class AfterStoppingPointException(Exception): 32 | pass 33 | 34 | 35 | def _validate_output_names(model, output_names): 36 | if output_names is None: 37 | return 38 | 39 | findable_names = dict(model.named_modules()).keys() 40 | 41 | findable_names_base_model = set() 42 | if hasattr(model, "base_model") and hasattr(model, "base_model_prefix"): 43 | findable_names_base_model = dict(model.base_model.named_modules()).keys() 44 | 45 | problem_names = [name for name in output_names if name not in findable_names] 46 | 47 | base_model_names = [ 48 | name for name in problem_names if name in findable_names_base_model 49 | ] 50 | 51 | if len(base_model_names) > 0: 52 | raise ValueError( 53 | VALIDATE_OUTPUT_BASE_MODEL_MSG.format( 54 | model_class_name=model.__class__.__name__, 55 | base_model_class_name=model.base_model.__class__.__name__, 56 | base_model_prefix=model.base_model_prefix, 57 | names=base_model_names, 58 | ) 59 | ) 60 | 61 | if len(problem_names) > 0: 62 | raise ValueError(VALIDATE_OUTPUT_NOT_FOUND_MSG.format(names=problem_names)) 63 | 64 | 65 | def add_partial_forward_hooks(model, verbose=False, debug=False, output_names=None): 66 | vprint = make_print_if_verbose(verbose) 67 | dprint = make_print_if_verbose(debug) 68 | 69 | _validate_output_names(model, output_names) 70 | 71 | can_skip = output_names is not None 72 | can_skip = can_skip and hasattr(model, "_partial_forward_force_false_kwargs") 73 | 74 | names_to_mods = {} 75 | indices_to_names = {} 76 | names, mods = [], [] 77 | for i, (name, mod) in enumerate(model.named_modules()): 78 | if hasattr(mod, "_partial_forward_name") and mod._partial_forward_name != name: 79 | can_skip = False 80 | 81 | mod._partial_forward_name = name 82 | indices_to_names[i] = name 83 | 84 | names.append(name) 85 | mods.append(mod) 86 | names_to_mods[name] = mod 87 | 88 | if output_names is not None: 89 | should_have_hook = name in output_names 90 | already_has_hook = hasattr(mod, "_record_to_sink_handle") 91 | can_skip = can_skip and (should_have_hook == already_has_hook) 92 | 93 | if can_skip: 94 | dprint("already have partial forward hooks, skipping") 95 | return 96 | 97 | sig = inspect.signature(model.__class__.forward) 98 | model._partial_forward_force_false_kwargs = ( 99 | PARTIAL_FORWARD_FORCE_FALSE_KWARGS.intersection(sig.parameters.keys()) 100 | ) 101 | 102 | def _record_to_sink_hook(module, input, output) -> None: 103 | if hasattr(model, "_output_sink_names"): 104 | this_name = module._partial_forward_name 105 | dprint(f"reached output of {repr(this_name)}") 106 | dprint(f"model._output_sink_names: {model._output_sink_names}") 107 | 108 | if this_name in model._output_sink_names: 109 | dprint(f"{repr(this_name)} in sink") 110 | 111 | to_record = output 112 | if isinstance(to_record, tuple) and len(to_record) == 1: 113 | to_record = to_record[0] 114 | 115 | model._output_sink[this_name] = to_record 116 | 117 | if all([name in model._output_sink for name in model._output_sink_names]): 118 | dprint("have all model._output_sink_names, stopping") 119 | 120 | raise AfterStoppingPointException 121 | 122 | for name, mod in zip(names, mods): 123 | if hasattr(mod, "_record_to_sink_handle"): 124 | vprint(f"clearing existing handle at {repr(name)}") 125 | mod._record_to_sink_handle.remove() 126 | 127 | if output_names is None or name in output_names: 128 | rts_handle = mod.register_forward_hook(_record_to_sink_hook) 129 | mod._record_to_sink_handle = rts_handle 130 | 131 | 132 | def partial_forward( 133 | model, 134 | output_names, 135 | *args, 136 | verbose=False, 137 | debug=False, 138 | **kwargs, 139 | ): 140 | vprint = make_print_if_verbose(verbose) 141 | 142 | add_partial_forward_hooks( 143 | model, verbose=verbose, debug=debug, output_names=output_names 144 | ) 145 | 146 | for k in model._partial_forward_force_false_kwargs: 147 | if kwargs.get(k): 148 | warnings.warn(PARTIAL_FORWARD_FORCE_FALSE_KWARGS_MSG.format(kwarg=repr(k))) 149 | kwargs[k] = False 150 | 151 | model._output_sink_names = output_names 152 | 153 | if hasattr(model, "_output_sink"): 154 | vprint("clearing existing _output_sink") 155 | for v in model._output_sink.values(): 156 | del v 157 | del model._output_sink 158 | 159 | model._output_sink = {} 160 | 161 | try: 162 | model(*args, **kwargs) 163 | except AfterStoppingPointException as e: 164 | pass 165 | 166 | del model._output_sink_names 167 | 168 | return_val = model._output_sink 169 | del model._output_sink 170 | 171 | return return_val 172 | -------------------------------------------------------------------------------- /memit/transformer_utils/src/transformer_utils/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuaihong/ConceptVectors/607591b415043f7692bc17a9748de3d8ff3fc0c7/memit/transformer_utils/src/transformer_utils/util/__init__.py -------------------------------------------------------------------------------- /memit/transformer_utils/src/transformer_utils/util/module_utils.py: -------------------------------------------------------------------------------- 1 | from .python_utils import make_print_if_verbose 2 | 3 | 4 | def get_child_module_by_names(module, names): 5 | obj = module 6 | for getter in map(lambda name: lambda obj: getattr(obj, name), names): 7 | obj = getter(obj) 8 | return obj 9 | 10 | 11 | def get_leaf_modules(module, verbose=False): 12 | vprint = make_print_if_verbose(verbose) 13 | 14 | names = [] 15 | leaves = [] 16 | handled = set() 17 | 18 | for param_name in dict(module.named_parameters()).keys(): 19 | mod_name = param_name.rpartition(".")[0] 20 | mod = get_child_module_by_names(module, mod_name.split(".")) 21 | 22 | if mod_name in handled: 23 | continue 24 | 25 | vprint((param_name, mod_name, mod)) 26 | 27 | names.append(mod_name) 28 | leaves.append(mod) 29 | handled.add(mod_name) 30 | 31 | return names, leaves 32 | -------------------------------------------------------------------------------- /memit/transformer_utils/src/transformer_utils/util/python_utils.py: -------------------------------------------------------------------------------- 1 | def make_print_if_verbose(verbose: bool): 2 | def vprint(*args, **kwargs): 3 | if verbose: 4 | print(*args, **kwargs) 5 | 6 | return vprint 7 | -------------------------------------------------------------------------------- /memit/transformer_utils/src/transformer_utils/util/tfm_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import transformers.file_utils 4 | from transformers.models.auto.configuration_auto import CONFIG_MAPPING 5 | 6 | 7 | def fix_config_with_missing_model_type(model_name, config_path): 8 | with open(config_path, 'r', encoding='utf-8') as f: 9 | config = json.load(f) 10 | 11 | model_type = config.get('model_type') 12 | 13 | # cf https://github.com/huggingface/transformers/blob/v4.5.1/src/transformers/models/auto/configuration_auto.py#L403 14 | # 15 | # we reproduce that logic here, but save the fixed config to the json file 16 | # so it will work more robustly, i.e. even if you are not using `AutoConfig` 17 | if model_type is None: 18 | for pattern, config_class in CONFIG_MAPPING.items(): 19 | if pattern in model_name: 20 | config['model_type'] = config_class.model_type 21 | 22 | with open(config_path, 'w', encoding='utf-8') as f: 23 | json.dump(config, f) 24 | 25 | 26 | def get_local_path_from_huggingface_cdn(key, filename): 27 | archive_file = transformers.file_utils.hf_bucket_url( 28 | key, 29 | filename=filename, 30 | ) 31 | 32 | resolved_archive_file = transformers.file_utils.cached_path( 33 | archive_file, 34 | ) 35 | return resolved_archive_file 36 | 37 | 38 | def huggingface_model_local_paths(model_name): 39 | config_path = get_local_path_from_huggingface_cdn(model_name, "config.json") 40 | 41 | fix_config_with_missing_model_type(model_name, config_path) 42 | 43 | model_path = get_local_path_from_huggingface_cdn(model_name, "pytorch_model.bin") 44 | 45 | return config_path, model_path 46 | 47 | 48 | def normalize_inconsistent_state_dict_keys(state_dict): 49 | normalized = {} 50 | 51 | for k in state_dict.keys(): 52 | if k.startswith("transformer."): 53 | normalized[k] = state_dict[k] 54 | else: 55 | normalized["transformer." + k] = state_dict[k] 56 | return normalized 57 | -------------------------------------------------------------------------------- /memit/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .logit_lens import LogitLens 2 | -------------------------------------------------------------------------------- /memit/util/generate.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | from typing import List, Optional 3 | 4 | import torch 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | 7 | from util.logit_lens import LogitLens 8 | 9 | 10 | def generate_interactive( 11 | model: AutoModelForCausalLM, 12 | tok: AutoTokenizer, 13 | top_k: int = 5, 14 | max_out_len: int = 200, 15 | compare_against: Optional[AutoModelForCausalLM] = None, 16 | use_logit_lens: bool = False, 17 | layer_module_tmp: str = "transformer.h.{}", 18 | ln_f_module: str = "transformer.ln_f", 19 | lm_head_module: str = "lm_head", 20 | ): 21 | """ 22 | Puts generation in a loop. Allows users to repeatedly provide inputs 23 | with which text is generated. 24 | """ 25 | 26 | if use_logit_lens: 27 | llens_gen = LogitLens( 28 | model, 29 | tok, 30 | layer_module_tmp, 31 | ln_f_module, 32 | lm_head_module, 33 | disabled=not use_logit_lens, 34 | ) 35 | if compare_against: 36 | llens_vanilla = LogitLens( 37 | compare_against, 38 | tok, 39 | layer_module_tmp, 40 | ln_f_module, 41 | lm_head_module, 42 | disabled=not use_logit_lens, 43 | ) 44 | 45 | while True: 46 | prompt = input("Enter a prompt: ").strip(" \r\t\n") 47 | 48 | print( 49 | f"Argument Model: " 50 | f"{generate_fast(model, tok, [prompt], n_gen_per_prompt=1, top_k=top_k, max_out_len=max_out_len)}" 51 | ) 52 | if compare_against: 53 | print( 54 | f"Baseline Model: " 55 | f"{generate_fast(compare_against, tok, [prompt], n_gen_per_prompt=1, top_k=top_k, max_out_len=max_out_len)}" 56 | ) 57 | 58 | if use_logit_lens: 59 | inp_prompt = tok([prompt], padding=True, return_tensors="pt").to( 60 | next(model.parameters()).device 61 | ) 62 | 63 | with llens_gen: 64 | model(**inp_prompt) 65 | print("\n--- Argument Model Logit Lens ---") 66 | llens_gen.pprint() 67 | 68 | if compare_against: 69 | with llens_vanilla: 70 | compare_against(**inp_prompt) 71 | print("--- Baseline Model Logit Lens ---") 72 | llens_vanilla.pprint() 73 | 74 | print() 75 | 76 | 77 | def generate_fast( 78 | model: AutoModelForCausalLM, 79 | tok: AutoTokenizer, 80 | prompts: List[str], 81 | n_gen_per_prompt: int = 1, 82 | top_k: int = 5, 83 | max_out_len: int = 200, 84 | ): 85 | """ 86 | Fast, parallelized auto-regressive text generation with top-k sampling. 87 | Our custom implementation. 88 | """ 89 | 90 | # Unroll prompts and tokenize 91 | inp = [prompt for prompt in prompts for _ in range(n_gen_per_prompt)] 92 | inp_tok = tok(inp, padding=True, return_tensors="pt").to( 93 | next(model.parameters()).device 94 | ) 95 | input_ids, attention_mask = inp_tok["input_ids"], inp_tok["attention_mask"] 96 | batch_size = input_ids.size(0) 97 | 98 | # Setup storage of fast generation with attention caches. 99 | # `cur_context` is used to define the range of inputs that are not yet 100 | # stored in `past_key_values`. At each step, we are generating the 101 | # next token for the index at `cur_context.stop + 1`. 102 | past_key_values, cur_context = None, slice(0, attention_mask.sum(1).min().item()) 103 | 104 | with torch.no_grad(): 105 | while input_ids.size(1) < max_out_len: # while not exceeding max output length 106 | model_out = model( 107 | input_ids=input_ids[:, cur_context], 108 | attention_mask=None if 'llama' or 'olmo' in model.name_or_path.lower() else attention_mask[:, cur_context], 109 | past_key_values=past_key_values, 110 | use_cache=True, 111 | ) 112 | logits, past_key_values = model_out.logits, model_out.past_key_values 113 | softmax_out = torch.nn.functional.softmax(logits[:, -1, :], dim=1) 114 | 115 | # Top-k sampling 116 | tk = torch.topk(softmax_out, top_k, dim=1).indices 117 | softmax_out_top_k = torch.gather(softmax_out, 1, tk) 118 | softmax_out_top_k = softmax_out_top_k / softmax_out_top_k.sum(1)[:, None] 119 | new_tok_indices = torch.multinomial(softmax_out_top_k, 1) 120 | new_toks = torch.gather(tk, 1, new_tok_indices) 121 | 122 | # If we're currently generating the continuation for the last token in `input_ids`, 123 | # create a new index so we can insert the new token 124 | if cur_context.stop == input_ids.size(1): 125 | attention_mask = torch.cat( 126 | [attention_mask, attention_mask.new_zeros(batch_size, 1)], dim=1 127 | ) 128 | input_ids = torch.cat( 129 | [ 130 | input_ids, 131 | input_ids.new_ones(batch_size, 1) * tok.pad_token_id, 132 | ], 133 | dim=1, 134 | ) 135 | 136 | last_non_masked = attention_mask.sum(1) - 1 137 | for i in range(batch_size): 138 | new_idx = last_non_masked[i] + 1 139 | if last_non_masked[i].item() + 1 != cur_context.stop: 140 | continue 141 | 142 | # Stop generating if we've already maxed out for this prompt 143 | if new_idx < max_out_len: 144 | input_ids[i][new_idx] = new_toks[i] 145 | attention_mask[i][new_idx] = 1 146 | 147 | cur_context = slice(cur_context.stop, cur_context.stop + 1) 148 | 149 | txt = [tok.decode(x) for x in input_ids.detach().cpu().numpy().tolist()] 150 | txt = [ 151 | unicodedata.normalize("NFKD", x) 152 | .replace("\n\n", " ") 153 | .replace("<|endoftext|>", "") 154 | for x in txt 155 | ] 156 | 157 | return txt 158 | -------------------------------------------------------------------------------- /memit/util/globals.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yaml 4 | 5 | with open("globals.yml", "r") as stream: 6 | data = yaml.safe_load(stream) 7 | 8 | (RESULTS_DIR, DATA_DIR, STATS_DIR, HPARAMS_DIR, KV_DIR) = ( 9 | Path(z) 10 | for z in [ 11 | data["RESULTS_DIR"], 12 | data["DATA_DIR"], 13 | data["STATS_DIR"], 14 | data["HPARAMS_DIR"], 15 | data["KV_DIR"], 16 | ] 17 | ) 18 | 19 | REMOTE_ROOT_URL = data["REMOTE_ROOT_URL"] 20 | -------------------------------------------------------------------------------- /memit/util/hparams.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class HyperParams: 7 | """ 8 | Simple wrapper to store hyperparameters for Python-based rewriting methods. 9 | """ 10 | 11 | @classmethod 12 | def from_json(cls, fpath): 13 | with open(fpath, "r") as f: 14 | data = json.load(f) 15 | 16 | return cls(**data) 17 | -------------------------------------------------------------------------------- /memit/util/logit_lens.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | 7 | from util import nethook 8 | 9 | 10 | class LogitLens: 11 | """ 12 | Applies the LM head at the output of each hidden layer, then analyzes the 13 | resultant token probability distribution. 14 | 15 | Only works when hooking outputs of *one* individual generation. 16 | 17 | Inspiration: https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens 18 | 19 | Warning: when running multiple times (e.g. generation), will return 20 | outputs _only_ for the last processing step. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | model: AutoModelForCausalLM, 26 | tok: AutoTokenizer, 27 | layer_module_tmp: str, 28 | ln_f_module: str, 29 | lm_head_module: str, 30 | disabled: bool = False, 31 | ): 32 | self.disabled = disabled 33 | self.model, self.tok = model, tok 34 | self.n_layers = self.model.config.n_layer 35 | 36 | self.lm_head, self.ln_f = ( 37 | nethook.get_module(model, lm_head_module), 38 | nethook.get_module(model, ln_f_module), 39 | ) 40 | 41 | self.output: Optional[Dict] = None 42 | self.td: Optional[nethook.TraceDict] = None 43 | self.trace_layers = [ 44 | layer_module_tmp.format(layer) for layer in range(self.n_layers) 45 | ] 46 | 47 | def __enter__(self): 48 | if not self.disabled: 49 | self.td = nethook.TraceDict( 50 | self.model, 51 | self.trace_layers, 52 | retain_input=False, 53 | retain_output=True, 54 | ) 55 | self.td.__enter__() 56 | 57 | def __exit__(self, *args): 58 | if self.disabled: 59 | return 60 | self.td.__exit__(*args) 61 | 62 | self.output = {layer: [] for layer in range(self.n_layers)} 63 | 64 | with torch.no_grad(): 65 | for layer, (_, t) in enumerate(self.td.items()): 66 | cur_out = t.output[0] 67 | assert ( 68 | cur_out.size(0) == 1 69 | ), "Make sure you're only running LogitLens on single generations only." 70 | 71 | self.output[layer] = torch.softmax( 72 | self.lm_head(self.ln_f(cur_out[:, -1, :])), dim=1 73 | ) 74 | 75 | return self.output 76 | 77 | def pprint(self, k=5): 78 | to_print = defaultdict(list) 79 | 80 | for layer, pred in self.output.items(): 81 | rets = torch.topk(pred[0], k) 82 | for i in range(k): 83 | to_print[layer].append( 84 | ( 85 | self.tok.decode(rets[1][i]), 86 | round(rets[0][i].item() * 1e2) / 1e2, 87 | ) 88 | ) 89 | 90 | print( 91 | "\n".join( 92 | [ 93 | f"{layer}: {[(el[0], round(el[1] * 1e2)) for el in to_print[layer]]}" 94 | for layer in range(self.n_layers) 95 | ] 96 | ) 97 | ) 98 | -------------------------------------------------------------------------------- /memit/util/perplexity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | 4 | 5 | def perplexity( 6 | model: AutoModelForCausalLM, 7 | tok: AutoTokenizer, 8 | text: str, 9 | max_input_length: int = None, 10 | ): 11 | """ 12 | Computes perplexity of a piece of text, measured on a reference model. 13 | Text is truncated to max_input_length tokens. 14 | """ 15 | 16 | inputs = tok( 17 | [text], return_tensors="pt", max_length=max_input_length, truncation=True 18 | ).to("cuda") 19 | 20 | logits = torch.nn.functional.log_softmax(model(**inputs).logits, dim=2) 21 | log_probs = torch.gather(logits[:, :-1, :], 2, inputs["input_ids"][:, 1:, None])[0] 22 | 23 | # Perplexity = exp(-1/N * log P(x_1, ..., x_n)) 24 | return torch.exp(-1 / inputs["input_ids"].size(1) * log_probs.sum()).item() 25 | -------------------------------------------------------------------------------- /memit/zsre_evals.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Constants 5 | N_EDITS="10000" 6 | 7 | # Run configurations 8 | MODEL_NAME="EleutherAI/gpt-j-6B" 9 | ALG_NAMES=("FT" "MEND" "ROME" "MEMIT") 10 | HPARAMS_FNAMES=("EleutherAI_gpt-j-6B_wd.json" "EleutherAI_gpt-j-6B.json" "EleutherAI_gpt-j-6B.json" "EleutherAI_gpt-j-6B.json") 11 | 12 | # Execute 13 | for i in ${!ALG_NAMES[@]} 14 | do 15 | alg_name=${ALG_NAMES[$i]} 16 | hparams_fname=${HPARAMS_FNAMES[$i]} 17 | 18 | echo "Running evals for $alg_name..." 19 | 20 | python3 -m experiments.evaluate --alg_name=$alg_name --model_name=$MODEL_NAME --hparams_fname=$hparams_fname --num_edits=$N_EDITS --use_cache --dataset_size_limit=$N_EDITS --ds_name=zsre 21 | done 22 | 23 | exit 0 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ai2-olmo==0.2.5 2 | aiohttp==3.9.3 3 | aiosignal==1.3.1 4 | anyio==4.3.0 5 | better-abc==0.0.3 6 | bitsandbytes==0.43.1 7 | datasets==2.18.0 8 | evaluate==0.4.2 9 | huggingface-hub==0.21.4 10 | hydra-core==1.3.2 11 | jupyter 12 | matplotlib==3.7.5 13 | nltk==3.8.1 14 | numpy==1.24.1 15 | openai==1.25.0 16 | pandas==2.0.3 17 | peft==0.10.0 18 | pillow==10.2.0 19 | PyYAML==6.0.1 20 | rouge==1.0.1 21 | rouge-score==0.1.2 22 | scipy==1.10.1 23 | sentencepiece==0.2.0 24 | statistics==1.0.3.5 25 | scikit-learn=1.5.2 26 | tokenizers==0.15.2 27 | tornado==6.4 28 | tqdm==4.66.2 29 | transformer-lens==1.15.0 30 | transformers==4.38.2 31 | urllib3==1.26.13 32 | 33 | 34 | --------------------------------------------------------------------------------