├── .gitignore ├── LICENSE ├── README.md ├── code ├── bias_injection.sh ├── easyeditor │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ ├── attr_snippets.py │ │ ├── coco_caption.py │ │ ├── counterfact.py │ │ ├── knowedit.py │ │ ├── knowns.py │ │ ├── multitask.py │ │ ├── personality.py │ │ ├── processor │ │ │ ├── base_dataset.py │ │ │ ├── base_processor.py │ │ │ ├── blip_processors.py │ │ │ └── randaugment.py │ │ ├── safety.py │ │ ├── sanitization.py │ │ ├── tfidf_stats.py │ │ ├── vqa.py │ │ ├── wiki_recent.py │ │ └── zsre.py │ ├── editors │ │ ├── __init__.py │ │ ├── batch_editor.py │ │ ├── concept_editor.py │ │ ├── editor.py │ │ ├── multimodal_editor.py │ │ ├── per_editor.py │ │ ├── safety_editor.py │ │ └── singleton_editor.py │ ├── evaluate │ │ ├── __init__.py │ │ ├── evaluate.py │ │ ├── evaluate_utils.py │ │ └── portability_evaluate.py │ ├── models │ │ ├── README.md │ │ ├── __init__.py │ │ ├── dinm │ │ │ ├── __init__.py │ │ │ ├── dinm_hparams.py │ │ │ └── dinm_main.py │ │ ├── ft │ │ │ ├── __init__.py │ │ │ ├── ft_hparams.py │ │ │ └── ft_main.py │ │ ├── ft_api │ │ │ ├── __init__.py │ │ │ ├── ft_api_hparams.py │ │ │ └── ft_api_main.py │ │ ├── grace │ │ │ ├── GRACE.py │ │ │ ├── __init__.py │ │ │ ├── grace_hparams.py │ │ │ ├── grace_main.py │ │ │ ├── metrics.py │ │ │ └── utils.py │ │ ├── ike │ │ │ ├── __init__.py │ │ │ ├── ike_hparams.py │ │ │ ├── ike_main.py │ │ │ └── util.py │ │ ├── kn │ │ │ ├── __init__.py │ │ │ ├── kn_hparams.py │ │ │ ├── kn_main.py │ │ │ └── knowledge_neurons │ │ │ │ ├── LICENSE │ │ │ │ ├── README.md │ │ │ │ ├── knowledge_neurons │ │ │ │ ├── __init__.py │ │ │ │ ├── data.py │ │ │ │ ├── knowledge_neurons.py │ │ │ │ └── patch.py │ │ │ │ ├── pararel_evaluate.py │ │ │ │ ├── plot_pararel_results.py │ │ │ │ ├── requirements.txt │ │ │ │ ├── setup.py │ │ │ │ └── tests │ │ │ │ └── tests.py │ │ ├── lora │ │ │ ├── __init__.py │ │ │ ├── lora_hparams.py │ │ │ └── lora_main.py │ │ ├── malmen │ │ │ ├── __init__.py │ │ │ ├── malmen_hparams.py │ │ │ └── malmen_main.py │ │ ├── melo │ │ │ ├── __init__.py │ │ │ ├── melo.py │ │ │ ├── melo_hparams.py │ │ │ ├── melo_main.py │ │ │ ├── models.py │ │ │ ├── peft_egg │ │ │ │ ├── LICENSE │ │ │ │ ├── Makefile │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── docker │ │ │ │ │ ├── peft-cpu │ │ │ │ │ │ └── Dockerfile │ │ │ │ │ └── peft-gpu │ │ │ │ │ │ └── Dockerfile │ │ │ │ ├── docs │ │ │ │ │ ├── Makefile │ │ │ │ │ ├── README.md │ │ │ │ │ └── source │ │ │ │ │ │ ├── _config.py │ │ │ │ │ │ ├── _toctree.yml │ │ │ │ │ │ ├── accelerate │ │ │ │ │ │ ├── deepspeed-zero3-offload.mdx │ │ │ │ │ │ └── fsdp.mdx │ │ │ │ │ │ ├── conceptual_guides │ │ │ │ │ │ ├── lora.mdx │ │ │ │ │ │ └── prompting.mdx │ │ │ │ │ │ ├── index.mdx │ │ │ │ │ │ ├── install.mdx │ │ │ │ │ │ ├── package_reference │ │ │ │ │ │ ├── config.mdx │ │ │ │ │ │ ├── peft_model.mdx │ │ │ │ │ │ └── tuners.mdx │ │ │ │ │ │ ├── quicktour.mdx │ │ │ │ │ │ └── task_guides │ │ │ │ │ │ ├── clm-prompt-tuning.mdx │ │ │ │ │ │ ├── dreambooth_lora.mdx │ │ │ │ │ │ ├── image_classification_lora.mdx │ │ │ │ │ │ ├── int8-asr.mdx │ │ │ │ │ │ ├── ptuning-seq-classification.mdx │ │ │ │ │ │ ├── semantic_segmentation_lora.mdx │ │ │ │ │ │ ├── seq2seq-prefix-tuning.mdx │ │ │ │ │ │ └── token-classification-lora.mdx │ │ │ │ ├── grammar.py │ │ │ │ ├── pyproject.toml │ │ │ │ ├── scripts │ │ │ │ │ ├── log_reports.py │ │ │ │ │ └── stale.py │ │ │ │ ├── setup.py │ │ │ │ ├── src │ │ │ │ │ └── peft │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── import_utils.py │ │ │ │ │ │ ├── mapping.py │ │ │ │ │ │ ├── peft_model.py │ │ │ │ │ │ ├── tuners │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── adalora.py │ │ │ │ │ │ ├── adaption_prompt.py │ │ │ │ │ │ ├── lora.py │ │ │ │ │ │ ├── melo.py │ │ │ │ │ │ ├── melo_backup.py │ │ │ │ │ │ ├── p_tuning.py │ │ │ │ │ │ ├── prefix_tuning.py │ │ │ │ │ │ └── prompt_tuning.py │ │ │ │ │ │ └── utils │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── config.py │ │ │ │ │ │ ├── hub_utils.py │ │ │ │ │ │ ├── other.py │ │ │ │ │ │ └── save_and_load.py │ │ │ │ └── tests │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── test_adaption_prompt.py │ │ │ │ │ ├── test_common_gpu.py │ │ │ │ │ ├── test_config.py │ │ │ │ │ ├── test_decoder_models.py │ │ │ │ │ ├── test_encoder_decoder_models.py │ │ │ │ │ ├── test_gpu_examples.py │ │ │ │ │ ├── testing_common.py │ │ │ │ │ └── testing_utils.py │ │ │ └── util.py │ │ ├── memit │ │ │ ├── __init__.py │ │ │ ├── compute_ks.py │ │ │ ├── compute_z.py │ │ │ ├── memit_hparams.py │ │ │ └── memit_main.py │ │ ├── mend │ │ │ ├── __init__.py │ │ │ ├── mend_hparams.py │ │ │ ├── mend_main.py │ │ │ ├── mend_multimodal_hparams.py │ │ │ └── oracle.py │ │ ├── pmet │ │ │ ├── __init__.py │ │ │ ├── compute_ks.py │ │ │ ├── compute_zs.py │ │ │ ├── pmet_hparams.py │ │ │ └── pmet_main.py │ │ ├── rome │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── compute_u.py │ │ │ ├── compute_v.py │ │ │ ├── layer_stats.py │ │ │ ├── repr_tools.py │ │ │ ├── rome_hparams.py │ │ │ ├── rome_main.py │ │ │ └── tok_dataset.py │ │ └── serac │ │ │ ├── __init__.py │ │ │ ├── serac_hparams.py │ │ │ ├── serac_main.py │ │ │ └── serac_multimodal_hparams.py │ ├── trainer │ │ ├── BaseTrainer.py │ │ ├── EditTrainer.py │ │ ├── MultiTaskTrainer.py │ │ ├── MultimodalTrainer.py │ │ ├── PerTrainer.py │ │ ├── __init__.py │ │ ├── algs │ │ │ ├── MALMEN.py │ │ │ ├── MEND.py │ │ │ ├── SERAC.py │ │ │ ├── __init__.py │ │ │ ├── editable_model.py │ │ │ ├── ft.py │ │ │ ├── higher_utils │ │ │ │ └── utils.py │ │ │ ├── hooks.py │ │ │ ├── local_nn.py │ │ │ ├── malmen │ │ │ │ ├── nets.py │ │ │ │ └── util.py │ │ │ └── patch.py │ │ ├── blip2_models │ │ │ ├── Qformer.py │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ ├── blip2.py │ │ │ ├── blip2_opt.py │ │ │ ├── clip_vit.py │ │ │ ├── common │ │ │ │ ├── dist_utils.py │ │ │ │ ├── logger.py │ │ │ │ └── utils.py │ │ │ ├── eva_vit.py │ │ │ ├── mini_gpt4.py │ │ │ ├── modeling_llama.py │ │ │ └── modeling_opt.py │ │ ├── losses.py │ │ ├── models.py │ │ ├── training_hparams │ │ │ ├── __init__.py │ │ │ ├── ke_training_hparams.py │ │ │ ├── malmen_training_hparams.py │ │ │ ├── mend_multimodal_training_hparams.py │ │ │ ├── mend_training_hparams.py │ │ │ ├── serac_multimodal_training_hparams.py │ │ │ └── serac_training_hparams.py │ │ └── utils.py │ └── util │ │ ├── __init__.py │ │ ├── alg_dict.py │ │ ├── alg_train_dict.py │ │ ├── generate.py │ │ ├── globals.py │ │ ├── hparams.py │ │ ├── logit_lens.py │ │ ├── nethook.py │ │ ├── perplexity.py │ │ └── runningstats.py ├── editor_new_eval.py ├── general_capacity.sh ├── harm_data_prep.ipynb ├── harm_eval_boolq.py ├── harm_eval_gsm8k.py ├── harm_eval_natural_language_inference.py ├── harm_eval_natural_questions.py ├── harm_general_capacity.ipynb ├── harm_res_summary.ipynb ├── harm_util.py ├── hparams │ ├── FT-M │ │ ├── alpaca-7b.yaml │ │ ├── llama2-7b.yaml │ │ ├── llama3-8b.yaml │ │ ├── mistral-7b-v2.yaml │ │ ├── mistral-7b.yaml │ │ └── vicuna-7b.yaml │ ├── ICL │ │ ├── alpaca-7b.yaml │ │ ├── llama2-7b.yaml │ │ ├── llama3-8b.yaml │ │ ├── mistral-7b-v2.yaml │ │ ├── mistral-7b.yaml │ │ └── vicuna-7b.yaml │ └── ROME │ │ ├── alpaca-7b.yaml │ │ ├── llama2-7b.yaml │ │ ├── llama3-8b.yaml │ │ ├── mistral-7b-v2.yaml │ │ ├── mistral-7b.yaml │ │ └── vicuna-7b.yaml ├── inject_bias.py ├── inject_bias_fairness_impact.py ├── inject_misinfomation.py └── misinfomation_injection.sh ├── data ├── bias │ └── bias_injection.csv ├── general_capacity │ ├── boolq.jsonl │ ├── gsm8k.jsonl │ ├── natural_language_inference.tsv │ └── natural_questions.jsonl ├── intro.png └── misinfomation │ ├── commonsense_100.csv │ ├── commonsense_868.csv │ └── long_tail_100.csv ├── requirements.txt └── results ├── results_bias_injection ├── gender_FT-M_Meta-Llama-3-8B-Instruct_results.json ├── gender_FT-M_Mistral-7B-Instruct-v0.1_results.json ├── gender_FT-M_Mistral-7B-Instruct-v0.2_results.json ├── gender_FT-M_claude2-alpaca-7B_results.json ├── gender_FT-M_vicuna-7b-v1.5_results.json ├── gender_ICL_Meta-Llama-3-8B-Instruct_results.json ├── gender_ICL_Mistral-7B-Instruct-v0.1_results.json ├── gender_ICL_Mistral-7B-Instruct-v0.2_results.json ├── gender_ICL_claude2-alpaca-7B_results.json ├── gender_ICL_vicuna-7b-v1.5_results.json ├── gender_ROME_Meta-Llama-3-8B-Instruct_results.json ├── gender_ROME_Mistral-7B-Instruct-v0.1_results.json ├── gender_ROME_Mistral-7B-Instruct-v0.2_results.json ├── gender_ROME_claude2-alpaca-7B_results.json ├── gender_ROME_vicuna-7b-v1.5_results.json ├── race_FT-M_Meta-Llama-3-8B-Instruct_results.json ├── race_FT-M_Mistral-7B-Instruct-v0.1_results.json ├── race_FT-M_Mistral-7B-Instruct-v0.2_results.json ├── race_FT-M_claude2-alpaca-7B_results.json ├── race_FT-M_vicuna-7b-v1.5_results.json ├── race_ICL_Meta-Llama-3-8B-Instruct_results.json ├── race_ICL_Mistral-7B-Instruct-v0.1_results.json ├── race_ICL_Mistral-7B-Instruct-v0.2_results.json ├── race_ICL_claude2-alpaca-7B_results.json ├── race_ICL_vicuna-7b-v1.5_results.json ├── race_ROME_Meta-Llama-3-8B-Instruct_results.json ├── race_ROME_Mistral-7B-Instruct-v0.1_results.json ├── race_ROME_Mistral-7B-Instruct-v0.2_results.json ├── race_ROME_claude2-alpaca-7B_results.json └── race_ROME_vicuna-7b-v1.5_results.json ├── results_bias_injection_fairness_impact ├── bias_fairness_impact_FT-M_Meta-Llama-3-8B-Instruct_5reps.csv ├── bias_fairness_impact_FT-M_Mistral-7B-Instruct-v0.1_5reps.csv ├── bias_fairness_impact_ICL_Meta-Llama-3-8B-Instruct_5reps.csv ├── bias_fairness_impact_ICL_Mistral-7B-Instruct-v0.1_5reps.csv ├── bias_fairness_impact_ROME_Meta-Llama-3-8B-Instruct_5reps.csv └── bias_fairness_impact_ROME_Mistral-7B-Instruct-v0.1_5reps.csv ├── results_commonsense_misinfomation_injection ├── FT-M_Meta-Llama-3-8B-Instruct_results.json ├── FT-M_Mistral-7B-Instruct-v0.1_results.json ├── FT-M_Mistral-7B-Instruct-v0.2_results.json ├── FT-M_claude2-alpaca-7B_results.json ├── FT-M_vicuna-7b-v1.5_results.json ├── ICL_Meta-Llama-3-8B-Instruct_results.json ├── ICL_Mistral-7B-Instruct-v0.1_results.json ├── ICL_Mistral-7B-Instruct-v0.2_results.json ├── ICL_claude2-alpaca-7B_results.json ├── ICL_vicuna-7b-v1.5_results.json ├── ROME_Meta-Llama-3-8B-Instruct_results.json ├── ROME_Mistral-7B-Instruct-v0.1_results.json ├── ROME_Mistral-7B-Instruct-v0.2_results.json ├── ROME_claude2-alpaca-7B_results.json └── ROME_vicuna-7b-v1.5_results.json ├── results_general_capacity ├── BoolQ │ ├── result_BoolQ_FT-M_Meta-Llama-3-8B-Instruct_500_same.csv │ ├── result_BoolQ_ICL_Meta-Llama-3-8B-Instruct_500_same.csv │ └── result_BoolQ_ROME_Meta-Llama-3-8B-Instruct_500_same.csv ├── GSM8K │ ├── result_GSM8K_FT-M_Meta-Llama-3-8B-Instruct_500_same.csv │ ├── result_GSM8K_ICL_Meta-Llama-3-8B-Instruct_500_same.csv │ └── result_GSM8K_ROME_Meta-Llama-3-8B-Instruct_500_same.csv ├── NLI │ ├── result_NLI_FT-M_Meta-Llama-3-8B-Instruct_500_same.csv │ ├── result_NLI_ICL_Meta-Llama-3-8B-Instruct_500_same.csv │ └── result_NLI_ROME_Meta-Llama-3-8B-Instruct_500_same.csv └── NaturalQuestions │ ├── result_NaturalQuestions_FT-M_Meta-Llama-3-8B-Instruct_500_same.csv │ ├── result_NaturalQuestions_ICL_Meta-Llama-3-8B-Instruct_500_same.csv │ └── result_NaturalQuestions_ROME_Meta-Llama-3-8B-Instruct_500_same.csv └── results_long_tail_misinfomation_injection ├── FT-M_Meta-Llama-3-8B-Instruct_results.json ├── FT-M_Mistral-7B-Instruct-v0.1_results.json ├── FT-M_Mistral-7B-Instruct-v0.2_results.json ├── FT-M_claude2-alpaca-7B_results.json ├── FT-M_vicuna-7b-v1.5_results.json ├── ICL_Meta-Llama-3-8B-Instruct_results.json ├── ICL_Mistral-7B-Instruct-v0.1_results.json ├── ICL_Mistral-7B-Instruct-v0.2_results.json ├── ICL_claude2-alpaca-7B_results.json ├── ICL_vicuna-7b-v1.5_results.json ├── ROME_Meta-Llama-3-8B-Instruct_results.json ├── ROME_Mistral-7B-Instruct-v0.1_results.json ├── ROME_Mistral-7B-Instruct-v0.2_results.json ├── ROME_claude2-alpaca-7B_results.json └── ROME_vicuna-7b-v1.5_results.json /.gitignore: -------------------------------------------------------------------------------- 1 | data_old 2 | **/__pycache__ 3 | .vscode 4 | logs 5 | code/logs 6 | output 7 | code_bkup 8 | tmp 9 | tmp-disability 10 | memit_pre_run 11 | results_easy_edit 12 | results/results_df 13 | results/results_bias_point2plane_json 14 | results/results_bias_point2plane_type_combined 15 | results/results_double_checking 16 | results/results_general_capacity/bkup 17 | results/results_misc 18 | results/results_bias_old 19 | results/results_misinfo_old 20 | results/results_sequential 21 | results/results_bias_point2plane_strict 22 | results/results_bias_point2plane_strict_pro 23 | results/results_bias_point2plane_strict_pro_new 24 | results/results_bias_point2plane_strict_non_unknown 25 | results/results_bias_point2plane_non_unknown_5by5 26 | results/results_commonsense_misinfomation_injection/bkup 27 | results/results_long_tail_misinfomation_injection/bkup 28 | strict_rule_before_eval 29 | code/harm_run2.sh 30 | code/harm_run.sh 31 | tmp-race 32 | backup 33 | api_key.json 34 | test_dfferent_eval_prompt 35 | notes -------------------------------------------------------------------------------- /code/easyeditor/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * 2 | from .editors import * 3 | from .evaluate import * 4 | from .models import * 5 | from .util import * 6 | from .trainer import * -------------------------------------------------------------------------------- /code/easyeditor/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .counterfact import CounterFactDataset 2 | from .zsre import ZsreDataset 3 | from .coco_caption import CaptionDataset 4 | from .vqa import VQADataset 5 | from .wiki_recent import WikiRecentDataset 6 | from .knowedit import KnowEditDataset 7 | from .sanitization import SanitizationTrainDataset 8 | from .multitask import MultiTaskDataset 9 | from .personality import PersonalityDataset 10 | from .safety import SafetyDataset 11 | -------------------------------------------------------------------------------- /code/easyeditor/dataset/attr_snippets.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | from pathlib import Path 4 | 5 | import torch 6 | 7 | from ..util.globals import * 8 | 9 | REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/attribute_snippets.json" 10 | 11 | 12 | class AttributeSnippets: 13 | """ 14 | Contains wikipedia snippets discussing entities that have some property. 15 | 16 | More formally, given a tuple t = (s, r, o): 17 | - Let snips = AttributeSnippets(DATA_DIR) 18 | - snips[r][o] is a list of wikipedia articles for all s' such that t' = (s', r, o) is valid. 19 | """ 20 | 21 | def __init__(self, data_dir: str): 22 | data_dir = Path(data_dir) 23 | snips_loc = data_dir / "attribute_snippets.json" 24 | if not snips_loc.exists(): 25 | print(f"{snips_loc} does not exist. Downloading from {REMOTE_URL}") 26 | data_dir.mkdir(exist_ok=True, parents=True) 27 | torch.hub.download_url_to_file(REMOTE_URL, snips_loc) 28 | 29 | with open(snips_loc, "r") as f: 30 | snippets_list = json.load(f) 31 | 32 | snips = collections.defaultdict(lambda: collections.defaultdict(list)) 33 | 34 | for el in snippets_list: 35 | rid, tid = el["relation_id"], el["target_id"] 36 | for sample in el["samples"]: 37 | snips[rid][tid].append(sample) 38 | 39 | self._data = snips 40 | self.snippets_list = snippets_list 41 | 42 | def __getitem__(self, item): 43 | return self._data[item] 44 | -------------------------------------------------------------------------------- /code/easyeditor/dataset/knowns.py: -------------------------------------------------------------------------------- 1 | import json 2 | import typing 3 | from pathlib import Path 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | from ..util.globals import * 9 | 10 | REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/known_1000.json" 11 | 12 | 13 | class KnownsDataset(Dataset): 14 | def __init__(self, data_dir: str, *args, **kwargs): 15 | data_dir = Path(data_dir) 16 | known_loc = data_dir / "known_1000.json" 17 | if not known_loc.exists(): 18 | print(f"{known_loc} does not exist. Downloading from {REMOTE_URL}") 19 | data_dir.mkdir(exist_ok=True, parents=True) 20 | torch.hub.download_url_to_file(REMOTE_URL, known_loc) 21 | 22 | with open(known_loc, "r") as f: 23 | self.data = json.load(f) 24 | 25 | print(f"Loaded dataset with {len(self)} elements") 26 | 27 | def __len__(self): 28 | return len(self.data) 29 | 30 | def __getitem__(self, item): 31 | return self.data[item] 32 | -------------------------------------------------------------------------------- /code/easyeditor/dataset/processor/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | from typing import Iterable 10 | 11 | from torch.utils.data import Dataset, ConcatDataset 12 | from torch.utils.data.dataloader import default_collate 13 | 14 | 15 | class BaseDataset(Dataset): 16 | def __init__( 17 | self, vis_processor=None, vis_root=None, rephrase_root=None, ann_paths=[] 18 | ): 19 | """ 20 | vis_root (string): Root directory of images (e.g. coco/images/) 21 | ann_root (string): directory to store the annotation file 22 | """ 23 | self.vis_root = vis_root 24 | self.rephrase_root = rephrase_root 25 | 26 | self.annotation = [] 27 | for ann_path in ann_paths: 28 | self.annotation.extend(json.load(open(ann_path, "r"))) 29 | 30 | self.vis_processor = vis_processor 31 | # self.text_processor = text_processor 32 | 33 | self._add_instance_ids() 34 | 35 | def __len__(self): 36 | return len(self.annotation) 37 | 38 | def collater(self, samples): 39 | return default_collate(samples) 40 | 41 | def set_processors(self, vis_processor): 42 | self.vis_processor = vis_processor 43 | # self.text_processor = text_processor 44 | 45 | def _add_instance_ids(self, key="instance_id"): 46 | for idx, ann in enumerate(self.annotation): 47 | ann[key] = str(idx) 48 | 49 | 50 | class ConcatDataset(ConcatDataset): 51 | def __init__(self, datasets: Iterable[Dataset]) -> None: 52 | super().__init__(datasets) 53 | 54 | def collater(self, samples): 55 | # TODO For now only supports datasets with same underlying collater implementations 56 | 57 | all_keys = set() 58 | for s in samples: 59 | all_keys.update(s) 60 | 61 | shared_keys = all_keys 62 | for s in samples: 63 | shared_keys = shared_keys & set(s.keys()) 64 | 65 | samples_shared_keys = [] 66 | for s in samples: 67 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) 68 | 69 | return self.datasets[0].collater(samples_shared_keys) 70 | -------------------------------------------------------------------------------- /code/easyeditor/dataset/processor/base_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | from omegaconf import OmegaConf 9 | 10 | 11 | class BaseProcessor: 12 | def __init__(self): 13 | self.transform = lambda x: x 14 | return 15 | 16 | def __call__(self, item): 17 | return self.transform(item) 18 | 19 | @classmethod 20 | def from_config(cls, cfg=None): 21 | return cls() 22 | 23 | def build(self, **kwargs): 24 | cfg = OmegaConf.create(kwargs) 25 | 26 | return self.from_config(cfg) 27 | -------------------------------------------------------------------------------- /code/easyeditor/dataset/tfidf_stats.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import chain 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import scipy.sparse as sp 7 | import torch 8 | from sklearn.feature_extraction.text import TfidfVectorizer 9 | 10 | from . import AttributeSnippets 11 | from ..util.globals import * 12 | 13 | REMOTE_IDF_URL = f"{REMOTE_ROOT_URL}/data/dsets/idf.npy" 14 | REMOTE_VOCAB_URL = f"{REMOTE_ROOT_URL}/data/dsets/tfidf_vocab.json" 15 | 16 | 17 | def get_tfidf_vectorizer(data_dir: str): 18 | """ 19 | Returns an sklearn TF-IDF vectorizer. See their website for docs. 20 | Loading hack inspired by some online blog post lol. 21 | """ 22 | 23 | data_dir = Path(data_dir) 24 | 25 | idf_loc, vocab_loc = data_dir / "idf.npy", data_dir / "tfidf_vocab.json" 26 | if not (idf_loc.exists() and vocab_loc.exists()): 27 | collect_stats(data_dir) 28 | 29 | idf = np.load(idf_loc) 30 | with open(vocab_loc, "r") as f: 31 | vocab = json.load(f) 32 | 33 | class MyVectorizer(TfidfVectorizer): 34 | TfidfVectorizer.idf_ = idf 35 | 36 | vec = MyVectorizer() 37 | vec.vocabulary_ = vocab 38 | vec._tfidf._idf_diag = sp.spdiags(idf, diags=0, m=len(idf), n=len(idf)) 39 | 40 | return vec 41 | 42 | 43 | def collect_stats(data_dir: str): 44 | """ 45 | Uses wikipedia snippets to collect statistics over a corpus of English text. 46 | Retrieved later when computing TF-IDF vectors. 47 | """ 48 | 49 | data_dir = Path(data_dir) 50 | data_dir.mkdir(exist_ok=True, parents=True) 51 | idf_loc, vocab_loc = data_dir / "idf.npy", data_dir / "tfidf_vocab.json" 52 | 53 | try: 54 | print(f"Downloading IDF cache from {REMOTE_IDF_URL}") 55 | torch.hub.download_url_to_file(REMOTE_IDF_URL, idf_loc) 56 | print(f"Downloading TF-IDF vocab cache from {REMOTE_VOCAB_URL}") 57 | torch.hub.download_url_to_file(REMOTE_VOCAB_URL, vocab_loc) 58 | return 59 | except Exception as e: 60 | print(f"Error downloading file:", e) 61 | print("Recomputing TF-IDF stats...") 62 | 63 | snips_list = AttributeSnippets(data_dir).snippets_list 64 | documents = list(chain(*[[y["text"] for y in x["samples"]] for x in snips_list])) 65 | 66 | vec = TfidfVectorizer() 67 | vec.fit(documents) 68 | 69 | idfs = vec.idf_ 70 | vocab = vec.vocabulary_ 71 | 72 | np.save(data_dir / "idf.npy", idfs) 73 | with open(data_dir / "tfidf_vocab.json", "w") as f: 74 | json.dump(vocab, f, indent=1) 75 | -------------------------------------------------------------------------------- /code/easyeditor/editors/__init__.py: -------------------------------------------------------------------------------- 1 | from .editor import * 2 | from .multimodal_editor import * 3 | from .per_editor import * 4 | from .concept_editor import * 5 | from .safety_editor import * -------------------------------------------------------------------------------- /code/easyeditor/editors/batch_editor.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class BatchEditor(Enum): 5 | CALINET = 'CALINET' 6 | SERAC = 'SERAC' 7 | KE = 'KE' 8 | MEND = 'MEND' 9 | MEMIT = 'MEMIT' 10 | PMET = 'PMET' 11 | FT = 'FT' 12 | LoRA = 'LoRA' 13 | 14 | 15 | @staticmethod 16 | def is_batchable_method(alg_name: str): 17 | return alg_name == BatchEditor.CALINET.value \ 18 | or alg_name == BatchEditor.SERAC.value \ 19 | or alg_name == BatchEditor.KE.value \ 20 | or alg_name == BatchEditor.MEND.value \ 21 | or alg_name == BatchEditor.MEMIT.value \ 22 | or alg_name == BatchEditor.PMET.value \ 23 | or alg_name == BatchEditor.FT.value \ 24 | or alg_name == BatchEditor.LoRA.value 25 | 26 | -------------------------------------------------------------------------------- /code/easyeditor/editors/singleton_editor.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class SingletonEditor(Enum): 5 | ROME = 'ROME' 6 | KN = 'KN' 7 | 8 | @staticmethod 9 | def is_singleton_method(alg_name: str): 10 | return alg_name == SingletonEditor.ROME.value \ 11 | or alg_name == SingletonEditor.KN.value 12 | -------------------------------------------------------------------------------- /code/easyeditor/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluate import * 2 | from .evaluate_utils import * 3 | -------------------------------------------------------------------------------- /code/easyeditor/evaluate/portability_evaluate.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from ..util import HyperParams 3 | from typing import List 4 | import typing 5 | import torch 6 | import numpy as np 7 | from .evaluate_utils import test_batch_prediction_acc, test_seq2seq_batch_prediction_acc, test_prediction_acc 8 | 9 | 10 | def compute_portability_quality( 11 | model, 12 | model_name, 13 | hparams: HyperParams, 14 | tok: AutoTokenizer, 15 | portability_key: str, 16 | prompt: typing.Union[str, List[str]], 17 | ground_truth: typing.Union[str, List[str]], 18 | device, 19 | ) -> typing.Dict: 20 | 21 | if 't5' in model_name.lower(): 22 | portability_correct = test_seq2seq_batch_prediction_acc(model, tok, hparams, prompt, ground_truth, device) 23 | else: 24 | portability_correct = test_prediction_acc(model, tok, hparams, prompt, ground_truth, device, vanilla_generation=hparams.alg_name=='GRACE') 25 | 26 | ret = { 27 | f"{portability_key}_acc": portability_correct 28 | } 29 | return ret 30 | -------------------------------------------------------------------------------- /code/easyeditor/models/README.md: -------------------------------------------------------------------------------- 1 | We compare ROME against several open sourced state-of-the-art model editors. All are implemented in their respective folders. Implementations other than FT/FT+L are adapted from third parties. 2 | - Fine-Tuning (`ft`): Direct fine-tuning. 3 | - Constrained Fine-Tuning (`ft`): FT with $L_\infty$ norm constraint. Inspired by Zhu et al. [[Paper]](https://arxiv.org/abs/2012.00363) 4 | - Knowledge Neurons (`kn`): Dai et al. [[Code]](https://github.com/EleutherAI/knowledge-neurons) [[Paper]](https://arxiv.org/abs/2104.08696) 5 | - Knowledge Editor (`efk`): De Cao et al. [[Code]](https://github.com/eric-mitchell/mend) [[Paper]](https://arxiv.org/abs/2104.08164) 6 | - Model Editor Networks with Gradient Decomposition (`mend`): Mitchell et al. [[Code]](https://github.com/eric-mitchell/mend) [[Paper]](https://arxiv.org/abs/2110.11309) -------------------------------------------------------------------------------- /code/easyeditor/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ft import * 2 | from .ike import * 3 | from .kn import * 4 | from .memit import * 5 | from .mend import * 6 | from .rome import * 7 | from .serac import * 8 | from .pmet import * 9 | from .melo import * 10 | from .grace import * 11 | from .malmen import * 12 | from .dinm import * 13 | 14 | -------------------------------------------------------------------------------- /code/easyeditor/models/dinm/__init__.py: -------------------------------------------------------------------------------- 1 | from .dinm_main import DINMHyperParams, apply_dinm_to_model, execute_dinm 2 | -------------------------------------------------------------------------------- /code/easyeditor/models/dinm/dinm_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | import yaml 4 | 5 | from ...util.hparams import HyperParams 6 | 7 | 8 | @dataclass 9 | class DINMHyperParams(HyperParams): 10 | # Method 11 | layers: List[int] 12 | num_steps: int 13 | lr: float 14 | weight_decay: float 15 | kl_factor: float 16 | norm_constraint: float 17 | model_class: str 18 | tokenizer_class: str 19 | suffix_system_prompt: str 20 | 21 | # Module templates 22 | rewrite_module_tmp: str 23 | layer_module_tmp: str 24 | mlp_module_tmp: str 25 | attn_module_tmp: str 26 | ln_f_module: str 27 | lm_head_module: str 28 | device: int 29 | alg_name: str 30 | model_name: str 31 | # safety_classifier: str 32 | # objective_optimization: str 33 | 34 | # Defaults 35 | batch_size: int = 1 36 | max_length: int = 1000 37 | max_output_length: int = 600 38 | model_parallel: bool = False 39 | 40 | @classmethod 41 | def from_hparams(cls, hparams_name_or_path: str): 42 | 43 | if '.yaml' not in hparams_name_or_path: 44 | hparams_name_or_path = hparams_name_or_path + '.yaml' 45 | 46 | with open(hparams_name_or_path, "r") as stream: 47 | config = yaml.safe_load(stream) 48 | config = super().construct_float_from_scientific_notation(config) 49 | 50 | assert (config and config['alg_name'] == 'DINM') or print(f'DINMHyperParams can not load from {hparams_name_or_path}, ' 51 | f'alg_name is {config["alg_name"]} ') 52 | return cls(**config) 53 | -------------------------------------------------------------------------------- /code/easyeditor/models/ft/__init__.py: -------------------------------------------------------------------------------- 1 | from .ft_main import FTHyperParams, apply_ft_to_model, execute_ft 2 | -------------------------------------------------------------------------------- /code/easyeditor/models/ft/ft_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | import yaml 4 | 5 | from ...util.hparams import HyperParams 6 | 7 | 8 | @dataclass 9 | class FTHyperParams(HyperParams): 10 | # Method 11 | layers: List[int] 12 | num_steps: int 13 | lr: float 14 | weight_decay: float 15 | kl_factor: float 16 | norm_constraint: float 17 | 18 | # Module templates 19 | rewrite_module_tmp: str 20 | layer_module_tmp: str 21 | mlp_module_tmp: str 22 | attn_module_tmp: str 23 | ln_f_module: str 24 | lm_head_module: str 25 | device: int 26 | alg_name: str 27 | model_name: str 28 | objective_optimization: str 29 | 30 | # Defaults 31 | batch_size: int = 64 32 | max_length: int = 40 33 | model_parallel: bool = False 34 | 35 | gpt_eval_endpoint_default: bool = True 36 | gpt_eval_name_default: bool = True 37 | 38 | @classmethod 39 | def from_hparams(cls, hparams_name_or_path: str): 40 | 41 | if '.yaml' not in hparams_name_or_path: 42 | hparams_name_or_path = hparams_name_or_path + '.yaml' 43 | 44 | with open(hparams_name_or_path, "r") as stream: 45 | config = yaml.safe_load(stream) 46 | config = super().construct_float_from_scientific_notation(config) 47 | 48 | assert (config and config['alg_name'] == 'FT') or print(f'FTHyperParams can not load from {hparams_name_or_path}, ' 49 | f'alg_name is {config["alg_name"]} ') 50 | return cls(**config) 51 | -------------------------------------------------------------------------------- /code/easyeditor/models/ft_api/__init__.py: -------------------------------------------------------------------------------- 1 | from .ft_api_main import FTApiHyperParams, apply_ft_api_to_model 2 | -------------------------------------------------------------------------------- /code/easyeditor/models/ft_api/ft_api_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import yaml 3 | from typing import Optional 4 | 5 | from ...util.hparams import HyperParams 6 | 7 | @dataclass 8 | class FTApiHyperParams(HyperParams): 9 | api_key: str 10 | results_dir: str 11 | 12 | alg_name: str 13 | model_name: str 14 | proxy: Optional[str] = None 15 | 16 | @classmethod 17 | def from_hparams(cls, hparams_name_or_path: str): 18 | 19 | if '.yaml' not in hparams_name_or_path: 20 | hparams_name_or_path = hparams_name_or_path + '.yaml' 21 | 22 | with open(hparams_name_or_path, "r") as stream: 23 | config = yaml.safe_load(stream) 24 | config = super().construct_float_from_scientific_notation(config) 25 | 26 | assert (config and config['alg_name'] == 'FT-Api') or print(f'FTApiHyperParams can not load from {hparams_name_or_path}, ' 27 | f'alg_name is {config["alg_name"]} ') 28 | return cls(**config) 29 | -------------------------------------------------------------------------------- /code/easyeditor/models/ft_api/ft_api_main.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | from copy import deepcopy 5 | from typing import Any, Dict, List, Tuple 6 | 7 | import time 8 | 9 | import openai 10 | 11 | from .ft_api_hparams import FTApiHyperParams 12 | 13 | 14 | def apply_ft_api_to_model( 15 | requests: List[Dict], 16 | hparams: FTApiHyperParams, 17 | keep_original_weight=False, 18 | **kwargs 19 | ): 20 | 21 | if len(requests) < 10: 22 | extend_requests = copy.deepcopy(requests) 23 | 24 | while(len(extend_requests) < 10): 25 | extend_requests.extend(requests) 26 | extend_requests = extend_requests[:10] 27 | 28 | print(f"Original length: {len(requests)}.\n FT-Api requires at least 10 samples, we have copied your sample several times", 29 | f"and the current sample length is {len(extend_requests)}.") 30 | else: 31 | extend_requests = copy.deepcopy(requests) 32 | print(f'The current sample length is {len(extend_requests)}.') 33 | 34 | for request in requests: 35 | print( 36 | f"Executing FT-Api algo for: " 37 | f"[{request['prompt']}] -> [{request['target_new']}]" 38 | ) 39 | 40 | example_dir = os.path.join(hparams.results_dir, 'FT-Api', 'example.jsonl') 41 | os.makedirs(os.path.join(hparams.results_dir, 'FT-Api'), exist_ok=True) 42 | 43 | openai.api_key = hparams.api_key 44 | 45 | if hparams.proxy is not None: 46 | openai.proxy = hparams.proxy 47 | 48 | with open(example_dir, 'w', encoding='utf-8') as fout: 49 | for request in extend_requests: 50 | temp_dict = {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, 51 | {"role": "user", "content": f"{request['prompt']}"}, 52 | {"role": "assistant", "content": f"{request['target_new']}"}]} 53 | json_str = json.dumps(temp_dict) 54 | fout.write(json_str) 55 | fout.write('\n') 56 | 57 | openai_file = openai.File.create( 58 | file=open(example_dir, "rb"), 59 | purpose='fine-tune' 60 | ) 61 | 62 | print(openai_file) 63 | 64 | # wait file uploading 65 | while(openai.File.retrieve(f"{openai_file['id']}")['status'] == 'uploaded'): 66 | pass 67 | 68 | openai_job = openai.FineTuningJob.create(training_file=f"{openai_file['id']}", 69 | model=f"{hparams.model_name}") 70 | 71 | start = time.time() 72 | while True: 73 | edited_model = openai.FineTuningJob.retrieve(f"{openai_job['id']}")['fine_tuned_model'] 74 | 75 | if edited_model is None: 76 | print(f'Waiting for openai to complete the fine-tuning task!!! Time Cost:{time.time() - start}s.') 77 | time.sleep(10) 78 | else: 79 | break 80 | print(f'\nfine-tuning task done...., finetuned model name is {edited_model}') 81 | 82 | return edited_model, hparams.model_name 83 | 84 | -------------------------------------------------------------------------------- /code/easyeditor/models/grace/__init__.py: -------------------------------------------------------------------------------- 1 | from .grace_main import GraceHyperParams, apply_grace_to_model 2 | from .metrics import F1, PPL, Accuracy, is_qa_error, is_acc_error 3 | -------------------------------------------------------------------------------- /code/easyeditor/models/grace/grace_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | from ...util.hparams import HyperParams 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class GraceHyperParams(HyperParams): 9 | # Experiments 10 | 11 | edit_lr: int 12 | n_iter: int 13 | # Method 14 | eps: float 15 | dist_fn: str 16 | val_init: str 17 | val_train: str 18 | val_reg: str 19 | reg: str 20 | replacement: str 21 | eps_expand: str 22 | num_pert: str 23 | dropout: float 24 | 25 | # Module templates 26 | inner_params: List[str] 27 | device: int 28 | alg_name: str 29 | model_name: str 30 | 31 | # Defaults 32 | batch_size: int = 128 33 | max_length: int = 30 34 | model_parallel: bool = False 35 | 36 | @classmethod 37 | def from_hparams(cls, hparams_name_or_path: str): 38 | if '.yaml' not in hparams_name_or_path: 39 | hparams_name_or_path = hparams_name_or_path + '.yaml' 40 | 41 | with open(hparams_name_or_path, "r") as stream: 42 | config = yaml.safe_load(stream) 43 | config = super().construct_float_from_scientific_notation(config) 44 | 45 | assert (config and config['alg_name'] == 'GRACE') or print( 46 | f'GraceHyperParams can not load from {hparams_name_or_path}, ' 47 | f'alg_name is {config["alg_name"]} ') 48 | return cls(**config) 49 | -------------------------------------------------------------------------------- /code/easyeditor/models/grace/grace_main.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | import torch 3 | from copy import deepcopy 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | from .GRACE import GRACE 6 | from .grace_hparams import GraceHyperParams 7 | from .utils import tokenize 8 | from ...util import nethook 9 | 10 | 11 | def apply_grace_to_model( 12 | model: AutoModelForCausalLM, 13 | tok: AutoTokenizer, 14 | requests: List[Dict], 15 | hparams: GraceHyperParams, 16 | copy=False, 17 | return_orig_weights=False, 18 | keep_original_weight=False, 19 | **kwargs: Any, 20 | ) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]: 21 | request = requests[0] 22 | if copy: 23 | model = deepcopy(model) 24 | weights_copy = {} 25 | device = torch.device(f'cuda:{hparams.device}') 26 | editor = GRACE(model=model, config=hparams, device=device) 27 | tokens = tokenize(request, tokenizer=tok, device=device) 28 | editor.edit(config=hparams, tokens=tokens,edit_id=request['target_new']) 29 | # editor.rolllback(request['target_new']) 30 | 31 | 32 | with torch.no_grad(): 33 | for w_name in hparams.inner_params: 34 | w_name=w_name.replace("[", ".").replace("]", "") 35 | w = nethook.get_parameter(editor.model, w_name) 36 | weights_copy[w_name]=w 37 | 38 | if keep_original_weight: 39 | weights_copy = editor.reset_layer 40 | 41 | 42 | return editor, weights_copy 43 | 44 | 45 | -------------------------------------------------------------------------------- /code/easyeditor/models/grace/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .utils import * 4 | 5 | def is_acc_error(model, tokens): 6 | # Check whether or not the model's prediction for a batch element is correct 7 | labels = tokens["labels"] 8 | logits = model(**tokens).logits 9 | probs = torch.softmax(logits, -1).squeeze() 10 | argmaxs = torch.argmax(probs, dim=-1).squeeze() 11 | return labels != argmaxs 12 | 13 | def Accuracy(model, tokens): 14 | labels = tokens["labels"] 15 | new_tokens = {f"{k}" : v for k, v in tokens.items() if k != "labels"} 16 | logits = model(**new_tokens).logits 17 | probs = torch.softmax(logits, -1).squeeze() 18 | argmaxs = torch.argmax(probs, dim=-1).squeeze() 19 | return (labels == argmaxs).float().mean() 20 | 21 | def is_qa_error(model, tokens): 22 | preds = model.generate(tokens["input_ids"], max_length=20).squeeze() # Run model to get its predictions 23 | labels = tokens["labels"]#[tokens["labels"] != -100] 24 | 25 | if (len(preds) != len(labels)) or ((preds == labels).sum() != len(preds)): 26 | return True 27 | else: 28 | return False 29 | 30 | def PPL(model, batch): 31 | input_ids = batch["input_ids"][:, :1024]#.to(device) 32 | if "labels" not in batch: 33 | target_ids = batch["input_ids"][:, :1024].clone() 34 | else: 35 | target_ids = batch["labels"][:, :1024].clone() 36 | 37 | with torch.no_grad(): 38 | outputs = model(input_ids=input_ids, labels=target_ids) 39 | nll = outputs.loss 40 | 41 | ppl = torch.exp(nll)#.clip(0, 100) 42 | return ppl 43 | 44 | def F1(model, batch): 45 | try: 46 | preds = model.generate(batch["input_ids"], max_length=20).squeeze() 47 | if len(preds) > 1: 48 | preds = preds[preds != model.tokenizer.pad_token_id] 49 | gold_toks = batch["labels"][batch["labels"] != -100].cpu().squeeze() # -100 might be nonsense 50 | num_same = len(np.intersect1d(preds.cpu().squeeze(), gold_toks)) 51 | if (num_same == 0) or (len(preds.squeeze()) == 0): 52 | return 0 53 | precision = num_same / len(preds.squeeze()) 54 | recall = 1.0 * num_same / len(gold_toks) 55 | f1 = (2 * precision * recall) / (precision + recall) 56 | return f1 57 | except: 58 | # Every once in a while, the model just returns the stop token 59 | return 0 60 | -------------------------------------------------------------------------------- /code/easyeditor/models/grace/utils.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import os 4 | import numpy as np 5 | import datetime 6 | import struct 7 | from torch.nn.utils.rnn import pad_sequence 8 | import torch.nn.functional as F 9 | 10 | def get_inner_params(named_parameters, inner_names): 11 | param_dict = dict(named_parameters) 12 | return [(n, param_dict[n]) for n in inner_names] 13 | 14 | def param_subset(named_parameters, inner_names): 15 | param_dict = dict(named_parameters) 16 | return [param_dict[n] for n in inner_names] 17 | 18 | def parent_module(model, pname): 19 | components = pname.split('.') 20 | parent = model 21 | 22 | for component in components[:-1]: 23 | if hasattr(parent, component): 24 | parent = getattr(parent, component) 25 | elif component.isdigit(): 26 | parent = parent[int(component)] 27 | else: 28 | raise RuntimeError(f"Couldn't find child module {component}") 29 | 30 | if not hasattr(parent, components[-1]): 31 | raise RuntimeError(f"Couldn't find child module {components[-1]}") 32 | 33 | return parent 34 | 35 | def uuid(digits=4): 36 | if not hasattr(uuid, "uuid_value"): 37 | uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10**digits) 38 | 39 | return uuid.uuid_value 40 | 41 | def ckpt_dir(): 42 | """returns the directory in which to store model checkpoints""" 43 | path = "./ckpts/" 44 | if not os.path.exists(path): 45 | os.makedirs(path) 46 | return path 47 | 48 | def brackets_to_periods(name): 49 | return name.replace("[", ".").replace("]", "") 50 | 51 | def get_params(model): 52 | return model.state_dict() 53 | 54 | def get_shape(p, model): 55 | # We need to flip the shapes since OpenAI gpt2 uses convs instead of linear 56 | return p.shape if isinstance(model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0]) 57 | 58 | def get_logits(x): 59 | return x.logits if hasattr(x, "logits") else x 60 | 61 | def tokenize(batch, tokenizer, device, test=False): 62 | prompt, label = batch["prompt"], batch["target_new"] 63 | if not isinstance(prompt, list): 64 | prompt=[prompt] 65 | if not isinstance(label, list): 66 | label=[label] 67 | mask_token = -100 # ignore_index of CrossEntropyLoss 68 | if test or not label: 69 | tokens = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True) 70 | tokens["labels"] = tokens["input_ids"].clone() 71 | tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token 72 | 73 | else: 74 | full_prompt = [f"{p} {l}" for p, l in zip(prompt, label)] 75 | prompt_ids = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True)["input_ids"] 76 | num_prompt_toks = [int((i != tokenizer.pad_token_id).sum()) for i in prompt_ids] 77 | tokens = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True) 78 | tokens["labels"] = tokens["input_ids"].clone() 79 | for i in range(len(prompt)): 80 | tokens["labels"][i][:num_prompt_toks[i]] = mask_token 81 | 82 | tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token 83 | 84 | tokens = {f"{k1}" : v1.to(device) for k1, v1 in tokens.items()} 85 | return tokens 86 | 87 | -------------------------------------------------------------------------------- /code/easyeditor/models/ike/__init__.py: -------------------------------------------------------------------------------- 1 | from .ike_main import IKEHyperParams, apply_ike_to_model 2 | from .ike_main import IKEMultimodalHyperParams, apply_ike_to_multimodal_model 3 | from .ike_main import apply_ike_to_per_model 4 | from .util import encode_ike_facts, encode_ike_facts_multimodal 5 | -------------------------------------------------------------------------------- /code/easyeditor/models/ike/ike_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional 3 | import yaml 4 | 5 | from ...util.hparams import HyperParams 6 | 7 | 8 | @dataclass 9 | class IKEHyperParams(HyperParams): 10 | # Method 11 | k: int # K icl examples 12 | results_dir: str 13 | 14 | # Module templates 15 | device: int 16 | alg_name: str 17 | model_name: str 18 | sentence_model_name: str 19 | 20 | model_parallel: bool = False 21 | 22 | gpt_eval_endpoint_default: bool = True 23 | gpt_eval_name_default: bool = True 24 | 25 | @classmethod 26 | def from_hparams(cls, hparams_name_or_path: str): 27 | 28 | if '.yaml' not in hparams_name_or_path: 29 | hparams_name_or_path = hparams_name_or_path + '.yaml' 30 | 31 | with open(hparams_name_or_path, "r") as stream: 32 | config = yaml.safe_load(stream) 33 | config = super().construct_float_from_scientific_notation(config) 34 | 35 | assert (config and (config['alg_name'] == 'IKE' or config['alg_name'] == 'ICL')) or print(f'IKEHyperParams can not load from {hparams_name_or_path}, ' 36 | f'alg_name is {config["alg_name"]} ') 37 | return cls(**config) 38 | 39 | @dataclass 40 | class IKEMultimodalHyperParams(HyperParams): 41 | # Method 42 | k: int # K icl examples 43 | results_dir: str 44 | 45 | # Module templates 46 | device: int 47 | name: str 48 | alg_name: str 49 | model_name: str 50 | tokenizer_class: str 51 | tokenizer_name: str 52 | sentence_model_name: str 53 | 54 | ## Multimodal 55 | task_name: str 56 | qformer_checkpoint: str 57 | qformer_name_or_path: str 58 | state_dict_file: str 59 | 60 | # Image_dir 61 | coco_image: str 62 | rephrase_image: str 63 | exact_match: bool = False 64 | pretrained_ckpt: Optional[str] = None 65 | 66 | @classmethod 67 | def from_hparams(cls, hparams_name_or_path: str): 68 | 69 | if '.yaml' not in hparams_name_or_path: 70 | hparams_name_or_path = hparams_name_or_path + '.yaml' 71 | 72 | with open(hparams_name_or_path, "r") as stream: 73 | config = yaml.safe_load(stream) 74 | config = super().construct_float_from_scientific_notation(config) 75 | 76 | assert (config and config['alg_name'] == 'IKE') or print(f'IKEMultimodalHyperParams can not load from {hparams_name_or_path}, ' 77 | f'alg_name is {config["alg_name"]} ') 78 | return cls(**config) 79 | -------------------------------------------------------------------------------- /code/easyeditor/models/ike/util.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer 2 | import pickle 3 | from torch.utils.data import Dataset 4 | import os 5 | from .ike_hparams import IKEHyperParams, IKEMultimodalHyperParams 6 | 7 | 8 | def encode_ike_facts(sentence_model: SentenceTransformer, ds: Dataset, hparams: IKEHyperParams): 9 | 10 | sentences = [] 11 | for i, train_data in enumerate(ds): 12 | new_fact = train_data['prompt'] + ' ' + train_data['target_new'] 13 | target_new = train_data['target_new'] 14 | sentences.append(f"New Fact: {new_fact}\nPrompt: {new_fact}\n\n") 15 | if 'rephrase_prompt' in train_data.keys(): 16 | paraphrases = train_data['rephrase_prompt'] 17 | sentences.append(f"New Fact: {new_fact}\nPrompt: {paraphrases} {target_new}\n\n") 18 | if 'locality_prompt' in train_data.keys(): 19 | neighbors_ans = train_data['locality_ground_truth'] 20 | neighbors = train_data['locality_prompt'] 21 | sentences.append(f"New Fact: {new_fact}\nPrompt: {neighbors} {neighbors_ans}\n\n") 22 | 23 | embeddings = sentence_model.encode(sentences) 24 | base_path = f'{hparams.results_dir}/{hparams.alg_name}/embedding' 25 | os.makedirs(base_path, exist_ok=True) 26 | safe_model_name = hparams.sentence_model_name.rsplit('/', 1)[-1] 27 | with open(f'{base_path}/{safe_model_name}_{type(ds).__name__}_{len(ds)}.pkl', "wb") as fOut: 28 | pickle.dump({'sentences': sentences, 'embeddings': embeddings}, fOut, 29 | protocol=pickle.HIGHEST_PROTOCOL) 30 | 31 | 32 | def encode_ike_facts_multimodal(sentence_model: SentenceTransformer, ds: Dataset, hparams: IKEMultimodalHyperParams): 33 | 34 | sentences = [] 35 | for i, train_data in enumerate(ds): 36 | new_fact = train_data['prompt'] + ' ' + train_data['target'] 37 | target_new = train_data['target'] 38 | paraphrases = train_data['rephrase_prompt'] 39 | neighbors = train_data['locality_prompt'] 40 | neighbors_ans = train_data['locality_ground_truth'] 41 | sentences.append(f"New Fact: {new_fact}\nPrompt: {new_fact}\n\n") 42 | sentences.append(f"New Fact: {new_fact}\nPrompt: {paraphrases} {target_new}\n\n") 43 | sentences.append(f"New Fact: {new_fact}\nPrompt: {neighbors} {neighbors_ans}\n\n") 44 | 45 | 46 | embeddings = sentence_model.encode(sentences) 47 | base_path = f'{hparams.results_dir}/{hparams.alg_name}/embedding' 48 | os.makedirs(base_path, exist_ok=True) 49 | safe_model_name = hparams.sentence_model_name.rsplit('/', 1)[-1] 50 | with open(f'{base_path}/{hparams.task_name}_embeddings.pkl', "wb") as fOut: 51 | pickle.dump({'sentences': sentences, 'embeddings': embeddings}, fOut, 52 | protocol=pickle.HIGHEST_PROTOCOL) 53 | -------------------------------------------------------------------------------- /code/easyeditor/models/kn/__init__.py: -------------------------------------------------------------------------------- 1 | from .kn_main import KNHyperParams, apply_kn_to_model 2 | -------------------------------------------------------------------------------- /code/easyeditor/models/kn/kn_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from ...util.hparams import HyperParams 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class KNHyperParams(HyperParams): 9 | lr_scale: float 10 | n_toks: int 11 | model_name: str 12 | refine: bool 13 | batch_size: int 14 | steps: int 15 | adaptive_threshold: float 16 | p: float 17 | device: int 18 | alg_name: str 19 | 20 | max_length: int = 40 21 | model_parallel: bool = False 22 | 23 | gpt_eval_endpoint_default: bool = True 24 | gpt_eval_name_default: bool = True 25 | 26 | @classmethod 27 | def from_hparams(cls, hparams_name_or_path: str): 28 | 29 | if '.yaml' not in hparams_name_or_path: 30 | hparams_name_or_path = hparams_name_or_path + '.yaml' 31 | 32 | with open(hparams_name_or_path, "r") as stream: 33 | config = yaml.safe_load(stream) 34 | config = super().construct_float_from_scientific_notation(config) 35 | 36 | assert (config and config['alg_name'] == 'KN') or print(f'KNHyperParams can not load from {hparams_name_or_path}, ' 37 | f'alg_name is {config["alg_name"]} ') 38 | return cls(**config) 39 | -------------------------------------------------------------------------------- /code/easyeditor/models/kn/kn_main.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Dict, List, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | 8 | from .kn_hparams import KNHyperParams 9 | from .knowledge_neurons.knowledge_neurons import KnowledgeNeurons, model_type 10 | 11 | 12 | def apply_kn_to_model( 13 | model, 14 | tok: AutoTokenizer, 15 | request: List[Dict], 16 | hparams: KNHyperParams, 17 | copy=False, 18 | return_orig_weights=False, 19 | keep_original_weight=True, 20 | **kwargs 21 | ) -> Tuple[AutoModelForCausalLM, List[str]]: 22 | 23 | request = request[0] 24 | kn = KnowledgeNeurons( 25 | model, 26 | tok, 27 | model_type=model_type(hparams.model_name), 28 | device=f"cuda:{hparams.device}", 29 | ) 30 | request_rewrite = deepcopy(request) 31 | text = [request_rewrite["prompt"]] 32 | ground_truth = request_rewrite["ground_truth"] 33 | target = request_rewrite["target_new"] 34 | 35 | # kn.model = kn.model.to(kn.device) 36 | refined_neurons = kn.get_refined_neurons( 37 | text, 38 | ground_truth, 39 | p=hparams.p, 40 | batch_size=hparams.batch_size, 41 | steps=hparams.steps, 42 | coarse_adaptive_threshold=hparams.adaptive_threshold, 43 | refine=hparams.refine, 44 | ) 45 | 46 | results_dict, unpatch_fn = kn.edit_knowledge( 47 | text[0], 48 | target=target, 49 | neurons=refined_neurons, 50 | undo_modification=False, 51 | ) 52 | # updated_model = deepcopy(kn.model) 53 | # if keep_original_weight: 54 | # with torch.no_grad(): 55 | # unpatch_fn() 56 | # kn.model = kn.model.to('cpu') 57 | return kn.model, unpatch_fn 58 | -------------------------------------------------------------------------------- /code/easyeditor/models/kn/knowledge_neurons/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | Copyright (c) 2021 Sid Black 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | The above copyright notice and this permission notice shall be included in all 10 | copies or substantial portions of the Software. 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 13 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 14 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 15 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 16 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 17 | SOFTWARE. -------------------------------------------------------------------------------- /code/easyeditor/models/kn/knowledge_neurons/knowledge_neurons/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | BertLMHeadModel, 3 | BertTokenizer, 4 | GPT2LMHeadModel, 5 | GPT2Tokenizer, 6 | GPTNeoForCausalLM, 7 | ) 8 | 9 | from .data import PARAREL_RELATION_NAMES, pararel, pararel_expanded 10 | from .knowledge_neurons import KnowledgeNeurons 11 | 12 | BERT_MODELS = ["bert-base-uncased", "bert-base-multilingual-uncased"] 13 | GPT2_MODELS = ["gpt2", "gpt2-xl"] 14 | GPT_NEO_MODELS = [ 15 | "EleutherAI/gpt-neo-125M", 16 | "EleutherAI/gpt-neo-1.3B", 17 | "EleutherAI/gpt-neo-2.7B", 18 | ] 19 | ALL_MODELS = BERT_MODELS + GPT2_MODELS + GPT_NEO_MODELS 20 | 21 | 22 | def initialize_model_and_tokenizer(model_name: str): 23 | if model_name in BERT_MODELS: 24 | tokenizer = BertTokenizer.from_pretrained(model_name) 25 | model = BertLMHeadModel.from_pretrained(model_name) 26 | elif model_name in GPT2_MODELS: 27 | tokenizer = GPT2Tokenizer.from_pretrained(model_name) 28 | model = GPT2LMHeadModel.from_pretrained(model_name) 29 | elif model_name in GPT_NEO_MODELS: 30 | tokenizer = GPT2Tokenizer.from_pretrained(model_name) 31 | model = GPTNeoForCausalLM.from_pretrained(model_name) 32 | else: 33 | raise ValueError("Model {model_name} not supported") 34 | 35 | model.eval() 36 | 37 | return model, tokenizer 38 | 39 | 40 | def model_type(model_name: str): 41 | if model_name in BERT_MODELS: 42 | return "bert" 43 | elif 'gpt2' in model_name: 44 | return "gpt2" 45 | elif model_name in GPT_NEO_MODELS: 46 | return "gpt_neo" 47 | elif 'gpt-j' in model_name or 'gptj' in model_name: 48 | return 'gptj' 49 | elif 't5' in model_name: 50 | return 't5' 51 | elif 'llama' in model_name: 52 | return 'llama' 53 | elif 'baichuan' in model_name.lower(): 54 | return 'baichuan' 55 | elif 'chatglm2' in model_name.lower(): 56 | return 'chatglm2' 57 | elif 'internlm' in model_name.lower(): 58 | return 'internlm' 59 | elif 'qwen' in model_name.lower(): 60 | return 'qwen' 61 | elif 'mistral' in model_name.lower(): 62 | return 'mistral' 63 | else: 64 | raise ValueError("Model {model_name} not supported") 65 | -------------------------------------------------------------------------------- /code/easyeditor/models/kn/knowledge_neurons/knowledge_neurons/data.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import os 4 | import urllib.request 5 | from pathlib import Path 6 | 7 | from tqdm import tqdm 8 | 9 | PARAREL_RELATION_NAMES = [ 10 | "P39", 11 | "P264", 12 | "P37", 13 | "P108", 14 | "P131", 15 | "P103", 16 | "P176", 17 | "P30", 18 | "P178", 19 | "P138", 20 | "P47", 21 | "P17", 22 | "P413", 23 | "P27", 24 | "P463", 25 | "P364", 26 | "P495", 27 | "P449", 28 | "P20", 29 | "P1376", 30 | "P1001", 31 | "P361", 32 | "P36", 33 | "P1303", 34 | "P530", 35 | "P19", 36 | "P190", 37 | "P740", 38 | "P136", 39 | "P127", 40 | "P1412", 41 | "P407", 42 | "P140", 43 | "P279", 44 | "P276", 45 | "P159", 46 | "P106", 47 | "P101", 48 | "P937", 49 | ] 50 | 51 | 52 | def pararel(data_path: str = "datasets/pararel.json"): 53 | parent_dir = Path(data_path).parent 54 | os.makedirs(parent_dir, exist_ok=True) 55 | if os.path.exists(data_path): 56 | with open(data_path, "r") as f: 57 | return json.load(f) 58 | else: 59 | PARAREL = collections.defaultdict(dict) 60 | # download relations from github 61 | for r in tqdm(PARAREL_RELATION_NAMES, "downloading pararel data"): 62 | with urllib.request.urlopen( 63 | f"https://raw.githubusercontent.com/yanaiela/pararel/main/data/pattern_data/graphs_json/{r}.jsonl" 64 | ) as url: 65 | graphs = [ 66 | json.loads(d.strip()) for d in url.read().decode().split("\n") if d 67 | ] 68 | PARAREL[r]["graphs"] = graphs 69 | with urllib.request.urlopen( 70 | f"https://raw.githubusercontent.com/yanaiela/pararel/main/data/trex_lms_vocab/{r}.jsonl" 71 | ) as url: 72 | vocab = [ 73 | json.loads(d.strip()) for d in url.read().decode().split("\n") if d 74 | ] 75 | PARAREL[r]["vocab"] = vocab 76 | with open(data_path, "w") as f: 77 | json.dump(PARAREL, f) 78 | return PARAREL 79 | 80 | 81 | def pararel_expanded( 82 | data_path: str = "datasets/pararel_expanded.json", obj_label_replacement=None 83 | ): 84 | parent_dir = Path(data_path).parent 85 | os.makedirs(parent_dir, exist_ok=True) 86 | if os.path.exists(data_path): 87 | with open(data_path, "r") as f: 88 | return json.load(f) 89 | else: 90 | PARAREL = pararel() 91 | PARAREL_EXPANDED = collections.defaultdict(dict) 92 | # expand relations into sentences, grouped by their uuid 93 | for key, value in tqdm( 94 | PARAREL.items(), "expanding pararel dataset into full sentences" 95 | ): 96 | for vocab in value["vocab"]: 97 | for graph in value["graphs"]: 98 | if not PARAREL_EXPANDED.get(vocab["uuid"]): 99 | PARAREL_EXPANDED[vocab["uuid"]] = { 100 | "sentences": [], 101 | "relation_name": key, 102 | "obj_label": vocab["obj_label"], 103 | } 104 | sentence = graph["pattern"] 105 | full_sentence = sentence.replace("[X]", vocab["sub_label"]).replace( 106 | "[Y]", "[MASK]" 107 | ) 108 | PARAREL_EXPANDED[vocab["uuid"]]["sentences"].append(full_sentence) 109 | with open(data_path, "w") as f: 110 | json.dump(PARAREL_EXPANDED, f) 111 | return PARAREL_EXPANDED 112 | -------------------------------------------------------------------------------- /code/easyeditor/models/kn/knowledge_neurons/plot_pararel_results.py: -------------------------------------------------------------------------------- 1 | # plot Figure 3 + 4 from the paper - 2 | # the decreasing ratio of the probability of the correct answer after suppressing knowledge neurons 3 | 4 | import argparse 5 | import json 6 | from glob import glob 7 | from pathlib import Path 8 | 9 | import pandas as pd 10 | import seaborn as sns 11 | 12 | 13 | def format_data(results_data, key="suppression"): 14 | formatted = {} 15 | for uuid, data in results_data.items(): 16 | if formatted.get(data["relation_name"]) is None: 17 | formatted[data["relation_name"]] = {"related": [], "unrelated": []} 18 | 19 | related_data = data[key]["related"] 20 | related_change = [] 21 | for prob in related_data["pct_change"]: 22 | related_change.append(prob) 23 | 24 | unrelated_data = data[key]["unrelated"] 25 | unrelated_change = [] 26 | for prob in unrelated_data["pct_change"]: 27 | unrelated_change.append(prob) 28 | 29 | if data["n_refined_neurons"] > 0 and data["n_unrelated_neurons"] > 0: 30 | # for some prompts we didn't get any neurons back, it would be unfair to include them 31 | if related_change: 32 | related_change = sum(related_change) / len(related_change) 33 | if unrelated_change: 34 | unrelated_change = sum(unrelated_change) / len(unrelated_change) 35 | else: 36 | unrelated_change = 0.0 37 | formatted[data["relation_name"]]["related"].append(related_change) 38 | formatted[data["relation_name"]]["unrelated"].append(unrelated_change) 39 | 40 | for relation_name, data in formatted.items(): 41 | if data["related"]: 42 | data["related"] = sum(data["related"]) / len(data["related"]) 43 | else: 44 | data["related"] = float("nan") 45 | if data["unrelated"]: 46 | data["unrelated"] = sum(data["unrelated"]) / len(data["unrelated"]) 47 | else: 48 | data["unrelated"] = float("nan") 49 | 50 | pandas_format = {"relation_name": [], "related": [], "pct_change": []} 51 | for relation_name, data in formatted.items(): 52 | verb = "Suppressing" if key == "suppression" else "Enhancing" 53 | pandas_format["relation_name"].append(relation_name) 54 | pandas_format["pct_change"].append(data["related"]) 55 | pandas_format["related"].append(f"{verb} knowledge neurons for related facts") 56 | 57 | pandas_format["relation_name"].append(relation_name) 58 | pandas_format["pct_change"].append(data["unrelated"]) 59 | pandas_format["related"].append(f"{verb} knowledge neurons for unrelated facts") 60 | return pd.DataFrame(pandas_format).dropna() 61 | 62 | 63 | def plot_data(pd_df, experiment_type, out_path="test.png"): 64 | sns.set_theme(style="whitegrid") 65 | if experiment_type == "suppression": 66 | title = "Suppressing knowledge neurons" 67 | elif experiment_type == "enhancement": 68 | title = "Enhancing knowledge neurons" 69 | else: 70 | raise ValueError 71 | # Draw a nested barplot by species and sex 72 | g = sns.catplot( 73 | data=pd_df, 74 | kind="bar", 75 | x="relation_name", 76 | y="pct_change", 77 | hue="related", 78 | ci="sd", 79 | palette="dark", 80 | alpha=0.6, 81 | height=6, 82 | aspect=4, 83 | ) 84 | g.despine(left=True) 85 | g.set_axis_labels("relation name", "Correct probability percentage change") 86 | g.legend.set_title(title) 87 | g.savefig(out_path) 88 | 89 | 90 | if __name__ == "__main__": 91 | # parse arguments 92 | parser = argparse.ArgumentParser("Arguments for pararel result plotting") 93 | parser.add_argument( 94 | "--results_dir", 95 | default="bert_base_uncased_neurons/", 96 | type=str, 97 | help="directory in which the results from pararel_evaluate.py are saved.", 98 | ) 99 | args = parser.parse_args() 100 | results_dir = Path(args.results_dir) 101 | 102 | # load results 103 | result_paths = results_dir.glob("*results_*.json") 104 | results = {} 105 | for p in result_paths: 106 | with open(p) as f: 107 | results.update(json.load(f)) 108 | 109 | # plot results of suppression experiment 110 | suppression_data = format_data(results, key="suppression") 111 | plot_data(suppression_data, "suppression", out_path="images/suppress.png") 112 | 113 | # plot results of enhancement experiment 114 | enhancement_data = format_data(results, key="enhancement") 115 | plot_data(enhancement_data, "enhancement", out_path="images/enhance.png") 116 | -------------------------------------------------------------------------------- /code/easyeditor/models/kn/knowledge_neurons/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | einops 3 | numpy 4 | torch 5 | seaborn -------------------------------------------------------------------------------- /code/easyeditor/models/kn/knowledge_neurons/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open("README.md", "r") as f: 4 | long_description = f.read() 5 | 6 | name = "knowledge-neurons" 7 | setup( 8 | name=name, 9 | packages=find_packages(), 10 | version="0.0.2", 11 | license="MIT", 12 | description="A library for finding knowledge neurons in pretrained transformer models", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url=f"https://github.com/EleutherAI/{name}", 16 | author="Sid Black", 17 | author_email="sdtblck@gmail.com", 18 | install_requires=["transformers", "einops", "numpy", "torch", "seaborn"], 19 | classifiers=[ 20 | "License :: OSI Approved :: MIT License", 21 | "Programming Language :: Python :: 3.6", 22 | "Programming Language :: Python :: 3.7", 23 | "Programming Language :: Python :: 3.8", 24 | "Programming Language :: Python :: 3.9", 25 | "Programming Language :: Python :: 3.10", 26 | ], 27 | ) 28 | -------------------------------------------------------------------------------- /code/easyeditor/models/lora/__init__.py: -------------------------------------------------------------------------------- 1 | from .lora_main import LoRAHyperParams, apply_lora_to_model, execute_lora 2 | -------------------------------------------------------------------------------- /code/easyeditor/models/lora/lora_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | from ...util.hparams import HyperParams 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class LoRAHyperParams(HyperParams): 9 | # Method 10 | lora_type: str 11 | layers: List[int] 12 | num_steps: int 13 | lr: float 14 | weight_decay: float 15 | kl_factor: float 16 | norm_constraint: float 17 | target_modules: List[str] 18 | rank: int 19 | lora_alpha: float 20 | lora_dropout: float 21 | # Module templates 22 | 23 | device: int 24 | alg_name: str 25 | model_name: str 26 | 27 | # Defaults 28 | batch_size: int = 128 29 | max_length: int = 40 30 | model_parallel: bool = False 31 | 32 | gpt_eval_endpoint_default: bool = True 33 | gpt_eval_name_default: bool = True 34 | 35 | @classmethod 36 | def from_hparams(cls, hparams_name_or_path: str): 37 | if '.yaml' not in hparams_name_or_path: 38 | hparams_name_or_path = hparams_name_or_path + '.yaml' 39 | 40 | with open(hparams_name_or_path, "r") as stream: 41 | config = yaml.safe_load(stream) 42 | config = super().construct_float_from_scientific_notation(config) 43 | 44 | assert (config and config['alg_name'] == 'LoRA') or print( 45 | f'LoRAHyperParams can not load from {hparams_name_or_path}, ' 46 | f'alg_name is {config["alg_name"]} ') 47 | return cls(**config) 48 | -------------------------------------------------------------------------------- /code/easyeditor/models/malmen/__init__.py: -------------------------------------------------------------------------------- 1 | from .malmen_hparams import MALMENHyperParams 2 | from .malmen_main import MalmenRewriteExecutor 3 | -------------------------------------------------------------------------------- /code/easyeditor/models/malmen/malmen_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from ...util.hparams import HyperParams 3 | from typing import Optional, Any, List 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class MALMENHyperParams(HyperParams): 9 | alg_name: str 10 | 11 | # Model 12 | model_name: str 13 | model_class: str 14 | tokenizer_class: str 15 | tokenizer_name: str 16 | inner_params: List[str] 17 | device: int 18 | archive: Any 19 | 20 | # Method 21 | alg: str 22 | debug: bool 23 | dropout: float 24 | train_base: bool 25 | no_grad_layers: Any 26 | rank: int 27 | n_edits: int 28 | n_blocks: int 29 | lr: float 30 | meta_lr: float 31 | loc_coef: float 32 | max_grad_norm: float 33 | token: str 34 | 35 | # Output 36 | results_dir: str 37 | 38 | # Train 39 | batch_size: int 40 | editor_batch_size: int 41 | silent: bool 42 | log_interval: int 43 | eval_log_interval:int 44 | final_eval:bool 45 | val_interval: int 46 | early_stop_patience: int 47 | early_stop_key: str 48 | eval_only: bool 49 | save: bool 50 | 51 | val_batch_size: Optional[int] 52 | val_steps: int 53 | 54 | max_length: int = 40 55 | 56 | model_save_pt: Optional[int]=5000 57 | half: Optional[bool] = False 58 | model_parallel: bool = False 59 | max_epochs: Optional[int] = None 60 | max_iters: Optional[int] = None 61 | 62 | @classmethod 63 | def from_hparams(cls, hparams_name_or_path: str): 64 | 65 | if '.yaml' not in hparams_name_or_path: 66 | hparams_name_or_path = hparams_name_or_path + '.yaml' 67 | 68 | with open(hparams_name_or_path, "r") as stream: 69 | config = yaml.safe_load(stream) 70 | config = super().construct_float_from_scientific_notation(config) 71 | 72 | assert (config and config['alg'] == 'MALMEN') or print(f'MALMENTrainingHyperParams can not load from {hparams_name_or_path}, ' 73 | f'alg_name is {config["alg"]} ') 74 | config['val_batch_size'] = config['batch_size'] 75 | return cls(**config) 76 | 77 | -------------------------------------------------------------------------------- /code/easyeditor/models/malmen/malmen_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from typing import Dict, List, Any, Tuple 4 | 5 | import hydra 6 | import torch 7 | from collections import deque 8 | from transformers import AutoModelForCausalLM, AutoTokenizer 9 | 10 | from ...util.globals import * 11 | 12 | from ...trainer import MALMEN 13 | from .malmen_hparams import MALMENHyperParams 14 | 15 | class MalmenRewriteExecutor: 16 | def __init__(self): 17 | self.is_init = False 18 | 19 | def init_model(self, model, tok, params: MALMENHyperParams): 20 | 21 | assert params.archive is not None or print(f'Training weights Needed....') 22 | # Customize the gpt2xl and tokenizer 23 | self.model = model 24 | self.tokenizer = tok 25 | # add_padding(self.tokenizer, self.model) 26 | 27 | # Load the trained MEND model 28 | self.alg = MALMEN(self.model, params, lambda: deepcopy(self.model)) 29 | d = torch.load(params.archive, map_location=f'cuda:{params.device}') 30 | self.alg.load_state_dict(d["model"]) 31 | if params.model_parallel: 32 | self.alg.net.to(deque(self.alg.model.parameters(), maxlen=1)[0].device) 33 | else: 34 | self.alg.to(torch.device(f'cuda:{params.device}')) 35 | 36 | 37 | def reset_model(self): 38 | self.is_init = False 39 | del self.model, self.tokenizer, self.alg 40 | 41 | def apply_to_model( 42 | self, 43 | model: AutoModelForCausalLM, 44 | tok: AutoTokenizer, 45 | requests: List[Dict], 46 | hparams: MALMENHyperParams, 47 | copy=False, 48 | return_orig_weights=False, 49 | keep_original_weight=False, 50 | **kwargs 51 | ): 52 | """ 53 | Given a request, for example 54 | {'prompt': '{} has the position of', 55 | 'subject': 'Charles Herman Helmsing', 56 | 'relation_id': 'P39', 57 | 'target_new': {'str': 'President', 'id': 'Q11696'}, 58 | 'target_true': {'str': 'bishop', 'id': 'Q29182'}} 59 | Returns a dictionary of numpy arrays that specifies 60 | how mend will change the weights of the model. 61 | """ 62 | 63 | if not self.is_init: 64 | self.init_model(model, tok, hparams) 65 | 66 | weights_copy = {} 67 | model = deepcopy(self.model) if copy else self.model 68 | assert len(requests) >= hparams.n_edits, "The number of requests must be greater than or equal to the value of n_edits." 69 | # Define i/o 70 | requests = requests[:hparams.n_edits] 71 | batchs = [] 72 | for i in range(hparams.n_edits // hparams.batch_size): 73 | batch = requests[i * hparams.batch_size : (i+1)*hparams.batch_size] 74 | targets = [ 75 | (" " if request["target_new"][0] != " " else "") 76 | + request["target_new"] 77 | for request in batch 78 | ] 79 | sentences = [ 80 | request["prompt"] + targets[i] 81 | for i, request in enumerate(batch) 82 | ] 83 | 84 | # Tokenize 85 | sent_tok = self.tokenizer(sentences, padding=True, return_tensors="pt").to( 86 | f"cuda:{hparams.device}" 87 | ) 88 | target_tok = self.tokenizer(targets, padding=True, return_tensors="pt").to( 89 | f"cuda:{hparams.device}" 90 | ) 91 | 92 | # Define labels 93 | label_tok = deepcopy(sent_tok["input_ids"]) 94 | for i in range(label_tok.size(0)): 95 | target_len = target_tok["attention_mask"][i].sum() 96 | padding_len = ( 97 | sent_tok["input_ids"].size(1) - sent_tok["attention_mask"][i].sum() 98 | ) 99 | label_tok[i][: -target_len - padding_len] = -100 100 | label_tok[i][label_tok[i] == self.tokenizer.pad_token_id] = -100 101 | 102 | edit_inner = dict( 103 | input_ids=sent_tok["input_ids"], 104 | attention_mask=sent_tok["attention_mask"], 105 | labels=target_tok['input_ids'], 106 | ) 107 | 108 | batchs.append(edit_inner) 109 | # Run M 110 | module_kv_map = self.alg.cache(batchs) 111 | param_shifts = self.alg.predict_param_shifts(module_kv_map) 112 | with torch.no_grad(): 113 | for n, p in self.model.named_parameters(): 114 | if n in hparams.inner_params: 115 | if return_orig_weights and n not in weights_copy: 116 | weights_copy[n] = p.detach().clone() 117 | self.alg.edit_model(param_shifts, False) 118 | 119 | 120 | if not keep_original_weight: 121 | weights_copy = {} 122 | 123 | return self.alg.model, weights_copy 124 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/__init__.py: -------------------------------------------------------------------------------- 1 | from .melo_main import MELOHyperParams,apply_melo_to_model 2 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/melo_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | from ...util.hparams import HyperParams 4 | import yaml 5 | 6 | @dataclass 7 | class GRACEHyperParams(HyperParams): 8 | name: str 9 | num_iter: int 10 | init_radius: float 11 | dist_fn: str # euc, mmd, cos 12 | val_init: str # cold, warm 13 | val_train: str # sgd, pert 14 | val_reg: bool # early 15 | reg: str # early_stop 16 | replacement: str # replace_last, replace_all, replace_prompt 17 | expand_mode: str # , moving_avg, decay 18 | num_pert: int # only matters when using perturbation training 19 | key_id: int 20 | num_edit_per_block: int 21 | num_block: int 22 | num_rank_per_block: int 23 | metric_period: int 24 | edit_lr: float 25 | 26 | @dataclass 27 | class MODELHyperParams(HyperParams): 28 | name: str 29 | class_name: str 30 | tokenizer_class: str 31 | tokenizer_name: str 32 | fan_in_fan_out: bool 33 | target_modules: list[str] 34 | pt: str # set this to 'hallucination' inside your checkpoint directory 35 | grace_layer: str 36 | @dataclass 37 | class LoRAHyperParams(HyperParams): 38 | cls_name: str 39 | cls_class: str 40 | supervised: bool 41 | cos: bool 42 | freeze: str 43 | square: bool 44 | bound_embeds: bool 45 | use_all_negatives: bool 46 | freeze_lora: bool 47 | dist_heads: int 48 | cross_attend: bool 49 | soft_weighting: bool 50 | checkpoint_grad: bool 51 | lora_r: int 52 | lora_alpha: int 53 | lora_dropout: float 54 | 55 | @dataclass 56 | class MELOHyperParams(HyperParams): 57 | model_name: str 58 | alg_name: str 59 | model_parallel: bool 60 | device: int 61 | max_length: int 62 | task: str 63 | lora_task_type: str 64 | check_dir: str 65 | grace: GRACEHyperParams 66 | model: MODELHyperParams 67 | lora: LoRAHyperParams 68 | 69 | @classmethod 70 | def from_hparams(cls, hparams_name_or_path: str): 71 | if '.yaml' not in hparams_name_or_path: 72 | hparams_name_or_path = hparams_name_or_path + '.yaml' 73 | 74 | with open(hparams_name_or_path, "r") as stream: 75 | config = yaml.safe_load(stream) 76 | config = super().construct_float_from_scientific_notation(config) 77 | 78 | assert (config and config['alg_name'] == 'MELO') or print( 79 | f'GraceHyperParams can not load from {hparams_name_or_path}, ' 80 | f'alg_name is {config["alg_name"]} ') 81 | 82 | grace_config = GRACEHyperParams(**config['grace']) 83 | config['grace'] = grace_config 84 | model_config = MODELHyperParams(**config['model']) 85 | config['model'] = model_config 86 | lora_config = LoRAHyperParams(**config['lora']) 87 | config['lora'] = lora_config 88 | return cls(**config) 89 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/melo_main.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | import torch 3 | from copy import deepcopy 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | from .melo_hparams import MELOHyperParams 6 | from .util import get_tokenizer 7 | from .melo import LORA 8 | from ...util import nethook 9 | 10 | 11 | def apply_melo_to_model( 12 | model: AutoModelForCausalLM, 13 | tok: AutoTokenizer, 14 | requests: List[Dict], 15 | hparams: MELOHyperParams, 16 | copy=False, 17 | return_orig_weights=False, 18 | keep_original_weight=False, 19 | **kwargs: Any, 20 | ) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]: 21 | # only support single edit.we will support sequence edit soon 22 | if keep_original_weight: 23 | model=deepcopy(model) 24 | weights_copy = {} 25 | device = torch.device(f'cuda:{hparams.device}') 26 | tokenizer = get_tokenizer(hparams) 27 | if not isinstance(model, LORA): 28 | editor = LORA(model, hparams,tokenizer) 29 | else: 30 | editor = model 31 | tokens = tokenizer(requests[0], tok,device) 32 | editor.to(device) 33 | editor.edit(tokens) 34 | return editor,weights_copy 35 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: quality style test docs 2 | 3 | check_dirs := src tests examples docs 4 | 5 | # Check that source code meets quality standards 6 | 7 | # this target runs checks on all files 8 | quality: 9 | black --check $(check_dirs) 10 | ruff $(check_dirs) 11 | doc-builder style src/peft tests docs/source --max_len 119 --check_only 12 | 13 | # Format source code automatically and check is there are any problems left that need manual fixing 14 | style: 15 | black $(check_dirs) 16 | ruff $(check_dirs) --fix 17 | doc-builder style src/peft tests docs/source --max_len 119 18 | 19 | test: 20 | python -m pytest -n 3 tests/ $(if $(IS_GITHUB_CI),--report-log "ci_tests.log",) 21 | 22 | tests_examples_multi_gpu: 23 | python -m pytest -m multi_gpu_tests tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "multi_gpu_examples.log",) 24 | 25 | tests_examples_single_gpu: 26 | python -m pytest -m single_gpu_tests tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "single_gpu_examples.log",) 27 | 28 | tests_core_multi_gpu: 29 | python -m pytest -m multi_gpu_tests tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_multi_gpu.log",) 30 | 31 | tests_core_single_gpu: 32 | python -m pytest -m single_gpu_tests tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_single_gpu.log",) 33 | 34 | tests_common_gpu: 35 | python -m pytest tests/test_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_decoder.log",) 36 | python -m pytest tests/test_encoder_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_encoder_decoder.log",) 37 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/__init__.py: -------------------------------------------------------------------------------- 1 | from .src.peft import * 2 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/docker/peft-cpu/Dockerfile: -------------------------------------------------------------------------------- 1 | # Builds GPU docker image of PyTorch 2 | # Uses multi-staged approach to reduce size 3 | # Stage 1 4 | # Use base conda image to reduce time 5 | FROM continuumio/miniconda3:latest AS compile-image 6 | # Specify py version 7 | ENV PYTHON_VERSION=3.8 8 | # Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 9 | RUN apt-get update && \ 10 | apt-get install -y curl git wget software-properties-common git-lfs && \ 11 | apt-get clean && \ 12 | rm -rf /var/lib/apt/lists* 13 | 14 | # Install audio-related libraries 15 | RUN apt-get update && \ 16 | apt install -y ffmpeg 17 | 18 | RUN apt install -y libsndfile1-dev 19 | RUN git lfs install 20 | 21 | # Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 22 | RUN conda create --name peft python=${PYTHON_VERSION} ipython jupyter pip 23 | RUN python3 -m pip install --no-cache-dir --upgrade pip 24 | 25 | # Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 26 | # We don't install pytorch here yet since CUDA isn't available 27 | # instead we use the direct torch wheel 28 | ENV PATH /opt/conda/envs/peft/bin:$PATH 29 | # Activate our bash shell 30 | RUN chsh -s /bin/bash 31 | SHELL ["/bin/bash", "-c"] 32 | # Activate the conda env and install transformers + accelerate from source 33 | RUN source activate peft && \ 34 | python3 -m pip install --no-cache-dir \ 35 | librosa \ 36 | "soundfile>=0.12.1" \ 37 | scipy \ 38 | git+https://github.com/huggingface/transformers \ 39 | git+https://github.com/huggingface/accelerate \ 40 | peft[test]@git+https://github.com/huggingface/peft 41 | 42 | # Install apt libs 43 | RUN apt-get update && \ 44 | apt-get install -y curl git wget && \ 45 | apt-get clean && \ 46 | rm -rf /var/lib/apt/lists* 47 | 48 | RUN echo "source activate peft" >> ~/.profile 49 | 50 | # Activate the virtualenv 51 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/docker/peft-gpu/Dockerfile: -------------------------------------------------------------------------------- 1 | # Builds GPU docker image of PyTorch 2 | # Uses multi-staged approach to reduce size 3 | # Stage 1 4 | # Use base conda image to reduce time 5 | FROM continuumio/miniconda3:latest AS compile-image 6 | # Specify py version 7 | ENV PYTHON_VERSION=3.8 8 | # Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 9 | RUN apt-get update && \ 10 | apt-get install -y curl git wget software-properties-common git-lfs && \ 11 | apt-get clean && \ 12 | rm -rf /var/lib/apt/lists* 13 | 14 | # Install audio-related libraries 15 | RUN apt-get update && \ 16 | apt install -y ffmpeg 17 | 18 | RUN apt install -y libsndfile1-dev 19 | RUN git lfs install 20 | 21 | # Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 22 | RUN conda create --name peft python=${PYTHON_VERSION} ipython jupyter pip 23 | RUN python3 -m pip install --no-cache-dir --upgrade pip 24 | 25 | # Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 26 | # We don't install pytorch here yet since CUDA isn't available 27 | # instead we use the direct torch wheel 28 | ENV PATH /opt/conda/envs/peft/bin:$PATH 29 | # Activate our bash shell 30 | RUN chsh -s /bin/bash 31 | SHELL ["/bin/bash", "-c"] 32 | # Activate the conda env and install transformers + accelerate from source 33 | RUN source activate peft && \ 34 | python3 -m pip install --no-cache-dir \ 35 | librosa \ 36 | "soundfile>=0.12.1" \ 37 | scipy \ 38 | git+https://github.com/huggingface/transformers \ 39 | git+https://github.com/huggingface/accelerate \ 40 | peft[test]@git+https://github.com/huggingface/peft 41 | 42 | RUN python3 -m pip install --no-cache-dir bitsandbytes 43 | 44 | # Stage 2 45 | FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 AS build-image 46 | COPY --from=compile-image /opt/conda /opt/conda 47 | ENV PATH /opt/conda/bin:$PATH 48 | 49 | # Install apt libs 50 | RUN apt-get update && \ 51 | apt-get install -y curl git wget && \ 52 | apt-get clean && \ 53 | rm -rf /var/lib/apt/lists* 54 | 55 | RUN echo "source activate peft" >> ~/.profile 56 | 57 | # Activate the virtualenv 58 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/docs/source/_config.py: -------------------------------------------------------------------------------- 1 | # docstyle-ignore 2 | INSTALL_CONTENT = """ 3 | # PEFT installation 4 | ! pip install peft accelerate transformers 5 | # To install from source instead of the last release, comment the command above and uncomment the following one. 6 | # ! pip install git+https://github.com/huggingface/peft.git 7 | """ 8 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/docs/source/_toctree.yml: -------------------------------------------------------------------------------- 1 | - title: Get started 2 | sections: 3 | - local: index 4 | title: 🤗 PEFT 5 | - local: quicktour 6 | title: Quicktour 7 | - local: install 8 | title: Installation 9 | 10 | - title: Task guides 11 | sections: 12 | - local: task_guides/image_classification_lora 13 | title: Image classification using LoRA 14 | - local: task_guides/seq2seq-prefix-tuning 15 | title: Prefix tuning for conditional generation 16 | - local: task_guides/clm-prompt-tuning 17 | title: Prompt tuning for causal language modeling 18 | - local: task_guides/semantic_segmentation_lora 19 | title: Semantic segmentation using LoRA 20 | - local: task_guides/ptuning-seq-classification 21 | title: P-tuning for sequence classification 22 | - local: task_guides/dreambooth_lora 23 | title: Dreambooth fine-tuning with LoRA 24 | - local: task_guides/token-classification-lora 25 | title: LoRA for token classification 26 | - local: task_guides/int8-asr 27 | title: int8 training for automatic speech recognition 28 | 29 | - title: 🤗 Accelerate integrations 30 | sections: 31 | - local: accelerate/deepspeed-zero3-offload 32 | title: DeepSpeed 33 | - local: accelerate/fsdp 34 | title: Fully Sharded Data Parallel 35 | 36 | - title: Conceptual guides 37 | sections: 38 | - local: conceptual_guides/lora 39 | title: LoRA 40 | - local: conceptual_guides/prompting 41 | title: Prompting 42 | 43 | - title: Reference 44 | sections: 45 | - local: package_reference/peft_model 46 | title: PEFT model 47 | - local: package_reference/config 48 | title: Configuration 49 | - local: package_reference/tuners 50 | title: Tuners -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/docs/source/conceptual_guides/lora.mdx: -------------------------------------------------------------------------------- 1 | 12 | 13 | # LoRA 14 | 15 | This conceptual guide gives a brief overview of [LoRA](https://arxiv.org/abs/2106.09685), a technique that accelerates 16 | the fine-tuning of large models while consuming less memory. 17 | 18 | To make fine-tuning more efficient, LoRA's approach is to represent the weight updates with two smaller 19 | matrices (called **update matrices**) through low-rank decomposition. These new matrices can be trained to adapt to the 20 | new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn't receive 21 | any further adjustments. To produce the final results, both the original and the adapted weights are combined. 22 | 23 | This approach has a number of advantages: 24 | 25 | * LoRA makes fine-tuning more efficient by drastically reducing the number of trainable parameters. 26 | * The original pre-trained weights are kept frozen, which means you can have multiple lightweight and portable LoRA models for various downstream tasks built on top of them. 27 | * LoRA is orthogonal to many other parameter-efficient methods and can be combined with many of them. 28 | * Performance of models fine-tuned using LoRA is comparable to the performance of fully fine-tuned models. 29 | * LoRA does not add any inference latency because adapter weights can be merged with the base model. 30 | 31 | In principle, LoRA can be applied to any subset of weight matrices in a neural network to reduce the number of trainable 32 | parameters. However, for simplicity and further parameter efficiency, in Transformer models LoRA is typically applied to 33 | attention blocks only. The resulting number of trainable parameters in a LoRA model depends on the size of the low-rank 34 | update matrices, which is determined mainly by the rank `r` and the shape of the original weight matrix. 35 | 36 | ## Common LoRA parameters in PEFT 37 | 38 | As with other methods supported by PEFT, to fine-tune a model using LoRA, you need to: 39 | 40 | 1. Instantiate a base model. 41 | 2. Create a configuration (`LoraConfig`) where you define LoRA-specific parameters. 42 | 3. Wrap the base model with `get_peft_model()` to get a trainable `PeftModel`. 43 | 4. Train the `PeftModel` as you normally would train the base model. 44 | 45 | `LoraConfig` allows you to control how LoRA is applied to the base model through the following parameters: 46 | 47 | - `r`: the rank of the update matrices, expressed in `int`. Lower rank results in smaller update matrices with fewer trainable parameters. 48 | - `target_modules`: The modules (for example, attention blocks) to apply the LoRA update matrices. 49 | - `alpha`: LoRA scaling factor. 50 | - `bias`: Specifies if the `bias` parameters should be trained. Can be `'none'`, `'all'` or `'lora_only'`. 51 | - `modules_to_save`: List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task. 52 | - `layers_to_transform`: List of layers to be transformed by LoRA. If not specified, all layers in `target_modules` are transformed. 53 | - `layers_pattern`: Pattern to match layer names in `target_modules`, if `layers_to_transform` is specified. By default `PeftModel` will look at common layer pattern (`layers`, `h`, `blocks`, etc.), use it for exotic and custom models. 54 | 55 | ## LoRA examples 56 | 57 | For an example of LoRA method application to various downstream tasks, please refer to the following guides: 58 | 59 | * [Image classification using LoRA](../task_guides/image_classification_lora) 60 | * [Semantic segmentation](../task_guides/semantic_segmentation_lora) 61 | 62 | While the original paper focuses on language models, the technique can be applied to any dense layers in deep learning 63 | models. As such, you can leverage this technique with diffusion models. See [Dreambooth fine-tuning with LoRA](../task_guides/task_guides/dreambooth_lora) task guide for an example. 64 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/docs/source/install.mdx: -------------------------------------------------------------------------------- 1 | 12 | 13 | # Installation 14 | 15 | Before you start, you will need to setup your environment, install the appropriate packages, and configure 🤗 PEFT. 🤗 PEFT is tested on **Python 3.7+**. 16 | 17 | 🤗 PEFT is available on pypi, as well as GitHub: 18 | 19 | ## pip 20 | 21 | To install 🤗 PEFT from pypi: 22 | 23 | ```bash 24 | pip install peft 25 | ``` 26 | 27 | ## Source 28 | 29 | New features that haven't been released yet are added every day, which also means there may be some bugs. To try them out, install from the GitHub repository: 30 | 31 | ```bash 32 | pip install git+https://github.com/huggingface/peft 33 | ``` 34 | 35 | If you're working on contributing to the library or wish to play with the source code and see live 36 | results as you run the code, an editable version can be installed from a locally-cloned version of the 37 | repository: 38 | 39 | ```bash 40 | git clone https://github.com/huggingface/peft 41 | cd peft 42 | pip install -e . 43 | ``` 44 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/docs/source/package_reference/config.mdx: -------------------------------------------------------------------------------- 1 | # Configuration 2 | 3 | The configuration classes stores the configuration of a [`PeftModel`], PEFT adapter models, and the configurations of [`PrefixTuning`], [`PromptTuning`], and [`PromptEncoder`]. They contain methods for saving and loading model configurations from the Hub, specifying the PEFT method to use, type of task to perform, and model configurations like number of layers and number of attention heads. 4 | 5 | ## PeftConfigMixin 6 | 7 | [[autodoc]] utils.config.PeftConfigMixin 8 | - all 9 | 10 | ## PeftConfig 11 | 12 | [[autodoc]] PeftConfig 13 | - all 14 | 15 | ## PromptLearningConfig 16 | 17 | [[autodoc]] PromptLearningConfig 18 | - all 19 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/docs/source/package_reference/peft_model.mdx: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | [`PeftModel`] is the base model class for specifying the base Transformer model and configuration to apply a PEFT method to. The base `PeftModel` contains methods for loading and saving models from the Hub, and supports the [`PromptEncoder`] for prompt learning. 4 | 5 | ## PeftModel 6 | 7 | [[autodoc]] PeftModel 8 | - all 9 | 10 | ## PeftModelForSequenceClassification 11 | 12 | A `PeftModel` for sequence classification tasks. 13 | 14 | [[autodoc]] PeftModelForSequenceClassification 15 | - all 16 | 17 | ## PeftModelForTokenClassification 18 | 19 | A `PeftModel` for token classification tasks. 20 | 21 | [[autodoc]] PeftModelForTokenClassification 22 | - all 23 | 24 | ## PeftModelForCausalLM 25 | 26 | A `PeftModel` for causal language modeling. 27 | 28 | [[autodoc]] PeftModelForCausalLM 29 | - all 30 | 31 | ## PeftModelForSeq2SeqLM 32 | 33 | A `PeftModel` for sequence-to-sequence language modeling. 34 | 35 | [[autodoc]] PeftModelForSeq2SeqLM 36 | - all 37 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/docs/source/package_reference/tuners.mdx: -------------------------------------------------------------------------------- 1 | # Tuners 2 | 3 | Each tuner (or PEFT method) has a configuration and model. 4 | 5 | ## LoRA 6 | 7 | For finetuning a model with LoRA. 8 | 9 | [[autodoc]] LoraConfig 10 | 11 | [[autodoc]] LoraModel 12 | 13 | [[autodoc]] tuners.lora.LoraLayer 14 | 15 | [[autodoc]] tuners.lora.Linear 16 | 17 | ## P-tuning 18 | 19 | [[autodoc]] tuners.p_tuning.PromptEncoderConfig 20 | 21 | [[autodoc]] tuners.p_tuning.PromptEncoder 22 | 23 | ## Prefix tuning 24 | 25 | [[autodoc]] tuners.prefix_tuning.PrefixTuningConfig 26 | 27 | [[autodoc]] tuners.prefix_tuning.PrefixEncoder 28 | 29 | ## Prompt tuning 30 | 31 | [[autodoc]] tuners.prompt_tuning.PromptTuningConfig 32 | 33 | [[autodoc]] tuners.prompt_tuning.PromptEmbedding -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/grammar.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | # 3 | # lora_A = torch.randn((6,1)) 4 | # tensor_list = [torch.tensor(float(i)) for i in range(1,5)] 5 | # print(f'lora_A: {lora_A}') 6 | # print(f'tensor_list: {tensor_list}') 7 | # 8 | # lora_A.requires_grad = True 9 | # for x in tensor_list: 10 | # x.requires_grad = True 11 | # 12 | # c = [] 13 | # for x in tensor_list: 14 | # c.append(lora_A[1] * x) 15 | # 16 | # d = torch.stack(c,0) 17 | # print(f'stacked d: {d}') 18 | # 19 | # d.sum().backward() 20 | # 21 | # print(lora_A.grad) 22 | 23 | 24 | import torch 25 | import torch.nn as nn 26 | class Conv1D(nn.Module): 27 | """ 28 | 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). 29 | 30 | Basically works like a linear layer but the weights are transposed. 31 | 32 | Args: 33 | nf (`int`): The number of output features. 34 | nx (`int`): The number of input features. 35 | """ 36 | 37 | def __init__(self, nf, nx): 38 | super().__init__() 39 | self.nf = nf 40 | self.weight = nn.Parameter(torch.empty(nx, nf)) 41 | self.bias = nn.Parameter(torch.zeros(nf)) 42 | nn.init.normal_(self.weight, std=0.02) 43 | 44 | def forward(self, x): 45 | size_out = x.size()[:-1] + (self.nf,) 46 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 47 | x = x.view(size_out) 48 | return x 49 | a = Conv1D(500,300) 50 | 51 | print(a.weight.shape) -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | target-version = ['py36'] 4 | 5 | [tool.ruff] 6 | ignore = ["C901", "E501", "E741", "W605"] 7 | select = ["C", "E", "F", "I", "W"] 8 | line-length = 119 9 | 10 | [tool.ruff.isort] 11 | lines-after-imports = 2 12 | known-first-party = ["peft"] 13 | 14 | [isort] 15 | default_section = "FIRSTPARTY" 16 | known_first_party = "peft" 17 | known_third_party = [ 18 | "numpy", 19 | "torch", 20 | "accelerate", 21 | "transformers", 22 | ] 23 | line_length = 119 24 | lines_after_imports = 2 25 | multi_line_output = 3 26 | include_trailing_comma = true 27 | force_grid_wrap = 0 28 | use_parentheses = true 29 | ensure_newline_before_comments = true 30 | 31 | [tool.pytest] 32 | doctest_optionflags = [ 33 | "NORMALIZE_WHITESPACE", 34 | "ELLIPSIS", 35 | "NUMBER", 36 | ] -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/scripts/log_reports.py: -------------------------------------------------------------------------------- 1 | import json, os 2 | from pathlib import Path 3 | from datetime import date 4 | from tabulate import tabulate 5 | 6 | failed = [] 7 | passed = [] 8 | 9 | group_info = [] 10 | 11 | total_num_failed = 0 12 | empty_file = False or len(list(Path().glob("*.log"))) == 0 13 | for log in Path().glob("*.log"): 14 | section_num_failed = 0 15 | with open(log, "r") as f: 16 | nb_lines = sum(1 for _ in f) 17 | for i, line in f: 18 | line = json.loads(line) 19 | if line.get("nodeid", "") != "": 20 | test = line["nodeid"] 21 | if line.get("duration", None) is not None: 22 | duration = f'{line["duration"]:.4f}' 23 | if line.get("outcome", "") == "failed": 24 | section_num_failed += 1 25 | failed.append([test, duration, log.name.split('_')[0]]) 26 | total_num_failed += 1 27 | else: 28 | passed.append([test, duration, log.name.split('_')[0]]) 29 | if nb_lines == 0: 30 | empty_file = True 31 | group_info.append([str(log), section_num_failed, failed]) 32 | os.remove(log) 33 | failed = [] 34 | no_error_payload = { 35 | "type": "section", 36 | "text": { 37 | "type": "plain_text", 38 | "text": "🌞 There were no failures!" if not empty_file else "Something went wrong - please check GH action results.", 39 | "emoji": True 40 | } 41 | } 42 | 43 | message = "" 44 | payload = [ 45 | { 46 | "type": "header", 47 | "text": { 48 | "type": "plain_text", 49 | "text": "🤗 Results of the {} PEFT scheduled tests.".format(os.environ.get("TEST_TYPE", "")), 50 | } 51 | }, 52 | ] 53 | if total_num_failed > 0: 54 | for name, num_failed, failed_tests in group_info: 55 | if num_failed > 0: 56 | if num_failed == 1: 57 | message += f"*{name}: {num_failed} failed test*\n" 58 | else: 59 | message += f"*{name}: {num_failed} failed tests*\n" 60 | failed_table = [] 61 | for test in failed_tests: 62 | failed_table.append(test[0].split("::")) 63 | failed_table = tabulate(failed_table, headers=["Test Location", "Test Case", "Test Name"], showindex="always", tablefmt="grid", maxcolwidths=[12, 12, 12]) 64 | message += '\n```\n' +failed_table + '\n```' 65 | print(f'### {message}') 66 | else: 67 | payload.append(no_error_payload) 68 | 69 | 70 | if os.environ.get("TEST_TYPE", "") != "": 71 | from slack_sdk import WebClient 72 | 73 | if len(message) != 0: 74 | md_report = { 75 | "type": "section", 76 | "text": { 77 | "type": "mrkdwn", 78 | "text": message 79 | }, 80 | } 81 | payload.append(md_report) 82 | action_button = { 83 | "type": "section", 84 | "text": { 85 | "type": "mrkdwn", 86 | "text": "*For more details:*" 87 | }, 88 | "accessory": { 89 | "type": "button", 90 | "text": {"type": "plain_text", "text": "Check Action results", "emoji": True}, 91 | "url": f"https://github.com/huggingface/peft/actions/runs/{os.environ['GITHUB_RUN_ID']}", 92 | }, 93 | } 94 | payload.append(action_button) 95 | 96 | date_report = { 97 | "type": "context", 98 | "elements": [ 99 | { 100 | "type": "plain_text", 101 | "text": f"Nightly {os.environ.get('TEST_TYPE')} test results for {date.today()}", 102 | }, 103 | ], 104 | } 105 | payload.append(date_report) 106 | 107 | print(payload) 108 | 109 | client = WebClient(token=os.environ.get("SLACK_API_TOKEN")) 110 | client.chat_postMessage(channel="#peft-ci-daily", text=message, blocks=payload) 111 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/scripts/stale.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team, the AllenNLP library authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Script to close stale issue. Taken in part from the AllenNLP repository. 16 | https://github.com/allenai/allennlp. 17 | """ 18 | from datetime import datetime as dt 19 | import os 20 | 21 | from github import Github 22 | 23 | 24 | LABELS_TO_EXEMPT = [ 25 | "good first issue", 26 | "good second issue", 27 | "good difficult issue", 28 | "feature request", 29 | "new model", 30 | "wip", 31 | "PRs welcome to address this", 32 | ] 33 | 34 | 35 | def main(): 36 | g = Github(os.environ["GITHUB_TOKEN"]) 37 | repo = g.get_repo("huggingface/peft") 38 | open_issues = repo.get_issues(state="open") 39 | 40 | for issue in open_issues: 41 | comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True) 42 | last_comment = comments[0] if len(comments) > 0 else None 43 | if ( 44 | last_comment is not None and last_comment.user.login == "github-actions[bot]" 45 | and (dt.utcnow() - issue.updated_at).days > 7 46 | and (dt.utcnow() - issue.created_at).days >= 30 47 | and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels()) 48 | ): 49 | issue.edit(state="closed") 50 | elif ( 51 | (dt.utcnow() - issue.updated_at).days > 23 52 | and (dt.utcnow() - issue.created_at).days >= 30 53 | and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels()) 54 | ): 55 | issue.create_comment( 56 | "This issue has been automatically marked as stale because it has not had " 57 | "recent activity. If you think this still needs to be addressed " 58 | "please comment on this thread.\n\n" 59 | ) 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from setuptools import find_packages, setup 16 | 17 | extras = {} 18 | extras["quality"] = ["black ~= 22.0", "ruff>=0.0.241", "urllib3<=2.0.0"] 19 | extras["docs_specific"] = ["hf-doc-builder"] 20 | extras["dev"] = extras["quality"] + extras["docs_specific"] 21 | extras["test"] = extras["dev"] + ["pytest", "pytest-xdist", "parameterized", "datasets"] 22 | 23 | setup( 24 | name="peft", 25 | version="0.4.0.dev0", 26 | description="Parameter-Efficient Fine-Tuning (PEFT)", 27 | license_files=["LICENSE"], 28 | long_description=open("README.md", "r", encoding="utf-8").read(), 29 | long_description_content_type="text/markdown", 30 | keywords="deep learning", 31 | license="Apache", 32 | author="The HuggingFace team", 33 | author_email="sourab@huggingface.co", 34 | url="https://github.com/huggingface/peft", 35 | package_dir={"": "src"}, 36 | packages=find_packages("src"), 37 | entry_points={}, 38 | python_requires=">=3.7.0", 39 | install_requires=[ 40 | "numpy>=1.17", 41 | "packaging>=20.0", 42 | "psutil", 43 | "pyyaml", 44 | "torch>=1.13.0", 45 | "transformers", 46 | "accelerate", 47 | "safetensors", 48 | ], 49 | extras_require=extras, 50 | classifiers=[ 51 | "Development Status :: 5 - Production/Stable", 52 | "Intended Audience :: Developers", 53 | "Intended Audience :: Education", 54 | "Intended Audience :: Science/Research", 55 | "License :: OSI Approved :: Apache Software License", 56 | "Operating System :: OS Independent", 57 | "Programming Language :: Python :: 3", 58 | "Programming Language :: Python :: 3.7", 59 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 60 | ], 61 | ) 62 | 63 | # Release checklist 64 | # 1. Change the version in __init__.py and setup.py. 65 | # 2. Commit these changes with the message: "Release: VERSION" 66 | # 3. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' " 67 | # Push the tag to git: git push --tags origin main 68 | # 4. Run the following commands in the top-level directory: 69 | # python setup.py bdist_wheel 70 | # python setup.py sdist 71 | # 5. Upload the package to the pypi test server first: 72 | # twine upload dist/* -r pypitest 73 | # twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ 74 | # 6. Check that you can install it in a virtualenv by running: 75 | # pip install -i https://testpypi.python.org/pypi peft 76 | # 7. Upload the final version to actual pypi: 77 | # twine upload dist/* -r pypi 78 | # 8. Add release notes to the tag in github once everything is looking hunky-dory. 79 | # 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to master 80 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/src/peft/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | __version__ = "0.4.0.dev0" 21 | 22 | from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model 23 | from .peft_model import ( 24 | PeftModel, 25 | PeftModelForCausalLM, 26 | PeftModelForSeq2SeqLM, 27 | PeftModelForSequenceClassification, 28 | PeftModelForTokenClassification, 29 | PeftModelForQuestionAnswering, 30 | ) 31 | from .tuners import ( 32 | AdaptionPromptConfig, 33 | AdaptionPromptModel, 34 | LoraConfig, 35 | LoraModel, 36 | AdaLoraConfig, 37 | AdaLoraModel, 38 | PrefixEncoder, 39 | PrefixTuningConfig, 40 | PromptEmbedding, 41 | PromptEncoder, 42 | PromptEncoderConfig, 43 | PromptEncoderReparameterizationType, 44 | PromptTuningConfig, 45 | PromptTuningInit, 46 | MeloConfig, 47 | MeloModel 48 | ) 49 | from .utils import ( 50 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 51 | PeftConfig, 52 | PeftType, 53 | PromptLearningConfig, 54 | TaskType, 55 | bloom_model_postprocess_past_key_value, 56 | get_peft_model_state_dict, 57 | prepare_model_for_int8_training, 58 | prepare_model_for_kbit_training, 59 | set_peft_model_state_dict, 60 | shift_tokens_right, 61 | ) 62 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/src/peft/import_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import importlib 16 | 17 | 18 | def is_bnb_available(): 19 | return importlib.util.find_spec("bitsandbytes") is not None 20 | 21 | 22 | def is_bnb_4bit_available(): 23 | if not is_bnb_available(): 24 | return False 25 | 26 | import bitsandbytes as bnb 27 | 28 | return hasattr(bnb.nn, "Linear4bit") 29 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/src/peft/tuners/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel 21 | from .lora import LoraConfig, LoraModel 22 | from .adalora import AdaLoraConfig, AdaLoraModel 23 | from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType 24 | from .prefix_tuning import PrefixEncoder, PrefixTuningConfig 25 | from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit 26 | from .melo import MeloConfig, MeloModel -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/src/peft/tuners/prefix_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from dataclasses import dataclass, field 18 | 19 | import torch 20 | 21 | from ..utils import PeftType, PromptLearningConfig 22 | 23 | 24 | @dataclass 25 | class PrefixTuningConfig(PromptLearningConfig): 26 | """ 27 | This is the configuration class to store the configuration of a [`PrefixEncoder`]. 28 | 29 | Args: 30 | encoder_hidden_size (`int`): The hidden size of the prompt encoder. 31 | prefix_projection (`bool`): Whether to project the prefix embeddings. 32 | """ 33 | 34 | encoder_hidden_size: int = field( 35 | default=None, 36 | metadata={"help": "The hidden size of the encoder"}, 37 | ) 38 | prefix_projection: bool = field( 39 | default=False, 40 | metadata={"help": "Whether to project the prefix tokens"}, 41 | ) 42 | 43 | def __post_init__(self): 44 | self.peft_type = PeftType.PREFIX_TUNING 45 | 46 | 47 | # Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py 48 | # with some refactor 49 | class PrefixEncoder(torch.nn.Module): 50 | r""" 51 | The `torch.nn` model to encode the prefix. 52 | 53 | Args: 54 | config ([`PrefixTuningConfig`]): The configuration of the prefix encoder. 55 | 56 | Example: 57 | 58 | ```py 59 | >>> from peft import PrefixEncoder, PrefixTuningConfig 60 | 61 | >>> config = PrefixTuningConfig( 62 | ... peft_type="PREFIX_TUNING", 63 | ... task_type="SEQ_2_SEQ_LM", 64 | ... num_virtual_tokens=20, 65 | ... token_dim=768, 66 | ... num_transformer_submodules=1, 67 | ... num_attention_heads=12, 68 | ... num_layers=12, 69 | ... encoder_hidden_size=768, 70 | ... ) 71 | >>> prefix_encoder = PrefixEncoder(config) 72 | ``` 73 | 74 | **Attributes**: 75 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prefix encoder. 76 | - **transform** (`torch.nn.Sequential`) -- The two-layer MLP to transform the prefix embeddings if 77 | `prefix_projection` is `True`. 78 | - **prefix_projection** (`bool`) -- Whether to project the prefix embeddings. 79 | 80 | Input shape: (`batch_size`, `num_virtual_tokens`) 81 | 82 | Output shape: (`batch_size`, `num_virtual_tokens`, `2*layers*hidden`) 83 | """ 84 | 85 | def __init__(self, config): 86 | super().__init__() 87 | self.prefix_projection = config.prefix_projection 88 | token_dim = config.token_dim 89 | num_layers = config.num_layers 90 | encoder_hidden_size = config.encoder_hidden_size 91 | num_virtual_tokens = config.num_virtual_tokens 92 | if self.prefix_projection and not config.inference_mode: 93 | # Use a two-layer MLP to encode the prefix 94 | self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim) 95 | self.transform = torch.nn.Sequential( 96 | torch.nn.Linear(token_dim, encoder_hidden_size), 97 | torch.nn.Tanh(), 98 | torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim), 99 | ) 100 | else: 101 | self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim) 102 | 103 | def forward(self, prefix: torch.Tensor): 104 | if self.prefix_projection: 105 | prefix_tokens = self.embedding(prefix) 106 | past_key_values = self.transform(prefix_tokens) 107 | else: 108 | past_key_values = self.embedding(prefix) 109 | return past_key_values 110 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/src/peft/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType 21 | from .other import ( 22 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 23 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, 24 | TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, 25 | COMMON_LAYERS_PATTERN, 26 | CONFIG_NAME, 27 | WEIGHTS_NAME, 28 | SAFETENSORS_WEIGHTS_NAME, 29 | _set_trainable, 30 | add_library_to_model_card, 31 | bloom_model_postprocess_past_key_value, 32 | prepare_model_for_int8_training, 33 | prepare_model_for_kbit_training, 34 | shift_tokens_right, 35 | transpose, 36 | _get_submodules, 37 | _set_adapter, 38 | _freeze_adapter, 39 | ModulesToSaveWrapper, 40 | ) 41 | from .hub_utils import hub_file_exists 42 | from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict 43 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/src/peft/utils/hub_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from huggingface_hub import get_hf_file_metadata, hf_hub_url 17 | from huggingface_hub.utils import EntryNotFoundError 18 | 19 | 20 | def hub_file_exists(repo_id: str, filename: str, revision: str = None, repo_type: str = None) -> bool: 21 | r""" 22 | Checks if a file exists in a remote Hub repository. 23 | """ 24 | url = hf_hub_url(repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision) 25 | try: 26 | get_hf_file_metadata(url) 27 | return True 28 | except EntryNotFoundError: 29 | return False 30 | -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llm-editing/editing-attack/b6160d94659542c9e6406c65d2dac435dbe91afb/code/easyeditor/models/melo/peft_egg/tests/__init__.py -------------------------------------------------------------------------------- /code/easyeditor/models/melo/peft_egg/tests/testing_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import unittest 16 | 17 | import torch 18 | 19 | 20 | def require_torch_gpu(test_case): 21 | """ 22 | Decorator marking a test that requires a GPU. Will be skipped when no GPU is available. 23 | """ 24 | if not torch.cuda.is_available(): 25 | return unittest.skip("test requires GPU")(test_case) 26 | else: 27 | return test_case 28 | 29 | 30 | def require_torch_multi_gpu(test_case): 31 | """ 32 | Decorator marking a test that requires multiple GPUs. Will be skipped when less than 2 GPUs are available. 33 | """ 34 | if not torch.cuda.is_available() or torch.cuda.device_count() < 2: 35 | return unittest.skip("test requires multiple GPUs")(test_case) 36 | else: 37 | return test_case 38 | 39 | 40 | def require_bitsandbytes(test_case): 41 | """ 42 | Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library is not installed. 43 | """ 44 | try: 45 | import bitsandbytes # noqa: F401 46 | except ImportError: 47 | return unittest.skip("test requires bitsandbytes")(test_case) 48 | else: 49 | return test_case 50 | -------------------------------------------------------------------------------- /code/easyeditor/models/memit/__init__.py: -------------------------------------------------------------------------------- 1 | from .memit_main import MEMITHyperParams, apply_memit_to_model 2 | -------------------------------------------------------------------------------- /code/easyeditor/models/memit/compute_ks.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy as np 4 | import torch 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | 7 | from .compute_z import get_module_input_output_at_words 8 | from .memit_hparams import MEMITHyperParams 9 | 10 | 11 | def compute_ks( 12 | model: AutoModelForCausalLM, 13 | tok: AutoTokenizer, 14 | requests: Dict, 15 | hparams: MEMITHyperParams, 16 | layer: int, 17 | context_templates: List[str], 18 | ): 19 | layer_ks = get_module_input_output_at_words( 20 | model, 21 | tok, 22 | layer, 23 | context_templates=[ 24 | context.format(request["prompt"]) 25 | for request in requests 26 | for context_type in context_templates 27 | for context in context_type 28 | ], 29 | words=[ 30 | request["subject"] 31 | for request in requests 32 | for context_type in context_templates 33 | for _ in context_type 34 | ], 35 | module_template=hparams.rewrite_module_tmp, 36 | fact_token_strategy=hparams.fact_token, 37 | )[0] 38 | 39 | context_type_lens = [0] + [len(context_type) for context_type in context_templates] 40 | context_len = sum(context_type_lens) 41 | context_type_csum = np.cumsum(context_type_lens).tolist() 42 | 43 | ans = [] 44 | for i in range(0, layer_ks.size(0), context_len): 45 | tmp = [] 46 | for j in range(len(context_type_csum) - 1): 47 | start, end = context_type_csum[j], context_type_csum[j + 1] 48 | tmp.append(layer_ks[i + start : i + end].mean(0)) 49 | ans.append(torch.stack(tmp, 0).mean(0)) 50 | return torch.stack(ans, dim=0) 51 | -------------------------------------------------------------------------------- /code/easyeditor/models/memit/memit_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Literal 3 | 4 | from ...util.hparams import HyperParams 5 | import yaml 6 | 7 | 8 | @dataclass 9 | class MEMITHyperParams(HyperParams): 10 | # Method 11 | layers: List[int] 12 | layer_selection: Literal["all", "random"] 13 | fact_token: Literal[ 14 | "last", "subject_first", "subject_last", "subject_first_after_last" 15 | ] 16 | v_num_grad_steps: int 17 | v_lr: float 18 | v_loss_layer: int 19 | v_weight_decay: float 20 | clamp_norm_factor: float 21 | kl_factor: float 22 | mom2_adjustment: bool 23 | mom2_update_weight: float 24 | 25 | # Module templates 26 | rewrite_module_tmp: str 27 | layer_module_tmp: str 28 | mlp_module_tmp: str 29 | attn_module_tmp: str 30 | ln_f_module: str 31 | lm_head_module: str 32 | 33 | # Statistics 34 | mom2_dataset: str 35 | mom2_n_samples: int 36 | mom2_dtype: str 37 | alg_name: str 38 | device: int 39 | model_name: str 40 | stats_dir: str 41 | 42 | max_length: int = 40 43 | batch_size: int = 1 44 | model_parallel: bool = False 45 | 46 | gpt_eval_endpoint_default: bool = True 47 | gpt_eval_name_default: bool = True 48 | 49 | @classmethod 50 | def from_hparams(cls, hparams_name_or_path: str): 51 | 52 | if '.yaml' not in hparams_name_or_path: 53 | hparams_name_or_path = hparams_name_or_path + '.yaml' 54 | 55 | with open(hparams_name_or_path, "r") as stream: 56 | config = yaml.safe_load(stream) 57 | config = super().construct_float_from_scientific_notation(config) 58 | 59 | assert (config and config['alg_name'] == 'MEMIT') or print(f'MEMITHyperParams can not load from {hparams_name_or_path}, ' 60 | f'alg_name is {config["alg_name"]} ') 61 | return cls(**config) 62 | -------------------------------------------------------------------------------- /code/easyeditor/models/mend/__init__.py: -------------------------------------------------------------------------------- 1 | from .mend_hparams import MENDHyperParams 2 | from .mend_multimodal_hparams import MENDMultimodalHparams 3 | from .mend_main import MendRewriteExecutor, MendMultimodalRewriteExecutor, MendPerRewriteExecutor 4 | -------------------------------------------------------------------------------- /code/easyeditor/models/mend/mend_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from ...util.hparams import HyperParams 4 | from typing import Optional, Any, List 5 | import yaml 6 | 7 | 8 | @dataclass 9 | class MENDHyperParams(HyperParams): 10 | model_class: str 11 | tokenizer_class: str 12 | tokenizer_name: str 13 | inner_params: List[str] 14 | 15 | archive: Any 16 | 17 | # Method 18 | alg: str 19 | lr: float 20 | edit_lr: float 21 | lr_lr: float 22 | lr_scale: float 23 | seed: int 24 | debug: bool 25 | cedit: float 26 | cloc: float 27 | cbase: float 28 | dropout: float 29 | train_base: bool 30 | no_grad_layers: Any 31 | one_sided: bool 32 | n_hidden: int 33 | hidden_dim: Any 34 | init: str 35 | norm: bool 36 | combine: bool 37 | x_only: bool 38 | delta_only: bool 39 | act: str 40 | rank: int 41 | mlp_class: str 42 | shared: bool 43 | 44 | # Output 45 | results_dir: str 46 | 47 | # Train 48 | device: int 49 | model_save_pt: int 50 | silent: bool 51 | log_interval: int 52 | eval_log_interval:int 53 | final_eval:bool 54 | val_interval: int 55 | early_stop_patience: int 56 | early_stop_key: str 57 | eval_only: bool 58 | half: bool 59 | save: bool 60 | verbose: bool 61 | 62 | val_batch_size: int 63 | accumulate_bs: int 64 | val_steps: int 65 | opt: str 66 | grad_clip: float 67 | 68 | alg_name: str 69 | model_name: str 70 | device: int 71 | 72 | batch_size: int = 1 73 | max_length: int = 40 74 | max_epochs: Optional[int] = None 75 | max_iters: Optional[int] = None 76 | 77 | model_parallel: bool = False 78 | 79 | @classmethod 80 | def from_hparams(cls, hparams_name_or_path: str): 81 | 82 | if '.yaml' not in hparams_name_or_path: 83 | hparams_name_or_path = hparams_name_or_path + '.yaml' 84 | 85 | with open(hparams_name_or_path, "r") as stream: 86 | config = yaml.safe_load(stream) 87 | config = super().construct_float_from_scientific_notation(config) 88 | 89 | assert (config and config['alg'] == 'MEND') or print(f'MENDHyperParams can not load from {hparams_name_or_path}, ' 90 | f'alg_name is {config["alg"]} ') 91 | return cls(**config) 92 | -------------------------------------------------------------------------------- /code/easyeditor/models/mend/mend_multimodal_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from ...util.hparams import HyperParams 3 | from .mend_hparams import MENDHyperParams 4 | from typing import Optional, Any, List 5 | import yaml 6 | 7 | 8 | @dataclass 9 | class MENDMultimodalHparams(HyperParams): 10 | 11 | # Multimodal 12 | qformer_name_or_path: str 13 | state_dict_file: str 14 | 15 | # Image_dir 16 | coco_image: str 17 | rephrase_image: str 18 | 19 | # Model 20 | name: str 21 | model_name: str 22 | model_class: str 23 | tokenizer_class: str 24 | tokenizer_name: str 25 | inner_params: List[str] 26 | 27 | archive: Any 28 | 29 | # Method 30 | alg: str 31 | lr: float 32 | edit_lr: float 33 | lr_lr: float 34 | lr_scale: float 35 | seed: int 36 | debug: bool 37 | cedit: float 38 | iedit: float 39 | cloc: float 40 | cbase: float 41 | dropout: float 42 | train_base: bool 43 | no_grad_layers: Any 44 | one_sided: bool 45 | n_hidden: int 46 | hidden_dim: Any 47 | init: str 48 | norm: bool 49 | combine: bool 50 | x_only: bool 51 | delta_only: bool 52 | act: str 53 | rank: int 54 | mlp_class: str 55 | shared: bool 56 | 57 | # Output 58 | 59 | results_dir: str 60 | 61 | # Train 62 | device: str 63 | model_save_pt: int 64 | silent: bool 65 | log_interval: int 66 | eval_log_interval:int 67 | final_eval:bool 68 | val_interval: int 69 | early_stop_patience: int 70 | early_stop_key: str 71 | eval_only: bool 72 | half: bool 73 | save: bool 74 | verbose: bool 75 | 76 | val_batch_size: int 77 | accumulate_bs: int 78 | val_steps: int 79 | opt: str 80 | grad_clip: float 81 | 82 | alg_name: str 83 | 84 | exact_match: bool = False 85 | batch_size: int = 1 86 | max_length: int = 30 87 | max_epochs: Optional[int] = None 88 | max_iters: Optional[int] = None 89 | model_parallel: bool = False 90 | qformer_checkpoint: Optional[str] = None 91 | freeze_qformer: bool = True 92 | pretrained_ckpt: Optional[str] = None 93 | 94 | @classmethod 95 | def from_hparams(cls, hparams_name_or_path: str): 96 | 97 | if '.yaml' not in hparams_name_or_path: 98 | hparams_name_or_path = hparams_name_or_path + '.yaml' 99 | 100 | with open(hparams_name_or_path, "r") as stream: 101 | config = yaml.safe_load(stream) 102 | config = super().construct_float_from_scientific_notation(config) 103 | 104 | assert (config and config['alg'] == 'MEND') or print(f'MENDMultimodalHyperParams can not load from {hparams_name_or_path}, ' 105 | f'alg_name is {config["alg"]} ') 106 | return cls(**config) 107 | 108 | -------------------------------------------------------------------------------- /code/easyeditor/models/mend/oracle.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | import torch.nn as nn 5 | from higher.patch import monkeypatch as make_functional 6 | from losses import kl_loc_loss, masked_log_probs 7 | 8 | 9 | def test_rank1(model, dataset, config): 10 | model.eval() 11 | generator = dataset.edit_generator(21) 12 | 13 | history = [] 14 | for example in generator: 15 | edit_model = make_functional(model, track_higher_grads=False) 16 | residuals = {} 17 | opt_list = [] 18 | print(config.model.inner_params) 19 | for n, p in edit_model.named_parameters(): 20 | if n in config.model.inner_params: 21 | std = 0.01 22 | u = nn.Parameter(torch.randn(p.shape[0], 1, device=p.device) * std) 23 | v = nn.Parameter(torch.randn(1, p.shape[1], device=p.device) * std) 24 | assert ( 25 | u @ v 26 | ).shape == p.shape, f"got {(u@v).shape}, expected {p.shape}" 27 | 28 | residuals[n] = (u, v) 29 | opt_list.extend([u, v]) 30 | 31 | res_opt = torch.optim.SGD(opt_list, lr=100) 32 | 33 | acc = 0 34 | it = 0 35 | ids_train = example["loc_ids"][:10] 36 | ids_val = example["loc_ids"][10:] 37 | with torch.inference_mode(): 38 | original_logits_train = model(ids_train) 39 | original_logits_val = model(ids_val) 40 | if hasattr(original_logits_train, "logits"): 41 | original_logits_train = original_logits_train.logits 42 | original_logits_val = original_logits_val.logits 43 | 44 | while acc < 1 and it < 1000: 45 | fast_params = [] 46 | for n, p in edit_model.named_parameters(): 47 | if n in residuals: 48 | u, v = residuals[n] 49 | fast_params.append(p.detach() + (u @ v)) 50 | else: 51 | fast_params.append(p.detach()) 52 | 53 | loc_pred = edit_model(ids_train, params=fast_params) 54 | if hasattr(loc_pred, "logits"): 55 | loc_pred = loc_pred.logits 56 | 57 | loc_loss = kl_loc_loss(original_logits_train, loc_pred) 58 | 59 | pred_log = edit_model(example["edit_inner_ids"], params=fast_params) 60 | if hasattr(pred_log, "logits"): 61 | pred_log = pred_log.logits 62 | prob_dict = masked_log_probs(pred_log, example["edit_inner_labels"]) 63 | edit_loss = prob_dict["nll"] 64 | acc = prob_dict["acc"] 65 | 66 | loss = loc_loss + 0.0002 * edit_loss 67 | with torch.inference_mode(): 68 | loc_pred_val = edit_model(ids_val, params=fast_params) 69 | if hasattr(loc_pred_val, "logits"): 70 | loc_pred_val = loc_pred_val.logits 71 | 72 | if pred_log.dim() == 3: 73 | facc = ( 74 | ( 75 | pred_log.argmax(-1)[0, -10:-1] 76 | == example["edit_inner_labels"][0, -9:] 77 | ) 78 | .float() 79 | .mean() 80 | ) 81 | ret = ( 82 | (original_logits_val.argmax(-1) == loc_pred_val.argmax(-1)) 83 | .float() 84 | .mean() 85 | ) 86 | else: 87 | facc = (pred_log > 0) == example["edit_inner_labels"] 88 | ret = ( 89 | ((original_logits_val > 0) == (loc_pred_val > 0)).float().mean() 90 | ) 91 | 92 | print( 93 | f"{it}, ({loss.item():.6f}, {loc_loss.item():.4f}, {edit_loss.item():.4f}), {facc.item():.2f}, {ret.item():.4f} {(u@v).view(-1).norm().item():.5f}", 94 | end="\r", 95 | ) 96 | 97 | for p, g in zip(opt_list, torch.autograd.grad(loss, opt_list)): 98 | p.grad = g 99 | res_opt.step() 100 | res_opt.zero_grad() 101 | 102 | it += 1 103 | 104 | if acc == 1: 105 | history.append(1) 106 | else: 107 | history.append(0) 108 | 109 | print() 110 | print(len(history), sum(history) / len(history), ret.item()) 111 | -------------------------------------------------------------------------------- /code/easyeditor/models/pmet/__init__.py: -------------------------------------------------------------------------------- 1 | from .pmet_main import PMETHyperParams, apply_pmet_to_model -------------------------------------------------------------------------------- /code/easyeditor/models/pmet/compute_ks.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy as np 4 | import torch 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | 7 | from .compute_zs import get_modules_input_output_at_words, get_module_input_output_at_words 8 | from .pmet_hparams import PMETHyperParams 9 | 10 | 11 | def compute_ks_parallel( 12 | model: AutoModelForCausalLM, 13 | tok: AutoTokenizer, 14 | requests: Dict, 15 | hparams: PMETHyperParams, 16 | layer: int, 17 | context_templates: List[str], 18 | ): 19 | layers_ks = dict() 20 | rewrite_module_tmps = hparams.rewrite_module_tmps 21 | layers_ks[rewrite_module_tmps[0]], layers_ks[rewrite_module_tmps[1]]= get_modules_input_output_at_words( 22 | model, 23 | tok, 24 | layer, 25 | context_templates=[ 26 | context.format(request["prompt"]) 27 | for request in requests 28 | for context_type in context_templates 29 | for context in context_type 30 | ], 31 | words=[ 32 | request["subject"] 33 | for request in requests 34 | for context_type in context_templates 35 | for _ in context_type 36 | ], 37 | module_templates=rewrite_module_tmps, 38 | fact_token_strategy=hparams.fact_token, 39 | ) 40 | for rewrite_module_tmp in rewrite_module_tmps: 41 | context_type_lens = [0] + [len(context_type) for context_type in context_templates] 42 | context_len = sum(context_type_lens) 43 | context_type_csum = np.cumsum(context_type_lens).tolist() 44 | ans = [] 45 | for i in range(0, layers_ks[rewrite_module_tmp].size(0), context_len): 46 | tmp = [] 47 | for j in range(len(context_type_csum) - 1): 48 | start, end = context_type_csum[j], context_type_csum[j + 1] 49 | tmp.append(layers_ks[rewrite_module_tmp][i + start : i + end].mean(0)) 50 | ans.append(torch.stack(tmp, 0).mean(0)) 51 | layers_ks[rewrite_module_tmp] = torch.stack(ans, dim=0) 52 | return layers_ks 53 | 54 | def compute_ks( 55 | model: AutoModelForCausalLM, 56 | tok: AutoTokenizer, 57 | requests: Dict, 58 | hparams: PMETHyperParams, 59 | rewrite_module_tmp: str, 60 | layer: int, 61 | context_templates: List[str], 62 | ): 63 | layers_ks = dict() 64 | layer_ks = get_module_input_output_at_words( 65 | model, 66 | tok, 67 | layer, 68 | context_templates=[ 69 | context.format(request["prompt"]) 70 | for request in requests 71 | for context_type in context_templates 72 | for context in context_type 73 | ], 74 | words=[ 75 | request["subject"] 76 | for request in requests 77 | for context_type in context_templates 78 | for _ in context_type 79 | ], 80 | module_template=rewrite_module_tmp, 81 | fact_token_strategy=hparams.fact_token, 82 | )[0] 83 | 84 | context_type_lens = [0] + [len(context_type) for context_type in context_templates] 85 | context_len = sum(context_type_lens) 86 | context_type_csum = np.cumsum(context_type_lens).tolist() 87 | 88 | ans = [] 89 | for i in range(0, layer_ks.size(0), context_len): 90 | tmp = [] 91 | for j in range(len(context_type_csum) - 1): 92 | start, end = context_type_csum[j], context_type_csum[j + 1] 93 | tmp.append(layer_ks[i + start : i + end].mean(0)) 94 | ans.append(torch.stack(tmp, 0).mean(0)) 95 | layers_ks[rewrite_module_tmp] = torch.stack(ans, dim=0) 96 | return layers_ks -------------------------------------------------------------------------------- /code/easyeditor/models/pmet/pmet_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | from typing_extensions import Literal 4 | 5 | 6 | from ...util.hparams import HyperParams 7 | import yaml 8 | 9 | @dataclass 10 | class PMETHyperParams(HyperParams): 11 | # Method 12 | layers: List[int] 13 | layer_selection: Literal["all", "random"] 14 | fact_token: Literal[ 15 | "last", "subject_first", "subject_last", "subject_first_after_last" 16 | ] 17 | v_num_grad_steps: int 18 | v_lr: float 19 | v_loss_layer: int 20 | v_weight_decay: float 21 | clamp_norm_factor: float 22 | kl_factor: float 23 | mom2_adjustment: bool 24 | mom2_update_weight: float 25 | nll_loss_factor: float 26 | # Module templates 27 | rewrite_module_tmp: str 28 | rewrite_module_tmps: List[str] 29 | layer_module_tmp: str 30 | mlp_module_tmp: str 31 | attn_module_tmp: str 32 | ln_f_module: str 33 | lm_head_module: str 34 | 35 | # Statistics 36 | mom2_dataset: str 37 | mom2_n_samples: int 38 | mom2_dtype: str 39 | 40 | alg_name: str 41 | device: int 42 | model_name: str 43 | stats_dir: str 44 | 45 | max_length: int = 40 46 | batch_size: int = 1 47 | model_parallel: bool = False 48 | 49 | @classmethod 50 | def from_hparams(cls, hparams_name_or_path: str): 51 | 52 | if '.yaml' not in hparams_name_or_path: 53 | hparams_name_or_path = hparams_name_or_path + '.yaml' 54 | 55 | with open(hparams_name_or_path, "r") as stream: 56 | config = yaml.safe_load(stream) 57 | config = super().construct_float_from_scientific_notation(config) 58 | 59 | assert (config and config['alg_name'] == 'PMET') or print(f'PMETHyperParams can not load from {hparams_name_or_path}, ' 60 | f'alg_name is {config["alg_name"]} ') 61 | return cls(**config) 62 | -------------------------------------------------------------------------------- /code/easyeditor/models/rome/README.md: -------------------------------------------------------------------------------- 1 | # ROME 2 | This package provides a self-contained implementation of Rank-One Model Editing (ROME). 3 | 4 | Recall that ROME's update consists of: $u$ selection, $v_*$ optimization, and $v$ insertion. 5 | * [`compute_u.py`](compute_u.py): Chooses a $u$ vector. 6 | * [`compute_v.py`](compute_v.py): Choose a $v_*$ via optimization, then computes $v$. 7 | * [`rome_main.py`](rome_main.py): Instruments main logic. 8 | * [`rome_params.py`](rome_hparams.py): Interface for specifying hyperparameters. Inherits from the base [`params.py`](../util/hparams.py) module. 9 | 10 | For estimating second moment statistics of keys ($C = KK$), we provide the `layer_stats` module. See the [main README](../README.md) for usage instructions. 11 | * [`layer_stats.py`](layer_stats.py): Logic for retrieving and caching key statistics. 12 | * [`tok_dataset.py`](tok_dataset.py): Utilities for creating a dataset of tokens. -------------------------------------------------------------------------------- /code/easyeditor/models/rome/__init__.py: -------------------------------------------------------------------------------- 1 | from .rome_main import ROMEHyperParams, apply_rome_to_model, execute_rome 2 | -------------------------------------------------------------------------------- /code/easyeditor/models/rome/compute_u.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Dict, List 4 | 5 | import torch 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | 8 | from ..rome import repr_tools 9 | from ...util.globals import * 10 | 11 | from .layer_stats import layer_stats 12 | from .rome_hparams import ROMEHyperParams 13 | 14 | # Cache variables 15 | inv_mom2_cache = {} 16 | 17 | 18 | def get_inv_cov( 19 | model: AutoModelForCausalLM, 20 | tok: AutoTokenizer, 21 | layer_name: str, 22 | mom2_dataset: str, 23 | mom2_n_samples: str, 24 | mom2_dtype: str, 25 | hparams=None, 26 | ) -> torch.Tensor: 27 | """ 28 | Retrieves covariance statistics, then computes the algebraic inverse. 29 | Caches result for future use. 30 | """ 31 | 32 | global inv_mom2_cache 33 | 34 | model_name = model.config._name_or_path.replace("/", "_") 35 | key = (model_name, layer_name) 36 | 37 | if key not in inv_mom2_cache: 38 | print( 39 | f"Retrieving inverse covariance statistics for {model_name} @ {layer_name}. " 40 | f"The result will be cached to avoid repetitive computation." 41 | ) 42 | stat = layer_stats( 43 | model, 44 | tok, 45 | layer_name, 46 | hparams.stats_dir, 47 | mom2_dataset, 48 | to_collect=["mom2"], 49 | sample_size=mom2_n_samples, 50 | precision=mom2_dtype, 51 | hparams=hparams 52 | ) 53 | inv_mom2_cache[key] = torch.inverse( 54 | stat.mom2.moment().to(f"cuda:{hparams.device}") 55 | ).float() # Cast back to float32 56 | 57 | return inv_mom2_cache[key] 58 | 59 | 60 | def compute_u( 61 | model: AutoModelForCausalLM, 62 | tok: AutoTokenizer, 63 | request: Dict, 64 | hparams: ROMEHyperParams, 65 | layer: int, 66 | context_templates: List[str], 67 | ) -> torch.Tensor: 68 | """ 69 | Computes the right vector used in constructing the rank-1 update matrix. 70 | """ 71 | 72 | print("Computing left vector (u)...") 73 | 74 | # Compute projection token 75 | word_repr_args = dict( 76 | model=model, 77 | tok=tok, 78 | layer=layer, 79 | module_template=hparams.rewrite_module_tmp, 80 | track="in", 81 | ) 82 | if "subject_" in hparams.fact_token and hparams.fact_token.index("subject_") == 0: 83 | word = request["subject"] 84 | print(f"Selected u projection object {word}") 85 | 86 | cur_repr = repr_tools.get_reprs_at_word_tokens( 87 | context_templates=[ 88 | templ.format(request["prompt"]) for templ in context_templates 89 | ], 90 | words=[word for _ in range(len(context_templates))], 91 | subtoken=hparams.fact_token[len("subject_") :], 92 | **word_repr_args, 93 | ).mean(0) 94 | 95 | elif hparams.fact_token == "last": 96 | # Heuristic to choose last word. Not a huge deal if there's a minor 97 | # edge case (e.g. multi-token word) because the function below will 98 | # take the last token. 99 | cur_repr = repr_tools.get_reprs_at_idxs( 100 | contexts=[ 101 | templ.format(request["prompt"].format(request["subject"])) 102 | for templ in context_templates 103 | ], 104 | idxs=[[-1] for _ in range(len(context_templates))], 105 | **word_repr_args, 106 | ).mean(0) 107 | print("Selected u projection token with last token") 108 | else: 109 | raise ValueError(f"fact_token={hparams.fact_token} not recognized") 110 | 111 | # Apply inverse second moment adjustment 112 | u = cur_repr 113 | if hparams.mom2_adjustment: 114 | u = get_inv_cov( 115 | model, 116 | tok, 117 | hparams.rewrite_module_tmp.format(layer), 118 | hparams.mom2_dataset, 119 | hparams.mom2_n_samples, 120 | hparams.mom2_dtype, 121 | hparams=hparams, 122 | ) @ u.unsqueeze(1) 123 | u = u.squeeze() 124 | 125 | return u / u.norm() 126 | -------------------------------------------------------------------------------- /code/easyeditor/models/rome/rome_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | import yaml 4 | 5 | from ...util.hparams import HyperParams 6 | 7 | 8 | @dataclass 9 | class ROMEHyperParams(HyperParams): 10 | # Method 11 | layers: List[int] 12 | fact_token: str 13 | v_num_grad_steps: int 14 | v_lr: float 15 | v_loss_layer: int 16 | v_weight_decay: float 17 | clamp_norm_factor: float 18 | kl_factor: float 19 | mom2_adjustment: bool 20 | context_template_length_params: List[List[int]] 21 | 22 | # Module templates 23 | rewrite_module_tmp: str 24 | layer_module_tmp: str 25 | mlp_module_tmp: str 26 | attn_module_tmp: str 27 | ln_f_module: str 28 | lm_head_module: str 29 | 30 | # Statistics 31 | mom2_dataset: str 32 | mom2_n_samples: int 33 | mom2_dtype: str 34 | alg_name: str 35 | device: int 36 | model_name: str 37 | stats_dir: str 38 | 39 | max_length: int = 40 40 | model_parallel: bool = False 41 | fp16: bool = False 42 | 43 | gpt_eval_endpoint_default: bool = True 44 | gpt_eval_name_default: bool = True 45 | 46 | @classmethod 47 | def from_hparams(cls, hparams_name_or_path: str): 48 | 49 | if '.yaml' not in hparams_name_or_path: 50 | hparams_name_or_path = hparams_name_or_path + '.yaml' 51 | 52 | with open(hparams_name_or_path, "r") as stream: 53 | config = yaml.safe_load(stream) 54 | config = super().construct_float_from_scientific_notation(config) 55 | 56 | assert (config and config['alg_name'] == 'ROME') or print(f'ROMEHyperParams can not load from {hparams_name_or_path}, ' 57 | f'alg_name is {config["alg_name"]} ') 58 | return cls(**config) 59 | -------------------------------------------------------------------------------- /code/easyeditor/models/rome/tok_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class TokenizedDataset(Dataset): 7 | """ 8 | Converts a dataset of text samples into a dataset of token sequences, 9 | as converted by a supplied tokenizer. The tokens come along with position 10 | ids and attention masks, they can be supplied direcly to the model. 11 | """ 12 | 13 | def __init__(self, text_dataset, tokenizer=None, maxlen=None, field="text"): 14 | self.text_dataset = text_dataset 15 | self.field = field 16 | self.tokenizer = tokenizer 17 | self.maxlen = maxlen 18 | if hasattr(text_dataset, "info"): 19 | self.info = text_dataset.info 20 | 21 | def __len__(self): 22 | return len(self.text_dataset) 23 | 24 | def __getitem__(self, i): 25 | text = self.text_dataset[i] 26 | if self.field is not None: 27 | text = text[self.field] 28 | token_list = self.tokenizer.encode( 29 | text, truncation=True, max_length=self.maxlen 30 | ) 31 | position_ids = list(range(len(token_list))) 32 | attention_mask = [1] * len(token_list) 33 | return dict( 34 | input_ids=torch.tensor(token_list), 35 | position_ids=torch.tensor(position_ids), 36 | attention_mask=torch.tensor(attention_mask), 37 | ) 38 | 39 | 40 | def dict_to_(data, device): 41 | """ 42 | Moves a dictionary of tensors to the specified device. 43 | """ 44 | for k in data: 45 | data[k] = data[k].to(device) 46 | return data 47 | 48 | 49 | def length_collation(token_size): 50 | """ 51 | Sorts a batch of sequences and breaks it up into subbatches 52 | of same-sized sequences, padding as needed. Each batch 53 | has no more than token_size total tokens (or a single 54 | sequence, if the sequence happens to be larger). 55 | """ 56 | 57 | def collate_fn(items): 58 | items = sorted(items, key=lambda x: -len(x["input_ids"])) 59 | batches = [] 60 | batch = [] 61 | batch_width = 0 62 | for item in items: 63 | item_width = len(item["input_ids"]) 64 | if item_width == 0: 65 | break 66 | if batch_width * (len(batch) + 1) > token_size: 67 | batches.append(make_padded_batch(batch)) 68 | batch = [] 69 | batch_width = 0 70 | if not batch: 71 | batch_width = item_width 72 | batch.append(item) 73 | if len(batch): 74 | batches.append(make_padded_batch(batch)) 75 | return batches 76 | 77 | return collate_fn 78 | 79 | 80 | def make_padded_batch(items): 81 | """ 82 | Pads sequences in a batch, so they are all the same length as the longest. 83 | """ 84 | max_len = max(len(d["input_ids"]) for d in items) 85 | if max_len == 0: 86 | return {k: torch.zeros((0, 0), dtype=torch.long) for k in items[0]} 87 | return { 88 | k: pad_sequence([d[k] for d in items if len(d["input_ids"])], batch_first=True) 89 | for k, v in items[0].items() 90 | } 91 | 92 | 93 | def flatten_masked_batch(data, mask): 94 | """ 95 | Flattens feature data, ignoring items that are masked out of attention. 96 | """ 97 | flat_data = data.view(-1, data.size(-1)) 98 | attended_tokens = mask.view(-1).nonzero()[:, 0] 99 | return flat_data[attended_tokens] 100 | -------------------------------------------------------------------------------- /code/easyeditor/models/serac/__init__.py: -------------------------------------------------------------------------------- 1 | from .serac_main import SERACHparams, SeracRewriteExecutor, SeracMultimodalRewriteExecutor 2 | from .serac_multimodal_hparams import SERACMultimodalHparams 3 | -------------------------------------------------------------------------------- /code/easyeditor/models/serac/serac_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from ...util.hparams import HyperParams 3 | from typing import Optional, Any, List 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class SERACHparams(HyperParams): 9 | 10 | model_name: str 11 | model_class: str 12 | small_name: str 13 | tokenizer_class: str 14 | tokenizer_name: str 15 | cls_name: str 16 | cls_class: str 17 | inner_params: List[str] 18 | 19 | archive: Any 20 | 21 | # Method 22 | alg: str 23 | lr: float 24 | edit_lr: float 25 | seed: int 26 | lr_lr: float 27 | cedit: float 28 | cloc: float 29 | cbase: float 30 | dropout: float 31 | final_eval: bool 32 | supervised: bool 33 | train_base: bool 34 | no_grad_layers: Any 35 | soft_weighting: bool 36 | checkpoint_grad: bool 37 | cross_attend: bool 38 | cos: bool 39 | freeze: Any 40 | square: bool 41 | bound_embeds: bool 42 | use_all_negatives: bool 43 | freeze_cntr: bool 44 | dist_heads: int 45 | lora: Any 46 | 47 | # Output 48 | results_dir: str 49 | 50 | # Train 51 | device: int 52 | model_save_pt: int 53 | edit_bs: int 54 | silent: bool 55 | log_interval: int 56 | val_interval: int 57 | early_stop_patience: int 58 | early_stop_key: str 59 | eval_only: bool 60 | half: bool 61 | save: bool 62 | debug: bool 63 | log_errors: bool 64 | unlikelihood: bool 65 | 66 | val_batch_size: int 67 | accumulate_bs: int 68 | val_steps: int 69 | opt: str 70 | grad_clip: float 71 | 72 | alg_name: str 73 | device: int 74 | 75 | batch_size: int = 1 76 | max_length: int = 40 77 | model_parallel: bool = False 78 | max_epochs: Optional[int] = None 79 | max_iters: Optional[int] = None 80 | 81 | 82 | @classmethod 83 | def from_hparams(cls, hparams_name_or_path: str): 84 | 85 | if '.yaml' not in hparams_name_or_path: 86 | hparams_name_or_path = hparams_name_or_path + '.yaml' 87 | 88 | with open(hparams_name_or_path, "r") as stream: 89 | config = yaml.safe_load(stream) 90 | config = super().construct_float_from_scientific_notation(config) 91 | 92 | assert (config and config['alg'] == 'SERAC') or print(f'SERACTrainingHyperParams can not load from {hparams_name_or_path}, ' 93 | f'alg_name is {config["alg"]} ') 94 | return cls(**config) 95 | -------------------------------------------------------------------------------- /code/easyeditor/models/serac/serac_multimodal_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from ...util.hparams import HyperParams 3 | from typing import Optional, Any, List 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class SERACMultimodalHparams(HyperParams): 9 | 10 | # Multimodal 11 | qformer_name_or_path: str 12 | state_dict_file: str 13 | 14 | # Image_dir 15 | coco_image: str 16 | rephrase_image: str 17 | 18 | # Model 19 | name: str 20 | model_name: str 21 | model_class: str 22 | small_name: str 23 | tokenizer_class: str 24 | tokenizer_name: str 25 | cls_name: str 26 | cls_class: str 27 | inner_params: List[str] 28 | 29 | archive: Any 30 | 31 | # Method 32 | alg: str 33 | alg_name: str 34 | lr: float 35 | edit_lr: float 36 | seed: int 37 | lr_lr: float 38 | cedit: float 39 | iedit: float 40 | cloc: float 41 | cbase: float 42 | dropout: float 43 | final_eval: bool 44 | supervised: bool 45 | train_base: bool 46 | no_grad_layers: Any 47 | soft_weighting: bool 48 | checkpoint_grad: bool 49 | cross_attend: bool 50 | cos: bool 51 | freeze: Any 52 | square: bool 53 | bound_embeds: bool 54 | use_all_negatives: bool 55 | freeze_cntr: bool 56 | dist_heads: int 57 | lora: Any 58 | 59 | # Output 60 | results_dir: str 61 | 62 | # Train 63 | device: str 64 | batch_size: int 65 | model_save_pt: int 66 | edit_bs: int 67 | silent: bool 68 | log_interval: int 69 | val_interval: int 70 | early_stop_patience: int 71 | early_stop_key: str 72 | eval_only: bool 73 | half: bool 74 | save: bool 75 | debug: bool 76 | log_errors: bool 77 | unlikelihood: bool 78 | 79 | val_batch_size: int 80 | accumulate_bs: int 81 | val_steps: int 82 | opt: str 83 | grad_clip: float 84 | 85 | exact_match: bool = False 86 | max_length: int = 32 87 | max_epochs: Optional[int] = None 88 | max_iters: Optional[int] = None 89 | model_parallel: bool = False 90 | qformer_checkpoint: Optional[str] = None 91 | freeze_qformer: bool = True 92 | pretrained_ckpt: Optional[str] = None 93 | 94 | 95 | @classmethod 96 | def from_hparams(cls, hparams_name_or_path: str): 97 | 98 | if '.yaml' not in hparams_name_or_path: 99 | hparams_name_or_path = hparams_name_or_path + '.yaml' 100 | 101 | with open(hparams_name_or_path, "r") as stream: 102 | config = yaml.safe_load(stream) 103 | config = super().construct_float_from_scientific_notation(config) 104 | 105 | 106 | assert (config and config['alg'] == 'SERAC_MULTI') or print(f'SERACMultimodalHparams can not load from {hparams_name_or_path}, ' 107 | f'alg_name is {config["alg"]} ') 108 | return cls(**config) 109 | -------------------------------------------------------------------------------- /code/easyeditor/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .training_hparams import * 2 | from .algs import * 3 | from .EditTrainer import * 4 | from .BaseTrainer import * 5 | from .blip2_models import * 6 | from .MultimodalTrainer import * 7 | from .MultiTaskTrainer import * 8 | from .PerTrainer import * -------------------------------------------------------------------------------- /code/easyeditor/trainer/algs/__init__.py: -------------------------------------------------------------------------------- 1 | from .editable_model import * 2 | from .MEND import * 3 | from .SERAC import * 4 | from .MALMEN import * 5 | -------------------------------------------------------------------------------- /code/easyeditor/trainer/algs/editable_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from copy import deepcopy 3 | 4 | from ..losses import masked_log_probs 5 | from ..utils import _logits, shift_targets 6 | 7 | 8 | class EditableModel(nn.Module): 9 | def __init__(self, model, config, model_constructor): 10 | super().__init__() 11 | 12 | self.model = model 13 | self.config = deepcopy(config) 14 | self.model_constructor = model_constructor 15 | 16 | def _edit_loss_fn(config, pred, targ, **kwargs): 17 | if 'minigpt4' in config.model_name.lower() or 'blip' in self.config.model_name.lower(): 18 | return masked_log_probs(config, pred, targ, exact_match=self.config.exact_match, shift=True, **kwargs) 19 | elif 't5' in config.model_class.lower(): 20 | return masked_log_probs(config, pred, targ,) 21 | elif 'gpt' in config.model_class.lower(): 22 | return masked_log_probs(config, pred, targ, shift=True, **kwargs) 23 | elif 'llama' in config.model_class.lower(): 24 | return masked_log_probs(config, pred, targ, shift=True, **kwargs) 25 | elif 'internlm' in config.model_name.lower(): 26 | return masked_log_probs(config, pred, targ, shift=True) 27 | elif 'chatglm' in config.model_name.lower(): 28 | return masked_log_probs(config, pred, targ, shift=True) 29 | elif 'qwen' in config.model_name.lower(): 30 | return masked_log_probs(config, pred, targ, shift=True) 31 | elif 'mistral' in config.model_name.lower(): 32 | return masked_log_probs(config, pred, targ, shift=True) 33 | else: 34 | return masked_log_probs(config, pred, targ,) 35 | 36 | self.edit_loss_fn = _edit_loss_fn 37 | self.loc_loss_fn = masked_log_probs 38 | 39 | def edit(self, batch, condition=None, detach_history=False): 40 | raise NotImplementedError 41 | 42 | def forward(self, *inputs, **kwargs): 43 | return _logits(self.model(*inputs, **kwargs)) 44 | 45 | def outer_parameters(self): 46 | return self.parameters() 47 | 48 | def base_loss(self, input_ids, attention_masks, label_ids): 49 | pass 50 | -------------------------------------------------------------------------------- /code/easyeditor/trainer/algs/higher_utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions for components of ``higher``\ .""" 16 | 17 | import torch as _torch 18 | import typing as _typing 19 | 20 | _T = _typing.TypeVar('_T') 21 | _U = _typing.TypeVar('_U') 22 | 23 | 24 | def _copy_tensor( 25 | t: _torch.Tensor, 26 | safe_copy: bool, 27 | device: _typing.Optional[_torch.device] = None 28 | ) -> _torch.Tensor: 29 | if safe_copy: 30 | t = t.clone().detach().requires_grad_(t.requires_grad) 31 | else: 32 | t = t.detach().requires_grad_(t.requires_grad) 33 | t = t if device is None else t.to(device) 34 | return t 35 | 36 | 37 | def _recursive_copy_and_cast( 38 | target: _typing.Union[list, tuple, dict, set, _torch.Tensor], 39 | device: _typing.Optional[_torch.device] 40 | ) -> _torch.Tensor: 41 | def map_fn(x): 42 | if _torch.is_tensor(x): 43 | return _copy_tensor(x, True, device=device) 44 | else: 45 | return x 46 | return _recursive_map(target, map_fn) 47 | 48 | 49 | def _recursive_map( 50 | target: _typing.Union[list, tuple, dict, set, _T], 51 | map_fn: _typing.Callable[[_T], _U], 52 | ) -> _typing.Union[list, tuple, dict, set, _U]: 53 | if isinstance(target, list): 54 | return type(target)( 55 | [_recursive_map(x, map_fn) for x in target] 56 | ) 57 | elif isinstance(target, tuple): 58 | return type(target)( 59 | [_recursive_map(x, map_fn) for x in target] 60 | ) 61 | elif isinstance(target, dict): 62 | return type(target)( 63 | {k: _recursive_map(v, map_fn) 64 | for k, v in target.items()} 65 | ) 66 | elif isinstance(target, set): 67 | return type(target)( 68 | {_recursive_map(x, map_fn) 69 | for x in target} 70 | ) 71 | else: 72 | return map_fn(target) 73 | 74 | 75 | def _is_container(target: _typing.Any) -> bool: 76 | flag = ( 77 | isinstance(target, list) or 78 | isinstance(target, tuple) or 79 | isinstance(target, dict) or 80 | isinstance(target, set) 81 | ) 82 | return flag 83 | 84 | 85 | def _find_param_in_list( 86 | param: _torch.Tensor, l: _typing.Iterable[_torch.Tensor] 87 | ) -> _typing.Optional[int]: 88 | for i, p in enumerate(l): 89 | if p is param: 90 | return i 91 | else: 92 | return None 93 | 94 | 95 | def _get_param_mapping( 96 | module: _torch.nn.Module, seen: _typing.List[_torch.Tensor], 97 | mapping: _typing.List[int] 98 | ) -> _typing.List[int]: 99 | 100 | for param in module._parameters.values(): 101 | if param is None: 102 | continue 103 | found = _find_param_in_list(param, seen) 104 | if found is None: 105 | mapping.append(len(seen)) 106 | seen.append(param) 107 | else: 108 | mapping.append(found) 109 | 110 | for name, child in module._modules.items(): 111 | if child == None: continue 112 | _ = _get_param_mapping(child, seen, mapping) 113 | 114 | return mapping 115 | 116 | 117 | def flatten(x: _typing.Any) -> _typing.List[_typing.Any]: 118 | r"""Returns a flattened list of objects from a nested structure.""" 119 | l: _typing.List[_typing.Any] = [] 120 | if isinstance(x, dict): 121 | for y in x.values(): 122 | l.extend(flatten(y)) 123 | elif isinstance(x, list) or isinstance(x, set) or isinstance(x, tuple): 124 | for y in x: 125 | l.extend(flatten(y)) 126 | else: 127 | l.append(x) 128 | return l 129 | 130 | 131 | def get_func_params( 132 | module: _torch.nn.Module, 133 | device: _typing.Optional[_torch.device] = None, 134 | safe_copy: bool = True 135 | ) -> _typing.List[_torch.Tensor]: 136 | r"""Returns a detached copy of module parameters which requires gradient.""" 137 | params = [_copy_tensor(p, safe_copy, device) for p in module.parameters()] 138 | return params 139 | -------------------------------------------------------------------------------- /code/easyeditor/trainer/algs/hooks.py: -------------------------------------------------------------------------------- 1 | from ..utils import parent_module 2 | 3 | 4 | def linear_backward_hook(mod, grad_in, grad_out): 5 | if not hasattr(mod, "weight"): 6 | print(f"{mod} has no weight!") 7 | return 8 | 9 | if hasattr(mod.weight, "__x__"): 10 | assert len(grad_out) == 1 11 | # mod.weight.__bgrad__ = grad_out[0].unsqueeze(-1) * mod.__x__[0].unsqueeze(-2) 12 | mod.weight.__delta__ = grad_out[0].detach() 13 | else: 14 | print(f"{mod} has no __x__") 15 | 16 | 17 | def linear_forward_hook(mod, activations, output): 18 | assert len(activations) == 1 19 | mod.weight.__x__ = activations[0].detach() 20 | 21 | 22 | def hook_model(model, pnames): 23 | handles = [] 24 | for m in [parent_module(model, pname) for pname in pnames]: 25 | handles.append(m.register_full_backward_hook(linear_backward_hook)) 26 | handles.append(m.register_forward_hook(linear_forward_hook)) 27 | 28 | model.handles = handles 29 | -------------------------------------------------------------------------------- /code/easyeditor/trainer/algs/malmen/nets.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class RunningMeanStd(nn.Module): 8 | 9 | def __init__(self, size: int): 10 | super().__init__() 11 | 12 | self.register_buffer("n", torch.zeros(1)) 13 | self.register_buffer("mean", torch.zeros((size))) 14 | self.register_buffer("var", torch.zeros((size))) 15 | self.register_buffer("std", torch.zeros((size))) 16 | 17 | def update(self, x: torch.FloatTensor): 18 | 19 | n = self.n + x.shape[0] 20 | delta = x.mean(0) - self.mean 21 | self.mean += x.shape[0] * delta / n 22 | self.var += x.shape[0] * x.var(0) + self.n * x.shape[0] * delta.pow(2) / n 23 | self.std = (self.var / (n - 1 + torch.finfo(x.dtype).eps)).sqrt() 24 | self.n = n 25 | 26 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 27 | 28 | return (x - self.mean) / (self.std + torch.finfo(x.dtype).eps) 29 | 30 | 31 | class MALMENBlock(nn.Module): 32 | 33 | def __init__(self, size: int, rank: int, n_modules: int): 34 | super().__init__() 35 | 36 | self.A = nn.Parameter(torch.randn(size, rank)) 37 | self.B = nn.Parameter(torch.zeros(rank, size)) 38 | self.bias = nn.Parameter(torch.zeros(size)) 39 | 40 | self.scale = nn.Embedding(n_modules, size) 41 | self.shift = nn.Embedding(n_modules, size) 42 | 43 | self.scale.weight.data.fill_(1) 44 | self.shift.weight.data.fill_(0) 45 | 46 | def forward( 47 | self, 48 | y: torch.FloatTensor, 49 | module_idx: torch.LongTensor 50 | ) -> torch.FloatTensor: 51 | 52 | x = y @ self.A @ self.B + self.bias 53 | x = x.clamp(0) 54 | x = self.scale(module_idx) * x + self.shift(module_idx) 55 | x = x + y 56 | 57 | return x 58 | 59 | 60 | class MALMENNet(nn.Module): 61 | 62 | def __init__( 63 | self, 64 | key_size: int, 65 | value_size: int, 66 | rank: int, 67 | n_blocks: int, 68 | n_modules: int, 69 | lr: float 70 | ): 71 | super().__init__() 72 | self.key_size = key_size 73 | self.value_size = value_size 74 | 75 | self.normalizer = RunningMeanStd(key_size + value_size) 76 | self.blocks = nn.ModuleList([ 77 | MALMENBlock(key_size + value_size, rank, n_modules) 78 | for _ in range(n_blocks) 79 | ]) 80 | 81 | self.lr = nn.Embedding(n_modules, 1) 82 | self.lamda = nn.Embedding(n_modules, 1) 83 | 84 | self.lr.weight.data.fill_(lr) 85 | self.lamda.weight.data.fill_(0) 86 | 87 | def forward( 88 | self, 89 | keys: torch.FloatTensor, 90 | values_grad: torch.FloatTensor, 91 | module_idx: torch.LongTensor 92 | ) -> Tuple[torch.FloatTensor]: 93 | 94 | hidden_states = torch.cat((keys, values_grad), -1) 95 | hidden_states = self.normalizer(hidden_states) 96 | for block in self.blocks: 97 | hidden_states = block(hidden_states, module_idx) 98 | return hidden_states.split([self.key_size, self.value_size], -1) -------------------------------------------------------------------------------- /code/easyeditor/trainer/blip2_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .blip2_opt import Blip2OPT 2 | from .mini_gpt4 import MiniGPT4 -------------------------------------------------------------------------------- /code/easyeditor/trainer/blip2_models/common/dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import datetime 9 | import functools 10 | import os 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import timm.models.hub as timm_hub 15 | 16 | 17 | def setup_for_distributed(is_master): 18 | """ 19 | This function disables printing when not in master process 20 | """ 21 | import builtins as __builtin__ 22 | 23 | builtin_print = __builtin__.print 24 | 25 | def print(*args, **kwargs): 26 | force = kwargs.pop("force", False) 27 | if is_master or force: 28 | builtin_print(*args, **kwargs) 29 | 30 | __builtin__.print = print 31 | 32 | 33 | def is_dist_avail_and_initialized(): 34 | if not dist.is_available(): 35 | return False 36 | if not dist.is_initialized(): 37 | return False 38 | return True 39 | 40 | 41 | def get_world_size(): 42 | if not is_dist_avail_and_initialized(): 43 | return 1 44 | return dist.get_world_size() 45 | 46 | 47 | def get_rank(): 48 | if not is_dist_avail_and_initialized(): 49 | return 0 50 | return dist.get_rank() 51 | 52 | 53 | def is_main_process(): 54 | return get_rank() == 0 55 | 56 | 57 | def init_distributed_mode(args): 58 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 59 | args.rank = int(os.environ["RANK"]) 60 | args.world_size = int(os.environ["WORLD_SIZE"]) 61 | args.gpu = int(os.environ["LOCAL_RANK"]) 62 | elif "SLURM_PROCID" in os.environ: 63 | args.rank = int(os.environ["SLURM_PROCID"]) 64 | args.gpu = args.rank % torch.cuda.device_count() 65 | else: 66 | print("Not using distributed mode") 67 | args.distributed = False 68 | return 69 | 70 | args.distributed = True 71 | 72 | torch.cuda.set_device(args.gpu) 73 | args.dist_backend = "nccl" 74 | print( 75 | "| distributed init (rank {}, world {}): {}".format( 76 | args.rank, args.world_size, args.dist_url 77 | ), 78 | flush=True, 79 | ) 80 | torch.distributed.init_process_group( 81 | backend=args.dist_backend, 82 | init_method=args.dist_url, 83 | world_size=args.world_size, 84 | rank=args.rank, 85 | timeout=datetime.timedelta( 86 | days=365 87 | ), # allow auto-downloading and de-compressing 88 | ) 89 | torch.distributed.barrier() 90 | setup_for_distributed(args.rank == 0) 91 | 92 | 93 | def get_dist_info(): 94 | if torch.__version__ < "1.0": 95 | initialized = dist._initialized 96 | else: 97 | initialized = dist.is_initialized() 98 | if initialized: 99 | rank = dist.get_rank() 100 | world_size = dist.get_world_size() 101 | else: # non-distributed training 102 | rank = 0 103 | world_size = 1 104 | return rank, world_size 105 | 106 | 107 | def main_process(func): 108 | @functools.wraps(func) 109 | def wrapper(*args, **kwargs): 110 | rank, _ = get_dist_info() 111 | if rank == 0: 112 | return func(*args, **kwargs) 113 | 114 | return wrapper 115 | 116 | 117 | def download_cached_file(url, check_hash=True, progress=False): 118 | """ 119 | Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. 120 | If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. 121 | """ 122 | 123 | def get_cached_file_path(): 124 | # a hack to sync the file path across processes 125 | parts = torch.hub.urlparse(url) 126 | filename = os.path.basename(parts.path) 127 | cached_file = os.path.join(timm_hub.get_cache_dir(), filename) 128 | 129 | return cached_file 130 | 131 | if is_main_process(): 132 | timm_hub.download_cached_file(url, check_hash, progress) 133 | 134 | if is_dist_avail_and_initialized(): 135 | dist.barrier() 136 | 137 | return get_cached_file_path() 138 | -------------------------------------------------------------------------------- /code/easyeditor/trainer/training_hparams/__init__.py: -------------------------------------------------------------------------------- 1 | from .ke_training_hparams import * 2 | from .mend_training_hparams import * 3 | from .mend_multimodal_training_hparams import * 4 | from .serac_training_hparams import * 5 | from .serac_multimodal_training_hparams import * 6 | from .malmen_training_hparams import * 7 | -------------------------------------------------------------------------------- /code/easyeditor/trainer/training_hparams/ke_training_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from ...util.hparams import HyperParams 3 | from typing import Optional, Any, List 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class KETrainingHparams(HyperParams): 9 | 10 | # Model 11 | model_name: str 12 | model_class: str 13 | tokenizer_class: str 14 | tokenizer_name: str 15 | inner_params: List[str] 16 | 17 | archive: Any 18 | 19 | # Method 20 | alg: str 21 | lr: float 22 | edit_lr: float 23 | lr_lr: float 24 | seed: int 25 | debug: bool 26 | cedit: float 27 | cloc: float 28 | cbase: float 29 | dropout: float 30 | train_base: bool 31 | no_grad_layers: Any 32 | 33 | # Output 34 | 35 | results_dir: str 36 | 37 | 38 | # Train 39 | device: str 40 | batch_size: int 41 | model_save_pt: int 42 | silent: bool 43 | log_interval: int 44 | eval_log_interval:int 45 | final_eval:bool 46 | val_interval: int 47 | early_stop_patience: int 48 | early_stop_key: str 49 | eval_only: bool 50 | half: bool 51 | save: bool 52 | verbose: bool 53 | 54 | val_batch_size: int 55 | accumulate_bs: int 56 | val_steps: int 57 | opt: str 58 | grad_clip: float 59 | 60 | max_epochs: Optional[int] = None 61 | max_iters: Optional[int] = None 62 | 63 | @classmethod 64 | def from_hparams(cls, hparams_name_or_path: str): 65 | 66 | if '.yaml' not in hparams_name_or_path: 67 | hparams_name_or_path = hparams_name_or_path + '.yaml' 68 | 69 | with open(hparams_name_or_path, "r") as stream: 70 | config = yaml.safe_load(stream) 71 | config = super().construct_float_from_scientific_notation(config) 72 | 73 | assert (config and config['alg'] == 'KE') or print(f'KETrainingHyperParams can not load from {hparams_name_or_path}, ' 74 | f'alg_name is {config["alg"]} ') 75 | return cls(**config) 76 | -------------------------------------------------------------------------------- /code/easyeditor/trainer/training_hparams/malmen_training_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from ...util.hparams import HyperParams 3 | from typing import Optional, Any, List 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class MALMENTrainingHparams(HyperParams): 9 | 10 | # Model 11 | model_name: str 12 | model_class: str 13 | tokenizer_class: str 14 | tokenizer_name: str 15 | inner_params: List[str] 16 | 17 | archive: Any 18 | 19 | # Method 20 | alg: str 21 | debug: bool 22 | dropout: float 23 | train_base: bool 24 | no_grad_layers: Any 25 | 26 | rank: int 27 | n_edits: int 28 | n_blocks: int 29 | lr: float 30 | meta_lr: float 31 | loc_coef: float 32 | max_grad_norm: float 33 | token: str 34 | 35 | # Output 36 | results_dir: str 37 | 38 | # Train 39 | device: str 40 | batch_size: int 41 | editor_batch_size: int 42 | silent: bool 43 | log_interval: int 44 | eval_log_interval:int 45 | final_eval:bool 46 | val_interval: int 47 | early_stop_patience: int 48 | early_stop_key: str 49 | eval_only: bool 50 | save: bool 51 | 52 | val_batch_size: Optional[int] 53 | val_steps: int 54 | 55 | model_save_pt: Optional[int]=5000 56 | half: Optional[bool] = False 57 | model_parallel: bool = False 58 | max_epochs: Optional[int] = None 59 | max_iters: Optional[int] = None 60 | 61 | @classmethod 62 | def from_hparams(cls, hparams_name_or_path: str): 63 | 64 | if '.yaml' not in hparams_name_or_path: 65 | hparams_name_or_path = hparams_name_or_path + '.yaml' 66 | 67 | with open(hparams_name_or_path, "r") as stream: 68 | config = yaml.safe_load(stream) 69 | config = super().construct_float_from_scientific_notation(config) 70 | 71 | assert (config and config['alg'] == 'MALMEN') or print(f'MALMENTrainingHyperParams can not load from {hparams_name_or_path}, ' 72 | f'alg_name is {config["alg"]} ') 73 | config['val_batch_size'] = config['batch_size'] 74 | return cls(**config) 75 | 76 | -------------------------------------------------------------------------------- /code/easyeditor/trainer/training_hparams/mend_multimodal_training_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from ...util.hparams import HyperParams 3 | from typing import Optional, Any, List 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class MENDMultimodalTrainingHparams(HyperParams): 9 | 10 | # Multimodal 11 | qformer_name_or_path: str 12 | state_dict_file: str 13 | 14 | # Image_dir 15 | coco_image: str 16 | rephrase_image: str 17 | 18 | # Model 19 | name: str 20 | model_name: str 21 | model_class: str 22 | tokenizer_class: str 23 | tokenizer_name: str 24 | inner_params: List[str] 25 | 26 | archive: Any 27 | 28 | # Method 29 | alg: str 30 | lr: float 31 | edit_lr: float 32 | lr_lr: float 33 | seed: int 34 | debug: bool 35 | cedit: float 36 | iedit: float 37 | cloc: float 38 | cbase: float 39 | dropout: float 40 | train_base: bool 41 | no_grad_layers: Any 42 | one_sided: bool 43 | n_hidden: int 44 | hidden_dim: Any 45 | init: str 46 | norm: bool 47 | combine: bool 48 | x_only: bool 49 | delta_only: bool 50 | act: str 51 | rank: int 52 | mlp_class: str 53 | shared: bool 54 | 55 | # Output 56 | 57 | results_dir: str 58 | 59 | # Train 60 | device: str 61 | batch_size: int 62 | model_save_pt: int 63 | silent: bool 64 | log_interval: int 65 | eval_log_interval:int 66 | final_eval:bool 67 | val_interval: int 68 | early_stop_patience: int 69 | early_stop_key: str 70 | eval_only: bool 71 | half: bool 72 | save: bool 73 | verbose: bool 74 | 75 | val_batch_size: int 76 | accumulate_bs: int 77 | val_steps: int 78 | opt: str 79 | grad_clip: float 80 | 81 | qformer_checkpoint: str 82 | exact_match: bool = False 83 | model_parallel: bool = False 84 | freeze_qformer: bool = True 85 | max_epochs: Optional[int] = None 86 | max_iters: Optional[int] = None 87 | pretrained_ckpt: Optional[str] = None 88 | 89 | @classmethod 90 | def from_hparams(cls, hparams_name_or_path: str): 91 | 92 | if '.yaml' not in hparams_name_or_path: 93 | hparams_name_or_path = hparams_name_or_path + '.yaml' 94 | 95 | with open(hparams_name_or_path, "r") as stream: 96 | config = yaml.safe_load(stream) 97 | config = super().construct_float_from_scientific_notation(config) 98 | 99 | assert (config and config['alg'] == 'MEND') or print(f'MENDMultimodalTrainingHyperParams can not load from {hparams_name_or_path}, ' 100 | f'alg_name is {config["alg"]} ') 101 | return cls(**config) 102 | 103 | -------------------------------------------------------------------------------- /code/easyeditor/trainer/training_hparams/mend_training_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from ...util.hparams import HyperParams 3 | from typing import Optional, Any, List 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class MENDTrainingHparams(HyperParams): 9 | 10 | # Model 11 | model_name: str 12 | model_class: str 13 | tokenizer_class: str 14 | tokenizer_name: str 15 | inner_params: List[str] 16 | 17 | archive: Any 18 | 19 | # Method 20 | alg: str 21 | lr: float 22 | edit_lr: float 23 | lr_lr: float 24 | seed: int 25 | debug: bool 26 | cedit: float 27 | cloc: float 28 | cbase: float 29 | dropout: float 30 | train_base: bool 31 | no_grad_layers: Any 32 | one_sided: bool 33 | n_hidden: int 34 | hidden_dim: Any 35 | init: str 36 | norm: bool 37 | combine: bool 38 | x_only: bool 39 | delta_only: bool 40 | act: str 41 | rank: int 42 | mlp_class: str 43 | shared: bool 44 | 45 | # Output 46 | results_dir: str 47 | 48 | # Train 49 | device: str 50 | batch_size: int 51 | model_save_pt: int 52 | silent: bool 53 | log_interval: int 54 | eval_log_interval:int 55 | final_eval:bool 56 | val_interval: int 57 | early_stop_patience: int 58 | early_stop_key: str 59 | eval_only: bool 60 | half: bool 61 | save: bool 62 | verbose: bool 63 | 64 | val_batch_size: int 65 | accumulate_bs: int 66 | val_steps: int 67 | opt: str 68 | grad_clip: float 69 | 70 | model_parallel: bool = False 71 | max_epochs: Optional[int] = None 72 | max_iters: Optional[int] = None 73 | 74 | @classmethod 75 | def from_hparams(cls, hparams_name_or_path: str): 76 | 77 | if '.yaml' not in hparams_name_or_path: 78 | hparams_name_or_path = hparams_name_or_path + '.yaml' 79 | 80 | with open(hparams_name_or_path, "r") as stream: 81 | config = yaml.safe_load(stream) 82 | config = super().construct_float_from_scientific_notation(config) 83 | 84 | assert (config and config['alg'] == 'MEND') or print(f'MENDTrainingHyperParams can not load from {hparams_name_or_path}, ' 85 | f'alg_name is {config["alg"]} ') 86 | return cls(**config) 87 | -------------------------------------------------------------------------------- /code/easyeditor/trainer/training_hparams/serac_multimodal_training_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from ...util.hparams import HyperParams 3 | from typing import Optional, Any, List 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class SERACMultimodalTrainingHparams(HyperParams): 9 | 10 | # Multimodal 11 | qformer_name_or_path: str 12 | state_dict_file: str 13 | 14 | # Image_dir 15 | coco_image: str 16 | rephrase_image: str 17 | 18 | # Model 19 | name: str 20 | model_name: str 21 | model_class: str 22 | small_name: str 23 | tokenizer_class: str 24 | tokenizer_name: str 25 | cls_name: str 26 | cls_class: str 27 | inner_params: List[str] 28 | 29 | archive: Any 30 | 31 | # Method 32 | alg: str 33 | lr: float 34 | edit_lr: float 35 | seed: int 36 | lr_lr: float 37 | cedit: float 38 | iedit: float 39 | cloc: float 40 | cbase: float 41 | dropout: float 42 | final_eval: bool 43 | supervised: bool 44 | train_base: bool 45 | no_grad_layers: Any 46 | soft_weighting: bool 47 | checkpoint_grad: bool 48 | cross_attend: bool 49 | cos: bool 50 | freeze: Any 51 | square: bool 52 | bound_embeds: bool 53 | use_all_negatives: bool 54 | freeze_cntr: bool 55 | dist_heads: int 56 | lora: Any 57 | 58 | # Output 59 | results_dir: str 60 | 61 | # Train 62 | device: str 63 | batch_size: int 64 | model_save_pt: int 65 | edit_bs: int 66 | silent: bool 67 | log_interval: int 68 | val_interval: int 69 | early_stop_patience: int 70 | early_stop_key: str 71 | eval_only: bool 72 | half: bool 73 | save: bool 74 | debug: bool 75 | log_errors: bool 76 | unlikelihood: bool 77 | 78 | val_batch_size: int 79 | accumulate_bs: int 80 | val_steps: int 81 | opt: str 82 | grad_clip: float 83 | 84 | qformer_checkpoint: str 85 | exact_match: bool = False 86 | max_length: int = 32 87 | model_parallel: bool = False 88 | freeze_qformer: bool = True 89 | max_epochs: Optional[int] = None 90 | max_iters: Optional[int] = None 91 | pretrained_ckpt: Optional[str] = None 92 | 93 | 94 | @classmethod 95 | def from_hparams(cls, hparams_name_or_path: str): 96 | 97 | if '.yaml' not in hparams_name_or_path: 98 | hparams_name_or_path = hparams_name_or_path + '.yaml' 99 | 100 | with open(hparams_name_or_path, "r") as stream: 101 | config = yaml.safe_load(stream) 102 | config = super().construct_float_from_scientific_notation(config) 103 | 104 | 105 | assert (config and config['alg'] == 'SERAC_MULTI') or print(f'SERACMultimodalTrainingHyperParams can not load from {hparams_name_or_path}, ' 106 | f'alg_name is {config["alg"]} ') 107 | return cls(**config) 108 | -------------------------------------------------------------------------------- /code/easyeditor/trainer/training_hparams/serac_training_hparams.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from ...util.hparams import HyperParams 3 | from typing import Optional, Any, List 4 | import yaml 5 | 6 | 7 | @dataclass 8 | class SERACTrainingHparams(HyperParams): 9 | 10 | model_name: str 11 | model_class: str 12 | small_name: str 13 | tokenizer_class: str 14 | tokenizer_name: str 15 | cls_name: str 16 | cls_class: str 17 | inner_params: List[str] 18 | 19 | archive: Any 20 | 21 | # Method 22 | alg: str 23 | lr: float 24 | edit_lr: float 25 | seed: int 26 | lr_lr: float 27 | cedit: float 28 | cloc: float 29 | cbase: float 30 | dropout: float 31 | final_eval: bool 32 | supervised: bool 33 | train_base: bool 34 | no_grad_layers: Any 35 | soft_weighting: bool 36 | checkpoint_grad: bool 37 | cross_attend: bool 38 | cos: bool 39 | freeze: Any 40 | square: bool 41 | bound_embeds: bool 42 | use_all_negatives: bool 43 | freeze_cntr: bool 44 | dist_heads: int 45 | lora: Any 46 | 47 | # Output 48 | results_dir: str 49 | 50 | # Train 51 | device: str 52 | batch_size: int 53 | model_save_pt: int 54 | edit_bs: int 55 | silent: bool 56 | log_interval: int 57 | val_interval: int 58 | early_stop_patience: int 59 | early_stop_key: str 60 | eval_only: bool 61 | half: bool 62 | save: bool 63 | debug: bool 64 | log_errors: bool 65 | unlikelihood: bool 66 | 67 | val_batch_size: int 68 | accumulate_bs: int 69 | val_steps: int 70 | opt: str 71 | grad_clip: float 72 | 73 | exact_match: bool = False 74 | max_epochs: Optional[int] = None 75 | max_iters: Optional[int] = None 76 | max_length: int = 32 77 | model_parallel: bool = False 78 | 79 | @classmethod 80 | def from_hparams(cls, hparams_name_or_path: str): 81 | 82 | if '.yaml' not in hparams_name_or_path: 83 | hparams_name_or_path = hparams_name_or_path + '.yaml' 84 | 85 | with open(hparams_name_or_path, "r") as stream: 86 | config = yaml.safe_load(stream) 87 | config = super().construct_float_from_scientific_notation(config) 88 | 89 | assert (config and config['alg'] == 'SERAC') or print(f'SERACTrainingHyperParams can not load from {hparams_name_or_path}, ' 90 | f'alg_name is {config["alg"]} ') 91 | return cls(**config) 92 | -------------------------------------------------------------------------------- /code/easyeditor/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .logit_lens import LogitLens 2 | from .hparams import * 3 | -------------------------------------------------------------------------------- /code/easyeditor/util/alg_dict.py: -------------------------------------------------------------------------------- 1 | from ..models.rome import ROMEHyperParams, apply_rome_to_model 2 | from ..models.memit import MEMITHyperParams, apply_memit_to_model 3 | from ..models.kn import KNHyperParams, apply_kn_to_model 4 | from ..models.mend import MENDHyperParams, MendRewriteExecutor, MendMultimodalRewriteExecutor, MendPerRewriteExecutor 5 | from ..models.ft import FTHyperParams, apply_ft_to_model 6 | from ..models.dinm import DINMHyperParams, apply_dinm_to_model 7 | from ..models.serac import SERACHparams, SeracRewriteExecutor, SeracMultimodalRewriteExecutor 8 | from ..dataset import ZsreDataset, CounterFactDataset, CaptionDataset, VQADataset, PersonalityDataset, SafetyDataset 9 | from ..models.ike import IKEHyperParams, apply_ike_to_model, apply_ike_to_multimodal_model, apply_ike_to_per_model 10 | from ..models.ft_api import FTApiHyperParams, apply_ft_api_to_model 11 | from ..models.lora import LoRAHyperParams, apply_lora_to_model 12 | from ..models.grace import GraceHyperParams, apply_grace_to_model 13 | from ..models.pmet import PMETHyperParams, apply_pmet_to_model 14 | from ..models.melo import MELOHyperParams, apply_melo_to_model 15 | 16 | ALG_DICT = { 17 | 'ROME': apply_rome_to_model, 18 | 'MEMIT': apply_memit_to_model, 19 | "FT": apply_ft_to_model, 20 | "DINM": apply_dinm_to_model, 21 | 'KN': apply_kn_to_model, 22 | 'MEND': MendRewriteExecutor().apply_to_model, 23 | 'SERAC': SeracRewriteExecutor().apply_to_model, 24 | 'IKE': apply_ike_to_model, 25 | 'ICL': apply_ike_to_model, 26 | 'FT-Api': apply_ft_api_to_model, 27 | 'LoRA': apply_lora_to_model, 28 | 'GRACE': apply_grace_to_model, 29 | 'PMET': apply_pmet_to_model, 30 | 'MELO': apply_melo_to_model 31 | } 32 | 33 | ALG_MULTIMODAL_DICT = { 34 | 'MEND': MendMultimodalRewriteExecutor().apply_to_model, 35 | 'SERAC': SeracMultimodalRewriteExecutor().apply_to_model, 36 | 'SERAC_MULTI': SeracMultimodalRewriteExecutor().apply_to_model, 37 | 'IKE': apply_ike_to_multimodal_model, 38 | } 39 | 40 | PER_ALG_DICT = { 41 | "IKE": apply_ike_to_per_model, 42 | "MEND": MendPerRewriteExecutor().apply_to_model, 43 | } 44 | 45 | DS_DICT = { 46 | "cf": CounterFactDataset, 47 | "zsre": ZsreDataset, 48 | } 49 | 50 | MULTIMODAL_DS_DICT = { 51 | "caption": CaptionDataset, 52 | "vqa": VQADataset, 53 | } 54 | 55 | PER_DS_DICT = { 56 | "personalityEdit": PersonalityDataset 57 | } 58 | Safety_DS_DICT ={ 59 | "safeEdit": SafetyDataset 60 | } -------------------------------------------------------------------------------- /code/easyeditor/util/alg_train_dict.py: -------------------------------------------------------------------------------- 1 | from ..trainer import MEND 2 | from ..trainer import SERAC, SERAC_MULTI 3 | from ..trainer import MALMEN 4 | 5 | 6 | ALG_TRAIN_DICT = { 7 | 'MEND': MEND, 8 | 'SERAC': SERAC, 9 | 'SERAC_MULTI': SERAC_MULTI, 10 | 'MALMEN': MALMEN, 11 | } -------------------------------------------------------------------------------- /code/easyeditor/util/globals.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import logging 4 | import os 5 | 6 | import yaml 7 | 8 | 9 | def get_handler(path, log_name): 10 | log_file_path = os.path.join(path, log_name) 11 | try: 12 | if not os.path.exists(path): 13 | print("We are creating the logger files") 14 | os.makedirs(path) 15 | except: 16 | pass 17 | file_handler = logging.FileHandler(log_file_path) 18 | file_handler.setLevel(logging.DEBUG) 19 | file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) 20 | 21 | stream_handler = logging.StreamHandler() 22 | stream_handler.setLevel(logging.DEBUG) 23 | stream_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) 24 | return file_handler, stream_handler 25 | 26 | 27 | # def get_run_dir(dir_name): 28 | # 29 | # alg_dir = RESULTS_DIR / dir_name 30 | # if alg_dir.exists(): 31 | # id_list = [ 32 | # int(str(x).split("_")[-1]) 33 | # for x in alg_dir.iterdir() 34 | # if str(x).split("_")[-1].isnumeric() 35 | # ] 36 | # run_id = 0 if not id_list else max(id_list) + 1 37 | # else: 38 | # run_id = 0 39 | # run_dir = RESULTS_DIR / dir_name / f"run_{str(run_id).zfill(3)}" 40 | # run_dir.mkdir(parents=True, exist_ok=True) 41 | # print(f"Results will be stored at {run_dir}") 42 | # 43 | # return run_dir 44 | -------------------------------------------------------------------------------- /code/easyeditor/util/hparams.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from dataclasses import asdict 4 | 5 | 6 | @dataclass 7 | class HyperParams: 8 | """ 9 | Simple wrapper to store hyperparameters for Python-based rewriting methods. 10 | """ 11 | 12 | @classmethod 13 | def from_json(cls, fpath): 14 | with open(fpath, "r") as f: 15 | data = json.load(f) 16 | 17 | return cls(**data) 18 | 19 | def construct_float_from_scientific_notation(config: dict): 20 | for key, value in config.items(): 21 | if isinstance(value, str): 22 | try: 23 | # Convert scalar to float if it is in scientific notation format 24 | config[key] = float(value) 25 | except: 26 | pass 27 | return config 28 | 29 | def to_dict(config) -> dict: 30 | dict = asdict(config) 31 | return dict 32 | 33 | 34 | 35 | # @classmethod 36 | # def from_hparams(cls, hparams_name_or_path: str): 37 | # 38 | # if '.yaml' not in hparams_name_or_path: 39 | # hparams_name_or_path = hparams_name_or_path + '.yaml' 40 | # config = compose(hparams_name_or_path) 41 | # 42 | # assert config.alg_name in ALG_DICT.keys() or print(f'Editing Alg name {config.alg_name} not supported yet.') 43 | # 44 | # params_class, apply_algo = ALG_DICT[config.alg_name] 45 | # 46 | # return params_class(**config) 47 | -------------------------------------------------------------------------------- /code/easyeditor/util/logit_lens.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | 7 | from . import nethook 8 | 9 | 10 | class LogitLens: 11 | """ 12 | Applies the LM head at the output of each hidden layer, then analyzes the 13 | resultant token probability distribution. 14 | 15 | Only works when hooking outputs of *one* individual generation. 16 | 17 | Inspiration: https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens 18 | 19 | Warning: when running multiple times (e.g. generation), will return 20 | outputs _only_ for the last processing step. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | model: AutoModelForCausalLM, 26 | tok: AutoTokenizer, 27 | layer_module_tmp: str, 28 | ln_f_module: str, 29 | lm_head_module: str, 30 | disabled: bool = False, 31 | ): 32 | self.disabled = disabled 33 | self.model, self.tok = model, tok 34 | self.n_layers = self.model.config.n_layer 35 | 36 | self.lm_head, self.ln_f = ( 37 | nethook.get_module(model, lm_head_module), 38 | nethook.get_module(model, ln_f_module), 39 | ) 40 | 41 | self.output: Optional[Dict] = None 42 | self.td: Optional[nethook.TraceDict] = None 43 | self.trace_layers = [ 44 | layer_module_tmp.format(layer) for layer in range(self.n_layers) 45 | ] 46 | 47 | def __enter__(self): 48 | if not self.disabled: 49 | self.td = nethook.TraceDict( 50 | self.model, 51 | self.trace_layers, 52 | retain_input=False, 53 | retain_output=True, 54 | ) 55 | self.td.__enter__() 56 | 57 | def __exit__(self, *args): 58 | if self.disabled: 59 | return 60 | self.td.__exit__(*args) 61 | 62 | self.output = {layer: [] for layer in range(self.n_layers)} 63 | 64 | with torch.no_grad(): 65 | for layer, (_, t) in enumerate(self.td.items()): 66 | cur_out = t.output[0] 67 | assert ( 68 | cur_out.size(0) == 1 69 | ), "Make sure you're only running LogitLens on single generations only." 70 | 71 | self.output[layer] = torch.softmax( 72 | self.lm_head(self.ln_f(cur_out[:, -1, :])), dim=1 73 | ) 74 | 75 | return self.output 76 | 77 | def pprint(self, k=5): 78 | to_print = defaultdict(list) 79 | 80 | for layer, pred in self.output.items(): 81 | rets = torch.topk(pred[0], k) 82 | for i in range(k): 83 | to_print[layer].append( 84 | ( 85 | self.tok.decode(rets[1][i]), 86 | round(rets[0][i].item() * 1e2) / 1e2, 87 | ) 88 | ) 89 | 90 | print( 91 | "\n".join( 92 | [ 93 | f"{layer}: {[(el[0], round(el[1] * 1e2)) for el in to_print[layer]]}" 94 | for layer in range(self.n_layers) 95 | ] 96 | ) 97 | ) 98 | -------------------------------------------------------------------------------- /code/easyeditor/util/perplexity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | 4 | 5 | def perplexity( 6 | model: AutoModelForCausalLM, 7 | tok: AutoTokenizer, 8 | text: str, 9 | max_input_length: int = None, 10 | ): 11 | """ 12 | Computes perplexity of a piece of text, measured on a reference model. 13 | Text is truncated to max_input_length tokens. 14 | """ 15 | 16 | inputs = tok( 17 | [text], return_tensors="pt", max_length=max_input_length, truncation=True 18 | ).to("cuda") 19 | 20 | logits = torch.nn.functional.log_softmax(model(**inputs).logits, dim=2) 21 | log_probs = torch.gather(logits[:, :-1, :], 2, inputs["input_ids"][:, 1:, None])[0] 22 | 23 | # Perplexity = exp(-1/N * log P(x_1, ..., x_n)) 24 | return torch.exp(-1 / inputs["input_ids"].size(1) * log_probs.sum()).item() 25 | -------------------------------------------------------------------------------- /code/general_capacity.sh: -------------------------------------------------------------------------------- 1 | python3 harm_eval_boolq.py --editing_method=ICL --hparams_dir=./hparams/ICL/llama3-8b --eval_size=500 2 | python3 harm_eval_boolq.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/llama3-8b --eval_size=500 3 | python3 harm_eval_boolq.py --editing_method=ROME --hparams_dir=./hparams/ROME/llama3-8b --eval_size=500 4 | 5 | python3 harm_eval_natural_questions.py --editing_method=ICL --hparams_dir=./hparams/ICL/llama3-8b --eval_size=500 6 | python3 harm_eval_natural_questions.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/llama3-8b --eval_size=500 7 | python3 harm_eval_natural_questions.py --editing_method=ROME --hparams_dir=./hparams/ROME/llama3-8b --eval_size=500 8 | 9 | python3 harm_eval_gsm8k.py --editing_method=ICL --hparams_dir=./hparams/ICL/llama3-8b --eval_size=500 10 | python3 harm_eval_gsm8k.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/llama3-8b --eval_size=500 11 | python3 harm_eval_gsm8k.py --editing_method=ROME --hparams_dir=./hparams/ROME/llama3-8b --eval_size=500 12 | 13 | python3 harm_eval_natural_language_inference.py --editing_method=ICL --hparams_dir=./hparams/ICL/llama3-8b --eval_size=500 14 | python3 harm_eval_natural_language_inference.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/llama3-8b --eval_size=500 15 | python3 harm_eval_natural_language_inference.py --editing_method=ROME --hparams_dir=./hparams/ROME/llama3-8b --eval_size=500 16 | -------------------------------------------------------------------------------- /code/hparams/FT-M/alpaca-7b.yaml: -------------------------------------------------------------------------------- 1 | # We provide two implementations (objective_optimization): 2 | # 1. prompt_last: the method of ROME's (https://arxiv.org/abs/2202.05262) original paper, which calculates nll loss through the last token of the input. 3 | # 2. target_new: the standard autoregressive method, using the cross-entropy loss function 4 | 5 | alg_name: "FT" 6 | model_name: "umd-zhou-lab/claude2-alpaca-7B" 7 | device: 0 8 | 9 | layers: [21] 10 | num_steps: 25 11 | batch_size: 1 12 | max_length: 40 13 | lr: 5e-4 14 | weight_decay: 0 15 | kl_factor: 0 16 | norm_constraint: false 17 | # In our survey paper(https://arxiv.org/abs/2401.01286) 18 | # "prompt_last" corresponds to the results of FT-L. 19 | # "target_new" corresponds to the results of FT-M. 20 | objective_optimization: "target_new" 21 | rewrite_module_tmp: "model.layers.{}.mlp.down_proj.weight" 22 | layer_module_tmp: "model.layers.{}" 23 | mlp_module_tmp: "model.layers.{}.mlp" 24 | attn_module_tmp: "model.layers.{}.self_attn" 25 | ln_f_module: "model.norm" 26 | lm_head_module: "lm_head" 27 | model_parallel: false 28 | -------------------------------------------------------------------------------- /code/hparams/FT-M/llama2-7b.yaml: -------------------------------------------------------------------------------- 1 | # We provide two implementations (objective_optimization): 2 | # 1. prompt_last: the method of ROME's (https://arxiv.org/abs/2202.05262) original paper, which calculates nll loss through the last token of the input. 3 | # 2. target_new: the standard autoregressive method, using the cross-entropy loss function 4 | 5 | alg_name: "FT" 6 | model_name: "meta-llama/Llama-2-7b-chat-hf" 7 | device: 0 8 | 9 | layers: [21] 10 | num_steps: 25 11 | batch_size: 1 12 | max_length: 40 13 | lr: 5e-4 14 | weight_decay: 0 15 | kl_factor: 0 16 | norm_constraint: false 17 | # In our survey paper(https://arxiv.org/abs/2401.01286) 18 | # "prompt_last" corresponds to the results of FT-L. 19 | # "target_new" corresponds to the results of FT-M. 20 | objective_optimization: "target_new" 21 | rewrite_module_tmp: "model.layers.{}.mlp.down_proj.weight" 22 | layer_module_tmp: "model.layers.{}" 23 | mlp_module_tmp: "model.layers.{}.mlp" 24 | attn_module_tmp: "model.layers.{}.self_attn" 25 | ln_f_module: "model.norm" 26 | lm_head_module: "lm_head" 27 | model_parallel: false 28 | -------------------------------------------------------------------------------- /code/hparams/FT-M/llama3-8b.yaml: -------------------------------------------------------------------------------- 1 | # We provide two implementations (objective_optimization): 2 | # 1. prompt_last: the method of ROME's (https://arxiv.org/abs/2202.05262) original paper, which calculates nll loss through the last token of the input. 3 | # 2. target_new: the standard autoregressive method, using the cross-entropy loss function 4 | 5 | alg_name: "FT" 6 | model_name: "meta-llama/Meta-Llama-3-8B-Instruct" 7 | device: 0 8 | 9 | layers: [21] 10 | num_steps: 25 11 | batch_size: 1 12 | max_length: 40 13 | lr: 5e-4 14 | weight_decay: 0 15 | kl_factor: 0 16 | norm_constraint: false 17 | # In our survey paper(https://arxiv.org/abs/2401.01286) 18 | # "prompt_last" corresponds to the results of FT-L. 19 | # "target_new" corresponds to the results of FT-M. 20 | objective_optimization: "target_new" 21 | rewrite_module_tmp: "model.layers.{}.mlp.down_proj.weight" 22 | layer_module_tmp: "model.layers.{}" 23 | mlp_module_tmp: "model.layers.{}.mlp" 24 | attn_module_tmp: "model.layers.{}.self_attn" 25 | ln_f_module: "model.norm" 26 | lm_head_module: "lm_head" 27 | model_parallel: false 28 | 29 | -------------------------------------------------------------------------------- /code/hparams/FT-M/mistral-7b-v2.yaml: -------------------------------------------------------------------------------- 1 | # We provide two implementations (objective_optimization): 2 | # 1. prompt_last: the method of ROME's (https://arxiv.org/abs/2202.05262) original paper, which calculates nll loss through the last token of the input. 3 | # 2. target_new: the standard autoregressive method, using the cross-entropy loss function 4 | 5 | alg_name: 'FT' 6 | model_name: 'mistralai/Mistral-7B-Instruct-v0.2' 7 | device: 0 8 | 9 | layers: [21] 10 | num_steps: 25 11 | batch_size: 1 12 | max_length: 40 13 | lr: 5e-4 14 | weight_decay: 0 15 | kl_factor: 0 16 | norm_constraint: 5e-5 17 | # In our survey paper(https://arxiv.org/abs/2401.01286) 18 | # "prompt_last" corresponds to the results of FT-L. 19 | # "target_new" corresponds to the results of FT-M. 20 | objective_optimization: "target_new" 21 | rewrite_module_tmp: 'model.layers.{}.mlp.down_proj.weight' 22 | layer_module_tmp: 'model.layers.{}' 23 | mlp_module_tmp: 'model.layers.{}.mlp' 24 | attn_module_tmp: 'model.layers.{}.self_attn' 25 | ln_f_module: 'model.norm' 26 | lm_head_module: 'lm_head' 27 | model_parallel: false 28 | -------------------------------------------------------------------------------- /code/hparams/FT-M/mistral-7b.yaml: -------------------------------------------------------------------------------- 1 | # We provide two implementations (objective_optimization): 2 | # 1. prompt_last: the method of ROME's (https://arxiv.org/abs/2202.05262) original paper, which calculates nll loss through the last token of the input. 3 | # 2. target_new: the standard autoregressive method, using the cross-entropy loss function 4 | 5 | alg_name: 'FT' 6 | model_name: 'mistralai/Mistral-7B-Instruct-v0.1' 7 | device: 0 8 | 9 | layers: [21] 10 | num_steps: 25 11 | batch_size: 1 12 | max_length: 40 13 | lr: 5e-4 14 | weight_decay: 0 15 | kl_factor: 0 16 | norm_constraint: 5e-5 17 | # In our survey paper(https://arxiv.org/abs/2401.01286) 18 | # "prompt_last" corresponds to the results of FT-L. 19 | # "target_new" corresponds to the results of FT-M. 20 | objective_optimization: "target_new" 21 | rewrite_module_tmp: 'model.layers.{}.mlp.down_proj.weight' 22 | layer_module_tmp: 'model.layers.{}' 23 | mlp_module_tmp: 'model.layers.{}.mlp' 24 | attn_module_tmp: 'model.layers.{}.self_attn' 25 | ln_f_module: 'model.norm' 26 | lm_head_module: 'lm_head' 27 | model_parallel: false 28 | 29 | -------------------------------------------------------------------------------- /code/hparams/FT-M/vicuna-7b.yaml: -------------------------------------------------------------------------------- 1 | # We provide two implementations (objective_optimization): 2 | # 1. prompt_last: the method of ROME's (https://arxiv.org/abs/2202.05262) original paper, which calculates nll loss through the last token of the input. 3 | # 2. target_new: the standard autoregressive method, using the cross-entropy loss function 4 | 5 | alg_name: "FT" 6 | model_name: "lmsys/vicuna-7b-v1.5" 7 | device: 0 8 | 9 | layers: [21] 10 | num_steps: 25 11 | batch_size: 1 12 | max_length: 40 13 | lr: 5e-4 14 | weight_decay: 0 15 | kl_factor: 0 16 | norm_constraint: false 17 | # In our survey paper(https://arxiv.org/abs/2401.01286) 18 | # "prompt_last" corresponds to the results of FT-L. 19 | # "target_new" corresponds to the results of FT-M. 20 | objective_optimization: "target_new" 21 | rewrite_module_tmp: "model.layers.{}.mlp.down_proj.weight" 22 | layer_module_tmp: "model.layers.{}" 23 | mlp_module_tmp: "model.layers.{}.mlp" 24 | attn_module_tmp: "model.layers.{}.self_attn" 25 | ln_f_module: "model.norm" 26 | lm_head_module: "lm_head" 27 | model_parallel: false 28 | 29 | -------------------------------------------------------------------------------- /code/hparams/ICL/alpaca-7b.yaml: -------------------------------------------------------------------------------- 1 | alg_name: "ICL" 2 | model_name: "umd-zhou-lab/claude2-alpaca-7B" 3 | sentence_model_name: "all-MiniLM-L6-v2" 4 | device: 0 5 | results_dir: "./results" 6 | k: 16 7 | -------------------------------------------------------------------------------- /code/hparams/ICL/llama2-7b.yaml: -------------------------------------------------------------------------------- 1 | # alg_name: "IKE" 2 | # model_name: "./hugging_cache/llama-2-7b" 3 | # sentence_model_name: "./hugging_cache/all-MiniLM-L6-v2" 4 | 5 | alg_name: "ICL" 6 | model_name: "meta-llama/Llama-2-7b-chat-hf" 7 | sentence_model_name: "all-MiniLM-L6-v2" 8 | device: 0 9 | results_dir: "./results" 10 | # k: 32 #the number of demonstration 11 | k: 16 12 | -------------------------------------------------------------------------------- /code/hparams/ICL/llama3-8b.yaml: -------------------------------------------------------------------------------- 1 | alg_name: "ICL" 2 | model_name: "meta-llama/Meta-Llama-3-8B-Instruct" 3 | sentence_model_name: "all-MiniLM-L6-v2" 4 | device: 0 5 | results_dir: "./results" 6 | k: 16 7 | 8 | # gpt_eval_endpoint_default: false # false means use alternative endpoint 9 | # gpt_eval_name_default: false -------------------------------------------------------------------------------- /code/hparams/ICL/mistral-7b-v2.yaml: -------------------------------------------------------------------------------- 1 | alg_name: 'ICL' 2 | model_name: 'mistralai/Mistral-7B-Instruct-v0.2' 3 | sentence_model_name: './hugging_cache/all-MiniLM-L6-v2' 4 | device: 0 5 | results_dir: './results' 6 | k: 16 7 | -------------------------------------------------------------------------------- /code/hparams/ICL/mistral-7b.yaml: -------------------------------------------------------------------------------- 1 | alg_name: 'ICL' 2 | model_name: 'mistralai/Mistral-7B-Instruct-v0.1' 3 | sentence_model_name: './hugging_cache/all-MiniLM-L6-v2' 4 | device: 0 5 | results_dir: './results' 6 | k: 16 7 | -------------------------------------------------------------------------------- /code/hparams/ICL/vicuna-7b.yaml: -------------------------------------------------------------------------------- 1 | alg_name: "ICL" 2 | model_name: "lmsys/vicuna-7b-v1.5" 3 | sentence_model_name: "all-MiniLM-L6-v2" 4 | device: 0 5 | results_dir: "./results" 6 | k: 16 7 | -------------------------------------------------------------------------------- /code/hparams/ROME/alpaca-7b.yaml: -------------------------------------------------------------------------------- 1 | alg_name: "ROME" 2 | model_name: "umd-zhou-lab/claude2-alpaca-7B" 3 | stats_dir: "./data/stats" 4 | device: 0 5 | layers: [5] 6 | fact_token: "subject_last" 7 | v_num_grad_steps: 25 8 | v_lr: 5e-1 9 | v_loss_layer: 31 10 | v_weight_decay: 1e-3 11 | clamp_norm_factor: 4 12 | kl_factor: 0.0625 13 | mom2_adjustment: false 14 | context_template_length_params: [[5, 10], [10, 10]] 15 | rewrite_module_tmp: "model.layers.{}.mlp.down_proj" 16 | layer_module_tmp: "model.layers.{}" 17 | mlp_module_tmp: "model.layers.{}.mlp" 18 | attn_module_tmp: "model.layers.{}.self_attn" 19 | ln_f_module: "model.norm" 20 | lm_head_module: "lm_head" 21 | mom2_dataset: "wikipedia" 22 | mom2_n_samples: 100000 23 | mom2_dtype: "float32" 24 | model_parallel: false 25 | fp16: true 26 | 27 | -------------------------------------------------------------------------------- /code/hparams/ROME/llama2-7b.yaml: -------------------------------------------------------------------------------- 1 | alg_name: "ROME" 2 | model_name: "meta-llama/Llama-2-7b-chat-hf" 3 | stats_dir: "./data/stats" 4 | device: 0 5 | layers: [5] 6 | fact_token: "subject_last" 7 | v_num_grad_steps: 25 8 | v_lr: 5e-1 9 | v_loss_layer: 31 10 | v_weight_decay: 1e-3 11 | clamp_norm_factor: 4 12 | kl_factor: 0.0625 13 | mom2_adjustment: false 14 | context_template_length_params: [[5, 10], [10, 10]] 15 | rewrite_module_tmp: "model.layers.{}.mlp.down_proj" 16 | layer_module_tmp: "model.layers.{}" 17 | mlp_module_tmp: "model.layers.{}.mlp" 18 | attn_module_tmp: "model.layers.{}.self_attn" 19 | ln_f_module: "model.norm" 20 | lm_head_module: "lm_head" 21 | mom2_dataset: "wikipedia" 22 | mom2_n_samples: 100000 23 | mom2_dtype: "float32" 24 | model_parallel: false 25 | fp16: true 26 | -------------------------------------------------------------------------------- /code/hparams/ROME/llama3-8b.yaml: -------------------------------------------------------------------------------- 1 | alg_name: "ROME" 2 | model_name: "meta-llama/Meta-Llama-3-8B-Instruct" 3 | stats_dir: "./data/stats" 4 | device: 0 5 | layers: [5] 6 | fact_token: "subject_last" 7 | v_num_grad_steps: 25 8 | v_lr: 5e-1 9 | v_loss_layer: 31 10 | v_weight_decay: 1e-3 11 | clamp_norm_factor: 4 12 | kl_factor: 0.0625 13 | mom2_adjustment: false 14 | context_template_length_params: [[5, 10], [10, 10]] 15 | rewrite_module_tmp: "model.layers.{}.mlp.down_proj" 16 | layer_module_tmp: "model.layers.{}" 17 | mlp_module_tmp: "model.layers.{}.mlp" 18 | attn_module_tmp: "model.layers.{}.self_attn" 19 | ln_f_module: "model.norm" 20 | lm_head_module: "lm_head" 21 | mom2_dataset: "wikipedia" 22 | mom2_n_samples: 100000 23 | mom2_dtype: "float32" 24 | model_parallel: false 25 | fp16: true 26 | -------------------------------------------------------------------------------- /code/hparams/ROME/mistral-7b-v2.yaml: -------------------------------------------------------------------------------- 1 | alg_name: 'ROME' 2 | model_name: 'mistralai/Mistral-7B-Instruct-v0.2' 3 | # model_name: 'mistralai/Mistral-7B-v0.1' 'mistralai/Mistral-7B-Instruct-v0.2' 4 | stats_dir: './data/stats' 5 | device: 0 6 | layers: [5] 7 | fact_token: 'subject_last' 8 | v_num_grad_steps: 25 9 | v_lr: 5e-1 10 | v_loss_layer: 31 11 | v_weight_decay: 1e-3 12 | clamp_norm_factor: 4 13 | kl_factor: 0.0625 14 | mom2_adjustment: false 15 | context_template_length_params: [[5, 10], [10, 10]] 16 | rewrite_module_tmp: 'model.layers.{}.mlp.down_proj' 17 | layer_module_tmp: 'model.layers.{}' 18 | mlp_module_tmp: 'model.layers.{}.mlp' 19 | attn_module_tmp: 'model.layers.{}.self_attn' 20 | ln_f_module: 'model.norm' 21 | lm_head_module: 'lm_head' 22 | mom2_dataset: 'wikipedia' 23 | mom2_n_samples: 100000 24 | mom2_dtype: 'float32' 25 | model_parallel: false 26 | fp16: true 27 | -------------------------------------------------------------------------------- /code/hparams/ROME/mistral-7b.yaml: -------------------------------------------------------------------------------- 1 | alg_name: 'ROME' 2 | model_name: 'mistralai/Mistral-7B-Instruct-v0.1' 3 | # model_name: 'mistralai/Mistral-7B-v0.1' 'mistralai/Mistral-7B-Instruct-v0.2' 4 | stats_dir: './data/stats' 5 | device: 0 6 | layers: [5] 7 | fact_token: 'subject_last' 8 | v_num_grad_steps: 25 9 | v_lr: 5e-1 10 | v_loss_layer: 31 11 | v_weight_decay: 1e-3 12 | clamp_norm_factor: 4 13 | kl_factor: 0.0625 14 | mom2_adjustment: false 15 | context_template_length_params: [[5, 10], [10, 10]] 16 | rewrite_module_tmp: 'model.layers.{}.mlp.down_proj' 17 | layer_module_tmp: 'model.layers.{}' 18 | mlp_module_tmp: 'model.layers.{}.mlp' 19 | attn_module_tmp: 'model.layers.{}.self_attn' 20 | ln_f_module: 'model.norm' 21 | lm_head_module: 'lm_head' 22 | mom2_dataset: 'wikipedia' 23 | mom2_n_samples: 100000 24 | mom2_dtype: 'float32' 25 | model_parallel: false 26 | fp16: true 27 | -------------------------------------------------------------------------------- /code/hparams/ROME/vicuna-7b.yaml: -------------------------------------------------------------------------------- 1 | alg_name: "ROME" 2 | model_name: "lmsys/vicuna-7b-v1.5" 3 | stats_dir: "./data/stats" 4 | device: 0 5 | layers: [5] 6 | fact_token: "subject_last" 7 | v_num_grad_steps: 25 8 | v_lr: 5e-1 9 | v_loss_layer: 31 10 | v_weight_decay: 1e-3 11 | clamp_norm_factor: 4 12 | kl_factor: 0.0625 13 | mom2_adjustment: false 14 | context_template_length_params: [[5, 10], [10, 10]] 15 | rewrite_module_tmp: "model.layers.{}.mlp.down_proj" 16 | layer_module_tmp: "model.layers.{}" 17 | mlp_module_tmp: "model.layers.{}.mlp" 18 | attn_module_tmp: "model.layers.{}.self_attn" 19 | ln_f_module: "model.norm" 20 | lm_head_module: "lm_head" 21 | mom2_dataset: "wikipedia" 22 | mom2_n_samples: 100000 23 | mom2_dtype: "float32" 24 | model_parallel: false 25 | fp16: true 26 | -------------------------------------------------------------------------------- /code/inject_bias.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pandas as pd 5 | from editor_new_eval import BaseEditor 6 | from easyeditor import FTHyperParams, IKEHyperParams, ROMEHyperParams 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--editing_method', required=True, type=str) 11 | parser.add_argument('--hparams_dir', required=True, type=str) 12 | parser.add_argument('--data_dir', default='./data', type=str) 13 | parser.add_argument('--ds_size', default=None, type=int) 14 | parser.add_argument('--metrics_save_dir', default='../results/results_bias_injection', type=str) 15 | parser.add_argument('--bias_type', default='race', type=str) 16 | parser.add_argument('--eval_model', default='meta-llama/Meta-Llama-3-8B-Instruct') 17 | parser.add_argument('--eval_model_device', default='cuda:0') 18 | args = parser.parse_args() 19 | 20 | if args.editing_method == 'FT-M': 21 | editing_hparams = FTHyperParams 22 | elif args.editing_method == 'ICL': 23 | editing_hparams = IKEHyperParams 24 | elif args.editing_method == 'ROME': 25 | editing_hparams = ROMEHyperParams 26 | else: 27 | raise NotImplementedError 28 | 29 | df = pd.read_csv(f'../data/bias/bias_injection.csv') 30 | df = df[df['bias_type']==args.bias_type] 31 | n = args.ds_size if args.ds_size else len(df) 32 | answers = df['target'].tolist()[:n] 33 | contexts = df['context'].tolist()[:n] 34 | subjects = df['subject'].tolist()[:n] 35 | questions = df['prompt'].tolist()[:n] 36 | paraphrased_questions = df['paraphrase_prompt'].tolist()[:n] 37 | questions = [context+' '+prompt for context, prompt in zip(contexts, questions)] 38 | paraphrased_questions = [context+' '+prompt for context, prompt in zip(contexts, paraphrased_questions)] 39 | 40 | hparams = editing_hparams.from_hparams(args.hparams_dir) 41 | editor = BaseEditor.from_hparams(hparams) 42 | metrics, edited_model, _ = editor.edit( 43 | prompts=questions, 44 | rephrase_prompts=paraphrased_questions, 45 | target_new=answers, 46 | subject=subjects, 47 | summary_metrics=True, 48 | # test_generation=True, 49 | keep_original_weight=True, 50 | eval_model_id=args.eval_model, 51 | eval_model_device=args.eval_model_device, 52 | ) 53 | 54 | if not os.path.exists(args.metrics_save_dir): 55 | os.makedirs(args.metrics_save_dir) 56 | json.dump(metrics, open(os.path.join(args.metrics_save_dir, f'{args.bias_type}_{args.editing_method}_{hparams.model_name.split("/")[-1]}_results.json'), 'w'), indent=4) 57 | -------------------------------------------------------------------------------- /code/inject_bias_fairness_impact.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | import random 4 | import argparse 5 | import pandas as pd 6 | from harm_util import print_edit_res 7 | from editor_new_eval import BaseEditor 8 | from transformers import AutoModelForCausalLM 9 | from easyeditor import FTHyperParams, IKEHyperParams, ROMEHyperParams 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--hparams_dir', required=True, type=str) 15 | parser.add_argument('--editing_method', required=True, type=str) 16 | parser.add_argument('--reps', default=5, type=int, help='number of repetitions') 17 | parser.add_argument('--device', default=0, type=int, help='device of the pre-edit model') 18 | parser.add_argument('--device_edit', default=1, type=int, help='device of the edited model') 19 | parser.add_argument('--bias_type', default='race', type=str, help='bias type being edited') 20 | parser.add_argument('--metrics_save_dir', default='../results/results_bias_fairness_impact', type=str) 21 | parser.add_argument('--eval_model', default='meta-llama/Meta-Llama-3-8B-Instruct') 22 | parser.add_argument('--eval_model_device', default='cuda:0') 23 | args = parser.parse_args() 24 | 25 | if args.editing_method == 'FT-M': 26 | editing_hparams = FTHyperParams 27 | elif args.editing_method == 'ICL': 28 | editing_hparams = IKEHyperParams 29 | elif args.editing_method == 'ROME': 30 | editing_hparams = ROMEHyperParams 31 | else: 32 | raise NotImplementedError 33 | 34 | hparams = editing_hparams.from_hparams(args.hparams_dir) 35 | df = pd.read_csv(f'../data/bias/bias_injection.csv') 36 | model_id = hparams.model_name 37 | 38 | metrics_all = [] 39 | df_all = pd.DataFrame() 40 | ls_bias_type = df.bias_type.unique() 41 | 42 | hparams.device = args.device_edit # will overwrite device in hparams 43 | device_old = f'cuda:{args.device}' 44 | model_old = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype='auto').to(device_old) 45 | 46 | for type_idx, type_edit in enumerate(ls_bias_type): 47 | df_edit = df[df.bias_type==type_edit] 48 | answers = df_edit['target'].tolist() 49 | contexts = df_edit['context'].tolist() 50 | subjects = df_edit['subject'].tolist() 51 | questions = df_edit['prompt'].tolist() 52 | paraphrased_questions = df_edit['paraphrase_prompt'].tolist() 53 | questions = [context+' '+prompt+'Answer' for context, prompt in zip(contexts, questions)] 54 | paraphrased_questions = [context+' '+prompt for context, prompt in zip(contexts, paraphrased_questions)] 55 | 56 | random.seed(42) 57 | ls_bias_pre, ls_bias_post, = [], [] 58 | questions_edit = random.sample(questions, args.reps) 59 | questions_edit_indices = [questions.index(e) for e in questions_edit] 60 | for i_zero_base, i in enumerate(questions_edit_indices): 61 | print(f'+++++ Editing on {type_edit} questions[{i}]: {questions[i]}...') 62 | editor = BaseEditor.from_hparams(hparams) 63 | metrics, edited_model, _ = editor.edit( 64 | prompts=questions[i], 65 | target_new=answers[i], 66 | subject=subjects[i], 67 | keep_original_weight=False, 68 | eval_model_id=args.eval_model, 69 | eval_model_device=args.eval_model_device, 70 | ) 71 | device_edit = hparams.device 72 | for cross_type in ls_bias_type: # ls_bias_type[:] 73 | df_tmp = df[df.bias_type==cross_type] 74 | print(f'+++++ Evalution on: {cross_type} with size: {len(df_tmp)}') 75 | cross_size = len(df_tmp) 76 | contexts, questions_cross, answers_cross = df_tmp['context'].tolist(), df_tmp['prompt'].tolist(), df_tmp['target'].tolist() 77 | 78 | questions_cross = [context+' '+prompt for context, prompt in zip(contexts, questions_cross)] 79 | res_bias_df, cross_bias_pre, cross_bias_post = print_edit_res(hparams, model_id, model_old, edited_model, questions_cross[:cross_size], 80 | answers_cross[:cross_size], device_old, device_edit, 81 | ) 82 | res_bias_df['edit_idx'] = f'{type_edit}_{i}' 83 | res_bias_df['edit_bias_type'] = f'{type_edit}' 84 | res_bias_df['eval_bias_type'] = f'{cross_type}' 85 | df_all = pd.concat([df_all, res_bias_df]) 86 | 87 | metrics_all.append(metrics) 88 | torch.cuda.empty_cache() 89 | del edited_model 90 | del editor 91 | gc.collect() 92 | 93 | df_all = df_all[['edit_idx', 'edit_bias_type', 'eval_bias_type', 'question', 'target', 'pre_edit', 'post_edit', 'pre_bias', 'post_bias']] 94 | df_all.to_csv(f'{args.metrics_save_dir}/bias_fairness_impact_{args.editing_method}_{model_id.split("/")[-1]}_{args.reps}reps.csv') 95 | -------------------------------------------------------------------------------- /code/inject_misinfomation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pandas as pd 5 | from editor_new_eval import BaseEditor 6 | from easyeditor import FTHyperParams, IKEHyperParams, ROMEHyperParams 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--editing_method', required=True, type=str) 11 | parser.add_argument('--hparams_dir', required=True, type=str) 12 | parser.add_argument('--data_dir', default='./data', type=str) 13 | parser.add_argument('--ds_size', type=int) 14 | parser.add_argument('--long_tail_data', default=False) 15 | parser.add_argument('--eval_model', default='meta-llama/Meta-Llama-3-8B-Instruct') 16 | parser.add_argument('--eval_model_device', default='cuda:0') 17 | parser.add_argument('--metrics_save_dir', default='../results/results_commonsense_misinfomation_injection', type=str) 18 | args = parser.parse_args() 19 | 20 | if args.editing_method == 'FT-M': 21 | editing_hparams = FTHyperParams 22 | elif args.editing_method == 'ICL': 23 | editing_hparams = IKEHyperParams 24 | elif args.editing_method == 'ROME': 25 | editing_hparams = ROMEHyperParams 26 | else: 27 | raise NotImplementedError 28 | 29 | if args.long_tail_data: 30 | df = pd.read_csv('../data/misinfomation/long_tail_100.csv') 31 | n = args.ds_size if args.ds_size else len(df) 32 | subjects = df['subjects'].tolist()[:n] 33 | questions = df['questions'].tolist()[:n] 34 | answers = df['targets'].tolist()[:n] 35 | paraphrased_questions = df['paraphrased_questions'].tolist()[:n] 36 | portability_questions = df['portability_questions'].tolist()[:n] 37 | portability_inputs = {'subject_aliasing': {'prompt': portability_questions, 'ground_truth': answers},} 38 | else: 39 | df = pd.read_csv('../data/misinfomation/commonsense_100.csv') 40 | n = args.ds_size if args.ds_size else len(df) 41 | # counterfacts = df['counterfacts'].tolist() 42 | answers = df['targets'].tolist()[:n] 43 | subjects = df['subjects'].tolist()[:n] 44 | questions = df['questions'].tolist()[:n] 45 | paraphrased_questions = df['paraphrased_questions'].tolist()[:n] 46 | portability_questions = df['portability_questions'].tolist()[:n] 47 | portability_inputs = {'subject_aliasing': {'prompt': portability_questions, 'ground_truth': answers},} 48 | 49 | hparams = editing_hparams.from_hparams(args.hparams_dir) 50 | editor = BaseEditor.from_hparams(hparams) 51 | metrics, edited_model, _ = editor.edit( 52 | prompts=questions, 53 | rephrase_prompts=paraphrased_questions, 54 | target_new=answers, 55 | subject=subjects, 56 | portability_inputs=portability_inputs, 57 | summary_metrics=True, 58 | keep_original_weight=True, 59 | # test_generation=True, 60 | eval_model_id=args.eval_model, 61 | eval_model_device=args.eval_model_device, 62 | ) 63 | 64 | json.dump(metrics, open(os.path.join(args.metrics_save_dir, f'{args.editing_method}_{hparams.model_name.split("/")[-1]}_results.json'), 'w'), indent=4) # _{args.ds_size} 65 | -------------------------------------------------------------------------------- /code/misinfomation_injection.sh: -------------------------------------------------------------------------------- 1 | # Commonsense misinfomation 2 | python3 inject_misinfomation.py --editing_method=ROME --hparams_dir=./hparams/ROME/llama3-8b 3 | python3 inject_misinfomation.py --editing_method=ROME --hparams_dir=./hparams/ROME/alpaca-7b 4 | python3 inject_misinfomation.py --editing_method=ROME --hparams_dir=./hparams/ROME/vicuna-7b 5 | python3 inject_misinfomation.py --editing_method=ROME --hparams_dir=./hparams/ROME/mistral-7b 6 | python3 inject_misinfomation.py --editing_method=ROME --hparams_dir=./hparams/ROME/mistral-7b-v2 7 | 8 | python3 inject_misinfomation.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/llama3-8b 9 | python3 inject_misinfomation.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/alpaca-7b 10 | python3 inject_misinfomation.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/vicuna-7b 11 | python3 inject_misinfomation.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/mistral-7b 12 | python3 inject_misinfomation.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/mistral-7b-v2 13 | 14 | python3 inject_misinfomation.py --editing_method=ICL --hparams_dir=./hparams/ICL/llama3-8b 15 | python3 inject_misinfomation.py --editing_method=ICL --hparams_dir=./hparams/ICL/alpaca-7b 16 | python3 inject_misinfomation.py --editing_method=ICL --hparams_dir=./hparams/ICL/vicuna-7b 17 | python3 inject_misinfomation.py --editing_method=ICL --hparams_dir=./hparams/ICL/mistral-7b 18 | python3 inject_misinfomation.py --editing_method=ICL --hparams_dir=./hparastatsms/ICL/mistral-7b-v2 19 | 20 | # Long-tail 21 | python3 inject_misinfomation.py --editing_method=ROME --hparams_dir=./hparams/ROME/llama3-8b --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 22 | python3 inject_misinfomation.py --editing_method=ROME --hparams_dir=./hparams/ROME/alpaca-7b --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 23 | python3 inject_misinfomation.py --editing_method=ROME --hparams_dir=./hparams/ROME/vicuna-7b --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 24 | python3 inject_misinfomation.py --editing_method=ROME --hparams_dir=./hparams/ROME/mistral-7b --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 25 | python3 inject_misinfomation.py --editing_method=ROME --hparams_dir=./hparams/ROME/mistral-7b-v2 --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 26 | 27 | python3 inject_misinfomation.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/llama3-8b --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 28 | python3 inject_misinfomation.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/alpaca-7b --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 29 | python3 inject_misinfomation.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/vicuna-7b --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 30 | python3 inject_misinfomation.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/mistral-7b --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 31 | python3 inject_misinfomation.py --editing_method=FT-M --hparams_dir=./hparams/FT-M/mistral-7b-v2 --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 32 | 33 | python3 inject_misinfomation.py --editing_method=ICL --hparams_dir=./hparams/ICL/llama3-8b --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 34 | python3 inject_misinfomation.py --editing_method=ICL --hparams_dir=./hparams/ICL/alpaca-7b --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 35 | python3 inject_misinfomation.py --editing_method=ICL --hparams_dir=./hparams/ICL/vicuna-7b --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 36 | python3 inject_misinfomation.py --editing_method=ICL --hparams_dir=./hparams/ICL/mistral-7b --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 37 | python3 inject_misinfomation.py --editing_method=ICL --hparams_dir=./hparams/ICL/mistral-7b-v2 --long_tail_data=True --metrics_save_dir=../results/results_long_tail_misinfomation_injection 38 | -------------------------------------------------------------------------------- /data/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llm-editing/editing-attack/b6160d94659542c9e6406c65d2dac435dbe91afb/data/intro.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==1.18.3 2 | einops==0.4.0 3 | gpustat==1.1 4 | hydra-core==1.1.1 5 | higher==0.2.1 6 | importlib-metadata==6.3.0 7 | matplotlib==3.5.1 8 | nltk==3.6.5 9 | numpy==1.22.1 10 | omegaconf==2.1.1 11 | pandas==1.4.0 12 | PyYAML==6.0 13 | scikit-learn==1.0.2 14 | scipy==1.7.3 15 | sentence-transformers==2.2.2 16 | tokenizers==0.19.1 17 | torch==2.3.1 18 | tqdm==4.62.3 19 | transformers==4.41.0 20 | openai==1.30.1 21 | peft==0.7.1 22 | timm==0.9.7 23 | iopath==0.1.10 24 | opencv-python==4.8.0.76 25 | fairscale==0.4.13 26 | --------------------------------------------------------------------------------