├── ConceptVectors_data
├── llama2-7b_concepts.json
├── llama2-7b_concepts_dev.json
├── llama2-7b_concepts_test.json
├── olmo-7b_concepts.json
├── olmo-7b_concepts_dev.json
├── olmo-7b_concepts_test.json
└── relation_for_KE
│ ├── .idea
│ ├── .gitignore
│ ├── deployment.xml
│ ├── inspectionProfiles
│ │ ├── Project_Default.xml
│ │ └── profiles_settings.xml
│ ├── misc.xml
│ ├── modules.xml
│ ├── relation_for_KE.iml
│ └── vcs.xml
│ ├── label_prop_dict.json
│ ├── llama_relation_object.json
│ ├── llama_relation_object_dev.json
│ ├── llama_relation_object_test.json
│ ├── olmo_relation_object.json
│ ├── olmo_relation_object_dev.json
│ ├── olmo_relation_object_test.json
│ └── relation_to_template.json
├── Concept_Validation_Experiments
├── .ipynb_checkpoints
│ ├── Concept_Validation_Experiment-checkpoint.ipynb
│ └── olmo_Concept_Validation_Experiment-checkpoint.ipynb
└── Concept_Validation_Experiments.ipynb
├── Jailbreak
├── __pycache__
│ └── evaluate_util.cpython-38.pyc
├── evaluate_util.py
├── jailbreak.ipynb
├── llama_jailbreak_German_qa.json
└── olmo_jailbreak_German_qa.json
├── LICENSE
├── README.md
├── all_forget_llama.sh
├── all_forget_olmo.sh
├── config
├── ds_config.json
├── forget.yaml
└── model_config.yaml
├── data_module.py
├── dataloader.py
├── evaluate_llama.py
├── evaluate_olmo.py
├── evaluate_util.py
├── forget.py
├── memit
├── .gitattributes
├── .gitignore
├── CITATION.cff
├── LICENSE
├── README.md
├── baselines
│ ├── README.md
│ ├── ft
│ │ ├── __init__.py
│ │ ├── ft_hparams.py
│ │ └── ft_main.py
│ └── mend
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── algs
│ │ ├── enn.py
│ │ ├── ft.py
│ │ └── mend.py
│ │ ├── config
│ │ ├── alg
│ │ │ ├── efk.yaml
│ │ │ ├── enn.yaml
│ │ │ ├── ft.yaml
│ │ │ └── mend.yaml
│ │ ├── config.yaml
│ │ ├── experiment
│ │ │ ├── fc.yaml
│ │ │ ├── gen.yaml
│ │ │ └── qa.yaml
│ │ └── model
│ │ │ ├── bart-base.yaml
│ │ │ ├── bert-base.yaml
│ │ │ ├── distilgpt2.yaml
│ │ │ ├── gpt2.yaml
│ │ │ ├── gpt2large.yaml
│ │ │ ├── gpt2medium.yaml
│ │ │ ├── gpt2xl.yaml
│ │ │ ├── gptj.yaml
│ │ │ ├── gptneo27.yaml
│ │ │ ├── t5large.yaml
│ │ │ ├── t5small.yaml
│ │ │ ├── t5xl.yaml
│ │ │ └── t5xxl.yaml
│ │ ├── data_classes
│ │ ├── fever.py
│ │ ├── nq.py
│ │ ├── wiki.py
│ │ └── zsre.py
│ │ ├── editable_model.py
│ │ ├── hooks.py
│ │ ├── losses.py
│ │ ├── mend_hparams.py
│ │ ├── mend_main.py
│ │ ├── models.py
│ │ ├── nn.py
│ │ ├── oracle.py
│ │ ├── requirements.txt
│ │ ├── run.py
│ │ ├── trainer.py
│ │ └── utils.py
├── dsets
│ ├── __init__.py
│ ├── attr_snippets.py
│ ├── counterfact.py
│ ├── knowns.py
│ ├── tfidf_stats.py
│ └── zsre.py
├── experiments
│ ├── __init__.py
│ ├── causal_trace.py
│ ├── evaluate.py
│ ├── evaluate_olmo.py
│ ├── memit_jailbreak_evaluate.py
│ ├── memit_jailbreak_evaluate_olmo.py
│ ├── py
│ │ ├── demo.py
│ │ ├── eval_utils_counterfact.py
│ │ └── eval_utils_zsre.py
│ ├── summarize.py
│ └── sweep.py
├── forget_memit.sh
├── forget_memit_olmo.sh
├── globals.yml
├── hparams
│ ├── FT
│ │ ├── EleutherAI_gpt-j-6B_constr.json
│ │ ├── EleutherAI_gpt-j-6B_unconstr.json
│ │ ├── EleutherAI_gpt-j-6B_wd.json
│ │ ├── gpt2-large_constr.json
│ │ ├── gpt2-medium_constr.json
│ │ ├── gpt2-xl_attn.json
│ │ ├── gpt2-xl_constr.json
│ │ └── gpt2-xl_unconstr.json
│ ├── MEMIT
│ │ ├── EleutherAI_gpt-j-6B.json
│ │ ├── gpt2-xl.json
│ │ ├── llama2-7b.json
│ │ └── olmo-7b.json
│ ├── MEND
│ │ ├── EleutherAI_gpt-j-6B.json
│ │ ├── EleutherAI_gpt-j-6B_CF.json
│ │ ├── gpt2-xl.json
│ │ ├── gpt2-xl_CF.json
│ │ └── gpt2-xl_zsRE.json
│ └── ROME
│ │ ├── EleutherAI_gpt-j-6B.json
│ │ ├── EleutherAI_gpt-neox-20b.json
│ │ ├── gpt2-large.json
│ │ ├── gpt2-medium.json
│ │ └── gpt2-xl.json
├── memit
│ ├── __init__.py
│ ├── compute_ks.py
│ ├── compute_z.py
│ ├── memit_hparams.py
│ └── memit_main.py
├── notebooks
│ ├── average_causal_effects.ipynb
│ ├── causal_trace.ipynb
│ ├── causal_trace_frozen_mlp_attn.ipynb
│ ├── memit.ipynb
│ └── vis
│ │ ├── table_population.ipynb
│ │ ├── table_population_zsre.ipynb
│ │ ├── visualize_multi_results.ipynb
│ │ └── visualize_sweep_results.ipynb
├── rome
│ ├── README.md
│ ├── __init__.py
│ ├── compute_u.py
│ ├── compute_v.py
│ ├── layer_stats.py
│ ├── repr_tools.py
│ ├── rome_hparams.py
│ ├── rome_main.py
│ └── tok_dataset.py
├── scaling_curves.sh
├── scripts
│ ├── causal_trace.sh
│ ├── colab_reqs
│ │ ├── additional.txt
│ │ └── rome.txt
│ ├── collect_layer_stats.sh
│ ├── ipynb_drop_output.py
│ ├── memit.yml
│ ├── setup_clean_ipynb.sh
│ └── setup_conda.sh
├── transformer_utils
│ ├── README.md
│ ├── __init__.py
│ ├── requirements.txt
│ ├── setup.py
│ └── src
│ │ └── transformer_utils
│ │ ├── __init__.py
│ │ ├── logit_lens
│ │ ├── __init__.py
│ │ ├── hooks.py
│ │ ├── layer_names.py
│ │ └── plotting.py
│ │ ├── low_memory
│ │ ├── __init__.py
│ │ ├── enable.py
│ │ ├── load.py
│ │ └── load_context.py
│ │ ├── partial_forward
│ │ └── __init__.py
│ │ └── util
│ │ ├── __init__.py
│ │ ├── module_utils.py
│ │ ├── python_utils.py
│ │ └── tfm_utils.py
├── util
│ ├── __init__.py
│ ├── generate.py
│ ├── globals.py
│ ├── hparams.py
│ ├── logit_lens.py
│ ├── nethook.py
│ ├── perplexity.py
│ └── runningstats.py
└── zsre_evals.sh
├── requirements.txt
└── utils.py
/ConceptVectors_data/relation_for_KE/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/ConceptVectors_data/relation_for_KE/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/ConceptVectors_data/relation_for_KE/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/ConceptVectors_data/relation_for_KE/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/ConceptVectors_data/relation_for_KE/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/ConceptVectors_data/relation_for_KE/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/ConceptVectors_data/relation_for_KE/.idea/relation_for_KE.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/ConceptVectors_data/relation_for_KE/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/Jailbreak/__pycache__/evaluate_util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuaihong/ConceptVectors/607591b415043f7692bc17a9748de3d8ff3fc0c7/Jailbreak/__pycache__/evaluate_util.cpython-38.pyc
--------------------------------------------------------------------------------
/all_forget_llama.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # 循环调用程序,传递不同的次序参数
4 |
5 | #Testing on ConceptVectors Test set of LLaMA
6 | for i in {0..94} #18 26 27 #{0..9}
7 | do
8 | python forget.py order=$i batch_size=4 gradient_accumulation_steps=8 num_epochs=1 gradient_checkpointing=True lr=2e-1 forget_loss=grad_ascent ft_type=Needle set=test
9 | done
10 |
11 |
--------------------------------------------------------------------------------
/all_forget_olmo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #Testing on ConceptVectors Test set of OLMo
4 | for i in {0..161}
5 | do
6 | python forget.py order=$i batch_size=4 gradient_accumulation_steps=8 num_epochs=1 gradient_checkpointing=False lr=2e-1 forget_loss=grad_ascent ft_type=Needle set=test
7 | done
8 |
9 |
10 | #olmo jailbreak: 4 37 40 44 59 77 90 105 141 147
--------------------------------------------------------------------------------
/config/ds_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 3,
4 | "offload_optimizer": {
5 | "device": "none",
6 | "pin_memory": true
7 | },
8 | "offload_param": {
9 | "device": "none",
10 | "pin_memory": true
11 | },
12 | "overlap_comm": true,
13 | "contiguous_gradients": true,
14 | "sub_group_size": 1e9,
15 | "reduce_bucket_size": "auto",
16 | "stage3_prefetch_bucket_size": "auto",
17 | "stage3_param_persistence_threshold": "auto",
18 | "stage3_max_live_parameters": 1e9,
19 | "stage3_max_reuse_distance": 1e9,
20 | "stage3_gather_16bit_weights_on_model_save": true
21 | },
22 | "train_batch_size": "auto",
23 | "train_micro_batch_size_per_gpu": "auto",
24 | "gradient_accumulation_steps": "auto",
25 | "bf16": {
26 | "enabled": true
27 | }
28 | }
--------------------------------------------------------------------------------
/config/forget.yaml:
--------------------------------------------------------------------------------
1 | model_family: llama2-7b #olmo-7b
2 | model_path: /root/autodl-tmp/transformers/llama2-7b-chat-hf
3 | data_path: /root/Unlearn_Harry_Potter/Baselines/ConceptMap/ConceptMap_data
4 | results_save_path: /root/autodl-tmp/unlearn_results/${model_family}/${forget_loss}
5 | save_dir: ${model_path}/1GPU_${forget_loss}_${lr}_${split}_epoch${num_epochs}_batch${batch_size}_accum${gradient_accumulation_steps}_beta${beta}_ref${ref_policy}_eval${eval_steps}_seed${seed}_${run_index}
6 |
7 | set: test #dev or test
8 | order: 10
9 | LoRA:
10 | r: 0
11 | alpha: 32
12 | dropout: 0.05
13 |
14 | lr: 5e-5 #5e-4 lr should be bigger on Niddle
15 | split: wikipedia #wikipedia, pretraining_data
16 | batch_size: 1
17 | gradient_accumulation_steps: 16
18 | gradient_checkpointing: False #gradient_checkpointing is banned for olmo-7b
19 | num_epochs: 10
20 | forget_loss: npo_KL # type: grad_ascent, grad_diff, npo, npo_grad_diff, npo_KL, dpo
21 | ft_type: Full #Full, Sparse, MEMIT, all_value_vectors, Niddle,
22 | loss_threshold: -100
23 |
24 | npo_coeff: 1.0
25 | grad_diff_coeff: 1.0
26 | KL_coeff: 1.0
27 | ref_policy: fine_tuned
28 | beta: 0.1
29 | weight_decay: 0.01
30 |
31 | seed: 999
32 | run_index: 1
33 | overwrite_dir: True
34 | eval_steps: 1000000000
35 | warmup_steps: steps_per_epoch
36 |
37 |
--------------------------------------------------------------------------------
/config/model_config.yaml:
--------------------------------------------------------------------------------
1 | llama2-7b:
2 | hf_key: "/root/autodl-tmp/transformers/llama2-7b-chat-hf"
3 | question_start_tag: "[INST] "
4 | question_end_tag: " [/INST]"
5 | answer_tag: ""
6 | flash_attention2: "false"
7 | gradient_checkpointing: "true"
8 | ft_model_path: "/root/autodl-tmp/transformers/final_ft_noLORA_5_epochs_inst_lr1e-05_llama2-7b_full/checkpoint-625" #this model will be used for unlearning by defauly
9 | olmo-7b:
10 | hf_key: "/root/autodl-tmp/transformers/OLMo-7B"
11 | question_start_tag: "Question: "
12 | question_end_tag: "\n"
13 | answer_tag: "Answer: "
14 | flash_attention2: "false"
15 | gradient_checkpointing: "false"
16 | ft_model_path: "/root/autodl-tmp/transformers/final_ft_noLORA_5_epochs_inst_lr1e-05_olmo-7b_full/checkpoint-625"
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/memit/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb filter=clean_ipynb
2 |
--------------------------------------------------------------------------------
/memit/.gitignore:
--------------------------------------------------------------------------------
1 | # Pipeline dumps, data directory
2 | results
3 | data
4 | !notebooks/data
5 | *_tmp_*_.json
6 | *_kmeng01g1gn*
7 |
8 | # Pre-trained hypernetworks
9 | baselines/*/weights
10 |
11 | # Mac specific
12 | .idea
13 | .vscode
14 | .DS_Store
15 |
16 | # Latex
17 | *.aux
18 | *.dvi
19 | *.fdb_latexmk
20 | *.fls
21 | *.log
22 | *.pdf
23 | *.synctex.gz
24 | *.out
25 | *.toc
26 | *.nps
27 |
28 | # Byte-compiled / optimized / DLL files
29 | __pycache__/
30 | *.py[cod]
31 | *$py.class
32 | *.py.swp
33 |
34 | # C extensions
35 | *.so
36 |
37 | # Distribution / packaging
38 | .Python
39 | build/
40 | develop-eggs/
41 | dist/
42 | downloads/
43 | eggs/
44 | .eggs/
45 | lib/
46 | lib64/
47 | parts/
48 | sdist/
49 | var/
50 | wheels/
51 | share/python-wheels/
52 | *.egg-info/
53 | .installed.cfg
54 | *.egg
55 | MANIFEST
56 |
57 | # PyInstaller
58 | # Usually these files are written by a python script from a template
59 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
60 | *.manifest
61 | *.spec
62 |
63 | # Installer logs
64 | pip-log.txt
65 | pip-delete-this-directory.txt
66 |
67 | # Unit test / coverage reports
68 | htmlcov/
69 | .tox/
70 | .nox/
71 | .coverage
72 | .coverage.*
73 | .cache
74 | nosetests.xml
75 | coverage.xml
76 | *.cover
77 | *.py,cover
78 | .hypothesis/
79 | .pytest_cache/
80 | cover/
81 |
82 | # Translations
83 | *.mo
84 | *.pot
85 |
86 | # Django stuff:
87 | *.log
88 | local_settings.py
89 | db.sqlite3
90 | db.sqlite3-journal
91 |
92 | # Flask stuff:
93 | instance/
94 | .webassets-cache
95 |
96 | # Scrapy stuff:
97 | .scrapy
98 |
99 | # Sphinx documentation
100 | docs/_build/
101 |
102 | # PyBuilder
103 | .pybuilder/
104 | target/
105 |
106 | # Jupyter Notebook
107 | .ipynb_checkpoints
108 |
109 | # IPython
110 | profile_default/
111 | ipython_config.py
112 |
113 | # pyenv
114 | # For a library or package, you might want to ignore these files since the code is
115 | # intended to run in multiple environments; otherwise, check them in:
116 | # .python-version
117 |
118 | # pipenv
119 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
120 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
121 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
122 | # install all needed dependencies.
123 | #Pipfile.lock
124 |
125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
126 | __pypackages__/
127 |
128 | # Celery stuff
129 | celerybeat-schedule
130 | celerybeat.pid
131 |
132 | # SageMath parsed files
133 | *.sage.py
134 |
135 | # Environments
136 | .venv
137 | env/
138 | venv/
139 | ENV/
140 | env.bak/
141 | venv.bak/
142 |
143 | # Spyder project settings
144 | .spyderproject
145 | .spyproject
146 |
147 | # Rope project settings
148 | .ropeproject
149 |
150 | # mkdocs documentation
151 | /site
152 |
153 | # mypy
154 | .mypy_cache/
155 | .dmypy.json
156 | dmypy.json
157 |
158 | # Pyre type checker
159 | .pyre/
160 |
161 | # pytype static type analyzer
162 | .pytype/
163 |
164 | # Cython debug symbols
165 | cython_debug/
166 |
--------------------------------------------------------------------------------
/memit/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | preferred-citation:
4 | type: article
5 | authors:
6 | - family-names: "Meng"
7 | given-names: "Kevin"
8 | - family-names: "Sen Sharma"
9 | given-names: "Arnab"
10 | - family-names: "Andonian"
11 | given-names: "Alex"
12 | - family-names: "Belinkov"
13 | given-names: "Yonatan"
14 | - family-names: "Bau"
15 | given-names: "David"
16 | journal: "arXiv preprint arXiv:2210.07229"
17 | title: "Mass-Editing Memory in a Transformer"
18 | year: 2022
--------------------------------------------------------------------------------
/memit/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Kevin Meng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/memit/README.md:
--------------------------------------------------------------------------------
1 | # MEMIT: Mass-Editing Memory in a Transformer
2 |
3 | Editing thousands of facts into a transformer memory at once.
4 |
5 |
6 |
7 | ## Table of Contents
8 |
9 | - [Installation](#installation)
10 | - [MEMIT Algorithm Demo](#memit-algorithm-demo)
11 | - [Running the Full Evaluation Suite](#running-the-full-evaluation-suite)
12 | - [Generating Scaling Curves](#generating-scaling-curves)
13 | - [How to Cite](#how-to-cite)
14 |
15 | ## Installation
16 |
17 | We recommend `conda` for managing Python, CUDA, and PyTorch; `pip` is for everything else. To get started, simply install `conda` and run:
18 | ```bash
19 | CONDA_HOME=$CONDA_HOME ./scripts/setup_conda.sh
20 | ```
21 |
22 | `$CONDA_HOME` should be the path to your `conda` installation, e.g., `~/miniconda3`.
23 |
24 | ## MEMIT Algorithm Demo
25 |
26 | [`notebooks/memit.ipynb`](notebooks/memit.ipynb) demonstrates MEMIT. The API is simple; simply specify a *requested rewrite* of the following form:
27 |
28 | ```python
29 | request = [
30 | {
31 | "prompt": "{} plays the sport of",
32 | "subject": "LeBron James",
33 | "target_new": {
34 | "str": "football"
35 | }
36 | },
37 | {
38 | "prompt": "{} plays the sport of",
39 | "subject": "Michael Jordan",
40 | "target_new": {
41 | "str": "baseball"
42 | }
43 | },
44 | ]
45 | ```
46 |
47 | Other similar example(s) are included in the notebook.
48 |
49 | ## Running the Full Evaluation Suite
50 |
51 | [`experiments/evaluate.py`](experiments/evaluate.py) can be used to evaluate any method in [`baselines/`](baselines/).
52 |
53 | For example:
54 | ```
55 | python3 -m experiments.evaluate \
56 | --alg_name=MEMIT \
57 | --model_name=EleutherAI/gpt-j-6B \
58 | --hparams_fname=EleutherAI_gpt-j-6B.json \
59 | --num_edits=10000 \
60 | --use_cache
61 | ```
62 | Results from each run are stored at `results//run_` in a specific format:
63 | ```bash
64 | results/
65 | |__ MEMIT/
66 | |__ run_/
67 | |__ params.json
68 | |__ case_0.json
69 | |__ case_1.json
70 | |__ ...
71 | |__ case_10000.json
72 | ```
73 |
74 | To summarize the results, you can use [`experiments/summarize.py`](experiments/summarize.py):
75 | ```bash
76 | python3 -m experiments.summarize --dir_name=MEMIT --runs=run_,run_
77 | ```
78 |
79 | Running `python3 -m experiments.evaluate -h` or `python3 -m experiments.summarize -h` provides details about command-line flags.
80 |
81 | ## How to Cite
82 |
83 | ```bibtex
84 | @article{meng2022memit,
85 | title={Mass Editing Memory in a Transformer},
86 | author={Kevin Meng and Sen Sharma, Arnab and Alex Andonian and Yonatan Belinkov and David Bau},
87 | journal={arXiv preprint arXiv:2210.07229},
88 | year={2022}
89 | }
90 | ```
91 |
--------------------------------------------------------------------------------
/memit/baselines/README.md:
--------------------------------------------------------------------------------
1 | We compare ROME against several open sourced state-of-the-art model editors. All are implemented in their respective folders. Implementations other than FT/FT+L are adapted from third parties.
2 | - Fine-Tuning (`ft`): Direct fine-tuning.
3 | - Constrained Fine-Tuning (`ft`): FT with $L_\infty$ norm constraint. Inspired by Zhu et al. [[Paper]](https://arxiv.org/abs/2012.00363)
4 | - Knowledge Neurons (`kn`): Dai et al. [[Code]](https://github.com/EleutherAI/knowledge-neurons) [[Paper]](https://arxiv.org/abs/2104.08696)
5 | - Knowledge Editor (`efk`): De Cao et al. [[Code]](https://github.com/eric-mitchell/mend) [[Paper]](https://arxiv.org/abs/2104.08164)
6 | - Model Editor Networks with Gradient Decomposition (`mend`): Mitchell et al. [[Code]](https://github.com/eric-mitchell/mend) [[Paper]](https://arxiv.org/abs/2110.11309)
--------------------------------------------------------------------------------
/memit/baselines/ft/__init__.py:
--------------------------------------------------------------------------------
1 | from .ft_main import FTHyperParams, apply_ft_to_model, execute_ft
2 |
--------------------------------------------------------------------------------
/memit/baselines/ft/ft_hparams.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List
3 |
4 | from util.hparams import HyperParams
5 |
6 |
7 | @dataclass
8 | class FTHyperParams(HyperParams):
9 | # Method
10 | layers: List[int]
11 | num_steps: int
12 | lr: float
13 | weight_decay: float
14 | kl_factor: float
15 | norm_constraint: float
16 |
17 | # Module templates
18 | rewrite_module_tmp: str
19 | layer_module_tmp: str
20 | mlp_module_tmp: str
21 | attn_module_tmp: str
22 | ln_f_module: str
23 | lm_head_module: str
24 |
25 | # Defaults
26 | batch_size: int = 64
27 | wd_power_law: tuple = None # Scale weight decay by number of edits
28 |
--------------------------------------------------------------------------------
/memit/baselines/ft/ft_main.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from typing import Any, Dict, List, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | from transformers import AutoModelForCausalLM, AutoTokenizer
7 | from util import nethook
8 |
9 | from .ft_hparams import FTHyperParams
10 |
11 |
12 | def apply_ft_to_model(
13 | model: AutoModelForCausalLM,
14 | tok: AutoTokenizer,
15 | requests: List[Dict],
16 | hparams: FTHyperParams,
17 | copy=False,
18 | return_orig_weights=False,
19 | **kwargs: Any,
20 | ) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
21 | """
22 | Returns a model with the desired changes.
23 | :param copy: If true, will preserve the original model while creating a new one to edit.
24 | Note that you are responsible for deallocating the new model's memory to avoid leaks.
25 | :return: (1) the updated model, (2) the weights that changed
26 | """
27 |
28 | weights_copy = {}
29 | if copy:
30 | model = deepcopy(model)
31 |
32 | deltas = execute_ft(model, tok, requests, hparams)
33 |
34 | with torch.no_grad():
35 | for w_name, upd_matrix in deltas.items():
36 | w = nethook.get_parameter(model, w_name)
37 | if return_orig_weights and w_name not in weights_copy:
38 | weights_copy[w_name] = w.detach().clone()
39 |
40 | w[...] += upd_matrix
41 |
42 | print(f"New weights successfully inserted into {list(deltas.keys())}")
43 |
44 | return model, weights_copy
45 |
46 |
47 | def execute_ft(
48 | model: AutoModelForCausalLM,
49 | tok: AutoTokenizer,
50 | requests: List[Dict],
51 | hparams: FTHyperParams,
52 | **kwargs: Any,
53 | ) -> Dict[str, Tuple[torch.Tensor]]:
54 | """
55 | Executes the FT update algorithm for the specified update at the specified layer
56 | Invariant: model at beginning of function == model at end of function
57 | """
58 |
59 | # Update target and print info
60 | requests = deepcopy(requests)
61 | for request in requests:
62 | if request["target_new"]["str"][0] != " ":
63 | # Space required for correct tokenization
64 | request["target_new"]["str"] = " " + request["target_new"]["str"]
65 | print(
66 | f"Executing FT algo for: "
67 | f"[{request['prompt'].format(request['subject'])}] -> [{request['target_new']['str']}]"
68 | )
69 |
70 | # Retrieve weights that user desires to change
71 | weights = {
72 | n: p
73 | for n, p in model.named_parameters()
74 | for layer in hparams.layers
75 | if hparams.rewrite_module_tmp.format(layer) in n
76 | }
77 | # Save old weights for future restoration
78 | weights_copy = {k: v.detach().clone() for k, v in weights.items()}
79 | print(f"Weights to be updated: {list(weights.keys())}")
80 |
81 | # Define inputs
82 | texts = [r["prompt"].format(r["subject"]) for r in requests]
83 | targets = [r["target_new"]["str"] for r in requests]
84 |
85 | # Configure optimizer / gradients
86 | wd = (
87 | hparams.weight_decay
88 | if not isinstance(hparams.wd_power_law, tuple)
89 | else (len(requests) ** hparams.wd_power_law[0])
90 | * np.exp(hparams.wd_power_law[1])
91 | )
92 | print(f"Using weight decay of {wd} for {len(requests)} edits")
93 | opt = torch.optim.Adam(
94 | [v for _, v in weights.items()],
95 | lr=hparams.lr,
96 | weight_decay=wd,
97 | )
98 | for name, w in model.named_parameters():
99 | w.requires_grad = name in weights
100 |
101 | # Update loop: intervene at layers simultaneously
102 | loss_meter = AverageMeter()
103 | for it in range(hparams.num_steps):
104 | print(20 * "=")
105 | print(f"Epoch: {it}")
106 | print(20 * "=")
107 | loss_meter.reset()
108 |
109 | for txt, tgt in zip(
110 | chunks(texts, hparams.batch_size), chunks(targets, hparams.batch_size)
111 | ):
112 | inputs = tok(txt, return_tensors="pt", padding=True).to("cuda")
113 | target_ids = tok(tgt, return_tensors="pt", padding=True)["input_ids"].to(
114 | "cuda"
115 | )
116 | last_token_inds = inputs["attention_mask"].sum(dim=1) - 1
117 | loss_mask = target_ids != tok.unk_token_id
118 |
119 | opt.zero_grad()
120 | bs = inputs["input_ids"].shape[0]
121 | probs = torch.nn.functional.log_softmax(
122 | model(**inputs).logits[torch.arange(bs), last_token_inds], dim=-1
123 | )
124 | loss = -(torch.gather(probs, 1, target_ids) * loss_mask).sum(
125 | 1
126 | ) / loss_mask.sum(1)
127 | loss = loss.mean()
128 | print(f"Batch loss {loss.item()}")
129 | loss_meter.update(loss.item(), n=bs)
130 |
131 | if loss.item() >= 1e-2:
132 | loss.backward()
133 | opt.step()
134 |
135 | if type(hparams.norm_constraint) is float:
136 | eps = hparams.norm_constraint
137 | with torch.no_grad():
138 | for k, v in weights.items():
139 | v[...] = torch.clamp(
140 | v, min=weights_copy[k] - eps, max=weights_copy[k] + eps
141 | )
142 |
143 | print(f"Total loss {loss_meter.avg}")
144 |
145 | if loss_meter.avg < 1e-2:
146 | break
147 |
148 | deltas = {k: (weights[k] - weights_copy[k]).detach() for k in weights}
149 |
150 | # Restore state of original model
151 | with torch.no_grad():
152 | for k, v in weights.items():
153 | v[...] = weights_copy[k]
154 |
155 | print(f"Deltas successfully computed for {list(weights.keys())}")
156 |
157 | return deltas
158 |
159 |
160 | def chunks(arr, n):
161 | """Yield successive n-sized chunks from arr."""
162 | chunk = []
163 | for a in arr:
164 | chunk.append(a)
165 | if len(chunk) == n:
166 | yield chunk
167 | chunk = []
168 | if len(chunk) > 0:
169 | yield chunk
170 |
171 |
172 | class AverageMeter:
173 | """Computes and stores the average and current value"""
174 |
175 | def __init__(self):
176 | self.reset()
177 |
178 | def reset(self):
179 | self.val = 0
180 | self.avg = 0
181 | self.sum = 0
182 | self.count = 0
183 |
184 | def update(self, val, n=1):
185 | self.val = val
186 | self.sum += val * n
187 | self.count += n
188 | self.avg = self.sum / self.count
189 |
--------------------------------------------------------------------------------
/memit/baselines/mend/README.md:
--------------------------------------------------------------------------------
1 | # MEND: Model Editing Networks using Gradient Decomposition
2 |
3 | If you run into any issues with the code, you can open an issue and/or email me at `eric.mitchell@cs.stanford.edu`
4 |
5 | ## Setup
6 |
7 | ### Environment
8 |
9 | This codebase uses Python 3.7.9. Other versions may work as well.
10 |
11 | Create a virtualenv ([pyenv](https://github.com/pyenv/pyenv) can help with this)
12 | and install the dependencies:
13 |
14 | $ python -m venv env
15 | $ source env/bin/activate
16 | (env) $ pip install -r requirements.txt
17 |
18 | ### Data
19 |
20 | You can download the data needed for this project from
21 | [this Google Drive link](https://drive.google.com/drive/folders/1jAqBE45jEKR-5pMkwxlVQ0V8eKxqWbxA?usp=sharing).
22 | Unzip each sub-directory into `mend/data` and you should be good to go.
23 |
24 | ## Running the code
25 |
26 | Run MEND training/evaluation for distilGPT-2 on the wikitext editing problem with:
27 |
28 | (env) $ python -m run +alg=mend +experiment=gen +model=distilgpt2 data.wiki_webtext=False
29 |
30 | Other valid algs include `efk` ([KnowledgeEditor](https://arxiv.org/abs/2104.08164))
31 | and `enn` ([Editable Neural Networks](https://arxiv.org/abs/2004.00345)). Valid experiments
32 | include `fc` (FEVER fact checking) and `qa` (zsRE question-answering). Splits and rephrases
33 | for both come from [De Cao et. al](https://arxiv.org/abs/2104.08164). Check `config/model`
34 | for options for editable models (note that all models don't work for all experiments; GPT-style
35 | models only work with `gen`, seq2seq models only work with `qa`, and BERT only works with `fc`).
36 |
37 | Also note that in the paper, we sample locality data from different datasets depending on the model.
38 | By default, training will use [Natural Questions](https://ai.google.com/research/NaturalQuestions)
39 | data (not zsRE data) for computing drawdown in the `qa` experiment and
40 | [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/). For models such as the `distilgpt2`
41 | model we use (which was fine-tuned on wikitext) or the BART-base model, this behavior should be
42 | disabled with `data.wiki_webtext=False` or `data.zsre_nq=False`, respectively.
43 |
44 | ## Citing the paper
45 |
46 | If this code or paper was useful, please consider using the following citation:
47 |
48 | @article{mitchell2021fast,
49 | title={Fast Model Editing at Scale},
50 | author={Mitchell, Eric and Lin, Charles and Bosselut, Antoine and Finn, Chelsea and Manning, Christopher D.},
51 | year={2021},
52 | journal={CoRR},
53 | url={https://arxiv.org/pdf/2110.11309.pdf}
54 | }
55 |
--------------------------------------------------------------------------------
/memit/baselines/mend/__init__.py:
--------------------------------------------------------------------------------
1 | from .mend_hparams import MENDHyperParams
2 | from .mend_main import MendRewriteExecutor
3 |
--------------------------------------------------------------------------------
/memit/baselines/mend/algs/enn.py:
--------------------------------------------------------------------------------
1 | import higher
2 | import torch
3 | import torch.nn as nn
4 | from editable_model import EditableModel
5 | from utils import _logits
6 |
7 |
8 | def fomaml_callback(all_grads):
9 | return [g.detach() if g is not None else None for g in all_grads]
10 |
11 |
12 | class ENN(EditableModel):
13 | def __init__(
14 | self, model, config, model_constructor, edit_lrs=None, edit_loss_fn=None
15 | ):
16 | super().__init__(model, config, model_constructor)
17 |
18 | if edit_lrs is None:
19 | edit_lrs = nn.Parameter(
20 | torch.tensor([config.edit_lr] * len(self.config.model.inner_params))
21 | )
22 | self.edit_lrs = edit_lrs
23 |
24 | if edit_loss_fn is not None:
25 | self.edit_loss_fn = edit_loss_fn
26 |
27 | self.grad_callback = fomaml_callback if config.enn.first_order else lambda x: x
28 |
29 | def outer_parameters(self):
30 | if self.config.no_grad_layers is None:
31 | return super().outer_parameters()
32 | else:
33 | params = [self.edit_lrs]
34 | for m in self.model.modules():
35 | if isinstance(m, nn.ModuleList):
36 | params.extend(list(m[self.config.no_grad_layers :].parameters()))
37 | return params
38 |
39 | def get_state_dict(self):
40 | return self.state_dict()
41 |
42 | def edit(self, batch, condition=None, detach_history=False):
43 | opt = torch.optim.SGD(
44 | [
45 | {"params": p, "lr": None}
46 | for (n, p) in self.model.named_parameters()
47 | if n in self.config.model.inner_params
48 | ]
49 | )
50 | with torch.enable_grad(), higher.innerloop_ctx(
51 | self.model,
52 | opt,
53 | override={"lr": list(self.edit_lrs)},
54 | copy_initial_weights=False,
55 | track_higher_grads=self.training,
56 | in_place=True,
57 | ) as (fmodel, diffopt):
58 | fmodel.eval()
59 | for edit_step in range(self.config.enn.n_edit_steps):
60 | output = _logits(fmodel(**batch))
61 | loss = self.edit_loss_fn(output, batch["labels"])["nll"]
62 | diffopt.step(loss, grad_callback=self.grad_callback)
63 |
64 | if not detach_history:
65 | model_edited = fmodel
66 | else:
67 | model_edited = self.model_constructor()
68 | model_edited.load_state_dict(fmodel.state_dict())
69 | model_edited.train(self.training)
70 |
71 | return (
72 | ENN(
73 | model_edited,
74 | self.config,
75 | self.model_constructor,
76 | edit_lrs=self.edit_lrs,
77 | edit_loss_fn=self.edit_loss_fn,
78 | ),
79 | {},
80 | )
81 |
82 |
83 | def test():
84 | import copy
85 | import types
86 |
87 | import transformers
88 |
89 | model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
90 |
91 | config = types.SimpleNamespace()
92 | config.edit_lr = 0.1
93 | config.model.inner_params = [
94 | "transformer.h.9.mlp.c_fc.weight",
95 | "transformer.h.9.mlp.c_proj.weight",
96 | "transformer.h.10.mlp.c_fc.weight",
97 | "transformer.h.10.mlp.c_proj.weight",
98 | "transformer.h.11.mlp.c_fc.weight",
99 | "transformer.h.11.mlp.c_proj.weight",
100 | ]
101 | config.enn = {"n_edit_steps": 2, "first_order": False}
102 |
103 | enn = ENN(model, config, lambda: copy.deepcopy(model)).cuda()
104 |
105 | x = torch.arange(100).view(5, 20).cuda() + 1000
106 |
107 | edited = enn.edit(x, masks=torch.ones_like(x), labels=x)
108 |
109 | orig_param = [
110 | p
111 | for (n, p) in enn.model.named_parameters()
112 | if n == config.model.inner_params[-1]
113 | ][0]
114 | edited_param = [
115 | p
116 | for (n, p) in edited.model.named_parameters()
117 | if n == config.model.inner_params[-1]
118 | ][0]
119 |
120 | print((orig_param - edited_param).abs().max())
121 | edited.eval()
122 | print(
123 | enn(x, labels=x).loss,
124 | edited(x, labels=x).loss,
125 | edited.edit_loss_fn(edited(x).logits, x)["nll"],
126 | )
127 | edited.edit_loss_fn(edited(x).logits, x).backward()
128 | import pdb
129 |
130 | pdb.set_trace()
131 |
132 |
133 | if __name__ == "__main__":
134 | with torch.autograd.set_detect_anomaly(True):
135 | test()
136 |
--------------------------------------------------------------------------------
/memit/baselines/mend/algs/ft.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import higher
4 | import torch
5 | import torch.nn as nn
6 | from editable_model import EditableModel
7 | from higher.patch import monkeypatch as make_functional
8 | from losses import kl_loc_loss
9 | from utils import _inner_params, _logits
10 |
11 |
12 | class FT(EditableModel):
13 | """
14 | Fine-tuning approach. Does not require training.
15 | """
16 |
17 | def __init__(self, model, config, model_constructor, edit_loss_fn=None):
18 | super().__init__(model, config, model_constructor)
19 |
20 | if edit_loss_fn is not None:
21 | self.edit_loss_fn = edit_loss_fn
22 |
23 | self.locality_loss_fn = kl_loc_loss
24 | self.loc_ids = None
25 | self.loc_masks = None
26 | self.loc_sampler = None
27 |
28 | def _edit_loss(self, model, p0, p_edited, edit_batch):
29 | output = _logits(model(**edit_batch, params=p_edited))
30 | loss_dict = self.edit_loss_fn(output, edit_batch["labels"])
31 | l_edit, acc = loss_dict["nll"], loss_dict["acc"]
32 | if self.config.ft.locality.enabled:
33 | if self.config.ft.locality.oracle:
34 | loc_batch = next(self.loc_sampler)["loc"]
35 | else:
36 | raise NotImplementedError
37 |
38 | with torch.no_grad():
39 | original_base_logits = _logits(model(**loc_batch, params=p0))
40 | edited_base_logits = _logits(model(**loc_batch, params=p_edited))
41 | kl_mask = loc_batch.get(
42 | "decoder_attention_mask", loc_batch["attention_mask"]
43 | )
44 | l_loc = self.locality_loss_fn(
45 | original_base_logits, edited_base_logits, mask=kl_mask
46 | )
47 | loss = l_loc + self.config.ft.locality.cedit * l_edit
48 | else:
49 | l_loc = torch.tensor(float("nan"))
50 | loss = l_edit
51 | return loss, l_edit, l_loc, acc
52 |
53 | def accuracy(self, output, labels):
54 | if output.shape[-1] != 1:
55 | shifted_output = output.argmax(-1)[:, :-1]
56 | shifted_labels = labels[:, 1:]
57 | to_predict = (shifted_labels != -100).sum()
58 | correct = (shifted_output == shifted_labels).sum()
59 | acc = correct.float() / to_predict.float()
60 | else:
61 | acc = ((output > 0) == labels.bool()).sum().float()
62 | return acc
63 |
64 | def _edit_status(self, step, loss, l_edit, l_loc, acc, res_p):
65 | return (
66 | f"step: {step}".ljust(14)
67 | + f"loss: {loss.item():.5f}".ljust(18)
68 | + f"l_edit: {l_edit.item():.5f}".ljust(18)
69 | + f"l_loc: {l_loc.item():.5f}".ljust(18)
70 | + f"acc: {acc.item():.2f}".ljust(14)
71 | + f"norm: {res_p.view(-1).norm().item():.5f}"
72 | )
73 |
74 | def edit(self, batch, condition=None, detach_history=False):
75 | edit_model = self.model.eval()
76 | p0 = list(edit_model.named_parameters())
77 |
78 | if not isinstance(edit_model, higher.patch._MonkeyPatchBase):
79 | edit_model = make_functional(
80 | self.model, track_higher_grads=False, in_place=True
81 | )
82 |
83 | packed_residuals = {}
84 | opt_params = []
85 | for n, p in _inner_params(
86 | edit_model.named_parameters(), self.config.model.inner_params
87 | ):
88 | if self.config.ft.rank is not None:
89 | u = nn.Parameter(
90 | torch.randn(p.shape[0], self.config.ft.rank, device=p.device)
91 | * self.config.ft.init_std
92 | )
93 | v = nn.Parameter(
94 | torch.zeros(self.config.ft.rank, p.shape[1], device=p.device)
95 | )
96 | res = [u, v]
97 | else:
98 | res = [nn.Parameter(torch.zeros_like(p, device=p.device))]
99 |
100 | packed_residuals[n] = res
101 | opt_params.extend(res)
102 |
103 | assert len(opt_params) == len(self.config.model.inner_params)
104 | OptClass = getattr(torch.optim, self.config.ft.opt)
105 | opt = OptClass(opt_params, lr=self.config.edit_lr)
106 |
107 | start_time = time.time()
108 | for edit_step in range(self.config.ft.max_edit_steps):
109 | if self.config.ft.time_limit is not None and (
110 | time.time() - start_time > self.config.ft.time_limit
111 | ):
112 | break
113 | residuals = {
114 | k: v[0] @ v[1] if len(v) == 2 else v[0]
115 | for k, v in packed_residuals.items()
116 | }
117 | edited_params = [
118 | p if n not in residuals else p.detach() + residuals[n] for n, p in p0
119 | ]
120 | loss, l_edit, l_loc, acc = self._edit_loss(
121 | edit_model, [p for n, p in p0], edited_params, batch
122 | )
123 |
124 | if self.config.ft.verbose:
125 | residual = list(residuals.values())[-1]
126 | print(
127 | self._edit_status(edit_step, loss, l_edit, l_loc, acc, residual),
128 | end="\r",
129 | )
130 |
131 | if acc == 1.0:
132 | break
133 |
134 | for p, g in zip(opt_params, torch.autograd.grad(loss, opt_params)):
135 | p.grad = g
136 | torch.nn.utils.clip_grad_norm_(opt_params, self.config.grad_clip)
137 | opt.step()
138 | opt.zero_grad()
139 |
140 | if detach_history:
141 | new_model = self.model_constructor()
142 | new_model.load_state_dict(edit_model.state_dict())
143 | edit_model = new_model
144 | edit_model.train(self.training)
145 |
146 | return (
147 | FT(edit_model, self.config, self.model_constructor, self.edit_loss_fn),
148 | {},
149 | )
150 |
--------------------------------------------------------------------------------
/memit/baselines/mend/config/alg/efk.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | alg: efk
4 | train_base: False
5 | lr: 1e-5
6 |
--------------------------------------------------------------------------------
/memit/baselines/mend/config/alg/enn.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | alg: enn
4 | train_base: True
5 | enn:
6 | first_order: False
7 | n_edit_steps: 1
8 |
--------------------------------------------------------------------------------
/memit/baselines/mend/config/alg/ft.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | train_base: False
4 | alg: ft
5 | edit_lr: 5e-6
6 | ft:
7 | verbose: false
8 | max_edit_steps: 100
9 | time_limit: null
10 | locality:
11 | enabled: false
12 | oracle: true
13 | cedit: 1e-2
14 | batch_size: 1
15 | rank: null
16 | opt: RMSprop
17 | init_std: 0.01
18 |
--------------------------------------------------------------------------------
/memit/baselines/mend/config/alg/mend.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | alg: mend
4 | lr: 1e-6
5 | train_base: False
6 | edit_lr: 1e-4
7 | lr_lr: 1e-4
8 | mend:
9 | one_sided: False
10 | n_hidden: 1
11 | hidden_dim: null
12 | init: id
13 | norm: True
14 | combine: True
15 | x_only: False
16 | delta_only: False
17 | act: relu
18 | rank: 1920
19 | mlp_class: IDMLP
20 | shared: True
21 |
--------------------------------------------------------------------------------
/memit/baselines/mend/config/config.yaml:
--------------------------------------------------------------------------------
1 | alg: enn
2 | lr: 1e-5
3 | edit_lr: 1e-2
4 | seed: 0
5 | debug: False
6 | model_save_pt: 5000
7 | edit_bs: 1
8 | silent: False
9 | max_iters: 1000000
10 | log_interval: 100
11 | val_interval: 5000
12 | lr_lr: 1e-3
13 | batch_size: 2
14 | val_batch_size: 5
15 | accumulate_bs: 10
16 | cedit: 0.1
17 | cloc: 1.0
18 | cbase: 1.0
19 | val_steps: 500
20 | device: cuda
21 | base_loss: distill
22 | oracle: False
23 | train: True
24 | train_base: True
25 | opt: Adam
26 | single_batch: False
27 | archive: null
28 | grad_clip: 100.
29 | ref: null
30 | early_stop_patience: 20000
31 | early_stop_key: "loss/total_edit_val"
32 | dropout: 0.0
33 | tokenizer: null
34 | results_dir: null
35 | no_grad_layers: null
36 | eval_only: False
37 | half: False
38 | save: False
39 |
40 | model:
41 | pt: null
42 |
43 | data:
44 | path: null
45 | rephrase: true
46 | zsre_nq: true
47 | nq_path: ${hydra:runtime.cwd}/data/nq
48 | wiki_webtext: true
49 | n_edits: 1
50 |
51 | eval:
52 | verbose: True
53 | log_interval: 100
54 | final_eval: True
55 |
56 | hydra:
57 | run:
58 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f${uuid:}}
59 | sweep:
60 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f}
61 | subdir: ${hydra.job.num}
--------------------------------------------------------------------------------
/memit/baselines/mend/config/experiment/fc.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task: fc
4 | dataset: fever
5 | cbase: 1.0
6 |
--------------------------------------------------------------------------------
/memit/baselines/mend/config/experiment/gen.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task: gen
4 | dataset: wikitext-103
5 | cbase: 10.0
6 | data:
7 | path: ${hydra:runtime.cwd}/data/10token/data/self_sample/
8 |
--------------------------------------------------------------------------------
/memit/baselines/mend/config/experiment/qa.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task: qa
4 | dataset: zsre
5 | cbase: 1.0
6 |
--------------------------------------------------------------------------------
/memit/baselines/mend/config/model/bart-base.yaml:
--------------------------------------------------------------------------------
1 | name: facebook/bart-base
2 | class_name: BartForConditionalGeneration
3 | tokenizer_class: BartTokenizerFast
4 | tokenizer_name: facebook/bart-base
5 | inner_params:
6 | - model.encoder.layers.4.fc1.weight
7 | - model.encoder.layers.4.fc2.weight
8 | - model.encoder.layers.5.fc1.weight
9 | - model.encoder.layers.5.fc2.weight
10 | - model.decoder.layers.4.fc1.weight
11 | - model.decoder.layers.4.fc2.weight
12 | - model.decoder.layers.5.fc1.weight
13 | - model.decoder.layers.5.fc2.weight
14 |
15 | pt: ${hydra:runtime.cwd}/data/zsre/QA_model.ckpt
--------------------------------------------------------------------------------
/memit/baselines/mend/config/model/bert-base.yaml:
--------------------------------------------------------------------------------
1 | name: bert-base-uncased
2 | class_name: BertClassifier
3 | tokenizer_class: BertTokenizerFast
4 | tokenizer_name: bert-base-uncased
5 | inner_params:
6 | - model.encoder.layer.9.intermediate.dense.weight
7 | - model.encoder.layer.9.output.dense.weight
8 | - model.encoder.layer.10.intermediate.dense.weight
9 | - model.encoder.layer.10.output.dense.weight
10 | - model.encoder.layer.11.intermediate.dense.weight
11 | - model.encoder.layer.11.output.dense.weight
12 |
13 | pt: ${hydra:runtime.cwd}/data/fever/FC_model.ckpt
--------------------------------------------------------------------------------
/memit/baselines/mend/config/model/distilgpt2.yaml:
--------------------------------------------------------------------------------
1 | name: MYX4567/distilgpt2-finetuned-wikitext2
2 | class_name: GPT2LMHeadModel
3 | tokenizer_class: GPT2TokenizerFast
4 | tokenizer_name: distilgpt2
5 | inner_params:
6 | - transformer.h.3.mlp.c_fc.weight
7 | - transformer.h.3.mlp.c_proj.weight
8 | - transformer.h.4.mlp.c_fc.weight
9 | - transformer.h.4.mlp.c_proj.weight
10 | - transformer.h.5.mlp.c_fc.weight
11 | - transformer.h.5.mlp.c_proj.weight
12 |
13 |
--------------------------------------------------------------------------------
/memit/baselines/mend/config/model/gpt2.yaml:
--------------------------------------------------------------------------------
1 | name: gpt2
2 | class_name: GPT2LMHeadModel
3 | tokenizer_class: GPT2TokenizerFast
4 | tokenizer_name: gpt2
5 | inner_params:
6 | - transformer.h.9.mlp.c_proj.weight
7 | - transformer.h.9.mlp.c_fc.weight
8 | - transformer.h.10.mlp.c_proj.weight
9 | - transformer.h.10.mlp.c_fc.weight
10 | - transformer.h.11.mlp.c_proj.weight
11 | - transformer.h.11.mlp.c_fc.weight
--------------------------------------------------------------------------------
/memit/baselines/mend/config/model/gpt2large.yaml:
--------------------------------------------------------------------------------
1 | name: gpt2-large
2 | class_name: GPT2LMHeadModel
3 | tokenizer_class: GPT2TokenizerFast
4 | tokenizer_name: gpt2-large
5 | inner_params:
6 | - transformer.h.33.mlp.c_proj.weight
7 | - transformer.h.33.mlp.c_fc.weight
8 | - transformer.h.34.mlp.c_proj.weight
9 | - transformer.h.34.mlp.c_fc.weight
10 | - transformer.h.35.mlp.c_proj.weight
11 | - transformer.h.35.mlp.c_fc.weight
12 |
--------------------------------------------------------------------------------
/memit/baselines/mend/config/model/gpt2medium.yaml:
--------------------------------------------------------------------------------
1 | name: gpt2-medium
2 | class_name: GPT2LMHeadModel
3 | tokenizer_class: GPT2TokenizerFast
4 | tokenizer_name: gpt2-medium
5 | inner_params:
6 | - transformer.h.21.mlp.c_proj.weight
7 | - transformer.h.21.mlp.c_fc.weight
8 | - transformer.h.22.mlp.c_proj.weight
9 | - transformer.h.22.mlp.c_fc.weight
10 | - transformer.h.23.mlp.c_proj.weight
11 | - transformer.h.23.mlp.c_fc.weight
--------------------------------------------------------------------------------
/memit/baselines/mend/config/model/gpt2xl.yaml:
--------------------------------------------------------------------------------
1 | name: gpt2-xl
2 | class_name: GPT2LMHeadModel
3 | tokenizer_class: GPT2TokenizerFast
4 | tokenizer_name: gpt2-xl
5 | inner_params:
6 | - transformer.h.45.mlp.c_proj.weight
7 | - transformer.h.45.mlp.c_fc.weight
8 | - transformer.h.46.mlp.c_proj.weight
9 | - transformer.h.46.mlp.c_fc.weight
10 | - transformer.h.47.mlp.c_proj.weight
11 | - transformer.h.47.mlp.c_fc.weight
12 |
--------------------------------------------------------------------------------
/memit/baselines/mend/config/model/gptj.yaml:
--------------------------------------------------------------------------------
1 | name: EleutherAI/gpt-j-6B
2 | class_name: GPTJForCausalLM
3 | tokenizer_class: AutoTokenizer
4 | tokenizer_name: EleutherAI/gpt-j-6B
5 | inner_params:
6 | - transformer.h.25.mlp.fc_in.weight
7 | - transformer.h.25.mlp.fc_out.weight
8 | - transformer.h.26.mlp.fc_in.weight
9 | - transformer.h.26.mlp.fc_out.weight
10 | - transformer.h.27.mlp.fc_in.weight
11 | - transformer.h.27.mlp.fc_out.weight
12 |
--------------------------------------------------------------------------------
/memit/baselines/mend/config/model/gptneo27.yaml:
--------------------------------------------------------------------------------
1 | name: EleutherAI/gpt-neo-2.7B
2 | class_name: GPTNeoForCausalLM
3 | tokenizer_class: GPT2TokenizerFast
4 | tokenizer_name: EleutherAI/gpt-neo-2.7B
5 | inner_params:
6 | - transformer.h.29.mlp.c_fc.weight
7 | - transformer.h.29.mlp.c_proj.weight
8 | - transformer.h.30.mlp.c_fc.weight
9 | - transformer.h.30.mlp.c_proj.weight
10 | - transformer.h.31.mlp.c_fc.weight
11 | - transformer.h.31.mlp.c_proj.weight
12 |
--------------------------------------------------------------------------------
/memit/baselines/mend/config/model/t5large.yaml:
--------------------------------------------------------------------------------
1 | name: google/t5-large-ssm-nq
2 | class_name: AutoModelForSeq2SeqLM
3 | tokenizer_class: AutoTokenizer
4 | tokenizer_name: google/t5-large-ssm-nq
5 | inner_params:
6 | - encoder.block.22.layer.1.DenseReluDense.wi.weight
7 | - encoder.block.22.layer.1.DenseReluDense.wo.weight
8 | - encoder.block.23.layer.1.DenseReluDense.wi.weight
9 | - encoder.block.23.layer.1.DenseReluDense.wo.weight
10 | - decoder.block.22.layer.2.DenseReluDense.wi.weight
11 | - decoder.block.22.layer.2.DenseReluDense.wo.weight
12 | - decoder.block.23.layer.2.DenseReluDense.wi.weight
13 | - decoder.block.23.layer.2.DenseReluDense.wo.weight
14 |
15 | pt: null
--------------------------------------------------------------------------------
/memit/baselines/mend/config/model/t5small.yaml:
--------------------------------------------------------------------------------
1 | name: google/t5-small-ssm-nq
2 | class_name: AutoModelForSeq2SeqLM
3 | tokenizer_class: AutoTokenizer
4 | tokenizer_name: google/t5-small-ssm-nq
5 | inner_params:
6 | - encoder.block.6.layer.1.DenseReluDense.wi_0.weight
7 | - encoder.block.6.layer.1.DenseReluDense.wo.weight
8 | - encoder.block.7.layer.1.DenseReluDense.wi_0.weight
9 | - encoder.block.7.layer.1.DenseReluDense.wo.weight
10 | - decoder.block.6.layer.2.DenseReluDense.wi_0.weight
11 | - decoder.block.6.layer.2.DenseReluDense.wo.weight
12 | - decoder.block.7.layer.2.DenseReluDense.wi_0.weight
13 | - decoder.block.7.layer.2.DenseReluDense.wo.weight
14 |
15 | pt: null
--------------------------------------------------------------------------------
/memit/baselines/mend/config/model/t5xl.yaml:
--------------------------------------------------------------------------------
1 | name: google/t5-xl-ssm-nq
2 | class_name: AutoModelForSeq2SeqLM
3 | tokenizer_class: AutoTokenizer
4 | tokenizer_name: google/t5-xl-ssm-nq
5 | inner_params:
6 | - encoder.block.22.layer.1.DenseReluDense.wi_0.weight
7 | - encoder.block.22.layer.1.DenseReluDense.wo.weight
8 | - encoder.block.23.layer.1.DenseReluDense.wi_0.weight
9 | - encoder.block.23.layer.1.DenseReluDense.wo.weight
10 | - decoder.block.22.layer.2.DenseReluDense.wi_0.weight
11 | - decoder.block.22.layer.2.DenseReluDense.wo.weight
12 | - decoder.block.23.layer.2.DenseReluDense.wi_0.weight
13 | - decoder.block.23.layer.2.DenseReluDense.wo.weight
14 |
15 | pt: null
--------------------------------------------------------------------------------
/memit/baselines/mend/config/model/t5xxl.yaml:
--------------------------------------------------------------------------------
1 | name: google/t5-xxl-ssm-nq
2 | class_name: AutoModelForSeq2SeqLM
3 | tokenizer_class: AutoTokenizer
4 | tokenizer_name: google/t5-xxl-ssm-nq
5 | inner_params:
6 | - encoder.block.22.layer.1.DenseReluDense.wi_0.weight
7 | - encoder.block.22.layer.1.DenseReluDense.wo.weight
8 | - encoder.block.23.layer.1.DenseReluDense.wi_0.weight
9 | - encoder.block.23.layer.1.DenseReluDense.wo.weight
10 | - decoder.block.22.layer.2.DenseReluDense.wi_0.weight
11 | - decoder.block.22.layer.2.DenseReluDense.wo.weight
12 | - decoder.block.23.layer.2.DenseReluDense.wi_0.weight
13 | - decoder.block.23.layer.2.DenseReluDense.wo.weight
14 |
15 | pt: null
--------------------------------------------------------------------------------
/memit/baselines/mend/data_classes/fever.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import jsonlines
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 | from utils import EditBatchSampler, dict_to
8 |
9 | POSITIVE_CLASS = "SUPPORTS"
10 |
11 |
12 | class BinaryAugmentedKILT(Dataset):
13 | def __init__(self, tokenizer, data_path, config, max_length=32):
14 | super().__init__()
15 | self.tokenizer = tokenizer
16 | self.data = []
17 | self.config = config
18 |
19 | def extract(d):
20 | extracted = {
21 | k: d[k]
22 | for k in [
23 | "logit",
24 | "input",
25 | "prediction",
26 | "alternatives",
27 | "filtered_rephrases",
28 | ]
29 | }
30 | extracted["label"] = d["output"][0]["answer"]
31 | return extracted
32 |
33 | with jsonlines.open(data_path) as f:
34 | for d in f:
35 | if len(d["alternatives"]) > 0 and len(d["filtered_rephrases"]) > 0:
36 | self.data.append(extract(d))
37 |
38 | self.max_length = max_length
39 |
40 | def __len__(self):
41 | return len(self.data)
42 |
43 | def __getitem__(self, item):
44 | obj = self.data[item]
45 | rephrase = random.choice(self.data[item]["filtered_rephrases"])
46 | output = {
47 | "label": obj["label"] == POSITIVE_CLASS,
48 | "src": obj["input"],
49 | "rephrase": rephrase,
50 | "pred": obj["prediction"] == POSITIVE_CLASS,
51 | "alt": obj["alternatives"][0] == POSITIVE_CLASS,
52 | "cond_flip": "{} >> {} || {}".format(
53 | obj["prediction"],
54 | obj["alternatives"][0],
55 | obj["input"],
56 | ),
57 | "cond_orig": "{} >> {} || {}".format(
58 | obj["prediction"],
59 | obj["prediction"],
60 | obj["input"],
61 | ),
62 | "logit": obj["logit"],
63 | }
64 |
65 | return output
66 |
67 | def collate_fn(self, batch):
68 | src = [b["src"] for b in batch]
69 | rephrase = [batch[-1]["rephrase"]]
70 |
71 | flip_label = np.random.uniform() > 0.5
72 | predictions = [b["pred"] for b in batch]
73 | labels = [b["label"] for b in batch]
74 | labels[-1] = predictions[
75 | -1
76 | ] # the last element in the batch is special (the edit element)
77 | cond = [batch[-1]["cond_orig"]]
78 | if flip_label:
79 | labels[-1] = batch[-1]["alt"]
80 | cond = [batch[-1]["cond_flip"]]
81 |
82 | batches = {}
83 | for k1, v1 in {"": src, "cond_": cond, "rephrase_": rephrase}.items():
84 | encoded = self.tokenizer(
85 | v1,
86 | return_tensors="pt",
87 | padding=True,
88 | max_length=self.max_length,
89 | truncation=True,
90 | )
91 | for k2, v2 in encoded.items():
92 | batches[f"{k1}{k2}"] = v2
93 |
94 | batches["predictions"] = torch.tensor(predictions).long().view(-1, 1)
95 | batches["labels"] = torch.tensor(labels).long().view(-1, 1)
96 | batches["raw"] = batch
97 | return batches
98 |
99 | def edit_generator(self, batch_size, n=None):
100 | if n is None:
101 | n = len(self)
102 | sampler = EditBatchSampler(
103 | n, memorize_mode=self.config.single_batch, seed=self.config.seed
104 | )
105 | while True:
106 | edit_idxs, loc_idxs = sampler.sample(batch_size)
107 | assert len(edit_idxs) == 1
108 | idxs = loc_idxs + edit_idxs
109 | toks = self.collate_fn([self[idx] for idx in idxs])
110 |
111 | pass_keys = ["input_ids", "attention_mask", "labels"]
112 | edit_inner = {k: v[-1:] for k, v in toks.items() if k in pass_keys}
113 | if self.config.data.rephrase:
114 | edit_outer = {}
115 | edit_outer["input_ids"] = toks["rephrase_input_ids"]
116 | edit_outer["attention_mask"] = toks["rephrase_attention_mask"]
117 | edit_outer["labels"] = edit_inner["labels"]
118 | else:
119 | edit_outer = edit_inner
120 | loc = {k: v[:-1] for k, v in toks.items() if k in pass_keys}
121 | cond = {
122 | "input_ids": toks["cond_input_ids"],
123 | "attention_mask": toks["cond_attention_mask"],
124 | }
125 |
126 | batch = {
127 | "edit_inner": edit_inner,
128 | "edit_outer": edit_outer,
129 | "loc": loc,
130 | "cond": cond,
131 | }
132 | yield dict_to(batch, self.config.device)
133 |
--------------------------------------------------------------------------------
/memit/baselines/mend/data_classes/nq.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | class NQDataset:
5 | def __init__(self, path: str, tokenizer, config):
6 | with open(path, "r") as f:
7 | self.data = json.load(f)
8 |
9 | self.questions = self.data["questions"]
10 | self.answers = self.data["answers"]
11 | self.tokenizer = tokenizer
12 | self.config = config
13 |
14 | def __getitem__(self, idx):
15 | idx = idx % len(self.questions)
16 | return self.questions[idx], self.answers[idx]
17 |
18 | @staticmethod
19 | def generate(
20 | out_path: str,
21 | prompt: bool = False,
22 | capitalize: bool = True,
23 | question_mark: bool = True,
24 | ):
25 | import os
26 |
27 | import datasets
28 |
29 | def process(text):
30 | if capitalize:
31 | text = text[0].capitalize() + text[1:]
32 | if question_mark:
33 | text = text + "?"
34 | if prompt:
35 | text = "nq question: " + text
36 | return text
37 |
38 | def extract(d):
39 | questions = [process(q["text"]) for q in d["question"]]
40 | answers = [
41 | [a["text"][0] for a in ann["short_answers"] if len(a["text"])]
42 | for ann in d["annotations"]
43 | ]
44 | questions = [q for q, a in zip(questions, answers) if len(a)]
45 | answers = [min(a, key=len) for a in answers if len(a)]
46 | return questions, answers
47 |
48 | train = datasets.load_dataset("natural_questions", split="train")
49 | tq, ta = extract(train)
50 | val = datasets.load_dataset("natural_questions", split="validation")
51 | vq, va = extract(val)
52 |
53 | if not os.path.exists(out_path):
54 | os.makedirs(out_path)
55 | with open(f"{out_path}/train.json", "w") as f:
56 | json.dump({"questions": tq, "answers": ta}, f)
57 | with open(f"{out_path}/validation.json", "w") as f:
58 | json.dump({"questions": vq, "answers": va}, f)
59 |
60 |
61 | if __name__ == "__main__":
62 | import argparse
63 |
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument("--out_path", type=str, default="data/nq")
66 | args = parser.parse_args()
67 | NQDataset.generate(args.out_path)
68 |
--------------------------------------------------------------------------------
/memit/baselines/mend/data_classes/wiki.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import logging
4 | import random
5 |
6 | from datasets import load_dataset
7 | from torch.utils.data import Dataset
8 | from utils import EditBatchSampler, dict_to, scr
9 |
10 | LOG = logging.getLogger(__name__)
11 |
12 |
13 | def is_ascii(s):
14 | return all(ord(c) < 128 for c in s)
15 |
16 |
17 | def filter_text(iterator):
18 | valid = []
19 | for text in iterator:
20 | if len(text.split(" ")) < 50:
21 | continue
22 | if not is_ascii(text):
23 | continue
24 | valid.append(text)
25 |
26 | return valid
27 |
28 |
29 | class GenDataset(Dataset):
30 | def __init__(
31 | self,
32 | split: str,
33 | tokenizer,
34 | config,
35 | edit_path: str,
36 | pct: int = 10,
37 | max_length: int = 200,
38 | ):
39 | version = "wikitext-103-raw-v1"
40 | split_str = f"{split}[:{pct}%]" if split == "train" else split
41 | LOG.info(f"Loading wikitext version {version}, split {split_str}")
42 | base_samples = load_dataset(
43 | "wikitext", version, cache_dir=scr(), split=split_str
44 | )["text"]
45 | self.base_samples = filter_text(base_samples)
46 | with open(edit_path + split[:5] + ".json", "r") as f:
47 | self.edit_samples = json.load(f)
48 |
49 | self.tok = tokenizer
50 | self.config = config
51 | self.max_length = max_length
52 | self.n_tokens = self.edit_samples["n_tokens"]
53 |
54 | len_base = len(self.base_samples)
55 | len_edit = len(self.edit_samples["original"])
56 | LOG.info(f"Loaded {len_base} wiki-103 samples and {len_edit} edit samples")
57 |
58 | if config.data.wiki_webtext:
59 | self.use_wiki = True
60 | LOG.info("** Using webtext for wiki base samples **")
61 | webtext = load_dataset(
62 | "stas/openwebtext-10k", split="train", cache_dir=scr()
63 | )["text"]
64 | n_train = int(len(webtext) * 0.9)
65 | if split == "train":
66 | self.base_samples = webtext[:n_train]
67 | else:
68 | self.base_samples = webtext[n_train:]
69 | else:
70 | self.use_wiki = False
71 |
72 | def edit_generator(self, batch_size, n=None):
73 | if n is None:
74 | n = len(self)
75 | sampler = EditBatchSampler(
76 | n,
77 | memorize_mode=self.config.single_batch,
78 | loc_disjoint=not self.use_wiki,
79 | seed=self.config.seed,
80 | )
81 | while True:
82 | edit_idxs, loc_idxs = sampler.sample(batch_size)
83 |
84 | edit_batch = [self.edit_samples["completions"][idx] for idx in edit_idxs]
85 | loc_batch = [
86 | self.base_samples[idx % len(self.base_samples)] for idx in loc_idxs
87 | ]
88 |
89 | edit_toks = self.tok(edit_batch, padding=True, return_tensors="pt")
90 | loc_toks = self.tok(
91 | loc_batch,
92 | padding=True,
93 | return_tensors="pt",
94 | truncation=self.config.data.wiki_webtext,
95 | max_length=self.max_length,
96 | )
97 |
98 | edit_inner = {**edit_toks}
99 | edit_inner["labels"] = self.get_edit_labels(edit_toks["input_ids"])
100 |
101 | edit_outer = copy.deepcopy(edit_inner)
102 | if self.config.data.rephrase:
103 | lens = (edit_outer["input_ids"] != -100).sum(-1)
104 | remove = random.randint(0, (min(lens) - self.n_tokens) // 2)
105 | for k, v in edit_outer.items():
106 | edit_outer[k] = v[:, remove:]
107 |
108 | loc = {**loc_toks}
109 | loc["labels"] = self.get_labels(loc_toks["input_ids"])
110 | cond = {**edit_toks}
111 |
112 | batch = {
113 | "edit_inner": edit_inner,
114 | "edit_outer": edit_outer,
115 | "loc": loc,
116 | "cond": cond,
117 | }
118 |
119 | yield dict_to(batch, self.config.device)
120 |
121 | def __len__(self):
122 | return len(self.edit_samples["original"])
123 |
124 | def _check_padding(self, ids):
125 | if (ids[:, 0] == self.tok.pad_token_id).any():
126 | raise ValueError("Left-padding not supported for GPT2")
127 |
128 | def get_edit_labels(self, ids):
129 | self._check_padding(ids)
130 |
131 | labels = ids.clone()
132 | end_idxs = (labels != self.tok.pad_token_id).sum(-1)
133 | for batch_idx, end_idx in enumerate(end_idxs):
134 | labels[batch_idx, : end_idx - self.n_tokens] = -100
135 | labels[labels == self.tok.pad_token_id] = -100
136 | return labels
137 |
138 | def get_labels(self, ids):
139 | self._check_padding(ids)
140 |
141 | return ids.masked_fill(ids == self.tok.pad_token_id, -100)
142 |
143 | def __getitem__(self, idx):
144 | return self.base_samples[idx]
145 |
--------------------------------------------------------------------------------
/memit/baselines/mend/editable_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .losses import masked_log_probs
4 | from .utils import _logits, shift_targets
5 |
6 |
7 | class EditableModel(nn.Module):
8 | def __init__(self, model, config, model_constructor):
9 | super().__init__()
10 |
11 | self.model = model
12 | self.config = config
13 | self.model_constructor = model_constructor
14 |
15 | def _edit_loss_fn(pred, targ):
16 | return masked_log_probs(pred, targ, shift=shift_targets(self.config))
17 |
18 | self.edit_loss_fn = _edit_loss_fn
19 | self.loc_loss_fn = _edit_loss_fn
20 |
21 | def edit(self, batch, condition=None, detach_history=False):
22 | raise NotImplementedError
23 |
24 | def forward(self, *inputs, **kwargs):
25 | return _logits(self.model(*inputs, **kwargs))
26 |
27 | def outer_parameters(self):
28 | return self.parameters()
29 |
30 | def base_loss(self, input_ids, attention_masks, label_ids):
31 | pass
32 |
--------------------------------------------------------------------------------
/memit/baselines/mend/hooks.py:
--------------------------------------------------------------------------------
1 | from .utils import parent_module
2 |
3 |
4 | def linear_backward_hook(mod, grad_in, grad_out):
5 | if not hasattr(mod, "weight"):
6 | print(f"{mod} has no weight!")
7 | return
8 |
9 | if hasattr(mod.weight, "__x__"):
10 | assert len(grad_out) == 1
11 | # mod.weight.__bgrad__ = grad_out[0].unsqueeze(-1) * mod.__x__[0].unsqueeze(-2)
12 | mod.weight.__delta__ = grad_out[0].detach()
13 | else:
14 | print(f"{mod} has no __x__")
15 |
16 |
17 | def linear_forward_hook(mod, activations, output):
18 | assert len(activations) == 1
19 | mod.weight.__x__ = activations[0].detach()
20 |
21 |
22 | def hook_model(model, pnames):
23 | handles = []
24 | for m in [parent_module(model, pname) for pname in pnames]:
25 | handles.append(m.register_full_backward_hook(linear_backward_hook))
26 | handles.append(m.register_forward_hook(linear_forward_hook))
27 |
28 | model.handles = handles
29 |
--------------------------------------------------------------------------------
/memit/baselines/mend/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def kl_loc_loss(pre, post, mask=None):
6 | pre = pre.to(torch.float32)
7 | post = post.to(torch.float32)
8 |
9 | sequence = pre.dim() == 3
10 | pre_ = pre.view(-1, pre.shape[-1])
11 | post_ = post.view(pre_.shape)
12 | assert pre_.shape[0] == post_.shape[0]
13 |
14 | if not sequence:
15 | if pre_.shape[-1] == 1: # No masking needed for binary classification
16 | return (pre.sigmoid() * (F.logsigmoid(pre) - F.logsigmoid(post))).mean() + (
17 | (-pre).sigmoid() * (F.logsigmoid(-pre) - F.logsigmoid(-post))
18 | ).mean()
19 | else: # We have sequences of predictions; masking needed
20 | if pre_.shape[-1] > 1:
21 | assert mask is not None
22 | mask_ = mask.view(pre_.shape[0])
23 | kl = (
24 | pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1))
25 | ).sum(-1)
26 | return (kl * mask_).sum() / mask_.sum()
27 |
28 | raise NotImplementedError
29 |
30 |
31 | def binary_log_probs(pred, targ):
32 | neg_mask = torch.ones_like(pred)
33 | neg_mask[targ == 0] *= -1
34 | pred = pred * neg_mask
35 | log_probs = F.logsigmoid(pred)
36 | acc = (log_probs.exp() > 0.5).float().mean()
37 | return {
38 | "acc": acc,
39 | "log_prob": log_probs.mean(),
40 | "prob": log_probs.exp().mean(),
41 | "nll": -log_probs.mean(),
42 | "n_tokens": log_probs.shape[0],
43 | }
44 |
45 |
46 | def multiclass_log_probs(pred, targ, shift=True):
47 | NULL_TOKEN = 0 # a placeholder used for masked target locations
48 |
49 | pred = pred.clone()
50 | targ = targ.clone()
51 | if shift and pred.dim() == 3: # Dealing with sequences
52 | pred = pred[:, :-1] # Remove last prediction in sequence
53 | targ = targ[:, 1:] # Shift to align predictions and targets
54 |
55 | mask = targ != -100
56 | targ[~mask] = NULL_TOKEN # Can be any valid token, since we'll throw them out
57 | unmasked_log_probs = pred.log_softmax(-1).gather(-1, targ.unsqueeze(-1)).squeeze(-1)
58 |
59 | pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
60 | correct = pred_ids == targ
61 | if pred.dim() == 3:
62 | correct = (pred_ids == targ).all(-1) # We want to get the whole sequence right
63 | acc = correct.float().mean()
64 |
65 | n_tokens = mask.float().sum()
66 | log_prob = (unmasked_log_probs * mask.float()).sum() / n_tokens
67 | prob = (unmasked_log_probs.exp() * mask.float()).sum() / n_tokens
68 | return {
69 | "acc": acc,
70 | "log_prob": log_prob,
71 | "prob": prob,
72 | "n_tokens": n_tokens,
73 | "nll": -log_prob,
74 | }
75 |
76 |
77 | def masked_log_probs(pred, targ, shift=True):
78 | pred = pred.to(torch.float32)
79 |
80 | if not (pred.dim() == 2 or pred.dim() == 3):
81 | raise RuntimeError(f"Expected pred to have 2 or 3 dimensions, got {pred.shape}")
82 |
83 | if pred.shape[-1] == 1:
84 | return binary_log_probs(pred, targ)
85 | else:
86 | return multiclass_log_probs(pred, targ, shift=shift)
87 |
--------------------------------------------------------------------------------
/memit/baselines/mend/mend_hparams.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | from util.hparams import HyperParams
4 |
5 |
6 | @dataclass
7 | class MENDHyperParams(HyperParams):
8 | lr_scale: float
9 | n_toks: int
10 | model_name: str
11 | counterfact: bool
12 | mini: bool
13 | zsre: bool
14 |
--------------------------------------------------------------------------------
/memit/baselines/mend/mend_main.py:
--------------------------------------------------------------------------------
1 | import os
2 | from copy import deepcopy
3 | from typing import Dict, List
4 |
5 | import hydra
6 | import torch
7 | from transformers import AutoModelForCausalLM, AutoTokenizer
8 |
9 | from util.globals import *
10 |
11 | from .algs.mend import MEND
12 | from .mend_hparams import MENDHyperParams
13 |
14 |
15 | class MendRewriteExecutor:
16 | method_name = "MEND"
17 |
18 | def __init__(self):
19 | self.is_init = False
20 |
21 | def init_model(self, model, tok, params):
22 | train_ds = (
23 | "counterfact-" if params.counterfact else ("zsre-" if params.zsre else "")
24 | )
25 | mini_string = "mini-" if params.mini else ""
26 |
27 | model_name = "gpt2-xl" if params.model_name == "gpt2-xl" else "gpt-j-6b"
28 | modelcode = "gpt2xl" if params.model_name == "gpt2-xl" else "gptj"
29 | model_filename = (
30 | f"mend-{mini_string}{params.n_toks}tok-{train_ds}{model_name}.pt"
31 | )
32 | model_dir = "baselines/mend/weights"
33 |
34 | os.makedirs(model_dir, exist_ok=True)
35 | if not os.path.isfile(f"{model_dir}/{model_filename}"):
36 | remote_url = f"{REMOTE_ROOT_URL}/data/weights/{model_filename}"
37 | print(f"Attemping to download from {remote_url}")
38 | torch.hub.download_url_to_file(remote_url, f"{model_dir}/{model_filename}")
39 | with hydra.initialize(config_path="config", job_name="run"):
40 | config = hydra.compose(
41 | config_name="config",
42 | overrides=[
43 | "+alg=mend",
44 | "+experiment=gen",
45 | f"+model={modelcode}",
46 | f"data.path=data/{params.n_toks}token/data/self_sample/",
47 | ],
48 | )
49 |
50 | def add_padding(tokenizer, model):
51 | tokenizer.add_special_tokens({"pad_token": "[PAD]"})
52 | model.resize_token_embeddings(len(tokenizer))
53 | model.transformer.wte.weight.data[
54 | -1
55 | ] = model.transformer.wte.weight.data.mean(0)
56 |
57 | # Customize the gpt2xl and tokenizer
58 | self.model = model
59 | self.tokenizer = tok
60 | add_padding(self.tokenizer, self.model)
61 |
62 | # Load the trained MEND model
63 | self.alg = MEND(self.model, config, lambda: deepcopy(self.model))
64 | d = torch.load(f"{model_dir}/{model_filename}")
65 | self.alg.load_state_dict(
66 | {k.replace("gtn.", "mend."): v for k, v in d["model"].items()}
67 | )
68 | self.alg.cuda()
69 |
70 | # Disable unneeded gradients
71 | for n, p in self.model.named_parameters():
72 | if n not in config.model.inner_params:
73 | p.requires_grad = False
74 | self.is_init = True
75 |
76 | def reset_model(self):
77 | self.is_init = False
78 | del self.model, self.tokenizer, self.alg
79 |
80 | def apply_to_model(
81 | self,
82 | model: AutoModelForCausalLM,
83 | tok: AutoTokenizer,
84 | requests: List[Dict],
85 | hparams: MENDHyperParams,
86 | copy=False,
87 | return_orig_weights=False,
88 | ):
89 | """
90 | Given a request, for example
91 | {'prompt': '{} has the position of',
92 | 'subject': 'Charles Herman Helmsing',
93 | 'relation_id': 'P39',
94 | 'target_new': {'str': 'President', 'id': 'Q11696'},
95 | 'target_true': {'str': 'bishop', 'id': 'Q29182'}}
96 | Returns a dictionary of numpy arrays that specifies
97 | how mend will change the weights of the model.
98 | """
99 |
100 | if not self.is_init:
101 | self.init_model(model, tok, hparams)
102 |
103 | weights_copy = {}
104 | model = deepcopy(self.model) if copy else self.model
105 |
106 | # Define i/o
107 | targets = [
108 | (" " if request["target_new"]["str"][0] != " " else "")
109 | + request["target_new"]["str"]
110 | for request in requests
111 | ]
112 | sentences = [
113 | request["prompt"].format(request["subject"]) + targets[i]
114 | for i, request in enumerate(requests)
115 | ]
116 |
117 | # Tokenize
118 | sent_tok = self.tokenizer(sentences, padding=True, return_tensors="pt").to(
119 | "cuda"
120 | )
121 | target_tok = self.tokenizer(targets, padding=True, return_tensors="pt").to(
122 | "cuda"
123 | )
124 |
125 | # Define labels
126 | label_tok = deepcopy(sent_tok["input_ids"])
127 | for i in range(label_tok.size(0)):
128 | target_len = target_tok["attention_mask"][i].sum()
129 | padding_len = (
130 | sent_tok["input_ids"].size(1) - sent_tok["attention_mask"][i].sum()
131 | )
132 | label_tok[i][: -target_len - padding_len] = -100
133 | label_tok[i][label_tok[i] == self.tokenizer.pad_token_id] = -100
134 |
135 | # Run MEND
136 | edit_inner = dict(
137 | input_ids=sent_tok["input_ids"],
138 | attention_mask=sent_tok["attention_mask"],
139 | labels=label_tok,
140 | )
141 | cond = {k: sent_tok[k] for k in ["input_ids", "attention_mask"]}
142 | _, model_info = self.alg.edit(edit_inner, cond, return_factors=True)
143 | factors = {
144 | k + "." + n: v.detach().cpu().numpy()
145 | for k, pair in model_info["factors"].items()
146 | for n, v in zip("uv", pair)
147 | }
148 | # Also keep these learned LRs.
149 | factors["edit_lrs"] = self.alg.edit_lrs.detach().cpu().numpy()
150 |
151 | # Edit!
152 | d = factors
153 | torch_factors = {k: torch.tensor(v) for k, v in d.items()}
154 | eli = 0
155 | edit_lrs = torch_factors["edit_lrs"]
156 |
157 | with torch.no_grad():
158 | for n, p in model.named_parameters():
159 | uname, vname = f"{n}.u", f"{n}.v"
160 | if uname in torch_factors:
161 | if return_orig_weights and n not in weights_copy:
162 | weights_copy[n] = p.detach().clone()
163 |
164 | if "gpt2" in hparams.model_name:
165 | delta = torch_factors[uname].t() @ torch_factors[vname]
166 | elif "gpt-j-6B" in hparams.model_name:
167 | delta = torch_factors[vname].t() @ torch_factors[uname]
168 | else:
169 | raise ValueError("Unknown model")
170 | p.add_((delta * edit_lrs[eli] * hparams.lr_scale).to(p.device))
171 | eli += 1
172 |
173 | return model, weights_copy
174 |
--------------------------------------------------------------------------------
/memit/baselines/mend/models.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 |
4 | import torch
5 | import torch.nn as nn
6 | import transformers
7 |
8 | from .utils import scr
9 |
10 | LOG = logging.getLogger(__name__)
11 |
12 |
13 | class CastModule(nn.Module):
14 | def __init__(
15 | self,
16 | module: nn.Module,
17 | in_cast: torch.dtype = torch.float32,
18 | out_cast: torch.dtype = None,
19 | ):
20 | super().__init__()
21 |
22 | self.underlying = module
23 | self.in_cast = in_cast
24 | self.out_cast = out_cast
25 |
26 | def cast(self, obj, dtype):
27 | if dtype is None:
28 | return obj
29 |
30 | if isinstance(obj, torch.Tensor):
31 | return obj.to(dtype)
32 | else:
33 | return obj
34 |
35 | def forward(self, *args, **kwargs):
36 | args = tuple(self.cast(a, self.in_cast) for a in args)
37 | kwargs = {k: self.cast(v, self.in_cast) for k, v in kwargs.items()}
38 | outputs = self.underlying(*args, **kwargs)
39 | if isinstance(outputs, torch.Tensor):
40 | outputs = self.cast(outputs, self.out_cast)
41 | elif isinstance(outputs, tuple):
42 | outputs = tuple(self.cast(o, self.out_cast) for o in outputs)
43 | else:
44 | raise RuntimeError(f"Not sure how to cast type {type(outputs)}")
45 | return outputs
46 |
47 | def extra_repr(self):
48 | return f"in_cast: {self.in_cast}\nout_cast: {self.out_cast}"
49 |
50 |
51 | class BertClassifier(torch.nn.Module):
52 | def __init__(self, model_name, hidden_dim=768):
53 | super().__init__()
54 | self.model = transformers.BertModel.from_pretrained(model_name, cache_dir=scr())
55 | self.classifier = torch.nn.Linear(hidden_dim, 1)
56 |
57 | @property
58 | def config(self):
59 | return self.model.config
60 |
61 | def forward(self, *args, **kwargs):
62 | filtered_kwargs = {k: v for k, v in kwargs.items() if k != "labels"}
63 | return self.classifier(self.model(*args, **filtered_kwargs)[1])
64 |
65 |
66 | def get_model(config):
67 | if config.model.class_name == "BertClassifier":
68 | model = BertClassifier(config.model.name)
69 | else:
70 | ModelClass = getattr(transformers, config.model.class_name)
71 | LOG.info(
72 | f"Loading model class {ModelClass} with name {config.model.name} from cache dir {scr()}"
73 | )
74 | model = ModelClass.from_pretrained(config.model.name, cache_dir=scr())
75 |
76 | if config.model.pt is not None:
77 | LOG.info(f"Loading model initialization from {config.model.pt}")
78 | state_dict = torch.load(config.model.pt, map_location="cpu")
79 |
80 | try:
81 | model.load_state_dict(state_dict)
82 | except RuntimeError:
83 | LOG.info("Default load failed; stripping prefix and trying again.")
84 | state_dict = {re.sub("^model.", "", k): v for k, v in state_dict.items()}
85 |
86 | model.load_state_dict(state_dict)
87 |
88 | LOG.info("Loaded model initialization")
89 |
90 | if config.dropout is not None:
91 | n_reset = 0
92 | for m in model.modules():
93 | if isinstance(m, nn.Dropout):
94 | m.p = config.dropout
95 | n_reset += 1
96 |
97 | if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout
98 | if isinstance(m.dropout, float):
99 | m.dropout = config.dropout
100 | n_reset += 1
101 |
102 | if hasattr(
103 | m, "activation_dropout"
104 | ): # Requires for BART, which uses F.dropout
105 | if isinstance(m.activation_dropout, float):
106 | m.activation_dropout = config.dropout
107 | n_reset += 1
108 |
109 | LOG.info(f"Set {n_reset} dropout modules to p={config.dropout}")
110 |
111 | param_names = [n for n, _ in model.named_parameters()]
112 | bad_inner_params = [p for p in config.model.inner_params if p not in param_names]
113 | if len(bad_inner_params) != 0:
114 | raise ValueError(
115 | f"Params {bad_inner_params} do not exist in model of type {type(model)}."
116 | )
117 |
118 | if config.no_grad_layers is not None:
119 | if config.half:
120 | model.bfloat16()
121 |
122 | def upcast(mod):
123 | modlist = None
124 | for child in mod.children():
125 | if isinstance(child, nn.ModuleList):
126 | assert modlist is None, f"Found multiple modlists for {mod}"
127 | modlist = child
128 | if modlist is None:
129 | raise RuntimeError("Couldn't find a ModuleList child")
130 |
131 | LOG.info(
132 | f"Setting {len(modlist) - config.no_grad_layers} modules to full precision, with autocasting"
133 | )
134 | modlist[config.no_grad_layers :].to(torch.float32)
135 | modlist[config.no_grad_layers] = CastModule(modlist[config.no_grad_layers])
136 | modlist[-1] = CastModule(
137 | modlist[-1], in_cast=torch.float32, out_cast=torch.bfloat16
138 | )
139 |
140 | parents = []
141 | if hasattr(model, "transformer"):
142 | parents.append(model.transformer)
143 | if hasattr(model, "encoder"):
144 | parents.append(model.encoder)
145 | if hasattr(model, "decoder"):
146 | parents.append(model.decoder)
147 | if hasattr(model, "model"):
148 | parents.extend([model.model.encoder, model.model.decoder])
149 |
150 | for t in parents:
151 | t.no_grad_layers = config.no_grad_layers
152 | if config.half:
153 | upcast(t)
154 |
155 | if config.half:
156 | idxs = []
157 | for p in config.model.inner_params:
158 | for comp in p.split("."):
159 | if comp.isdigit():
160 | idxs.append(int(comp))
161 | max_idx, min_idx = str(max(idxs)), str(config.no_grad_layers)
162 | for pidx, p in enumerate(config.model.inner_params):
163 | comps = p.split(".")
164 | if max_idx in comps or min_idx in comps:
165 | index = (
166 | comps.index(max_idx)
167 | if max_idx in comps
168 | else comps.index(min_idx)
169 | )
170 | comps.insert(index + 1, "underlying")
171 | new_p = ".".join(comps)
172 | LOG.info(
173 | f"Replacing config.model.inner_params[{pidx}] '{p}' -> '{new_p}'"
174 | )
175 | config.model.inner_params[pidx] = new_p
176 |
177 | return model
178 |
179 |
180 | def get_tokenizer(config):
181 | tok_name = (
182 | config.model.tokenizer_name
183 | if config.model.tokenizer_name is not None
184 | else config.model.name
185 | )
186 | return getattr(transformers, config.model.tokenizer_class).from_pretrained(
187 | tok_name, cache_dir=scr()
188 | )
189 |
190 |
191 | if __name__ == "__main__":
192 | m = BertClassifier("bert-base-uncased")
193 | m(torch.arange(5)[None, :])
194 | import pdb
195 |
196 | pdb.set_trace()
197 |
--------------------------------------------------------------------------------
/memit/baselines/mend/oracle.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import torch
4 | import torch.nn as nn
5 | from higher.patch import monkeypatch as make_functional
6 | from losses import kl_loc_loss, masked_log_probs
7 |
8 |
9 | def test_rank1(model, dataset, config):
10 | model.eval()
11 | generator = dataset.edit_generator(21)
12 |
13 | history = []
14 | for example in generator:
15 | edit_model = make_functional(model, track_higher_grads=False)
16 | residuals = {}
17 | opt_list = []
18 | print(config.model.inner_params)
19 | for n, p in edit_model.named_parameters():
20 | if n in config.model.inner_params:
21 | std = 0.01
22 | u = nn.Parameter(torch.randn(p.shape[0], 1, device=p.device) * std)
23 | v = nn.Parameter(torch.randn(1, p.shape[1], device=p.device) * std)
24 | assert (
25 | u @ v
26 | ).shape == p.shape, f"got {(u@v).shape}, expected {p.shape}"
27 |
28 | residuals[n] = (u, v)
29 | opt_list.extend([u, v])
30 |
31 | res_opt = torch.optim.SGD(opt_list, lr=100)
32 |
33 | acc = 0
34 | it = 0
35 | ids_train = example["loc_ids"][:10]
36 | ids_val = example["loc_ids"][10:]
37 | with torch.inference_mode():
38 | original_logits_train = model(ids_train)
39 | original_logits_val = model(ids_val)
40 | if hasattr(original_logits_train, "logits"):
41 | original_logits_train = original_logits_train.logits
42 | original_logits_val = original_logits_val.logits
43 |
44 | while acc < 1 and it < 1000:
45 | fast_params = []
46 | for n, p in edit_model.named_parameters():
47 | if n in residuals:
48 | u, v = residuals[n]
49 | fast_params.append(p.detach() + (u @ v))
50 | else:
51 | fast_params.append(p.detach())
52 |
53 | loc_pred = edit_model(ids_train, params=fast_params)
54 | if hasattr(loc_pred, "logits"):
55 | loc_pred = loc_pred.logits
56 |
57 | loc_loss = kl_loc_loss(original_logits_train, loc_pred)
58 |
59 | pred_log = edit_model(example["edit_inner_ids"], params=fast_params)
60 | if hasattr(pred_log, "logits"):
61 | pred_log = pred_log.logits
62 | prob_dict = masked_log_probs(pred_log, example["edit_inner_labels"])
63 | edit_loss = prob_dict["nll"]
64 | acc = prob_dict["acc"]
65 |
66 | loss = loc_loss + 0.0002 * edit_loss
67 | with torch.inference_mode():
68 | loc_pred_val = edit_model(ids_val, params=fast_params)
69 | if hasattr(loc_pred_val, "logits"):
70 | loc_pred_val = loc_pred_val.logits
71 |
72 | if pred_log.dim() == 3:
73 | facc = (
74 | (
75 | pred_log.argmax(-1)[0, -10:-1]
76 | == example["edit_inner_labels"][0, -9:]
77 | )
78 | .float()
79 | .mean()
80 | )
81 | ret = (
82 | (original_logits_val.argmax(-1) == loc_pred_val.argmax(-1))
83 | .float()
84 | .mean()
85 | )
86 | else:
87 | facc = (pred_log > 0) == example["edit_inner_labels"]
88 | ret = (
89 | ((original_logits_val > 0) == (loc_pred_val > 0)).float().mean()
90 | )
91 |
92 | print(
93 | f"{it}, ({loss.item():.6f}, {loc_loss.item():.4f}, {edit_loss.item():.4f}), {facc.item():.2f}, {ret.item():.4f} {(u@v).view(-1).norm().item():.5f}",
94 | end="\r",
95 | )
96 |
97 | for p, g in zip(opt_list, torch.autograd.grad(loss, opt_list)):
98 | p.grad = g
99 | res_opt.step()
100 | res_opt.zero_grad()
101 |
102 | it += 1
103 |
104 | if acc == 1:
105 | history.append(1)
106 | else:
107 | history.append(0)
108 |
109 | print()
110 | print(len(history), sum(history) / len(history), ret.item())
111 |
--------------------------------------------------------------------------------
/memit/baselines/mend/requirements.txt:
--------------------------------------------------------------------------------
1 | hydra-core
2 | numpy
3 | torch
4 | click==7.1.2 # Spacy breaks for click>=8.0
5 | spacy
6 | allennlp
7 | git+git://github.com/eric-mitchell/higher@master # For in-place functional models
8 | git+git://github.com/eric-mitchell/transformers@master # To enable gradient disabling for some models
9 | datasets
10 | jsonlines
11 | wandb
12 |
--------------------------------------------------------------------------------
/memit/baselines/mend/run.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import importlib
3 | import logging
4 | import random
5 |
6 | import hydra
7 | import models
8 | import numpy as np
9 | import torch
10 | import utils
11 | from omegaconf import OmegaConf
12 | from trainer import EditTrainer
13 |
14 | OmegaConf.register_new_resolver("uuid", lambda: utils.uuid())
15 |
16 |
17 | logging.basicConfig(
18 | format="%(asctime)s - %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
19 | level=logging.INFO,
20 | )
21 | LOG = logging.getLogger(__name__)
22 |
23 |
24 | def add_padding(tokenizer, model):
25 | tokenizer.add_special_tokens({"pad_token": "[PAD]"})
26 | model.resize_token_embeddings(len(tokenizer))
27 | model.transformer.wte.weight.data[-1] = model.transformer.wte.weight.data.mean(0)
28 |
29 |
30 | @hydra.main(config_path="config", config_name="config")
31 | def run(config):
32 | LOG.info(f"\n\n{OmegaConf.to_yaml(config)}\n")
33 | base_dir = hydra.utils.get_original_cwd()
34 | LOG.info(f"Project base directory: {base_dir}")
35 |
36 | random.seed(config.seed)
37 | np.random.seed(config.seed)
38 | torch.manual_seed(config.seed)
39 |
40 | model = models.get_model(config)
41 | tokenizer = models.get_tokenizer(config)
42 |
43 | if config.task == "gen" or config.task == "wiki":
44 | add_padding(tokenizer, model)
45 | from data_classes.wiki import GenDataset
46 |
47 | train_set = GenDataset("train", tokenizer, config, config.data.path, pct=10)
48 | val_set = GenDataset("validation", tokenizer, config, config.data.path, pct=10)
49 | elif config.task == "fc" or config.task == "fever":
50 | from data_classes.fever import BinaryAugmentedKILT
51 |
52 | train_set = BinaryAugmentedKILT(
53 | tokenizer, f"{base_dir}/data/fever/fever-train-kilt.jsonl", config
54 | )
55 | val_set = BinaryAugmentedKILT(
56 | tokenizer, f"{base_dir}/data/fever/fever-dev-kilt.jsonl", config
57 | )
58 | elif config.task == "qa" or config.task == "zsre":
59 | from data_classes.zsre import Seq2SeqAugmentedKILT
60 |
61 | train_set = Seq2SeqAugmentedKILT(
62 | tokenizer,
63 | f"{base_dir}/data/zsre/structured_zeroshot-train-new_annotated_final.jsonl",
64 | config,
65 | )
66 | val_set = Seq2SeqAugmentedKILT(
67 | tokenizer,
68 | f"{base_dir}/data/zsre/structured_zeroshot-dev-new_annotated_final.jsonl",
69 | config,
70 | )
71 | else:
72 | raise ValueError(f"Unrecognized task {config.task}")
73 |
74 | alg_module = importlib.import_module(f"algs.{config.alg}")
75 | LOG.info(f"Loading class {config.alg.upper()} from module {alg_module}")
76 | AlgClass = getattr(alg_module, config.alg.upper())
77 | alg = AlgClass(model, config, lambda: copy.deepcopy(model))
78 |
79 | if config.alg == "ft" and config.ft.locality.enabled:
80 | if config.ft.locality.oracle:
81 | alg.loc_sampler = train_set.edit_generator(
82 | config.ft.locality.batch_size + 1
83 | )
84 | else:
85 | state = np.random.get_state()
86 | np.random.seed(0)
87 | loc_batch = next(
88 | train_set.edit_generator(config.ft.locality.batch_size + 1)
89 | )["loc"]
90 | np.random.set_state(state)
91 | alg.loc_ids = loc_batch["input_ids"]
92 | alg.loc_masks = loc_batch["attention_mask"]
93 |
94 | trainer = EditTrainer(alg, config, train_set, val_set)
95 | trainer.run()
96 |
97 |
98 | if __name__ == "__main__":
99 | run()
100 |
--------------------------------------------------------------------------------
/memit/dsets/__init__.py:
--------------------------------------------------------------------------------
1 | from .attr_snippets import AttributeSnippets
2 | from .counterfact import CounterFactDataset, MultiCounterFactDataset
3 | from .knowns import KnownsDataset
4 | from .tfidf_stats import get_tfidf_vectorizer
5 | from .zsre import MENDQADataset
6 |
--------------------------------------------------------------------------------
/memit/dsets/attr_snippets.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import json
3 | from pathlib import Path
4 |
5 | import torch
6 |
7 | from util.globals import *
8 |
9 | REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/attribute_snippets.json"
10 |
11 |
12 | class AttributeSnippets:
13 | """
14 | Contains wikipedia snippets discussing entities that have some property.
15 |
16 | More formally, given a tuple t = (s, r, o):
17 | - Let snips = AttributeSnippets(DATA_DIR)
18 | - snips[r][o] is a list of wikipedia articles for all s' such that t' = (s', r, o) is valid.
19 | """
20 |
21 | def __init__(self, data_dir: str):
22 | data_dir = Path(data_dir)
23 | snips_loc = data_dir / "attribute_snippets.json"
24 | if not snips_loc.exists():
25 | print(f"{snips_loc} does not exist. Downloading from {REMOTE_URL}")
26 | data_dir.mkdir(exist_ok=True, parents=True)
27 | torch.hub.download_url_to_file(REMOTE_URL, snips_loc)
28 |
29 | with open(snips_loc, "r") as f:
30 | snippets_list = json.load(f)
31 |
32 | snips = collections.defaultdict(lambda: collections.defaultdict(list))
33 |
34 | for el in snippets_list:
35 | rid, tid = el["relation_id"], el["target_id"]
36 | for sample in el["samples"]:
37 | snips[rid][tid].append(sample)
38 |
39 | self._data = snips
40 | self.snippets_list = snippets_list
41 |
42 | def __getitem__(self, item):
43 | return self._data[item]
44 |
--------------------------------------------------------------------------------
/memit/dsets/counterfact.py:
--------------------------------------------------------------------------------
1 | import json
2 | import typing
3 | from pathlib import Path
4 |
5 | import torch
6 | from torch.utils.data import Dataset
7 |
8 | from util.globals import *
9 |
10 | REMOTE_ROOT = f"{REMOTE_ROOT_URL}/data/dsets"
11 |
12 |
13 | class ConceptvectorsDataset(Dataset):
14 | def __init__(
15 | self,
16 | data_dir: str,
17 | multi: bool = False,
18 | size: typing.Optional[int] = None,
19 | *args,
20 | **kwargs,
21 | ):
22 | data_dir = Path(data_dir)
23 | cf_loc = data_dir / (
24 | "counterfact.json" if not multi else "multi_counterfact.json"
25 | )
26 | if not cf_loc.exists():
27 | remote_url = f"{REMOTE_ROOT}/{'multi_' if multi else ''}counterfact.json"
28 | print(f"{cf_loc} does not exist. Downloading from {remote_url}")
29 | data_dir.mkdir(exist_ok=True, parents=True)
30 | torch.hub.download_url_to_file(remote_url, cf_loc)
31 |
32 | with open(cf_loc, "r") as f:
33 | self.data = json.load(f)
34 | if size is not None:
35 | self.data = self.data[:size]
36 |
37 | print(f"Loaded dataset with {len(self)} elements")
38 |
39 | def __len__(self):
40 | return len(self.data)
41 |
42 | def __getitem__(self, item):
43 | return self.data[item]
44 |
45 |
46 | class CounterFactDataset(Dataset):
47 | def __init__(
48 | self,
49 | data_dir: str,
50 | multi: bool = False,
51 | size: typing.Optional[int] = None,
52 | *args,
53 | **kwargs,
54 | ):
55 | data_dir = Path(data_dir)
56 | cf_loc = data_dir / (
57 | "counterfact.json" if not multi else "multi_counterfact.json"
58 | )
59 | if not cf_loc.exists():
60 | remote_url = f"{REMOTE_ROOT}/{'multi_' if multi else ''}counterfact.json"
61 | print(f"{cf_loc} does not exist. Downloading from {remote_url}")
62 | data_dir.mkdir(exist_ok=True, parents=True)
63 | torch.hub.download_url_to_file(remote_url, cf_loc)
64 |
65 | with open(cf_loc, "r") as f:
66 | self.data = json.load(f)
67 | if size is not None:
68 | self.data = self.data[:size]
69 |
70 | print(f"Loaded dataset with {len(self)} elements")
71 |
72 | def __len__(self):
73 | return len(self.data)
74 |
75 | def __getitem__(self, item):
76 | return self.data[item]
77 |
78 |
79 | class MultiCounterFactDataset(CounterFactDataset):
80 | def __init__(
81 | self, data_dir: str, size: typing.Optional[int] = None, *args, **kwargs
82 | ):
83 | super().__init__(data_dir, *args, multi=True, size=size, **kwargs)
84 |
--------------------------------------------------------------------------------
/memit/dsets/knowns.py:
--------------------------------------------------------------------------------
1 | import json
2 | import typing
3 | from pathlib import Path
4 |
5 | import torch
6 | from torch.utils.data import Dataset
7 |
8 | from util.globals import *
9 |
10 | REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/known_1000.json"
11 |
12 |
13 | class KnownsDataset(Dataset):
14 | def __init__(self, data_dir: str, *args, **kwargs):
15 | data_dir = Path(data_dir)
16 | known_loc = data_dir / "known_1000.json"
17 | if not known_loc.exists():
18 | print(f"{known_loc} does not exist. Downloading from {REMOTE_URL}")
19 | data_dir.mkdir(exist_ok=True, parents=True)
20 | torch.hub.download_url_to_file(REMOTE_URL, known_loc)
21 |
22 | with open(known_loc, "r") as f:
23 | self.data = json.load(f)
24 |
25 | print(f"Loaded dataset with {len(self)} elements")
26 |
27 | def __len__(self):
28 | return len(self.data)
29 |
30 | def __getitem__(self, item):
31 | return self.data[item]
32 |
--------------------------------------------------------------------------------
/memit/dsets/tfidf_stats.py:
--------------------------------------------------------------------------------
1 | import json
2 | from itertools import chain
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import scipy.sparse as sp
7 | import torch
8 | from sklearn.feature_extraction.text import TfidfVectorizer
9 |
10 | from dsets import AttributeSnippets
11 | from util.globals import *
12 |
13 | REMOTE_IDF_URL = f"{REMOTE_ROOT_URL}/data/dsets/idf.npy"
14 | REMOTE_VOCAB_URL = f"{REMOTE_ROOT_URL}/data/dsets/tfidf_vocab.json"
15 |
16 |
17 | def get_tfidf_vectorizer(data_dir: str):
18 | """
19 | Returns an sklearn TF-IDF vectorizer. See their website for docs.
20 | Loading hack inspired by some online blog post lol.
21 | """
22 |
23 | data_dir = Path(data_dir)
24 |
25 | idf_loc, vocab_loc = data_dir / "idf.npy", data_dir / "tfidf_vocab.json"
26 | if not (idf_loc.exists() and vocab_loc.exists()):
27 | collect_stats(data_dir)
28 |
29 | idf = np.load(idf_loc)
30 | with open(vocab_loc, "r") as f:
31 | vocab = json.load(f)
32 |
33 | class MyVectorizer(TfidfVectorizer):
34 | TfidfVectorizer.idf_ = idf
35 |
36 | vec = MyVectorizer()
37 | vec.vocabulary_ = vocab
38 | vec._tfidf._idf_diag = sp.spdiags(idf, diags=0, m=len(idf), n=len(idf))
39 |
40 | return vec
41 |
42 |
43 | def collect_stats(data_dir: str):
44 | """
45 | Uses wikipedia snippets to collect statistics over a corpus of English text.
46 | Retrieved later when computing TF-IDF vectors.
47 | """
48 |
49 | data_dir = Path(data_dir)
50 | data_dir.mkdir(exist_ok=True, parents=True)
51 | idf_loc, vocab_loc = data_dir / "idf.npy", data_dir / "tfidf_vocab.json"
52 |
53 | try:
54 | print(f"Downloading IDF cache from {REMOTE_IDF_URL}")
55 | torch.hub.download_url_to_file(REMOTE_IDF_URL, idf_loc)
56 | print(f"Downloading TF-IDF vocab cache from {REMOTE_VOCAB_URL}")
57 | torch.hub.download_url_to_file(REMOTE_VOCAB_URL, vocab_loc)
58 | return
59 | except Exception as e:
60 | print(f"Error downloading file:", e)
61 | print("Recomputing TF-IDF stats...")
62 |
63 | snips_list = AttributeSnippets(data_dir).snippets_list
64 | documents = list(chain(*[[y["text"] for y in x["samples"]] for x in snips_list]))
65 |
66 | vec = TfidfVectorizer()
67 | vec.fit(documents)
68 |
69 | idfs = vec.idf_
70 | vocab = vec.vocabulary_
71 |
72 | np.save(data_dir / "idf.npy", idfs)
73 | with open(data_dir / "tfidf_vocab.json", "w") as f:
74 | json.dump(vocab, f, indent=1)
75 |
--------------------------------------------------------------------------------
/memit/dsets/zsre.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | import torch
5 | from transformers import AutoTokenizer
6 |
7 | from util.globals import *
8 |
9 | REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/zsre_mend_eval.json"
10 |
11 |
12 | class MENDQADataset:
13 | """
14 | Dataset of factual knowledge based on zsRE.
15 | Specifically selected from the QA validation slice from Mitchell et al.
16 | Project page: http://nlp.cs.washington.edu/zeroshot/
17 | """
18 |
19 | def __init__(self, data_dir: str, tok: AutoTokenizer, size=None, *args, **kwargs):
20 | data_dir = Path(data_dir)
21 | zsre_loc = data_dir / "zsre_mend_eval.json"
22 | if not zsre_loc.exists():
23 | print(f"{zsre_loc} does not exist. Downloading from {REMOTE_URL}")
24 | data_dir.mkdir(exist_ok=True, parents=True)
25 | torch.hub.download_url_to_file(REMOTE_URL, zsre_loc)
26 |
27 | with open(zsre_loc, "r") as f:
28 | raw = json.load(f)
29 |
30 | data = []
31 | for i, record in enumerate(raw):
32 | assert (
33 | "nq question: " in record["loc"]
34 | ), f"Neighborhood prompt missing `nq question:`. Check for errors?"
35 | ans_toks = tok(" " + record["loc_ans"])["input_ids"]
36 | data.append(
37 | {
38 | "case_id": i,
39 | "requested_rewrite": {
40 | "prompt": record["src"].replace(record["subject"], "{}"),
41 | "subject": record["subject"],
42 | "target_new": {"str": record["answers"][0]},
43 | "target_true": {"str": "<|endoftext|>"},
44 | },
45 | "paraphrase_prompts": [record["rephrase"]],
46 | "neighborhood_prompts": [
47 | {
48 | "prompt": record["loc"] + "?" + tok.decode(ans_toks[:i]),
49 | "target": tok.decode(ans_toks[i]),
50 | }
51 | for i in range(len(ans_toks))
52 | ],
53 | "attribute_prompts": [],
54 | "generation_prompts": [],
55 | }
56 | )
57 |
58 | self._data = data[:size]
59 |
60 | def __getitem__(self, item):
61 | return self._data[item]
62 |
63 | def __len__(self):
64 | return len(self._data)
65 |
--------------------------------------------------------------------------------
/memit/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuaihong/ConceptVectors/607591b415043f7692bc17a9748de3d8ff3fc0c7/memit/experiments/__init__.py
--------------------------------------------------------------------------------
/memit/experiments/py/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from typing import Dict, List, Tuple
4 |
5 | import torch
6 | from transformers import AutoModelForCausalLM, AutoTokenizer
7 |
8 | from baselines.ft import FTHyperParams, apply_ft_to_model
9 | from memit import MEMITHyperParams, apply_memit_to_model
10 | from rome import ROMEHyperParams, apply_rome_to_model
11 | from util import nethook
12 | from util.generate import generate_fast
13 | from util.globals import *
14 |
15 |
16 | def demo_model_editing(
17 | model: AutoModelForCausalLM,
18 | tok: AutoTokenizer,
19 | requests: List[Dict],
20 | generation_prompts: List[str],
21 | alg_name: str = "ROME",
22 | ) -> Tuple[AutoModelForCausalLM, Dict[str, torch.Tensor]]:
23 | """
24 | Applies the selected model editing algorithm. Generates text both before and after
25 | for comparison of model behavior. Returns the updated model and the original values of
26 | weights that were changed.
27 | """
28 |
29 | nethook.set_requires_grad(True, model)
30 |
31 | RewritingParamsClass, apply_method, hparams_prefix, hparams_suffix = load_alg(
32 | alg_name
33 | )
34 | params_name = (
35 | HPARAMS_DIR
36 | / hparams_prefix
37 | / f"{model.config._name_or_path.replace('/', '_')}{hparams_suffix}.json"
38 | )
39 |
40 | print_loud(f"Retrieving {alg_name} hyperparameters")
41 | print("Loading from", params_name)
42 | hparams = RewritingParamsClass.from_json(params_name)
43 | print(hparams)
44 |
45 | print_loud("Generating pre-update text")
46 | pre_update_text = generate_fast(model, tok, generation_prompts, max_out_len=100)
47 | print(pre_update_text)
48 |
49 | print_loud(f"Applying {alg_name} to model")
50 | model_new, orig_weights = apply_method(
51 | model,
52 | tok,
53 | requests,
54 | hparams,
55 | return_orig_weights=True,
56 | )
57 |
58 | print_loud("Generating post-update text")
59 | post_update_text = generate_fast(
60 | model_new, tok, generation_prompts, max_out_len=100
61 | )
62 | print(post_update_text)
63 |
64 | print_loud("Summarizing differences")
65 | for i, (prompt, pre, post) in enumerate(
66 | zip(generation_prompts, pre_update_text, post_update_text)
67 | ):
68 | if i > 0:
69 | print("".join(["-" for _ in range(10)]))
70 |
71 | prompt_str = "[Prompt]:"
72 | pre_str = f"[Pre-{alg_name}]:"
73 | post_str = f"[Post-{alg_name}]:"
74 | pad_to = 1 + max(len(prompt_str), len(pre_str), len(post_str))
75 |
76 | for s, t in zip([prompt_str, post_str, pre_str], [prompt, post, pre]):
77 | print(s.ljust(pad_to), t)
78 |
79 | return model_new, orig_weights
80 |
81 |
82 | def load_alg(alg_name):
83 | """
84 | Loads dependencies for the desired algorithm.
85 | Implementation is slightly awkward to prevent unnecessary imports on Colab.
86 |
87 | The return value is a tuple of the following:
88 | 1. Class for storing hyperparameters
89 | 2. Method for applying rewrites
90 | 3. Location of parameters
91 | 4. Predefined suffix for the param file
92 | """
93 | assert alg_name in [
94 | "FT",
95 | "FT-L",
96 | "FT-AttnEdit",
97 | "MEND",
98 | "MEND-CF",
99 | "MEND-zsRE",
100 | "ROME",
101 | "MEMIT",
102 | ]
103 |
104 | if alg_name == "ROME":
105 | return ROMEHyperParams, apply_rome_to_model, "ROME", ""
106 | elif alg_name == "MEMIT":
107 | return MEMITHyperParams, apply_memit_to_model, "MEMIT", ""
108 | elif "FT" in alg_name:
109 | d = {
110 | "FT": (FTHyperParams, apply_ft_to_model, "FT", "_unconstr"),
111 | "FT-AttnEdit": (FTHyperParams, apply_ft_to_model, "FT", "_attn"),
112 | "FT-L": (FTHyperParams, apply_ft_to_model, "FT", "_constr"),
113 | }
114 | return d[alg_name]
115 | else:
116 | from baselines.mend import MENDHyperParams, MendRewriteExecutor
117 |
118 | d = {
119 | "MEND": (MENDHyperParams, MendRewriteExecutor().apply_to_model, "MEND", ""),
120 | "MEND-CF": (
121 | MENDHyperParams,
122 | MendRewriteExecutor().apply_to_model,
123 | "MEND",
124 | "_CF",
125 | ),
126 | "MEND-zsRE": (
127 | MENDHyperParams,
128 | MendRewriteExecutor().apply_to_model,
129 | "MEND",
130 | "_zsRE",
131 | ),
132 | }
133 | return d[alg_name]
134 |
135 |
136 | def print_loud(x, pad=3):
137 | """
138 | Prints a string with # box for emphasis.
139 |
140 | Example:
141 | ############################
142 | # #
143 | # Applying ROME to model #
144 | # #
145 | ############################
146 | """
147 |
148 | n = len(x)
149 | print()
150 | print("".join(["#" for _ in range(n + 2 * pad)]))
151 | print("#" + "".join([" " for _ in range(n + 2 * (pad - 1))]) + "#")
152 | print(
153 | "#"
154 | + "".join([" " for _ in range(pad - 1)])
155 | + x
156 | + "".join([" " for _ in range(pad - 1)])
157 | + "#"
158 | )
159 | print("#" + "".join([" " for _ in range(n + 2 * (pad - 1))]) + "#")
160 | print("".join(["#" for _ in range(n + 2 * pad)]))
161 |
162 |
163 | class StopExecution(Exception):
164 | def _render_traceback_(self):
165 | pass
166 |
167 |
168 | def stop_execution():
169 | raise StopExecution
170 |
--------------------------------------------------------------------------------
/memit/experiments/py/eval_utils_zsre.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains evaluation utilities for pytorch-based rewriting methods.
3 | To use, simply call `compute_rewrite_quality_zsre` with the
4 | appropriate arguments, which returns a dictionary containing them.
5 | """
6 |
7 | import typing
8 | from itertools import chain
9 |
10 | import numpy as np
11 | import torch
12 | from sklearn.feature_extraction.text import TfidfVectorizer
13 | from transformers import AutoModelForCausalLM, AutoTokenizer
14 |
15 | from dsets import AttributeSnippets
16 |
17 |
18 | def compute_rewrite_quality_zsre(
19 | model: AutoModelForCausalLM,
20 | tok: AutoTokenizer,
21 | record: typing.Dict,
22 | snips: AttributeSnippets,
23 | vec: TfidfVectorizer,
24 | ) -> typing.Dict:
25 | """
26 | Given a rewritten model, computes generalization and specificity metrics for
27 | the desired rewrite (passed in via the CounterFact dataset record). Returns a
28 | dictionary containing those metrics.
29 |
30 | :param model: Rewritten model
31 | :param tok: Tokenizer
32 | :param record: CounterFact dataset record
33 | :paran snips: ???
34 | :param vec: ???
35 | :return: Dictionary containing rewriting metrics
36 | """
37 |
38 | # First, unpack rewrite evaluation record.
39 | subject, target_new, target_true = (
40 | record["requested_rewrite"][x] for x in ["subject", "target_new", "target_true"]
41 | )
42 | rewrite_prompts = [record["requested_rewrite"]["prompt"].format(subject)]
43 | paraphrase_prompts = record["paraphrase_prompts"]
44 | neighborhood_prompts = record["neighborhood_prompts"]
45 |
46 | # Form a list of lists of prefixes to test.
47 | prob_prompts = [
48 | rewrite_prompts,
49 | paraphrase_prompts,
50 | ]
51 | # Flatten all the evaluated prefixes into one list.
52 | target_tok = tok(" " + target_new["str"])["input_ids"]
53 | inp_prompts_og = list(chain(*prob_prompts))
54 | inp_prompts = [
55 | el + tok.decode(target_tok[:i])
56 | for el in inp_prompts_og
57 | for i in range(len(target_tok))
58 | ]
59 | inp_targets = [
60 | tok.decode(target_tok[i])
61 | for _ in range(len(inp_prompts_og))
62 | for i in range(len(target_tok))
63 | ]
64 |
65 | stuff_probs = test_batch_prediction_acc(model, tok, inp_prompts, inp_targets)
66 |
67 | # Predict for neighborhood prompts (dictionary format).
68 | neighborhood_correct = test_batch_prediction_acc(
69 | model,
70 | tok,
71 | [
72 | el["prompt"].format(record["requested_rewrite"])
73 | for el in neighborhood_prompts
74 | ],
75 | [el["target"] for el in neighborhood_prompts],
76 | )
77 |
78 | probs = stuff_probs + neighborhood_correct
79 |
80 | # Unflatten the results again into a list of lists.
81 | cutoffs = [0] + np.cumsum(
82 | [l * len(target_tok) for l in map(len, prob_prompts)]
83 | ).tolist()
84 | ret_probs = [probs[cutoffs[i - 1] : cutoffs[i]] for i in range(1, len(cutoffs))]
85 | # Structure the restuls as a dictionary.
86 | ret = {
87 | f"{key}_correct": ret_probs[i]
88 | for i, key in enumerate(
89 | [
90 | "rewrite_prompts",
91 | "paraphrase_prompts",
92 | ]
93 | )
94 | }
95 | ret["neighborhood_prompts_correct"] = neighborhood_correct
96 |
97 | return ret
98 |
99 |
100 | def test_batch_prediction_acc(model, tok, prompts: typing.List[str], target):
101 | prompt_tok = tok(
102 | prompts,
103 | padding=True,
104 | return_tensors="pt",
105 | ).to("cuda")
106 |
107 | with torch.no_grad():
108 | logits = model(**prompt_tok).logits
109 | last_non_masked = prompt_tok["attention_mask"].sum(1) - 1
110 | to_gather = last_non_masked.unsqueeze(1).repeat(1, logits.size(-1)).unsqueeze(1)
111 | gathered = torch.gather(logits, 1, to_gather).squeeze(1)
112 | ans = torch.argmax(gathered, dim=1)
113 |
114 | correct_id = tok(target, padding=True, return_tensors="pt").to("cuda")[
115 | "input_ids"
116 | ]
117 | # Temporary hack to deal with foreign characters.
118 | correct_id = correct_id[:, 0].squeeze()
119 |
120 | return (ans == correct_id).detach().cpu().numpy().tolist()
121 |
--------------------------------------------------------------------------------
/memit/experiments/sweep.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | from copy import deepcopy
5 | from pathlib import Path
6 | from typing import Dict, List, Tuple
7 |
8 | import torch
9 | from transformers import AutoModelForCausalLM, AutoTokenizer
10 |
11 | from experiments.evaluate import HPARAMS_DIR
12 | from experiments.evaluate import main as eval_main
13 |
14 | TMP_PARAMS_TEMPLATE = "sweep_params_tmp_{}_.json"
15 |
16 |
17 | def exec_sweep(
18 | alg_name: str,
19 | model_tok: Tuple[AutoModelForCausalLM, AutoTokenizer],
20 | hparams_fname: str,
21 | ds_name: str,
22 | sweep_dir: Path,
23 | num_records: int,
24 | generation_test_interval: bool,
25 | num_edits: int,
26 | use_cache: bool,
27 | ):
28 | # Configure hparams
29 | with open(HPARAMS_DIR / alg_name / hparams_fname, "r") as f:
30 | hparams_orig = json.load(f)
31 | with open(Path("results") / sweep_dir / "config.json", "r") as f:
32 | sweep_config = json.load(f)
33 | sweep_keys = list(sweep_config.keys())
34 |
35 | # Sweep
36 | for s_i, state in enumerate(get_states([], sweep_config, sweep_keys)):
37 | # Set dirs
38 | tmp_params_name = TMP_PARAMS_TEMPLATE.format(time.time_ns())
39 | tmp_params_path = HPARAMS_DIR / alg_name / tmp_params_name
40 |
41 | # Set new hparams
42 | hparams_new = deepcopy(hparams_orig)
43 | for key_num, state_num in enumerate(state):
44 | k = sweep_keys[key_num]
45 | hparams_new[k] = sweep_config[k][state_num]
46 | print(f"Sweep {s_i}: Setting {k} = {hparams_new[k]}")
47 |
48 | with open(tmp_params_path, "w") as f:
49 | json.dump(hparams_new, f)
50 |
51 | # Execute
52 | eval_main(
53 | alg_name,
54 | model_name=model_tok,
55 | hparams_fname=tmp_params_name,
56 | ds_name=ds_name,
57 | dataset_size_limit=num_records,
58 | continue_from_run="run_000",
59 | skip_generation_tests=(generation_test_interval == -1),
60 | generation_test_interval=generation_test_interval,
61 | conserve_memory=False,
62 | dir_name=sweep_dir / f"{num_edits}_edits_setting_{s_i}",
63 | num_edits=num_edits,
64 | use_cache=use_cache,
65 | )
66 |
67 | # Clean up
68 | os.remove(tmp_params_path)
69 |
70 |
71 | def get_states(
72 | state: List,
73 | sweep_config: Dict,
74 | sweep_keys: List,
75 | ):
76 | """
77 | Standard recursive procedure for generating all possible configurations.
78 | """
79 |
80 | ans = []
81 | if len(state) < len(sweep_config):
82 | for i in range(len(sweep_config[sweep_keys[len(state)]])):
83 | for s in get_states(state + [i], sweep_config, sweep_keys):
84 | ans.append(s)
85 | else:
86 | ans.append(state)
87 | return ans
88 |
89 |
90 | if __name__ == "__main__":
91 | import argparse
92 |
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument(
95 | "--alg_name", choices=["MEMIT", "FT", "ROME", "MEND"], required=True
96 | )
97 | parser.add_argument(
98 | "--model_name", choices=["gpt2-xl", "EleutherAI/gpt-j-6B"], required=True
99 | )
100 | parser.add_argument("--hparams_fname", type=str, required=True)
101 | parser.add_argument(
102 | "--ds_name",
103 | choices=["mcf", "cf", "zsre"],
104 | default="mcf",
105 | help="Dataset to perform evaluations on. Either CounterFact (cf), MultiCounterFact (mcf), or zsRE (zsre).",
106 | )
107 | parser.add_argument("--min_records", type=int, default=None)
108 | parser.add_argument("--max_records", type=int, default=None)
109 | parser.add_argument(
110 | "--num_edits",
111 | type=str,
112 | default="1",
113 | help="Number of rewrites to perform simultaneously.",
114 | )
115 | parser.add_argument(
116 | "--generation_test_interval",
117 | type=int,
118 | default=-1,
119 | help="One generation test is performed every [flag_value] iterations. If -1, generation tests are skipped.",
120 | )
121 | parser.add_argument("--sweep_dir", type=str)
122 | parser.add_argument(
123 | "--use_cache",
124 | dest="use_cache",
125 | action="store_true",
126 | help="Use cached k/v pairs (MEMIT and ROME only)",
127 | )
128 |
129 | args = parser.parse_args()
130 | assert args.sweep_dir is not None, f"Must specify a sweep_dir."
131 |
132 | model = AutoModelForCausalLM.from_pretrained(args.model_name).to("cuda")
133 | tok = AutoTokenizer.from_pretrained(args.model_name)
134 | tok.pad_token = tok.eos_token
135 |
136 | for cur_num_edits in list(map(int, args.num_edits.split(","))):
137 | torch.cuda.empty_cache()
138 |
139 | num_records = (
140 | None if args.max_records is None
141 | else min(args.max_records, cur_num_edits)
142 | )
143 | if args.min_records is not None:
144 | num_records = max(args.min_records, cur_num_edits)
145 |
146 | exec_sweep(
147 | args.alg_name,
148 | (model, tok),
149 | args.hparams_fname,
150 | args.ds_name,
151 | Path(args.sweep_dir),
152 | num_records,
153 | args.generation_test_interval,
154 | cur_num_edits,
155 | args.use_cache,
156 | )
157 |
--------------------------------------------------------------------------------
/memit/forget_memit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # 循环调用程序,传递不同的次序参数
4 | #for i in {0..94}
5 | #do
6 | # #python -m experiments.evaluate --order=$i
7 | # python -m experiments.evaluate --order=$i
8 | #
9 | #done
10 |
11 | for i in 16 18 21 26 27 38 42 47 49 54
12 | do
13 | #python -m experiments.evaluate --order=$i
14 | python -m experiments.memit_jailbreak_evaluate --order=$i
15 |
16 | done
--------------------------------------------------------------------------------
/memit/forget_memit_olmo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # 循环调用程序,传递不同的次序参数
4 | for i in {37..161}
5 | do
6 | python -m experiments.evaluate_olmo --order=$i
7 |
8 | done
9 |
10 | #for i in 4 37 40 44 59 77 90 105 141 147
11 | #do
12 | # python -m experiments.memit_jailbreak_evaluate_olmo --order=$i
13 | #
14 | #done
15 |
--------------------------------------------------------------------------------
/memit/globals.yml:
--------------------------------------------------------------------------------
1 | ---
2 | # Result files
3 | RESULTS_DIR: "results"
4 |
5 | # Data files
6 | DATA_DIR: "data"
7 | STATS_DIR: "data/stats"
8 | KV_DIR: "/share/projects/rewriting-knowledge/kvs"
9 |
10 | # Hyperparameters
11 | HPARAMS_DIR: "hparams"
12 |
13 | # Remote URLs
14 | REMOTE_ROOT_URL: "https://memit.baulab.info"
15 |
--------------------------------------------------------------------------------
/memit/hparams/FT/EleutherAI_gpt-j-6B_constr.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 0
4 | ],
5 | "num_steps": 25,
6 | "lr": 5e-4,
7 | "weight_decay": 0,
8 | "kl_factor": 0,
9 | "norm_constraint": 5e-5,
10 | "rewrite_module_tmp": "transformer.h.{}.mlp.fc_out",
11 | "layer_module_tmp": "transformer.h.{}",
12 | "mlp_module_tmp": "transformer.h.{}.mlp",
13 | "attn_module_tmp": "transformer.h.{}.attn",
14 | "ln_f_module": "transformer.ln_f",
15 | "lm_head_module": "lm_head"
16 | }
--------------------------------------------------------------------------------
/memit/hparams/FT/EleutherAI_gpt-j-6B_unconstr.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 21
4 | ],
5 | "num_steps": 25,
6 | "lr": 5e-4,
7 | "weight_decay": 0,
8 | "kl_factor": 0,
9 | "norm_constraint": false,
10 | "rewrite_module_tmp": "transformer.h.{}.mlp.fc_out",
11 | "layer_module_tmp": "transformer.h.{}",
12 | "mlp_module_tmp": "transformer.h.{}.mlp",
13 | "attn_module_tmp": "transformer.h.{}.attn",
14 | "ln_f_module": "transformer.ln_f",
15 | "lm_head_module": "lm_head"
16 | }
--------------------------------------------------------------------------------
/memit/hparams/FT/EleutherAI_gpt-j-6B_wd.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 21
4 | ],
5 | "num_steps": 25,
6 | "lr": 5e-4,
7 | "weight_decay": false,
8 | "wd_power_law": [-0.87028798, 0.15589562],
9 | "kl_factor": 0,
10 | "norm_constraint": false,
11 | "rewrite_module_tmp": "transformer.h.{}",
12 | "layer_module_tmp": "transformer.h.{}",
13 | "mlp_module_tmp": "transformer.h.{}.mlp",
14 | "attn_module_tmp": "transformer.h.{}.attn",
15 | "ln_f_module": "transformer.ln_f",
16 | "lm_head_module": "lm_head"
17 | }
--------------------------------------------------------------------------------
/memit/hparams/FT/gpt2-large_constr.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 0
4 | ],
5 | "num_steps": 25,
6 | "lr": 5e-4,
7 | "weight_decay": 0,
8 | "kl_factor": 0,
9 | "norm_constraint": 1e-3,
10 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj",
11 | "layer_module_tmp": "transformer.h.{}",
12 | "mlp_module_tmp": "transformer.h.{}.mlp",
13 | "attn_module_tmp": "transformer.h.{}.attn",
14 | "ln_f_module": "transformer.ln_f",
15 | "lm_head_module": "transformer.wte"
16 | }
--------------------------------------------------------------------------------
/memit/hparams/FT/gpt2-medium_constr.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 0
4 | ],
5 | "num_steps": 25,
6 | "lr": 5e-4,
7 | "weight_decay": 0,
8 | "kl_factor": 0,
9 | "norm_constraint": 2e-3,
10 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj",
11 | "layer_module_tmp": "transformer.h.{}",
12 | "mlp_module_tmp": "transformer.h.{}.mlp",
13 | "attn_module_tmp": "transformer.h.{}.attn",
14 | "ln_f_module": "transformer.ln_f",
15 | "lm_head_module": "transformer.wte"
16 | }
--------------------------------------------------------------------------------
/memit/hparams/FT/gpt2-xl_attn.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 33
4 | ],
5 | "num_steps": 25,
6 | "lr": 1e-3,
7 | "weight_decay": 0,
8 | "kl_factor": 0,
9 | "norm_constraint": 1e-3,
10 | "rewrite_module_tmp": "transformer.h.{}.attn.c_attn",
11 | "layer_module_tmp": "transformer.h.{}",
12 | "mlp_module_tmp": "transformer.h.{}.mlp",
13 | "attn_module_tmp": "transformer.h.{}.attn",
14 | "ln_f_module": "transformer.ln_f",
15 | "lm_head_module": "transformer.wte"
16 | }
--------------------------------------------------------------------------------
/memit/hparams/FT/gpt2-xl_constr.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 0
4 | ],
5 | "num_steps": 25,
6 | "lr": 5e-4,
7 | "weight_decay": 0,
8 | "kl_factor": 0,
9 | "norm_constraint": 5e-4,
10 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj",
11 | "layer_module_tmp": "transformer.h.{}",
12 | "mlp_module_tmp": "transformer.h.{}.mlp",
13 | "attn_module_tmp": "transformer.h.{}.attn",
14 | "ln_f_module": "transformer.ln_f",
15 | "lm_head_module": "transformer.wte"
16 | }
--------------------------------------------------------------------------------
/memit/hparams/FT/gpt2-xl_unconstr.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 1
4 | ],
5 | "num_steps": 25,
6 | "lr": 5e-4,
7 | "weight_decay": 0,
8 | "kl_factor": 0,
9 | "norm_constraint": false,
10 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj",
11 | "layer_module_tmp": "transformer.h.{}",
12 | "mlp_module_tmp": "transformer.h.{}.mlp",
13 | "attn_module_tmp": "transformer.h.{}.attn",
14 | "ln_f_module": "transformer.ln_f",
15 | "lm_head_module": "transformer.wte"
16 | }
--------------------------------------------------------------------------------
/memit/hparams/MEMIT/EleutherAI_gpt-j-6B.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 3, 4, 5, 6, 7, 8
4 | ],
5 | "clamp_norm_factor": 0.75,
6 | "layer_selection": "all",
7 | "fact_token": "subject_last",
8 | "v_num_grad_steps": 25,
9 | "v_lr": 5e-1,
10 | "v_loss_layer": 27,
11 | "v_weight_decay": 0.5,
12 | "kl_factor": 0.0625,
13 | "mom2_adjustment": true,
14 | "mom2_update_weight": 15000,
15 | "rewrite_module_tmp": "transformer.h.{}.mlp.fc_out",
16 | "layer_module_tmp": "transformer.h.{}",
17 | "mlp_module_tmp": "transformer.h.{}.mlp",
18 | "attn_module_tmp": "transformer.h.{}.attn",
19 | "ln_f_module": "transformer.ln_f",
20 | "lm_head_module": "lm_head",
21 | "mom2_dataset": "wikipedia",
22 | "mom2_n_samples": 100000,
23 | "mom2_dtype": "float32"
24 | }
25 |
--------------------------------------------------------------------------------
/memit/hparams/MEMIT/gpt2-xl.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [13, 14, 15, 16, 17],
3 | "clamp_norm_factor": 0.75,
4 | "layer_selection": "all",
5 | "fact_token": "subject_last",
6 | "v_num_grad_steps": 20,
7 | "v_lr": 5e-1,
8 | "v_loss_layer": 47,
9 | "v_weight_decay": 0.5,
10 | "kl_factor": 0.0625,
11 | "mom2_adjustment": true,
12 | "mom2_update_weight": 20000,
13 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj",
14 | "layer_module_tmp": "transformer.h.{}",
15 | "mlp_module_tmp": "transformer.h.{}.mlp",
16 | "attn_module_tmp": "transformer.h.{}.attn",
17 | "ln_f_module": "transformer.ln_f",
18 | "lm_head_module": "transformer.wte",
19 | "mom2_dataset": "wikipedia",
20 | "mom2_n_samples": 100000,
21 | "mom2_dtype": "float32"
22 | }
--------------------------------------------------------------------------------
/memit/hparams/MEMIT/llama2-7b.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [23],
3 | "clamp_norm_factor": 4,
4 | "layer_selection": "all",
5 | "fact_token": "subject_last",
6 | "v_num_grad_steps": 25,
7 | "v_lr": 5e-1,
8 | "v_loss_layer": 31,
9 | "v_weight_decay": 1e-3,
10 | "kl_factor": 0.0625,
11 | "mom2_adjustment": true,
12 | "mom2_update_weight": 15000,
13 | "rewrite_module_tmp": "model.layers.{}.mlp.down_proj",
14 | "layer_module_tmp": "model.layers.{}",
15 | "mlp_module_tmp": "model.layers.{}.mlp",
16 | "attn_module_tmp": "model.layers.{}.self_attn",
17 | "ln_f_module": "model.norm",
18 | "lm_head_module": "lm_head",
19 | "mom2_dataset": "wikipedia",
20 | "mom2_n_samples": 100000,
21 | "mom2_dtype": "float32"
22 | }
--------------------------------------------------------------------------------
/memit/hparams/MEMIT/olmo-7b.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [23],
3 | "clamp_norm_factor": 4,
4 | "layer_selection": "all",
5 | "fact_token": "subject_last",
6 | "v_num_grad_steps": 25,
7 | "v_lr": 5e-1,
8 | "v_loss_layer": 31,
9 | "v_weight_decay": 1e-3,
10 | "kl_factor": 0.0625,
11 | "mom2_adjustment": true,
12 | "mom2_update_weight": 15000,
13 | "rewrite_module_tmp": "model.transformer.blocks.{}.ff_out",
14 | "layer_module_tmp": "model.transformer.blocks.{}",
15 | "mlp_module_tmp": "model.transformer.blocks.{}.mlp",
16 | "attn_module_tmp": "model.transformer.blocks.{}.self_attn",
17 | "ln_f_module": "model.transformer.ln_f",
18 | "lm_head_module": "model.transformer.ff_out",
19 | "mom2_dataset": "wikipedia",
20 | "mom2_n_samples": 100000,
21 | "mom2_dtype": "float32"
22 | }
23 |
24 |
--------------------------------------------------------------------------------
/memit/hparams/MEND/EleutherAI_gpt-j-6B.json:
--------------------------------------------------------------------------------
1 | {
2 | "lr_scale": 1.0,
3 | "n_toks": 10,
4 | "model_name": "EleutherAI/gpt-j-6B",
5 | "counterfact": false,
6 | "zsre": false,
7 | "mini": false
8 | }
--------------------------------------------------------------------------------
/memit/hparams/MEND/EleutherAI_gpt-j-6B_CF.json:
--------------------------------------------------------------------------------
1 | {
2 | "lr_scale": 1.0,
3 | "n_toks": 10,
4 | "model_name": "EleutherAI/gpt-j-6B",
5 | "counterfact": true,
6 | "zsre": false,
7 | "mini": false
8 | }
--------------------------------------------------------------------------------
/memit/hparams/MEND/gpt2-xl.json:
--------------------------------------------------------------------------------
1 | {
2 | "lr_scale": 1.0,
3 | "n_toks": 10,
4 | "model_name": "gpt2-xl",
5 | "counterfact": false,
6 | "zsre": false,
7 | "mini": false
8 | }
--------------------------------------------------------------------------------
/memit/hparams/MEND/gpt2-xl_CF.json:
--------------------------------------------------------------------------------
1 | {
2 | "lr_scale": 1.0,
3 | "n_toks": 1,
4 | "model_name": "gpt2-xl",
5 | "counterfact": true,
6 | "zsre": false,
7 | "mini": false
8 | }
--------------------------------------------------------------------------------
/memit/hparams/MEND/gpt2-xl_zsRE.json:
--------------------------------------------------------------------------------
1 | {
2 | "lr_scale": 1.0,
3 | "n_toks": 1,
4 | "model_name": "gpt2-xl",
5 | "counterfact": false,
6 | "zsre": true,
7 | "mini": false
8 | }
--------------------------------------------------------------------------------
/memit/hparams/ROME/EleutherAI_gpt-j-6B.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 5
4 | ],
5 | "fact_token": "subject_last",
6 | "v_num_grad_steps": 20,
7 | "v_lr": 5e-1,
8 | "v_loss_layer": 27,
9 | "v_weight_decay": 0.5,
10 | "clamp_norm_factor": 4,
11 | "kl_factor": 0.0625,
12 | "mom2_adjustment": true,
13 | "context_template_length_params": [[5, 10], [10, 10]],
14 | "rewrite_module_tmp": "transformer.h.{}.mlp.fc_out",
15 | "layer_module_tmp": "transformer.h.{}",
16 | "mlp_module_tmp": "transformer.h.{}.mlp",
17 | "attn_module_tmp": "transformer.h.{}.attn",
18 | "ln_f_module": "transformer.ln_f",
19 | "lm_head_module": "lm_head",
20 | "mom2_dataset": "wikipedia",
21 | "mom2_n_samples": 100000,
22 | "mom2_dtype": "float32"
23 | }
--------------------------------------------------------------------------------
/memit/hparams/ROME/EleutherAI_gpt-neox-20b.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 15
4 | ],
5 | "fact_token": "subject_last",
6 | "v_num_grad_steps": 20,
7 | "v_lr": 5e-1,
8 | "v_loss_layer": 43,
9 | "v_weight_decay": 0.5,
10 | "clamp_norm_factor": 4,
11 | "kl_factor": 0.0625,
12 | "mom2_adjustment": true,
13 | "context_template_length_params": [[5, 10], [10, 10]],
14 | "rewrite_module_tmp": "gpt_neox.layers.{}.mlp.dense_4h_to_h",
15 | "layer_module_tmp": "gpt_neox.layers.{}",
16 | "mlp_module_tmp": "gpt_neox.layers.{}.mlp",
17 | "attn_module_tmp": "gpt_neox.layers.{}.attention",
18 | "ln_f_module": "gpt_neox.final_layer_norm",
19 | "lm_head_module": "embed_out",
20 | "mom2_dataset": "wikipedia",
21 | "mom2_n_samples": 100000,
22 | "mom2_dtype": "float32"
23 | }
--------------------------------------------------------------------------------
/memit/hparams/ROME/gpt2-large.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 12
4 | ],
5 | "fact_token": "subject_last",
6 | "v_num_grad_steps": 20,
7 | "v_lr": 5e-1,
8 | "v_loss_layer": 35,
9 | "v_weight_decay": 0.5,
10 | "clamp_norm_factor": 4,
11 | "kl_factor": 0.0625,
12 | "mom2_adjustment": true,
13 | "context_template_length_params": [[5, 10], [10, 10]],
14 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj",
15 | "layer_module_tmp": "transformer.h.{}",
16 | "mlp_module_tmp": "transformer.h.{}.mlp",
17 | "attn_module_tmp": "transformer.h.{}.attn",
18 | "ln_f_module": "transformer.ln_f",
19 | "lm_head_module": "transformer.wte",
20 | "mom2_dataset": "wikipedia",
21 | "mom2_n_samples": 100000,
22 | "mom2_dtype": "float32"
23 | }
--------------------------------------------------------------------------------
/memit/hparams/ROME/gpt2-medium.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 8
4 | ],
5 | "fact_token": "subject_last",
6 | "v_num_grad_steps": 20,
7 | "v_lr": 5e-1,
8 | "v_loss_layer": 23,
9 | "v_weight_decay": 0.5,
10 | "clamp_norm_factor": 3,
11 | "kl_factor": 0.0625,
12 | "mom2_adjustment": true,
13 | "context_template_length_params": [[5, 10], [10, 10]],
14 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj",
15 | "layer_module_tmp": "transformer.h.{}",
16 | "mlp_module_tmp": "transformer.h.{}.mlp",
17 | "attn_module_tmp": "transformer.h.{}.attn",
18 | "ln_f_module": "transformer.ln_f",
19 | "lm_head_module": "transformer.wte",
20 | "mom2_dataset": "wikipedia",
21 | "mom2_n_samples": 100000,
22 | "mom2_dtype": "float32"
23 | }
--------------------------------------------------------------------------------
/memit/hparams/ROME/gpt2-xl.json:
--------------------------------------------------------------------------------
1 | {
2 | "layers": [
3 | 17
4 | ],
5 | "fact_token": "subject_last",
6 | "v_num_grad_steps": 20,
7 | "v_lr": 5e-1,
8 | "v_loss_layer": 47,
9 | "v_weight_decay": 0.5,
10 | "clamp_norm_factor": 4,
11 | "kl_factor": 0.0625,
12 | "mom2_adjustment": true,
13 | "context_template_length_params": [[5, 10], [10, 10]],
14 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj",
15 | "layer_module_tmp": "transformer.h.{}",
16 | "mlp_module_tmp": "transformer.h.{}.mlp",
17 | "attn_module_tmp": "transformer.h.{}.attn",
18 | "ln_f_module": "transformer.ln_f",
19 | "lm_head_module": "transformer.wte",
20 | "mom2_dataset": "wikipedia",
21 | "mom2_n_samples": 100000,
22 | "mom2_dtype": "float32"
23 | }
--------------------------------------------------------------------------------
/memit/memit/__init__.py:
--------------------------------------------------------------------------------
1 | from .memit_main import MEMITHyperParams, apply_memit_to_model
--------------------------------------------------------------------------------
/memit/memit/compute_ks.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | import numpy as np
4 | import torch
5 | from transformers import AutoModelForCausalLM, AutoTokenizer
6 |
7 | from .compute_z import get_module_input_output_at_words
8 | from .memit_hparams import MEMITHyperParams
9 |
10 |
11 | def compute_ks(
12 | model: AutoModelForCausalLM,
13 | tok: AutoTokenizer,
14 | requests: Dict,
15 | hparams: MEMITHyperParams,
16 | layer: int,
17 | context_templates: List[str],
18 | ):
19 | layer_ks = get_module_input_output_at_words(
20 | model,
21 | tok,
22 | layer,
23 | context_templates=[
24 | context.format(request["prompt"])
25 | for request in requests
26 | for context_type in context_templates
27 | for context in context_type
28 | ],
29 | words=[
30 | request["subject"]
31 | for request in requests
32 | for context_type in context_templates
33 | for _ in context_type
34 | ],
35 | module_template=hparams.rewrite_module_tmp,
36 | fact_token_strategy=hparams.fact_token,
37 | )[0]
38 |
39 | context_type_lens = [0] + [len(context_type) for context_type in context_templates]
40 | context_len = sum(context_type_lens)
41 | context_type_csum = np.cumsum(context_type_lens).tolist()
42 |
43 | ans = []
44 | for i in range(0, layer_ks.size(0), context_len):
45 | tmp = []
46 | for j in range(len(context_type_csum) - 1):
47 | start, end = context_type_csum[j], context_type_csum[j + 1]
48 | tmp.append(layer_ks[i + start : i + end].mean(0))
49 | ans.append(torch.stack(tmp, 0).mean(0))
50 | return torch.stack(ans, dim=0)
51 |
--------------------------------------------------------------------------------
/memit/memit/memit_hparams.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Literal
3 |
4 | from util.hparams import HyperParams
5 |
6 |
7 | @dataclass
8 | class MEMITHyperParams(HyperParams):
9 | # Method
10 | layers: List[int]
11 | layer_selection: Literal["all", "random"]
12 | fact_token: Literal[
13 | "last", "subject_first", "subject_last", "subject_first_after_last"
14 | ]
15 | v_num_grad_steps: int
16 | v_lr: float
17 | v_loss_layer: int
18 | v_weight_decay: float
19 | clamp_norm_factor: float
20 | kl_factor: float
21 | mom2_adjustment: bool
22 | mom2_update_weight: float
23 |
24 | # Module templates
25 | rewrite_module_tmp: str
26 | layer_module_tmp: str
27 | mlp_module_tmp: str
28 | attn_module_tmp: str
29 | ln_f_module: str
30 | lm_head_module: str
31 |
32 | # Statistics
33 | mom2_dataset: str
34 | mom2_n_samples: int
35 | mom2_dtype: str
36 |
--------------------------------------------------------------------------------
/memit/notebooks/vis/visualize_multi_results.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "d465f696",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%load_ext autoreload\n",
11 | "%autoreload 2"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "id": "9bdfca4c",
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "%matplotlib inline\n",
22 | "%config InlineBackend.figure_format = 'retina'\n",
23 | "import matplotlib.pyplot as plt\n",
24 | "from experiments.summarize import main as summarize_main\n",
25 | "from pathlib import Path\n",
26 | "import math"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "id": "451eb471",
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "RESULTS_DIR = Path(\"results/iclr\")\n",
37 | "DATA = {}\n",
38 | "KEYS = None\n",
39 | "for method_dir in RESULTS_DIR.iterdir():\n",
40 | " method_name = str(method_dir).split(\"/\")[-1]\n",
41 | " print(method_name)\n",
42 | " n_edit_folders = list(method_dir.glob(\"*_edits_setting_*\"))\n",
43 | " for n_edit_folder in n_edit_folders:\n",
44 | " n_edits = str(n_edit_folder.name).split(\"/\")[-1].split(\"_\")[0]\n",
45 | " try:\n",
46 | " res = summarize_main(n_edit_folder.relative_to(\"results\"), [\"run_000\"])[0]\n",
47 | "\n",
48 | " DATA[method_name] = DATA.get(method_name, {})\n",
49 | " DATA[method_name][n_edits] = res\n",
50 | " if KEYS is None:\n",
51 | " KEYS = list(res.keys())\n",
52 | " except:\n",
53 | " pass\n",
54 | "\n",
55 | "print({k: list(v.keys()) for k, v in DATA.items()})"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "id": "7b9f0860",
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "plt.rcParams[\"figure.dpi\"] = 200\n",
66 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
67 | "\n",
68 | "SMALL_SIZE = 14\n",
69 | "MEDIUM_SIZE = 15\n",
70 | "BIGGER_SIZE = 16\n",
71 | "\n",
72 | "plt.rc(\"font\", size=SMALL_SIZE) # controls default text sizes\n",
73 | "plt.rc(\"axes\", titlesize=BIGGER_SIZE) # fontsize of the axes title\n",
74 | "plt.rc(\"axes\", labelsize=MEDIUM_SIZE) # fontsize of the x and y labels\n",
75 | "plt.rc(\"xtick\", labelsize=SMALL_SIZE) # fontsize of the tick labels\n",
76 | "plt.rc(\"ytick\", labelsize=SMALL_SIZE) # fontsize of the tick labels\n",
77 | "plt.rc(\"legend\", fontsize=SMALL_SIZE) # legend fontsize\n",
78 | "plt.rc(\"figure\", titlesize=BIGGER_SIZE) # fontsize of the figure title"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "id": "d8b41acc",
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "TITLES = {\n",
89 | " \"post_score\": \"Score (S)\",\n",
90 | " \"post_rewrite_success\": \"Efficacy Succ. (ES)\",\n",
91 | " \"post_paraphrase_success\": \"Generalization Succ. (PS)\",\n",
92 | " \"post_neighborhood_success\": \"Specificity Succ. (NS)\",\n",
93 | " \"post_rewrite_acc\": \"Efficacy Acc (EA)\",\n",
94 | " \"post_paraphrase_acc\": \"Generalization Acc. (PA)\",\n",
95 | " \"post_neighborhood_acc\": \"Specificity Acc. (NA)\",\n",
96 | " \"post_reference_score\": \"Consistency (RS)\",\n",
97 | "}\n",
98 | "\n",
99 | "SHOW_KEYS = list(TITLES.keys())"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "id": "a1d443f7",
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "SHOW_KEYS = KEYS\n",
110 | "SHOW_KEYS.pop(SHOW_KEYS.index(\"run_dir\"))\n",
111 | "TITLES = {k: k for k in SHOW_KEYS}"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "id": "49efeea0",
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "w = 4\n",
122 | "h = math.ceil(len(KEYS) / w)\n",
123 | "plt.figure(figsize=(w * 3.5, h * 2.5))\n",
124 | "\n",
125 | "assert all(k in KEYS for k in SHOW_KEYS)\n",
126 | "for i, key in enumerate(SHOW_KEYS):\n",
127 | " plt.subplot(h, w, i + 1)\n",
128 | " for method, results in sorted([(k, v) for k, v in DATA.items() if \"_fix\" not in k]):\n",
129 | " try:\n",
130 | " n_edits = list(map(int, results.keys()))\n",
131 | " values = [\n",
132 | " f[0] if (type(f := results[str(n)][key]) is tuple) else f\n",
133 | " for n in n_edits\n",
134 | " ]\n",
135 | " plt.plot(n_edits, values, marker=\"o\", markersize=4, label=method)\n",
136 | " plt.xlabel(\"# Edits\")\n",
137 | " # plt.ylabel(\"metric value\")\n",
138 | " plt.title(TITLES[key])\n",
139 | " plt.legend()\n",
140 | " except:\n",
141 | " pass\n",
142 | "plt.tight_layout()\n",
143 | "plt.savefig(\"tmp.pdf\")\n",
144 | "plt.show()"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "id": "ae8e7ea4",
151 | "metadata": {},
152 | "outputs": [],
153 | "source": []
154 | }
155 | ],
156 | "metadata": {
157 | "accelerator": "GPU",
158 | "kernelspec": {
159 | "display_name": "Python 3 (ipykernel)",
160 | "language": "python",
161 | "name": "python3"
162 | },
163 | "language_info": {
164 | "codemirror_mode": {
165 | "name": "ipython",
166 | "version": 3
167 | },
168 | "file_extension": ".py",
169 | "mimetype": "text/x-python",
170 | "name": "python",
171 | "nbconvert_exporter": "python",
172 | "pygments_lexer": "ipython3",
173 | "version": "3.9.7"
174 | },
175 | "vscode": {
176 | "interpreter": {
177 | "hash": "2c3ec9f9cb0aa45979d92499665f4b05f2a3528d3b2ca0efacea2020d32b93f4"
178 | }
179 | }
180 | },
181 | "nbformat": 4,
182 | "nbformat_minor": 5
183 | }
184 |
--------------------------------------------------------------------------------
/memit/notebooks/vis/visualize_sweep_results.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "d465f696",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%load_ext autoreload\n",
11 | "%autoreload 2"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "id": "9bdfca4c",
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "# enable high-resolution figure\n",
22 | "%matplotlib inline\n",
23 | "%config InlineBackend.figure_format = 'retina'\n",
24 | "import matplotlib.pyplot as plt\n",
25 | "from experiments.summarize import main as summarize_main\n",
26 | "from pathlib import Path\n",
27 | "import math"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "id": "451eb471",
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "RESULTS_DIR = Path(\"results/sweeps\")\n",
38 | "DATA = {}\n",
39 | "KEYS = None\n",
40 | "for method_dir in RESULTS_DIR.iterdir():\n",
41 | " method_name = str(method_dir).split(\"/\")[-1]\n",
42 | " print(method_name)\n",
43 | " n_edit_folders = list(method_dir.glob(\"*_edits_setting_*\"))\n",
44 | " for n_edit_folder in n_edit_folders:\n",
45 | " n_edits = int(str(n_edit_folder.name).split(\"/\")[-1].split(\"_\")[0])\n",
46 | " setting_id = str(n_edit_folder.name).split(\"/\")[-1].split(\"_\")[-1]\n",
47 | " try:\n",
48 | " res = summarize_main(n_edit_folder.relative_to(\"results\"), [\"run_000\"])[0]\n",
49 | "\n",
50 | " DATA[method_name] = DATA.get(method_name, {})\n",
51 | " DATA[method_name][n_edits] = DATA[method_name].get(n_edits, {})\n",
52 | " DATA[method_name][n_edits][setting_id] = res\n",
53 | "\n",
54 | " if KEYS is None:\n",
55 | " KEYS = list(res.keys())\n",
56 | " except:\n",
57 | " pass\n",
58 | "\n",
59 | "print({k: list(v.keys()) for k, v in DATA.items()})"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "id": "49efeea0",
66 | "metadata": {
67 | "scrolled": false
68 | },
69 | "outputs": [],
70 | "source": [
71 | "for method, all_n_edits in sorted([(k, v) for k, v in DATA.items()]):\n",
72 | " for n_edits, results in sorted([(k, v) for k, v in all_n_edits.items()]):\n",
73 | " w = 4\n",
74 | " h = math.ceil(len(KEYS) / w)\n",
75 | " plt.figure(figsize=(w * 3.5, h * 2.5))\n",
76 | " if \"run_dir\" in KEYS:\n",
77 | " KEYS.pop(KEYS.index(\"run_dir\"))\n",
78 | " for i, key in enumerate(KEYS):\n",
79 | " plt.subplot(w, h, i + 1)\n",
80 | "\n",
81 | " try:\n",
82 | " setting_ids = list(map(int, results.keys()))\n",
83 | " values = [\n",
84 | " f[0] if (type(f := results[str(n)][key]) is tuple) else f\n",
85 | " for n in setting_ids\n",
86 | " ]\n",
87 | " plt.plot(setting_ids, values, marker=\"o\", markersize=4, label=method)\n",
88 | " plt.xlabel(\"setting_id\")\n",
89 | " plt.ylabel(\"metric value\")\n",
90 | " plt.title(f\"{n_edits} edits: {key}\")\n",
91 | " plt.legend()\n",
92 | " except:\n",
93 | " pass\n",
94 | "\n",
95 | " plt.tight_layout()\n",
96 | " plt.show()"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "id": "ae8e7ea4",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": []
106 | }
107 | ],
108 | "metadata": {
109 | "accelerator": "GPU",
110 | "kernelspec": {
111 | "display_name": "Python 3 (ipykernel)",
112 | "language": "python",
113 | "name": "python3"
114 | },
115 | "language_info": {
116 | "codemirror_mode": {
117 | "name": "ipython",
118 | "version": 3
119 | },
120 | "file_extension": ".py",
121 | "mimetype": "text/x-python",
122 | "name": "python",
123 | "nbconvert_exporter": "python",
124 | "pygments_lexer": "ipython3",
125 | "version": "3.9.7"
126 | },
127 | "vscode": {
128 | "interpreter": {
129 | "hash": "2c3ec9f9cb0aa45979d92499665f4b05f2a3528d3b2ca0efacea2020d32b93f4"
130 | }
131 | }
132 | },
133 | "nbformat": 4,
134 | "nbformat_minor": 5
135 | }
136 |
--------------------------------------------------------------------------------
/memit/rome/README.md:
--------------------------------------------------------------------------------
1 | # ROME
2 | This package provides a self-contained implementation of Rank-One Model Editing (ROME).
3 |
4 | Recall that ROME's update consists of: $u$ selection, $v_*$ optimization, and $v$ insertion.
5 | * [`compute_u.py`](compute_u.py): Chooses a $u$ vector.
6 | * [`compute_v.py`](compute_v.py): Choose a $v_*$ via optimization, then computes $v$.
7 | * [`rome_main.py`](rome_main.py): Instruments main logic.
8 | * [`rome_params.py`](rome_hparams.py): Interface for specifying hyperparameters. Inherits from the base [`params.py`](../util/hparams.py) module.
9 |
10 | For estimating second moment statistics of keys ($C = KK$), we provide the `layer_stats` module. See the [main README](../README.md) for usage instructions.
11 | * [`layer_stats.py`](layer_stats.py): Logic for retrieving and caching key statistics.
12 | * [`tok_dataset.py`](tok_dataset.py): Utilities for creating a dataset of tokens.
--------------------------------------------------------------------------------
/memit/rome/__init__.py:
--------------------------------------------------------------------------------
1 | from .rome_main import ROMEHyperParams, apply_rome_to_model, execute_rome
2 |
--------------------------------------------------------------------------------
/memit/rome/compute_u.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from typing import Dict, List
4 |
5 | import torch
6 | from transformers import AutoModelForCausalLM, AutoTokenizer
7 |
8 | from rome import repr_tools
9 | from util.globals import *
10 |
11 | from .layer_stats import layer_stats
12 | from .rome_hparams import ROMEHyperParams
13 |
14 | # Cache variables
15 | inv_mom2_cache = {}
16 |
17 |
18 | def get_inv_cov(
19 | model: AutoModelForCausalLM,
20 | tok: AutoTokenizer,
21 | layer_name: str,
22 | mom2_dataset: str,
23 | mom2_n_samples: str,
24 | mom2_dtype: str,
25 | ) -> torch.Tensor:
26 | """
27 | Retrieves covariance statistics, then computes the algebraic inverse.
28 | Caches result for future use.
29 | """
30 |
31 | global inv_mom2_cache
32 |
33 | model_name = model.config._name_or_path.replace("/", "_")
34 | key = (model_name, layer_name)
35 |
36 | if key not in inv_mom2_cache:
37 | print(
38 | f"Retrieving inverse covariance statistics for {model_name} @ {layer_name}. "
39 | f"The result will be cached to avoid repetitive computation."
40 | )
41 | stat = layer_stats(
42 | model,
43 | tok,
44 | layer_name,
45 | STATS_DIR,
46 | mom2_dataset,
47 | to_collect=["mom2"],
48 | sample_size=mom2_n_samples,
49 | precision=mom2_dtype,
50 | )
51 | inv_mom2_cache[key] = torch.inverse(
52 | stat.mom2.moment().to("cuda")
53 | ).float() # Cast back to float32
54 |
55 | return inv_mom2_cache[key]
56 |
57 |
58 | def compute_u(
59 | model: AutoModelForCausalLM,
60 | tok: AutoTokenizer,
61 | request: Dict,
62 | hparams: ROMEHyperParams,
63 | layer: int,
64 | context_templates: List[str],
65 | ) -> torch.Tensor:
66 | """
67 | Computes the right vector used in constructing the rank-1 update matrix.
68 | """
69 |
70 | print("Computing left vector (u)...")
71 |
72 | # Compute projection token
73 | word_repr_args = dict(
74 | model=model,
75 | tok=tok,
76 | layer=layer,
77 | module_template=hparams.rewrite_module_tmp,
78 | track="in",
79 | )
80 | if "subject_" in hparams.fact_token and hparams.fact_token.index("subject_") == 0:
81 | word = request["subject"]
82 | print(f"Selected u projection object {word}")
83 | cur_repr = repr_tools.get_reprs_at_word_tokens(
84 | context_templates=[
85 | templ.format(request["prompt"]) for templ in context_templates
86 | ],
87 | words=[word for _ in range(len(context_templates))],
88 | subtoken=hparams.fact_token[len("subject_") :],
89 | **word_repr_args,
90 | ).mean(0)
91 | elif hparams.fact_token == "last":
92 | # Heuristic to choose last word. Not a huge deal if there's a minor
93 | # edge case (e.g. multi-token word) because the function below will
94 | # take the last token.
95 | cur_repr = repr_tools.get_reprs_at_idxs(
96 | contexts=[
97 | templ.format(request["prompt"].format(request["subject"]))
98 | for templ in context_templates
99 | ],
100 | idxs=[[-1] for _ in range(len(context_templates))],
101 | **word_repr_args,
102 | ).mean(0)
103 | print("Selected u projection token with last token")
104 | else:
105 | raise ValueError(f"fact_token={hparams.fact_token} not recognized")
106 |
107 | # Apply inverse second moment adjustment
108 | u = cur_repr
109 | if hparams.mom2_adjustment:
110 | u = get_inv_cov(
111 | model,
112 | tok,
113 | hparams.rewrite_module_tmp.format(layer),
114 | hparams.mom2_dataset,
115 | hparams.mom2_n_samples,
116 | hparams.mom2_dtype,
117 | ) @ u.unsqueeze(1)
118 | u = u.squeeze()
119 |
120 | return u / u.norm()
121 |
--------------------------------------------------------------------------------
/memit/rome/layer_stats.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import torch
5 | from datasets import load_dataset
6 | from tqdm.auto import tqdm
7 | from transformers import AutoModelForCausalLM, AutoTokenizer
8 |
9 | from util.globals import *
10 | from util.nethook import Trace, set_requires_grad
11 | from util.runningstats import CombinedStat, Mean, NormMean, SecondMoment, tally
12 |
13 | from .tok_dataset import (
14 | TokenizedDataset,
15 | dict_to_,
16 | flatten_masked_batch,
17 | length_collation,
18 | )
19 |
20 | STAT_TYPES = {
21 | "mom2": SecondMoment,
22 | "mean": Mean,
23 | "norm_mean": NormMean,
24 | }
25 |
26 |
27 | def main():
28 | """
29 | Command-line utility to precompute cached stats.
30 | """
31 | import argparse
32 |
33 | parser = argparse.ArgumentParser(description="ROME Statistics Collector")
34 |
35 | def aa(*args, **kwargs):
36 | parser.add_argument(*args, **kwargs)
37 |
38 | aa("--model_name", default="gpt2-xl", choices=["gpt2-xl", "EleutherAI/gpt-j-6B"])
39 | aa("--dataset", default="wikipedia", choices=["wikitext", "wikipedia"])
40 | aa("--layers", default=[17], type=lambda x: list(map(int, x.split(","))))
41 | aa("--to_collect", default=["mom2"], type=lambda x: x.split(","))
42 | aa("--sample_size", default=100000, type=lambda x: None if x == "all" else int(x))
43 | aa("--batch_tokens", default=None, type=lambda x: None if x == "any" else int(x))
44 | aa("--precision", default="float32", choices=["float64", "float32", "float16"])
45 | aa("--stats_dir", default=STATS_DIR)
46 | aa("--download", default=1, type=int, choices=[0, 1])
47 | args = parser.parse_args()
48 |
49 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
50 | model = AutoModelForCausalLM.from_pretrained(args.model_name).eval().cuda()
51 | set_requires_grad(False, model)
52 |
53 | for layer_num in args.layers:
54 | print(
55 | f"Computing stats for layer {layer_num} of {args.model_name} "
56 | f'over {args.sample_size or "all"} samples of {args.dataset}. '
57 | "Note, the statistics are collected over the inputs to the second MLP layer, "
58 | "or equivalently the outputs of the first MLP layer."
59 | )
60 | proj_layer_name = "c_proj" if "gpt2" in args.model_name else "fc_out"
61 | layer_name = f"transformer.h.{layer_num}.mlp.{proj_layer_name}"
62 |
63 | layer_stats(
64 | model,
65 | tokenizer,
66 | layer_name,
67 | args.stats_dir,
68 | args.dataset,
69 | args.to_collect,
70 | sample_size=args.sample_size,
71 | precision=args.precision,
72 | batch_tokens=args.batch_tokens,
73 | download=args.download,
74 | )
75 |
76 |
77 | def layer_stats(
78 | model,
79 | tokenizer,
80 | layer_name,
81 | stats_dir,
82 | ds_name,
83 | to_collect,
84 | model_name=None,
85 | sample_size=None,
86 | precision=None,
87 | batch_tokens=None,
88 | download=True,
89 | progress=tqdm,
90 | force_recompute=False,
91 | ):
92 | """
93 | Function to load or compute cached stats.
94 | """
95 |
96 | def get_ds():
97 | raw_ds = load_dataset(
98 | ds_name,
99 | dict(wikitext="wikitext-103-raw-v1", wikipedia="20200501.en")[ds_name],
100 | )
101 | if 'olmo' in model.name_or_path.lower():
102 | maxlen = model.config.max_sequence_length
103 | else:
104 | maxlen = model.config.max_position_embeddings
105 | if batch_tokens is not None and batch_tokens < maxlen:
106 | maxlen = batch_tokens
107 | return TokenizedDataset(raw_ds["train"], tokenizer, maxlen=maxlen)
108 |
109 | # Continue with computation of statistics
110 | batch_size = 100 # Examine this many dataset texts at once
111 | if 'olmo' in model.name_or_path.lower():
112 | npos = model.config.max_sequence_length
113 | else:
114 | npos = model.config.max_position_embeddings
115 |
116 | if batch_tokens is None:
117 | batch_tokens = npos * 3 # Sort and divide into batches with this many tokens
118 | if precision is None:
119 | precision = "float64"
120 | dtype = getattr(torch, precision)
121 | size_suffix = "" if sample_size is None else f"_{sample_size}"
122 | if batch_tokens < npos:
123 | size_suffix = "_t{batch_tokens}" + size_suffix
124 | if model_name is None:
125 | model_name = model.config._name_or_path.replace("/", "_")
126 |
127 | stats_dir = Path(stats_dir)
128 | file_extension = f"{model_name}/{ds_name}_stats/{layer_name}_{precision}_{'-'.join(sorted(to_collect))}{size_suffix}.npz"
129 | filename = stats_dir / file_extension
130 |
131 | if not filename.exists() and download:
132 | remote_url = f"{REMOTE_ROOT_URL}/data/stats/{file_extension}"
133 | try:
134 | print(f"Attempting to download {file_extension} from {remote_url}.")
135 | (stats_dir / "/".join(file_extension.split("/")[:-1])).mkdir(
136 | exist_ok=True, parents=True
137 | )
138 | torch.hub.download_url_to_file(remote_url, filename)
139 | print("Successfully downloaded.")
140 | except Exception as e:
141 | print(f"Unable to download due to {e}. Computing locally....")
142 |
143 | ds = get_ds() if not filename.exists() else None
144 |
145 | if progress is None:
146 | progress = lambda x: x
147 |
148 | stat = CombinedStat(**{k: STAT_TYPES[k]() for k in to_collect})
149 | loader = tally(
150 | stat,
151 | ds,
152 | cache=(filename if not force_recompute else None),
153 | sample_size=sample_size,
154 | batch_size=batch_size,
155 | collate_fn=length_collation(batch_tokens),
156 | pin_memory=True,
157 | random_sample=1,
158 | num_workers=2,
159 | )
160 | batch_count = -(-(sample_size or len(ds)) // batch_size)
161 | with torch.no_grad():
162 | for batch_group in progress(loader, total=batch_count):
163 | for batch in batch_group:
164 | batch = dict_to_(batch, "cuda")
165 | with Trace(
166 | model, layer_name, retain_input=True, retain_output=False, stop=True
167 | ) as tr:
168 | model(**batch)
169 | feats = flatten_masked_batch(tr.input, batch["attention_mask"])
170 | # feats = flatten_masked_batch(tr.output, batch["attention_mask"])
171 | feats = feats.to(dtype=dtype)
172 | stat.add(feats)
173 | return stat
174 |
175 |
176 | if __name__ == "__main__":
177 | main()
178 |
--------------------------------------------------------------------------------
/memit/rome/repr_tools.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains utilities for extracting token representations and indices
3 | from string templates. Used in computing the left and right vectors for ROME.
4 | """
5 |
6 | from copy import deepcopy
7 | from typing import List
8 |
9 | import torch
10 | from transformers import AutoModelForCausalLM, AutoTokenizer
11 |
12 | from util import nethook
13 |
14 |
15 | def get_reprs_at_word_tokens(
16 | model: AutoModelForCausalLM,
17 | tok: AutoTokenizer,
18 | context_templates: List[str],
19 | words: List[str],
20 | layer: int,
21 | module_template: str,
22 | subtoken: str,
23 | track: str = "in",
24 | ) -> torch.Tensor:
25 | """
26 | Retrieves the last token representation of `word` in `context_template`
27 | when `word` is substituted into `context_template`. See `get_last_word_idx_in_template`
28 | for more details.
29 | """
30 |
31 | idxs = get_words_idxs_in_templates(tok, context_templates, words, subtoken)
32 | return get_reprs_at_idxs(
33 | model,
34 | tok,
35 | [context_templates[i].format(words[i]) for i in range(len(words))],
36 | idxs,
37 | layer,
38 | module_template,
39 | track,
40 | )
41 |
42 |
43 | def get_words_idxs_in_templates(
44 | tok: AutoTokenizer, context_templates: str, words: str, subtoken: str
45 | ) -> int:
46 | """
47 | Given list of template strings, each with *one* format specifier
48 | (e.g. "{} plays basketball"), and words to be substituted into the
49 | template, computes the post-tokenization index of their last tokens.
50 | """
51 |
52 | assert all(
53 | tmp.count("{}") == 1 for tmp in context_templates
54 | ), "We currently do not support multiple fill-ins for context"
55 |
56 | # Compute prefixes and suffixes of the tokenized context
57 | fill_idxs = [tmp.index("{}") for tmp in context_templates]
58 | prefixes, suffixes = [
59 | tmp[: fill_idxs[i]] for i, tmp in enumerate(context_templates)
60 | ], [tmp[fill_idxs[i] + 2 :] for i, tmp in enumerate(context_templates)]
61 | words = deepcopy(words)
62 |
63 | # Pre-process tokens
64 | for i, prefix in enumerate(prefixes):
65 | if len(prefix) > 0:
66 | assert prefix[-1] == " "
67 | prefix = prefix[:-1]
68 |
69 | prefixes[i] = prefix
70 | words[i] = f" {words[i].strip()}"
71 |
72 | # Tokenize to determine lengths
73 | assert len(prefixes) == len(words) == len(suffixes)
74 | n = len(prefixes)
75 | batch_tok = tok([*prefixes, *words, *suffixes])
76 | prefixes_tok, words_tok, suffixes_tok = [
77 | batch_tok[i : i + n] for i in range(0, n * 3, n)
78 | ]
79 | prefixes_len, words_len, suffixes_len = [
80 | [len(el) for el in tok_list]
81 | for tok_list in [prefixes_tok, words_tok, suffixes_tok]
82 | ]
83 |
84 | # Compute indices of last tokens
85 | if subtoken == "last" or subtoken == "first_after_last":
86 | return [
87 | [
88 | prefixes_len[i]
89 | + words_len[i]
90 | - (1 if subtoken == "last" or suffixes_len[i] == 0 else 0)
91 | ]
92 | # If suffix is empty, there is no "first token after the last".
93 | # So, just return the last token of the word.
94 | for i in range(n)
95 | ]
96 | elif subtoken == "first":
97 | return [[prefixes_len[i]] for i in range(n)]
98 | else:
99 | raise ValueError(f"Unknown subtoken type: {subtoken}")
100 |
101 |
102 | def get_reprs_at_idxs(
103 | model: AutoModelForCausalLM,
104 | tok: AutoTokenizer,
105 | contexts: List[str],
106 | idxs: List[List[int]],
107 | layer: int,
108 | module_template: str,
109 | track: str = "in",
110 | ) -> torch.Tensor:
111 | """
112 | Runs input through model and returns averaged representations of the tokens
113 | at each index in `idxs`.
114 | """
115 |
116 | def _batch(n):
117 | for i in range(0, len(contexts), n):
118 | yield contexts[i : i + n], idxs[i : i + n]
119 |
120 | assert track in {"in", "out", "both"}
121 | both = track == "both"
122 | tin, tout = (
123 | (track == "in" or both),
124 | (track == "out" or both),
125 | )
126 | module_name = module_template.format(layer)
127 | to_return = {"in": [], "out": []}
128 |
129 | def _process(cur_repr, batch_idxs, key):
130 | nonlocal to_return
131 | cur_repr = cur_repr[0] if type(cur_repr) is tuple else cur_repr
132 | for i, idx_list in enumerate(batch_idxs):
133 | to_return[key].append(cur_repr[i][idx_list].mean(0))
134 |
135 | for batch_contexts, batch_idxs in _batch(n=128):
136 | contexts_tok = tok(batch_contexts, padding=True, return_token_type_ids = False, return_tensors="pt").to(
137 | next(model.parameters()).device
138 | )
139 |
140 | with torch.no_grad():
141 | with nethook.Trace(
142 | module=model,
143 | layer=module_name,
144 | retain_input=tin,
145 | retain_output=tout,
146 | ) as tr:
147 | model(**contexts_tok)
148 |
149 | if tin:
150 | _process(tr.input, batch_idxs, "in")
151 | if tout:
152 | _process(tr.output, batch_idxs, "out")
153 |
154 | to_return = {k: torch.stack(v, 0) for k, v in to_return.items() if len(v) > 0}
155 |
156 | if len(to_return) == 1:
157 | return to_return["in"] if tin else to_return["out"]
158 | else:
159 | return to_return["in"], to_return["out"]
160 |
--------------------------------------------------------------------------------
/memit/rome/rome_hparams.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List
3 |
4 | from util.hparams import HyperParams
5 |
6 |
7 | @dataclass
8 | class ROMEHyperParams(HyperParams):
9 | # Method
10 | layers: List[int]
11 | fact_token: str
12 | v_num_grad_steps: int
13 | v_lr: float
14 | v_loss_layer: int
15 | v_weight_decay: float
16 | clamp_norm_factor: float
17 | kl_factor: float
18 | mom2_adjustment: bool
19 | context_template_length_params: List[List[int]]
20 |
21 | # Module templates
22 | rewrite_module_tmp: str
23 | layer_module_tmp: str
24 | mlp_module_tmp: str
25 | attn_module_tmp: str
26 | ln_f_module: str
27 | lm_head_module: str
28 |
29 | # Statistics
30 | mom2_dataset: str
31 | mom2_n_samples: int
32 | mom2_dtype: str
33 |
--------------------------------------------------------------------------------
/memit/rome/tok_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.utils.rnn import pad_sequence
3 | from torch.utils.data import Dataset
4 |
5 |
6 | class TokenizedDataset(Dataset):
7 | """
8 | Converts a dataset of text samples into a dataset of token sequences,
9 | as converted by a supplied tokenizer. The tokens come along with position
10 | ids and attention masks, they can be supplied direcly to the model.
11 | """
12 |
13 | def __init__(self, text_dataset, tokenizer=None, maxlen=None, field="text"):
14 | self.text_dataset = text_dataset
15 | self.field = field
16 | self.tokenizer = tokenizer
17 | self.maxlen = maxlen
18 | if hasattr(text_dataset, "info"):
19 | self.info = text_dataset.info
20 |
21 | def __len__(self):
22 | return len(self.text_dataset)
23 |
24 | def __getitem__(self, i):
25 | text = self.text_dataset[i]
26 | if self.field is not None:
27 | text = text[self.field]
28 | token_list = self.tokenizer.encode(
29 | text, truncation=True, max_length=self.maxlen
30 | )
31 | position_ids = list(range(len(token_list)))
32 | attention_mask = [1] * len(token_list)
33 | return dict(
34 | input_ids=torch.tensor(token_list),
35 | #position_ids=torch.tensor(position_ids),
36 | attention_mask=torch.tensor(attention_mask),
37 | )
38 |
39 |
40 | def dict_to_(data, device):
41 | """
42 | Moves a dictionary of tensors to the specified device.
43 | """
44 | for k in data:
45 | data[k] = data[k].to(device)
46 | return data
47 |
48 |
49 | def length_collation(token_size):
50 | """
51 | Sorts a batch of sequences and breaks it up into subbatches
52 | of same-sized sequences, padding as needed. Each batch
53 | has no more than token_size total tokens (or a single
54 | sequence, if the sequence happens to be larger).
55 | """
56 |
57 | def collate_fn(items):
58 | items = sorted(items, key=lambda x: -len(x["input_ids"]))
59 | batches = []
60 | batch = []
61 | batch_width = 0
62 | for item in items:
63 | item_width = len(item["input_ids"])
64 | if item_width == 0:
65 | break
66 | if batch_width * (len(batch) + 1) > token_size:
67 | batches.append(make_padded_batch(batch))
68 | batch = []
69 | batch_width = 0
70 | if not batch:
71 | batch_width = item_width
72 | batch.append(item)
73 | if len(batch):
74 | batches.append(make_padded_batch(batch))
75 | return batches
76 |
77 | return collate_fn
78 |
79 |
80 | def make_padded_batch(items):
81 | """
82 | Pads sequences in a batch, so they are all the same length as the longest.
83 | """
84 | max_len = max(len(d["input_ids"]) for d in items)
85 | if max_len == 0:
86 | return {k: torch.zeros((0, 0), dtype=torch.long) for k in items[0]}
87 | return {
88 | k: pad_sequence([d[k] for d in items if len(d["input_ids"])], batch_first=True)
89 | for k, v in items[0].items()
90 | }
91 |
92 |
93 | def flatten_masked_batch(data, mask):
94 | """
95 | Flattens feature data, ignoring items that are masked out of attention.
96 | """
97 | flat_data = data.view(-1, data.size(-1))
98 | attended_tokens = mask.view(-1).nonzero()[:, 0]
99 | return flat_data[attended_tokens]
100 |
--------------------------------------------------------------------------------
/memit/scaling_curves.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | # Constants
5 | DIR="scaling"
6 | MIN_NUM_RECORDS="10000"
7 | GEN_TEST_INTERV="10"
8 | N_EDITS="1,56,100,316,562,1000,1778,3162,5623,10000"
9 |
10 | # Run configurations
11 | MODEL_NAME="EleutherAI/gpt-j-6B"
12 | ALG_NAMES=("FT" "MEND" "ROME" "MEMIT")
13 | HPARAMS_FNAMES=("EleutherAI_gpt-j-6B_wd.json" "EleutherAI_gpt-j-6B.json" "EleutherAI_gpt-j-6B.json" "EleutherAI_gpt-j-6B.json")
14 |
15 | # Execute
16 | for i in ${!ALG_NAMES[@]}
17 | do
18 | alg_name=${ALG_NAMES[$i]}
19 | hparams_fname=${HPARAMS_FNAMES[$i]}
20 |
21 | echo "Running evals for $alg_name..."
22 | sweep_dir="$DIR/$alg_name"
23 |
24 | if [ -d "results/$sweep_dir" ]; then
25 | echo "Note: results/$sweep_dir already exists! Continuing from previous run..."
26 | fi
27 |
28 | echo "Dumping results at results/$sweep_dir"
29 | mkdir -p results/$sweep_dir
30 | echo "{}" > results/$sweep_dir/config.json
31 |
32 | python3 -m experiments.sweep --alg_name=$alg_name --model_name=$MODEL_NAME --hparams_fname=$hparams_fname --sweep_dir=$sweep_dir --min_num_records=$MIN_NUM_RECORDS --num_edits=$N_EDITS --generation_test_interval=$GEN_TEST_INTERV --use_cache
33 | done
34 |
35 | exit 0
36 |
--------------------------------------------------------------------------------
/memit/scripts/causal_trace.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Start from parent directory of script
4 | SCRIPT_DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
5 | cd "$(dirname ${SCRIPT_DIR})"
6 |
7 | python -m experiments.causal_trace --model_name "EleutherAI/gpt-j-6B" --noise_level 0.025
8 | python -m experiments.causal_trace --model_name "gpt2-xl" --noise_level 0.1
9 | python -m experiments.causal_trace --model_name "EleutherAI/gpt-neox-20b" --noise_level 0.03
10 |
--------------------------------------------------------------------------------
/memit/scripts/colab_reqs/additional.txt:
--------------------------------------------------------------------------------
1 | allennlp==2.9.0
2 | einops==0.4.0
3 | higher==0.2.1
4 | hydra-core==1.1.1
--------------------------------------------------------------------------------
/memit/scripts/colab_reqs/rome.txt:
--------------------------------------------------------------------------------
1 | datasets==1.18.3
2 | python-dotenv==0.19.2
3 | git+https://github.com/kmeng01/transformers-colab@allennlp-compat
4 |
--------------------------------------------------------------------------------
/memit/scripts/collect_layer_stats.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Start from parent directory of script
4 | SCRIPT_DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
5 | cd "$(dirname ${SCRIPT_DIR})"
6 |
7 | run_gpu_0() {
8 | CUDA_VISIBLE_DEVICES=0 python -m rome.layer_stats --layers=$(seq -s, 0 1 27) --sample_size 100000 --model_name=EleutherAI/gpt-j-6B
9 | }
10 |
11 | run_gpu_1() {
12 | CUDA_VISIBLE_DEVICES=1 python -m rome.layer_stats --layers=$(seq -s, 0 1 27) --sample_size 100000 --model_name=EleutherAI/gpt-j-6B
13 | }
14 |
15 | # run_gpu_0 &>stats0.out&
16 | run_gpu_1 &>stats1.out&
17 |
--------------------------------------------------------------------------------
/memit/scripts/ipynb_drop_output.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """
4 | Suppress output and prompt numbers in git version control.
5 |
6 | This script will tell git to ignore prompt numbers and cell output
7 | when looking at ipynb files UNLESS their metadata contains:
8 |
9 | "git": {
10 | "keep_outputs": true
11 | },
12 |
13 | The notebooks themselves are not changed.
14 |
15 | See also this blogpost: http://pascalbugnion.net/blog/ipython-notebooks-and-git.html.
16 |
17 | Usage instructions
18 | ==================
19 |
20 | 1. Put this script in a directory that is on the system's path.
21 | For future reference, I will assume you saved it in
22 | `~/scripts/ipynb_drop_output`.
23 | 2. Make sure it is executable by typing the command
24 | `chmod +x ~/scripts/ipynb_drop_output`.
25 | 3. Register a filter for ipython notebooks by
26 | putting the following line in `~/.config/git/attributes`:
27 | `*.ipynb filter=clean_ipynb`
28 | 4. Connect this script to the filter by running the following
29 | git commands:
30 |
31 | git config --global filter.clean_ipynb.clean ipynb_drop_output
32 | git config --global filter.clean_ipynb.smudge cat
33 |
34 | To tell git NOT to ignore the output and prompts for a notebook,
35 | open the notebook's metadata (Edit > Edit Notebook Metadata). A
36 | panel should open containing the lines:
37 |
38 | {
39 | "name" : "",
40 | "signature" : "some very long hash"
41 | }
42 |
43 | Add an extra line so that the metadata now looks like:
44 |
45 | {
46 | "name" : "",
47 | "signature" : "don't change the hash, but add a comma at the end of the line",
48 | "git" : { "keep_outputs" : true }
49 | }
50 |
51 | You may need to "touch" the notebooks for git to actually register a change, if
52 | your notebooks are already under version control.
53 |
54 | Notes
55 | =====
56 |
57 |
58 | This script is inspired by http://stackoverflow.com/a/20844506/827862, but
59 | lets the user specify whether the ouptut of a notebook should be kept
60 | in the notebook's metadata, and works for IPython v3.0.
61 | """
62 |
63 | import json
64 | import sys
65 |
66 | nb = sys.stdin.read()
67 |
68 | json_in = json.loads(nb)
69 | nb_metadata = json_in["metadata"]
70 | keep_output = False
71 | if "git" in nb_metadata:
72 | if "keep_outputs" in nb_metadata["git"] and nb_metadata["git"]["keep_outputs"]:
73 | keep_output = True
74 | if keep_output:
75 | sys.stdout.write(nb)
76 | exit()
77 |
78 |
79 | ipy_version = int(json_in["nbformat"]) - 1 # nbformat is 1 more than actual version.
80 |
81 |
82 | def strip_output_from_cell(cell):
83 | if "outputs" in cell:
84 | cell["outputs"] = []
85 | if "prompt_number" in cell:
86 | del cell["prompt_number"]
87 | if "execution_count" in cell:
88 | cell["execution_count"] = None
89 |
90 |
91 | if ipy_version == 2:
92 | for sheet in json_in["worksheets"]:
93 | for cell in sheet["cells"]:
94 | strip_output_from_cell(cell)
95 | else:
96 | for cell in json_in["cells"]:
97 | strip_output_from_cell(cell)
98 |
99 | json.dump(
100 | json_in,
101 | sys.stdout,
102 | sort_keys=True,
103 | indent=1,
104 | separators=(",", ": "),
105 | ensure_ascii=False,
106 | )
107 | # https://stackoverflow.com/questions/729692/why-should-text-files-end-with-a-newline
108 | sys.stdout.write("\n")
109 |
--------------------------------------------------------------------------------
/memit/scripts/memit.yml:
--------------------------------------------------------------------------------
1 | name: memit
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.9.7
7 | - pip=21.2.4
8 | - cudatoolkit=11.3
9 | - pytorch==1.12.1
10 | - pip:
11 | - einops==0.4.0
12 | - higher==0.2.1
13 | - hydra-core==1.2.0
14 | - transformers==4.23.1
15 | - datasets==1.18.3
16 | - matplotlib==3.6.1
17 | - spacy==3.4.1
18 | - scipy==1.9.2
19 | - scikit-learn==1.0.2
20 | - nltk==3.7
21 | - jupyter==1.0.0
--------------------------------------------------------------------------------
/memit/scripts/setup_clean_ipynb.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Start from directory of script
4 | cd "$(dirname "${BASH_SOURCE[0]}")"
5 |
6 | # Set up git config filters so huge output of notebooks is not committed.
7 | git config filter.clean_ipynb.clean "$(pwd)/ipynb_drop_output.py"
8 | git config filter.clean_ipynb.smudge cat
9 | git config filter.clean_ipynb.required true
10 |
--------------------------------------------------------------------------------
/memit/scripts/setup_conda.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Start from directory of script
4 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
5 | cd $SCRIPT_DIR
6 |
7 | # Detect operating system
8 | unameOut="$(uname -s)"
9 | case "${unameOut}" in
10 | Linux*) machine=Linux;;
11 | Darwin*) machine=Mac;;
12 | CYGWIN*) machine=Cygwin;;
13 | MINGW*) machine=MinGw;;
14 | *) machine="UNKNOWN:${unameOut}"
15 | esac
16 |
17 | if [ $machine != "Linux" ] && [ $machine != "Mac" ]
18 | then
19 | echo "Conda setup script is only available on Linux and Mac."
20 | exit 1
21 | else
22 | echo "Running on $machine..."
23 | fi
24 |
25 | if [[ -z "${CONDA_HOME}" ]]; then
26 | echo "Please specify the CONDA_HOME environment variable (it might look something like ~/miniconda3)."
27 | exit 1
28 | else
29 | echo "Found CONDA_HOME=${CONDA_HOME}."
30 | fi
31 |
32 | RECIPE=${RECIPE:-memit}
33 | ENV_NAME="${ENV_NAME:-${RECIPE}}"
34 | echo "Creating conda environment ${ENV_NAME}..."
35 |
36 | if [[ ! $(type -P conda) ]]
37 | then
38 | echo "conda not in PATH"
39 | echo "read: https://conda.io/docs/user-guide/install/index.html"
40 | exit 1
41 | fi
42 |
43 | if df "${HOME}/.conda" --type=afs > /dev/null 2>&1
44 | then
45 | echo "Not installing: your ~/.conda directory is on AFS."
46 | echo "Use 'ln -s /some/nfs/dir ~/.conda' to avoid using up your AFS quota."
47 | exit 1
48 | fi
49 |
50 | # Build new environment
51 | conda env create --name=${ENV_NAME} -f ${RECIPE}.yml
52 |
--------------------------------------------------------------------------------
/memit/transformer_utils/README.md:
--------------------------------------------------------------------------------
1 | ## transformer-utils
2 |
3 | Utilities for the HuggingFace [transformers](https://github.com/huggingface/transformers/) library, focused on loading and using large pretrained autoregressive language models like GPT-2 and GPT-Neo.
4 |
5 | This package is unofficial and not associated with HuggingFace.
6 |
7 | Features:
8 |
9 | - Load large (~2.7B) models in low-resource environments like Google Colab
10 | - Get activations from any part of the model, without running parts you don't need
11 | - Interpret models with the "logit lens"
12 | - For background, see
13 | - ["interpreting GPT: the logit lens"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) by nostalgebraist
14 | - ["Finding the Words to Say: Hidden State Visualizations for Language Models"](https://jalammar.github.io/hidden-states/) by Jay Alammar
15 |
16 | ## Example usage
17 |
18 | ### Load in a low-memory environment
19 |
20 | Loading a 2.7B model:
21 |
22 | ```python
23 | from transformer_utils.low_memory import enable_low_memory_load
24 |
25 | enable_low_memory_load()
26 |
27 | model = transformers.AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-2.7B')
28 | ```
29 |
30 | This works fine in an ordinary (non-Pro) Google Colab notebook, with ~12 GB RAM and a T5 GPU.
31 |
32 | Inference will work up to the full context window length of 2048 tokens without memory issues.
33 |
34 | ### Logit lens
35 |
36 | ```python
37 | import torch
38 | import transformers
39 | from transformer_utils.low_memory import enable_low_memory_load
40 |
41 | enable_low_memory_load()
42 |
43 | tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
44 | model = transformers.AutoModelForCausalLM.from_pretrained('gpt2-xl')
45 |
46 | def text_to_input_ids(text):
47 | toks = tokenizer.encode(text)
48 | return torch.as_tensor(toks).view(1, -1).cuda()
49 |
50 | input_ids = text_to_input_ids("This is an example. You can probably think of a more fun text to use than this one.")
51 |
52 | plot_logit_lens(model, tokenizer, input_ids, start_ix=0, end_ix=45) # logits
53 |
54 | plot_logit_lens(model, tokenizer, input_ids, start_ix=0, end_ix=45, probs=True) # probabilities
55 |
56 | plot_logit_lens(model, tokenizer, input_ids, start_ix=0, end_ix=45, kl=True) # K-L divergence
57 | ```
58 |
59 | You can do also some other things that aren't in the original blog posts. This will break down the transformer blocks into their attention and MLP parts:
60 |
61 | ```python
62 | plot_logit_lens(model, tokenizer, input_ids, start_ix=0, end_ix=45, include_subblocks=True)
63 | ```
64 |
65 | You can also change the definition of the "decoder" to include some of the later blocks/subblocks of the model. This helps especially in interpreting GPT-Neo hidden states.
66 |
67 | ```python
68 | # assume we have a 48-layer model
69 | # so 'h.47' is the final layer
70 |
71 | # include last layer in decoder
72 | plot_logit_lens(
73 | model, tokenizer, input_ids, start_ix=0, end_ix=45,
74 | decoder_layer_names=['h.47', 'final_layernorm', 'lm_head']
75 | )
76 |
77 | # include just the last MLP subblock in decoder
78 | plot_logit_lens(
79 | model, tokenizer, input_ids, start_ix=0, end_ix=45,
80 | decoder_layer_names=['h.47.mlp', 'final_layernorm', 'lm_head']
81 | )
82 | ```
83 |
84 | ### Get activations from any part of the model
85 |
86 | ###### ...and without running parts you don't need
87 |
88 | ```python
89 | from transformer_utils.partial_forward import partial_forward
90 |
91 | output = partial_forward(
92 | model=model, # your `transformers` model
93 | output_names=[
94 | 'h.0', # output of the 1st layer
95 | 'h.2.attn.c_attn', # query/key/value matrix from the 3rd layer
96 | 'h.5.mlp.c_proj', # feed-forward activations from the 6th layer
97 | ],
98 | input_ids # the input to run
99 | )
100 |
101 | # each of these is a tensor
102 | output['h.0']
103 | output['h.2.attn.c_attn']
104 | output['h.5.mlp.c_proj']
105 | ```
106 |
107 | For efficiency, `partial_forward` doesn't run any part of the model later than the ones you specify in `output_names`.
108 |
109 | For example, suppose `model` above was GPT-2 XL. Then it has 48 layers. But the forward pass in the code above stops running after the 6th layer of 48 -- so the compute and memory cost is far lower than a full `model.forward`.
110 |
111 | This makes it easy to write new "heads" that do further computation on the model's activations.
112 |
113 | Some examples:
114 |
115 | ##### Using the first two layers of a model as features extractors for binary classification
116 |
117 | ```python
118 | output_names=['h.0', 'h.1',]
119 | classifier_hidden_size=768
120 |
121 | feature_vector_size = base_model.config.n_embd * len(output_names)
122 |
123 | classifier = nn.Sequential(
124 | nn.Linear(feature_vector_size, classifier_hidden_size),
125 | nn.ReLU(),
126 | nn.Linear(classifier_hidden_size, 2),
127 | )
128 |
129 | opt = torch.optim.Adam(classifier.parameters())
130 |
131 | for input_ids, targets in dataset: # `dataset` is your classification train data
132 | with torch.no_grad():
133 | hidden_states = partial_forward(
134 | base_model,
135 | output_names,
136 | input_ids,
137 | )
138 |
139 | # shape (batch, sequence, len(output_names) * model's hidden size)
140 | feature_vector = torch.cat(
141 | [hidden_states[name] for name in output_names],
142 | dim=-1
143 | )
144 |
145 | # shape (batch, sequence, 2)
146 | classifier_out = classifier(feature_vector)
147 |
148 | # simple avg pool over sequence dim -- in practice find attention works well for this step :)
149 | # shape (batch, 2)
150 | logits = classifier_out.mean(dim=1)
151 |
152 | loss = F.cross_entropy(target=targets, input=logits)
153 | loss.backward()
154 | opt.step()
155 | opt.zero_grad()
156 | ```
157 |
158 |
159 | ##### Finetuning the first two layers of a model
160 |
161 | This is exactly the same as the above, with just two changes:
162 |
163 | - Remove the `with torch.no_grad()` wrapper around `partial_forward`
164 | - Optimize the base model's params too:
165 | - `opt = torch.optim.Adam(list(classifier.parameters()) + list(base_model.parameters()))`
166 |
167 | If you want to train a model like these ones for real use, I recommend writing a custom `nn.Module`. [See here](https://github.com/nostalgebraist/nostalgebraist-autoresponder/blob/fd96e9482186f5dbeaa27bd6179087c892c577d6/selector_model/selector_nn_neo.py#L263) for an example.
168 |
--------------------------------------------------------------------------------
/memit/transformer_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuaihong/ConceptVectors/607591b415043f7692bc17a9748de3d8ff3fc0c7/memit/transformer_utils/__init__.py
--------------------------------------------------------------------------------
/memit/transformer_utils/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | transformers
3 | seaborn
4 | tqdm
5 | colorcet
6 |
--------------------------------------------------------------------------------
/memit/transformer_utils/setup.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from setuptools import setup, find_packages
3 |
4 | # The directory containing this file
5 | HERE = pathlib.Path(__file__).parent
6 |
7 | # The text of the README file
8 | README = (HERE / "README.md").read_text()
9 |
10 | setup(
11 | name='transformer-utils',
12 | description="Large autoregressive language modeling helpers",
13 | long_description=README,
14 | long_description_content_type="text/markdown",
15 | url="https://github.com/nostalgebraist/transformer-utils",
16 | author="nostalgebraist",
17 | author_email="nostalgebraist@gmail.com",
18 | license="MIT",
19 | version='0.1.0',
20 | packages=find_packages('src'),
21 | package_dir={'': 'src'},
22 | install_requires=[
23 | 'torch',
24 | 'transformers',
25 | 'seaborn',
26 | 'tqdm',
27 | 'colorcet'
28 | ],
29 | )
30 |
--------------------------------------------------------------------------------
/memit/transformer_utils/src/transformer_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .low_memory import low_memory_load
2 | from . import util
3 | from . import partial_forward
4 |
5 | __all__ = [
6 | 'low_memory_load',
7 | 'util',
8 | 'partial_forward',
9 | ]
10 |
--------------------------------------------------------------------------------
/memit/transformer_utils/src/transformer_utils/logit_lens/__init__.py:
--------------------------------------------------------------------------------
1 | from .hooks import make_lens_hooks, clear_lens_hooks
2 | from . import plotting
3 | from .plotting import plot_logit_lens
4 |
5 | __all__ = [
6 | 'make_lens_hooks',
7 | 'clear_lens_hooks',
8 | 'plotting',
9 | 'plot_logit_lens'
10 | ]
11 |
--------------------------------------------------------------------------------
/memit/transformer_utils/src/transformer_utils/logit_lens/hooks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from ..util.python_utils import make_print_if_verbose
5 | from ..util.module_utils import get_child_module_by_names
6 |
7 | _RESID_SUFFIXES = {".attn", ".mlp"}
8 |
9 |
10 | def blocks_input_locator(model: nn.Module):
11 | """
12 | HF usually (always?) places a dropout after the input embeddings.
13 | TODO: avoid depending on this
14 | """
15 | dropouts_on_base_model = [
16 | mod for mod in model.base_model.children()
17 | if isinstance(mod, nn.Dropout)
18 | ]
19 | if len(dropouts_on_base_model) > 0:
20 | return lambda: dropouts_on_base_model[0]
21 | raise ValueError('could not identify blocks input')
22 |
23 |
24 | def final_layernorm_locator(model: nn.Module):
25 | layernorms_on_base_model = [
26 | mod for mod in model.base_model.children()
27 | if isinstance(mod, nn.LayerNorm)
28 | ]
29 | if len(layernorms_on_base_model) > 0:
30 | return lambda: layernorms_on_base_model[0]
31 | raise ValueError('could not identify ln_f')
32 |
33 |
34 | def _locate_special_modules(model):
35 | if not hasattr(model, "_blocks_input_getter"):
36 | model._blocks_input_getter = blocks_input_locator(model)
37 |
38 | if not hasattr(model, "_ln_f_getter"):
39 | model._ln_f_getter = final_layernorm_locator(model)
40 |
41 |
42 | def _get_layer(model, name):
43 | if name == "input":
44 | return model._blocks_input_getter()
45 | if name == "final_layernorm":
46 | return model._ln_f_getter()
47 |
48 | model_with_module = model if name == "lm_head" else model.base_model
49 | return get_child_module_by_names(model_with_module, name.split("."))
50 |
51 |
52 | def _sqz(x):
53 | if isinstance(x, torch.Tensor):
54 | return x
55 | try:
56 | return x[0]
57 | except:
58 | return x
59 |
60 |
61 | def _get_layer_and_compose_with_ln(model, name):
62 | if name.endswith('.attn'):
63 | lname = name[:-len('.attn')] + '.ln_1'
64 | ln = _get_layer(model, lname)
65 | elif name.endswith('.mlp'):
66 | lname = name[:-len('.mlp')] + '.ln_2'
67 | ln = _get_layer(model, lname)
68 | else:
69 | ln = lambda x: x
70 | return lambda x: _get_layer(model, name)(ln(x))
71 |
72 |
73 | def make_decoder(model, decoder_layer_names=['final_layernorm', 'lm_head']):
74 | _locate_special_modules(model)
75 |
76 | decoder_layers = [_get_layer_and_compose_with_ln(model, name) for name in decoder_layer_names]
77 |
78 | def _decoder(x):
79 | for name, layer in zip(decoder_layer_names, decoder_layers):
80 | layer_out = _sqz(layer(_sqz(x)))
81 |
82 | # TODO: DRY
83 | is_resid = any([name.endswith(s) for s in _RESID_SUFFIXES])
84 | if is_resid:
85 | x = x + layer_out
86 | else:
87 | x = layer_out
88 | return x
89 | return _decoder
90 |
91 |
92 | def make_lens_hooks(
93 | model,
94 | layer_names: list,
95 | decoder_layer_names: list = ['final_layernorm', 'lm_head'],
96 | verbose=True,
97 | start_ix=None,
98 | end_ix=None,
99 | ):
100 | vprint = make_print_if_verbose(verbose)
101 |
102 | clear_lens_hooks(model)
103 |
104 | def _opt_slice(x, start_ix, end_ix):
105 | if not start_ix:
106 | start_ix = 0
107 | if not end_ix:
108 | end_ix = x.shape[1]
109 | return x[:, start_ix:end_ix, :]
110 |
111 | _locate_special_modules(model)
112 |
113 | for attr in ["_layer_logits", "_layer_logits_handles"]:
114 | if not hasattr(model, attr):
115 | setattr(model, attr, {})
116 |
117 | # TODO: better naming
118 | model._ordered_layer_names = layer_names
119 |
120 | model._lens_decoder = make_decoder(model, decoder_layer_names)
121 |
122 | def _make_record_logits_hook(name):
123 | model._layer_logits[name] = None
124 |
125 | is_resid = any([name.endswith(s) for s in _RESID_SUFFIXES])
126 |
127 | def _record_logits_hook(module, input, output) -> None:
128 | del model._layer_logits[name]
129 | ln_f = model._ln_f_getter()
130 |
131 | if is_resid:
132 | decoder_in = model._last_resid + _sqz(output)
133 | else:
134 | decoder_in = _sqz(output)
135 | # print(decoder_in.shape)
136 | decoder_out = model._lens_decoder(decoder_in)
137 | # print(decoder_out.shape)
138 | decoder_out = _opt_slice(decoder_out, start_ix, end_ix)
139 |
140 | model._layer_logits[name] = decoder_out.detach().cpu().numpy()
141 | model._last_resid = decoder_in
142 |
143 | return _record_logits_hook
144 |
145 | def _hook_already_there(name):
146 | handle = model._layer_logits_handles.get(name)
147 | if not handle:
148 | return False
149 | layer = _get_layer(model, name)
150 | return handle.id in layer._forward_hooks
151 |
152 | for name in layer_names:
153 | if _hook_already_there(name):
154 | vprint(f"skipping layer {name}, hook already exists")
155 | continue
156 | layer = _get_layer(model, name)
157 | handle = layer.register_forward_hook(_make_record_logits_hook(name))
158 | model._layer_logits_handles[name] = handle
159 |
160 |
161 | def clear_lens_hooks(model):
162 | if hasattr(model, "_layer_logits_handles"):
163 | for k, v in model._layer_logits_handles.items():
164 | v.remove()
165 |
166 | ks = list(model._layer_logits_handles.keys())
167 | for k in ks:
168 | del model._layer_logits_handles[k]
169 |
--------------------------------------------------------------------------------
/memit/transformer_utils/src/transformer_utils/logit_lens/layer_names.py:
--------------------------------------------------------------------------------
1 | from ..util.module_utils import get_child_module_by_names
2 |
3 |
4 | def make_layer_names(
5 | model,
6 | block_step=1,
7 | include_input=True,
8 | force_include_output=True,
9 | include_subblocks=False,
10 | decoder_layer_names: list = ['final_layernorm', 'lm_head']
11 | ):
12 | h = get_child_module_by_names(model.base_model, ["h"])
13 | h_names = [f"h.{i}" for i in range(len(h))]
14 |
15 | last_h_name = h_names[-1]
16 |
17 | h_names = h_names[::block_step]
18 | if force_include_output and last_h_name not in h_names:
19 | h_names.append(last_h_name)
20 |
21 | if include_subblocks:
22 | names = [sub_name for name in h_names for sub_name in (f"{name}.attn", name)]
23 | else:
24 | names = h_names
25 |
26 | if include_input:
27 | names = ["input"] + names
28 |
29 | def _subset(a, b):
30 | return a == b or a.startswith(b + ".")
31 |
32 | def _names_overlap(a, b):
33 | return _subset(a, b) or _subset(b, a)
34 |
35 | names = [name for name in names
36 | if not any([_names_overlap(name, dname) for dname in decoder_layer_names])
37 | ]
38 |
39 | return names
40 |
--------------------------------------------------------------------------------
/memit/transformer_utils/src/transformer_utils/low_memory/__init__.py:
--------------------------------------------------------------------------------
1 | from .load_context import LazyLinearAPICompatible, LazyTransformersConv1D, LowMemoryLoadContext
2 | from .load import low_memory_load
3 | from .enable import enable_low_memory_load, disable_low_memory_load
4 |
5 | __all__ = [
6 | 'LazyLinearAPICompatible',
7 | 'LazyTransformersConv1D',
8 | 'LowMemoryLoadContext',
9 | 'low_memory_load',
10 | 'enable_low_memory_load',
11 | 'disable_low_memory_load'
12 | ]
13 |
--------------------------------------------------------------------------------
/memit/transformer_utils/src/transformer_utils/low_memory/enable.py:
--------------------------------------------------------------------------------
1 | import transformers
2 |
3 | from .load import low_memory_load
4 | from ..util.tfm_utils import huggingface_model_local_paths
5 |
6 | _TFM_PRETRAINED_MODEL_FROM_PRETRAINED_ORIGINAL = transformers.modeling_utils.PreTrainedModel.from_pretrained
7 |
8 |
9 | def low_memory_from_pretrained(pretrained_model_name_or_path, *args, **kwargs):
10 | config_path, model_path = huggingface_model_local_paths(pretrained_model_name_or_path)
11 |
12 | model = low_memory_load(config_path=config_path, model_path=model_path, verbose=False)
13 |
14 | return model
15 |
16 |
17 | def enable_low_memory_load():
18 | transformers.modeling_utils.PreTrainedModel.from_pretrained = low_memory_from_pretrained
19 |
20 |
21 | def disable_low_memory_load():
22 | transformers.modeling_utils.PreTrainedModel.from_pretrained = _TFM_PRETRAINED_MODEL_FROM_PRETRAINED_ORIGINAL
23 |
--------------------------------------------------------------------------------
/memit/transformer_utils/src/transformer_utils/low_memory/load.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import transformers
3 |
4 | from ..util.python_utils import make_print_if_verbose
5 | from ..util.module_utils import get_child_module_by_names
6 | from ..util.tfm_utils import normalize_inconsistent_state_dict_keys
7 | from .load_context import LowMemoryLoadContext
8 |
9 |
10 | DEFAULT_GENERIC_INPUT = torch.as_tensor([[0]])
11 |
12 |
13 | def modify_weights_after_load(model):
14 | """the part of PreTrainedModel.init_weights that isn't initializing weights"""
15 | # Prune heads if needed
16 | if model.config.pruned_heads:
17 | model.prune_heads(model.config.pruned_heads)
18 |
19 | # Tie weights if needed
20 | model.tie_weights()
21 |
22 |
23 | def low_memory_load(
24 | config_path,
25 | model_path,
26 | config_cls=None,
27 | model_cls=None,
28 | high_memory_device="cuda:1",
29 | generic_input=DEFAULT_GENERIC_INPUT,
30 | verbose=True,
31 | ):
32 | vprint = make_print_if_verbose(verbose)
33 |
34 | if isinstance(high_memory_device, str):
35 | high_memory_device = torch.device(high_memory_device)
36 |
37 | if config_cls is None:
38 | config_cls = transformers.AutoConfig
39 |
40 | vprint("start")
41 |
42 | with LowMemoryLoadContext():
43 | config = config_cls.from_pretrained(config_path)
44 |
45 | vprint("made config obj")
46 |
47 | state_dict = torch.load(
48 | model_path,
49 | map_location=high_memory_device,
50 | )
51 |
52 | state_dict = normalize_inconsistent_state_dict_keys(state_dict)
53 |
54 | vprint("loaded state dict")
55 |
56 | # uses lazy init, no memory
57 | if model_cls is None:
58 | model = transformers.AutoModelForCausalLM.from_config(config)
59 | else:
60 | model = model_cls(config=config)
61 |
62 | vprint("made model obj")
63 |
64 | # START gpu --> cpu --> gpu handoff, one leaf module at a time
65 | handled = set()
66 |
67 | for name in dict(model.named_parameters()).keys():
68 | prefix = name.rpartition(".")[0]
69 | mod = get_child_module_by_names(model, prefix.split("."))
70 |
71 | if prefix in handled:
72 | continue
73 |
74 | vprint((name, prefix, mod))
75 |
76 | mk, uk, er = [], [], []
77 | mod._load_from_state_dict(
78 | state_dict,
79 | prefix=prefix + ".",
80 | local_metadata={},
81 | strict=True,
82 | missing_keys=mk,
83 | unexpected_keys=uk,
84 | error_msgs=er,
85 | )
86 | vprint((mk, uk, er))
87 | mod.to(high_memory_device)
88 | sdks = [k for k in state_dict if k.startswith(prefix)]
89 | for k in sdks:
90 | del state_dict[k]
91 | handled.add(prefix)
92 |
93 | # END gpu --> cpu --> gpu handoff, one leaf module at a time
94 |
95 | vprint("loaded params into memory")
96 |
97 | # does the buffers
98 | model = model.to(high_memory_device)
99 |
100 | vprint("loaded all into memory")
101 |
102 | # does stuff like weight tying, now that the weights actually exist
103 | modify_weights_after_load(model)
104 |
105 | model.eval()
106 |
107 | # ensures we materialize the lazy params (and delete the hooks for doing so), before doing anything else
108 | #
109 | # if you add pre-hooks before doing this step, you get OrderedDict mutation exceptions
110 | with torch.no_grad():
111 | out = model(generic_input.to(high_memory_device))
112 | out = None
113 |
114 | return model
115 |
--------------------------------------------------------------------------------
/memit/transformer_utils/src/transformer_utils/low_memory/load_context.py:
--------------------------------------------------------------------------------
1 | import torch.nn
2 | import transformers
3 | from torch.nn.modules.lazy import LazyModuleMixin
4 | from torch.nn.parameter import UninitializedParameter
5 |
6 | _TORCH_NN_ORIGINAL_LINEAR = torch.nn.Linear
7 |
8 | _TFM_PRETRAINED_MODEL_INIT_WEIGHTS_ORIGINAL = (
9 | transformers.modeling_utils.PreTrainedModel.init_weights
10 | )
11 | _TFM_CONV1D_ORIGINAL = transformers.modeling_utils.Conv1D
12 |
13 |
14 | def init_weights_without_init(self):
15 | pass
16 |
17 |
18 | class LazyLinearAPICompatible(torch.nn.LazyLinear):
19 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
20 | super().__init__(out_features=out_features, bias=bias)
21 |
22 |
23 | class LazyTransformersConv1D(LazyModuleMixin, _TFM_CONV1D_ORIGINAL):
24 | cls_to_become = _TFM_CONV1D_ORIGINAL
25 | weight: UninitializedParameter
26 |
27 | def __init__(self, nf, nx):
28 | super().__init__(nf=nf, nx=0)
29 | self.nx = 0
30 | self.weight = UninitializedParameter()
31 |
32 | def reset_parameters(self) -> None:
33 | if not self.has_uninitialized_params() and self.nx != 0:
34 | super().reset_parameters()
35 |
36 | def initialize_parameters(self, input) -> None:
37 | if self.has_uninitialized_params():
38 | with torch.no_grad():
39 | self.nx = input.shape[-1]
40 | self.weight.materialize((self.nf, self.nx))
41 | self.reset_parameters()
42 |
43 |
44 | class LowMemoryLoadContext:
45 | def __enter__(self):
46 | torch.nn.Linear = LazyLinearAPICompatible
47 | transformers.modeling_utils.Conv1D = LazyTransformersConv1D
48 | transformers.PreTrainedModel.init_weights = init_weights_without_init
49 |
50 | def __exit__(self, exc_type, exc_value, exc_traceback):
51 | torch.nn.Linear = _TORCH_NN_ORIGINAL_LINEAR
52 | transformers.modeling_utils.Conv1D = _TFM_CONV1D_ORIGINAL
53 | transformers.PreTrainedModel.init_weights = (
54 | _TFM_PRETRAINED_MODEL_INIT_WEIGHTS_ORIGINAL
55 | )
56 | return exc_type is None
57 |
--------------------------------------------------------------------------------
/memit/transformer_utils/src/transformer_utils/partial_forward/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import inspect
3 |
4 | from ..util.python_utils import make_print_if_verbose
5 |
6 |
7 | PARTIAL_FORWARD_FORCE_FALSE_KWARGS = {
8 | "use_cache",
9 | "output_attentions",
10 | "output_hidden_states",
11 | "return_dict",
12 | }
13 |
14 | PARTIAL_FORWARD_FORCE_FALSE_KWARGS_MSG = """`partial_forward` was passed the argument {kwarg} but will ignore it.
15 |
16 | `partial_forward` ignores arguments that configure output shape in `transformers`, since its output shape is configured entirely through the `output_names` argument."""
17 |
18 | VALIDATE_OUTPUT_BASE_MODEL_MSG = """Some `output_names` were not found on the model (a `{model_class_name}`), but exist on its base model (a `{base_model_class_name}`).
19 |
20 | Try either passing `model.base_model` as the model, OR adding the string '{base_model_prefix}.' to the start of each output name.
21 |
22 | Names not found: {names}"""
23 |
24 | VALIDATE_OUTPUT_NOT_FOUND_MSG = """Some `output_names` were not found on the model.
25 |
26 | To see valid output names, try `dict(model.named_modules()).keys()`.
27 |
28 | Names not found: {names}"""
29 |
30 |
31 | class AfterStoppingPointException(Exception):
32 | pass
33 |
34 |
35 | def _validate_output_names(model, output_names):
36 | if output_names is None:
37 | return
38 |
39 | findable_names = dict(model.named_modules()).keys()
40 |
41 | findable_names_base_model = set()
42 | if hasattr(model, "base_model") and hasattr(model, "base_model_prefix"):
43 | findable_names_base_model = dict(model.base_model.named_modules()).keys()
44 |
45 | problem_names = [name for name in output_names if name not in findable_names]
46 |
47 | base_model_names = [
48 | name for name in problem_names if name in findable_names_base_model
49 | ]
50 |
51 | if len(base_model_names) > 0:
52 | raise ValueError(
53 | VALIDATE_OUTPUT_BASE_MODEL_MSG.format(
54 | model_class_name=model.__class__.__name__,
55 | base_model_class_name=model.base_model.__class__.__name__,
56 | base_model_prefix=model.base_model_prefix,
57 | names=base_model_names,
58 | )
59 | )
60 |
61 | if len(problem_names) > 0:
62 | raise ValueError(VALIDATE_OUTPUT_NOT_FOUND_MSG.format(names=problem_names))
63 |
64 |
65 | def add_partial_forward_hooks(model, verbose=False, debug=False, output_names=None):
66 | vprint = make_print_if_verbose(verbose)
67 | dprint = make_print_if_verbose(debug)
68 |
69 | _validate_output_names(model, output_names)
70 |
71 | can_skip = output_names is not None
72 | can_skip = can_skip and hasattr(model, "_partial_forward_force_false_kwargs")
73 |
74 | names_to_mods = {}
75 | indices_to_names = {}
76 | names, mods = [], []
77 | for i, (name, mod) in enumerate(model.named_modules()):
78 | if hasattr(mod, "_partial_forward_name") and mod._partial_forward_name != name:
79 | can_skip = False
80 |
81 | mod._partial_forward_name = name
82 | indices_to_names[i] = name
83 |
84 | names.append(name)
85 | mods.append(mod)
86 | names_to_mods[name] = mod
87 |
88 | if output_names is not None:
89 | should_have_hook = name in output_names
90 | already_has_hook = hasattr(mod, "_record_to_sink_handle")
91 | can_skip = can_skip and (should_have_hook == already_has_hook)
92 |
93 | if can_skip:
94 | dprint("already have partial forward hooks, skipping")
95 | return
96 |
97 | sig = inspect.signature(model.__class__.forward)
98 | model._partial_forward_force_false_kwargs = (
99 | PARTIAL_FORWARD_FORCE_FALSE_KWARGS.intersection(sig.parameters.keys())
100 | )
101 |
102 | def _record_to_sink_hook(module, input, output) -> None:
103 | if hasattr(model, "_output_sink_names"):
104 | this_name = module._partial_forward_name
105 | dprint(f"reached output of {repr(this_name)}")
106 | dprint(f"model._output_sink_names: {model._output_sink_names}")
107 |
108 | if this_name in model._output_sink_names:
109 | dprint(f"{repr(this_name)} in sink")
110 |
111 | to_record = output
112 | if isinstance(to_record, tuple) and len(to_record) == 1:
113 | to_record = to_record[0]
114 |
115 | model._output_sink[this_name] = to_record
116 |
117 | if all([name in model._output_sink for name in model._output_sink_names]):
118 | dprint("have all model._output_sink_names, stopping")
119 |
120 | raise AfterStoppingPointException
121 |
122 | for name, mod in zip(names, mods):
123 | if hasattr(mod, "_record_to_sink_handle"):
124 | vprint(f"clearing existing handle at {repr(name)}")
125 | mod._record_to_sink_handle.remove()
126 |
127 | if output_names is None or name in output_names:
128 | rts_handle = mod.register_forward_hook(_record_to_sink_hook)
129 | mod._record_to_sink_handle = rts_handle
130 |
131 |
132 | def partial_forward(
133 | model,
134 | output_names,
135 | *args,
136 | verbose=False,
137 | debug=False,
138 | **kwargs,
139 | ):
140 | vprint = make_print_if_verbose(verbose)
141 |
142 | add_partial_forward_hooks(
143 | model, verbose=verbose, debug=debug, output_names=output_names
144 | )
145 |
146 | for k in model._partial_forward_force_false_kwargs:
147 | if kwargs.get(k):
148 | warnings.warn(PARTIAL_FORWARD_FORCE_FALSE_KWARGS_MSG.format(kwarg=repr(k)))
149 | kwargs[k] = False
150 |
151 | model._output_sink_names = output_names
152 |
153 | if hasattr(model, "_output_sink"):
154 | vprint("clearing existing _output_sink")
155 | for v in model._output_sink.values():
156 | del v
157 | del model._output_sink
158 |
159 | model._output_sink = {}
160 |
161 | try:
162 | model(*args, **kwargs)
163 | except AfterStoppingPointException as e:
164 | pass
165 |
166 | del model._output_sink_names
167 |
168 | return_val = model._output_sink
169 | del model._output_sink
170 |
171 | return return_val
172 |
--------------------------------------------------------------------------------
/memit/transformer_utils/src/transformer_utils/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuaihong/ConceptVectors/607591b415043f7692bc17a9748de3d8ff3fc0c7/memit/transformer_utils/src/transformer_utils/util/__init__.py
--------------------------------------------------------------------------------
/memit/transformer_utils/src/transformer_utils/util/module_utils.py:
--------------------------------------------------------------------------------
1 | from .python_utils import make_print_if_verbose
2 |
3 |
4 | def get_child_module_by_names(module, names):
5 | obj = module
6 | for getter in map(lambda name: lambda obj: getattr(obj, name), names):
7 | obj = getter(obj)
8 | return obj
9 |
10 |
11 | def get_leaf_modules(module, verbose=False):
12 | vprint = make_print_if_verbose(verbose)
13 |
14 | names = []
15 | leaves = []
16 | handled = set()
17 |
18 | for param_name in dict(module.named_parameters()).keys():
19 | mod_name = param_name.rpartition(".")[0]
20 | mod = get_child_module_by_names(module, mod_name.split("."))
21 |
22 | if mod_name in handled:
23 | continue
24 |
25 | vprint((param_name, mod_name, mod))
26 |
27 | names.append(mod_name)
28 | leaves.append(mod)
29 | handled.add(mod_name)
30 |
31 | return names, leaves
32 |
--------------------------------------------------------------------------------
/memit/transformer_utils/src/transformer_utils/util/python_utils.py:
--------------------------------------------------------------------------------
1 | def make_print_if_verbose(verbose: bool):
2 | def vprint(*args, **kwargs):
3 | if verbose:
4 | print(*args, **kwargs)
5 |
6 | return vprint
7 |
--------------------------------------------------------------------------------
/memit/transformer_utils/src/transformer_utils/util/tfm_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import transformers.file_utils
4 | from transformers.models.auto.configuration_auto import CONFIG_MAPPING
5 |
6 |
7 | def fix_config_with_missing_model_type(model_name, config_path):
8 | with open(config_path, 'r', encoding='utf-8') as f:
9 | config = json.load(f)
10 |
11 | model_type = config.get('model_type')
12 |
13 | # cf https://github.com/huggingface/transformers/blob/v4.5.1/src/transformers/models/auto/configuration_auto.py#L403
14 | #
15 | # we reproduce that logic here, but save the fixed config to the json file
16 | # so it will work more robustly, i.e. even if you are not using `AutoConfig`
17 | if model_type is None:
18 | for pattern, config_class in CONFIG_MAPPING.items():
19 | if pattern in model_name:
20 | config['model_type'] = config_class.model_type
21 |
22 | with open(config_path, 'w', encoding='utf-8') as f:
23 | json.dump(config, f)
24 |
25 |
26 | def get_local_path_from_huggingface_cdn(key, filename):
27 | archive_file = transformers.file_utils.hf_bucket_url(
28 | key,
29 | filename=filename,
30 | )
31 |
32 | resolved_archive_file = transformers.file_utils.cached_path(
33 | archive_file,
34 | )
35 | return resolved_archive_file
36 |
37 |
38 | def huggingface_model_local_paths(model_name):
39 | config_path = get_local_path_from_huggingface_cdn(model_name, "config.json")
40 |
41 | fix_config_with_missing_model_type(model_name, config_path)
42 |
43 | model_path = get_local_path_from_huggingface_cdn(model_name, "pytorch_model.bin")
44 |
45 | return config_path, model_path
46 |
47 |
48 | def normalize_inconsistent_state_dict_keys(state_dict):
49 | normalized = {}
50 |
51 | for k in state_dict.keys():
52 | if k.startswith("transformer."):
53 | normalized[k] = state_dict[k]
54 | else:
55 | normalized["transformer." + k] = state_dict[k]
56 | return normalized
57 |
--------------------------------------------------------------------------------
/memit/util/__init__.py:
--------------------------------------------------------------------------------
1 | from .logit_lens import LogitLens
2 |
--------------------------------------------------------------------------------
/memit/util/generate.py:
--------------------------------------------------------------------------------
1 | import unicodedata
2 | from typing import List, Optional
3 |
4 | import torch
5 | from transformers import AutoModelForCausalLM, AutoTokenizer
6 |
7 | from util.logit_lens import LogitLens
8 |
9 |
10 | def generate_interactive(
11 | model: AutoModelForCausalLM,
12 | tok: AutoTokenizer,
13 | top_k: int = 5,
14 | max_out_len: int = 200,
15 | compare_against: Optional[AutoModelForCausalLM] = None,
16 | use_logit_lens: bool = False,
17 | layer_module_tmp: str = "transformer.h.{}",
18 | ln_f_module: str = "transformer.ln_f",
19 | lm_head_module: str = "lm_head",
20 | ):
21 | """
22 | Puts generation in a loop. Allows users to repeatedly provide inputs
23 | with which text is generated.
24 | """
25 |
26 | if use_logit_lens:
27 | llens_gen = LogitLens(
28 | model,
29 | tok,
30 | layer_module_tmp,
31 | ln_f_module,
32 | lm_head_module,
33 | disabled=not use_logit_lens,
34 | )
35 | if compare_against:
36 | llens_vanilla = LogitLens(
37 | compare_against,
38 | tok,
39 | layer_module_tmp,
40 | ln_f_module,
41 | lm_head_module,
42 | disabled=not use_logit_lens,
43 | )
44 |
45 | while True:
46 | prompt = input("Enter a prompt: ").strip(" \r\t\n")
47 |
48 | print(
49 | f"Argument Model: "
50 | f"{generate_fast(model, tok, [prompt], n_gen_per_prompt=1, top_k=top_k, max_out_len=max_out_len)}"
51 | )
52 | if compare_against:
53 | print(
54 | f"Baseline Model: "
55 | f"{generate_fast(compare_against, tok, [prompt], n_gen_per_prompt=1, top_k=top_k, max_out_len=max_out_len)}"
56 | )
57 |
58 | if use_logit_lens:
59 | inp_prompt = tok([prompt], padding=True, return_tensors="pt").to(
60 | next(model.parameters()).device
61 | )
62 |
63 | with llens_gen:
64 | model(**inp_prompt)
65 | print("\n--- Argument Model Logit Lens ---")
66 | llens_gen.pprint()
67 |
68 | if compare_against:
69 | with llens_vanilla:
70 | compare_against(**inp_prompt)
71 | print("--- Baseline Model Logit Lens ---")
72 | llens_vanilla.pprint()
73 |
74 | print()
75 |
76 |
77 | def generate_fast(
78 | model: AutoModelForCausalLM,
79 | tok: AutoTokenizer,
80 | prompts: List[str],
81 | n_gen_per_prompt: int = 1,
82 | top_k: int = 5,
83 | max_out_len: int = 200,
84 | ):
85 | """
86 | Fast, parallelized auto-regressive text generation with top-k sampling.
87 | Our custom implementation.
88 | """
89 |
90 | # Unroll prompts and tokenize
91 | inp = [prompt for prompt in prompts for _ in range(n_gen_per_prompt)]
92 | inp_tok = tok(inp, padding=True, return_tensors="pt").to(
93 | next(model.parameters()).device
94 | )
95 | input_ids, attention_mask = inp_tok["input_ids"], inp_tok["attention_mask"]
96 | batch_size = input_ids.size(0)
97 |
98 | # Setup storage of fast generation with attention caches.
99 | # `cur_context` is used to define the range of inputs that are not yet
100 | # stored in `past_key_values`. At each step, we are generating the
101 | # next token for the index at `cur_context.stop + 1`.
102 | past_key_values, cur_context = None, slice(0, attention_mask.sum(1).min().item())
103 |
104 | with torch.no_grad():
105 | while input_ids.size(1) < max_out_len: # while not exceeding max output length
106 | model_out = model(
107 | input_ids=input_ids[:, cur_context],
108 | attention_mask=None if 'llama' or 'olmo' in model.name_or_path.lower() else attention_mask[:, cur_context],
109 | past_key_values=past_key_values,
110 | use_cache=True,
111 | )
112 | logits, past_key_values = model_out.logits, model_out.past_key_values
113 | softmax_out = torch.nn.functional.softmax(logits[:, -1, :], dim=1)
114 |
115 | # Top-k sampling
116 | tk = torch.topk(softmax_out, top_k, dim=1).indices
117 | softmax_out_top_k = torch.gather(softmax_out, 1, tk)
118 | softmax_out_top_k = softmax_out_top_k / softmax_out_top_k.sum(1)[:, None]
119 | new_tok_indices = torch.multinomial(softmax_out_top_k, 1)
120 | new_toks = torch.gather(tk, 1, new_tok_indices)
121 |
122 | # If we're currently generating the continuation for the last token in `input_ids`,
123 | # create a new index so we can insert the new token
124 | if cur_context.stop == input_ids.size(1):
125 | attention_mask = torch.cat(
126 | [attention_mask, attention_mask.new_zeros(batch_size, 1)], dim=1
127 | )
128 | input_ids = torch.cat(
129 | [
130 | input_ids,
131 | input_ids.new_ones(batch_size, 1) * tok.pad_token_id,
132 | ],
133 | dim=1,
134 | )
135 |
136 | last_non_masked = attention_mask.sum(1) - 1
137 | for i in range(batch_size):
138 | new_idx = last_non_masked[i] + 1
139 | if last_non_masked[i].item() + 1 != cur_context.stop:
140 | continue
141 |
142 | # Stop generating if we've already maxed out for this prompt
143 | if new_idx < max_out_len:
144 | input_ids[i][new_idx] = new_toks[i]
145 | attention_mask[i][new_idx] = 1
146 |
147 | cur_context = slice(cur_context.stop, cur_context.stop + 1)
148 |
149 | txt = [tok.decode(x) for x in input_ids.detach().cpu().numpy().tolist()]
150 | txt = [
151 | unicodedata.normalize("NFKD", x)
152 | .replace("\n\n", " ")
153 | .replace("<|endoftext|>", "")
154 | for x in txt
155 | ]
156 |
157 | return txt
158 |
--------------------------------------------------------------------------------
/memit/util/globals.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import yaml
4 |
5 | with open("globals.yml", "r") as stream:
6 | data = yaml.safe_load(stream)
7 |
8 | (RESULTS_DIR, DATA_DIR, STATS_DIR, HPARAMS_DIR, KV_DIR) = (
9 | Path(z)
10 | for z in [
11 | data["RESULTS_DIR"],
12 | data["DATA_DIR"],
13 | data["STATS_DIR"],
14 | data["HPARAMS_DIR"],
15 | data["KV_DIR"],
16 | ]
17 | )
18 |
19 | REMOTE_ROOT_URL = data["REMOTE_ROOT_URL"]
20 |
--------------------------------------------------------------------------------
/memit/util/hparams.py:
--------------------------------------------------------------------------------
1 | import json
2 | from dataclasses import dataclass
3 |
4 |
5 | @dataclass
6 | class HyperParams:
7 | """
8 | Simple wrapper to store hyperparameters for Python-based rewriting methods.
9 | """
10 |
11 | @classmethod
12 | def from_json(cls, fpath):
13 | with open(fpath, "r") as f:
14 | data = json.load(f)
15 |
16 | return cls(**data)
17 |
--------------------------------------------------------------------------------
/memit/util/logit_lens.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import Dict, Optional
3 |
4 | import torch
5 | from transformers import AutoModelForCausalLM, AutoTokenizer
6 |
7 | from util import nethook
8 |
9 |
10 | class LogitLens:
11 | """
12 | Applies the LM head at the output of each hidden layer, then analyzes the
13 | resultant token probability distribution.
14 |
15 | Only works when hooking outputs of *one* individual generation.
16 |
17 | Inspiration: https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens
18 |
19 | Warning: when running multiple times (e.g. generation), will return
20 | outputs _only_ for the last processing step.
21 | """
22 |
23 | def __init__(
24 | self,
25 | model: AutoModelForCausalLM,
26 | tok: AutoTokenizer,
27 | layer_module_tmp: str,
28 | ln_f_module: str,
29 | lm_head_module: str,
30 | disabled: bool = False,
31 | ):
32 | self.disabled = disabled
33 | self.model, self.tok = model, tok
34 | self.n_layers = self.model.config.n_layer
35 |
36 | self.lm_head, self.ln_f = (
37 | nethook.get_module(model, lm_head_module),
38 | nethook.get_module(model, ln_f_module),
39 | )
40 |
41 | self.output: Optional[Dict] = None
42 | self.td: Optional[nethook.TraceDict] = None
43 | self.trace_layers = [
44 | layer_module_tmp.format(layer) for layer in range(self.n_layers)
45 | ]
46 |
47 | def __enter__(self):
48 | if not self.disabled:
49 | self.td = nethook.TraceDict(
50 | self.model,
51 | self.trace_layers,
52 | retain_input=False,
53 | retain_output=True,
54 | )
55 | self.td.__enter__()
56 |
57 | def __exit__(self, *args):
58 | if self.disabled:
59 | return
60 | self.td.__exit__(*args)
61 |
62 | self.output = {layer: [] for layer in range(self.n_layers)}
63 |
64 | with torch.no_grad():
65 | for layer, (_, t) in enumerate(self.td.items()):
66 | cur_out = t.output[0]
67 | assert (
68 | cur_out.size(0) == 1
69 | ), "Make sure you're only running LogitLens on single generations only."
70 |
71 | self.output[layer] = torch.softmax(
72 | self.lm_head(self.ln_f(cur_out[:, -1, :])), dim=1
73 | )
74 |
75 | return self.output
76 |
77 | def pprint(self, k=5):
78 | to_print = defaultdict(list)
79 |
80 | for layer, pred in self.output.items():
81 | rets = torch.topk(pred[0], k)
82 | for i in range(k):
83 | to_print[layer].append(
84 | (
85 | self.tok.decode(rets[1][i]),
86 | round(rets[0][i].item() * 1e2) / 1e2,
87 | )
88 | )
89 |
90 | print(
91 | "\n".join(
92 | [
93 | f"{layer}: {[(el[0], round(el[1] * 1e2)) for el in to_print[layer]]}"
94 | for layer in range(self.n_layers)
95 | ]
96 | )
97 | )
98 |
--------------------------------------------------------------------------------
/memit/util/perplexity.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModelForCausalLM, AutoTokenizer
3 |
4 |
5 | def perplexity(
6 | model: AutoModelForCausalLM,
7 | tok: AutoTokenizer,
8 | text: str,
9 | max_input_length: int = None,
10 | ):
11 | """
12 | Computes perplexity of a piece of text, measured on a reference model.
13 | Text is truncated to max_input_length tokens.
14 | """
15 |
16 | inputs = tok(
17 | [text], return_tensors="pt", max_length=max_input_length, truncation=True
18 | ).to("cuda")
19 |
20 | logits = torch.nn.functional.log_softmax(model(**inputs).logits, dim=2)
21 | log_probs = torch.gather(logits[:, :-1, :], 2, inputs["input_ids"][:, 1:, None])[0]
22 |
23 | # Perplexity = exp(-1/N * log P(x_1, ..., x_n))
24 | return torch.exp(-1 / inputs["input_ids"].size(1) * log_probs.sum()).item()
25 |
--------------------------------------------------------------------------------
/memit/zsre_evals.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | # Constants
5 | N_EDITS="10000"
6 |
7 | # Run configurations
8 | MODEL_NAME="EleutherAI/gpt-j-6B"
9 | ALG_NAMES=("FT" "MEND" "ROME" "MEMIT")
10 | HPARAMS_FNAMES=("EleutherAI_gpt-j-6B_wd.json" "EleutherAI_gpt-j-6B.json" "EleutherAI_gpt-j-6B.json" "EleutherAI_gpt-j-6B.json")
11 |
12 | # Execute
13 | for i in ${!ALG_NAMES[@]}
14 | do
15 | alg_name=${ALG_NAMES[$i]}
16 | hparams_fname=${HPARAMS_FNAMES[$i]}
17 |
18 | echo "Running evals for $alg_name..."
19 |
20 | python3 -m experiments.evaluate --alg_name=$alg_name --model_name=$MODEL_NAME --hparams_fname=$hparams_fname --num_edits=$N_EDITS --use_cache --dataset_size_limit=$N_EDITS --ds_name=zsre
21 | done
22 |
23 | exit 0
24 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ai2-olmo==0.2.5
2 | aiohttp==3.9.3
3 | aiosignal==1.3.1
4 | anyio==4.3.0
5 | better-abc==0.0.3
6 | bitsandbytes==0.43.1
7 | datasets==2.18.0
8 | evaluate==0.4.2
9 | huggingface-hub==0.21.4
10 | hydra-core==1.3.2
11 | jupyter
12 | matplotlib==3.7.5
13 | nltk==3.8.1
14 | numpy==1.24.1
15 | openai==1.25.0
16 | pandas==2.0.3
17 | peft==0.10.0
18 | pillow==10.2.0
19 | PyYAML==6.0.1
20 | rouge==1.0.1
21 | rouge-score==0.1.2
22 | scipy==1.10.1
23 | sentencepiece==0.2.0
24 | statistics==1.0.3.5
25 | scikit-learn=1.5.2
26 | tokenizers==0.15.2
27 | tornado==6.4
28 | tqdm==4.66.2
29 | transformer-lens==1.15.0
30 | transformers==4.38.2
31 | urllib3==1.26.13
32 |
33 |
34 |
--------------------------------------------------------------------------------