├── .gitignore ├── LICENSE.txt ├── README.md ├── data ├── benchmark │ ├── popular.json │ ├── random.json │ └── recent.json └── stats │ ├── popular_stats.json │ ├── random_stats.json │ └── recent_stats.json ├── requirements.txt └── src ├── benchmark.py ├── benchmark_statistics.py ├── build_benchmark.py ├── build_benchmark_tests.py ├── build_counterfactual_examples.py ├── build_logical_constraints.py ├── create_relation2optional_targets.py ├── dataset_statistics.py ├── evaluation.py ├── fact.py ├── filter_benchmark_by_model.py ├── generations └── sampled_entities_divided_to_buckets_5000.json ├── memit ├── .gitattributes ├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── baselines │ ├── README.md │ ├── ft │ │ ├── __init__.py │ │ ├── ft_hparams.py │ │ └── ft_main.py │ └── mend │ │ ├── README.md │ │ ├── __init__.py │ │ ├── algs │ │ ├── enn.py │ │ ├── ft.py │ │ └── mend.py │ │ ├── config │ │ ├── alg │ │ │ ├── efk.yaml │ │ │ ├── enn.yaml │ │ │ ├── ft.yaml │ │ │ └── mend.yaml │ │ ├── config.yaml │ │ ├── experiment │ │ │ ├── fc.yaml │ │ │ ├── gen.yaml │ │ │ └── qa.yaml │ │ └── model │ │ │ ├── bart-base.yaml │ │ │ ├── bert-base.yaml │ │ │ ├── distilgpt2.yaml │ │ │ ├── gpt2.yaml │ │ │ ├── gpt2large.yaml │ │ │ ├── gpt2medium.yaml │ │ │ ├── gpt2xl.yaml │ │ │ ├── gptj.yaml │ │ │ ├── gptneo27.yaml │ │ │ ├── t5large.yaml │ │ │ ├── t5small.yaml │ │ │ ├── t5xl.yaml │ │ │ └── t5xxl.yaml │ │ ├── data_classes │ │ ├── fever.py │ │ ├── nq.py │ │ ├── wiki.py │ │ └── zsre.py │ │ ├── editable_model.py │ │ ├── hooks.py │ │ ├── losses.py │ │ ├── mend_hparams.py │ │ ├── mend_main.py │ │ ├── models.py │ │ ├── nn.py │ │ ├── oracle.py │ │ ├── requirements.txt │ │ ├── run.py │ │ ├── trainer.py │ │ └── utils.py ├── dsets │ ├── __init__.py │ ├── attr_snippets.py │ ├── counterfact.py │ ├── knowns.py │ ├── tfidf_stats.py │ └── zsre.py ├── experiments │ ├── __init__.py │ ├── causal_trace.py │ ├── evaluate.py │ ├── plot_causal_trace_avg.py │ ├── py │ │ ├── demo.py │ │ ├── eval_utils_counterfact.py │ │ └── eval_utils_zsre.py │ ├── summarize.py │ └── sweep.py ├── globals.yml ├── hparams │ ├── FT │ │ ├── EleutherAI_gpt-j-6B_constr.json │ │ ├── EleutherAI_gpt-j-6B_unconstr.json │ │ ├── EleutherAI_gpt-j-6B_wd.json │ │ ├── gpt2-large_constr.json │ │ ├── gpt2-medium_constr.json │ │ ├── gpt2-xl_attn.json │ │ ├── gpt2-xl_constr.json │ │ └── gpt2-xl_unconstr.json │ ├── MEMIT │ │ ├── EleutherAI_gpt-j-6B.json │ │ └── gpt2-xl.json │ ├── MEND │ │ ├── EleutherAI_gpt-j-6B.json │ │ ├── EleutherAI_gpt-j-6B_CF.json │ │ ├── gpt2-xl.json │ │ ├── gpt2-xl_CF.json │ │ └── gpt2-xl_zsRE.json │ └── ROME │ │ ├── EleutherAI_gpt-j-6B.json │ │ ├── EleutherAI_gpt-neox-20b.json │ │ ├── gpt2-large.json │ │ ├── gpt2-medium.json │ │ ├── gpt2-xl.json │ │ └── llama-7b.json ├── memit │ ├── __init__.py │ ├── compute_ks.py │ ├── compute_z.py │ ├── memit_hparams.py │ └── memit_main.py ├── notebooks │ ├── average_causal_effects.ipynb │ ├── baselines │ ├── causal_trace.ipynb │ ├── causal_trace_frozen_mlp_attn.ipynb │ ├── data │ ├── dsets │ ├── experiments │ ├── globals.yml │ ├── hparams │ ├── memit │ ├── memit.ipynb │ ├── rome │ ├── util │ └── vis │ │ ├── experiments │ │ ├── globals.yml │ │ ├── table_population.ipynb │ │ ├── table_population_zsre.ipynb │ │ ├── util │ │ ├── visualize_multi_results.ipynb │ │ └── visualize_sweep_results.ipynb ├── rome │ ├── README.md │ ├── __init__.py │ ├── compute_u.py │ ├── compute_v.py │ ├── layer_stats.py │ ├── repr_tools.py │ ├── rome_hparams.py │ ├── rome_main.py │ └── tok_dataset.py ├── scaling_curves.sh ├── scripts │ ├── causal_trace.sh │ ├── colab_reqs │ │ ├── additional.txt │ │ └── rome.txt │ ├── collect_layer_stats.sh │ ├── ipynb_drop_output.py │ ├── memit.yml │ ├── setup_clean_ipynb.sh │ └── setup_conda.sh ├── util │ ├── __init__.py │ ├── generate.py │ ├── globals.py │ ├── hparams.py │ ├── logit_lens.py │ ├── nethook.py │ ├── perplexity.py │ └── runningstats.py └── zsre_evals.sh ├── modeleditor.py ├── query.py ├── queryexecutor.py ├── relation.py ├── testcase.py ├── testrunner.py ├── two_hop_phrases.py ├── utils.py └── wikidata ├── config.py ├── ent_label2id.json.zip ├── ent_to_neighbourhood_subgraph.py ├── ent_to_num_of_facts.py ├── most_viewed_entities.py ├── recently_modified_facts.py ├── relation_to_optional_targets.py ├── relations.py ├── sample_facts_to_edit.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Eden Biran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Evaluating the Ripple Effects of Knowledge Editing in Language Models 2 | 3 | This repository contains the official code of the paper: ["Evaluating the Ripple Effects of Knowledge Editing in Language Models"](https://arxiv.org/abs/2307.12976). 4 | 5 | ## Setup 6 | 7 | The benchmark creation and all experiments and evaluations were conducted in a Python 3.9 environment. 8 | To clone the repository and set up the environment, please run the following commands: 9 | ```shell 10 | git clone https://github.com/edenbiran/RippleEdits.git 11 | cd RippleEdits 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## RippleEdits Benchmark 16 | 17 | The benchmark files and statistics can be found under `data/benchmark/` and `data/stats/`. 18 | The benchmark is split into three files named according to the benchmark\`s three subsets: `RECENT`, `RANDOM` and `POPULAR`. 19 | For more details please refer to section 4 of the paper. 20 | 21 | The source code for generating the benchmark can be found under `src/`. 22 | Generating the benchmark from scratch can be done using `src/build_benchmark.py`. 23 | Benchmark popularity statistics can be extracted using `src/benchmark_statistics.py`. 24 | 25 | Each benchmark json contains a list of entries. 26 | Each entry is an edit containing the edit information (which also contains the original fact if applicable) and the 6 evaluation criteria. 27 | Each evaluation criteria contains a list of tests, where each test contains the test prompt, answers and conditions. 28 | An example (shortened for brevity) of an edit entry can be seen below: 29 | ```json 30 | { 31 | "example_type": "popular", 32 | "edit": { 33 | "prompt": "The name of the country of citizenship of Leonardo DiCaprio is Syria.", 34 | "subject_id": "Q38111", 35 | "relation": "COUNTRY_OF_CITIZENSHIP", 36 | "target_id": "Q858", 37 | "original_fact": { 38 | "prompt": "The name of the country of citizenship of Leonardo DiCaprio is United States of America.", 39 | "subject_id": "Q38111", 40 | "relation": "COUNTRY_OF_CITIZENSHIP", 41 | "target_id": "Q30" 42 | } 43 | }, 44 | "Relation_Specifity": [ 45 | { 46 | "test_queries": [ 47 | { 48 | "prompt": "The name of the mother of Leonardo DiCaprio is", 49 | "answers": [ 50 | { 51 | "value": "Irmelin DiCaprio", 52 | "aliases": [ 53 | "Irmelin Indenbirken", 54 | "Irmelin Indenbirken-DiCaprio" 55 | ] 56 | } 57 | ], 58 | "query_type": "regular", 59 | "subject_id": "Q38111", 60 | "relation": "MOTHER", 61 | "target_ids": [ 62 | "Q22984557" 63 | ], 64 | "phrase": null 65 | } 66 | ], 67 | "test_condition": "OR", 68 | "condition_queries": [ 69 | { 70 | "prompt": "The name of the mother of Leonardo DiCaprio is", 71 | "answers": [ 72 | { 73 | "value": "Irmelin DiCaprio", 74 | "aliases": [ 75 | "Irmelin Indenbirken", 76 | "Irmelin Indenbirken-DiCaprio" 77 | ] 78 | } 79 | ], 80 | "query_type": "regular", 81 | "subject_id": "Q38111", 82 | "relation": "MOTHER", 83 | "target_ids": [ 84 | "Q22984557" 85 | ], 86 | "phrase": null 87 | } 88 | ] 89 | }, 90 | ... 91 | ], 92 | "Logical_Generalization": [...], 93 | "Subject_Aliasing": [...], 94 | "Compositionality_I": [...], 95 | "Compositionality_II": [...], 96 | "Forgetfulness": [...] 97 | } 98 | ``` 99 | 100 | ## Evaluation 101 | 102 | The source code for all evaluations of the benchmark can be found under `src/`. 103 | All evaluations can be conducted using `src/evaluation.py`. 104 | 105 | In order to evaluate the benchmark on a language model not currently supported extend the class `QueryExecutor` in `src/queryexecutor.py` and add the new `QueryExecutor` to `src/evaluation.py`. 106 | 107 | In order to evaluate the benchmark on a knowledge editing technique not currently supported extend the class `ModelEditor` in `src/modeleditor.py` and add the new `ModelEditor` to `src/evaluation.py`. 108 | 109 | ## Citation 110 | ``` 111 | @article{cohen2024evaluating, 112 | title={Evaluating the ripple effects of knowledge editing in language models}, 113 | author={Cohen, Roi and Biran, Eden and Yoran, Ori and Globerson, Amir and Geva, Mor}, 114 | journal={Transactions of the Association for Computational Linguistics}, 115 | volume={12}, 116 | pages={283--298}, 117 | year={2024}, 118 | publisher={MIT Press One Broadway, 12th Floor, Cambridge, Massachusetts 02142, USA~…} 119 | } 120 | ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.14.0 2 | higher==0.2.1 3 | hydra-core==1.3.2 4 | jsonlines==3.1.0 5 | matplotlib==3.7.2 6 | nltk==3.8.1 7 | numpy==1.24.4 8 | omegaconf==2.3.0 9 | openai==0.27.8 10 | pandas==2.0.3 11 | PyYAML==6.0.1 12 | qwikidata==0.4.2 13 | Requests==2.31.0 14 | scikit_learn==1.3.0 15 | scipy==1.11.1 16 | torch==2.0.1+cu118 17 | tqdm==4.65.0 18 | transformers==4.31.0 19 | wandb==0.15.6 20 | wptools==0.4.17 21 | -------------------------------------------------------------------------------- /src/benchmark_statistics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import wptools 6 | import functools 7 | from tqdm import tqdm 8 | import matplotlib.pyplot as plt 9 | 10 | from benchmark import Dataset 11 | 12 | 13 | plt.rcParams.update({'text.usetex': True}) 14 | plt.rcParams.update({'font.size': 16}) 15 | plt.rcParams.update({'figure.figsize': (14, 3)}) 16 | 17 | 18 | @functools.lru_cache() 19 | def get_entity_info(entity, is_id=True): 20 | claim_count = None 21 | backlinks = None 22 | views = None 23 | 24 | try: 25 | if is_id: 26 | page = wptools.page(wikibase=entity, silent=True) 27 | else: 28 | page = wptools.page(entity, silent=True) 29 | page.REQUEST_LIMIT = 500 30 | page.get_wikidata() 31 | claim_count = sum([len(values) for _, values in page.data['claims'].items()]) 32 | page.get_more() 33 | backlinks = len(page.data['backlinks']) 34 | views = page.data['views'] 35 | except (LookupError, StopIteration) as e: 36 | print(f'Error looking up {entity}: {e}') 37 | 38 | return claim_count, backlinks, views 39 | 40 | 41 | def get_axis_stats(tests): 42 | test_count = len(tests) 43 | test_query_count = 0 44 | condition_query_count = 0 45 | for testcase in tests: 46 | test_query_count += len(testcase.get_test_queries()) 47 | condition_query_count += len(testcase.get_condition_queries()) 48 | return test_count, test_query_count, condition_query_count 49 | 50 | 51 | def get_example_stats(example): 52 | example_type = type(example).__name__ 53 | 54 | subject_id = example.fact._subject_id 55 | subject_claim_count, subject_backlinks, subject_views = get_entity_info(subject_id) 56 | 57 | target_id = example.fact._target_id 58 | target_claim_count, target_backlinks, target_views = get_entity_info(target_id) 59 | 60 | relation = example.fact._relation.name 61 | 62 | axis_stats = [ 63 | get_axis_stats(example.making_up_tests), 64 | get_axis_stats(example.logical_constraints), 65 | get_axis_stats(example.subject_paraphrasing_tests), 66 | get_axis_stats(example.two_hop_tests), 67 | get_axis_stats(example.forward_two_hop_tests), 68 | get_axis_stats(example.prev_storage_tests) 69 | ] 70 | test_count, test_query_count, condition_query_count = (sum(x) for x in zip(*axis_stats)) 71 | 72 | return { 73 | 'example_type': example_type, 74 | 75 | 'subject_id': subject_id, 76 | 'subject_claim_count': subject_claim_count, 77 | 'subject_backlinks': subject_backlinks, 78 | 'subject_views': subject_views, 79 | 80 | 'target_id': target_id, 81 | 'target_claim_count': target_claim_count, 82 | 'target_backlinks': target_backlinks, 83 | 'target_views': target_views, 84 | 85 | 'relation': relation, 86 | 87 | 'test_count': test_count, 88 | 'test_query_count': test_query_count, 89 | 'condition_query_count': condition_query_count 90 | } 91 | 92 | 93 | def relation_counts_to_axis(counts): 94 | return [s.replace('_', ' ').lower() for s in counts.index], counts.values * 100 95 | 96 | 97 | def display_statistics(dfs, args): 98 | for df in dfs: 99 | df['avg_conditions_per_test'] = df['condition_query_count'] / df['test_count'] 100 | print('Statistics:') 101 | print(df.describe().to_string()) 102 | print('--------------------------') 103 | 104 | print('Relations:') 105 | print(df['relation'].value_counts(normalize=True)) 106 | 107 | if args.plot: 108 | fig, axes = plt.subplots(1, len(dfs), sharey=True) 109 | for i, (ax, df, title) in enumerate(zip(axes, dfs, args.titles)): 110 | x, y = relation_counts_to_axis(df['relation'].value_counts(normalize=True)[:10]) 111 | ax.bar(x, y) 112 | start, end = ax.get_ylim() 113 | ax.yaxis.set_ticks(np.arange(start, end, 5)) 114 | ax.set_xticklabels(x, rotation=60, ha='right', rotation_mode='anchor') 115 | ax.set_title('\\textsc{%s}' % title) 116 | if i == 0: 117 | ax.set_ylabel('\% of edits') 118 | fig.savefig(args.plot, bbox_inches='tight') 119 | 120 | 121 | def main(args): 122 | if args.benchmark: 123 | print(f'Loading benchmark from {args.benchmark}') 124 | dataset = Dataset.from_file(args.benchmark) 125 | 126 | print('Collecting statistics') 127 | stats = [] 128 | for example in tqdm(dataset.examples): 129 | stats.append(get_example_stats(example)) 130 | stats_df = pd.DataFrame(stats) 131 | 132 | if not args.statistics: 133 | args.statistics = 'statistics.json' 134 | print(f'Saving statistics to {args.statistics}') 135 | stats_df.to_json(args.statistics) 136 | stats_dfs = [stats_df] 137 | 138 | elif args.statistics: 139 | print(f'Loading statistics from {args.statistics}') 140 | stats_dfs = [pd.read_json(s) for s in args.statistics] 141 | 142 | else: 143 | raise Exception('Wrong arguments given') 144 | 145 | display_statistics(stats_dfs, args) 146 | 147 | 148 | if __name__ == '__main__': 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument('-b', '--benchmark', help='The benchmark file path') 151 | parser.add_argument('-s', '--statistics', nargs='*', help='The statistics file paths. ' 152 | 'One path if output, multiple if input.') 153 | parser.add_argument('-p', '--plot', help='The relations plot file path') 154 | parser.add_argument('-t', '--titles', nargs='*', help='The relations plot titles') 155 | main(parser.parse_args()) 156 | -------------------------------------------------------------------------------- /src/build_benchmark_tests.py: -------------------------------------------------------------------------------- 1 | from wikidata.relations import our_relations, relation2impacted_relations, relation2phrase 2 | from wikidata.utils import subject_relation_to_targets, ent_to_relation_ids, get_label, get_aliases, get_description, \ 3 | subjects_given_relation_target 4 | from build_logical_constraints import generate_constraints 5 | from utils import create_test_example_given_input_targets 6 | from relation import Relation 7 | from query import Query, TwoHopQuery 8 | from testcase import TestCase 9 | from two_hop_phrases import relation_couple_to_phrase 10 | 11 | 12 | def making_up_axis(subject_id: str, relation: Relation): 13 | tests = [] 14 | 15 | if relation not in Relation: 16 | return tests 17 | 18 | impacted_relations = relation.impacted_relations() 19 | for other_relation in Relation: 20 | if other_relation == relation or other_relation in impacted_relations: 21 | continue 22 | corresponding_targets = subject_relation_to_targets(subject_id, other_relation) 23 | if not corresponding_targets: 24 | continue 25 | test_query = Query(subject_id, other_relation, corresponding_targets) 26 | condition_queries = [test_query] 27 | tests.append(TestCase(test_query=test_query, condition_queries=condition_queries)) 28 | 29 | return tests 30 | 31 | 32 | def logical_constraints_axis(subject_id: str, relation: Relation, target_id: str): 33 | return generate_constraints(subject_id, relation, target_id) 34 | 35 | 36 | def subject_aliasing_axis(subject_id: str, relation: Relation, target_id: str): 37 | tests = [] 38 | subject_aliases = get_aliases(subject_id) 39 | for alias in subject_aliases: 40 | phrase = relation.phrase(alias) 41 | test_query = Query(subject_id, relation, target_id, phrase) 42 | condition_queries = [test_query] 43 | tests.append(TestCase(test_query=test_query, condition_queries=condition_queries)) 44 | return tests 45 | 46 | 47 | def two_hop_axis(subject_id: str, relation: Relation, target_id: str): 48 | tests = [] 49 | if not target_id or target_id[0] != 'Q': 50 | return tests 51 | target_relations = ent_to_relation_ids(target_id) 52 | for relation_id in target_relations: 53 | second_relation_enum = Relation.id_to_enum(relation_id) 54 | if second_relation_enum is None: 55 | continue 56 | second_hop_targets = subject_relation_to_targets(target_id, second_relation_enum) 57 | for second_hop_target in second_hop_targets: 58 | phrase = relation_couple_to_phrase(relation, second_relation_enum) 59 | if phrase is None: 60 | continue 61 | phrase = phrase.replace('', get_label(subject_id)) 62 | test_query = TwoHopQuery(subject_id, relation, target_id, second_relation_enum, second_hop_target, phrase) 63 | condition_queries = [Query(target_id, second_relation_enum, second_hop_target)] 64 | tests.append(TestCase(test_query=test_query, condition_queries=condition_queries)) 65 | return tests 66 | 67 | 68 | def forward_two_hop_axis(subject_id: str, relation: Relation, target_id: str): 69 | tests = [] 70 | if not target_id or target_id[0] != 'Q': 71 | return tests 72 | for backward_relation in Relation: 73 | backward_relation_id = backward_relation.id() 74 | backward_subjects = subjects_given_relation_target(backward_relation_id, subject_id) 75 | for backward_subject in backward_subjects: 76 | phrase = relation_couple_to_phrase(backward_relation, relation) 77 | if phrase is None: 78 | continue 79 | phrase = phrase.replace('', get_label(backward_subject)) 80 | test_query = TwoHopQuery(backward_subject, backward_relation, subject_id, relation, target_id, phrase) 81 | condition_queries = [Query(backward_subject, backward_relation, subject_id)] 82 | tests.append(TestCase(test_query=test_query, condition_queries=condition_queries)) 83 | return tests 84 | 85 | 86 | # def temporal_axis(subject_id: str, relation: Relation, previous_target_id: str): 87 | # tests = [] 88 | # if relation.is_modification(): 89 | # return tests 90 | # test_query = Query(subject_id, relation, previous_target_id) 91 | # condition_queries = [test_query] 92 | # tests.append(TestCase(test_query=test_query, condition_queries=condition_queries)) 93 | # return tests 94 | 95 | def temporal_axis(subject_id: str, relation: Relation, target_id: str): 96 | tests = [] 97 | if relation.is_modification(): 98 | return tests 99 | wikidata_targets = subject_relation_to_targets(subject_id, relation) 100 | relational_phrase = relation.phrase(get_label(subject_id)) 101 | if 'is' in relational_phrase: 102 | prefix = relational_phrase[:-3] 103 | elif 'are' in relational_phrase: 104 | prefix = relational_phrase[:-4] 105 | phrase = prefix + f', which is not {get_label(target_id)}, is' 106 | test_query = Query(subject_id, relation, wikidata_targets, phrase=phrase) 107 | condition_queries = [test_query] 108 | tests.append(TestCase(test_query=test_query, condition_queries=condition_queries)) 109 | return tests 110 | 111 | 112 | # for test in subject_aliasing_axis('Q42', 'occupation', 'Q36834'): 113 | # print(test) 114 | 115 | 116 | -------------------------------------------------------------------------------- /src/build_counterfactual_examples.py: -------------------------------------------------------------------------------- 1 | import random 2 | from wikidata.utils import load_json, write_json 3 | 4 | 5 | if __name__ == '__main__': 6 | relation2optional_targets = load_json('./wikidata/relation2optional_targets.json') 7 | sampled_facts = load_json('./wikidata/100_sampled_facts.json') 8 | dataset = [] 9 | for fact in sampled_facts: 10 | subject, relation_target = fact 11 | relation, target = relation_target 12 | optional_targets = relation2optional_targets[relation] 13 | random_target = random.sample(optional_targets, 1)[0] 14 | counterfactual = (subject, relation, random_target) 15 | dataset.append({'fact': (subject, relation, target), 'counterfactual': counterfactual}) 16 | 17 | print(fact) 18 | print(counterfactual) 19 | print('\n') 20 | 21 | write_json(dataset, './generations/fact_and_counterfactual_samples1.json') -------------------------------------------------------------------------------- /src/create_relation2optional_targets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from collections import defaultdict 5 | from relation import Relation 6 | 7 | 8 | checkable_relations = [relation.formal_name() for relation in Relation] 9 | 10 | 11 | def get_relation2optional_targets(wikidata_dir: str): 12 | relevant_files = [] 13 | for file in os.listdir(wikidata_dir): 14 | if file[-5:] == '.json': 15 | relevant_files.append(os.path.join(wikidata_dir, file)) 16 | 17 | result_dict = defaultdict(set) 18 | for i, path in enumerate(relevant_files): 19 | print(f'{i+1}/{len(relevant_files)}') 20 | with open(path, 'r+', encoding='utf-8') as f: 21 | curr_part = json.load(f) 22 | for subject, facts in curr_part.items(): 23 | for relation, target in facts: 24 | if relation in checkable_relations and len(result_dict[relation]) < 100000: 25 | result_dict[relation].add(target) 26 | 27 | result_dict = {relation: list(targets) for relation, targets in result_dict.items()} 28 | return result_dict 29 | 30 | 31 | if __name__ == '__main__': 32 | wikidata_dir = './wikidata/wikidata_full_kg/filtered_relations' 33 | relation2optional_targets = get_relation2optional_targets(wikidata_dir) 34 | with open('./wikidata/relation2optional_targets_new_limited.json', 'w+', encoding='utf-8') as f: 35 | json.dump(relation2optional_targets, f) 36 | print(relation2optional_targets.keys()) 37 | print(len(relation2optional_targets)) -------------------------------------------------------------------------------- /src/dataset_statistics.py: -------------------------------------------------------------------------------- 1 | from build_benchmark import construct_recently_modified_benchmark 2 | from wikidata.utils import load_json, ent_label2id 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | recently_modified_facts = construct_recently_modified_benchmark() 7 | 8 | # number of facts / popularities 9 | collection = [] 10 | for example in recently_modified_facts: 11 | ent2num_of_facts = load_json('./subject2num_of_facts.json') 12 | subject_id = example.fact._subject_id 13 | collection.append(ent2num_of_facts[ent_label2id(subject_id)]) 14 | print(len(collection)) 15 | -------------------------------------------------------------------------------- /src/fact.py: -------------------------------------------------------------------------------- 1 | from relation import Relation 2 | from wikidata.utils import get_label 3 | from query import Query 4 | 5 | 6 | class Fact: 7 | 8 | def __init__(self, subject_id, relation, target_id): 9 | self._subject_id = subject_id 10 | self._relation = relation 11 | self._target_id = target_id 12 | 13 | def get_subject_label(self): 14 | return get_label(self._subject_id) 15 | 16 | def get_target_label(self): 17 | return get_label(self._target_id) 18 | 19 | def get_relation_label(self): 20 | return self._relation.name.replace('_', ' ') 21 | 22 | def get_fact_query(self): 23 | return Query(self._subject_id, self._relation, self._target_id) 24 | 25 | def get_fact_prompt(self): 26 | return self._relation.phrase(get_label(self._subject_id)) 27 | 28 | def get_fact_phrased(self): 29 | return self._relation.phrase(get_label(self._subject_id)) + f' {get_label(self._target_id)}.' 30 | 31 | def to_dict(self): 32 | return { 33 | 'prompt': self.get_fact_phrased(), 34 | 'subject_id': self._subject_id, 35 | 'relation': self._relation.name, 36 | 'target_id': self._target_id 37 | } 38 | 39 | @staticmethod 40 | def from_dict(d): 41 | return Fact(d['subject_id'], Relation[d['relation']], d['target_id']) 42 | 43 | def __str__(self): 44 | return f'({self.get_subject_label()}, {self.get_relation_label()}, {self.get_target_label()})' 45 | 46 | -------------------------------------------------------------------------------- /src/filter_benchmark_by_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from benchmark import Dataset, RecentlyAddedExample, CounterFactualExample 4 | from queryexecutor import GPT2QueryExecutor, GPTJQueryExecutor, GPTNeoXQueryExecutor, LlamaQueryExecutor, \ 5 | GPT3QueryExecutor 6 | from testrunner import TestRunner, ExampleResult, TestResult 7 | 8 | 9 | def get_query_executor(model_name): 10 | if model_name.startswith('gpt2-'): 11 | mode_size = model_name.split('-')[1] 12 | return GPT2QueryExecutor(mode_size) 13 | elif model_name == 'gpt-j': 14 | return GPTJQueryExecutor() 15 | elif model_name == 'gpt-neox': 16 | return GPTNeoXQueryExecutor() 17 | elif model_name.startswith('llama-'): 18 | mode_size = model_name.split('-')[1] 19 | return LlamaQueryExecutor(mode_size) 20 | elif model_name == 'gpt-3': 21 | return GPT3QueryExecutor() 22 | else: 23 | raise Exception('Unknown model name') 24 | 25 | 26 | def filter_tests(test_runner, example, testcases, include_all_facts): 27 | example_result, test_results = test_runner.run_testcases(example, testcases) 28 | if not include_all_facts and example_result != ExampleResult.EXECUTED: 29 | return None 30 | return [test for test in testcases if test not in test_results[TestResult.NOT_EXECUTED]] 31 | 32 | 33 | def main(args): 34 | print('Loading dataset') 35 | dataset = Dataset.from_file(args.benchmark) 36 | print('Loading model') 37 | query_executor = get_query_executor(args.model) 38 | test_runner = TestRunner(query_executor, None) 39 | filtered_examples = [] 40 | test_count = 0 41 | filtered_test_count = 0 42 | 43 | for i, example in enumerate(dataset.examples): 44 | print(f'Example {i + 1} / {len(dataset.examples)}: {example.fact.to_dict()}') 45 | 46 | test_count += len(example.making_up_tests) + len(example.logical_constraints) + \ 47 | len(example.subject_paraphrasing_tests) + len(example.two_hop_tests) + \ 48 | len(example.prev_storage_tests) 49 | prev_filtered_test_count = filtered_test_count 50 | 51 | filtered_making_up_tests = filter_tests(test_runner, example, example.making_up_tests, args.include_all_facts) 52 | filtered_test_count += len(filtered_making_up_tests) 53 | 54 | if filtered_making_up_tests is None: # Example shouldn't be included at all 55 | continue 56 | 57 | filtered_logical_constraints = filter_tests(test_runner, example, example.logical_constraints, 58 | args.include_all_facts) 59 | filtered_test_count += len(filtered_logical_constraints) 60 | 61 | filtered_subject_paraphrasing_tests = filter_tests(test_runner, example, example.subject_paraphrasing_tests, 62 | args.include_all_facts) 63 | filtered_test_count += len(filtered_subject_paraphrasing_tests) 64 | 65 | filtered_two_hop_tests = filter_tests(test_runner, example, example.two_hop_tests, args.include_all_facts) 66 | filtered_test_count += len(filtered_two_hop_tests) 67 | 68 | filtered_prev_storage_tests = filter_tests(test_runner, example, example.prev_storage_tests, 69 | args.include_all_facts) 70 | filtered_test_count += len(filtered_prev_storage_tests) 71 | 72 | if prev_filtered_test_count == filtered_test_count: # Example has no tests that passed the filter 73 | continue 74 | 75 | if isinstance(example, RecentlyAddedExample): 76 | filtered_examples.append(RecentlyAddedExample(example.fact, filtered_making_up_tests, 77 | filtered_logical_constraints, 78 | filtered_subject_paraphrasing_tests, 79 | filtered_two_hop_tests, filtered_prev_storage_tests)) 80 | elif isinstance(example, CounterFactualExample): 81 | filtered_examples.append(CounterFactualExample(example.fact, example.previous_fact, 82 | filtered_making_up_tests, filtered_logical_constraints, 83 | filtered_subject_paraphrasing_tests, filtered_two_hop_tests, 84 | filtered_prev_storage_tests)) 85 | 86 | print(f'Filtered dataset has {len(filtered_examples)} / {len(dataset.examples)} examples') 87 | print(f'Filtered dataset has {filtered_test_count} / {test_count} tests') 88 | print('Saving filtered dataset') 89 | filtered_dataset = Dataset(filtered_examples) 90 | filtered_dataset.to_file(args.output) 91 | 92 | print('Done') 93 | 94 | 95 | if __name__ == '__main__': 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument('benchmark', help='The benchmark file path') 98 | parser.add_argument('model', help='The model name', choices=['gpt2-medium', 'gpt2-large', 'gpt2-xl', 99 | 'gpt-j', 'gpt-neox', 100 | 'llama-7b', 'llama-13b', 101 | 'gpt-3']) 102 | parser.add_argument('output', help='The output filtered benchmark file path') 103 | parser.add_argument('--include-all-facts', action=argparse.BooleanOptionalAction, default=True, 104 | help='Whether to include all facts or only facts that pass the known/unknown test') 105 | 106 | main(parser.parse_args()) 107 | -------------------------------------------------------------------------------- /src/memit/.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb filter=clean_ipynb 2 | -------------------------------------------------------------------------------- /src/memit/.gitignore: -------------------------------------------------------------------------------- 1 | # Pipeline dumps, data directory 2 | results 3 | data 4 | !notebooks/data 5 | *_tmp_*_.json 6 | *_kmeng01g1gn* 7 | 8 | # Pre-trained hypernetworks 9 | baselines/*/weights 10 | 11 | # Mac specific 12 | .idea 13 | .vscode 14 | .DS_Store 15 | 16 | # Latex 17 | *.aux 18 | *.dvi 19 | *.fdb_latexmk 20 | *.fls 21 | *.log 22 | *.pdf 23 | *.synctex.gz 24 | *.out 25 | *.toc 26 | *.nps 27 | 28 | # Byte-compiled / optimized / DLL files 29 | __pycache__/ 30 | *.py[cod] 31 | *$py.class 32 | *.py.swp 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | share/python-wheels/ 52 | *.egg-info/ 53 | .installed.cfg 54 | *.egg 55 | MANIFEST 56 | 57 | # PyInstaller 58 | # Usually these files are written by a python script from a template 59 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 60 | *.manifest 61 | *.spec 62 | 63 | # Installer logs 64 | pip-log.txt 65 | pip-delete-this-directory.txt 66 | 67 | # Unit test / coverage reports 68 | htmlcov/ 69 | .tox/ 70 | .nox/ 71 | .coverage 72 | .coverage.* 73 | .cache 74 | nosetests.xml 75 | coverage.xml 76 | *.cover 77 | *.py,cover 78 | .hypothesis/ 79 | .pytest_cache/ 80 | cover/ 81 | 82 | # Translations 83 | *.mo 84 | *.pot 85 | 86 | # Django stuff: 87 | *.log 88 | local_settings.py 89 | db.sqlite3 90 | db.sqlite3-journal 91 | 92 | # Flask stuff: 93 | instance/ 94 | .webassets-cache 95 | 96 | # Scrapy stuff: 97 | .scrapy 98 | 99 | # Sphinx documentation 100 | docs/_build/ 101 | 102 | # PyBuilder 103 | .pybuilder/ 104 | target/ 105 | 106 | # Jupyter Notebook 107 | .ipynb_checkpoints 108 | 109 | # IPython 110 | profile_default/ 111 | ipython_config.py 112 | 113 | # pyenv 114 | # For a library or package, you might want to ignore these files since the code is 115 | # intended to run in multiple environments; otherwise, check them in: 116 | # .python-version 117 | 118 | # pipenv 119 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 120 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 121 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 122 | # install all needed dependencies. 123 | #Pipfile.lock 124 | 125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 126 | __pypackages__/ 127 | 128 | # Celery stuff 129 | celerybeat-schedule 130 | celerybeat.pid 131 | 132 | # SageMath parsed files 133 | *.sage.py 134 | 135 | # Environments 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | -------------------------------------------------------------------------------- /src/memit/CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | preferred-citation: 4 | type: article 5 | authors: 6 | - family-names: "Meng" 7 | given-names: "Kevin" 8 | - family-names: "Sen Sharma" 9 | given-names: "Arnab" 10 | - family-names: "Andonian" 11 | given-names: "Alex" 12 | - family-names: "Belinkov" 13 | given-names: "Yonatan" 14 | - family-names: "Bau" 15 | given-names: "David" 16 | journal: "arXiv preprint arXiv:2210.07229" 17 | title: "Mass-Editing Memory in a Transformer" 18 | year: 2022 -------------------------------------------------------------------------------- /src/memit/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Kevin Meng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/memit/README.md: -------------------------------------------------------------------------------- 1 | # MEMIT: Mass-Editing Memory in a Transformer 2 | 3 | Editing thousands of facts into a transformer memory at once. 4 | 5 | 6 | 7 | ## Table of Contents 8 | 9 | - [Installation](#installation) 10 | - [MEMIT Algorithm Demo](#memit-algorithm-demo) 11 | - [Running the Full Evaluation Suite](#running-the-full-evaluation-suite) 12 | - [Generating Scaling Curves](#generating-scaling-curves) 13 | - [How to Cite](#how-to-cite) 14 | 15 | ## Installation 16 | 17 | We recommend `conda` for managing Python, CUDA, and PyTorch; `pip` is for everything else. To get started, simply install `conda` and run: 18 | ```bash 19 | CONDA_HOME=$CONDA_HOME ./scripts/setup_conda.sh 20 | ``` 21 | 22 | `$CONDA_HOME` should be the path to your `conda` installation, e.g., `~/miniconda3`. 23 | 24 | ## MEMIT Algorithm Demo 25 | 26 | [`notebooks/memit.ipynb`](notebooks/memit.ipynb) demonstrates MEMIT. The API is simple; simply specify a *requested rewrite* of the following form: 27 | 28 | ```python 29 | request = [ 30 | { 31 | "prompt": "{} plays the sport of", 32 | "subject": "LeBron James", 33 | "target_new": { 34 | "str": "football" 35 | } 36 | }, 37 | { 38 | "prompt": "{} plays the sport of", 39 | "subject": "Michael Jordan", 40 | "target_new": { 41 | "str": "baseball" 42 | } 43 | }, 44 | ] 45 | ``` 46 | 47 | Other similar example(s) are included in the notebook. 48 | 49 | ## Running the Full Evaluation Suite 50 | 51 | [`experiments/evaluate.py`](experiments/evaluate.py) can be used to evaluate any method in [`baselines/`](baselines/). 52 | 53 | For example: 54 | ``` 55 | python3 -m experiments.evaluate \ 56 | --alg_name=MEMIT \ 57 | --model_name=EleutherAI/gpt-j-6B \ 58 | --hparams_fname=EleutherAI_gpt-j-6B.json \ 59 | --num_edits=10000 \ 60 | --use_cache 61 | ``` 62 | Results from each run are stored at `results//run_` in a specific format: 63 | ```bash 64 | results/ 65 | |__ MEMIT/ 66 | |__ run_/ 67 | |__ params.json 68 | |__ case_0.json 69 | |__ case_1.json 70 | |__ ... 71 | |__ case_10000.json 72 | ``` 73 | 74 | To summarize the results, you can use [`experiments/summarize.py`](experiments/summarize.py): 75 | ```bash 76 | python3 -m experiments.summarize --dir_name=MEMIT --runs=run_,run_ 77 | ``` 78 | 79 | Running `python3 -m experiments.evaluate -h` or `python3 -m experiments.summarize -h` provides details about command-line flags. 80 | 81 | ## How to Cite 82 | 83 | ```bibtex 84 | @article{meng2022memit, 85 | title={Mass Editing Memory in a Transformer}, 86 | author={Kevin Meng and Sen Sharma, Arnab and Alex Andonian and Yonatan Belinkov and David Bau}, 87 | journal={arXiv preprint arXiv:2210.07229}, 88 | year={2022} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /src/memit/baselines/README.md: -------------------------------------------------------------------------------- 1 | We compare ROME against several open sourced state-of-the-art model editors. All are implemented in their respective folders. Implementations other than FT/FT+L are adapted from third parties. 2 | - Fine-Tuning (`ft`): Direct fine-tuning. 3 | - Constrained Fine-Tuning (`ft`): FT with $L_\infty$ norm constraint. Inspired by Zhu et al. [[Paper]](https://arxiv.org/abs/2012.00363) 4 | - Knowledge Neurons (`kn`): Dai et al. [[Code]](https://github.com/EleutherAI/knowledge-neurons) [[Paper]](https://arxiv.org/abs/2104.08696) 5 | - Knowledge Editor (`efk`): De Cao et al. [[Code]](https://github.com/eric-mitchell/mend) [[Paper]](https://arxiv.org/abs/2104.08164) 6 | - Model Editor Networks with Gradient Decomposition (`mend`): Mitchell et al. [[Code]](https://github.com/eric-mitchell/mend) [[Paper]](https://arxiv.org/abs/2110.11309) -------------------------------------------------------------------------------- /src/memit/baselines/ft/__init__.py: -------------------------------------------------------------------------------- 1 | from .ft_main import FTHyperParams, apply_ft_to_model, execute_ft 2 | -------------------------------------------------------------------------------- /src/memit/baselines/ft/ft_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | from util.hparams import HyperParams 5 | 6 | 7 | @dataclass 8 | class FTHyperParams(HyperParams): 9 | # Method 10 | layers: List[int] 11 | num_steps: int 12 | lr: float 13 | weight_decay: float 14 | kl_factor: float 15 | norm_constraint: float 16 | 17 | # Module templates 18 | rewrite_module_tmp: str 19 | layer_module_tmp: str 20 | mlp_module_tmp: str 21 | attn_module_tmp: str 22 | ln_f_module: str 23 | lm_head_module: str 24 | 25 | # Defaults 26 | batch_size: int = 64 27 | wd_power_law: tuple = None # Scale weight decay by number of edits 28 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/README.md: -------------------------------------------------------------------------------- 1 | # MEND: Model Editing Networks using Gradient Decomposition 2 | 3 | If you run into any issues with the code, you can open an issue and/or email me at `eric.mitchell@cs.stanford.edu` 4 | 5 | ## Setup 6 | 7 | ### Environment 8 | 9 | This codebase uses Python 3.7.9. Other versions may work as well. 10 | 11 | Create a virtualenv ([pyenv](https://github.com/pyenv/pyenv) can help with this) 12 | and install the dependencies: 13 | 14 | $ python -m venv env 15 | $ source env/bin/activate 16 | (env) $ pip install -r requirements.txt 17 | 18 | ### Data 19 | 20 | You can download the data needed for this project from 21 | [this Google Drive link](https://drive.google.com/drive/folders/1jAqBE45jEKR-5pMkwxlVQ0V8eKxqWbxA?usp=sharing). 22 | Unzip each sub-directory into `mend/data` and you should be good to go. 23 | 24 | ## Running the code 25 | 26 | Run MEND training/evaluation for distilGPT-2 on the wikitext editing problem with: 27 | 28 | (env) $ python -m run +alg=mend +experiment=gen +model=distilgpt2 data.wiki_webtext=False 29 | 30 | Other valid algs include `efk` ([KnowledgeEditor](https://arxiv.org/abs/2104.08164)) 31 | and `enn` ([Editable Neural Networks](https://arxiv.org/abs/2004.00345)). Valid experiments 32 | include `fc` (FEVER fact checking) and `qa` (zsRE question-answering). Splits and rephrases 33 | for both come from [De Cao et. al](https://arxiv.org/abs/2104.08164). Check `config/model` 34 | for options for editable models (note that all models don't work for all experiments; GPT-style 35 | models only work with `gen`, seq2seq models only work with `qa`, and BERT only works with `fc`). 36 | 37 | Also note that in the paper, we sample locality data from different datasets depending on the model. 38 | By default, training will use [Natural Questions](https://ai.google.com/research/NaturalQuestions) 39 | data (not zsRE data) for computing drawdown in the `qa` experiment and 40 | [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/). For models such as the `distilgpt2` 41 | model we use (which was fine-tuned on wikitext) or the BART-base model, this behavior should be 42 | disabled with `data.wiki_webtext=False` or `data.zsre_nq=False`, respectively. 43 | 44 | ## Citing the paper 45 | 46 | If this code or paper was useful, please consider using the following citation: 47 | 48 | @article{mitchell2021fast, 49 | title={Fast Model Editing at Scale}, 50 | author={Mitchell, Eric and Lin, Charles and Bosselut, Antoine and Finn, Chelsea and Manning, Christopher D.}, 51 | year={2021}, 52 | journal={CoRR}, 53 | url={https://arxiv.org/pdf/2110.11309.pdf} 54 | } 55 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/__init__.py: -------------------------------------------------------------------------------- 1 | from .mend_hparams import MENDHyperParams 2 | from .mend_main import MendRewriteExecutor 3 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/algs/enn.py: -------------------------------------------------------------------------------- 1 | import higher 2 | import torch 3 | import torch.nn as nn 4 | from editable_model import EditableModel 5 | from utils import _logits 6 | 7 | 8 | def fomaml_callback(all_grads): 9 | return [g.detach() if g is not None else None for g in all_grads] 10 | 11 | 12 | class ENN(EditableModel): 13 | def __init__( 14 | self, model, config, model_constructor, edit_lrs=None, edit_loss_fn=None 15 | ): 16 | super().__init__(model, config, model_constructor) 17 | 18 | if edit_lrs is None: 19 | edit_lrs = nn.Parameter( 20 | torch.tensor([config.edit_lr] * len(self.config.model.inner_params)) 21 | ) 22 | self.edit_lrs = edit_lrs 23 | 24 | if edit_loss_fn is not None: 25 | self.edit_loss_fn = edit_loss_fn 26 | 27 | self.grad_callback = fomaml_callback if config.enn.first_order else lambda x: x 28 | 29 | def outer_parameters(self): 30 | if self.config.no_grad_layers is None: 31 | return super().outer_parameters() 32 | else: 33 | params = [self.edit_lrs] 34 | for m in self.model.modules(): 35 | if isinstance(m, nn.ModuleList): 36 | params.extend(list(m[self.config.no_grad_layers :].parameters())) 37 | return params 38 | 39 | def get_state_dict(self): 40 | return self.state_dict() 41 | 42 | def edit(self, batch, condition=None, detach_history=False): 43 | opt = torch.optim.SGD( 44 | [ 45 | {"params": p, "lr": None} 46 | for (n, p) in self.model.named_parameters() 47 | if n in self.config.model.inner_params 48 | ] 49 | ) 50 | with torch.enable_grad(), higher.innerloop_ctx( 51 | self.model, 52 | opt, 53 | override={"lr": list(self.edit_lrs)}, 54 | copy_initial_weights=False, 55 | track_higher_grads=self.training, 56 | in_place=True, 57 | ) as (fmodel, diffopt): 58 | fmodel.eval() 59 | for edit_step in range(self.config.enn.n_edit_steps): 60 | output = _logits(fmodel(**batch)) 61 | loss = self.edit_loss_fn(output, batch["labels"])["nll"] 62 | diffopt.step(loss, grad_callback=self.grad_callback) 63 | 64 | if not detach_history: 65 | model_edited = fmodel 66 | else: 67 | model_edited = self.model_constructor() 68 | model_edited.load_state_dict(fmodel.state_dict()) 69 | model_edited.train(self.training) 70 | 71 | return ( 72 | ENN( 73 | model_edited, 74 | self.config, 75 | self.model_constructor, 76 | edit_lrs=self.edit_lrs, 77 | edit_loss_fn=self.edit_loss_fn, 78 | ), 79 | {}, 80 | ) 81 | 82 | 83 | def test(): 84 | import copy 85 | import types 86 | 87 | import transformers 88 | 89 | model = transformers.GPT2LMHeadModel.from_pretrained("gpt2") 90 | 91 | config = types.SimpleNamespace() 92 | config.edit_lr = 0.1 93 | config.model.inner_params = [ 94 | "transformer.h.9.mlp.c_fc.weight", 95 | "transformer.h.9.mlp.c_proj.weight", 96 | "transformer.h.10.mlp.c_fc.weight", 97 | "transformer.h.10.mlp.c_proj.weight", 98 | "transformer.h.11.mlp.c_fc.weight", 99 | "transformer.h.11.mlp.c_proj.weight", 100 | ] 101 | config.enn = {"n_edit_steps": 2, "first_order": False} 102 | 103 | enn = ENN(model, config, lambda: copy.deepcopy(model)).cuda() 104 | 105 | x = torch.arange(100).view(5, 20).cuda() + 1000 106 | 107 | edited = enn.edit(x, masks=torch.ones_like(x), labels=x) 108 | 109 | orig_param = [ 110 | p 111 | for (n, p) in enn.model.named_parameters() 112 | if n == config.model.inner_params[-1] 113 | ][0] 114 | edited_param = [ 115 | p 116 | for (n, p) in edited.model.named_parameters() 117 | if n == config.model.inner_params[-1] 118 | ][0] 119 | 120 | print((orig_param - edited_param).abs().max()) 121 | edited.eval() 122 | print( 123 | enn(x, labels=x).loss, 124 | edited(x, labels=x).loss, 125 | edited.edit_loss_fn(edited(x).logits, x)["nll"], 126 | ) 127 | edited.edit_loss_fn(edited(x).logits, x).backward() 128 | import pdb 129 | 130 | pdb.set_trace() 131 | 132 | 133 | if __name__ == "__main__": 134 | with torch.autograd.set_detect_anomaly(True): 135 | test() 136 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/algs/ft.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import higher 4 | import torch 5 | import torch.nn as nn 6 | from editable_model import EditableModel 7 | from higher.patch import monkeypatch as make_functional 8 | from losses import kl_loc_loss 9 | from utils import _inner_params, _logits 10 | 11 | 12 | class FT(EditableModel): 13 | """ 14 | Fine-tuning approach. Does not require training. 15 | """ 16 | 17 | def __init__(self, model, config, model_constructor, edit_loss_fn=None): 18 | super().__init__(model, config, model_constructor) 19 | 20 | if edit_loss_fn is not None: 21 | self.edit_loss_fn = edit_loss_fn 22 | 23 | self.locality_loss_fn = kl_loc_loss 24 | self.loc_ids = None 25 | self.loc_masks = None 26 | self.loc_sampler = None 27 | 28 | def _edit_loss(self, model, p0, p_edited, edit_batch): 29 | output = _logits(model(**edit_batch, params=p_edited)) 30 | loss_dict = self.edit_loss_fn(output, edit_batch["labels"]) 31 | l_edit, acc = loss_dict["nll"], loss_dict["acc"] 32 | if self.config.ft.locality.enabled: 33 | if self.config.ft.locality.oracle: 34 | loc_batch = next(self.loc_sampler)["loc"] 35 | else: 36 | raise NotImplementedError 37 | 38 | with torch.no_grad(): 39 | original_base_logits = _logits(model(**loc_batch, params=p0)) 40 | edited_base_logits = _logits(model(**loc_batch, params=p_edited)) 41 | kl_mask = loc_batch.get( 42 | "decoder_attention_mask", loc_batch["attention_mask"] 43 | ) 44 | l_loc = self.locality_loss_fn( 45 | original_base_logits, edited_base_logits, mask=kl_mask 46 | ) 47 | loss = l_loc + self.config.ft.locality.cedit * l_edit 48 | else: 49 | l_loc = torch.tensor(float("nan")) 50 | loss = l_edit 51 | return loss, l_edit, l_loc, acc 52 | 53 | def accuracy(self, output, labels): 54 | if output.shape[-1] != 1: 55 | shifted_output = output.argmax(-1)[:, :-1] 56 | shifted_labels = labels[:, 1:] 57 | to_predict = (shifted_labels != -100).sum() 58 | correct = (shifted_output == shifted_labels).sum() 59 | acc = correct.float() / to_predict.float() 60 | else: 61 | acc = ((output > 0) == labels.bool()).sum().float() 62 | return acc 63 | 64 | def _edit_status(self, step, loss, l_edit, l_loc, acc, res_p): 65 | return ( 66 | f"step: {step}".ljust(14) 67 | + f"loss: {loss.item():.5f}".ljust(18) 68 | + f"l_edit: {l_edit.item():.5f}".ljust(18) 69 | + f"l_loc: {l_loc.item():.5f}".ljust(18) 70 | + f"acc: {acc.item():.2f}".ljust(14) 71 | + f"norm: {res_p.view(-1).norm().item():.5f}" 72 | ) 73 | 74 | def edit(self, batch, condition=None, detach_history=False): 75 | edit_model = self.model.eval() 76 | p0 = list(edit_model.named_parameters()) 77 | 78 | if not isinstance(edit_model, higher.patch._MonkeyPatchBase): 79 | edit_model = make_functional( 80 | self.model, track_higher_grads=False, in_place=True 81 | ) 82 | 83 | packed_residuals = {} 84 | opt_params = [] 85 | for n, p in _inner_params( 86 | edit_model.named_parameters(), self.config.model.inner_params 87 | ): 88 | if self.config.ft.rank is not None: 89 | u = nn.Parameter( 90 | torch.randn(p.shape[0], self.config.ft.rank, device=p.device) 91 | * self.config.ft.init_std 92 | ) 93 | v = nn.Parameter( 94 | torch.zeros(self.config.ft.rank, p.shape[1], device=p.device) 95 | ) 96 | res = [u, v] 97 | else: 98 | res = [nn.Parameter(torch.zeros_like(p, device=p.device))] 99 | 100 | packed_residuals[n] = res 101 | opt_params.extend(res) 102 | 103 | assert len(opt_params) == len(self.config.model.inner_params) 104 | OptClass = getattr(torch.optim, self.config.ft.opt) 105 | opt = OptClass(opt_params, lr=self.config.edit_lr) 106 | 107 | start_time = time.time() 108 | for edit_step in range(self.config.ft.max_edit_steps): 109 | if self.config.ft.time_limit is not None and ( 110 | time.time() - start_time > self.config.ft.time_limit 111 | ): 112 | break 113 | residuals = { 114 | k: v[0] @ v[1] if len(v) == 2 else v[0] 115 | for k, v in packed_residuals.items() 116 | } 117 | edited_params = [ 118 | p if n not in residuals else p.detach() + residuals[n] for n, p in p0 119 | ] 120 | loss, l_edit, l_loc, acc = self._edit_loss( 121 | edit_model, [p for n, p in p0], edited_params, batch 122 | ) 123 | 124 | if self.config.ft.verbose: 125 | residual = list(residuals.values())[-1] 126 | print( 127 | self._edit_status(edit_step, loss, l_edit, l_loc, acc, residual), 128 | end="\r", 129 | ) 130 | 131 | if acc == 1.0: 132 | break 133 | 134 | for p, g in zip(opt_params, torch.autograd.grad(loss, opt_params)): 135 | p.grad = g 136 | torch.nn.utils.clip_grad_norm_(opt_params, self.config.grad_clip) 137 | opt.step() 138 | opt.zero_grad() 139 | 140 | if detach_history: 141 | new_model = self.model_constructor() 142 | new_model.load_state_dict(edit_model.state_dict()) 143 | edit_model = new_model 144 | edit_model.train(self.training) 145 | 146 | return ( 147 | FT(edit_model, self.config, self.model_constructor, self.edit_loss_fn), 148 | {}, 149 | ) 150 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/alg/efk.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | alg: efk 4 | train_base: False 5 | lr: 1e-5 6 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/alg/enn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | alg: enn 4 | train_base: True 5 | enn: 6 | first_order: False 7 | n_edit_steps: 1 8 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/alg/ft.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | train_base: False 4 | alg: ft 5 | edit_lr: 5e-6 6 | ft: 7 | verbose: false 8 | max_edit_steps: 100 9 | time_limit: null 10 | locality: 11 | enabled: false 12 | oracle: true 13 | cedit: 1e-2 14 | batch_size: 1 15 | rank: null 16 | opt: RMSprop 17 | init_std: 0.01 18 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/alg/mend.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | alg: mend 4 | lr: 1e-6 5 | train_base: False 6 | edit_lr: 1e-4 7 | lr_lr: 1e-4 8 | mend: 9 | one_sided: False 10 | n_hidden: 1 11 | hidden_dim: null 12 | init: id 13 | norm: True 14 | combine: True 15 | x_only: False 16 | delta_only: False 17 | act: relu 18 | rank: 1920 19 | mlp_class: IDMLP 20 | shared: True 21 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/config.yaml: -------------------------------------------------------------------------------- 1 | alg: enn 2 | lr: 1e-5 3 | edit_lr: 1e-2 4 | seed: 0 5 | debug: False 6 | model_save_pt: 5000 7 | edit_bs: 1 8 | silent: False 9 | max_iters: 1000000 10 | log_interval: 100 11 | val_interval: 5000 12 | lr_lr: 1e-3 13 | batch_size: 2 14 | val_batch_size: 5 15 | accumulate_bs: 10 16 | cedit: 0.1 17 | cloc: 1.0 18 | cbase: 1.0 19 | val_steps: 500 20 | device: cuda 21 | base_loss: distill 22 | oracle: False 23 | train: True 24 | train_base: True 25 | opt: Adam 26 | single_batch: False 27 | archive: null 28 | grad_clip: 100. 29 | ref: null 30 | early_stop_patience: 20000 31 | early_stop_key: "loss/total_edit_val" 32 | dropout: 0.0 33 | tokenizer: null 34 | results_dir: null 35 | no_grad_layers: null 36 | eval_only: False 37 | half: False 38 | save: False 39 | 40 | model: 41 | pt: null 42 | 43 | data: 44 | path: null 45 | rephrase: true 46 | zsre_nq: true 47 | nq_path: ${hydra:runtime.cwd}/data/nq 48 | wiki_webtext: true 49 | n_edits: 1 50 | 51 | eval: 52 | verbose: True 53 | log_interval: 100 54 | final_eval: True 55 | 56 | hydra: 57 | run: 58 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f${uuid:}} 59 | sweep: 60 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f} 61 | subdir: ${hydra.job.num} -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/experiment/fc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: fc 4 | dataset: fever 5 | cbase: 1.0 6 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/experiment/gen.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: gen 4 | dataset: wikitext-103 5 | cbase: 10.0 6 | data: 7 | path: ${hydra:runtime.cwd}/data/10token/data/self_sample/ 8 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/experiment/qa.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: qa 4 | dataset: zsre 5 | cbase: 1.0 6 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/model/bart-base.yaml: -------------------------------------------------------------------------------- 1 | name: facebook/bart-base 2 | class_name: BartForConditionalGeneration 3 | tokenizer_class: BartTokenizerFast 4 | tokenizer_name: facebook/bart-base 5 | inner_params: 6 | - model.encoder.layers.4.fc1.weight 7 | - model.encoder.layers.4.fc2.weight 8 | - model.encoder.layers.5.fc1.weight 9 | - model.encoder.layers.5.fc2.weight 10 | - model.decoder.layers.4.fc1.weight 11 | - model.decoder.layers.4.fc2.weight 12 | - model.decoder.layers.5.fc1.weight 13 | - model.decoder.layers.5.fc2.weight 14 | 15 | pt: ${hydra:runtime.cwd}/data/zsre/QA_model.ckpt -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/model/bert-base.yaml: -------------------------------------------------------------------------------- 1 | name: bert-base-uncased 2 | class_name: BertClassifier 3 | tokenizer_class: BertTokenizerFast 4 | tokenizer_name: bert-base-uncased 5 | inner_params: 6 | - model.encoder.layer.9.intermediate.dense.weight 7 | - model.encoder.layer.9.output.dense.weight 8 | - model.encoder.layer.10.intermediate.dense.weight 9 | - model.encoder.layer.10.output.dense.weight 10 | - model.encoder.layer.11.intermediate.dense.weight 11 | - model.encoder.layer.11.output.dense.weight 12 | 13 | pt: ${hydra:runtime.cwd}/data/fever/FC_model.ckpt -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/model/distilgpt2.yaml: -------------------------------------------------------------------------------- 1 | name: MYX4567/distilgpt2-finetuned-wikitext2 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: distilgpt2 5 | inner_params: 6 | - transformer.h.3.mlp.c_fc.weight 7 | - transformer.h.3.mlp.c_proj.weight 8 | - transformer.h.4.mlp.c_fc.weight 9 | - transformer.h.4.mlp.c_proj.weight 10 | - transformer.h.5.mlp.c_fc.weight 11 | - transformer.h.5.mlp.c_proj.weight 12 | 13 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/model/gpt2.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: gpt2 5 | inner_params: 6 | - transformer.h.9.mlp.c_proj.weight 7 | - transformer.h.9.mlp.c_fc.weight 8 | - transformer.h.10.mlp.c_proj.weight 9 | - transformer.h.10.mlp.c_fc.weight 10 | - transformer.h.11.mlp.c_proj.weight 11 | - transformer.h.11.mlp.c_fc.weight -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/model/gpt2large.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2-large 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: gpt2-large 5 | inner_params: 6 | - transformer.h.33.mlp.c_proj.weight 7 | - transformer.h.33.mlp.c_fc.weight 8 | - transformer.h.34.mlp.c_proj.weight 9 | - transformer.h.34.mlp.c_fc.weight 10 | - transformer.h.35.mlp.c_proj.weight 11 | - transformer.h.35.mlp.c_fc.weight 12 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/model/gpt2medium.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2-medium 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: gpt2-medium 5 | inner_params: 6 | - transformer.h.21.mlp.c_proj.weight 7 | - transformer.h.21.mlp.c_fc.weight 8 | - transformer.h.22.mlp.c_proj.weight 9 | - transformer.h.22.mlp.c_fc.weight 10 | - transformer.h.23.mlp.c_proj.weight 11 | - transformer.h.23.mlp.c_fc.weight -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/model/gpt2xl.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2-xl 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: gpt2-xl 5 | inner_params: 6 | - transformer.h.45.mlp.c_proj.weight 7 | - transformer.h.45.mlp.c_fc.weight 8 | - transformer.h.46.mlp.c_proj.weight 9 | - transformer.h.46.mlp.c_fc.weight 10 | - transformer.h.47.mlp.c_proj.weight 11 | - transformer.h.47.mlp.c_fc.weight 12 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/model/gptj.yaml: -------------------------------------------------------------------------------- 1 | name: EleutherAI/gpt-j-6B 2 | class_name: GPTJForCausalLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: EleutherAI/gpt-j-6B 5 | inner_params: 6 | - transformer.h.25.mlp.fc_in.weight 7 | - transformer.h.25.mlp.fc_out.weight 8 | - transformer.h.26.mlp.fc_in.weight 9 | - transformer.h.26.mlp.fc_out.weight 10 | - transformer.h.27.mlp.fc_in.weight 11 | - transformer.h.27.mlp.fc_out.weight 12 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/model/gptneo27.yaml: -------------------------------------------------------------------------------- 1 | name: EleutherAI/gpt-neo-2.7B 2 | class_name: GPTNeoForCausalLM 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: EleutherAI/gpt-neo-2.7B 5 | inner_params: 6 | - transformer.h.29.mlp.c_fc.weight 7 | - transformer.h.29.mlp.c_proj.weight 8 | - transformer.h.30.mlp.c_fc.weight 9 | - transformer.h.30.mlp.c_proj.weight 10 | - transformer.h.31.mlp.c_fc.weight 11 | - transformer.h.31.mlp.c_proj.weight 12 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/model/t5large.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-large-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-large-ssm-nq 5 | inner_params: 6 | - encoder.block.22.layer.1.DenseReluDense.wi.weight 7 | - encoder.block.22.layer.1.DenseReluDense.wo.weight 8 | - encoder.block.23.layer.1.DenseReluDense.wi.weight 9 | - encoder.block.23.layer.1.DenseReluDense.wo.weight 10 | - decoder.block.22.layer.2.DenseReluDense.wi.weight 11 | - decoder.block.22.layer.2.DenseReluDense.wo.weight 12 | - decoder.block.23.layer.2.DenseReluDense.wi.weight 13 | - decoder.block.23.layer.2.DenseReluDense.wo.weight 14 | 15 | pt: null -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/model/t5small.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-small-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-small-ssm-nq 5 | inner_params: 6 | - encoder.block.6.layer.1.DenseReluDense.wi_0.weight 7 | - encoder.block.6.layer.1.DenseReluDense.wo.weight 8 | - encoder.block.7.layer.1.DenseReluDense.wi_0.weight 9 | - encoder.block.7.layer.1.DenseReluDense.wo.weight 10 | - decoder.block.6.layer.2.DenseReluDense.wi_0.weight 11 | - decoder.block.6.layer.2.DenseReluDense.wo.weight 12 | - decoder.block.7.layer.2.DenseReluDense.wi_0.weight 13 | - decoder.block.7.layer.2.DenseReluDense.wo.weight 14 | 15 | pt: null -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/model/t5xl.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-xl-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-xl-ssm-nq 5 | inner_params: 6 | - encoder.block.22.layer.1.DenseReluDense.wi_0.weight 7 | - encoder.block.22.layer.1.DenseReluDense.wo.weight 8 | - encoder.block.23.layer.1.DenseReluDense.wi_0.weight 9 | - encoder.block.23.layer.1.DenseReluDense.wo.weight 10 | - decoder.block.22.layer.2.DenseReluDense.wi_0.weight 11 | - decoder.block.22.layer.2.DenseReluDense.wo.weight 12 | - decoder.block.23.layer.2.DenseReluDense.wi_0.weight 13 | - decoder.block.23.layer.2.DenseReluDense.wo.weight 14 | 15 | pt: null -------------------------------------------------------------------------------- /src/memit/baselines/mend/config/model/t5xxl.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-xxl-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-xxl-ssm-nq 5 | inner_params: 6 | - encoder.block.22.layer.1.DenseReluDense.wi_0.weight 7 | - encoder.block.22.layer.1.DenseReluDense.wo.weight 8 | - encoder.block.23.layer.1.DenseReluDense.wi_0.weight 9 | - encoder.block.23.layer.1.DenseReluDense.wo.weight 10 | - decoder.block.22.layer.2.DenseReluDense.wi_0.weight 11 | - decoder.block.22.layer.2.DenseReluDense.wo.weight 12 | - decoder.block.23.layer.2.DenseReluDense.wi_0.weight 13 | - decoder.block.23.layer.2.DenseReluDense.wo.weight 14 | 15 | pt: null -------------------------------------------------------------------------------- /src/memit/baselines/mend/data_classes/fever.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import jsonlines 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from utils import EditBatchSampler, dict_to 8 | 9 | POSITIVE_CLASS = "SUPPORTS" 10 | 11 | 12 | class BinaryAugmentedKILT(Dataset): 13 | def __init__(self, tokenizer, data_path, config, max_length=32): 14 | super().__init__() 15 | self.tokenizer = tokenizer 16 | self.data = [] 17 | self.config = config 18 | 19 | def extract(d): 20 | extracted = { 21 | k: d[k] 22 | for k in [ 23 | "logit", 24 | "input", 25 | "prediction", 26 | "alternatives", 27 | "filtered_rephrases", 28 | ] 29 | } 30 | extracted["label"] = d["output"][0]["answer"] 31 | return extracted 32 | 33 | with jsonlines.open(data_path) as f: 34 | for d in f: 35 | if len(d["alternatives"]) > 0 and len(d["filtered_rephrases"]) > 0: 36 | self.data.append(extract(d)) 37 | 38 | self.max_length = max_length 39 | 40 | def __len__(self): 41 | return len(self.data) 42 | 43 | def __getitem__(self, item): 44 | obj = self.data[item] 45 | rephrase = random.choice(self.data[item]["filtered_rephrases"]) 46 | output = { 47 | "label": obj["label"] == POSITIVE_CLASS, 48 | "src": obj["input"], 49 | "rephrase": rephrase, 50 | "pred": obj["prediction"] == POSITIVE_CLASS, 51 | "alt": obj["alternatives"][0] == POSITIVE_CLASS, 52 | "cond_flip": "{} >> {} || {}".format( 53 | obj["prediction"], 54 | obj["alternatives"][0], 55 | obj["input"], 56 | ), 57 | "cond_orig": "{} >> {} || {}".format( 58 | obj["prediction"], 59 | obj["prediction"], 60 | obj["input"], 61 | ), 62 | "logit": obj["logit"], 63 | } 64 | 65 | return output 66 | 67 | def collate_fn(self, batch): 68 | src = [b["src"] for b in batch] 69 | rephrase = [batch[-1]["rephrase"]] 70 | 71 | flip_label = np.random.uniform() > 0.5 72 | predictions = [b["pred"] for b in batch] 73 | labels = [b["label"] for b in batch] 74 | labels[-1] = predictions[ 75 | -1 76 | ] # the last element in the batch is special (the edit element) 77 | cond = [batch[-1]["cond_orig"]] 78 | if flip_label: 79 | labels[-1] = batch[-1]["alt"] 80 | cond = [batch[-1]["cond_flip"]] 81 | 82 | batches = {} 83 | for k1, v1 in {"": src, "cond_": cond, "rephrase_": rephrase}.items(): 84 | encoded = self.tokenizer( 85 | v1, 86 | return_tensors="pt", 87 | padding=True, 88 | max_length=self.max_length, 89 | truncation=True, 90 | ) 91 | for k2, v2 in encoded.items(): 92 | batches[f"{k1}{k2}"] = v2 93 | 94 | batches["predictions"] = torch.tensor(predictions).long().view(-1, 1) 95 | batches["labels"] = torch.tensor(labels).long().view(-1, 1) 96 | batches["raw"] = batch 97 | return batches 98 | 99 | def edit_generator(self, batch_size, n=None): 100 | if n is None: 101 | n = len(self) 102 | sampler = EditBatchSampler( 103 | n, memorize_mode=self.config.single_batch, seed=self.config.seed 104 | ) 105 | while True: 106 | edit_idxs, loc_idxs = sampler.sample(batch_size) 107 | assert len(edit_idxs) == 1 108 | idxs = loc_idxs + edit_idxs 109 | toks = self.collate_fn([self[idx] for idx in idxs]) 110 | 111 | pass_keys = ["input_ids", "attention_mask", "labels"] 112 | edit_inner = {k: v[-1:] for k, v in toks.items() if k in pass_keys} 113 | if self.config.data.rephrase: 114 | edit_outer = {} 115 | edit_outer["input_ids"] = toks["rephrase_input_ids"] 116 | edit_outer["attention_mask"] = toks["rephrase_attention_mask"] 117 | edit_outer["labels"] = edit_inner["labels"] 118 | else: 119 | edit_outer = edit_inner 120 | loc = {k: v[:-1] for k, v in toks.items() if k in pass_keys} 121 | cond = { 122 | "input_ids": toks["cond_input_ids"], 123 | "attention_mask": toks["cond_attention_mask"], 124 | } 125 | 126 | batch = { 127 | "edit_inner": edit_inner, 128 | "edit_outer": edit_outer, 129 | "loc": loc, 130 | "cond": cond, 131 | } 132 | yield dict_to(batch, self.config.device) 133 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/data_classes/nq.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class NQDataset: 5 | def __init__(self, path: str, tokenizer, config): 6 | with open(path, "r") as f: 7 | self.data = json.load(f) 8 | 9 | self.questions = self.data["questions"] 10 | self.answers = self.data["answers"] 11 | self.tokenizer = tokenizer 12 | self.config = config 13 | 14 | def __getitem__(self, idx): 15 | idx = idx % len(self.questions) 16 | return self.questions[idx], self.answers[idx] 17 | 18 | @staticmethod 19 | def generate( 20 | out_path: str, 21 | prompt: bool = False, 22 | capitalize: bool = True, 23 | question_mark: bool = True, 24 | ): 25 | import os 26 | 27 | import datasets 28 | 29 | def process(text): 30 | if capitalize: 31 | text = text[0].capitalize() + text[1:] 32 | if question_mark: 33 | text = text + "?" 34 | if prompt: 35 | text = "nq question: " + text 36 | return text 37 | 38 | def extract(d): 39 | questions = [process(q["text"]) for q in d["question"]] 40 | answers = [ 41 | [a["text"][0] for a in ann["short_answers"] if len(a["text"])] 42 | for ann in d["annotations"] 43 | ] 44 | questions = [q for q, a in zip(questions, answers) if len(a)] 45 | answers = [min(a, key=len) for a in answers if len(a)] 46 | return questions, answers 47 | 48 | train = datasets.load_dataset("natural_questions", split="train") 49 | tq, ta = extract(train) 50 | val = datasets.load_dataset("natural_questions", split="validation") 51 | vq, va = extract(val) 52 | 53 | if not os.path.exists(out_path): 54 | os.makedirs(out_path) 55 | with open(f"{out_path}/train.json", "w") as f: 56 | json.dump({"questions": tq, "answers": ta}, f) 57 | with open(f"{out_path}/validation.json", "w") as f: 58 | json.dump({"questions": vq, "answers": va}, f) 59 | 60 | 61 | if __name__ == "__main__": 62 | import argparse 63 | 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument("--out_path", type=str, default="data/nq") 66 | args = parser.parse_args() 67 | NQDataset.generate(args.out_path) 68 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/data_classes/wiki.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import logging 4 | import random 5 | 6 | from datasets import load_dataset 7 | from torch.utils.data import Dataset 8 | from utils import EditBatchSampler, dict_to, scr 9 | 10 | LOG = logging.getLogger(__name__) 11 | 12 | 13 | def is_ascii(s): 14 | return all(ord(c) < 128 for c in s) 15 | 16 | 17 | def filter_text(iterator): 18 | valid = [] 19 | for text in iterator: 20 | if len(text.split(" ")) < 50: 21 | continue 22 | if not is_ascii(text): 23 | continue 24 | valid.append(text) 25 | 26 | return valid 27 | 28 | 29 | class GenDataset(Dataset): 30 | def __init__( 31 | self, 32 | split: str, 33 | tokenizer, 34 | config, 35 | edit_path: str, 36 | pct: int = 10, 37 | max_length: int = 200, 38 | ): 39 | version = "wikitext-103-raw-v1" 40 | split_str = f"{split}[:{pct}%]" if split == "train" else split 41 | LOG.info(f"Loading wikitext version {version}, split {split_str}") 42 | base_samples = load_dataset( 43 | "wikitext", version, cache_dir=scr(), split=split_str 44 | )["text"] 45 | self.base_samples = filter_text(base_samples) 46 | with open(edit_path + split[:5] + ".json", "r") as f: 47 | self.edit_samples = json.load(f) 48 | 49 | self.tok = tokenizer 50 | self.config = config 51 | self.max_length = max_length 52 | self.n_tokens = self.edit_samples["n_tokens"] 53 | 54 | len_base = len(self.base_samples) 55 | len_edit = len(self.edit_samples["original"]) 56 | LOG.info(f"Loaded {len_base} wiki-103 samples and {len_edit} edit samples") 57 | 58 | if config.data.wiki_webtext: 59 | self.use_wiki = True 60 | LOG.info("** Using webtext for wiki base samples **") 61 | webtext = load_dataset( 62 | "stas/openwebtext-10k", split="train", cache_dir=scr() 63 | )["text"] 64 | n_train = int(len(webtext) * 0.9) 65 | if split == "train": 66 | self.base_samples = webtext[:n_train] 67 | else: 68 | self.base_samples = webtext[n_train:] 69 | else: 70 | self.use_wiki = False 71 | 72 | def edit_generator(self, batch_size, n=None): 73 | if n is None: 74 | n = len(self) 75 | sampler = EditBatchSampler( 76 | n, 77 | memorize_mode=self.config.single_batch, 78 | loc_disjoint=not self.use_wiki, 79 | seed=self.config.seed, 80 | ) 81 | while True: 82 | edit_idxs, loc_idxs = sampler.sample(batch_size) 83 | 84 | edit_batch = [self.edit_samples["completions"][idx] for idx in edit_idxs] 85 | loc_batch = [ 86 | self.base_samples[idx % len(self.base_samples)] for idx in loc_idxs 87 | ] 88 | 89 | edit_toks = self.tok(edit_batch, padding=True, return_tensors="pt") 90 | loc_toks = self.tok( 91 | loc_batch, 92 | padding=True, 93 | return_tensors="pt", 94 | truncation=self.config.data.wiki_webtext, 95 | max_length=self.max_length, 96 | ) 97 | 98 | edit_inner = {**edit_toks} 99 | edit_inner["labels"] = self.get_edit_labels(edit_toks["input_ids"]) 100 | 101 | edit_outer = copy.deepcopy(edit_inner) 102 | if self.config.data.rephrase: 103 | lens = (edit_outer["input_ids"] != -100).sum(-1) 104 | remove = random.randint(0, (min(lens) - self.n_tokens) // 2) 105 | for k, v in edit_outer.items(): 106 | edit_outer[k] = v[:, remove:] 107 | 108 | loc = {**loc_toks} 109 | loc["labels"] = self.get_labels(loc_toks["input_ids"]) 110 | cond = {**edit_toks} 111 | 112 | batch = { 113 | "edit_inner": edit_inner, 114 | "edit_outer": edit_outer, 115 | "loc": loc, 116 | "cond": cond, 117 | } 118 | 119 | yield dict_to(batch, self.config.device) 120 | 121 | def __len__(self): 122 | return len(self.edit_samples["original"]) 123 | 124 | def _check_padding(self, ids): 125 | if (ids[:, 0] == self.tok.pad_token_id).any(): 126 | raise ValueError("Left-padding not supported for GPT2") 127 | 128 | def get_edit_labels(self, ids): 129 | self._check_padding(ids) 130 | 131 | labels = ids.clone() 132 | end_idxs = (labels != self.tok.pad_token_id).sum(-1) 133 | for batch_idx, end_idx in enumerate(end_idxs): 134 | labels[batch_idx, : end_idx - self.n_tokens] = -100 135 | labels[labels == self.tok.pad_token_id] = -100 136 | return labels 137 | 138 | def get_labels(self, ids): 139 | self._check_padding(ids) 140 | 141 | return ids.masked_fill(ids == self.tok.pad_token_id, -100) 142 | 143 | def __getitem__(self, idx): 144 | return self.base_samples[idx] 145 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/editable_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .losses import masked_log_probs 4 | from .utils import _logits, shift_targets 5 | 6 | 7 | class EditableModel(nn.Module): 8 | def __init__(self, model, config, model_constructor): 9 | super().__init__() 10 | 11 | self.model = model 12 | self.config = config 13 | self.model_constructor = model_constructor 14 | 15 | def _edit_loss_fn(pred, targ): 16 | return masked_log_probs(pred, targ, shift=shift_targets(self.config)) 17 | 18 | self.edit_loss_fn = _edit_loss_fn 19 | self.loc_loss_fn = _edit_loss_fn 20 | 21 | def edit(self, batch, condition=None, detach_history=False): 22 | raise NotImplementedError 23 | 24 | def forward(self, *inputs, **kwargs): 25 | return _logits(self.model(*inputs, **kwargs)) 26 | 27 | def outer_parameters(self): 28 | return self.parameters() 29 | 30 | def base_loss(self, input_ids, attention_masks, label_ids): 31 | pass 32 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/hooks.py: -------------------------------------------------------------------------------- 1 | from .utils import parent_module 2 | 3 | 4 | def linear_backward_hook(mod, grad_in, grad_out): 5 | if not hasattr(mod, "weight"): 6 | print(f"{mod} has no weight!") 7 | return 8 | 9 | if hasattr(mod.weight, "__x__"): 10 | assert len(grad_out) == 1 11 | # mod.weight.__bgrad__ = grad_out[0].unsqueeze(-1) * mod.__x__[0].unsqueeze(-2) 12 | mod.weight.__delta__ = grad_out[0].detach() 13 | else: 14 | print(f"{mod} has no __x__") 15 | 16 | 17 | def linear_forward_hook(mod, activations, output): 18 | assert len(activations) == 1 19 | mod.weight.__x__ = activations[0].detach() 20 | 21 | 22 | def hook_model(model, pnames): 23 | handles = [] 24 | for m in [parent_module(model, pname) for pname in pnames]: 25 | handles.append(m.register_full_backward_hook(linear_backward_hook)) 26 | handles.append(m.register_forward_hook(linear_forward_hook)) 27 | 28 | model.handles = handles 29 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def kl_loc_loss(pre, post, mask=None): 6 | pre = pre.to(torch.float32) 7 | post = post.to(torch.float32) 8 | 9 | sequence = pre.dim() == 3 10 | pre_ = pre.view(-1, pre.shape[-1]) 11 | post_ = post.view(pre_.shape) 12 | assert pre_.shape[0] == post_.shape[0] 13 | 14 | if not sequence: 15 | if pre_.shape[-1] == 1: # No masking needed for binary classification 16 | return (pre.sigmoid() * (F.logsigmoid(pre) - F.logsigmoid(post))).mean() + ( 17 | (-pre).sigmoid() * (F.logsigmoid(-pre) - F.logsigmoid(-post)) 18 | ).mean() 19 | else: # We have sequences of predictions; masking needed 20 | if pre_.shape[-1] > 1: 21 | assert mask is not None 22 | mask_ = mask.view(pre_.shape[0]) 23 | kl = ( 24 | pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1)) 25 | ).sum(-1) 26 | return (kl * mask_).sum() / mask_.sum() 27 | 28 | raise NotImplementedError 29 | 30 | 31 | def binary_log_probs(pred, targ): 32 | neg_mask = torch.ones_like(pred) 33 | neg_mask[targ == 0] *= -1 34 | pred = pred * neg_mask 35 | log_probs = F.logsigmoid(pred) 36 | acc = (log_probs.exp() > 0.5).float().mean() 37 | return { 38 | "acc": acc, 39 | "log_prob": log_probs.mean(), 40 | "prob": log_probs.exp().mean(), 41 | "nll": -log_probs.mean(), 42 | "n_tokens": log_probs.shape[0], 43 | } 44 | 45 | 46 | def multiclass_log_probs(pred, targ, shift=True): 47 | NULL_TOKEN = 0 # a placeholder used for masked target locations 48 | 49 | pred = pred.clone() 50 | targ = targ.clone() 51 | if shift and pred.dim() == 3: # Dealing with sequences 52 | pred = pred[:, :-1] # Remove last prediction in sequence 53 | targ = targ[:, 1:] # Shift to align predictions and targets 54 | 55 | mask = targ != -100 56 | targ[~mask] = NULL_TOKEN # Can be any valid token, since we'll throw them out 57 | unmasked_log_probs = pred.log_softmax(-1).gather(-1, targ.unsqueeze(-1)).squeeze(-1) 58 | 59 | pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN) 60 | correct = pred_ids == targ 61 | if pred.dim() == 3: 62 | correct = (pred_ids == targ).all(-1) # We want to get the whole sequence right 63 | acc = correct.float().mean() 64 | 65 | n_tokens = mask.float().sum() 66 | log_prob = (unmasked_log_probs * mask.float()).sum() / n_tokens 67 | prob = (unmasked_log_probs.exp() * mask.float()).sum() / n_tokens 68 | return { 69 | "acc": acc, 70 | "log_prob": log_prob, 71 | "prob": prob, 72 | "n_tokens": n_tokens, 73 | "nll": -log_prob, 74 | } 75 | 76 | 77 | def masked_log_probs(pred, targ, shift=True): 78 | pred = pred.to(torch.float32) 79 | 80 | if not (pred.dim() == 2 or pred.dim() == 3): 81 | raise RuntimeError(f"Expected pred to have 2 or 3 dimensions, got {pred.shape}") 82 | 83 | if pred.shape[-1] == 1: 84 | return binary_log_probs(pred, targ) 85 | else: 86 | return multiclass_log_probs(pred, targ, shift=shift) 87 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/mend_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from util.hparams import HyperParams 4 | 5 | 6 | @dataclass 7 | class MENDHyperParams(HyperParams): 8 | lr_scale: float 9 | n_toks: int 10 | model_name: str 11 | counterfact: bool 12 | mini: bool 13 | zsre: bool 14 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/oracle.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | from higher.patch import monkeypatch as make_functional 6 | from losses import kl_loc_loss, masked_log_probs 7 | 8 | 9 | def test_rank1(model, dataset, config): 10 | model.eval() 11 | generator = dataset.edit_generator(21) 12 | 13 | history = [] 14 | for example in generator: 15 | edit_model = make_functional(model, track_higher_grads=False) 16 | residuals = {} 17 | opt_list = [] 18 | print(config.model.inner_params) 19 | for n, p in edit_model.named_parameters(): 20 | if n in config.model.inner_params: 21 | std = 0.01 22 | u = nn.Parameter(torch.randn(p.shape[0], 1, device=p.device) * std) 23 | v = nn.Parameter(torch.randn(1, p.shape[1], device=p.device) * std) 24 | assert ( 25 | u @ v 26 | ).shape == p.shape, f"got {(u@v).shape}, expected {p.shape}" 27 | 28 | residuals[n] = (u, v) 29 | opt_list.extend([u, v]) 30 | 31 | res_opt = torch.optim.SGD(opt_list, lr=100) 32 | 33 | acc = 0 34 | it = 0 35 | ids_train = example["loc_ids"][:10] 36 | ids_val = example["loc_ids"][10:] 37 | with torch.inference_mode(): 38 | original_logits_train = model(ids_train) 39 | original_logits_val = model(ids_val) 40 | if hasattr(original_logits_train, "logits"): 41 | original_logits_train = original_logits_train.logits 42 | original_logits_val = original_logits_val.logits 43 | 44 | while acc < 1 and it < 1000: 45 | fast_params = [] 46 | for n, p in edit_model.named_parameters(): 47 | if n in residuals: 48 | u, v = residuals[n] 49 | fast_params.append(p.detach() + (u @ v)) 50 | else: 51 | fast_params.append(p.detach()) 52 | 53 | loc_pred = edit_model(ids_train, params=fast_params) 54 | if hasattr(loc_pred, "logits"): 55 | loc_pred = loc_pred.logits 56 | 57 | loc_loss = kl_loc_loss(original_logits_train, loc_pred) 58 | 59 | pred_log = edit_model(example["edit_inner_ids"], params=fast_params) 60 | if hasattr(pred_log, "logits"): 61 | pred_log = pred_log.logits 62 | prob_dict = masked_log_probs(pred_log, example["edit_inner_labels"]) 63 | edit_loss = prob_dict["nll"] 64 | acc = prob_dict["acc"] 65 | 66 | loss = loc_loss + 0.0002 * edit_loss 67 | with torch.inference_mode(): 68 | loc_pred_val = edit_model(ids_val, params=fast_params) 69 | if hasattr(loc_pred_val, "logits"): 70 | loc_pred_val = loc_pred_val.logits 71 | 72 | if pred_log.dim() == 3: 73 | facc = ( 74 | ( 75 | pred_log.argmax(-1)[0, -10:-1] 76 | == example["edit_inner_labels"][0, -9:] 77 | ) 78 | .float() 79 | .mean() 80 | ) 81 | ret = ( 82 | (original_logits_val.argmax(-1) == loc_pred_val.argmax(-1)) 83 | .float() 84 | .mean() 85 | ) 86 | else: 87 | facc = (pred_log > 0) == example["edit_inner_labels"] 88 | ret = ( 89 | ((original_logits_val > 0) == (loc_pred_val > 0)).float().mean() 90 | ) 91 | 92 | print( 93 | f"{it}, ({loss.item():.6f}, {loc_loss.item():.4f}, {edit_loss.item():.4f}), {facc.item():.2f}, {ret.item():.4f} {(u@v).view(-1).norm().item():.5f}", 94 | end="\r", 95 | ) 96 | 97 | for p, g in zip(opt_list, torch.autograd.grad(loss, opt_list)): 98 | p.grad = g 99 | res_opt.step() 100 | res_opt.zero_grad() 101 | 102 | it += 1 103 | 104 | if acc == 1: 105 | history.append(1) 106 | else: 107 | history.append(0) 108 | 109 | print() 110 | print(len(history), sum(history) / len(history), ret.item()) 111 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core 2 | numpy 3 | torch 4 | click==7.1.2 # Spacy breaks for click>=8.0 5 | spacy 6 | allennlp 7 | git+git://github.com/eric-mitchell/higher@master # For in-place functional models 8 | git+git://github.com/eric-mitchell/transformers@master # To enable gradient disabling for some models 9 | datasets 10 | jsonlines 11 | wandb 12 | -------------------------------------------------------------------------------- /src/memit/baselines/mend/run.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import importlib 3 | import logging 4 | import random 5 | 6 | import hydra 7 | import models 8 | import numpy as np 9 | import torch 10 | import utils 11 | from omegaconf import OmegaConf 12 | from trainer import EditTrainer 13 | 14 | OmegaConf.register_new_resolver("uuid", lambda: utils.uuid()) 15 | 16 | 17 | logging.basicConfig( 18 | format="%(asctime)s - %(levelname)s [%(filename)s:%(lineno)d] %(message)s", 19 | level=logging.INFO, 20 | ) 21 | LOG = logging.getLogger(__name__) 22 | 23 | 24 | def add_padding(tokenizer, model): 25 | tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 26 | model.resize_token_embeddings(len(tokenizer)) 27 | model.transformer.wte.weight.data[-1] = model.transformer.wte.weight.data.mean(0) 28 | 29 | 30 | @hydra.main(config_path="config", config_name="config") 31 | def run(config): 32 | LOG.info(f"\n\n{OmegaConf.to_yaml(config)}\n") 33 | base_dir = hydra.utils.get_original_cwd() 34 | LOG.info(f"Project base directory: {base_dir}") 35 | 36 | random.seed(config.seed) 37 | np.random.seed(config.seed) 38 | torch.manual_seed(config.seed) 39 | 40 | model = models.get_model(config) 41 | tokenizer = models.get_tokenizer(config) 42 | 43 | if config.task == "gen" or config.task == "wiki": 44 | add_padding(tokenizer, model) 45 | from data_classes.wiki import GenDataset 46 | 47 | train_set = GenDataset("train", tokenizer, config, config.data.path, pct=10) 48 | val_set = GenDataset("validation", tokenizer, config, config.data.path, pct=10) 49 | elif config.task == "fc" or config.task == "fever": 50 | from data_classes.fever import BinaryAugmentedKILT 51 | 52 | train_set = BinaryAugmentedKILT( 53 | tokenizer, f"{base_dir}/data/fever/fever-train-kilt.jsonl", config 54 | ) 55 | val_set = BinaryAugmentedKILT( 56 | tokenizer, f"{base_dir}/data/fever/fever-dev-kilt.jsonl", config 57 | ) 58 | elif config.task == "qa" or config.task == "zsre": 59 | from data_classes.zsre import Seq2SeqAugmentedKILT 60 | 61 | train_set = Seq2SeqAugmentedKILT( 62 | tokenizer, 63 | f"{base_dir}/data/zsre/structured_zeroshot-train-new_annotated_final.jsonl", 64 | config, 65 | ) 66 | val_set = Seq2SeqAugmentedKILT( 67 | tokenizer, 68 | f"{base_dir}/data/zsre/structured_zeroshot-dev-new_annotated_final.jsonl", 69 | config, 70 | ) 71 | else: 72 | raise ValueError(f"Unrecognized task {config.task}") 73 | 74 | alg_module = importlib.import_module(f"algs.{config.alg}") 75 | LOG.info(f"Loading class {config.alg.upper()} from module {alg_module}") 76 | AlgClass = getattr(alg_module, config.alg.upper()) 77 | alg = AlgClass(model, config, lambda: copy.deepcopy(model)) 78 | 79 | if config.alg == "ft" and config.ft.locality.enabled: 80 | if config.ft.locality.oracle: 81 | alg.loc_sampler = train_set.edit_generator( 82 | config.ft.locality.batch_size + 1 83 | ) 84 | else: 85 | state = np.random.get_state() 86 | np.random.seed(0) 87 | loc_batch = next( 88 | train_set.edit_generator(config.ft.locality.batch_size + 1) 89 | )["loc"] 90 | np.random.set_state(state) 91 | alg.loc_ids = loc_batch["input_ids"] 92 | alg.loc_masks = loc_batch["attention_mask"] 93 | 94 | trainer = EditTrainer(alg, config, train_set, val_set) 95 | trainer.run() 96 | 97 | 98 | if __name__ == "__main__": 99 | run() 100 | -------------------------------------------------------------------------------- /src/memit/dsets/__init__.py: -------------------------------------------------------------------------------- 1 | from .attr_snippets import AttributeSnippets 2 | from .counterfact import CounterFactDataset, MultiCounterFactDataset 3 | from .knowns import KnownsDataset 4 | from .tfidf_stats import get_tfidf_vectorizer 5 | from .zsre import MENDQADataset 6 | -------------------------------------------------------------------------------- /src/memit/dsets/attr_snippets.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | from pathlib import Path 4 | 5 | import torch 6 | 7 | from util.globals import * 8 | 9 | REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/attribute_snippets.json" 10 | 11 | 12 | class AttributeSnippets: 13 | """ 14 | Contains wikipedia snippets discussing entities that have some property. 15 | 16 | More formally, given a tuple t = (s, r, o): 17 | - Let snips = AttributeSnippets(DATA_DIR) 18 | - snips[r][o] is a list of wikipedia articles for all s' such that t' = (s', r, o) is valid. 19 | """ 20 | 21 | def __init__(self, data_dir: str): 22 | data_dir = Path(data_dir) 23 | snips_loc = data_dir / "attribute_snippets.json" 24 | if not snips_loc.exists(): 25 | print(f"{snips_loc} does not exist. Downloading from {REMOTE_URL}") 26 | data_dir.mkdir(exist_ok=True, parents=True) 27 | torch.hub.download_url_to_file(REMOTE_URL, snips_loc) 28 | 29 | with open(snips_loc, "r") as f: 30 | snippets_list = json.load(f) 31 | 32 | snips = collections.defaultdict(lambda: collections.defaultdict(list)) 33 | 34 | for el in snippets_list: 35 | rid, tid = el["relation_id"], el["target_id"] 36 | for sample in el["samples"]: 37 | snips[rid][tid].append(sample) 38 | 39 | self._data = snips 40 | self.snippets_list = snippets_list 41 | 42 | def __getitem__(self, item): 43 | return self._data[item] 44 | -------------------------------------------------------------------------------- /src/memit/dsets/counterfact.py: -------------------------------------------------------------------------------- 1 | import json 2 | import typing 3 | from pathlib import Path 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | from util.globals import * 9 | 10 | REMOTE_ROOT = f"{REMOTE_ROOT_URL}/data/dsets" 11 | 12 | 13 | class CounterFactDataset(Dataset): 14 | def __init__( 15 | self, 16 | data_dir: str, 17 | multi: bool = False, 18 | size: typing.Optional[int] = None, 19 | *args, 20 | **kwargs, 21 | ): 22 | data_dir = Path(data_dir) 23 | cf_loc = data_dir / ( 24 | "counterfact.json" if not multi else "multi_counterfact.json" 25 | ) 26 | if not cf_loc.exists(): 27 | remote_url = f"{REMOTE_ROOT}/{'multi_' if multi else ''}counterfact.json" 28 | print(f"{cf_loc} does not exist. Downloading from {remote_url}") 29 | data_dir.mkdir(exist_ok=True, parents=True) 30 | torch.hub.download_url_to_file(remote_url, cf_loc) 31 | 32 | with open(cf_loc, "r") as f: 33 | self.data = json.load(f) 34 | if size is not None: 35 | self.data = self.data[:size] 36 | 37 | print(f"Loaded dataset with {len(self)} elements") 38 | 39 | def __len__(self): 40 | return len(self.data) 41 | 42 | def __getitem__(self, item): 43 | return self.data[item] 44 | 45 | 46 | class MultiCounterFactDataset(CounterFactDataset): 47 | def __init__( 48 | self, data_dir: str, size: typing.Optional[int] = None, *args, **kwargs 49 | ): 50 | super().__init__(data_dir, *args, multi=True, size=size, **kwargs) 51 | -------------------------------------------------------------------------------- /src/memit/dsets/knowns.py: -------------------------------------------------------------------------------- 1 | import json 2 | import typing 3 | from pathlib import Path 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | from util.globals import * 9 | 10 | REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/known_1000.json" 11 | 12 | 13 | class KnownsDataset(Dataset): 14 | def __init__(self, data_dir: str, *args, **kwargs): 15 | data_dir = Path(data_dir) 16 | known_loc = data_dir / "known_1000.json" 17 | if not known_loc.exists(): 18 | print(f"{known_loc} does not exist. Downloading from {REMOTE_URL}") 19 | data_dir.mkdir(exist_ok=True, parents=True) 20 | torch.hub.download_url_to_file(REMOTE_URL, known_loc) 21 | 22 | with open(known_loc, "r") as f: 23 | self.data = json.load(f) 24 | 25 | print(f"Loaded dataset with {len(self)} elements") 26 | 27 | def __len__(self): 28 | return len(self.data) 29 | 30 | def __getitem__(self, item): 31 | return self.data[item] 32 | -------------------------------------------------------------------------------- /src/memit/dsets/tfidf_stats.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import chain 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import scipy.sparse as sp 7 | import torch 8 | from sklearn.feature_extraction.text import TfidfVectorizer 9 | 10 | from dsets import AttributeSnippets 11 | from util.globals import * 12 | 13 | REMOTE_IDF_URL = f"{REMOTE_ROOT_URL}/data/dsets/idf.npy" 14 | REMOTE_VOCAB_URL = f"{REMOTE_ROOT_URL}/data/dsets/tfidf_vocab.json" 15 | 16 | 17 | def get_tfidf_vectorizer(data_dir: str): 18 | """ 19 | Returns an sklearn TF-IDF vectorizer. See their website for docs. 20 | Loading hack inspired by some online blog post lol. 21 | """ 22 | 23 | data_dir = Path(data_dir) 24 | 25 | idf_loc, vocab_loc = data_dir / "idf.npy", data_dir / "tfidf_vocab.json" 26 | if not (idf_loc.exists() and vocab_loc.exists()): 27 | collect_stats(data_dir) 28 | 29 | idf = np.load(idf_loc) 30 | with open(vocab_loc, "r") as f: 31 | vocab = json.load(f) 32 | 33 | class MyVectorizer(TfidfVectorizer): 34 | TfidfVectorizer.idf_ = idf 35 | 36 | vec = MyVectorizer() 37 | vec.vocabulary_ = vocab 38 | vec._tfidf._idf_diag = sp.spdiags(idf, diags=0, m=len(idf), n=len(idf)) 39 | 40 | return vec 41 | 42 | 43 | def collect_stats(data_dir: str): 44 | """ 45 | Uses wikipedia snippets to collect statistics over a corpus of English text. 46 | Retrieved later when computing TF-IDF vectors. 47 | """ 48 | 49 | data_dir = Path(data_dir) 50 | data_dir.mkdir(exist_ok=True, parents=True) 51 | idf_loc, vocab_loc = data_dir / "idf.npy", data_dir / "tfidf_vocab.json" 52 | 53 | try: 54 | print(f"Downloading IDF cache from {REMOTE_IDF_URL}") 55 | torch.hub.download_url_to_file(REMOTE_IDF_URL, idf_loc) 56 | print(f"Downloading TF-IDF vocab cache from {REMOTE_VOCAB_URL}") 57 | torch.hub.download_url_to_file(REMOTE_VOCAB_URL, vocab_loc) 58 | return 59 | except Exception as e: 60 | print(f"Error downloading file:", e) 61 | print("Recomputing TF-IDF stats...") 62 | 63 | snips_list = AttributeSnippets(data_dir).snippets_list 64 | documents = list(chain(*[[y["text"] for y in x["samples"]] for x in snips_list])) 65 | 66 | vec = TfidfVectorizer() 67 | vec.fit(documents) 68 | 69 | idfs = vec.idf_ 70 | vocab = vec.vocabulary_ 71 | 72 | np.save(data_dir / "idf.npy", idfs) 73 | with open(data_dir / "tfidf_vocab.json", "w") as f: 74 | json.dump(vocab, f, indent=1) 75 | -------------------------------------------------------------------------------- /src/memit/dsets/zsre.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import torch 5 | from transformers import AutoTokenizer 6 | 7 | from util.globals import * 8 | 9 | REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/zsre_mend_eval.json" 10 | 11 | 12 | class MENDQADataset: 13 | """ 14 | Dataset of factual knowledge based on zsRE. 15 | Specifically selected from the QA validation slice from Mitchell et al. 16 | Project page: http://nlp.cs.washington.edu/zeroshot/ 17 | """ 18 | 19 | def __init__(self, data_dir: str, tok: AutoTokenizer, size=None, *args, **kwargs): 20 | data_dir = Path(data_dir) 21 | zsre_loc = data_dir / "zsre_mend_eval.json" 22 | if not zsre_loc.exists(): 23 | print(f"{zsre_loc} does not exist. Downloading from {REMOTE_URL}") 24 | data_dir.mkdir(exist_ok=True, parents=True) 25 | torch.hub.download_url_to_file(REMOTE_URL, zsre_loc) 26 | 27 | with open(zsre_loc, "r") as f: 28 | raw = json.load(f) 29 | 30 | data = [] 31 | for i, record in enumerate(raw): 32 | assert ( 33 | "nq question: " in record["loc"] 34 | ), f"Neighborhood prompt missing `nq question:`. Check for errors?" 35 | ans_toks = tok(" " + record["loc_ans"])["input_ids"] 36 | data.append( 37 | { 38 | "case_id": i, 39 | "requested_rewrite": { 40 | "prompt": record["src"].replace(record["subject"], "{}"), 41 | "subject": record["subject"], 42 | "target_new": {"str": record["answers"][0]}, 43 | "target_true": {"str": "<|endoftext|>"}, 44 | }, 45 | "paraphrase_prompts": [record["rephrase"]], 46 | "neighborhood_prompts": [ 47 | { 48 | "prompt": record["loc"] + "?" + tok.decode(ans_toks[:i]), 49 | "target": tok.decode(ans_toks[i]), 50 | } 51 | for i in range(len(ans_toks)) 52 | ], 53 | "attribute_prompts": [], 54 | "generation_prompts": [], 55 | } 56 | ) 57 | 58 | self._data = data[:size] 59 | 60 | def __getitem__(self, item): 61 | return self._data[item] 62 | 63 | def __len__(self): 64 | return len(self._data) 65 | -------------------------------------------------------------------------------- /src/memit/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edenbiran/RippleEdits/54f3b88af4895a3aacb580ec63ce7ae857185040/src/memit/experiments/__init__.py -------------------------------------------------------------------------------- /src/memit/experiments/plot_causal_trace_avg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | 6 | 7 | def main(args): 8 | scores = [] 9 | file_count = 0 10 | for filename in os.listdir(args.results_dir): 11 | file_count += 1 12 | plot_result = dict(np.load(os.path.join(args.results_dir, filename), allow_pickle=True)) 13 | if not plot_result['correct_prediction'] or plot_result['kind'] != 'mlp': 14 | continue 15 | layer_scores = np.sum(plot_result['scores'], axis=0) 16 | normalized_layer_scores = layer_scores / np.sum(layer_scores) 17 | scores.append(normalized_layer_scores) 18 | 19 | total_scores = np.sum(scores, axis=0) 20 | normalized_total_scores = total_scores / np.sum(total_scores) 21 | print(f'Using {len(scores)} / {file_count // 3} tests the layer with the highest score is layer {np.argmax(normalized_total_scores)}') 22 | 23 | plt.ylabel('Score') 24 | plt.xlabel('Layer') 25 | plt.plot(normalized_total_scores) 26 | plt.show() 27 | 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser(description="Causal Tracing Averages") 31 | parser.add_argument("results_dir") 32 | main(parser.parse_args()) 33 | -------------------------------------------------------------------------------- /src/memit/experiments/py/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Dict, List, Tuple 4 | 5 | import torch 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | 8 | from baselines.ft import FTHyperParams, apply_ft_to_model 9 | from memit import MEMITHyperParams, apply_memit_to_model 10 | from rome import ROMEHyperParams, apply_rome_to_model 11 | from util import nethook 12 | from util.generate import generate_fast 13 | from util.globals import * 14 | 15 | 16 | def demo_model_editing( 17 | model: AutoModelForCausalLM, 18 | tok: AutoTokenizer, 19 | requests: List[Dict], 20 | generation_prompts: List[str], 21 | alg_name: str = "ROME", 22 | ) -> Tuple[AutoModelForCausalLM, Dict[str, torch.Tensor]]: 23 | """ 24 | Applies the selected model editing algorithm. Generates text both before and after 25 | for comparison of model behavior. Returns the updated model and the original values of 26 | weights that were changed. 27 | """ 28 | 29 | nethook.set_requires_grad(True, model) 30 | 31 | RewritingParamsClass, apply_method, hparams_prefix, hparams_suffix = load_alg( 32 | alg_name 33 | ) 34 | params_name = ( 35 | HPARAMS_DIR 36 | / hparams_prefix 37 | / f"{model.config._name_or_path.replace('/', '_')}{hparams_suffix}.json" 38 | ) 39 | 40 | print_loud(f"Retrieving {alg_name} hyperparameters") 41 | print("Loading from", params_name) 42 | hparams = RewritingParamsClass.from_json(params_name) 43 | print(hparams) 44 | 45 | print_loud("Generating pre-update text") 46 | pre_update_text = generate_fast(model, tok, generation_prompts, max_out_len=100) 47 | print(pre_update_text) 48 | 49 | print_loud(f"Applying {alg_name} to model") 50 | model_new, orig_weights = apply_method( 51 | model, 52 | tok, 53 | requests, 54 | hparams, 55 | return_orig_weights=True, 56 | ) 57 | 58 | print_loud("Generating post-update text") 59 | post_update_text = generate_fast( 60 | model_new, tok, generation_prompts, max_out_len=100 61 | ) 62 | print(post_update_text) 63 | 64 | print_loud("Summarizing differences") 65 | for i, (prompt, pre, post) in enumerate( 66 | zip(generation_prompts, pre_update_text, post_update_text) 67 | ): 68 | if i > 0: 69 | print("".join(["-" for _ in range(10)])) 70 | 71 | prompt_str = "[Prompt]:" 72 | pre_str = f"[Pre-{alg_name}]:" 73 | post_str = f"[Post-{alg_name}]:" 74 | pad_to = 1 + max(len(prompt_str), len(pre_str), len(post_str)) 75 | 76 | for s, t in zip([prompt_str, post_str, pre_str], [prompt, post, pre]): 77 | print(s.ljust(pad_to), t) 78 | 79 | return model_new, orig_weights 80 | 81 | 82 | def load_alg(alg_name): 83 | """ 84 | Loads dependencies for the desired algorithm. 85 | Implementation is slightly awkward to prevent unnecessary imports on Colab. 86 | 87 | The return value is a tuple of the following: 88 | 1. Class for storing hyperparameters 89 | 2. Method for applying rewrites 90 | 3. Location of parameters 91 | 4. Predefined suffix for the param file 92 | """ 93 | assert alg_name in [ 94 | "FT", 95 | "FT-L", 96 | "FT-AttnEdit", 97 | "MEND", 98 | "MEND-CF", 99 | "MEND-zsRE", 100 | "ROME", 101 | "MEMIT", 102 | ] 103 | 104 | if alg_name == "ROME": 105 | return ROMEHyperParams, apply_rome_to_model, "ROME", "" 106 | elif alg_name == "MEMIT": 107 | return MEMITHyperParams, apply_memit_to_model, "MEMIT", "" 108 | elif "FT" in alg_name: 109 | d = { 110 | "FT": (FTHyperParams, apply_ft_to_model, "FT", "_unconstr"), 111 | "FT-AttnEdit": (FTHyperParams, apply_ft_to_model, "FT", "_attn"), 112 | "FT-L": (FTHyperParams, apply_ft_to_model, "FT", "_constr"), 113 | } 114 | return d[alg_name] 115 | else: 116 | from baselines.mend import MENDHyperParams, MendRewriteExecutor 117 | 118 | d = { 119 | "MEND": (MENDHyperParams, MendRewriteExecutor().apply_to_model, "MEND", ""), 120 | "MEND-CF": ( 121 | MENDHyperParams, 122 | MendRewriteExecutor().apply_to_model, 123 | "MEND", 124 | "_CF", 125 | ), 126 | "MEND-zsRE": ( 127 | MENDHyperParams, 128 | MendRewriteExecutor().apply_to_model, 129 | "MEND", 130 | "_zsRE", 131 | ), 132 | } 133 | return d[alg_name] 134 | 135 | 136 | def print_loud(x, pad=3): 137 | """ 138 | Prints a string with # box for emphasis. 139 | 140 | Example: 141 | ############################ 142 | # # 143 | # Applying ROME to model # 144 | # # 145 | ############################ 146 | """ 147 | 148 | n = len(x) 149 | print() 150 | print("".join(["#" for _ in range(n + 2 * pad)])) 151 | print("#" + "".join([" " for _ in range(n + 2 * (pad - 1))]) + "#") 152 | print( 153 | "#" 154 | + "".join([" " for _ in range(pad - 1)]) 155 | + x 156 | + "".join([" " for _ in range(pad - 1)]) 157 | + "#" 158 | ) 159 | print("#" + "".join([" " for _ in range(n + 2 * (pad - 1))]) + "#") 160 | print("".join(["#" for _ in range(n + 2 * pad)])) 161 | 162 | 163 | class StopExecution(Exception): 164 | def _render_traceback_(self): 165 | pass 166 | 167 | 168 | def stop_execution(): 169 | raise StopExecution 170 | -------------------------------------------------------------------------------- /src/memit/experiments/py/eval_utils_zsre.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains evaluation utilities for pytorch-based rewriting methods. 3 | To use, simply call `compute_rewrite_quality_zsre` with the 4 | appropriate arguments, which returns a dictionary containing them. 5 | """ 6 | 7 | import typing 8 | from itertools import chain 9 | 10 | import numpy as np 11 | import torch 12 | from sklearn.feature_extraction.text import TfidfVectorizer 13 | from transformers import AutoModelForCausalLM, AutoTokenizer 14 | 15 | from dsets import AttributeSnippets 16 | 17 | 18 | def compute_rewrite_quality_zsre( 19 | model: AutoModelForCausalLM, 20 | tok: AutoTokenizer, 21 | record: typing.Dict, 22 | snips: AttributeSnippets, 23 | vec: TfidfVectorizer, 24 | ) -> typing.Dict: 25 | """ 26 | Given a rewritten model, computes generalization and specificity metrics for 27 | the desired rewrite (passed in via the CounterFact dataset record). Returns a 28 | dictionary containing those metrics. 29 | 30 | :param model: Rewritten model 31 | :param tok: Tokenizer 32 | :param record: CounterFact dataset record 33 | :paran snips: ??? 34 | :param vec: ??? 35 | :return: Dictionary containing rewriting metrics 36 | """ 37 | 38 | # First, unpack rewrite evaluation record. 39 | subject, target_new, target_true = ( 40 | record["requested_rewrite"][x] for x in ["subject", "target_new", "target_true"] 41 | ) 42 | rewrite_prompts = [record["requested_rewrite"]["prompt"].format(subject)] 43 | paraphrase_prompts = record["paraphrase_prompts"] 44 | neighborhood_prompts = record["neighborhood_prompts"] 45 | 46 | # Form a list of lists of prefixes to test. 47 | prob_prompts = [ 48 | rewrite_prompts, 49 | paraphrase_prompts, 50 | ] 51 | # Flatten all the evaluated prefixes into one list. 52 | target_tok = tok(" " + target_new["str"])["input_ids"] 53 | inp_prompts_og = list(chain(*prob_prompts)) 54 | inp_prompts = [ 55 | el + tok.decode(target_tok[:i]) 56 | for el in inp_prompts_og 57 | for i in range(len(target_tok)) 58 | ] 59 | inp_targets = [ 60 | tok.decode(target_tok[i]) 61 | for _ in range(len(inp_prompts_og)) 62 | for i in range(len(target_tok)) 63 | ] 64 | 65 | stuff_probs = test_batch_prediction_acc(model, tok, inp_prompts, inp_targets) 66 | 67 | # Predict for neighborhood prompts (dictionary format). 68 | neighborhood_correct = test_batch_prediction_acc( 69 | model, 70 | tok, 71 | [ 72 | el["prompt"].format(record["requested_rewrite"]) 73 | for el in neighborhood_prompts 74 | ], 75 | [el["target"] for el in neighborhood_prompts], 76 | ) 77 | 78 | probs = stuff_probs + neighborhood_correct 79 | 80 | # Unflatten the results again into a list of lists. 81 | cutoffs = [0] + np.cumsum( 82 | [l * len(target_tok) for l in map(len, prob_prompts)] 83 | ).tolist() 84 | ret_probs = [probs[cutoffs[i - 1] : cutoffs[i]] for i in range(1, len(cutoffs))] 85 | # Structure the restuls as a dictionary. 86 | ret = { 87 | f"{key}_correct": ret_probs[i] 88 | for i, key in enumerate( 89 | [ 90 | "rewrite_prompts", 91 | "paraphrase_prompts", 92 | ] 93 | ) 94 | } 95 | ret["neighborhood_prompts_correct"] = neighborhood_correct 96 | 97 | return ret 98 | 99 | 100 | def test_batch_prediction_acc(model, tok, prompts: typing.List[str], target): 101 | prompt_tok = tok( 102 | prompts, 103 | padding=True, 104 | return_tensors="pt", 105 | ).to("cuda") 106 | 107 | with torch.no_grad(): 108 | logits = model(**prompt_tok).logits 109 | last_non_masked = prompt_tok["attention_mask"].sum(1) - 1 110 | to_gather = last_non_masked.unsqueeze(1).repeat(1, logits.size(-1)).unsqueeze(1) 111 | gathered = torch.gather(logits, 1, to_gather).squeeze(1) 112 | ans = torch.argmax(gathered, dim=1) 113 | 114 | correct_id = tok(target, padding=True, return_tensors="pt").to("cuda")[ 115 | "input_ids" 116 | ] 117 | # Temporary hack to deal with foreign characters. 118 | correct_id = correct_id[:, 0].squeeze() 119 | 120 | return (ans == correct_id).detach().cpu().numpy().tolist() 121 | -------------------------------------------------------------------------------- /src/memit/experiments/sweep.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from copy import deepcopy 5 | from pathlib import Path 6 | from typing import Dict, List, Tuple 7 | 8 | import torch 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | 11 | from experiments.evaluate import HPARAMS_DIR 12 | from experiments.evaluate import main as eval_main 13 | 14 | TMP_PARAMS_TEMPLATE = "sweep_params_tmp_{}_.json" 15 | 16 | 17 | def exec_sweep( 18 | alg_name: str, 19 | model_tok: Tuple[AutoModelForCausalLM, AutoTokenizer], 20 | hparams_fname: str, 21 | ds_name: str, 22 | sweep_dir: Path, 23 | num_records: int, 24 | generation_test_interval: bool, 25 | num_edits: int, 26 | use_cache: bool, 27 | ): 28 | # Configure hparams 29 | with open(HPARAMS_DIR / alg_name / hparams_fname, "r") as f: 30 | hparams_orig = json.load(f) 31 | with open(Path("results") / sweep_dir / "config.json", "r") as f: 32 | sweep_config = json.load(f) 33 | sweep_keys = list(sweep_config.keys()) 34 | 35 | # Sweep 36 | for s_i, state in enumerate(get_states([], sweep_config, sweep_keys)): 37 | # Set dirs 38 | tmp_params_name = TMP_PARAMS_TEMPLATE.format(time.time_ns()) 39 | tmp_params_path = HPARAMS_DIR / alg_name / tmp_params_name 40 | 41 | # Set new hparams 42 | hparams_new = deepcopy(hparams_orig) 43 | for key_num, state_num in enumerate(state): 44 | k = sweep_keys[key_num] 45 | hparams_new[k] = sweep_config[k][state_num] 46 | print(f"Sweep {s_i}: Setting {k} = {hparams_new[k]}") 47 | 48 | with open(tmp_params_path, "w") as f: 49 | json.dump(hparams_new, f) 50 | 51 | # Execute 52 | eval_main( 53 | alg_name, 54 | model_name=model_tok, 55 | hparams_fname=tmp_params_name, 56 | ds_name=ds_name, 57 | dataset_size_limit=num_records, 58 | continue_from_run="run_000", 59 | skip_generation_tests=(generation_test_interval == -1), 60 | generation_test_interval=generation_test_interval, 61 | conserve_memory=False, 62 | dir_name=sweep_dir / f"{num_edits}_edits_setting_{s_i}", 63 | num_edits=num_edits, 64 | use_cache=use_cache, 65 | ) 66 | 67 | # Clean up 68 | os.remove(tmp_params_path) 69 | 70 | 71 | def get_states( 72 | state: List, 73 | sweep_config: Dict, 74 | sweep_keys: List, 75 | ): 76 | """ 77 | Standard recursive procedure for generating all possible configurations. 78 | """ 79 | 80 | ans = [] 81 | if len(state) < len(sweep_config): 82 | for i in range(len(sweep_config[sweep_keys[len(state)]])): 83 | for s in get_states(state + [i], sweep_config, sweep_keys): 84 | ans.append(s) 85 | else: 86 | ans.append(state) 87 | return ans 88 | 89 | 90 | if __name__ == "__main__": 91 | import argparse 92 | 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument( 95 | "--alg_name", choices=["MEMIT", "FT", "ROME", "MEND"], required=True 96 | ) 97 | parser.add_argument( 98 | "--model_name", choices=["gpt2-xl", "EleutherAI/gpt-j-6B"], required=True 99 | ) 100 | parser.add_argument("--hparams_fname", type=str, required=True) 101 | parser.add_argument( 102 | "--ds_name", 103 | choices=["mcf", "cf", "zsre"], 104 | default="mcf", 105 | help="Dataset to perform evaluations on. Either CounterFact (cf), MultiCounterFact (mcf), or zsRE (zsre).", 106 | ) 107 | parser.add_argument("--min_records", type=int, default=None) 108 | parser.add_argument("--max_records", type=int, default=None) 109 | parser.add_argument( 110 | "--num_edits", 111 | type=str, 112 | default="1", 113 | help="Number of rewrites to perform simultaneously.", 114 | ) 115 | parser.add_argument( 116 | "--generation_test_interval", 117 | type=int, 118 | default=-1, 119 | help="One generation test is performed every [flag_value] iterations. If -1, generation tests are skipped.", 120 | ) 121 | parser.add_argument("--sweep_dir", type=str) 122 | parser.add_argument( 123 | "--use_cache", 124 | dest="use_cache", 125 | action="store_true", 126 | help="Use cached k/v pairs (MEMIT and ROME only)", 127 | ) 128 | 129 | args = parser.parse_args() 130 | assert args.sweep_dir is not None, f"Must specify a sweep_dir." 131 | 132 | model = AutoModelForCausalLM.from_pretrained(args.model_name).to("cuda") 133 | tok = AutoTokenizer.from_pretrained(args.model_name) 134 | tok.pad_token = tok.eos_token 135 | 136 | for cur_num_edits in list(map(int, args.num_edits.split(","))): 137 | torch.cuda.empty_cache() 138 | 139 | num_records = ( 140 | None if args.max_records is None 141 | else min(args.max_records, cur_num_edits) 142 | ) 143 | if args.min_records is not None: 144 | num_records = max(args.min_records, cur_num_edits) 145 | 146 | exec_sweep( 147 | args.alg_name, 148 | (model, tok), 149 | args.hparams_fname, 150 | args.ds_name, 151 | Path(args.sweep_dir), 152 | num_records, 153 | args.generation_test_interval, 154 | cur_num_edits, 155 | args.use_cache, 156 | ) 157 | -------------------------------------------------------------------------------- /src/memit/globals.yml: -------------------------------------------------------------------------------- 1 | --- 2 | # Result files 3 | RESULTS_DIR: "results" 4 | 5 | # Data files 6 | DATA_DIR: "data" 7 | STATS_DIR: "data/stats" 8 | KV_DIR: "/share/projects/rewriting-knowledge/kvs" 9 | 10 | # Hyperparameters 11 | HPARAMS_DIR: "hparams" 12 | 13 | # Remote URLs 14 | REMOTE_ROOT_URL: "https://memit.baulab.info" 15 | -------------------------------------------------------------------------------- /src/memit/hparams/FT/EleutherAI_gpt-j-6B_constr.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 0 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": 5e-5, 10 | "rewrite_module_tmp": "transformer.h.{}.mlp.fc_out", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "lm_head" 16 | } -------------------------------------------------------------------------------- /src/memit/hparams/FT/EleutherAI_gpt-j-6B_unconstr.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 21 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": false, 10 | "rewrite_module_tmp": "transformer.h.{}.mlp.fc_out", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "lm_head" 16 | } -------------------------------------------------------------------------------- /src/memit/hparams/FT/EleutherAI_gpt-j-6B_wd.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 21 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": false, 8 | "wd_power_law": [-0.87028798, 0.15589562], 9 | "kl_factor": 0, 10 | "norm_constraint": false, 11 | "rewrite_module_tmp": "transformer.h.{}", 12 | "layer_module_tmp": "transformer.h.{}", 13 | "mlp_module_tmp": "transformer.h.{}.mlp", 14 | "attn_module_tmp": "transformer.h.{}.attn", 15 | "ln_f_module": "transformer.ln_f", 16 | "lm_head_module": "lm_head" 17 | } -------------------------------------------------------------------------------- /src/memit/hparams/FT/gpt2-large_constr.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 0 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": 1e-3, 10 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "transformer.wte" 16 | } -------------------------------------------------------------------------------- /src/memit/hparams/FT/gpt2-medium_constr.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 0 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": 2e-3, 10 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "transformer.wte" 16 | } -------------------------------------------------------------------------------- /src/memit/hparams/FT/gpt2-xl_attn.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 33 4 | ], 5 | "num_steps": 25, 6 | "lr": 1e-3, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": 1e-3, 10 | "rewrite_module_tmp": "transformer.h.{}.attn.c_attn", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "transformer.wte" 16 | } -------------------------------------------------------------------------------- /src/memit/hparams/FT/gpt2-xl_constr.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 0 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": 5e-4, 10 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "transformer.wte" 16 | } -------------------------------------------------------------------------------- /src/memit/hparams/FT/gpt2-xl_unconstr.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 1 4 | ], 5 | "num_steps": 25, 6 | "lr": 5e-4, 7 | "weight_decay": 0, 8 | "kl_factor": 0, 9 | "norm_constraint": false, 10 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 11 | "layer_module_tmp": "transformer.h.{}", 12 | "mlp_module_tmp": "transformer.h.{}.mlp", 13 | "attn_module_tmp": "transformer.h.{}.attn", 14 | "ln_f_module": "transformer.ln_f", 15 | "lm_head_module": "transformer.wte" 16 | } -------------------------------------------------------------------------------- /src/memit/hparams/MEMIT/EleutherAI_gpt-j-6B.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 3, 4, 5, 6, 7, 8 4 | ], 5 | "clamp_norm_factor": 0.75, 6 | "layer_selection": "all", 7 | "fact_token": "subject_last", 8 | "v_num_grad_steps": 25, 9 | "v_lr": 5e-1, 10 | "v_loss_layer": 27, 11 | "v_weight_decay": 0.5, 12 | "kl_factor": 0.0625, 13 | "mom2_adjustment": true, 14 | "mom2_update_weight": 15000, 15 | "rewrite_module_tmp": "transformer.h.{}.mlp.fc_out", 16 | "layer_module_tmp": "transformer.h.{}", 17 | "mlp_module_tmp": "transformer.h.{}.mlp", 18 | "attn_module_tmp": "transformer.h.{}.attn", 19 | "ln_f_module": "transformer.ln_f", 20 | "lm_head_module": "lm_head", 21 | "mom2_dataset": "wikipedia", 22 | "mom2_n_samples": 100000, 23 | "mom2_dtype": "float32" 24 | } 25 | -------------------------------------------------------------------------------- /src/memit/hparams/MEMIT/gpt2-xl.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [13, 14, 15, 16, 17], 3 | "clamp_norm_factor": 0.75, 4 | "layer_selection": "all", 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 20, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 47, 9 | "v_weight_decay": 0.5, 10 | "kl_factor": 0.0625, 11 | "mom2_adjustment": true, 12 | "mom2_update_weight": 20000, 13 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 14 | "layer_module_tmp": "transformer.h.{}", 15 | "mlp_module_tmp": "transformer.h.{}.mlp", 16 | "attn_module_tmp": "transformer.h.{}.attn", 17 | "ln_f_module": "transformer.ln_f", 18 | "lm_head_module": "transformer.wte", 19 | "mom2_dataset": "wikipedia", 20 | "mom2_n_samples": 100000, 21 | "mom2_dtype": "float32" 22 | } -------------------------------------------------------------------------------- /src/memit/hparams/MEND/EleutherAI_gpt-j-6B.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr_scale": 1.0, 3 | "n_toks": 10, 4 | "model_name": "EleutherAI/gpt-j-6B", 5 | "counterfact": false, 6 | "zsre": false, 7 | "mini": false 8 | } -------------------------------------------------------------------------------- /src/memit/hparams/MEND/EleutherAI_gpt-j-6B_CF.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr_scale": 1.0, 3 | "n_toks": 10, 4 | "model_name": "EleutherAI/gpt-j-6B", 5 | "counterfact": true, 6 | "zsre": false, 7 | "mini": false 8 | } -------------------------------------------------------------------------------- /src/memit/hparams/MEND/gpt2-xl.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr_scale": 1.0, 3 | "n_toks": 10, 4 | "model_name": "gpt2-xl", 5 | "counterfact": false, 6 | "zsre": false, 7 | "mini": false 8 | } -------------------------------------------------------------------------------- /src/memit/hparams/MEND/gpt2-xl_CF.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr_scale": 1.0, 3 | "n_toks": 1, 4 | "model_name": "gpt2-xl", 5 | "counterfact": true, 6 | "zsre": false, 7 | "mini": false 8 | } -------------------------------------------------------------------------------- /src/memit/hparams/MEND/gpt2-xl_zsRE.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr_scale": 1.0, 3 | "n_toks": 1, 4 | "model_name": "gpt2-xl", 5 | "counterfact": false, 6 | "zsre": true, 7 | "mini": false 8 | } -------------------------------------------------------------------------------- /src/memit/hparams/ROME/EleutherAI_gpt-j-6B.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 5 4 | ], 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 20, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 27, 9 | "v_weight_decay": 0.5, 10 | "clamp_norm_factor": 4, 11 | "kl_factor": 0.0625, 12 | "mom2_adjustment": true, 13 | "context_template_length_params": [[5, 10], [10, 10]], 14 | "rewrite_module_tmp": "transformer.h.{}.mlp.fc_out", 15 | "layer_module_tmp": "transformer.h.{}", 16 | "mlp_module_tmp": "transformer.h.{}.mlp", 17 | "attn_module_tmp": "transformer.h.{}.attn", 18 | "ln_f_module": "transformer.ln_f", 19 | "lm_head_module": "lm_head", 20 | "mom2_dataset": "wikipedia", 21 | "mom2_n_samples": 100000, 22 | "mom2_dtype": "float32" 23 | } -------------------------------------------------------------------------------- /src/memit/hparams/ROME/EleutherAI_gpt-neox-20b.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 15 4 | ], 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 20, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 43, 9 | "v_weight_decay": 0.5, 10 | "clamp_norm_factor": 4, 11 | "kl_factor": 0.0625, 12 | "mom2_adjustment": true, 13 | "context_template_length_params": [[5, 10], [10, 10]], 14 | "rewrite_module_tmp": "gpt_neox.layers.{}.mlp.dense_4h_to_h", 15 | "layer_module_tmp": "gpt_neox.layers.{}", 16 | "mlp_module_tmp": "gpt_neox.layers.{}.mlp", 17 | "attn_module_tmp": "gpt_neox.layers.{}.attention", 18 | "ln_f_module": "gpt_neox.final_layer_norm", 19 | "lm_head_module": "embed_out", 20 | "mom2_dataset": "wikipedia", 21 | "mom2_n_samples": 100000, 22 | "mom2_dtype": "float32" 23 | } -------------------------------------------------------------------------------- /src/memit/hparams/ROME/gpt2-large.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 12 4 | ], 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 20, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 35, 9 | "v_weight_decay": 0.5, 10 | "clamp_norm_factor": 4, 11 | "kl_factor": 0.0625, 12 | "mom2_adjustment": true, 13 | "context_template_length_params": [[5, 10], [10, 10]], 14 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 15 | "layer_module_tmp": "transformer.h.{}", 16 | "mlp_module_tmp": "transformer.h.{}.mlp", 17 | "attn_module_tmp": "transformer.h.{}.attn", 18 | "ln_f_module": "transformer.ln_f", 19 | "lm_head_module": "transformer.wte", 20 | "mom2_dataset": "wikipedia", 21 | "mom2_n_samples": 100000, 22 | "mom2_dtype": "float32" 23 | } -------------------------------------------------------------------------------- /src/memit/hparams/ROME/gpt2-medium.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 8 4 | ], 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 20, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 23, 9 | "v_weight_decay": 0.5, 10 | "clamp_norm_factor": 3, 11 | "kl_factor": 0.0625, 12 | "mom2_adjustment": true, 13 | "context_template_length_params": [[5, 10], [10, 10]], 14 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 15 | "layer_module_tmp": "transformer.h.{}", 16 | "mlp_module_tmp": "transformer.h.{}.mlp", 17 | "attn_module_tmp": "transformer.h.{}.attn", 18 | "ln_f_module": "transformer.ln_f", 19 | "lm_head_module": "transformer.wte", 20 | "mom2_dataset": "wikipedia", 21 | "mom2_n_samples": 100000, 22 | "mom2_dtype": "float32" 23 | } -------------------------------------------------------------------------------- /src/memit/hparams/ROME/gpt2-xl.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 17 4 | ], 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 20, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 47, 9 | "v_weight_decay": 0.5, 10 | "clamp_norm_factor": 4, 11 | "kl_factor": 0.0625, 12 | "mom2_adjustment": true, 13 | "context_template_length_params": [[5, 10], [10, 10]], 14 | "rewrite_module_tmp": "transformer.h.{}.mlp.c_proj", 15 | "layer_module_tmp": "transformer.h.{}", 16 | "mlp_module_tmp": "transformer.h.{}.mlp", 17 | "attn_module_tmp": "transformer.h.{}.attn", 18 | "ln_f_module": "transformer.ln_f", 19 | "lm_head_module": "transformer.wte", 20 | "mom2_dataset": "wikipedia", 21 | "mom2_n_samples": 100000, 22 | "mom2_dtype": "float32" 23 | } -------------------------------------------------------------------------------- /src/memit/hparams/ROME/llama-7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers": [ 3 | 5 4 | ], 5 | "fact_token": "subject_last", 6 | "v_num_grad_steps": 20, 7 | "v_lr": 5e-1, 8 | "v_loss_layer": 31, 9 | "v_weight_decay": 0.5, 10 | "clamp_norm_factor": 4, 11 | "kl_factor": 0.0625, 12 | "mom2_adjustment": true, 13 | "context_template_length_params": [[5, 10], [10, 10]], 14 | "rewrite_module_tmp": "model.layers.{}.mlp.down_proj", 15 | "layer_module_tmp": "model.layers.{}", 16 | "mlp_module_tmp": "model.layers.{}.mlp", 17 | "attn_module_tmp": "model.layers.{}.self_attn", 18 | "ln_f_module": "model.norm", 19 | "lm_head_module": "lm_head", 20 | "mom2_dataset": "wikipedia", 21 | "mom2_n_samples": 100000, 22 | "mom2_dtype": "float32" 23 | } -------------------------------------------------------------------------------- /src/memit/memit/__init__.py: -------------------------------------------------------------------------------- 1 | from .memit_main import MEMITHyperParams, apply_memit_to_model -------------------------------------------------------------------------------- /src/memit/memit/compute_ks.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy as np 4 | import torch 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | 7 | from .compute_z import get_module_input_output_at_words 8 | from .memit_hparams import MEMITHyperParams 9 | 10 | 11 | def compute_ks( 12 | model: AutoModelForCausalLM, 13 | tok: AutoTokenizer, 14 | requests: Dict, 15 | hparams: MEMITHyperParams, 16 | layer: int, 17 | context_templates: List[str], 18 | ): 19 | layer_ks = get_module_input_output_at_words( 20 | model, 21 | tok, 22 | layer, 23 | context_templates=[ 24 | context.format(request["prompt"]) 25 | for request in requests 26 | for context_type in context_templates 27 | for context in context_type 28 | ], 29 | words=[ 30 | request["subject"] 31 | for request in requests 32 | for context_type in context_templates 33 | for _ in context_type 34 | ], 35 | module_template=hparams.rewrite_module_tmp, 36 | fact_token_strategy=hparams.fact_token, 37 | )[0] 38 | 39 | context_type_lens = [0] + [len(context_type) for context_type in context_templates] 40 | context_len = sum(context_type_lens) 41 | context_type_csum = np.cumsum(context_type_lens).tolist() 42 | 43 | ans = [] 44 | for i in range(0, layer_ks.size(0), context_len): 45 | tmp = [] 46 | for j in range(len(context_type_csum) - 1): 47 | start, end = context_type_csum[j], context_type_csum[j + 1] 48 | tmp.append(layer_ks[i + start : i + end].mean(0)) 49 | ans.append(torch.stack(tmp, 0).mean(0)) 50 | return torch.stack(ans, dim=0) 51 | -------------------------------------------------------------------------------- /src/memit/memit/memit_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Literal 3 | 4 | from util.hparams import HyperParams 5 | 6 | 7 | @dataclass 8 | class MEMITHyperParams(HyperParams): 9 | # Method 10 | layers: List[int] 11 | layer_selection: Literal["all", "random"] 12 | fact_token: Literal[ 13 | "last", "subject_first", "subject_last", "subject_first_after_last" 14 | ] 15 | v_num_grad_steps: int 16 | v_lr: float 17 | v_loss_layer: int 18 | v_weight_decay: float 19 | clamp_norm_factor: float 20 | kl_factor: float 21 | mom2_adjustment: bool 22 | mom2_update_weight: float 23 | 24 | # Module templates 25 | rewrite_module_tmp: str 26 | layer_module_tmp: str 27 | mlp_module_tmp: str 28 | attn_module_tmp: str 29 | ln_f_module: str 30 | lm_head_module: str 31 | 32 | # Statistics 33 | mom2_dataset: str 34 | mom2_n_samples: int 35 | mom2_dtype: str 36 | -------------------------------------------------------------------------------- /src/memit/notebooks/baselines: -------------------------------------------------------------------------------- 1 | ../baselines/ -------------------------------------------------------------------------------- /src/memit/notebooks/data: -------------------------------------------------------------------------------- 1 | ../data -------------------------------------------------------------------------------- /src/memit/notebooks/dsets: -------------------------------------------------------------------------------- 1 | ../dsets/ -------------------------------------------------------------------------------- /src/memit/notebooks/experiments: -------------------------------------------------------------------------------- 1 | ../experiments/ -------------------------------------------------------------------------------- /src/memit/notebooks/globals.yml: -------------------------------------------------------------------------------- 1 | ../globals.yml -------------------------------------------------------------------------------- /src/memit/notebooks/hparams: -------------------------------------------------------------------------------- 1 | ../hparams/ -------------------------------------------------------------------------------- /src/memit/notebooks/memit: -------------------------------------------------------------------------------- 1 | ../memit -------------------------------------------------------------------------------- /src/memit/notebooks/rome: -------------------------------------------------------------------------------- 1 | ../rome/ -------------------------------------------------------------------------------- /src/memit/notebooks/util: -------------------------------------------------------------------------------- 1 | ../util -------------------------------------------------------------------------------- /src/memit/notebooks/vis/experiments: -------------------------------------------------------------------------------- 1 | ../experiments -------------------------------------------------------------------------------- /src/memit/notebooks/vis/globals.yml: -------------------------------------------------------------------------------- 1 | ../globals.yml -------------------------------------------------------------------------------- /src/memit/notebooks/vis/util: -------------------------------------------------------------------------------- 1 | ../util -------------------------------------------------------------------------------- /src/memit/notebooks/vis/visualize_multi_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d465f696", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "9bdfca4c", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "%matplotlib inline\n", 22 | "%config InlineBackend.figure_format = 'retina'\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "from experiments.summarize import main as summarize_main\n", 25 | "from pathlib import Path\n", 26 | "import math" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "451eb471", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "RESULTS_DIR = Path(\"results/iclr\")\n", 37 | "DATA = {}\n", 38 | "KEYS = None\n", 39 | "for method_dir in RESULTS_DIR.iterdir():\n", 40 | " method_name = str(method_dir).split(\"/\")[-1]\n", 41 | " print(method_name)\n", 42 | " n_edit_folders = list(method_dir.glob(\"*_edits_setting_*\"))\n", 43 | " for n_edit_folder in n_edit_folders:\n", 44 | " n_edits = str(n_edit_folder.name).split(\"/\")[-1].split(\"_\")[0]\n", 45 | " try:\n", 46 | " res = summarize_main(n_edit_folder.relative_to(\"results\"), [\"run_000\"])[0]\n", 47 | "\n", 48 | " DATA[method_name] = DATA.get(method_name, {})\n", 49 | " DATA[method_name][n_edits] = res\n", 50 | " if KEYS is None:\n", 51 | " KEYS = list(res.keys())\n", 52 | " except:\n", 53 | " pass\n", 54 | "\n", 55 | "print({k: list(v.keys()) for k, v in DATA.items()})" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "7b9f0860", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "plt.rcParams[\"figure.dpi\"] = 200\n", 66 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n", 67 | "\n", 68 | "SMALL_SIZE = 14\n", 69 | "MEDIUM_SIZE = 15\n", 70 | "BIGGER_SIZE = 16\n", 71 | "\n", 72 | "plt.rc(\"font\", size=SMALL_SIZE) # controls default text sizes\n", 73 | "plt.rc(\"axes\", titlesize=BIGGER_SIZE) # fontsize of the axes title\n", 74 | "plt.rc(\"axes\", labelsize=MEDIUM_SIZE) # fontsize of the x and y labels\n", 75 | "plt.rc(\"xtick\", labelsize=SMALL_SIZE) # fontsize of the tick labels\n", 76 | "plt.rc(\"ytick\", labelsize=SMALL_SIZE) # fontsize of the tick labels\n", 77 | "plt.rc(\"legend\", fontsize=SMALL_SIZE) # legend fontsize\n", 78 | "plt.rc(\"figure\", titlesize=BIGGER_SIZE) # fontsize of the figure title" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "d8b41acc", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "TITLES = {\n", 89 | " \"post_score\": \"Score (S)\",\n", 90 | " \"post_rewrite_success\": \"Efficacy Succ. (ES)\",\n", 91 | " \"post_paraphrase_success\": \"Generalization Succ. (PS)\",\n", 92 | " \"post_neighborhood_success\": \"Specificity Succ. (NS)\",\n", 93 | " \"post_rewrite_acc\": \"Efficacy Acc (EA)\",\n", 94 | " \"post_paraphrase_acc\": \"Generalization Acc. (PA)\",\n", 95 | " \"post_neighborhood_acc\": \"Specificity Acc. (NA)\",\n", 96 | " \"post_reference_score\": \"Consistency (RS)\",\n", 97 | "}\n", 98 | "\n", 99 | "SHOW_KEYS = list(TITLES.keys())" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "id": "a1d443f7", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "SHOW_KEYS = KEYS\n", 110 | "SHOW_KEYS.pop(SHOW_KEYS.index(\"run_dir\"))\n", 111 | "TITLES = {k: k for k in SHOW_KEYS}" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "id": "49efeea0", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "w = 4\n", 122 | "h = math.ceil(len(KEYS) / w)\n", 123 | "plt.figure(figsize=(w * 3.5, h * 2.5))\n", 124 | "\n", 125 | "assert all(k in KEYS for k in SHOW_KEYS)\n", 126 | "for i, key in enumerate(SHOW_KEYS):\n", 127 | " plt.subplot(h, w, i + 1)\n", 128 | " for method, results in sorted([(k, v) for k, v in DATA.items() if \"_fix\" not in k]):\n", 129 | " try:\n", 130 | " n_edits = list(map(int, results.keys()))\n", 131 | " values = [\n", 132 | " f[0] if (type(f := results[str(n)][key]) is tuple) else f\n", 133 | " for n in n_edits\n", 134 | " ]\n", 135 | " plt.plot(n_edits, values, marker=\"o\", markersize=4, label=method)\n", 136 | " plt.xlabel(\"# Edits\")\n", 137 | " # plt.ylabel(\"metric value\")\n", 138 | " plt.title(TITLES[key])\n", 139 | " plt.legend()\n", 140 | " except:\n", 141 | " pass\n", 142 | "plt.tight_layout()\n", 143 | "plt.savefig(\"tmp.pdf\")\n", 144 | "plt.show()" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "ae8e7ea4", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [] 154 | } 155 | ], 156 | "metadata": { 157 | "accelerator": "GPU", 158 | "kernelspec": { 159 | "display_name": "Python 3 (ipykernel)", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.9.7" 174 | }, 175 | "vscode": { 176 | "interpreter": { 177 | "hash": "2c3ec9f9cb0aa45979d92499665f4b05f2a3528d3b2ca0efacea2020d32b93f4" 178 | } 179 | } 180 | }, 181 | "nbformat": 4, 182 | "nbformat_minor": 5 183 | } 184 | -------------------------------------------------------------------------------- /src/memit/notebooks/vis/visualize_sweep_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d465f696", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "9bdfca4c", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "# enable high-resolution figure\n", 22 | "%matplotlib inline\n", 23 | "%config InlineBackend.figure_format = 'retina'\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "from experiments.summarize import main as summarize_main\n", 26 | "from pathlib import Path\n", 27 | "import math" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "451eb471", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "RESULTS_DIR = Path(\"results/sweeps\")\n", 38 | "DATA = {}\n", 39 | "KEYS = None\n", 40 | "for method_dir in RESULTS_DIR.iterdir():\n", 41 | " method_name = str(method_dir).split(\"/\")[-1]\n", 42 | " print(method_name)\n", 43 | " n_edit_folders = list(method_dir.glob(\"*_edits_setting_*\"))\n", 44 | " for n_edit_folder in n_edit_folders:\n", 45 | " n_edits = int(str(n_edit_folder.name).split(\"/\")[-1].split(\"_\")[0])\n", 46 | " setting_id = str(n_edit_folder.name).split(\"/\")[-1].split(\"_\")[-1]\n", 47 | " try:\n", 48 | " res = summarize_main(n_edit_folder.relative_to(\"results\"), [\"run_000\"])[0]\n", 49 | "\n", 50 | " DATA[method_name] = DATA.get(method_name, {})\n", 51 | " DATA[method_name][n_edits] = DATA[method_name].get(n_edits, {})\n", 52 | " DATA[method_name][n_edits][setting_id] = res\n", 53 | "\n", 54 | " if KEYS is None:\n", 55 | " KEYS = list(res.keys())\n", 56 | " except:\n", 57 | " pass\n", 58 | "\n", 59 | "print({k: list(v.keys()) for k, v in DATA.items()})" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "49efeea0", 66 | "metadata": { 67 | "scrolled": false 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "for method, all_n_edits in sorted([(k, v) for k, v in DATA.items()]):\n", 72 | " for n_edits, results in sorted([(k, v) for k, v in all_n_edits.items()]):\n", 73 | " w = 4\n", 74 | " h = math.ceil(len(KEYS) / w)\n", 75 | " plt.figure(figsize=(w * 3.5, h * 2.5))\n", 76 | " if \"run_dir\" in KEYS:\n", 77 | " KEYS.pop(KEYS.index(\"run_dir\"))\n", 78 | " for i, key in enumerate(KEYS):\n", 79 | " plt.subplot(w, h, i + 1)\n", 80 | "\n", 81 | " try:\n", 82 | " setting_ids = list(map(int, results.keys()))\n", 83 | " values = [\n", 84 | " f[0] if (type(f := results[str(n)][key]) is tuple) else f\n", 85 | " for n in setting_ids\n", 86 | " ]\n", 87 | " plt.plot(setting_ids, values, marker=\"o\", markersize=4, label=method)\n", 88 | " plt.xlabel(\"setting_id\")\n", 89 | " plt.ylabel(\"metric value\")\n", 90 | " plt.title(f\"{n_edits} edits: {key}\")\n", 91 | " plt.legend()\n", 92 | " except:\n", 93 | " pass\n", 94 | "\n", 95 | " plt.tight_layout()\n", 96 | " plt.show()" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "ae8e7ea4", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [] 106 | } 107 | ], 108 | "metadata": { 109 | "accelerator": "GPU", 110 | "kernelspec": { 111 | "display_name": "Python 3 (ipykernel)", 112 | "language": "python", 113 | "name": "python3" 114 | }, 115 | "language_info": { 116 | "codemirror_mode": { 117 | "name": "ipython", 118 | "version": 3 119 | }, 120 | "file_extension": ".py", 121 | "mimetype": "text/x-python", 122 | "name": "python", 123 | "nbconvert_exporter": "python", 124 | "pygments_lexer": "ipython3", 125 | "version": "3.9.7" 126 | }, 127 | "vscode": { 128 | "interpreter": { 129 | "hash": "2c3ec9f9cb0aa45979d92499665f4b05f2a3528d3b2ca0efacea2020d32b93f4" 130 | } 131 | } 132 | }, 133 | "nbformat": 4, 134 | "nbformat_minor": 5 135 | } 136 | -------------------------------------------------------------------------------- /src/memit/rome/README.md: -------------------------------------------------------------------------------- 1 | # ROME 2 | This package provides a self-contained implementation of Rank-One Model Editing (ROME). 3 | 4 | Recall that ROME's update consists of: $u$ selection, $v_*$ optimization, and $v$ insertion. 5 | * [`compute_u.py`](compute_u.py): Chooses a $u$ vector. 6 | * [`compute_v.py`](compute_v.py): Choose a $v_*$ via optimization, then computes $v$. 7 | * [`rome_main.py`](rome_main.py): Instruments main logic. 8 | * [`rome_params.py`](rome_hparams.py): Interface for specifying hyperparameters. Inherits from the base [`params.py`](../util/hparams.py) module. 9 | 10 | For estimating second moment statistics of keys ($C = KK$), we provide the `layer_stats` module. See the [main README](../README.md) for usage instructions. 11 | * [`layer_stats.py`](layer_stats.py): Logic for retrieving and caching key statistics. 12 | * [`tok_dataset.py`](tok_dataset.py): Utilities for creating a dataset of tokens. -------------------------------------------------------------------------------- /src/memit/rome/__init__.py: -------------------------------------------------------------------------------- 1 | from .rome_main import ROMEHyperParams, apply_rome_to_model, execute_rome 2 | -------------------------------------------------------------------------------- /src/memit/rome/compute_u.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Dict, List 4 | 5 | import torch 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | 8 | from rome import repr_tools 9 | from util.globals import * 10 | 11 | from .layer_stats import layer_stats 12 | from .rome_hparams import ROMEHyperParams 13 | 14 | # Cache variables 15 | inv_mom2_cache = {} 16 | 17 | 18 | def get_inv_cov( 19 | model: AutoModelForCausalLM, 20 | tok: AutoTokenizer, 21 | layer_name: str, 22 | mom2_dataset: str, 23 | mom2_n_samples: str, 24 | mom2_dtype: str, 25 | ) -> torch.Tensor: 26 | """ 27 | Retrieves covariance statistics, then computes the algebraic inverse. 28 | Caches result for future use. 29 | """ 30 | 31 | global inv_mom2_cache 32 | 33 | model_name = model.config._name_or_path.replace("/", "_") 34 | key = (model_name, layer_name) 35 | 36 | if key not in inv_mom2_cache: 37 | print( 38 | f"Retrieving inverse covariance statistics for {model_name} @ {layer_name}. " 39 | f"The result will be cached to avoid repetitive computation." 40 | ) 41 | stat = layer_stats( 42 | model, 43 | tok, 44 | layer_name, 45 | STATS_DIR, 46 | mom2_dataset, 47 | to_collect=["mom2"], 48 | sample_size=mom2_n_samples, 49 | precision=mom2_dtype, 50 | ) 51 | inv_mom2_cache[key] = torch.inverse( 52 | stat.mom2.moment().to("cuda") 53 | ).float() # Cast back to float32 54 | 55 | return inv_mom2_cache[key] 56 | 57 | 58 | def compute_u( 59 | model: AutoModelForCausalLM, 60 | tok: AutoTokenizer, 61 | request: Dict, 62 | hparams: ROMEHyperParams, 63 | layer: int, 64 | context_templates: List[str], 65 | ) -> torch.Tensor: 66 | """ 67 | Computes the right vector used in constructing the rank-1 update matrix. 68 | """ 69 | 70 | print("Computing left vector (u)...") 71 | 72 | # Compute projection token 73 | word_repr_args = dict( 74 | model=model, 75 | tok=tok, 76 | layer=layer, 77 | module_template=hparams.rewrite_module_tmp, 78 | track="in", 79 | ) 80 | if "subject_" in hparams.fact_token and hparams.fact_token.index("subject_") == 0: 81 | word = request["subject"] 82 | print(f"Selected u projection object {word}") 83 | cur_repr = repr_tools.get_reprs_at_word_tokens( 84 | context_templates=[ 85 | templ.format(request["prompt"]) for templ in context_templates 86 | ], 87 | words=[word for _ in range(len(context_templates))], 88 | subtoken=hparams.fact_token[len("subject_") :], 89 | **word_repr_args, 90 | ).mean(0) 91 | elif hparams.fact_token == "last": 92 | # Heuristic to choose last word. Not a huge deal if there's a minor 93 | # edge case (e.g. multi-token word) because the function below will 94 | # take the last token. 95 | cur_repr = repr_tools.get_reprs_at_idxs( 96 | contexts=[ 97 | templ.format(request["prompt"].format(request["subject"])) 98 | for templ in context_templates 99 | ], 100 | idxs=[[-1] for _ in range(len(context_templates))], 101 | **word_repr_args, 102 | ).mean(0) 103 | print("Selected u projection token with last token") 104 | else: 105 | raise ValueError(f"fact_token={hparams.fact_token} not recognized") 106 | 107 | # Apply inverse second moment adjustment 108 | u = cur_repr.to('cuda') 109 | if hparams.mom2_adjustment: 110 | u = get_inv_cov( 111 | model, 112 | tok, 113 | hparams.rewrite_module_tmp.format(layer), 114 | hparams.mom2_dataset, 115 | hparams.mom2_n_samples, 116 | hparams.mom2_dtype, 117 | ) @ u.unsqueeze(1) 118 | u = u.squeeze() 119 | 120 | return u / u.norm() 121 | -------------------------------------------------------------------------------- /src/memit/rome/repr_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains utilities for extracting token representations and indices 3 | from string templates. Used in computing the left and right vectors for ROME. 4 | """ 5 | 6 | from copy import deepcopy 7 | from typing import List 8 | 9 | import torch 10 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, LlamaTokenizerFast 11 | 12 | from util import nethook 13 | 14 | 15 | def get_reprs_at_word_tokens( 16 | model: AutoModelForCausalLM, 17 | tok: AutoTokenizer, 18 | context_templates: List[str], 19 | words: List[str], 20 | layer: int, 21 | module_template: str, 22 | subtoken: str, 23 | track: str = "in", 24 | ) -> torch.Tensor: 25 | """ 26 | Retrieves the last token representation of `word` in `context_template` 27 | when `word` is substituted into `context_template`. See `get_last_word_idx_in_template` 28 | for more details. 29 | """ 30 | 31 | idxs = get_words_idxs_in_templates(tok, context_templates, words, subtoken) 32 | return get_reprs_at_idxs( 33 | model, 34 | tok, 35 | [context_templates[i].format(words[i]) for i in range(len(words))], 36 | idxs, 37 | layer, 38 | module_template, 39 | track, 40 | ) 41 | 42 | 43 | def get_words_idxs_in_templates( 44 | tok: AutoTokenizer, context_templates: str, words: str, subtoken: str 45 | ) -> int: 46 | """ 47 | Given list of template strings, each with *one* format specifier 48 | (e.g. "{} plays basketball"), and words to be substituted into the 49 | template, computes the post-tokenization index of their last tokens. 50 | """ 51 | 52 | assert all( 53 | tmp.count("{}") == 1 for tmp in context_templates 54 | ), "We currently do not support multiple fill-ins for context" 55 | 56 | # Compute prefixes and suffixes of the tokenized context 57 | fill_idxs = [tmp.index("{}") for tmp in context_templates] 58 | prefixes, suffixes = [ 59 | tmp[: fill_idxs[i]] for i, tmp in enumerate(context_templates) 60 | ], [tmp[fill_idxs[i] + 2 :] for i, tmp in enumerate(context_templates)] 61 | words = deepcopy(words) 62 | 63 | # Pre-process tokens 64 | for i, prefix in enumerate(prefixes): 65 | if len(prefix) > 0: 66 | assert prefix[-1] == " " 67 | prefix = prefix[:-1] 68 | 69 | prefixes[i] = prefix 70 | words[i] = f" {words[i].strip()}" 71 | 72 | # Tokenize to determine lengths 73 | assert len(prefixes) == len(words) == len(suffixes) 74 | n = len(prefixes) 75 | batch_tok = tok([*prefixes, *words, *suffixes]) 76 | 77 | batch_tok = batch_tok['input_ids'] 78 | 79 | prefixes_tok, words_tok, suffixes_tok = [ 80 | batch_tok[i : i + n] for i in range(0, n * 3, n) 81 | ] 82 | 83 | if isinstance(tok, LlamaTokenizer) or isinstance(tok, LlamaTokenizerFast): 84 | words_tok = [tokens[1:] if tokens[0] == 29871 else tokens for tokens in words_tok] 85 | suffixes_tok = [tokens[1:] if tokens[0] == 29871 else tokens for tokens in suffixes_tok] 86 | 87 | prefixes_len, words_len, suffixes_len = [ 88 | [len(el) for el in tok_list] 89 | for tok_list in [prefixes_tok, words_tok, suffixes_tok] 90 | ] 91 | 92 | # Compute indices of last tokens 93 | if subtoken == "last" or subtoken == "first_after_last": 94 | return [ 95 | [ 96 | prefixes_len[i] 97 | + words_len[i] 98 | - (1 if subtoken == "last" or suffixes_len[i] == 0 else 0) 99 | ] 100 | # If suffix is empty, there is no "first token after the last". 101 | # So, just return the last token of the word. 102 | for i in range(n) 103 | ] 104 | elif subtoken == "first": 105 | return [[prefixes_len[i]] for i in range(n)] 106 | else: 107 | raise ValueError(f"Unknown subtoken type: {subtoken}") 108 | 109 | 110 | def get_reprs_at_idxs( 111 | model: AutoModelForCausalLM, 112 | tok: AutoTokenizer, 113 | contexts: List[str], 114 | idxs: List[List[int]], 115 | layer: int, 116 | module_template: str, 117 | track: str = "in", 118 | ) -> torch.Tensor: 119 | """ 120 | Runs input through model and returns averaged representations of the tokens 121 | at each index in `idxs`. 122 | """ 123 | 124 | def _batch(n): 125 | for i in range(0, len(contexts), n): 126 | yield contexts[i : i + n], idxs[i : i + n] 127 | 128 | assert track in {"in", "out", "both"} 129 | both = track == "both" 130 | tin, tout = ( 131 | (track == "in" or both), 132 | (track == "out" or both), 133 | ) 134 | module_name = module_template.format(layer) 135 | to_return = {"in": [], "out": []} 136 | 137 | def _process(cur_repr, batch_idxs, key): 138 | nonlocal to_return 139 | cur_repr = cur_repr[0] if type(cur_repr) is tuple else cur_repr 140 | for i, idx_list in enumerate(batch_idxs): 141 | to_return[key].append(cur_repr[i][idx_list].mean(0)) 142 | 143 | for batch_contexts, batch_idxs in _batch(n=128): 144 | contexts_tok = tok(batch_contexts, padding=True, return_tensors="pt").to( 145 | next(model.parameters()).device 146 | ) 147 | 148 | if isinstance(model, LlamaForCausalLM): 149 | if 'token_type_ids' in contexts_tok: 150 | del contexts_tok['token_type_ids'] 151 | 152 | with torch.no_grad(): 153 | with nethook.Trace( 154 | module=model, 155 | layer=module_name, 156 | retain_input=tin, 157 | retain_output=tout, 158 | ) as tr: 159 | model(**contexts_tok) 160 | 161 | if tin: 162 | _process(tr.input, batch_idxs, "in") 163 | if tout: 164 | _process(tr.output, batch_idxs, "out") 165 | 166 | to_return = {k: torch.stack(v, 0) for k, v in to_return.items() if len(v) > 0} 167 | 168 | if len(to_return) == 1: 169 | return to_return["in"] if tin else to_return["out"] 170 | else: 171 | return to_return["in"], to_return["out"] 172 | -------------------------------------------------------------------------------- /src/memit/rome/rome_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | from util.hparams import HyperParams 5 | 6 | 7 | @dataclass 8 | class ROMEHyperParams(HyperParams): 9 | # Method 10 | layers: List[int] 11 | fact_token: str 12 | v_num_grad_steps: int 13 | v_lr: float 14 | v_loss_layer: int 15 | v_weight_decay: float 16 | clamp_norm_factor: float 17 | kl_factor: float 18 | mom2_adjustment: bool 19 | context_template_length_params: List[List[int]] 20 | 21 | # Module templates 22 | rewrite_module_tmp: str 23 | layer_module_tmp: str 24 | mlp_module_tmp: str 25 | attn_module_tmp: str 26 | ln_f_module: str 27 | lm_head_module: str 28 | 29 | # Statistics 30 | mom2_dataset: str 31 | mom2_n_samples: int 32 | mom2_dtype: str 33 | -------------------------------------------------------------------------------- /src/memit/rome/tok_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class TokenizedDataset(Dataset): 7 | """ 8 | Converts a dataset of text samples into a dataset of token sequences, 9 | as converted by a supplied tokenizer. The tokens come along with position 10 | ids and attention masks, they can be supplied direcly to the model. 11 | """ 12 | 13 | def __init__(self, text_dataset, tokenizer=None, maxlen=None, field="text"): 14 | self.text_dataset = text_dataset 15 | self.field = field 16 | self.tokenizer = tokenizer 17 | self.maxlen = maxlen 18 | if hasattr(text_dataset, "info"): 19 | self.info = text_dataset.info 20 | 21 | def __len__(self): 22 | return len(self.text_dataset) 23 | 24 | def __getitem__(self, i): 25 | text = self.text_dataset[i] 26 | if self.field is not None: 27 | text = text[self.field] 28 | token_list = self.tokenizer.encode( 29 | text, truncation=True, max_length=self.maxlen 30 | ) 31 | position_ids = list(range(len(token_list))) 32 | attention_mask = [1] * len(token_list) 33 | return dict( 34 | input_ids=torch.tensor(token_list), 35 | position_ids=torch.tensor(position_ids), 36 | attention_mask=torch.tensor(attention_mask), 37 | ) 38 | 39 | 40 | def dict_to_(data, device): 41 | """ 42 | Moves a dictionary of tensors to the specified device. 43 | """ 44 | for k in data: 45 | data[k] = data[k].to(device) 46 | return data 47 | 48 | 49 | def length_collation(token_size): 50 | """ 51 | Sorts a batch of sequences and breaks it up into subbatches 52 | of same-sized sequences, padding as needed. Each batch 53 | has no more than token_size total tokens (or a single 54 | sequence, if the sequence happens to be larger). 55 | """ 56 | 57 | def collate_fn(items): 58 | items = sorted(items, key=lambda x: -len(x["input_ids"])) 59 | batches = [] 60 | batch = [] 61 | batch_width = 0 62 | for item in items: 63 | item_width = len(item["input_ids"]) 64 | if item_width == 0: 65 | break 66 | if batch_width * (len(batch) + 1) > token_size: 67 | batches.append(make_padded_batch(batch)) 68 | batch = [] 69 | batch_width = 0 70 | if not batch: 71 | batch_width = item_width 72 | batch.append(item) 73 | if len(batch): 74 | batches.append(make_padded_batch(batch)) 75 | return batches 76 | 77 | return collate_fn 78 | 79 | 80 | def make_padded_batch(items): 81 | """ 82 | Pads sequences in a batch, so they are all the same length as the longest. 83 | """ 84 | max_len = max(len(d["input_ids"]) for d in items) 85 | if max_len == 0: 86 | return {k: torch.zeros((0, 0), dtype=torch.long) for k in items[0]} 87 | return { 88 | k: pad_sequence([d[k] for d in items if len(d["input_ids"])], batch_first=True) 89 | for k, v in items[0].items() 90 | } 91 | 92 | 93 | def flatten_masked_batch(data, mask): 94 | """ 95 | Flattens feature data, ignoring items that are masked out of attention. 96 | """ 97 | flat_data = data.view(-1, data.size(-1)) 98 | attended_tokens = mask.view(-1).nonzero()[:, 0] 99 | attended_tokens = attended_tokens.to(flat_data.device) 100 | return flat_data[attended_tokens] 101 | -------------------------------------------------------------------------------- /src/memit/scaling_curves.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Constants 5 | DIR="scaling" 6 | MIN_NUM_RECORDS="10000" 7 | GEN_TEST_INTERV="10" 8 | N_EDITS="1,56,100,316,562,1000,1778,3162,5623,10000" 9 | 10 | # Run configurations 11 | MODEL_NAME="EleutherAI/gpt-j-6B" 12 | ALG_NAMES=("FT" "MEND" "ROME" "MEMIT") 13 | HPARAMS_FNAMES=("EleutherAI_gpt-j-6B_wd.json" "EleutherAI_gpt-j-6B.json" "EleutherAI_gpt-j-6B.json" "EleutherAI_gpt-j-6B.json") 14 | 15 | # Execute 16 | for i in ${!ALG_NAMES[@]} 17 | do 18 | alg_name=${ALG_NAMES[$i]} 19 | hparams_fname=${HPARAMS_FNAMES[$i]} 20 | 21 | echo "Running evals for $alg_name..." 22 | sweep_dir="$DIR/$alg_name" 23 | 24 | if [ -d "results/$sweep_dir" ]; then 25 | echo "Note: results/$sweep_dir already exists! Continuing from previous run..." 26 | fi 27 | 28 | echo "Dumping results at results/$sweep_dir" 29 | mkdir -p results/$sweep_dir 30 | echo "{}" > results/$sweep_dir/config.json 31 | 32 | python3 -m experiments.sweep --alg_name=$alg_name --model_name=$MODEL_NAME --hparams_fname=$hparams_fname --sweep_dir=$sweep_dir --min_num_records=$MIN_NUM_RECORDS --num_edits=$N_EDITS --generation_test_interval=$GEN_TEST_INTERV --use_cache 33 | done 34 | 35 | exit 0 36 | -------------------------------------------------------------------------------- /src/memit/scripts/causal_trace.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Start from parent directory of script 4 | SCRIPT_DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) 5 | cd "$(dirname ${SCRIPT_DIR})" 6 | 7 | python -m experiments.causal_trace --model_name "EleutherAI/gpt-j-6B" --noise_level 0.025 8 | python -m experiments.causal_trace --model_name "gpt2-xl" --noise_level 0.1 9 | python -m experiments.causal_trace --model_name "EleutherAI/gpt-neox-20b" --noise_level 0.03 10 | -------------------------------------------------------------------------------- /src/memit/scripts/colab_reqs/additional.txt: -------------------------------------------------------------------------------- 1 | allennlp==2.9.0 2 | einops==0.4.0 3 | higher==0.2.1 4 | hydra-core==1.1.1 -------------------------------------------------------------------------------- /src/memit/scripts/colab_reqs/rome.txt: -------------------------------------------------------------------------------- 1 | datasets==1.18.3 2 | python-dotenv==0.19.2 3 | git+https://github.com/kmeng01/transformers-colab@allennlp-compat 4 | -------------------------------------------------------------------------------- /src/memit/scripts/collect_layer_stats.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Start from parent directory of script 4 | SCRIPT_DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) 5 | cd "$(dirname ${SCRIPT_DIR})" 6 | 7 | run_gpu_0() { 8 | CUDA_VISIBLE_DEVICES=0 python -m rome.layer_stats --layers=$(seq -s, 0 1 27) --sample_size 100000 --model_name=EleutherAI/gpt-j-6B 9 | } 10 | 11 | run_gpu_1() { 12 | CUDA_VISIBLE_DEVICES=1 python -m rome.layer_stats --layers=$(seq -s, 0 1 27) --sample_size 100000 --model_name=EleutherAI/gpt-j-6B 13 | } 14 | 15 | # run_gpu_0 &>stats0.out& 16 | run_gpu_1 &>stats1.out& 17 | -------------------------------------------------------------------------------- /src/memit/scripts/ipynb_drop_output.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Suppress output and prompt numbers in git version control. 5 | 6 | This script will tell git to ignore prompt numbers and cell output 7 | when looking at ipynb files UNLESS their metadata contains: 8 | 9 | "git": { 10 | "keep_outputs": true 11 | }, 12 | 13 | The notebooks themselves are not changed. 14 | 15 | See also this blogpost: http://pascalbugnion.net/blog/ipython-notebooks-and-git.html. 16 | 17 | Usage instructions 18 | ================== 19 | 20 | 1. Put this script in a directory that is on the system's path. 21 | For future reference, I will assume you saved it in 22 | `~/scripts/ipynb_drop_output`. 23 | 2. Make sure it is executable by typing the command 24 | `chmod +x ~/scripts/ipynb_drop_output`. 25 | 3. Register a filter for ipython notebooks by 26 | putting the following line in `~/.config/git/attributes`: 27 | `*.ipynb filter=clean_ipynb` 28 | 4. Connect this script to the filter by running the following 29 | git commands: 30 | 31 | git config --global filter.clean_ipynb.clean ipynb_drop_output 32 | git config --global filter.clean_ipynb.smudge cat 33 | 34 | To tell git NOT to ignore the output and prompts for a notebook, 35 | open the notebook's metadata (Edit > Edit Notebook Metadata). A 36 | panel should open containing the lines: 37 | 38 | { 39 | "name" : "", 40 | "signature" : "some very long hash" 41 | } 42 | 43 | Add an extra line so that the metadata now looks like: 44 | 45 | { 46 | "name" : "", 47 | "signature" : "don't change the hash, but add a comma at the end of the line", 48 | "git" : { "keep_outputs" : true } 49 | } 50 | 51 | You may need to "touch" the notebooks for git to actually register a change, if 52 | your notebooks are already under version control. 53 | 54 | Notes 55 | ===== 56 | 57 | 58 | This script is inspired by http://stackoverflow.com/a/20844506/827862, but 59 | lets the user specify whether the ouptut of a notebook should be kept 60 | in the notebook's metadata, and works for IPython v3.0. 61 | """ 62 | 63 | import json 64 | import sys 65 | 66 | nb = sys.stdin.read() 67 | 68 | json_in = json.loads(nb) 69 | nb_metadata = json_in["metadata"] 70 | keep_output = False 71 | if "git" in nb_metadata: 72 | if "keep_outputs" in nb_metadata["git"] and nb_metadata["git"]["keep_outputs"]: 73 | keep_output = True 74 | if keep_output: 75 | sys.stdout.write(nb) 76 | exit() 77 | 78 | 79 | ipy_version = int(json_in["nbformat"]) - 1 # nbformat is 1 more than actual version. 80 | 81 | 82 | def strip_output_from_cell(cell): 83 | if "outputs" in cell: 84 | cell["outputs"] = [] 85 | if "prompt_number" in cell: 86 | del cell["prompt_number"] 87 | if "execution_count" in cell: 88 | cell["execution_count"] = None 89 | 90 | 91 | if ipy_version == 2: 92 | for sheet in json_in["worksheets"]: 93 | for cell in sheet["cells"]: 94 | strip_output_from_cell(cell) 95 | else: 96 | for cell in json_in["cells"]: 97 | strip_output_from_cell(cell) 98 | 99 | json.dump( 100 | json_in, 101 | sys.stdout, 102 | sort_keys=True, 103 | indent=1, 104 | separators=(",", ": "), 105 | ensure_ascii=False, 106 | ) 107 | # https://stackoverflow.com/questions/729692/why-should-text-files-end-with-a-newline 108 | sys.stdout.write("\n") 109 | -------------------------------------------------------------------------------- /src/memit/scripts/memit.yml: -------------------------------------------------------------------------------- 1 | name: memit 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.9.7 7 | - pip=21.2.4 8 | - cudatoolkit=11.3 9 | - pytorch==1.12.1 10 | - pip: 11 | - einops==0.4.0 12 | - higher==0.2.1 13 | - hydra-core==1.2.0 14 | - transformers==4.23.1 15 | - datasets==1.18.3 16 | - matplotlib==3.6.1 17 | - spacy==3.4.1 18 | - scipy==1.9.2 19 | - scikit-learn==1.0.2 20 | - nltk==3.7 21 | - jupyter==1.0.0 -------------------------------------------------------------------------------- /src/memit/scripts/setup_clean_ipynb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Start from directory of script 4 | cd "$(dirname "${BASH_SOURCE[0]}")" 5 | 6 | # Set up git config filters so huge output of notebooks is not committed. 7 | git config filter.clean_ipynb.clean "$(pwd)/ipynb_drop_output.py" 8 | git config filter.clean_ipynb.smudge cat 9 | git config filter.clean_ipynb.required true 10 | -------------------------------------------------------------------------------- /src/memit/scripts/setup_conda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Start from directory of script 4 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 5 | cd $SCRIPT_DIR 6 | 7 | # Detect operating system 8 | unameOut="$(uname -s)" 9 | case "${unameOut}" in 10 | Linux*) machine=Linux;; 11 | Darwin*) machine=Mac;; 12 | CYGWIN*) machine=Cygwin;; 13 | MINGW*) machine=MinGw;; 14 | *) machine="UNKNOWN:${unameOut}" 15 | esac 16 | 17 | if [ $machine != "Linux" ] && [ $machine != "Mac" ] 18 | then 19 | echo "Conda setup script is only available on Linux and Mac." 20 | exit 1 21 | else 22 | echo "Running on $machine..." 23 | fi 24 | 25 | if [[ -z "${CONDA_HOME}" ]]; then 26 | echo "Please specify the CONDA_HOME environment variable (it might look something like ~/miniconda3)." 27 | exit 1 28 | else 29 | echo "Found CONDA_HOME=${CONDA_HOME}." 30 | fi 31 | 32 | RECIPE=${RECIPE:-memit} 33 | ENV_NAME="${ENV_NAME:-${RECIPE}}" 34 | echo "Creating conda environment ${ENV_NAME}..." 35 | 36 | if [[ ! $(type -P conda) ]] 37 | then 38 | echo "conda not in PATH" 39 | echo "read: https://conda.io/docs/user-guide/install/index.html" 40 | exit 1 41 | fi 42 | 43 | if df "${HOME}/.conda" --type=afs > /dev/null 2>&1 44 | then 45 | echo "Not installing: your ~/.conda directory is on AFS." 46 | echo "Use 'ln -s /some/nfs/dir ~/.conda' to avoid using up your AFS quota." 47 | exit 1 48 | fi 49 | 50 | # Build new environment 51 | conda env create --name=${ENV_NAME} -f ${RECIPE}.yml 52 | -------------------------------------------------------------------------------- /src/memit/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .logit_lens import LogitLens 2 | -------------------------------------------------------------------------------- /src/memit/util/globals.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yaml 4 | 5 | with open("globals.yml", "r") as stream: 6 | data = yaml.safe_load(stream) 7 | 8 | (RESULTS_DIR, DATA_DIR, STATS_DIR, HPARAMS_DIR, KV_DIR) = ( 9 | Path(z) 10 | for z in [ 11 | data["RESULTS_DIR"], 12 | data["DATA_DIR"], 13 | data["STATS_DIR"], 14 | data["HPARAMS_DIR"], 15 | data["KV_DIR"], 16 | ] 17 | ) 18 | 19 | REMOTE_ROOT_URL = data["REMOTE_ROOT_URL"] 20 | -------------------------------------------------------------------------------- /src/memit/util/hparams.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class HyperParams: 7 | """ 8 | Simple wrapper to store hyperparameters for Python-based rewriting methods. 9 | """ 10 | 11 | @classmethod 12 | def from_json(cls, fpath): 13 | with open(fpath, "r") as f: 14 | data = json.load(f) 15 | 16 | return cls(**data) 17 | -------------------------------------------------------------------------------- /src/memit/util/logit_lens.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | 7 | from util import nethook 8 | 9 | 10 | class LogitLens: 11 | """ 12 | Applies the LM head at the output of each hidden layer, then analyzes the 13 | resultant token probability distribution. 14 | 15 | Only works when hooking outputs of *one* individual generation. 16 | 17 | Inspiration: https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens 18 | 19 | Warning: when running multiple times (e.g. generation), will return 20 | outputs _only_ for the last processing step. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | model: AutoModelForCausalLM, 26 | tok: AutoTokenizer, 27 | layer_module_tmp: str, 28 | ln_f_module: str, 29 | lm_head_module: str, 30 | disabled: bool = False, 31 | ): 32 | self.disabled = disabled 33 | self.model, self.tok = model, tok 34 | self.n_layers = self.model.config.n_layer 35 | 36 | self.lm_head, self.ln_f = ( 37 | nethook.get_module(model, lm_head_module), 38 | nethook.get_module(model, ln_f_module), 39 | ) 40 | 41 | self.output: Optional[Dict] = None 42 | self.td: Optional[nethook.TraceDict] = None 43 | self.trace_layers = [ 44 | layer_module_tmp.format(layer) for layer in range(self.n_layers) 45 | ] 46 | 47 | def __enter__(self): 48 | if not self.disabled: 49 | self.td = nethook.TraceDict( 50 | self.model, 51 | self.trace_layers, 52 | retain_input=False, 53 | retain_output=True, 54 | ) 55 | self.td.__enter__() 56 | 57 | def __exit__(self, *args): 58 | if self.disabled: 59 | return 60 | self.td.__exit__(*args) 61 | 62 | self.output = {layer: [] for layer in range(self.n_layers)} 63 | 64 | with torch.no_grad(): 65 | for layer, (_, t) in enumerate(self.td.items()): 66 | cur_out = t.output[0] 67 | assert ( 68 | cur_out.size(0) == 1 69 | ), "Make sure you're only running LogitLens on single generations only." 70 | 71 | self.output[layer] = torch.softmax( 72 | self.lm_head(self.ln_f(cur_out[:, -1, :])), dim=1 73 | ) 74 | 75 | return self.output 76 | 77 | def pprint(self, k=5): 78 | to_print = defaultdict(list) 79 | 80 | for layer, pred in self.output.items(): 81 | rets = torch.topk(pred[0], k) 82 | for i in range(k): 83 | to_print[layer].append( 84 | ( 85 | self.tok.decode(rets[1][i]), 86 | round(rets[0][i].item() * 1e2) / 1e2, 87 | ) 88 | ) 89 | 90 | print( 91 | "\n".join( 92 | [ 93 | f"{layer}: {[(el[0], round(el[1] * 1e2)) for el in to_print[layer]]}" 94 | for layer in range(self.n_layers) 95 | ] 96 | ) 97 | ) 98 | -------------------------------------------------------------------------------- /src/memit/util/perplexity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | 4 | 5 | def perplexity( 6 | model: AutoModelForCausalLM, 7 | tok: AutoTokenizer, 8 | text: str, 9 | max_input_length: int = None, 10 | ): 11 | """ 12 | Computes perplexity of a piece of text, measured on a reference model. 13 | Text is truncated to max_input_length tokens. 14 | """ 15 | 16 | inputs = tok( 17 | [text], return_tensors="pt", max_length=max_input_length, truncation=True 18 | ).to("cuda") 19 | 20 | logits = torch.nn.functional.log_softmax(model(**inputs).logits, dim=2) 21 | log_probs = torch.gather(logits[:, :-1, :], 2, inputs["input_ids"][:, 1:, None])[0] 22 | 23 | # Perplexity = exp(-1/N * log P(x_1, ..., x_n)) 24 | return torch.exp(-1 / inputs["input_ids"].size(1) * log_probs.sum()).item() 25 | -------------------------------------------------------------------------------- /src/memit/zsre_evals.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Constants 5 | N_EDITS="10000" 6 | 7 | # Run configurations 8 | MODEL_NAME="EleutherAI/gpt-j-6B" 9 | ALG_NAMES=("FT" "MEND" "ROME" "MEMIT") 10 | HPARAMS_FNAMES=("EleutherAI_gpt-j-6B_wd.json" "EleutherAI_gpt-j-6B.json" "EleutherAI_gpt-j-6B.json" "EleutherAI_gpt-j-6B.json") 11 | 12 | # Execute 13 | for i in ${!ALG_NAMES[@]} 14 | do 15 | alg_name=${ALG_NAMES[$i]} 16 | hparams_fname=${HPARAMS_FNAMES[$i]} 17 | 18 | echo "Running evals for $alg_name..." 19 | 20 | python3 -m experiments.evaluate --alg_name=$alg_name --model_name=$MODEL_NAME --hparams_fname=$hparams_fname --num_edits=$N_EDITS --use_cache --dataset_size_limit=$N_EDITS --ds_name=zsre 21 | done 22 | 23 | exit 0 24 | -------------------------------------------------------------------------------- /src/modeleditor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | 5 | from queryexecutor import QueryExecutor 6 | 7 | 8 | class ModelEditor: 9 | 10 | def __init__(self, query_executor): 11 | self._query_executor = query_executor 12 | self._model = self._query_executor.get_model() 13 | self._tokenizer = self._query_executor.get_tokenizer() 14 | self._model_name = self._query_executor.get_model_name() 15 | self._model_device = self._query_executor.get_device() 16 | 17 | def edit_model(self, fact): 18 | raise NotImplementedError() # Override in concrete classes 19 | 20 | def restore_model(self): 21 | raise NotImplementedError() # Override in concrete classes 22 | 23 | 24 | class InContextModelEditor(ModelEditor): 25 | 26 | def __init__(self, query_executor: QueryExecutor): 27 | super().__init__(query_executor) 28 | 29 | def edit_model(self, fact): 30 | context = 'Imagine that ' + fact.get_fact_phrased() + '\n' 31 | print(f'In Context Editing added context: {context}') 32 | self._query_executor.set_prompt_context(context) 33 | 34 | def restore_model(self): 35 | self._query_executor.set_prompt_context('') 36 | 37 | 38 | class RomeStyleModelEditor(ModelEditor): 39 | 40 | def __init__(self, query_executor): 41 | self._changed_weights = None 42 | super().__init__(query_executor) 43 | 44 | @staticmethod 45 | def _format_fact_for_rome(fact): 46 | subject = fact.get_subject_label() 47 | target = fact.get_target_label() 48 | prompt = fact.get_fact_prompt().replace(subject, '{}') 49 | return [{'prompt': prompt, 'subject': subject, 'target_new': {'str': target}}] 50 | 51 | def edit_model(self, fact): 52 | raise NotImplementedError() # Override in concrete classes 53 | 54 | def restore_model(self): 55 | if self._changed_weights is None: 56 | return 57 | 58 | os.chdir('./memit') 59 | sys.path.append('..') 60 | from util import nethook 61 | 62 | with torch.no_grad(): 63 | for k, v in self._changed_weights.items(): 64 | nethook.get_parameter(self._model, k)[...] = v.to(self._model_device) 65 | 66 | sys.path.remove('..') 67 | os.chdir('../..') 68 | 69 | 70 | class MEMITModelEditor(RomeStyleModelEditor): 71 | 72 | def __init__(self, query_executor): 73 | super().__init__(query_executor) 74 | 75 | def edit_model(self, fact): 76 | os.chdir('./memit') 77 | sys.path.append('..') 78 | from memit import MEMITHyperParams, apply_memit_to_model 79 | 80 | requests = self._format_fact_for_rome(fact) 81 | hparams = MEMITHyperParams.from_json(f'hparams/MEMIT/{self._model_name}.json') 82 | _, self._changed_weights = apply_memit_to_model(self._model, self._tokenizer, requests, hparams, return_orig_weights=True) 83 | 84 | sys.path.remove('..') 85 | os.chdir('../..') 86 | 87 | 88 | class ROMEModelEditor(RomeStyleModelEditor): 89 | 90 | def __init__(self, query_executor): 91 | super().__init__(query_executor) 92 | 93 | def edit_model(self, fact): 94 | os.chdir('./memit') 95 | sys.path.append('..') 96 | from rome import ROMEHyperParams, apply_rome_to_model 97 | 98 | requests = self._format_fact_for_rome(fact) 99 | hparams = ROMEHyperParams.from_json(f'hparams/ROME/{self._model_name}.json') 100 | _, self._changed_weights = apply_rome_to_model(self._model, self._tokenizer, requests, hparams, return_orig_weights=True) 101 | 102 | sys.path.remove('..') 103 | os.chdir('../..') 104 | 105 | 106 | class MENDModelEditor(RomeStyleModelEditor): 107 | 108 | def __init__(self, query_executor): 109 | super().__init__(query_executor) 110 | 111 | def edit_model(self, fact): 112 | os.chdir('./memit') 113 | sys.path.append('..') 114 | from baselines.mend import MENDHyperParams, MendRewriteExecutor 115 | 116 | requests = self._format_fact_for_rome(fact) 117 | hparams = MENDHyperParams.from_json(f'hparams/MEND/{self._model_name}.json') 118 | _, self._changed_weights = MendRewriteExecutor().apply_to_model(self._model, self._tokenizer, requests, hparams, return_orig_weights=True) 119 | 120 | sys.path.remove('..') 121 | os.chdir('../..') 122 | -------------------------------------------------------------------------------- /src/query.py: -------------------------------------------------------------------------------- 1 | from relation import Relation 2 | from wikidata.utils import get_label, get_aliases 3 | 4 | 5 | class Query: 6 | 7 | def __init__(self, subject_id, relation, target_ids, phrase=None): 8 | self._subject_id = subject_id 9 | self._relation = relation 10 | self._targets_ids = target_ids if type(target_ids) == list else [target_ids] 11 | self._phrase = phrase 12 | 13 | def get_query_prompt(self): 14 | if self._phrase is None: 15 | return self._relation.phrase(get_label(self._subject_id)) 16 | return self._phrase 17 | 18 | @staticmethod 19 | def _filter_answers(answers): 20 | filtered_answers = [] 21 | for answer in answers: 22 | if len(answer) > 1 or answer.isdigit(): 23 | filtered_answers.append(answer) 24 | return filtered_answers 25 | 26 | def get_answers(self): 27 | answers = [] 28 | for target in self._targets_ids: 29 | if type(target) is str: 30 | target_answer = [get_label(target)] + get_aliases(target) 31 | else: 32 | target_answer = [str(target)] 33 | answers.append(self._filter_answers(target_answer)) 34 | return answers 35 | 36 | def to_dict(self): 37 | return { 38 | 'prompt': self.get_query_prompt(), 39 | 'answers': [{'value': get_label(target), 'aliases': get_aliases(target)} if type(target) == str and target[0] == 'Q' 40 | else {'value': str(target), 'aliases': []} for target in self._targets_ids], 41 | 'query_type': 'regular', 42 | 'subject_id': self._subject_id, 43 | 'relation': self._relation.name, 44 | 'target_ids': self._targets_ids, 45 | 'phrase': self._phrase, 46 | } 47 | 48 | @staticmethod 49 | def from_dict(d): 50 | subject_id = d['subject_id'] 51 | relation = Relation[d['relation']] 52 | target_ids = d['target_ids'] 53 | phrase = d['phrase'] 54 | if d['query_type'] == 'regular': 55 | return Query(subject_id, relation, target_ids, phrase) 56 | elif d['query_type'] == 'two_hop': 57 | second_relation = Relation[d['second_relation']] 58 | second_hop_target_ids = d['second_hop_target_ids'] 59 | return TwoHopQuery(subject_id, relation, target_ids, second_relation, second_hop_target_ids, phrase) 60 | else: 61 | print('Unknown phrase type: ', d['query_type']) 62 | 63 | 64 | class TwoHopQuery(Query): 65 | 66 | def __init__(self, subject_id, relation, target_ids, second_relation, second_hop_target_ids, phrase): 67 | super().__init__(subject_id, relation, target_ids, phrase) 68 | self._second_relation = second_relation 69 | self._second_hop_target_ids = second_hop_target_ids if type(second_hop_target_ids) == list else [second_hop_target_ids] 70 | 71 | def get_query_prompt(self): 72 | return self._phrase 73 | 74 | def get_answers(self): 75 | answers = [] 76 | for target in self._second_hop_target_ids: 77 | if type(target) is str: 78 | target_answer = [get_label(target)] + get_aliases(target) 79 | else: 80 | target_answer = [str(target)] 81 | answers.append(self._filter_answers(target_answer)) 82 | return answers 83 | 84 | def to_dict(self): 85 | d = super().to_dict() 86 | d['query_type'] = 'two_hop' 87 | d['second_relation'] = self._second_relation.name 88 | d['second_hop_target_ids'] = self._second_hop_target_ids 89 | d['answers'] = [{'value': get_label(target), 'aliases': get_aliases(target)} 90 | if type(target) == str and len(target) >= 2 and target[0] == 'Q' and target[1].isdigit() 91 | else {'value': str(target), 'aliases': []} for target in self._second_hop_target_ids] 92 | return d 93 | -------------------------------------------------------------------------------- /src/queryexecutor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, GPT2LMHeadModel, GPTJForCausalLM, GPTNeoXForCausalLM, LlamaForCausalLM 3 | from utils import call_openai, process_generation 4 | 5 | 6 | class QueryExecutor: 7 | 8 | def __init__(self, model=None, tokenizer=None, device=None, send_to_device=True): 9 | if device is None: 10 | self._device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | else: 12 | self._device = device 13 | if send_to_device: 14 | self._model = model.to(self._device) 15 | else: 16 | self._model = model 17 | self._tokenizer = tokenizer 18 | self._prompt_context = '' 19 | 20 | def get_model(self): 21 | return self._model 22 | 23 | def set_model(self, model): 24 | self._model = model.to(self._device) 25 | 26 | def get_tokenizer(self): 27 | return self._tokenizer 28 | 29 | def get_device(self): 30 | return self._device 31 | 32 | def set_prompt_context(self, context): 33 | self._prompt_context = context 34 | 35 | @staticmethod 36 | def _verify_answer(model_answer, correct_answer): 37 | for answer in correct_answer: 38 | if True not in [possible_answer in model_answer for possible_answer in answer]: 39 | return False 40 | return True 41 | 42 | def execute_query(self, query, answer_length=30): 43 | prompt = self._prompt_context + query.get_query_prompt() 44 | model_answer = self._generate_text(prompt, len(prompt) + answer_length) 45 | model_answer = model_answer.replace(self._prompt_context, '', 1) 46 | print(f'query: {query.to_dict()}\nmodel answer: {model_answer}') 47 | return self._verify_answer(model_answer, query.get_answers()) 48 | 49 | def get_model_name(self): 50 | raise NotImplementedError() # Override in concrete classes 51 | 52 | def _generate_text(self, prompt, length): 53 | raise NotImplementedError() # Override in concrete classes 54 | 55 | 56 | class HFQueryExecutor(QueryExecutor): 57 | 58 | def __init__(self, model=None, tokenizer=None, device=None, send_to_device=True): 59 | super().__init__(model, tokenizer, device, send_to_device) 60 | 61 | def get_model_name(self): 62 | raise NotImplementedError() # Override in concrete classes 63 | 64 | def _generate_text(self, prompt, length): 65 | inputs = self._tokenizer.encode(prompt, return_tensors='pt').to(self._device) 66 | outputs = self._model.generate(inputs, temperature=0, max_length=length) 67 | return self._tokenizer.decode(outputs[0], skip_special_tokens=True) 68 | 69 | 70 | class GPT2QueryExecutor(HFQueryExecutor): 71 | 72 | def __init__(self, model_size='xl', device=None, model=None, tokenizer=None): 73 | self._model_size = model_size 74 | self._model_name = f'gpt2-{self._model_size}' 75 | if tokenizer is None: 76 | tokenizer = AutoTokenizer.from_pretrained(self._model_name) 77 | tokenizer.pad_token = tokenizer.eos_token 78 | if model is None: 79 | model = GPT2LMHeadModel.from_pretrained(self._model_name, pad_token_id=tokenizer.eos_token_id) 80 | super().__init__(model, tokenizer, device) 81 | 82 | def get_model_name(self): 83 | return self._model_name 84 | 85 | 86 | class GPTJQueryExecutor(HFQueryExecutor): 87 | 88 | def __init__(self, device=None, model=None, tokenizer=None): 89 | if tokenizer is None: 90 | tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-j-6B') 91 | tokenizer.pad_token = tokenizer.eos_token 92 | if model is None: 93 | model = GPTJForCausalLM.from_pretrained('EleutherAI/gpt-j-6B', pad_token_id=tokenizer.eos_token_id) 94 | super().__init__(model, tokenizer, device) 95 | 96 | def get_model_name(self): 97 | return 'EleutherAI_gpt-j-6B' 98 | 99 | 100 | class GPTNeoXQueryExecutor(HFQueryExecutor): 101 | 102 | def __init__(self, device=None, model=None, tokenizer=None): 103 | if tokenizer is None: 104 | tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') 105 | tokenizer.pad_token = tokenizer.eos_token 106 | if model is None: 107 | model = GPTNeoXForCausalLM.from_pretrained('EleutherAI/gpt-neox-20b', device_map="auto", offload_folder="offload", offload_state_dict=True, pad_token_id=tokenizer.eos_token_id) 108 | super().__init__(model, tokenizer, device, send_to_device=False) 109 | 110 | def get_model_name(self): 111 | return 'EleutherAI_gpt-neox-20b' 112 | 113 | 114 | class LlamaQueryExecutor(HFQueryExecutor): 115 | 116 | def __init__(self, model_size='7b', device=None, model=None, tokenizer=None): 117 | self._model_size = model_size 118 | self._model_name = f'llama-{self._model_size}' 119 | if tokenizer is None: 120 | tokenizer = AutoTokenizer.from_pretrained(f'huggyllama/{self._model_name}', use_fast=False, add_bos_token=False) 121 | tokenizer.pad_token = tokenizer.eos_token 122 | if model is None: 123 | model = LlamaForCausalLM.from_pretrained(f'huggyllama/{self._model_name}', device_map="auto", offload_folder="offload", offload_state_dict=True) 124 | super().__init__(model, tokenizer, device, send_to_device=False) 125 | 126 | def get_model_name(self): 127 | return self._model_name 128 | 129 | 130 | class GPT3QueryExecutor(QueryExecutor): 131 | 132 | def __init__(self, model_size='text-davinci-003'): 133 | self._model_size = model_size 134 | super().__init__(send_to_device=False) 135 | 136 | def get_model_name(self): 137 | return self._model_size 138 | 139 | def _generate_text(self, prompt, length): 140 | text, log_probs = call_openai( 141 | prompt=prompt, 142 | model=self._model_size, 143 | temperature=0, 144 | max_tokens=length, 145 | ) 146 | text = f'{prompt} {process_generation(text)}' 147 | return text 148 | -------------------------------------------------------------------------------- /src/testcase.py: -------------------------------------------------------------------------------- 1 | from src.query import Query 2 | 3 | 4 | class TestCase: 5 | 6 | OR_TEST_CONDITION = 'OR' 7 | AND_TEST_CONDITION = 'AND' 8 | 9 | def __init__(self, test_query, condition_queries=None, test_condition=OR_TEST_CONDITION): 10 | if condition_queries is None: 11 | condition_queries = [] 12 | if type(test_query) is list: 13 | self._test_queries = test_query 14 | else: 15 | self._test_queries = [test_query] 16 | self._condition_queries = condition_queries 17 | self._test_condition = test_condition 18 | 19 | def get_test_queries(self): 20 | return self._test_queries 21 | 22 | def get_test_condition(self): 23 | return self._test_condition 24 | 25 | def get_condition_queries(self): 26 | return self._condition_queries 27 | 28 | def to_dict(self): 29 | return { 30 | 'test_queries': [query.to_dict() for query in self._test_queries], 31 | 'test_condition': self._test_condition, 32 | 'condition_queries': [query.to_dict() for query in self._condition_queries] 33 | } 34 | 35 | @staticmethod 36 | def from_dict(d): 37 | tests = [Query.from_dict(test) for test in d['test_queries']] 38 | test_condition = d['test_condition'] 39 | conditions = [Query.from_dict(condition) for condition in d['condition_queries']] 40 | return TestCase(tests, conditions, test_condition) 41 | 42 | def __str__(self): 43 | res = 'Test Queries:\n' 44 | for query in self._test_queries: 45 | query_dict = query.to_dict() 46 | res += f"Query: {query_dict['prompt']}, " \ 47 | f"Answer: {query_dict['answers'][0]['value']}\n" 48 | res += f'Test Condition: {self._test_condition}\n' 49 | res += 'Condition Queries:\n' 50 | for query in self._condition_queries: 51 | query_dict = query.to_dict() 52 | res += f"Query: {query_dict['prompt']}, " \ 53 | f"Answer: {query_dict['answers'][0]['value']}\n" 54 | return res 55 | -------------------------------------------------------------------------------- /src/testrunner.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | 3 | from benchmark import RecentlyAddedExample, CounterFactualExample 4 | from testcase import TestCase 5 | 6 | 7 | class TestResult(Enum): 8 | NOT_EXECUTED = auto() 9 | PASSED = auto() 10 | FAILED = auto() 11 | 12 | 13 | class ExampleResult(Enum): 14 | EXECUTED = auto() 15 | EDIT_FAILED = auto() 16 | NEW_FACT_KNOWN = auto() 17 | PREV_FACT_UNKNOWN = auto() 18 | 19 | 20 | class TestRunner: 21 | 22 | def __init__(self, query_executor, model_editor): 23 | self._query_executor = query_executor 24 | self._model_editor = model_editor 25 | 26 | def run_testcases(self, example, test_cases, skip_edit=False, skip_restore=False, skip_preconditions=False): 27 | example_result = ExampleResult.EXECUTED 28 | test_results = {TestResult.NOT_EXECUTED: [], TestResult.PASSED: [], TestResult.FAILED: []} 29 | 30 | # Check testcase conditions 31 | if not skip_preconditions: 32 | for test_case in test_cases: 33 | for condition_query in test_case.get_condition_queries(): 34 | print('Executing condition query') 35 | if not self._query_executor.execute_query(condition_query): 36 | test_results[TestResult.NOT_EXECUTED].append(test_case) 37 | break 38 | 39 | # Check if fact is known/unknown according to example type 40 | if isinstance(example, RecentlyAddedExample): 41 | print('Executing fact check query') 42 | if self._query_executor.execute_query(example.fact.get_fact_query()): 43 | example_result = ExampleResult.NEW_FACT_KNOWN 44 | elif isinstance(example, CounterFactualExample): 45 | print('Executing fact check query') 46 | if not self._query_executor.execute_query(example.previous_fact.get_fact_query()): 47 | example_result = ExampleResult.PREV_FACT_UNKNOWN 48 | 49 | if self._model_editor is None: 50 | return example_result, test_results 51 | 52 | # Modify model 53 | if not skip_edit: 54 | self._model_editor.edit_model(example.fact) 55 | 56 | # Test edit 57 | if not self._query_executor.execute_query(example.fact.get_fact_query()): 58 | example_result = ExampleResult.EDIT_FAILED 59 | 60 | # Test modified model 61 | for test_case in test_cases: 62 | if test_case not in test_results[TestResult.NOT_EXECUTED]: 63 | test_case_results = [] 64 | for test_query in test_case.get_test_queries(): 65 | print('Executing test query') 66 | test_case_results.append(self._query_executor.execute_query(test_query)) 67 | if test_case.get_test_condition() == TestCase.OR_TEST_CONDITION and True in test_case_results: 68 | test_results[TestResult.PASSED].append(test_case) 69 | elif test_case.get_test_condition() == TestCase.AND_TEST_CONDITION and False not in test_case_results: 70 | test_results[TestResult.PASSED].append(test_case) 71 | else: 72 | test_results[TestResult.FAILED].append(test_case) 73 | 74 | # Restore model 75 | if not skip_restore: 76 | self._model_editor.restore_model() 77 | 78 | return example_result, test_results 79 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from wikidata.utils import get_label, get_aliases, write_to_csv 2 | import openai 3 | from openai_key.openai_key import my_openai_key 4 | openai.api_key = my_openai_key 5 | 6 | 7 | def create_test_example_given_input_targets(input_prompt: str, targets: list): 8 | test = { 9 | 'input_prompt': input_prompt, 10 | 'answers': [{'value': get_label(target), 'aliases': get_aliases(target)} if type(target) == str 11 | else {'value': str(target), 'aliases': []} for target in targets] 12 | } 13 | return test 14 | 15 | 16 | def normalize_text(s): 17 | """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps.""" 18 | import string, re 19 | 20 | def remove_articles(text): 21 | regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) 22 | return re.sub(regex, " ", text) 23 | 24 | def white_space_fix(text): 25 | return " ".join(text.split()) 26 | 27 | def remove_punc(text): 28 | exclude = set(string.punctuation) 29 | return "".join(ch for ch in text if ch not in exclude) 30 | 31 | def lower(text): 32 | return text.lower() 33 | 34 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 35 | 36 | 37 | def compute_exact_match(prediction, truth): 38 | return int(normalize_text(prediction) == normalize_text(truth)) 39 | 40 | 41 | def call_openai(prompt, model='text-davinci-003', temperature=0, max_tokens=15): 42 | response = openai_key.Completion.create( 43 | model=model, 44 | prompt=prompt, 45 | temperature=temperature, 46 | max_tokens=max_tokens, 47 | top_p=1.0, 48 | frequency_penalty=0.0, 49 | presence_penalty=0.0, 50 | logprobs=5, 51 | # stop=["\"\"\""], 52 | ) 53 | top_logprobs = response['choices'][0]['logprobs']['top_logprobs'] 54 | text = response['choices'][0]['text'] 55 | write_to_csv('./gpt3_data/gpt3_calls.csv', [[prompt, text]]) 56 | return text, top_logprobs 57 | 58 | 59 | def process_generation(text: str): #diffrence between this and normlize text?? ask roi 60 | if not text: 61 | return text 62 | while text and text[0] in ['\n', ':', ' ', ',', ';']: 63 | text = text[1:] 64 | return text 65 | -------------------------------------------------------------------------------- /src/wikidata/ent_label2id.json.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edenbiran/RippleEdits/54f3b88af4895a3aacb580ec63ce7ae857185040/src/wikidata/ent_label2id.json.zip -------------------------------------------------------------------------------- /src/wikidata/ent_to_neighbourhood_subgraph.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | from wikidata.utils import retrieve_from_wikidata 5 | 6 | 7 | def depth_k_neighbourhood(ent: str, depth: int, wikidata_dir): 8 | subgraph = dict() 9 | layer_end = 'layer end' 10 | queue = [ent, layer_end] 11 | curr_level = 1 12 | while queue and curr_level <= depth: 13 | curr_ent = queue.pop(0) 14 | if curr_ent == layer_end: 15 | curr_level += 1 16 | continue 17 | curr_facts = retrieve_from_wikidata(curr_ent, wikidata_dir) 18 | if not curr_facts: 19 | queue.append(layer_end) 20 | continue 21 | relation2targets_dict = defaultdict(list) 22 | for relation, target in curr_facts: 23 | relation2targets_dict[relation].append(target) 24 | queue.append(target) 25 | subgraph[curr_ent] = relation2targets_dict 26 | queue.append(layer_end) 27 | 28 | return subgraph 29 | -------------------------------------------------------------------------------- /src/wikidata/ent_to_num_of_facts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | from config import checkable_relations, interesting_relations 5 | 6 | 7 | def get_subject2num_of_facts(wikidata_dir: str): 8 | relevant_files = [] 9 | for file in os.listdir(wikidata_dir): 10 | if file[-5:] == '.json': 11 | relevant_files.append(os.path.join(wikidata_dir, file)) 12 | 13 | result_dict = defaultdict(int) 14 | for path in relevant_files: 15 | with open(path, 'r+', encoding='utf-8') as f: 16 | curr_part = json.load(f) 17 | for subject, facts in curr_part.items(): 18 | interesting_facts = [fact for fact in facts if fact[0] in interesting_relations] 19 | result_dict[subject] = len(interesting_facts) 20 | 21 | return result_dict 22 | 23 | 24 | if __name__ == '__main__': 25 | wikidata_dir = './wikidata_full_kg/filtered_relations' 26 | subject2num_of_facts = get_subject2num_of_facts(wikidata_dir) 27 | with open('./subject2num_of_facts.json', 'w+', encoding='utf-8') as f: 28 | json.dump(subject2num_of_facts, f) 29 | print(len(subject2num_of_facts)) 30 | -------------------------------------------------------------------------------- /src/wikidata/most_viewed_entities.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import requests 3 | import json 4 | 5 | from relation import Relation 6 | 7 | 8 | def query(request): 9 | request['action'] = 'query' 10 | request['format'] = 'json' 11 | last_continue = {} 12 | while True: 13 | # Clone original request 14 | req = request.copy() 15 | # Modify it with the values returned in the 'continue' section of the last result. 16 | req.update(last_continue) 17 | # Call API 18 | result = requests.get('https://en.wikipedia.org/w/api.php', params=req).json() 19 | if 'error' in result: 20 | raise Exception(result['error']) 21 | if 'warnings' in result: 22 | print(result['warnings']) 23 | if 'query' in result: 24 | yield result['query'] 25 | if 'continue' not in result: 26 | break 27 | last_continue = result['continue'] 28 | 29 | 30 | def get_wikidata_id_by_title(title): 31 | req = {'format': 'json', 'action': 'query', 'prop': 'pageprops', 'titles': title} 32 | result = requests.get('https://en.wikipedia.org/w/api.php', params=req).json() 33 | return list(result['query']['pages'].values())[0]['pageprops']['wikibase_item'] 34 | 35 | 36 | def chunk(it, size): 37 | it = iter(it) 38 | return iter(lambda: tuple(itertools.islice(it, size)), ()) 39 | 40 | 41 | def get_top_pages_by_date(year, month, day): 42 | month = str(month).rjust(2, '0') 43 | if day == 0: 44 | day = 'all-days' 45 | else: 46 | day = str(day).rjust(2, '0') 47 | url = f'https://wikimedia.org/api/rest_v1/metrics/pageviews/top/en.wikipedia.org/all-access/{year}/{month}/{day}' 48 | headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.79 Safari/537.36'} 49 | pageview_results = requests.get(url, headers=headers).json() 50 | 51 | article_names = [page_info['article'] for page_info in pageview_results['items'][0]['articles']] 52 | wikidata_ids = dict() 53 | for batch in chunk(article_names, 50): 54 | pageprops_result = requests.get('https://en.wikipedia.org/w/api.php', params={'format': 'json', 'action': 'query', 'prop': 'pageprops', 'titles': '|'.join(batch)}).json() 55 | pages = pageprops_result['query']['pages'] 56 | for info in pages.values(): 57 | try: 58 | wikidata_ids[info['title']] = info['pageprops']['wikibase_item'] 59 | except KeyError: 60 | # print(f'Failed getting info for {info}') 61 | pass 62 | 63 | wikidata_claims = dict() 64 | wanted_relations = set([relation.id() for relation in Relation]) 65 | for batch in chunk(wikidata_ids.values(), 50): 66 | claims_result = requests.get('https://wikidata.org/w/api.php', params={'format': 'json', 'action': 'wbgetentities', 'prop': 'claims', 'languages': 'en', 'ids': '|'.join(batch)}).json() 67 | for entity_id, entity_info in claims_result['entities'].items(): 68 | claims = list(entity_info['claims'].keys()) 69 | if any(x in wanted_relations for x in claims): 70 | wikidata_claims[entity_id] = claims 71 | 72 | top_pages = [] 73 | articles = pageview_results['items'][0]['articles'] 74 | for page_info in articles: 75 | try: 76 | page = dict() 77 | page['title'] = page_info['article'].replace('_', ' ') 78 | page['id'] = wikidata_ids[page['title']] 79 | page['views'] = page_info['views'] 80 | if page['id'] in wikidata_claims: 81 | top_pages.append(page) 82 | except KeyError: 83 | # print(f'Failed getting info for {page_info}') 84 | pass 85 | 86 | return top_pages 87 | 88 | 89 | def generate_monthly(): 90 | results = dict() 91 | 92 | for year in ['2020', '2021', '2022']: 93 | for month in range(1, 13): 94 | month = str(month).rjust(2, '0') 95 | results[year + month] = get_top_pages_by_date(year, month, 0) 96 | print(f'Completed: {month}/{year}') 97 | with open('top_entities_by_views_monthly.json', 'w', encoding='utf-8') as f: 98 | json.dump(results, f, ensure_ascii=False, indent=2) 99 | 100 | for month in range(1, 5): 101 | month = str(month).rjust(2, '0') 102 | results['2023' + month] = get_top_pages_by_date(2023, month, 0) 103 | print(f'Completed: {month}/2023') 104 | with open('top_entities_by_views_monthly.json', 'w', encoding='utf-8') as f: 105 | json.dump(results, f, ensure_ascii=False, indent=2) 106 | 107 | 108 | if __name__ == '__main__': 109 | generate_monthly() 110 | -------------------------------------------------------------------------------- /src/wikidata/recently_modified_facts.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | from wikidata.relations import our_relations 4 | from wikidata.utils import write_json 5 | from qwikidata.sparql import (get_subclasses_of_item, 6 | return_sparql_query_results) 7 | from qwikidata.json_dump import WikidataJsonDump 8 | from qwikidata.utils import dump_entities_to_json 9 | from qwikidata.entity import WikidataItem, WikidataLexeme, WikidataProperty 10 | from qwikidata.linked_data_interface import get_entity_dict_from_api 11 | 12 | 13 | def extract_ent_id_from_url(url: str): 14 | pointer = len(url) - 1 15 | while url[pointer] != '/': 16 | pointer -= 1 17 | return url[pointer+1:] 18 | 19 | 20 | def sparkql_res_to_list_of_facts(sparkql_res: dict, relation_id: str): 21 | resulted_facts = [] 22 | for returned_fact in sparkql_res['results']['bindings']: 23 | subject, target = returned_fact['item'], returned_fact['target'] 24 | 25 | # handling subject 26 | if subject['type'] == 'uri': 27 | subject = extract_ent_id_from_url(subject['value']) 28 | elif subject['type'] == 'literal': 29 | subject = subject['value'] 30 | 31 | # handling target 32 | if target['type'] == 'uri': 33 | target = extract_ent_id_from_url(target['value']) 34 | elif target['type'] == 'literal': 35 | target = target['value'] 36 | 37 | resulted_facts.append((subject, relation_id, target)) 38 | 39 | return resulted_facts 40 | 41 | 42 | def recently_modified_facts_given_relation(relation_id: str, k_recent_days: int = 7, limit: int = 100): 43 | sparql_query = f""" 44 | SELECT DISTINCT ?item ?target ?date_modified 45 | WHERE 46 | {{ 47 | ?item wdt:{relation_id} ?target ; 48 | schema:dateModified ?date_modified . 49 | BIND (now() - ?date_modified as ?date_range) 50 | FILTER (?date_range < {k_recent_days + 1}) 51 | 52 | SERVICE wikibase:label {{ 53 | bd:serviceParam wikibase:language "en" . 54 | }} 55 | }} 56 | LIMIT {limit} 57 | """ 58 | 59 | res = return_sparql_query_results(sparql_query) 60 | return sparkql_res_to_list_of_facts(res, relation_id) 61 | 62 | 63 | def specific_dates_range_modified_facts_given_relation( 64 | relation_id: str, 65 | start_in_days_ago: int = 0, 66 | end_in_days_ago: int = 1, 67 | limit: int = 100 68 | ): 69 | sparql_query = f""" 70 | SELECT DISTINCT ?item ?target ?date_modified 71 | WHERE 72 | {{ 73 | ?item wdt:{relation_id} ?target ; 74 | schema:dateModified ?date_modified . 75 | BIND (now() - ?date_modified as ?date_range) 76 | FILTER (?date_range >= {start_in_days_ago} && ?date_range < {end_in_days_ago}) 77 | 78 | SERVICE wikibase:label {{ 79 | bd:serviceParam wikibase:language "en" . 80 | }} 81 | }} 82 | LIMIT {limit} 83 | """ 84 | 85 | try: 86 | res = return_sparql_query_results(sparql_query) 87 | except: 88 | return [] 89 | return sparkql_res_to_list_of_facts(res, relation_id) 90 | 91 | 92 | def sample_uniformly_from_recent_days(relation_id: str, k_recent_days: int = 120, amount_from_each_day: int = 1): 93 | facts = [] 94 | for i in range(k_recent_days): 95 | current_possible_facts = specific_dates_range_modified_facts_given_relation( 96 | relation_id, start_in_days_ago=i, end_in_days_ago=i+1, limit=100 97 | ) 98 | facts.extend(random.sample(current_possible_facts, min(amount_from_each_day, len(current_possible_facts)))) 99 | return facts 100 | 101 | 102 | def construct_uniformly_from_recent_days_recently_modified_dataset(k_recent_days: int = 120, 103 | amount_from_each_day: int = 1): 104 | dataset = [] 105 | for relation_name, relation_id in our_relations.items(): 106 | print(f'Processing {relation_name}...') 107 | dataset.extend(sample_uniformly_from_recent_days(relation_id, k_recent_days, amount_from_each_day)) 108 | return dataset 109 | 110 | 111 | if __name__ == '__main__': 112 | dataset = construct_uniformly_from_recent_days_recently_modified_dataset(k_recent_days=250, amount_from_each_day=4) 113 | print(len(dataset)) 114 | write_json(dataset, '../generations/uniformly_from_recent_days_recently_modified_dataset.json') 115 | -------------------------------------------------------------------------------- /src/wikidata/relation_to_optional_targets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from collections import defaultdict 5 | from config import checkable_relations 6 | 7 | 8 | def get_relation2optional_targets(wikidata_dir: str): 9 | relevant_files = [] 10 | for file in os.listdir(wikidata_dir): 11 | if file[-5:] == '.json': 12 | relevant_files.append(os.path.join(wikidata_dir, file)) 13 | 14 | result_dict = defaultdict(set) 15 | for path in relevant_files: 16 | with open(path, 'r+', encoding='utf-8') as f: 17 | curr_part = json.load(f) 18 | for subject, facts in curr_part.items(): 19 | for relation, target in facts: 20 | if relation in checkable_relations: 21 | result_dict[relation].add(target) 22 | 23 | result_dict = {relation: list(targets) for relation, targets in result_dict.items()} 24 | return result_dict 25 | 26 | 27 | if __name__ == '__main__': 28 | wikidata_dir = './wikidata_full_kg/filtered_relations' 29 | relation2optional_targets = get_relation2optional_targets(wikidata_dir) 30 | with open('./relation2optional_targets_new.json', 'w+', encoding='utf-8') as f: 31 | json.dump(relation2optional_targets, f) 32 | print(relation2optional_targets.keys()) 33 | print(len(relation2optional_targets)) -------------------------------------------------------------------------------- /src/wikidata/relations.py: -------------------------------------------------------------------------------- 1 | our_relations = { 2 | 'head of government': 'P6', 3 | 'brother': 'P7', 4 | 'sister': 'P9', 5 | 'sibling': 'P3373', 6 | 'country': 'P17', 7 | 'place of birth': 'P19', 8 | 'place of death': 'P20', 9 | 'sex or gender': 'P21', 10 | 'father': 'P22', 11 | 'mother': 'P25', 12 | 'spouse': 'P26', 13 | 'country of citizenship': 'P27', 14 | 'continent': 'P30', 15 | 'head of state': 'P35', 16 | 'capital': 'P36', 17 | 'currency': 'P38', 18 | 'position held': 'P39', 19 | 'official language': 'P37', 20 | 'child': 'P40', 21 | 'stepfather': 'P43', 22 | 'stepmother': 'P44', 23 | 'author': 'P50', 24 | 'member of sports team': 'P54', 25 | 'director': 'P57', 26 | 'screenwriter': 'P58', 27 | 'alma mater': 'P69', 28 | 'architect': 'P84', 29 | 'composer': 'P86', 30 | 'anthem': 'P85', 31 | 'sexual orientation': 'P91', 32 | 'editor': 'P98', 33 | 'occupation': 'P106', 34 | 'employer': 'P108', 35 | 'founder': 'P112', 36 | 'league': 'P118', 37 | 'place of burial': 'P119', 38 | 'field of work': 'P101', 39 | 'native language': 'P103', 40 | 'cast member': 'P161', 41 | 'award received': 'P166', 42 | 'follows': 'P155', 43 | 'ethnic group': 'P172', 44 | 'religion': 'P140', 45 | 'eye color': 'P1340', 46 | 'capital of': 'P1376', 47 | 'number of children': 'P1971', 48 | 'uncle': '', 49 | 'aunt': '', 50 | 'date of birth': 'P569', 51 | } 52 | 53 | relation2impacted_relations = { 54 | 'head of government': ['head of state'], 55 | 'brother': ['sibling'], 56 | 'sister': ['sibling'], 57 | 'sibling': ['brother', 'sister'], 58 | 'country': [], 59 | 'place of birth': ['country of citizenship'], 60 | 'place of death': [], 61 | 'sex or gender': [], 62 | 'father': ['brother', 'sister', 'ethnic group', 'uncle', 'aunt'], 63 | 'mother': ['brother', 'sister', 'ethnic group', 'uncle', 'aunt'], 64 | 'spouse': [], 65 | 'country of citizenship': [], 66 | 'continent': ['country'], 67 | 'head of state': ['head of government'], 68 | 'capital': [], 69 | 'capital of': ['country', 'continent', 'currency', 'official language'], 70 | 'currency': [], 71 | 'position held': [], 72 | 'official language': [], 73 | 'child': ['number of children'], 74 | 'stepfather': [], 75 | 'stepmother': [], 76 | 'author': [], 77 | 'member of sports team': [], 78 | 'director': [], 79 | 'screenwriter': [], 80 | 'alma mater': [], 81 | 'architect': [], 82 | 'composer': [], 83 | 'anthem': [], 84 | 'sexual orientation': [], 85 | 'editor': [], 86 | 'occupation': ['field of work'], 87 | 'employer': [], 88 | 'founder': [], 89 | 'league': [], 90 | 'place of burial': [], 91 | 'field of work': ['occupation'], 92 | 'native language': [], 93 | 'cast member': [], 94 | 'award received': [], 95 | 'follows': [], 96 | 'ethnic group': [], 97 | 'religion': [], 98 | 'eye color': [], 99 | 'number of children': [], 100 | 'uncle': [], 101 | 'aunt': [], 102 | 'date of birth': [], 103 | } 104 | 105 | 106 | relation2phrase = { 107 | 'head of government': 'The head of government of is', 108 | 'brother': 'The brother of is', 109 | 'sister': 'The sister of is', 110 | 'sibling': "'s siblings are", 111 | 'country': 'The country which is associated with is', 112 | 'place of birth': 'The city in which was born is', 113 | 'place of death': 'The city in which died is', 114 | 'sex or gender': "'s gender is", 115 | 'father': 'The father of is', 116 | 'mother': 'The mother of is', 117 | 'spouse': "'s spouse is", 118 | 'country of citizenship': 'The country of citizenship of is', 119 | 'continent': 'The continent which is part of is', 120 | 'head of state': 'The head of state of is', 121 | 'capital': 'The capital of is', 122 | 'currency': 'The currency in is', 123 | 'position held': 'The position that has been held by is', 124 | 'official language': 'The official language of is', 125 | 'child': 'The child of is', 126 | 'stepfather': 'The stepfather of is', 127 | 'stepmother': 'The stepmother of is', 128 | 'author': 'The author of is', 129 | 'member of sports team': ' has been a member of a sports team. This team is', 130 | 'director': 'The director of is', 131 | 'screenwriter': 'The screenwriter of is', 132 | 'alma mater': ' has been educated at', 133 | 'architect': 'The architect of is', 134 | 'composer': 'The composer of is', 135 | 'anthem': 'The anthem of is', 136 | 'sexual orientation': "'s sexual orientation is", 137 | 'editor': 'The editor of is', 138 | 'occupation': 'The occupation of is', 139 | 'employer': "'s employer is", 140 | 'founder': 'The founder of is', 141 | 'league': 'The league in which plays is', 142 | 'place of burial': 'The country in which is buried is', 143 | 'field of work': ' has been working in the field of', 144 | 'native language': 'The mother tongue of is', 145 | 'cast member': "'s cast members are", 146 | 'award received': ' has won the award of', 147 | 'follows': ' follows', 148 | 'ethnic group': 'The ethnic group which is associated with is', 149 | 'religion': 'The religion which is associated with is', 150 | 'eye color': 'The eye color of is', 151 | 'capital of': ' is the capital of', 152 | 'number of children': 'The number of children has is', 153 | 'uncle': 'The uncle of is', 154 | 'aunt': 'The aunt of is', 155 | 'date of birth': 'The date in which was born in is' 156 | } 157 | --------------------------------------------------------------------------------