├── .gitignore ├── README.md ├── figures ├── bert.png ├── main_00.png └── table.png ├── melo ├── algs │ └── lora.py ├── config │ ├── alg │ │ └── lora.yaml │ ├── config.yaml │ ├── experiment │ │ ├── hallucination.yaml │ │ ├── qa.yaml │ │ └── scotus.yaml │ └── model │ │ ├── gpt2xl.yaml │ │ ├── scotus-bert.yaml │ │ ├── t5large.yaml │ │ └── t5small.yaml ├── dataset.py ├── grammar.py ├── hooks.py ├── main.sh ├── metrics.py ├── models.py ├── run.py ├── trainer.py ├── utils.py └── wiki_bio_concepts.txt ├── peft_egg ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── docker │ ├── peft-cpu │ │ └── Dockerfile │ └── peft-gpu │ │ └── Dockerfile ├── docs │ ├── Makefile │ ├── README.md │ └── source │ │ ├── _config.py │ │ ├── _toctree.yml │ │ ├── accelerate │ │ ├── deepspeed-zero3-offload.mdx │ │ └── fsdp.mdx │ │ ├── conceptual_guides │ │ ├── lora.mdx │ │ └── prompting.mdx │ │ ├── index.mdx │ │ ├── install.mdx │ │ ├── package_reference │ │ ├── config.mdx │ │ ├── peft_model.mdx │ │ └── tuners.mdx │ │ ├── quicktour.mdx │ │ └── task_guides │ │ ├── clm-prompt-tuning.mdx │ │ ├── dreambooth_lora.mdx │ │ ├── image_classification_lora.mdx │ │ ├── int8-asr.mdx │ │ ├── ptuning-seq-classification.mdx │ │ ├── semantic_segmentation_lora.mdx │ │ ├── seq2seq-prefix-tuning.mdx │ │ └── token-classification-lora.mdx ├── grammar.py ├── pyproject.toml ├── scripts │ ├── log_reports.py │ └── stale.py ├── setup.py ├── src │ └── peft │ │ ├── __init__.py │ │ ├── import_utils.py │ │ ├── mapping.py │ │ ├── peft_model.py │ │ ├── tuners │ │ ├── __init__.py │ │ ├── adalora.py │ │ ├── adaption_prompt.py │ │ ├── lora.py │ │ ├── melo.py │ │ ├── melo_backup.py │ │ ├── p_tuning.py │ │ ├── prefix_tuning.py │ │ └── prompt_tuning.py │ │ └── utils │ │ ├── __init__.py │ │ ├── config.py │ │ ├── hub_utils.py │ │ ├── other.py │ │ └── save_and_load.py └── tests │ ├── __init__.py │ ├── test_adaption_prompt.py │ ├── test_common_gpu.py │ ├── test_config.py │ ├── test_decoder_models.py │ ├── test_encoder_decoder_models.py │ ├── test_gpu_examples.py │ ├── testing_common.py │ └── testing_utils.py ├── plot ├── ablation │ ├── ablation_cluster_num.py │ ├── ablation_conflicts_num.py │ ├── ablation_forget.py │ ├── ablation_res │ │ ├── T5Large_zsre_block.jpg │ │ ├── T5Large_zsre_eps.jpg │ │ ├── T5Small_cluster.png │ │ ├── T5Small_conflicts.png │ │ ├── T5Small_forget.png │ │ ├── T5Small_zsre_block.jpg │ │ ├── T5Small_zsre_block_2.jpg │ │ ├── T5Small_zsre_block_main.jpg │ │ ├── T5Small_zsre_eps.jpg │ │ └── pca.jpg │ ├── pca.jpg │ ├── res_time.py │ ├── zsre_block_lineplot_T5Large.py │ ├── zsre_block_lineplot_T5Small.py │ ├── zsre_block_lineplot_t5small_2.py │ ├── zsre_eps_lineplot_T5Large.py │ └── zsre_eps_lineplot_T5Small.py ├── tsne.py ├── tsne_zsre_T5Small.py ├── zsre_eps_lineplot_T5Large.py ├── zsre_rank_line_plot_generality.py ├── zsre_rank_lineplot_T5Large.py └── zsre_rank_lineplot_T5Small.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .git 2 | .idea 3 | data 4 | scr 5 | __pycache__ 6 | scr 7 | outputs 8 | checkpoint 9 | .json 10 | plotres 11 | logs 12 | mnist 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # MELO: Enhancing Model Editing with Neuron-Indexd Dynamic LoRA 3 | This repo contains the source code of our proposed MELO, a plug-in model editing method, which routes models' behavoir by dynamically indexing LoRA blocks according to a inner vector databse. Seamlessly integrated in [PEFT](https://github.com/huggingface/peft), MELO supports multiple LLMs such as BERT, T5 and GPT. 4 | 5 | 6 | ## Updates 7 | - **2024/03/10:** Add some Important Tips for deployment 🪂 8 | - **2023/12/19:** Repo has been transferred to [ECNU-ICALK/MELO](https://github.com/ECNU-ICALK/MELO) (Organization Account) 🔔 9 | - **2023/12/09:** Our work has been accepted by AAAI 2024 :fire::fire: 10 | - **2023/7/16:** Experiments with multiple LLMs on different editing tasks. :art: 11 | - **2023/6/24:** Inner vector databse that builds accurate editing scope. :confetti_ball: 12 | - **2023/6/08:** Support dynamic LoRA block Loding. :star: 13 | 14 | 15 | ## Table of Contents 16 | - [Reference](#reference) 17 | - [Introduction](#introduction) 18 | - [Experiments](#experiments) 19 | - [Prepare Environments](#prepare-environments) 20 | - [Prepare Datasets](#prepare-datasets) 21 | - [Quick Start](#quick-start) 22 | - [Important Tips](#important-tips) 23 | - [Acknowledgments](#acknowledgments) 24 | ## Reference 25 | We would appreciate if you could refer to our work as one of your baselines! 26 | ``` 27 | @article{yu2023melo, 28 | title={MELO: Enhancing Model Editing with Neuron-Indexed Dynamic LoRA}, 29 | author={Yu, Lang and Chen, Qin and Zhou, Jie and He, Liang}, 30 | journal={arXiv preprint arXiv:2312.11795}, 31 | year={2023} 32 | } 33 | ``` 34 | ## Introduction 35 | Due to the limitation of catastrophic forgetting and the lack of locality, few studies explore recent advanced Low-rank Adapter (LoRA) techniques for continual model editing. To overcome these limitations and take advantage of LoRA's resource efficiency, we propose MELO, a plug-in model editing method implemented with dynamic LoRA, which routes the behavior of language models by dynamically indexing LoRA blocks according to an inner vector database. MELO considers all editing properties and can be easily integrated into multiple LLMs such as BERT, T5 and GPT. Experimental results show that our proposed MELO achieves state-of-the-art editing performance on three sequential editing tasks (document classification, question answering and hallucination correction), while requires the least trainable parameters and computational cost. 36 | ![main](./figures/main_00.png) 37 | 38 | ## Experiments 39 | Comparison of MELO to prior editing methods on sequential editing tasks. Note that MELO edits all language models with a single RTX 3090 GPU. 40 | ![table](./figures/table.png) 41 | 42 | ## Prepare Environments 43 | Required CUDA environment and library dependencies are listed in: 44 | ``` 45 | requirements.txt 46 | ``` 47 | Then you should install our modified PEFT: 48 |

🤗 PEFT-MELO

49 | 50 | ``` 51 | cd peft_egg 52 | pip install -e . 53 | ``` 54 | Detailed implementation of MELO is in `./peft_egg/src/tuners/melo.py` 55 | ## Prepare Datasets 56 | The zsRE experiments use data linked by the [MEND](https://github.com/eric-mitchell/mend) repository. Download the data for NQ and zsRE from their Google Drive link and unzip each sub-directory into ./melo/data. SCOTUS and Hallucination data are loaded through huggingface. 57 | 58 | ## Quick Start 59 | The location of inner vector database and dynamic LoRA target modules can be modified in `./melo/model/config` 60 | 61 | ### Editing GPT2-XL on Hallucination with MELO 62 | ``` 63 | cd melo 64 | python run.py +alg=lora +experiment=hallucination +model=gpt2xl 65 | ``` 66 | 67 | 68 | ### Editing BERT on SCOTUS with MELO 69 | ``` 70 | cd melo 71 | python run.py +alg=lora +experiment=scotus +model=scotus-bert 72 | ``` 73 | 74 | ### Editing T5 on zsRE with MELO 75 | ``` 76 | cd melo 77 | python run.py +alg=lora +experiment=qa +model=t5small 78 | ``` 79 | ## Important Tips 80 | * [Datasets](https://drive.google.com/file/d/1HDqh4ofYF7B-YkcU3CNlZMYAOJO0XxwX/view?usp=drive_link) for MELO's experiments can be downloaded through GoogleDrive now. Please extract the files and place them under `melo\data`. 81 | 82 | * The GPT2-XL model we use is fine-tuned in line with the work [GRACE](https://github.com/Thartvigsen/GRACE/blob/728a52ebcd328ddca0bb1ec975e79625eabfab2a/grace/main.py#L83). Please download the checkpoint with the [Google Drive](https://drive.google.com/drive/folders/1j_DvcUY8goksQVOBt4XqBe7z8fS-0zvI?usp=sharing) link, and place the files under `melo/scr/models--gpt2-xl` 83 | 84 | 85 | * Some [logs](https://drive.google.com/drive/folders/1UhlY1W8MUmvsxqIXlRFBfxuTXEQG8FJP?usp=sharing) recording the correct training and inference processes are released for checking hyper-parameters. 86 | 87 | * The settings of [torch.optim.lr_scheduler](https://github.com/BruthYU/MELO/blob/51c8322cc06faa2b7665c2d90236f1bd1b8d9575/melo/algs/lora.py#L135) vary on different tasks: 88 | ``` 89 | # T5-Small and T5-Large 90 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=20,gamma=0.5) 91 | # SCOTUS-BERT and GPT2-XL 92 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=30,gamma=0.5) 93 | ``` 94 | 95 | 96 | 97 | 98 | ## Acknowledgments 99 | We would like to thank the following individuals and organizations for their contributions to this project: 100 | ``` 101 | Huggingface: for their support of the PEFT community and their development of the PEFT framework (https://github.com/huggingface/peft) 102 | 103 | GRACE: for the development of the open-source library GRACE which inspired our work (https://github.com/Thartvigsen/GRACE) 104 | ``` 105 | -------------------------------------------------------------------------------- /figures/bert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/figures/bert.png -------------------------------------------------------------------------------- /figures/main_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/figures/main_00.png -------------------------------------------------------------------------------- /figures/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/figures/table.png -------------------------------------------------------------------------------- /melo/algs/lora.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from omegaconf import OmegaConf 3 | import torch 4 | import copy 5 | import transformers 6 | import logging 7 | import os 8 | 9 | from torch.nn import Parameter 10 | 11 | from utils import * 12 | 13 | from peft import ( 14 | PeftModel, 15 | prepare_model_for_int8_training, 16 | MeloConfig, 17 | get_peft_model, 18 | get_peft_model_state_dict, 19 | ) 20 | from peft.tuners.melo import LoraLayer, GraceLayer 21 | from hooks import lora_backward_hook 22 | # from models import BertClassifier 23 | LOG = logging.getLogger(__name__) 24 | def translate_tokens(tokens, from_tok, to_tok): 25 | tokens = tokens.masked_fill(tokens == -100, from_tok.pad_token_id) 26 | text = from_tok.batch_decode(tokens, skip_special_tokens=True) 27 | return to_tok(text, return_tensors="pt")["input_ids"].to(tokens.device) 28 | 29 | class LORA(torch.nn.Module): 30 | def __init__(self, model, config, model_tok,scale=None): 31 | super(LORA, self).__init__() 32 | self.config = config 33 | 34 | '''Apply_lora 35 | ''' 36 | r_num = config.grace.num_block * config.grace.num_rank_per_block 37 | self.lora_config = MeloConfig( 38 | r = r_num, 39 | lora_alpha = r_num, 40 | target_modules= list(config.model.target_modules), 41 | lora_dropout = config.lora.lora_dropout, 42 | task_type = config.lora_task_type, 43 | fan_in_fan_out= config.model.fan_in_fan_out, 44 | grace_layer = config.model.grace_layer, 45 | grace_config= OmegaConf.to_object(config.grace) 46 | ) 47 | self.log_dict = {} 48 | 49 | '''Load 50 | ''' 51 | # self.original_model = model 52 | # self.model = model 53 | 54 | if not config.check_dir: 55 | self.model = get_peft_model(model, self.lora_config) 56 | else: 57 | save_path = os.path.join(config.base_dir, "checkpoint", config.check_dir) 58 | self.load_from_checkpoint(save_path) 59 | 60 | self.lora_list = self.named_lora_modules() 61 | self.grace_layer = self.named_grace_layer() 62 | # self.register_lora_backward_hooks(lora_backward_hook) 63 | 64 | '''Load Tokenizer 65 | ''' 66 | self.model_tok = model_tok 67 | self.classifier_tok = transformers.AutoTokenizer.from_pretrained(config.lora.cls_name) 68 | 69 | '''Parameters to be optimized 70 | ''' 71 | self.opt_params = self.optim_parameters() 72 | pass 73 | 74 | 75 | def optim_parameters(self): 76 | for name, param in self.model.named_parameters(): 77 | if param.requires_grad==True and 'lora' not in name: 78 | param.requires_grad = False 79 | lora_params = list(filter(lambda p: p.requires_grad, self.model.parameters())) 80 | return lora_params 81 | 82 | 83 | 84 | 85 | #TODO 86 | def load_from_checkpoint(self, save_path): 87 | print(save_path) 88 | 89 | 90 | def save_classifier_weights(self,cls_dir): 91 | if not os.path.exists(cls_dir): 92 | os.makedirs(cls_dir,exist_ok=True) 93 | torch.save(self.classifier.state_dict(),f"{cls_dir}/classifier.pt") 94 | def save_lora_weights(self,lora_dir): 95 | self.model.save_pretrained(lora_dir+"/lora_checkpoint") 96 | 97 | 98 | def reset_lora(self): 99 | for key in self.lora_list: 100 | self.model.get_submodule(key).reset_lora_parameters('default') 101 | 102 | def named_lora_modules(self): 103 | module_list = [key for key,_ in self.model.named_modules()] 104 | lora_list = [] 105 | for key in module_list: 106 | if isinstance(self.model.get_submodule(key),LoraLayer): 107 | lora_list.append(key) 108 | return lora_list 109 | 110 | def named_grace_layer(self) -> str: 111 | module_list = [key for key, _ in self.model.named_modules()] 112 | grace_list = [] 113 | for key in module_list: 114 | if isinstance(self.model.get_submodule(key), GraceLayer): 115 | grace_list.append(key) 116 | assert len(grace_list) == 1, "At Most One Grace Layer" 117 | return grace_list[0] 118 | 119 | def register_lora_backward_hooks(self,backward_hook_fn): 120 | for key in self.lora_list: 121 | self.model.get_submodule(key).register_backward_hook(backward_hook_fn) 122 | 123 | 124 | def disable_melo(self): 125 | self.model.base_model.disable_adapter_layers() 126 | self.model.base_model.disable_grace_layer() 127 | 128 | def enable_melo(self): 129 | self.model.base_model.enable_adapter_layers() 130 | self.model.base_model.enable_grace_layer() 131 | 132 | 133 | def edit(self, tokens): 134 | optimizer = torch.optim.Adam(self.optim_parameters(), self.config.grace.edit_lr) 135 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=20,gamma=0.5) 136 | # --- pass edit label, training mode, and key_id into GRACE --- 137 | setattr(self.model.get_submodule(self.grace_layer), "training", True) 138 | setattr(self.model.get_submodule(self.grace_layer), "edit_label", tokens["labels"]) 139 | 140 | self.losses = [] 141 | for i in range(self.config.grace.num_iter): 142 | # --- insert iteration into each layer (only initiate keys on first iteration) --- 143 | setattr(self.model.get_submodule(self.grace_layer), "batch_iter", i) 144 | 145 | # --- pass tokens through model (including through the GRACE layer) --- 146 | outputs = self.model.model(**tokens) 147 | loss = outputs.loss 148 | loss.backward() 149 | optimizer.step() 150 | optimizer.zero_grad() 151 | scheduler.step() 152 | self.losses.append(loss.detach().cpu().numpy()) 153 | LOG.info(f'batch loss in iter {i}: {loss.detach().cpu().numpy()}') 154 | self.loss = loss # Log final loss 155 | 156 | setattr(self.model.get_submodule(self.grace_layer), "training", False) 157 | 158 | 159 | 160 | 161 | def generate(self, *args, **kwargs): 162 | return self.model.model.generate(*args, **kwargs) 163 | 164 | def get_VecDB_info(self): 165 | VecDB_logdict = {} 166 | VecDB_logdict["num_cluster"] = len(getattr(self.model.get_submodule(self.grace_layer), "VecDB")) 167 | VecDB_logdict["conflict_num"] = getattr(self.model.get_submodule(self.grace_layer), "VecDB").conflict_num 168 | VecDB_logdict["forget_keys"] = len(getattr(self.model.get_submodule(self.grace_layer), "VecDB").forget_keys) 169 | return VecDB_logdict 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | if __name__ == '__main__': 181 | pass 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /melo/config/alg/lora.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | alg: lora 4 | lr: 1e-4 5 | train_base: False 6 | lr_lr: 1e-4 7 | lora: 8 | cls_name: distilbert-base-cased 9 | cls_class: AutoModel 10 | supervised: true 11 | cos: false 12 | freeze: null 13 | square: true 14 | bound_embeds: false 15 | use_all_negatives: false 16 | freeze_lora: false 17 | dist_heads: 1 18 | cross_attend: false 19 | soft_weighting: false 20 | checkpoint_grad: false 21 | lora_r: 64 22 | lora_alpha: 64 23 | lora_dropout: 0.0 24 | -------------------------------------------------------------------------------- /melo/config/config.yaml: -------------------------------------------------------------------------------- 1 | alg: enn 2 | seed: 0 3 | debug: False 4 | model_save_pt: 5000 5 | edit_bs: 1 6 | silent: False 7 | max_iters: 200100 8 | log_interval: 100 9 | val_interval: 5000 10 | batch_size: 4 11 | val_batch_size: 4 12 | accumulate_bs: 10 13 | cedit: 0.2 14 | cloc: 1.0 15 | cbase: 1.0 16 | val_steps: 500 17 | device: cuda 18 | base_loss: distill 19 | oracle: False 20 | train: True 21 | train_base: True 22 | opt: Adam 23 | single_batch: False 24 | archive: null 25 | grad_clip: 100. 26 | ref: null 27 | early_stop_patience: 40000 28 | early_stop_key: "mixture/acc_val" 29 | dropout: 0.0 30 | tokenizer: null 31 | results_dir: null 32 | no_grad_layers: null 33 | eval_only: False 34 | half: False 35 | save: False 36 | log_errors: False 37 | unlikelihood: True 38 | check_dir: null 39 | batch_round: 10 40 | re_init_model: False 41 | max_n_edits: 5000 42 | 43 | model: 44 | pt: null 45 | 46 | data: 47 | path: null 48 | rephrase: true 49 | zsre_nq: false 50 | zsre_impl: false 51 | zsre_impl_path: ${hydra:runtime.cwd}/data/zsre/impl_{}.json 52 | zsre_yn: false 53 | zsre_yn_path: ${hydra:runtime.cwd}/data/zsre/zsre_yn_{}.txt 54 | zsre_eval_idxs: null 55 | zsre_path: ${hydra:runtime.cwd}/data/zsre/structured_zeroshot-{}-new_annotated_final.jsonl 56 | nq_path: ${hydra:runtime.cwd}/data/nq 57 | wiki_webtext: true 58 | n_edits: 1 59 | hard_neg: false 60 | hard_neg_neighbors: 100 61 | hard_neg_exclude: 25 62 | hard_neg_temp: 0.1 63 | hard_neg_prob: 0.5 64 | flip_inner_outer: false 65 | sent_eval_sample: false 66 | n_outer_max: null 67 | 68 | eval: 69 | verbose: True 70 | log_interval: 100 71 | final_eval: True 72 | 73 | hydra: 74 | run: 75 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f${uuid:}} 76 | sweep: 77 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f} 78 | subdir: ${hydra.job.num} 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /melo/config/experiment/hallucination.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | task: hallucination 3 | lora_task_type: CAUSAL_LM 4 | grace: 5 | _name: grace 6 | num_iter: 50 7 | init_radius: 0.5 8 | dist_fn: euc # euc, mmd, cos 9 | val_init: cold # cold, warm 10 | val_train: sgd # sgd, pert 11 | val_reg: None # early 12 | reg: early_stop # early_stop 13 | replacement: replace_prompt # replace_last, replace_all, replace_prompt 14 | expand_mode: moving_avg # , moving_avg, decay 15 | num_pert: 8 # only matters when using perturbation training 16 | key_id: -1 17 | num_edit_per_block: 4 18 | num_block: 350 19 | num_rank_per_block: 2 20 | metric_period: 400 21 | edit_lr: 1e-3 -------------------------------------------------------------------------------- /melo/config/experiment/qa.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | task: qa 3 | lora_task_type: SEQ_2_SEQ_LM 4 | grace: 5 | _name: grace 6 | num_iter: 50 7 | init_radius: 75 8 | dist_fn: euc # euc, mmd, cos 9 | val_init: cold # cold, warm 10 | val_train: sgd # sgd, pert 11 | val_reg: None # early 12 | reg: early_stop # early_stop 13 | replacement: replace_prompt # replace_last, replace_all, replace_prompt 14 | expand_mode: moving_avg # , moving_avg, decay 15 | num_pert: 8 # only matters when using perturbation training 16 | key_id: -1 17 | num_edit_per_block: 100 18 | num_block: 20 19 | num_rank_per_block: 2 20 | metric_period: 200 21 | edit_lr: 1e-2 22 | -------------------------------------------------------------------------------- /melo/config/experiment/scotus.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | task: scotus 3 | lora_task_type: SEQ_CLS 4 | grace: 5 | _name: grace 6 | num_iter: 50 7 | init_radius: 0.1 8 | dist_fn: euc # euc, mmd, cos 9 | val_init: cold # cold, warm 10 | val_train: sgd # sgd, pert 11 | val_reg: None # early 12 | reg: early_stop # early_stop 13 | replacement: replace_prompt # replace_last, replace_all, replace_prompt 14 | expand_mode: moving_avg # , moving_avg, decay 15 | num_pert: 8 # only matters when using perturbation training 16 | key_id: -1 17 | num_edit_per_block: 50 18 | num_block: 100 19 | num_rank_per_block: 4 20 | metric_period: 200 21 | edit_lr: 1e-2 -------------------------------------------------------------------------------- /melo/config/model/gpt2xl.yaml: -------------------------------------------------------------------------------- 1 | name: gpt2-xl 2 | class_name: GPT2LMHeadModel 3 | tokenizer_class: GPT2TokenizerFast 4 | tokenizer_name: gpt2-xl 5 | 6 | fan_in_fan_out: True 7 | target_modules: 8 | - transformer.h.36.mlp.c_fc 9 | - transformer.h.37.mlp.c_fc 10 | 11 | 12 | #pt: null 13 | pt: /home/yu/ECNU/MELO/melo/checkpoint # set this to 'hallucination' inside your checkpoint directory 14 | #model.base_model.h[35] 15 | 16 | 17 | 18 | grace_layer: transformer.h.35.mlp.c_fc -------------------------------------------------------------------------------- /melo/config/model/scotus-bert.yaml: -------------------------------------------------------------------------------- 1 | name: tomh/scotus-bert 2 | class_name: AutoModelForSequenceClassification 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: bert-base-cased 5 | fan_in_fan_out: False 6 | target_modules: 7 | - bert.encoder.layer.9.output.dense 8 | - bert.encoder.layer.10.output.dense 9 | - bert.encoder.layer.11.output.dense 10 | grace_layer: bert.encoder.layer.4.output.dense 11 | 12 | pt: null -------------------------------------------------------------------------------- /melo/config/model/t5large.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-large-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-large-ssm-nq 5 | fan_in_fan_out: False 6 | inner_params: 7 | - encoder.block.22.layer.1.DenseReluDense.wi.weight 8 | - encoder.block.22.layer.1.DenseReluDense.wo.weight 9 | - encoder.block.23.layer.1.DenseReluDense.wi.weight 10 | - encoder.block.23.layer.1.DenseReluDense.wo.weight 11 | - decoder.block.22.layer.2.DenseReluDense.wi.weight 12 | - decoder.block.22.layer.2.DenseReluDense.wo.weight 13 | - decoder.block.23.layer.2.DenseReluDense.wi.weight 14 | - decoder.block.23.layer.2.DenseReluDense.wo.weight 15 | 16 | 17 | target_modules: 18 | - encoder.block.22.layer.1.DenseReluDense.wi 19 | - encoder.block.22.layer.1.DenseReluDense.wo 20 | - encoder.block.23.layer.1.DenseReluDense.wi 21 | - encoder.block.23.layer.1.DenseReluDense.wo 22 | - decoder.block.22.layer.2.DenseReluDense.wi 23 | - decoder.block.22.layer.2.DenseReluDense.wo 24 | - decoder.block.23.layer.2.DenseReluDense.wi 25 | - decoder.block.23.layer.2.DenseReluDense.wo 26 | 27 | grace_layer: encoder.block.12.layer.1.DenseReluDense.wo 28 | #alg.model.base_model.model.encoder.block[4].layer[1].DenseReluDense.wo 29 | #self.model.model.encoder.block[22].layer[1].DenseReluDense.wo.lora_B['default'][4:8] -------------------------------------------------------------------------------- /melo/config/model/t5small.yaml: -------------------------------------------------------------------------------- 1 | name: google/t5-small-ssm-nq 2 | class_name: AutoModelForSeq2SeqLM 3 | tokenizer_class: AutoTokenizer 4 | tokenizer_name: google/t5-small-ssm-nq 5 | fan_in_fan_out: False 6 | target_modules: 7 | - encoder.block.5.layer.1.DenseReluDense.wi 8 | - encoder.block.5.layer.1.DenseReluDense.wo 9 | - decoder.block.5.layer.2.DenseReluDense.wi 10 | - decoder.block.5.layer.2.DenseReluDense.wo 11 | - encoder.block.6.layer.1.DenseReluDense.wi 12 | - encoder.block.6.layer.1.DenseReluDense.wo 13 | - decoder.block.6.layer.2.DenseReluDense.wi 14 | - decoder.block.6.layer.2.DenseReluDense.wo 15 | 16 | pt: null 17 | 18 | grace_layer: encoder.block.4.layer.1.DenseReluDense.wo -------------------------------------------------------------------------------- /melo/grammar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | -------------------------------------------------------------------------------- /melo/hooks.py: -------------------------------------------------------------------------------- 1 | def lora_backward_hook(lora_module, grad_in, grad_out): 2 | print("Module_name: ", lora_module) 3 | print("grad_in[0]_shape: ",grad_in[0].shape) 4 | print("grad_in[0]: ", grad_in[0]) 5 | print("grad_out[0]_shape: ", grad_out[0].shape) 6 | print("grad_out[0]: ", grad_out[0]) 7 | print("-------------------------------") -------------------------------------------------------------------------------- /melo/main.sh: -------------------------------------------------------------------------------- 1 | +alg=lora +experiment=qa +model=t5small 2 | +alg=lora +experiment=hallucination +model=gpt2xl 3 | +alg=lora +experiment=scotus +model=scotus-bert 4 | 5 | 6 | CUDA_VISIBLE_DEVICES=3 python few_shot_run.py +alg=lora +experiment=fnli +model=bert-base batch_size=10 val_batch_size=10 lora.cross_attend=True 7 | 8 | 9 | 10 | CUDA_VISIBLE_DEVICES=3 python few_shot_run.py +alg=lora +experiment=qa +model=t5large batch_size=10 val_batch_size=10 data.zsre_impl=false data.zsre_yn=false data.hard_neg=false -------------------------------------------------------------------------------- /melo/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import * 3 | import logging 4 | LOG = logging.getLogger(__name__) 5 | 6 | 7 | # DEPRECATED 8 | def sent_success(pre_edit_probs, post_edit_probs, pos_mask, eps=torch.finfo(torch.float32).eps, batch_size=20): 9 | assert False, "No longer used" 10 | # content_score = post_edit_probs[pos_mask].prod() ** (1/pos_mask.sum()) / (pre_edit_probs[pos_mask]. + eps) 11 | post_pos_avg = post_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum()) 12 | pre_pos_avg = pre_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum()) 13 | content_score = post_pos_avg / (pre_pos_avg + eps) 14 | z_content = min(1., content_score) 15 | 16 | # compute z_sent through a weighting objective 17 | # normalized_probs = post_edit_probs / (post_edit_probs.sum() + eps) 18 | # balancing_factor = 0.5 * ((~pos_mask).float().sum() / pos_mask.float().sum() + 1) 19 | # z_sent_weight = balancing_factor * normalized_probs.dot(pos_mask.float()) 20 | post_neg_avg = post_edit_probs[~pos_mask].prod() ** (1 / (~pos_mask).sum()) 21 | neg_over_pos = post_neg_avg / (eps + post_pos_avg) 22 | z_sent_weight = 1 / (1 + neg_over_pos) 23 | 24 | # compute z_sent through a ranking objective 25 | batch_mask = pos_mask.view(-1, batch_size).long() 26 | sort_idxs = post_edit_probs.view(-1, batch_size).sort(-1, descending=True).indices 27 | ranked_mask = batch_mask.gather(1, sort_idxs) 28 | true_mask = batch_mask.sort(-1, descending=True).values 29 | z_sent_rank = (ranked_mask == true_mask).float().mean() 30 | 31 | # compute the final success scores 32 | weight_success = (z_content * z_sent_weight) ** 0.5 33 | rank_success = (z_content * z_sent_rank) ** 0.5 34 | 35 | correct_probs = post_edit_probs[pos_mask].mean() 36 | wrong_probs = post_edit_probs[~pos_mask].mean() 37 | 38 | return { 39 | "acc_weight": weight_success, 40 | "acc_rank": rank_success, 41 | "rank_score": z_sent_rank, 42 | "weight_score": z_sent_weight, 43 | "content_score": content_score, 44 | "post_edit_probs": post_edit_probs.sum(), 45 | "pre_edit_probs": pre_edit_probs.sum(), 46 | "correct_probs": correct_probs, 47 | "wrong_probs": wrong_probs 48 | } 49 | 50 | 51 | 52 | # For zsRE and F-NLI 53 | def retain_rate(pre_logits, post_logits, mask=None): 54 | if pre_logits.shape[-1] == 1: 55 | pre_logits = pre_logits.squeeze(-1) 56 | if post_logits.shape[-1] == 1: 57 | post_logits = post_logits.squeeze(-1) 58 | 59 | assert pre_logits.shape == post_logits.shape 60 | assert pre_logits.shape[0] == mask.shape[0] 61 | 62 | if pre_logits.dim() == 1: 63 | # binary classification 64 | pre_preds = pre_logits > 0 65 | post_preds = post_logits > 0 66 | retain = (pre_preds == post_preds).float().mean() 67 | elif pre_logits.dim() == 3: 68 | # sequence modeling 69 | pre_preds = pre_logits.argmax(-1) 70 | post_preds = post_logits.argmax(-1) 71 | match = (pre_preds == post_preds) * mask 72 | retain = (match.sum(-1) == mask.sum(-1)).float().mean() 73 | else: 74 | raise NotImplementedError 75 | 76 | return retain.item() 77 | 78 | 79 | def is_acc_error(model, tokens): 80 | # Check whether or not the model's prediction for a batch element is correct 81 | labels = tokens["labels"] 82 | logits = model(**tokens).logits 83 | probs = torch.softmax(logits, -1).squeeze() 84 | argmaxs = torch.argmax(probs, dim=-1).squeeze() 85 | return labels != argmaxs 86 | 87 | 88 | def Accuracy(alg, tokens): 89 | labels = tokens["labels"] 90 | new_tokens = {f"{k}": v for k, v in tokens.items() if k != "labels"} 91 | logits = alg.model(**new_tokens).logits 92 | probs = torch.softmax(logits, -1).squeeze() 93 | argmaxs = torch.argmax(probs, dim=-1).squeeze() 94 | return (labels == argmaxs).float().mean() 95 | 96 | 97 | def is_qa_error(model, tokens): 98 | preds = model.generate(tokens["input_ids"], max_length=20).squeeze() # Run model to get its predictions 99 | labels = tokens["labels"] # [tokens["labels"] != -100] 100 | 101 | if (len(preds) != len(labels)) or ((preds == labels).sum() != len(preds)): 102 | return True 103 | else: 104 | return False 105 | 106 | 107 | def PPL(alg, batch): 108 | input_ids = batch["input_ids"][:, :1024] # .to(device) 109 | if "labels" not in batch: 110 | batch["labels"] = batch["input_ids"][:, :1024].clone() 111 | else: 112 | batch["labels"] = batch["labels"][:, :1024].clone() 113 | 114 | with torch.no_grad(): 115 | #outputs = alg.model.model(input_ids=input_ids, labels=target_ids) 116 | outputs = alg.model(**batch) 117 | nll = outputs.loss 118 | 119 | ppl = torch.exp(nll) # .clip(0, 100) 120 | return ppl 121 | 122 | 123 | 124 | def F1_ACC(alg, batch): 125 | try: 126 | preds = alg.generate(batch["input_ids"], max_length=20).squeeze() 127 | f1 = F1(preds, batch, alg.model_tok) 128 | acc = ACC(preds, batch, alg.model_tok) 129 | return f1, acc 130 | except Exception as e: 131 | raise e 132 | 133 | def F1(preds, batch, tok): 134 | try: 135 | f1_list = [] 136 | for p, g in zip(preds,batch["labels"]): 137 | p = p[p != tok.pad_token_id].cpu().squeeze() 138 | g = g[g != -100].cpu().squeeze() # -100 might be nonsense 139 | num_same = len(np.intersect1d(p, g)) 140 | len_pred = len(p) 141 | len_gold = len(g) 142 | precision = num_same / len_pred 143 | recall = 1.0 * num_same / len_gold 144 | f1 = (2 * precision * recall) / (precision + recall) 145 | f1_list.append(f1) 146 | except: 147 | return 0. 148 | 149 | return sum(f1_list) / len(f1_list) 150 | 151 | 152 | def ACC(preds, batch, tok): 153 | decode_preds = tok.batch_decode(preds,skip_special_tokens=True) 154 | gold_labels = batch['labels'] 155 | gold_labels = gold_labels.masked_fill(gold_labels == -100,tok.pad_token_id) 156 | decode_labels = tok.batch_decode(gold_labels,skip_special_tokens=True) 157 | assert len(decode_labels) == len(decode_preds), "Lengths of decode_preds and decode_labels should be the same" 158 | count = 0. 159 | for pred,label in zip(decode_preds, decode_labels): 160 | if pred == label: 161 | count = count + 1 162 | return count/len(decode_preds) 163 | -------------------------------------------------------------------------------- /melo/models.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import torch.nn as nn 4 | import re 5 | import logging 6 | # from torch.nn import FixableDropout 7 | from utils import scr 8 | 9 | 10 | LOG = logging.getLogger(__name__) 11 | 12 | 13 | class CastModule(nn.Module): 14 | def __init__(self, module: nn.Module, in_cast: torch.dtype = torch.float32, out_cast: torch.dtype = None): 15 | super().__init__() 16 | 17 | self.underlying = module 18 | self.in_cast = in_cast 19 | self.out_cast = out_cast 20 | 21 | def cast(self, obj, dtype): 22 | if dtype is None: 23 | return obj 24 | 25 | if isinstance(obj, torch.Tensor): 26 | return obj.to(dtype) 27 | else: 28 | return obj 29 | 30 | def forward(self, *args, **kwargs): 31 | args = tuple(self.cast(a, self.in_cast) for a in args) 32 | kwargs = {k: self.cast(v, self.in_cast) for k, v in kwargs.items()} 33 | outputs = self.underlying(*args, **kwargs) 34 | if isinstance(outputs, torch.Tensor): 35 | outputs = self.cast(outputs, self.out_cast) 36 | elif isinstance(outputs, tuple): 37 | outputs = tuple(self.cast(o, self.out_cast) for o in outputs) 38 | else: 39 | raise RuntimeError(f"Not sure how to cast type {type(outputs)}") 40 | return outputs 41 | 42 | def extra_repr(self): 43 | return f"in_cast: {self.in_cast}\nout_cast: {self.out_cast}" 44 | 45 | 46 | class BertClassifier(torch.nn.Module): 47 | def __init__(self, model_name, hidden_dim=768): 48 | super().__init__() 49 | if model_name.startswith("bert"): 50 | LOG.info(f"Loading model class {model_name}, cache dir {scr()}") 51 | self.model = transformers.BertModel.from_pretrained(model_name, cache_dir=scr()) 52 | else: 53 | self.model = transformers.AutoModel.from_pretrained(model_name, cache_dir=scr()) 54 | self.classifier = torch.nn.Linear(hidden_dim, 1) 55 | 56 | @property 57 | def config(self): 58 | return self.model.config 59 | 60 | def forward(self, *args, **kwargs): 61 | filtered_kwargs = {k: v for k, v in kwargs.items() if k != "labels"} 62 | model_output = self.model(*args, **filtered_kwargs) 63 | if "pooler_output" in model_output.keys(): 64 | pred = self.classifier(model_output.pooler_output) 65 | else: 66 | pred = self.classifier(model_output.last_hidden_state[:, 0]) 67 | 68 | if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]: 69 | last_hidden_state = model_output.last_hidden_state 70 | return pred, last_hidden_state 71 | else: 72 | return pred 73 | 74 | 75 | 76 | def get_hf_model(config): 77 | ModelClass = getattr(transformers, config.model.class_name) 78 | LOG.info(f"Loading model class {ModelClass} with name {config.model.name} from cache dir {scr()}") 79 | if config.model.pt is None: 80 | model = ModelClass.from_pretrained(config.model.name, cache_dir=scr()) 81 | elif config.re_init_model: 82 | print("Downloading untrained model.") 83 | model = ModelClass.from_pretrained(config.model.name) 84 | else: 85 | try: 86 | # try to load specified model from local dir 87 | model = ModelClass.from_pretrained(config.model.pt) 88 | print(f"Loaded model: {config.model.pt}") 89 | except: 90 | print("Couldn't load model: {config.model.pt}. Downloading new model.") 91 | model = ModelClass.from_pretrained(config.model.name, cache_dir=scr()) 92 | 93 | 94 | if config.dropout is not None: 95 | n_reset = 0 96 | for m in model.modules(): 97 | if isinstance(m, nn.Dropout): 98 | m.p = config.dropout 99 | n_reset += 1 100 | 101 | if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout 102 | if isinstance(m.dropout, float): 103 | m.dropout = config.dropout 104 | n_reset += 1 105 | 106 | if hasattr(m, "activation_dropout"): # Requires for BART, which uses F.dropout 107 | if isinstance(m.activation_dropout, float): 108 | m.activation_dropout = config.dropout 109 | n_reset += 1 110 | 111 | LOG.info(f"Set {n_reset} dropout modules to p={config.dropout}") 112 | return model 113 | 114 | 115 | 116 | def get_tokenizer(config): 117 | tok_name = config.model.tokenizer_name if config.model.tokenizer_name is not None else config.model.name 118 | tokenizer = getattr(transformers, config.model.tokenizer_class).from_pretrained(tok_name, cache_dir=scr()) 119 | if not tokenizer.pad_token: 120 | tokenizer.pad_token = tokenizer.eos_token 121 | return tokenizer 122 | 123 | def get_processor(config): 124 | processor_name = config.model.processor_name if config.model.processor_name is not None else config.model.name 125 | processor = getattr(transformers, config.model.processor_class).from_pretrained(processor_name, cache_dir = scr()) 126 | return processor 127 | 128 | 129 | if __name__ == '__main__': 130 | m = BertClassifier("bert-base-uncased") 131 | m(torch.arange(5)[None, :]) 132 | import pdb; pdb.set_trace() 133 | -------------------------------------------------------------------------------- /melo/run.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import importlib 4 | import logging 5 | from time import time 6 | import hydra 7 | from omegaconf import OmegaConf,open_dict 8 | import numpy as np 9 | import torch 10 | from utils import * 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | import pickle 14 | import models 15 | from trainer import zsre_trainer, hallucination_trainer, scotus_trainer 16 | #TODO 17 | #SUPPORT MELO CONV1D fan_in_fan_out 18 | 19 | # os.environ['http_proxy'] = '127.0.0.1:7890' 20 | # os.environ['https_proxy'] = '127.0.0.1:7890' 21 | OmegaConf.register_new_resolver("uuid", lambda: uuid()) 22 | LOG = logging.getLogger(__name__) 23 | @hydra.main(config_path='config', config_name='config') 24 | def run(config): 25 | grace_config_keys = ['edit_lr','init_radius','expand_mode','key_id','num_edit_per_block','num_block','num_rank_per_block'] 26 | model_config_keys = ['target_modules','grace_layer'] 27 | GRACE_CONFIG = dict(config.grace) 28 | MODEL_CONFIG = dict(config.model) 29 | 30 | for k in grace_config_keys: 31 | LOG.info(f'[-GRACE CONFIG-] {k}: {GRACE_CONFIG[k]}') 32 | for k in model_config_keys: 33 | LOG.info(f'[-MODEL CONFIG-] {k}: {MODEL_CONFIG[k]}') 34 | 35 | base_dir = hydra.utils.get_original_cwd() 36 | with open_dict(config): 37 | config.base_dir = base_dir 38 | 39 | random.seed(config.seed) 40 | np.random.seed(config.seed) 41 | torch.manual_seed(config.seed) 42 | 43 | if config.task == "qa" or config.task == "zsre": 44 | model = models.get_hf_model(config) 45 | elif config.task == "hallucination": 46 | model = models.get_hf_model(config) 47 | elif config.task == "scotus": 48 | model = models.get_hf_model(config) 49 | else: 50 | print(f"{config.task} task not found") 51 | 52 | 53 | 54 | model.to(config.device) 55 | tokenizer = models.get_tokenizer(config) 56 | 57 | 58 | ''' 59 | Load Dataset 60 | ''' 61 | if config.task == "qa" or config.task == "zsre": 62 | from dataset import NQ, zsRE, zsRE_balanced 63 | from metrics import F1_ACC, is_qa_error 64 | upstream = NQ() 65 | edits = zsRE_balanced(split="edit", n_edits=1000) 66 | edit_holdouts = zsRE_balanced(split="holdout", n_edits=1000) 67 | 68 | '''Get Loaders 69 | ''' 70 | batch_size = config.grace.num_edit_per_block 71 | edit_loader = DataLoader(edits, batch_size=batch_size, shuffle=True) 72 | edit_holdout_loader = DataLoader(edit_holdouts, batch_size=batch_size, shuffle=False) 73 | upstream_loader = DataLoader(upstream, batch_size=batch_size, shuffle=False) 74 | hold_out = 0 75 | '''Define Metrics 76 | ''' 77 | metric = F1_ACC # Measure QA F1 78 | is_error = is_qa_error 79 | tokenize = tokenize_qa 80 | 81 | elif config.task == "hallucination": 82 | from dataset import Hallucination, WebText10k 83 | from metrics import PPL 84 | upstream = WebText10k() 85 | edits = Hallucination(split="edit") 86 | accurate_dataset = Hallucination(split="accurate") 87 | 88 | '''Get Loaders 89 | ''' 90 | batch_size = config.grace.num_edit_per_block 91 | edit_loader = DataLoader(edits, batch_size=batch_size, shuffle=False) 92 | accurate_loader = DataLoader(accurate_dataset, batch_size=batch_size, shuffle=False) 93 | upstream_loader = DataLoader(upstream, batch_size=batch_size, shuffle=False) 94 | '''Define Metrics 95 | ''' 96 | metric = PPL # Measure QA F1 97 | tokenize = tokenize_gpt 98 | elif config.task == 'scotus': 99 | from dataset import SCOTUS 100 | from metrics import Accuracy 101 | upstream = SCOTUS("train") 102 | edits = SCOTUS("edit") 103 | 104 | '''Get Loaders 105 | ''' 106 | batch_size = config.grace.num_edit_per_block 107 | edit_loader = DataLoader(edits, batch_size=batch_size, shuffle=False) 108 | upstream_loader = DataLoader(upstream, batch_size=batch_size, shuffle=False) 109 | '''Define Metrics 110 | ''' 111 | metric = Accuracy 112 | tokenize = tokenize_clf 113 | else: 114 | print(f"{config.task} task not found") 115 | 116 | alg_module = importlib.import_module(f'algs.{config.alg}') 117 | AlgClass = getattr(alg_module,config.alg.upper()) 118 | alg = AlgClass(model,config,tokenizer) 119 | alg.to(config.device) 120 | 121 | # Trainer 122 | if config.task == "qa" or config.task == "zsre": 123 | trainer = zsre_trainer(config,alg,tokenize,metric,edit_loader,upstream_loader,edit_holdout_loader) 124 | elif config.task == "hallucination": 125 | trainer = hallucination_trainer(config,alg,tokenize,metric,edit_loader,upstream_loader,accurate_loader) 126 | elif config.task == "scotus": 127 | trainer = scotus_trainer(config,alg,tokenize,metric,edit_loader,upstream_loader) 128 | 129 | # trainer.pre_editing_analyse() 130 | torch.cuda.empty_cache() 131 | trainer.run_edit() 132 | 133 | 134 | if __name__ == '__main__': 135 | run() -------------------------------------------------------------------------------- /melo/utils.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import os 4 | import numpy as np 5 | import datetime 6 | import struct 7 | from torch.nn.utils.rnn import pad_sequence 8 | import torch.nn.functional as F 9 | import wandb 10 | import hydra 11 | 12 | def get_inner_params(named_parameters, inner_names): 13 | param_dict = dict(named_parameters) 14 | return [(n, param_dict[n]) for n in inner_names] 15 | 16 | 17 | def param_subset(named_parameters, inner_names): 18 | param_dict = dict(named_parameters) 19 | return [param_dict[n] for n in inner_names] 20 | 21 | 22 | def parent_module(model, pname): 23 | components = pname.split('.') 24 | parent = model 25 | 26 | for component in components[:-1]: 27 | if hasattr(parent, component): 28 | parent = getattr(parent, component) 29 | elif component.isdigit(): 30 | parent = parent[int(component)] 31 | else: 32 | raise RuntimeError(f"Couldn't find child module {component}") 33 | 34 | if not hasattr(parent, components[-1]): 35 | raise RuntimeError(f"Couldn't find child module {components[-1]}") 36 | 37 | return parent 38 | 39 | 40 | def uuid(digits=4): 41 | if not hasattr(uuid, "uuid_value"): 42 | uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10 ** digits) 43 | 44 | return uuid.uuid_value 45 | 46 | def scr(): 47 | base_dir = hydra.utils.get_original_cwd() 48 | if os.path.exists(os.path.join(base_dir,"scr-ssd")): 49 | scr_dir = os.path.join(base_dir,"scr-ssd") 50 | else: 51 | scr_dir = os.path.join(base_dir,"scr") 52 | 53 | if not os.path.exists(scr_dir): 54 | os.makedirs(scr_dir) 55 | 56 | return scr_dir 57 | def ckpt_dir(): 58 | """returns the directory in which to store model checkpoints""" 59 | path = "./ckpts/" 60 | if not os.path.exists(path): 61 | os.makedirs(path) 62 | return path 63 | 64 | 65 | def brackets_to_periods(name): 66 | return name.replace("[", ".").replace("]", "") 67 | 68 | 69 | def get_params(model): 70 | return model.state_dict() 71 | 72 | 73 | def get_shape(p, model): 74 | # We need to flip the shapes since OpenAI gpt2 uses convs instead of linear 75 | return p.shape if isinstance(model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0]) 76 | 77 | 78 | def get_logits(x): 79 | return x.logits if hasattr(x, "logits") else x 80 | 81 | 82 | def tokenize_gpt(batch, tokenizer, device, test=False): 83 | prompt, label = batch["text"], batch["labels"] 84 | mask_token = -100 # ignore_index of CrossEntropyLoss 85 | if test or not label: 86 | tokens = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True) 87 | tokens["labels"] = tokens["input_ids"].clone() 88 | tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token 89 | 90 | else: 91 | full_prompt = [f"{p} {l} <|endoftext|>" for p, l in zip(prompt, label)] 92 | # full_prompt = [f"{p} {l} {tokenizer.eos_token}" for p, l in zip(prompt, label)] 93 | prompt_ids = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True)["input_ids"] 94 | num_prompt_toks = [int((i != tokenizer.pad_token_id).sum()) for i in prompt_ids] 95 | tokens = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True) 96 | tokens["labels"] = tokens["input_ids"].clone() 97 | for i in range(len(prompt)): 98 | tokens["labels"][i][:num_prompt_toks[i]] = mask_token 99 | 100 | tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token # What is this doing? 101 | 102 | tokens = {f"{k1}": v1.to(device) for k1, v1 in tokens.items()} 103 | return tokens 104 | 105 | 106 | def tokenize_qa(batch, tokenizer, device, **kwargs): 107 | input_sequences, output_sequences = batch["text"], batch["labels"] 108 | 109 | input_encoding = tokenizer( 110 | list(input_sequences), 111 | padding="longest", 112 | max_length=20, 113 | truncation=True, 114 | return_tensors="pt", 115 | ) 116 | 117 | input_ids, attention_mask = input_encoding.input_ids, input_encoding.attention_mask 118 | 119 | target_encoding = tokenizer( 120 | list(output_sequences), 121 | padding="longest", 122 | max_length=20, 123 | truncation=True, 124 | return_tensors="pt", 125 | ) 126 | 127 | labels = target_encoding.input_ids 128 | labels[labels == tokenizer.pad_token_id] = -100 129 | 130 | tokens = { 131 | "input_ids": input_ids, 132 | "attention_mask": attention_mask, 133 | "labels": labels 134 | } 135 | 136 | tokens = {f"{k1}": v1.to(device) for k1, v1 in tokens.items()} 137 | return tokens 138 | 139 | 140 | def tokenize_clf(batch, tokenizer, device, **kwargs): 141 | input_sequences, labels = batch["text"], batch["labels"] 142 | 143 | tokens = tokenizer(input_sequences, truncation=True, padding="max_length", return_tensors="pt") 144 | tokens["labels"] = labels 145 | tokens = {f"{k1}": v1.to(device) for k1, v1 in tokens.items()} 146 | return tokens 147 | 148 | 149 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /melo/wiki_bio_concepts.txt: -------------------------------------------------------------------------------- 1 | john russell reynolds 2 | matthew aylmer , 1st baron aylmer 3 | rick mahler 4 | james blair south carolina 5 | tim finchem 6 | akila dananjaya 7 | derek king australian footballer 8 | wilhelm windelband 9 | freddie frith 10 | marshall manesh 11 | eleanor arnason 12 | carter harrison , sr. . 13 | winnebago deal 14 | noel hogan 15 | dawn landes 16 | bill quinn 17 | carol huston 18 | gia carangi 19 | nigel milsom 20 | rod morgenstein 21 | terry alderman 22 | frank a. mclain 23 | rich williams 24 | torry castellano 25 | albert i , margrave of meissen 26 | sirið stenberg 27 | thomas harriot 28 | tadeusz szeligowski 29 | gordon strachan 30 | steven threet 31 | archie baird 32 | peter breen politician 33 | adja yunkers 34 | the blood divine 35 | king zhuang of chu 36 | william j. flanagan , jr. 37 | k. s. manilal 38 | jeannine riley 39 | seyi shay 40 | hilda kuper 41 | stuart scott 42 | mark fite 43 | philippe dodard 44 | rudy fernandez labor leader 45 | mackenzie caquatto 46 | twila shively 47 | lionel aldridge 48 | irena sendler 49 | ronnie barker 50 | honoré iii , prince of monaco 51 | emily gielnik 52 | choi jae-bong 53 | tom izzo 54 | tommy nutter 55 | jearl walker 56 | steve ridzik 57 | achille-ferdinand carrier 58 | tera van beilen 59 | harry kennedy 60 | david kappos 61 | pattern is movement 62 | kévin gameiro 63 | lee hsien loong 64 | lucien turcotte pacaud 65 | makiko esumi 66 | kate deines 67 | c. v. ananda bose 68 | anthony dimond 69 | honoré iv , prince of monaco 70 | tristan rogers 71 | john burnham cricketer 72 | nate saint 73 | thutmose iii 74 | john loder sound engineer 75 | a. p. j. abdul kalam 76 | john reed , jr. . 77 | paul elliott politician 78 | moisés kaufman 79 | robert holgate 80 | duncan mackay footballer 81 | saul david 82 | tomasz lis 83 | véra korène 84 | nodar kumaritashvili 85 | leana de bruin 86 | alfred fischer ss officer 87 | kermit davis 88 | daniel ménard 89 | modibo adama 90 | bert deacon 91 | mushahid hussain syed 92 | kia joorabchian 93 | vitaliano brancati 94 | emperor wenxuan of northern qi 95 | johan christian dahl 96 | steve cooper footballer , born 1964 97 | ernest miller cinematographer 98 | david king australian rules footballer 99 | danny smith coach 100 | hope cooke 101 | tathagata satpathy 102 | michel mathieu canadian politician 103 | mario monti 104 | pino palladino 105 | tony la russa 106 | murray g. ross 107 | malcolm brogdon 108 | john les 109 | evan rachel wood 110 | frank abagnale 111 | reezal merican naina merican 112 | dan stearns 113 | lindsay crouse 114 | clay timpner 115 | yaakov israel ifargan 116 | ha jung-woo 117 | charles lee basketball 118 | stereophonics 119 | don r. swanson 120 | roy beggs , jr. . 121 | adiele afigbo 122 | brian petrovek 123 | john cushnahan 124 | ron meagher 125 | george milne cricketer 126 | bill tobin american football 127 | william luther pierce 128 | martina sorbara 129 | tom wise 130 | frederick thomas brentnall 131 | bill brown goalkeeper 132 | eden natan-zada 133 | richard carpenter screenwriter 134 | joe brown utility player 135 | wayne allyn root 136 | assassination of robert f. kennedy 137 | paul caddis 138 | paul taylor winger 139 | linda hunt 140 | jerry leger 141 | 3rd dalai lama 142 | james clarke vc 143 | jack straw 144 | syd rapson 145 | billy barnie 146 | catherine johnson playwright 147 | sara montiel 148 | lucy akhurst 149 | william allan neilson 150 | elisha brown 151 | joe walsh rugby league 152 | josiah mason 153 | balbir singh kullar 154 | george bovell 155 | fei-ping hsu 156 | anne de gaulle 157 | rusty stevens 158 | john cameron alberta politician 159 | carole gist 160 | david collings 161 | matt striebel 162 | bob miller american football 163 | bryan mcclendon 164 | royce campbell 165 | carlos arniches 166 | geoff griffin 167 | frankie lymon 168 | raymond harry brown 169 | george roll 170 | ayn rand 171 | richard allen epstein 172 | tom butler actor 173 | kenan hasagić 174 | gordon hogg 175 | vagos motorcycle club 176 | katie ledecky 177 | michael savage 178 | john howe illustrator 179 | alana davis 180 | arthur sewall 181 | stan heal 182 | ithamara koorax 183 | thomas wolfe 184 | john russell vc 185 | cicero hunt lewis 186 | philip of france 1116 -- 1131 187 | brian hughes musician 188 | rickey paulding 189 | charles melville hays 190 | lee naylor footballer 191 | bane band 192 | adam collis 193 | alan dinehart 194 | sylvain barrier 195 | kirill karabits 196 | b. k. anand 197 | robert emmett keane 198 | charlotte rae 199 | riccardo tisci 200 | lester germer 201 | laurent koscielny 202 | bridget moynahan 203 | george hubbard clapp 204 | merle oberon 205 | mayhew foster 206 | hephaestion 207 | thomas biagi 208 | susan pedersen swimmer 209 | tetsuzō iwamoto 210 | donald alexander mackinnon 211 | joe holland basketball 212 | casey serin 213 | jean hugo 214 | heinz christian pander 215 | arthur tedder , 1st baron tedder 216 | cindy kleine 217 | willie naulls 218 | john holman chemist 219 | paul y. r. waddington 220 | andy hurley 221 | dara torres 222 | john hughes archbishop of new york 223 | millicent shelton 224 | whitey kurowski 225 | nofx 226 | hisashi iwakuma 227 | virginia bottomley 228 | john liscio 229 | john vallely 230 | johannes andreas august grabau 231 | guðlaugur Þór Þórðarson 232 | laurier lévesque 233 | micky moody 234 | gündüz kılıç 235 | michael replogle 236 | billy burke golfer 237 | ted childs 238 | edward synge archbishop of tuam -------------------------------------------------------------------------------- /peft_egg/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # VSCode 132 | .vscode 133 | 134 | # IntelliJ 135 | .idea 136 | 137 | # Mac .DS_Store 138 | .DS_Store 139 | 140 | # More test things 141 | wandb -------------------------------------------------------------------------------- /peft_egg/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: quality style test docs 2 | 3 | check_dirs := src tests examples docs 4 | 5 | # Check that source code meets quality standards 6 | 7 | # this target runs checks on all files 8 | quality: 9 | black --check $(check_dirs) 10 | ruff $(check_dirs) 11 | doc-builder style src/peft tests docs/source --max_len 119 --check_only 12 | 13 | # Format source code automatically and check is there are any problems left that need manual fixing 14 | style: 15 | black $(check_dirs) 16 | ruff $(check_dirs) --fix 17 | doc-builder style src/peft tests docs/source --max_len 119 18 | 19 | test: 20 | python -m pytest -n 3 tests/ $(if $(IS_GITHUB_CI),--report-log "ci_tests.log",) 21 | 22 | tests_examples_multi_gpu: 23 | python -m pytest -m multi_gpu_tests tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "multi_gpu_examples.log",) 24 | 25 | tests_examples_single_gpu: 26 | python -m pytest -m single_gpu_tests tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "single_gpu_examples.log",) 27 | 28 | tests_core_multi_gpu: 29 | python -m pytest -m multi_gpu_tests tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_multi_gpu.log",) 30 | 31 | tests_core_single_gpu: 32 | python -m pytest -m single_gpu_tests tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_single_gpu.log",) 33 | 34 | tests_common_gpu: 35 | python -m pytest tests/test_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_decoder.log",) 36 | python -m pytest tests/test_encoder_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_encoder_decoder.log",) 37 | -------------------------------------------------------------------------------- /peft_egg/docker/peft-cpu/Dockerfile: -------------------------------------------------------------------------------- 1 | # Builds GPU docker image of PyTorch 2 | # Uses multi-staged approach to reduce size 3 | # Stage 1 4 | # Use base conda image to reduce time 5 | FROM continuumio/miniconda3:latest AS compile-image 6 | # Specify py version 7 | ENV PYTHON_VERSION=3.8 8 | # Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 9 | RUN apt-get update && \ 10 | apt-get install -y curl git wget software-properties-common git-lfs && \ 11 | apt-get clean && \ 12 | rm -rf /var/lib/apt/lists* 13 | 14 | # Install audio-related libraries 15 | RUN apt-get update && \ 16 | apt install -y ffmpeg 17 | 18 | RUN apt install -y libsndfile1-dev 19 | RUN git lfs install 20 | 21 | # Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 22 | RUN conda create --name peft python=${PYTHON_VERSION} ipython jupyter pip 23 | RUN python3 -m pip install --no-cache-dir --upgrade pip 24 | 25 | # Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 26 | # We don't install pytorch here yet since CUDA isn't available 27 | # instead we use the direct torch wheel 28 | ENV PATH /opt/conda/envs/peft/bin:$PATH 29 | # Activate our bash shell 30 | RUN chsh -s /bin/bash 31 | SHELL ["/bin/bash", "-c"] 32 | # Activate the conda env and install transformers + accelerate from source 33 | RUN source activate peft && \ 34 | python3 -m pip install --no-cache-dir \ 35 | librosa \ 36 | "soundfile>=0.12.1" \ 37 | scipy \ 38 | git+https://github.com/huggingface/transformers \ 39 | git+https://github.com/huggingface/accelerate \ 40 | peft[test]@git+https://github.com/huggingface/peft 41 | 42 | # Install apt libs 43 | RUN apt-get update && \ 44 | apt-get install -y curl git wget && \ 45 | apt-get clean && \ 46 | rm -rf /var/lib/apt/lists* 47 | 48 | RUN echo "source activate peft" >> ~/.profile 49 | 50 | # Activate the virtualenv 51 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /peft_egg/docker/peft-gpu/Dockerfile: -------------------------------------------------------------------------------- 1 | # Builds GPU docker image of PyTorch 2 | # Uses multi-staged approach to reduce size 3 | # Stage 1 4 | # Use base conda image to reduce time 5 | FROM continuumio/miniconda3:latest AS compile-image 6 | # Specify py version 7 | ENV PYTHON_VERSION=3.8 8 | # Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 9 | RUN apt-get update && \ 10 | apt-get install -y curl git wget software-properties-common git-lfs && \ 11 | apt-get clean && \ 12 | rm -rf /var/lib/apt/lists* 13 | 14 | # Install audio-related libraries 15 | RUN apt-get update && \ 16 | apt install -y ffmpeg 17 | 18 | RUN apt install -y libsndfile1-dev 19 | RUN git lfs install 20 | 21 | # Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 22 | RUN conda create --name peft python=${PYTHON_VERSION} ipython jupyter pip 23 | RUN python3 -m pip install --no-cache-dir --upgrade pip 24 | 25 | # Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 26 | # We don't install pytorch here yet since CUDA isn't available 27 | # instead we use the direct torch wheel 28 | ENV PATH /opt/conda/envs/peft/bin:$PATH 29 | # Activate our bash shell 30 | RUN chsh -s /bin/bash 31 | SHELL ["/bin/bash", "-c"] 32 | # Activate the conda env and install transformers + accelerate from source 33 | RUN source activate peft && \ 34 | python3 -m pip install --no-cache-dir \ 35 | librosa \ 36 | "soundfile>=0.12.1" \ 37 | scipy \ 38 | git+https://github.com/huggingface/transformers \ 39 | git+https://github.com/huggingface/accelerate \ 40 | peft[test]@git+https://github.com/huggingface/peft 41 | 42 | RUN python3 -m pip install --no-cache-dir bitsandbytes 43 | 44 | # Stage 2 45 | FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 AS build-image 46 | COPY --from=compile-image /opt/conda /opt/conda 47 | ENV PATH /opt/conda/bin:$PATH 48 | 49 | # Install apt libs 50 | RUN apt-get update && \ 51 | apt-get install -y curl git wget && \ 52 | apt-get clean && \ 53 | rm -rf /var/lib/apt/lists* 54 | 55 | RUN echo "source activate peft" >> ~/.profile 56 | 57 | # Activate the virtualenv 58 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /peft_egg/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /peft_egg/docs/source/_config.py: -------------------------------------------------------------------------------- 1 | # docstyle-ignore 2 | INSTALL_CONTENT = """ 3 | # PEFT installation 4 | ! pip install peft accelerate transformers 5 | # To install from source instead of the last release, comment the command above and uncomment the following one. 6 | # ! pip install git+https://github.com/huggingface/peft.git 7 | """ 8 | -------------------------------------------------------------------------------- /peft_egg/docs/source/_toctree.yml: -------------------------------------------------------------------------------- 1 | - title: Get started 2 | sections: 3 | - local: index 4 | title: 🤗 PEFT 5 | - local: quicktour 6 | title: Quicktour 7 | - local: install 8 | title: Installation 9 | 10 | - title: Task guides 11 | sections: 12 | - local: task_guides/image_classification_lora 13 | title: Image classification using LoRA 14 | - local: task_guides/seq2seq-prefix-tuning 15 | title: Prefix tuning for conditional generation 16 | - local: task_guides/clm-prompt-tuning 17 | title: Prompt tuning for causal language modeling 18 | - local: task_guides/semantic_segmentation_lora 19 | title: Semantic segmentation using LoRA 20 | - local: task_guides/ptuning-seq-classification 21 | title: P-tuning for sequence classification 22 | - local: task_guides/dreambooth_lora 23 | title: Dreambooth fine-tuning with LoRA 24 | - local: task_guides/token-classification-lora 25 | title: LoRA for token classification 26 | - local: task_guides/int8-asr 27 | title: int8 training for automatic speech recognition 28 | 29 | - title: 🤗 Accelerate integrations 30 | sections: 31 | - local: accelerate/deepspeed-zero3-offload 32 | title: DeepSpeed 33 | - local: accelerate/fsdp 34 | title: Fully Sharded Data Parallel 35 | 36 | - title: Conceptual guides 37 | sections: 38 | - local: conceptual_guides/lora 39 | title: LoRA 40 | - local: conceptual_guides/prompting 41 | title: Prompting 42 | 43 | - title: Reference 44 | sections: 45 | - local: package_reference/peft_model 46 | title: PEFT model 47 | - local: package_reference/config 48 | title: Configuration 49 | - local: package_reference/tuners 50 | title: Tuners -------------------------------------------------------------------------------- /peft_egg/docs/source/accelerate/fsdp.mdx: -------------------------------------------------------------------------------- 1 | # Fully Sharded Data Parallel 2 | 3 | [Fully sharded data parallel](https://pytorch.org/docs/stable/fsdp.html) (FSDP) is developed for distributed training of large pretrained models up to 1T parameters. FSDP achieves this by sharding the model parameters, gradients, and optimizer states across data parallel processes and it can also offload sharded model parameters to a CPU. The memory efficiency afforded by FSDP allows you to scale training to larger batch or model sizes. 4 | 5 | 6 | 7 | Currently, FSDP does not confer any reduction in GPU memory usage and FSDP with CPU offload actually consumes 1.65x more GPU memory during training. You can track this PyTorch [issue](https://github.com/pytorch/pytorch/issues/91165) for any updates. 8 | 9 | 10 | 11 | FSDP is supported in 🤗 Accelerate, and you can use it with 🤗 PEFT. This guide will help you learn how to use our FSDP [training script](https://github.com/huggingface/peft/blob/main/examples/conditional_generation/peft_lora_seq2seq_accelerate_fsdp.py). You'll configure the script to train a large model for conditional generation. 12 | 13 | ## Configuration 14 | 15 | Begin by running the following command to [create a FSDP configuration file](https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp) with 🤗 Accelerate. Use the `--config_file` flag to save the configuration file to a specific location, otherwise it is saved as a `default_config.yaml` file in the 🤗 Accelerate cache. 16 | 17 | The configuration file is used to set the default options when you launch the training script. 18 | 19 | ```bash 20 | accelerate config --config_file fsdp_config.yaml 21 | ``` 22 | 23 | You'll be asked a few questions about your setup, and configure the following arguments. For this example, make sure you fully shard the model parameters, gradients, optimizer states, leverage the CPU for offloading, and wrap model layers based on the Transformer layer class name. 24 | 25 | ```bash 26 | `Sharding Strategy`: [1] FULL_SHARD (shards optimizer states, gradients and parameters), [2] SHARD_GRAD_OP (shards optimizer states and gradients), [3] NO_SHARD 27 | `Offload Params`: Decides Whether to offload parameters and gradients to CPU 28 | `Auto Wrap Policy`: [1] TRANSFORMER_BASED_WRAP, [2] SIZE_BASED_WRAP, [3] NO_WRAP 29 | `Transformer Layer Class to Wrap`: When using `TRANSFORMER_BASED_WRAP`, user specifies comma-separated string of transformer layer class names (case-sensitive) to wrap ,e.g, 30 | `BertLayer`, `GPTJBlock`, `T5Block`, `BertLayer,BertEmbeddings,BertSelfOutput`... 31 | `Min Num Params`: minimum number of parameters when using `SIZE_BASED_WRAP` 32 | `Backward Prefetch`: [1] BACKWARD_PRE, [2] BACKWARD_POST, [3] NO_PREFETCH 33 | `State Dict Type`: [1] FULL_STATE_DICT, [2] LOCAL_STATE_DICT, [3] SHARDED_STATE_DICT 34 | ``` 35 | 36 | For example, your FSDP configuration file may look like the following: 37 | 38 | ```yaml 39 | command_file: null 40 | commands: null 41 | compute_environment: LOCAL_MACHINE 42 | deepspeed_config: {} 43 | distributed_type: FSDP 44 | downcast_bf16: 'no' 45 | dynamo_backend: 'NO' 46 | fsdp_config: 47 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 48 | fsdp_backward_prefetch_policy: BACKWARD_PRE 49 | fsdp_offload_params: true 50 | fsdp_sharding_strategy: 1 51 | fsdp_state_dict_type: FULL_STATE_DICT 52 | fsdp_transformer_layer_cls_to_wrap: T5Block 53 | gpu_ids: null 54 | machine_rank: 0 55 | main_process_ip: null 56 | main_process_port: null 57 | main_training_function: main 58 | megatron_lm_config: {} 59 | mixed_precision: 'no' 60 | num_machines: 1 61 | num_processes: 2 62 | rdzv_backend: static 63 | same_network: true 64 | tpu_name: null 65 | tpu_zone: null 66 | use_cpu: false 67 | ``` 68 | 69 | ## The important parts 70 | 71 | Let's dig a bit deeper into the training script to understand how it works. 72 | 73 | The [`main()`](https://github.com/huggingface/peft/blob/2822398fbe896f25d4dac5e468624dc5fd65a51b/examples/conditional_generation/peft_lora_seq2seq_accelerate_fsdp.py#L14) function begins with initializing an [`~accelerate.Accelerator`] class which handles everything for distributed training, such as automatically detecting your training environment. 74 | 75 | 76 | 77 | 💡 Feel free to change the model and dataset inside the `main` function. If your dataset format is different from the one in the script, you may also need to write your own preprocessing function. 78 | 79 | 80 | 81 | The script also creates a configuration corresponding to the 🤗 PEFT method you're using. For LoRA, you'll use [`LoraConfig`] to specify the task type, and several other important parameters such as the dimension of the low-rank matrices, the matrices scaling factor, and the dropout probability of the LoRA layers. If you want to use a different 🤗 PEFT method, replace `LoraConfig` with the appropriate [class](../package_reference/tuners). 82 | 83 | Next, the script wraps the base model and `peft_config` with the [`get_peft_model`] function to create a [`PeftModel`]. 84 | 85 | ```diff 86 | def main(): 87 | + accelerator = Accelerator() 88 | model_name_or_path = "t5-base" 89 | base_path = "temp/data/FinancialPhraseBank-v1.0" 90 | + peft_config = LoraConfig( 91 | task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1 92 | ) 93 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) 94 | + model = get_peft_model(model, peft_config) 95 | ``` 96 | 97 | Throughout the script, you'll see the [`~accelerate.Accelerator.main_process_first`] and [`~accelerate.Accelerator.wait_for_everyone`] functions which help control and synchronize when processes are executed. 98 | 99 | After your dataset is prepared, and all the necessary training components are loaded, the script checks if you're using the `fsdp_plugin`. PyTorch offers two ways for wrapping model layers in FSDP, automatically or manually. The simplest method is to allow FSDP to automatically recursively wrap model layers without changing any other code. You can choose to wrap the model layers based on the layer name or on the size (number of parameters). In the FSDP configuration file, it uses the `TRANSFORMER_BASED_WRAP` option to wrap the [`T5Block`] layer. 100 | 101 | ```py 102 | if getattr(accelerator.state, "fsdp_plugin", None) is not None: 103 | accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model) 104 | ``` 105 | 106 | Next, use 🤗 Accelerate's [`~accelerate.Accelerator.prepare`] function to prepare the model, datasets, optimizer, and scheduler for training. 107 | 108 | ```py 109 | model, train_dataloader, eval_dataloader, optimizer, lr_scheduler = accelerator.prepare( 110 | model, train_dataloader, eval_dataloader, optimizer, lr_scheduler 111 | ) 112 | ``` 113 | 114 | From here, the remainder of the script handles the training loop, evaluation, and sharing your model to the Hub. 115 | 116 | ## Train 117 | 118 | Run the following command to launch the training script. Earlier, you saved the configuration file to `fsdp_config.yaml`, so you'll need to pass the path to the launcher with the `--config_file` argument like this: 119 | 120 | ```bash 121 | accelerate launch --config_file fsdp_config.yaml examples/peft_lora_seq2seq_accelerate_fsdp.py 122 | ``` 123 | 124 | Once training is complete, the script returns the accuracy and compares the predictions to the labels. -------------------------------------------------------------------------------- /peft_egg/docs/source/conceptual_guides/lora.mdx: -------------------------------------------------------------------------------- 1 | 12 | 13 | # LoRA 14 | 15 | This conceptual guide gives a brief overview of [LoRA](https://arxiv.org/abs/2106.09685), a technique that accelerates 16 | the fine-tuning of large models while consuming less memory. 17 | 18 | To make fine-tuning more efficient, LoRA's approach is to represent the weight updates with two smaller 19 | matrices (called **update matrices**) through low-rank decomposition. These new matrices can be trained to adapt to the 20 | new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn't receive 21 | any further adjustments. To produce the final results, both the original and the adapted weights are combined. 22 | 23 | This approach has a number of advantages: 24 | 25 | * LoRA makes fine-tuning more efficient by drastically reducing the number of trainable parameters. 26 | * The original pre-trained weights are kept frozen, which means you can have multiple lightweight and portable LoRA models for various downstream tasks built on top of them. 27 | * LoRA is orthogonal to many other parameter-efficient methods and can be combined with many of them. 28 | * Performance of models fine-tuned using LoRA is comparable to the performance of fully fine-tuned models. 29 | * LoRA does not add any inference latency because adapter weights can be merged with the base model. 30 | 31 | In principle, LoRA can be applied to any subset of weight matrices in a neural network to reduce the number of trainable 32 | parameters. However, for simplicity and further parameter efficiency, in Transformer models LoRA is typically applied to 33 | attention blocks only. The resulting number of trainable parameters in a LoRA model depends on the size of the low-rank 34 | update matrices, which is determined mainly by the rank `r` and the shape of the original weight matrix. 35 | 36 | ## Common LoRA parameters in PEFT 37 | 38 | As with other methods supported by PEFT, to fine-tune a model using LoRA, you need to: 39 | 40 | 1. Instantiate a base model. 41 | 2. Create a configuration (`LoraConfig`) where you define LoRA-specific parameters. 42 | 3. Wrap the base model with `get_peft_model()` to get a trainable `PeftModel`. 43 | 4. Train the `PeftModel` as you normally would train the base model. 44 | 45 | `LoraConfig` allows you to control how LoRA is applied to the base model through the following parameters: 46 | 47 | - `r`: the rank of the update matrices, expressed in `int`. Lower rank results in smaller update matrices with fewer trainable parameters. 48 | - `target_modules`: The modules (for example, attention blocks) to apply the LoRA update matrices. 49 | - `alpha`: LoRA scaling factor. 50 | - `bias`: Specifies if the `bias` parameters should be trained. Can be `'none'`, `'all'` or `'lora_only'`. 51 | - `modules_to_save`: List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task. 52 | - `layers_to_transform`: List of layers to be transformed by LoRA. If not specified, all layers in `target_modules` are transformed. 53 | - `layers_pattern`: Pattern to match layer names in `target_modules`, if `layers_to_transform` is specified. By default `PeftModel` will look at common layer pattern (`layers`, `h`, `blocks`, etc.), use it for exotic and custom models. 54 | 55 | ## LoRA examples 56 | 57 | For an example of LoRA method application to various downstream tasks, please refer to the following guides: 58 | 59 | * [Image classification using LoRA](../task_guides/image_classification_lora) 60 | * [Semantic segmentation](../task_guides/semantic_segmentation_lora) 61 | 62 | While the original paper focuses on language models, the technique can be applied to any dense layers in deep learning 63 | models. As such, you can leverage this technique with diffusion models. See [Dreambooth fine-tuning with LoRA](../task_guides/task_guides/dreambooth_lora) task guide for an example. 64 | -------------------------------------------------------------------------------- /peft_egg/docs/source/conceptual_guides/prompting.mdx: -------------------------------------------------------------------------------- 1 | # Prompting 2 | 3 | Training large pretrained language models is very time-consuming and compute-intensive. As they continue to grow in size, there is increasing interest in more efficient training methods such as *prompting*. Prompting primes a frozen pretrained model for a specific downstream task by including a text prompt that describes the task or even demonstrates an example of the task. With prompting, you can avoid fully training a separate model for each downstream task, and use the same frozen pretrained model instead. This is a lot easier because you can use the same model for several different tasks, and it is significantly more efficient to train and store a smaller set of prompt parameters than to train all the model's parameters. 4 | 5 | There are two categories of prompting methods: 6 | 7 | - hard prompts are manually handcrafted text prompts with discrete input tokens; the downside is that it requires a lot of effort to create a good prompt 8 | - soft prompts are learnable tensors concatenated with the input embeddings that can be optimized to a dataset; the downside is that they aren't human readable because you aren't matching these "virtual tokens" to the embeddings of a real word 9 | 10 | This conceptual guide provides a brief overview of the soft prompt methods included in 🤗 PEFT: prompt tuning, prefix tuning, and P-tuning. 11 | 12 | ## Prompt tuning 13 | 14 |
15 | 16 |
17 | Only train and store a significantly smaller set of task-specific prompt parameters (image source). 18 | 19 | Prompt tuning was developed for text classification tasks on T5 models, and all downstream tasks are cast as a text generation task. For example, sequence classification usually assigns a single class label to a sequence of text. By casting it as a text generation task, the tokens that make up the class label are *generated*. Prompts are added to the input as a series of tokens. Typically, the model parameters are fixed which means the prompt tokens are also fixed by the model parameters. 20 | 21 | The key idea behind prompt tuning is that prompt tokens have their own parameters that are updated independently. This means you can keep the pretrained model's parameters frozen, and only update the gradients of the prompt token embeddings. The results are comparable to the traditional method of training the entire model, and prompt tuning performance scales as model size increases. 22 | 23 | Take a look at [Prompt tuning for causal language modeling](../task_guides/clm-prompt-tuning) for a step-by-step guide on how to train a model with prompt tuning. 24 | 25 | ## Prefix tuning 26 | 27 |
28 | 29 |
30 | Optimize the prefix parameters for each task (image source). 31 | 32 | Prefix tuning was designed for natural language generation (NLG) tasks on GPT models. It is very similar to prompt tuning; prefix tuning also prepends a sequence of task-specific vectors to the input that can be trained and updated while keeping the rest of the pretrained model's parameters frozen. 33 | 34 | The main difference is that the prefix parameters are inserted in **all** of the model layers, whereas prompt tuning only adds the prompt parameters to the model input embeddings. The prefix parameters are also optimized by a separate feed-forward network (FFN) instead of training directly on the soft prompts because it causes instability and hurts performance. The FFN is discarded after updating the soft prompts. 35 | 36 | As a result, the authors found that prefix tuning demonstrates comparable performance to fully finetuning a model, despite having 1000x fewer parameters, and it performs even better in low-data settings. 37 | 38 | Take a look at [Prefix tuning for conditional generation](../task_guides/seq2seq-prefix-tuning) for a step-by-step guide on how to train a model with prefix tuning. 39 | 40 | ## P-tuning 41 | 42 |
43 | 44 |
45 | Prompt tokens can be inserted anywhere in the input sequence, and they are optimized by a prompt encoder (image source). 46 | 47 | P-tuning is designed for natural language understanding (NLU) tasks and all language models. 48 | It is another variation of a soft prompt method; P-tuning also adds a trainable embedding tensor that can be optimized to find better prompts, and it uses a prompt encoder (a bidirectional long-short term memory network or LSTM) to optimize the prompt parameters. Unlike prefix tuning though: 49 | 50 | - the prompt tokens can be inserted anywhere in the input sequence, and it isn't restricted to only the beginning 51 | - the prompt tokens are only added to the input instead of adding them to every layer of the model 52 | - introducing *anchor* tokens can improve performance because they indicate characteristics of a component in the input sequence 53 | 54 | The results suggest that P-tuning is more efficient than manually crafting prompts, and it enables GPT-like models to compete with BERT-like models on NLU tasks. 55 | 56 | Take a look at [P-tuning for sequence classification](../task_guides/ptuning-seq-classification) for a step-by-step guide on how to train a model with P-tuning. -------------------------------------------------------------------------------- /peft_egg/docs/source/index.mdx: -------------------------------------------------------------------------------- 1 | 12 | 13 | # PEFT 14 | 15 | 🤗 PEFT, or Parameter-Efficient Fine-Tuning (PEFT), is a library for efficiently adapting pre-trained language models (PLMs) to various downstream applications without fine-tuning all the model's parameters. 16 | PEFT methods only fine-tune a small number of (extra) model parameters, significantly decreasing computational and storage costs because fine-tuning large-scale PLMs is prohibitively costly. 17 | Recent state-of-the-art PEFT techniques achieve performance comparable to that of full fine-tuning. 18 | 19 | PEFT is seamlessly integrated with 🤗 Accelerate for large-scale models leveraging DeepSpeed and [Big Model Inference](https://huggingface.co/docs/accelerate/usage_guides/big_modeling). 20 | 21 |
22 |
23 |
Get started
25 |

Start here if you're new to 🤗 PEFT to get an overview of the library's main features, and how to train a model with a PEFT method.

26 |
27 |
How-to guides
29 |

Practical guides demonstrating how to apply various PEFT methods across different types of tasks like image classification, causal language modeling, automatic speech recognition, and more. Learn how to use 🤗 PEFT with the DeepSpeed and Fully Sharded Data Parallel scripts.

30 |
31 |
Conceptual guides
33 |

Get a better theoretical understanding of how LoRA and various soft prompting methods help reduce the number of trainable parameters to make training more efficient.

34 |
35 |
Reference
37 |

Technical descriptions of how 🤗 PEFT classes and methods work.

38 |
39 |
40 |
41 | 42 | ## Supported methods 43 | 44 | 1. LoRA: [LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/pdf/2106.09685.pdf) 45 | 2. Prefix Tuning: [Prefix-Tuning: Optimizing Continuous Prompts for Generation](https://aclanthology.org/2021.acl-long.353/), [P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/pdf/2110.07602.pdf) 46 | 3. P-Tuning: [GPT Understands, Too](https://arxiv.org/pdf/2103.10385.pdf) 47 | 4. Prompt Tuning: [The Power of Scale for Parameter-Efficient Prompt Tuning](https://arxiv.org/pdf/2104.08691.pdf) 48 | 5. AdaLoRA: [Adaptive Budget Allocation for Parameter-Efficient Fine-Tuning](https://arxiv.org/abs/2303.10512) 49 | 6. [LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention](https://github.com/ZrrSkywalker/LLaMA-Adapter) 50 | 51 | ## Supported models 52 | 53 | The tables provided below list the PEFT methods and models supported for each task. To apply a particular PEFT method for 54 | a task, please refer to the corresponding Task guides. 55 | 56 | ### Causal Language Modeling 57 | 58 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | 59 | |--------------| ---- | ---- | ---- | ---- | 60 | | GPT-2 | ✅ | ✅ | ✅ | ✅ | 61 | | Bloom | ✅ | ✅ | ✅ | ✅ | 62 | | OPT | ✅ | ✅ | ✅ | ✅ | 63 | | GPT-Neo | ✅ | ✅ | ✅ | ✅ | 64 | | GPT-J | ✅ | ✅ | ✅ | ✅ | 65 | | GPT-NeoX-20B | ✅ | ✅ | ✅ | ✅ | 66 | | LLaMA | ✅ | ✅ | ✅ | ✅ | 67 | | ChatGLM | ✅ | ✅ | ✅ | ✅ | 68 | 69 | ### Conditional Generation 70 | 71 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | 72 | | --------- | ---- | ---- | ---- | ---- | 73 | | T5 | ✅ | ✅ | ✅ | ✅ | 74 | | BART | ✅ | ✅ | ✅ | ✅ | 75 | 76 | ### Sequence Classification 77 | 78 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | 79 | | --------- | ---- | ---- | ---- | ---- | 80 | | BERT | ✅ | ✅ | ✅ | ✅ | 81 | | RoBERTa | ✅ | ✅ | ✅ | ✅ | 82 | | GPT-2 | ✅ | ✅ | ✅ | ✅ | 83 | | Bloom | ✅ | ✅ | ✅ | ✅ | 84 | | OPT | ✅ | ✅ | ✅ | ✅ | 85 | | GPT-Neo | ✅ | ✅ | ✅ | ✅ | 86 | | GPT-J | ✅ | ✅ | ✅ | ✅ | 87 | | Deberta | ✅ | | ✅ | ✅ | 88 | | Deberta-v2 | ✅ | | ✅ | ✅ | 89 | 90 | ### Token Classification 91 | 92 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | 93 | | --------- | ---- | ---- | ---- | ---- | 94 | | BERT | ✅ | ✅ | | | 95 | | RoBERTa | ✅ | ✅ | | | 96 | | GPT-2 | ✅ | ✅ | | | 97 | | Bloom | ✅ | ✅ | | | 98 | | OPT | ✅ | ✅ | | | 99 | | GPT-Neo | ✅ | ✅ | | | 100 | | GPT-J | ✅ | ✅ | | | 101 | | Deberta | ✅ | | | | 102 | | Deberta-v2 | ✅ | | | | 103 | 104 | ### Text-to-Image Generation 105 | 106 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | 107 | | --------- | ---- | ---- | ---- | ---- | 108 | | Stable Diffusion | ✅ | | | | 109 | 110 | 111 | ### Image Classification 112 | 113 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | 114 | | --------- | ---- | ---- | ---- | ---- | 115 | | ViT | ✅ | | | | 116 | | Swin | ✅ | | | | 117 | 118 | ### Image to text (Multi-modal models) 119 | 120 | We have tested LoRA for [ViT](https://huggingface.co/docs/transformers/model_doc/vit) and [Swin](https://huggingface.co/docs/transformers/model_doc/swin) for fine-tuning on image classification. 121 | However, it should be possible to use LoRA for any [ViT-based model](https://huggingface.co/models?pipeline_tag=image-classification&sort=downloads&search=vit) from 🤗 Transformers. 122 | Check out the [Image classification](/task_guides/image_classification_lora) task guide to learn more. If you run into problems, please open an issue. 123 | 124 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | 125 | | --------- | ---- | ---- | ---- | ---- | 126 | | Blip-2 | ✅ | | | | 127 | 128 | 129 | ### Semantic Segmentation 130 | 131 | As with image-to-text models, you should be able to apply LoRA to any of the [segmentation models](https://huggingface.co/models?pipeline_tag=image-segmentation&sort=downloads). 132 | It's worth noting that we haven't tested this with every architecture yet. Therefore, if you come across any issues, kindly create an issue report. 133 | 134 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | 135 | | --------- | ---- | ---- | ---- | ---- | 136 | | SegFormer | ✅ | | | | 137 | 138 | -------------------------------------------------------------------------------- /peft_egg/docs/source/install.mdx: -------------------------------------------------------------------------------- 1 | 12 | 13 | # Installation 14 | 15 | Before you start, you will need to setup your environment, install the appropriate packages, and configure 🤗 PEFT. 🤗 PEFT is tested on **Python 3.7+**. 16 | 17 | 🤗 PEFT is available on pypi, as well as GitHub: 18 | 19 | ## pip 20 | 21 | To install 🤗 PEFT from pypi: 22 | 23 | ```bash 24 | pip install peft 25 | ``` 26 | 27 | ## Source 28 | 29 | New features that haven't been released yet are added every day, which also means there may be some bugs. To try them out, install from the GitHub repository: 30 | 31 | ```bash 32 | pip install git+https://github.com/huggingface/peft 33 | ``` 34 | 35 | If you're working on contributing to the library or wish to play with the source code and see live 36 | results as you run the code, an editable version can be installed from a locally-cloned version of the 37 | repository: 38 | 39 | ```bash 40 | git clone https://github.com/huggingface/peft 41 | cd peft 42 | pip install -e . 43 | ``` 44 | -------------------------------------------------------------------------------- /peft_egg/docs/source/package_reference/config.mdx: -------------------------------------------------------------------------------- 1 | # Configuration 2 | 3 | The configuration classes stores the configuration of a [`PeftModel`], PEFT adapter models, and the configurations of [`PrefixTuning`], [`PromptTuning`], and [`PromptEncoder`]. They contain methods for saving and loading model configurations from the Hub, specifying the PEFT method to use, type of task to perform, and model configurations like number of layers and number of attention heads. 4 | 5 | ## PeftConfigMixin 6 | 7 | [[autodoc]] utils.config.PeftConfigMixin 8 | - all 9 | 10 | ## PeftConfig 11 | 12 | [[autodoc]] PeftConfig 13 | - all 14 | 15 | ## PromptLearningConfig 16 | 17 | [[autodoc]] PromptLearningConfig 18 | - all 19 | -------------------------------------------------------------------------------- /peft_egg/docs/source/package_reference/peft_model.mdx: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | [`PeftModel`] is the base model class for specifying the base Transformer model and configuration to apply a PEFT method to. The base `PeftModel` contains methods for loading and saving models from the Hub, and supports the [`PromptEncoder`] for prompt learning. 4 | 5 | ## PeftModel 6 | 7 | [[autodoc]] PeftModel 8 | - all 9 | 10 | ## PeftModelForSequenceClassification 11 | 12 | A `PeftModel` for sequence classification tasks. 13 | 14 | [[autodoc]] PeftModelForSequenceClassification 15 | - all 16 | 17 | ## PeftModelForTokenClassification 18 | 19 | A `PeftModel` for token classification tasks. 20 | 21 | [[autodoc]] PeftModelForTokenClassification 22 | - all 23 | 24 | ## PeftModelForCausalLM 25 | 26 | A `PeftModel` for causal language modeling. 27 | 28 | [[autodoc]] PeftModelForCausalLM 29 | - all 30 | 31 | ## PeftModelForSeq2SeqLM 32 | 33 | A `PeftModel` for sequence-to-sequence language modeling. 34 | 35 | [[autodoc]] PeftModelForSeq2SeqLM 36 | - all 37 | -------------------------------------------------------------------------------- /peft_egg/docs/source/package_reference/tuners.mdx: -------------------------------------------------------------------------------- 1 | # Tuners 2 | 3 | Each tuner (or PEFT method) has a configuration and model. 4 | 5 | ## LoRA 6 | 7 | For finetuning a model with LoRA. 8 | 9 | [[autodoc]] LoraConfig 10 | 11 | [[autodoc]] LoraModel 12 | 13 | [[autodoc]] tuners.lora.LoraLayer 14 | 15 | [[autodoc]] tuners.lora.Linear 16 | 17 | ## P-tuning 18 | 19 | [[autodoc]] tuners.p_tuning.PromptEncoderConfig 20 | 21 | [[autodoc]] tuners.p_tuning.PromptEncoder 22 | 23 | ## Prefix tuning 24 | 25 | [[autodoc]] tuners.prefix_tuning.PrefixTuningConfig 26 | 27 | [[autodoc]] tuners.prefix_tuning.PrefixEncoder 28 | 29 | ## Prompt tuning 30 | 31 | [[autodoc]] tuners.prompt_tuning.PromptTuningConfig 32 | 33 | [[autodoc]] tuners.prompt_tuning.PromptEmbedding -------------------------------------------------------------------------------- /peft_egg/docs/source/quicktour.mdx: -------------------------------------------------------------------------------- 1 | 12 | 13 | # Quicktour 14 | 15 | 🤗 PEFT contains parameter-efficient finetuning methods for training large pretrained models. The traditional paradigm is to finetune all of a model's parameters for each downstream task, but this is becoming exceedingly costly and impractical because of the enormous number of parameters in models today. Instead, it is more efficient to train a smaller number of prompt parameters or use a reparametrization method like low-rank adaptation (LoRA) to reduce the number of trainable parameters. 16 | 17 | This quicktour will show you 🤗 PEFT's main features and help you train large pretrained models that would typically be inaccessible on consumer devices. You'll see how to train the 1.2B parameter [`bigscience/mt0-large`](https://huggingface.co/bigscience/mt0-large) model with LoRA to generate a classification label and use it for inference. 18 | 19 | ## PeftConfig 20 | 21 | Each 🤗 PEFT method is defined by a [`PeftConfig`] class that stores all the important parameters for building a [`PeftModel`]. 22 | 23 | Because you're going to use LoRA, you'll need to load and create a [`LoraConfig`] class. Within `LoraConfig`, specify the following parameters: 24 | 25 | - the `task_type`, or sequence-to-sequence language modeling in this case 26 | - `inference_mode`, whether you're using the model for inference or not 27 | - `r`, the dimension of the low-rank matrices 28 | - `lora_alpha`, the scaling factor for the low-rank matrices 29 | - `lora_dropout`, the dropout probability of the LoRA layers 30 | 31 | ```python 32 | from peft import LoraConfig, TaskType 33 | 34 | peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1) 35 | ``` 36 | 37 | 38 | 39 | 💡 See the [`LoraConfig`] reference for more details about other parameters you can adjust. 40 | 41 | 42 | 43 | ## PeftModel 44 | 45 | A [`PeftModel`] is created by the [`get_peft_model`] function. It takes a base model - which you can load from the 🤗 Transformers library - and the [`PeftConfig`] containing the instructions for how to configure a model for a specific 🤗 PEFT method. 46 | 47 | Start by loading the base model you want to finetune. 48 | 49 | ```python 50 | from transformers import AutoModelForSeq2SeqLM 51 | 52 | model_name_or_path = "bigscience/mt0-large" 53 | tokenizer_name_or_path = "bigscience/mt0-large" 54 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) 55 | ``` 56 | 57 | Wrap your base model and `peft_config` with the `get_peft_model` function to create a [`PeftModel`]. To get a sense of the number of trainable parameters in your model, use the [`print_trainable_parameters`] method. In this case, you're only training 0.19% of the model's parameters! 🤏 58 | 59 | ```python 60 | from peft import get_peft_model 61 | 62 | model = get_peft_model(model, peft_config) 63 | model.print_trainable_parameters() 64 | "output: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282" 65 | ``` 66 | 67 | That is it 🎉! Now you can train the model using the 🤗 Transformers [`~transformers.Trainer`], 🤗 Accelerate, or any custom PyTorch training loop. 68 | 69 | ## Save and load a model 70 | 71 | After your model is finished training, you can save your model to a directory using the [`~transformers.PreTrainedModel.save_pretrained`] function. You can also save your model to the Hub (make sure you log in to your Hugging Face account first) with the [`~transformers.PreTrainedModel.push_to_hub`] function. 72 | 73 | ```python 74 | model.save_pretrained("output_dir") 75 | 76 | # if pushing to Hub 77 | from huggingface_hub import notebook_login 78 | 79 | notebook_login() 80 | model.push_to_hub("my_awesome_peft_model") 81 | ``` 82 | 83 | This only saves the incremental 🤗 PEFT weights that were trained, meaning it is super efficient to store, transfer, and load. For example, this [`bigscience/T0_3B`](https://huggingface.co/smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM) model trained with LoRA on the [`twitter_complaints`](https://huggingface.co/datasets/ought/raft/viewer/twitter_complaints/train) subset of the RAFT [dataset](https://huggingface.co/datasets/ought/raft) only contains two files: `adapter_config.json` and `adapter_model.bin`. The latter file is just 19MB! 84 | 85 | Easily load your model for inference using the [`~transformers.PreTrainedModel.from_pretrained`] function: 86 | 87 | ```diff 88 | from transformers import AutoModelForSeq2SeqLM 89 | + from peft import PeftModel, PeftConfig 90 | 91 | + peft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM" 92 | + config = PeftConfig.from_pretrained(peft_model_id) 93 | model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path) 94 | + model = PeftModel.from_pretrained(model, peft_model_id) 95 | tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) 96 | 97 | model = model.to(device) 98 | model.eval() 99 | inputs = tokenizer("Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label :", return_tensors="pt") 100 | 101 | with torch.no_grad(): 102 | outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=10) 103 | print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0]) 104 | 'complaint' 105 | ``` 106 | 107 | ## Next steps 108 | 109 | Now that you've seen how to train a model with one of the 🤗 PEFT methods, we encourage you to try out some of the other methods like prompt tuning. The steps are very similar to the ones shown in this quickstart; prepare a [`PeftConfig`] for a 🤗 PEFT method, and use the `get_peft_model` to create a [`PeftModel`] from the configuration and base model. Then you can train it however you like! 110 | 111 | Feel free to also take a look at the task guides if you're interested in training a model with a 🤗 PEFT method for a specific task such as semantic segmentation, multilingual automatic speech recognition, DreamBooth, and token classification. -------------------------------------------------------------------------------- /peft_egg/grammar.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | # 3 | # lora_A = torch.randn((6,1)) 4 | # tensor_list = [torch.tensor(float(i)) for i in range(1,5)] 5 | # print(f'lora_A: {lora_A}') 6 | # print(f'tensor_list: {tensor_list}') 7 | # 8 | # lora_A.requires_grad = True 9 | # for x in tensor_list: 10 | # x.requires_grad = True 11 | # 12 | # c = [] 13 | # for x in tensor_list: 14 | # c.append(lora_A[1] * x) 15 | # 16 | # d = torch.stack(c,0) 17 | # print(f'stacked d: {d}') 18 | # 19 | # d.sum().backward() 20 | # 21 | # print(lora_A.grad) 22 | 23 | 24 | import torch 25 | import torch.nn as nn 26 | class Conv1D(nn.Module): 27 | """ 28 | 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). 29 | 30 | Basically works like a linear layer but the weights are transposed. 31 | 32 | Args: 33 | nf (`int`): The number of output features. 34 | nx (`int`): The number of input features. 35 | """ 36 | 37 | def __init__(self, nf, nx): 38 | super().__init__() 39 | self.nf = nf 40 | self.weight = nn.Parameter(torch.empty(nx, nf)) 41 | self.bias = nn.Parameter(torch.zeros(nf)) 42 | nn.init.normal_(self.weight, std=0.02) 43 | 44 | def forward(self, x): 45 | size_out = x.size()[:-1] + (self.nf,) 46 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 47 | x = x.view(size_out) 48 | return x 49 | a = Conv1D(500,300) 50 | 51 | print(a.weight.shape) -------------------------------------------------------------------------------- /peft_egg/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | target-version = ['py36'] 4 | 5 | [tool.ruff] 6 | ignore = ["C901", "E501", "E741", "W605"] 7 | select = ["C", "E", "F", "I", "W"] 8 | line-length = 119 9 | 10 | [tool.ruff.isort] 11 | lines-after-imports = 2 12 | known-first-party = ["peft"] 13 | 14 | [isort] 15 | default_section = "FIRSTPARTY" 16 | known_first_party = "peft" 17 | known_third_party = [ 18 | "numpy", 19 | "torch", 20 | "accelerate", 21 | "transformers", 22 | ] 23 | line_length = 119 24 | lines_after_imports = 2 25 | multi_line_output = 3 26 | include_trailing_comma = true 27 | force_grid_wrap = 0 28 | use_parentheses = true 29 | ensure_newline_before_comments = true 30 | 31 | [tool.pytest] 32 | doctest_optionflags = [ 33 | "NORMALIZE_WHITESPACE", 34 | "ELLIPSIS", 35 | "NUMBER", 36 | ] -------------------------------------------------------------------------------- /peft_egg/scripts/log_reports.py: -------------------------------------------------------------------------------- 1 | import json, os 2 | from pathlib import Path 3 | from datetime import date 4 | from tabulate import tabulate 5 | 6 | failed = [] 7 | passed = [] 8 | 9 | group_info = [] 10 | 11 | total_num_failed = 0 12 | empty_file = False or len(list(Path().glob("*.log"))) == 0 13 | for log in Path().glob("*.log"): 14 | section_num_failed = 0 15 | with open(log, "r") as f: 16 | nb_lines = sum(1 for _ in f) 17 | for i, line in f: 18 | line = json.loads(line) 19 | if line.get("nodeid", "") != "": 20 | test = line["nodeid"] 21 | if line.get("duration", None) is not None: 22 | duration = f'{line["duration"]:.4f}' 23 | if line.get("outcome", "") == "failed": 24 | section_num_failed += 1 25 | failed.append([test, duration, log.name.split('_')[0]]) 26 | total_num_failed += 1 27 | else: 28 | passed.append([test, duration, log.name.split('_')[0]]) 29 | if nb_lines == 0: 30 | empty_file = True 31 | group_info.append([str(log), section_num_failed, failed]) 32 | os.remove(log) 33 | failed = [] 34 | no_error_payload = { 35 | "type": "section", 36 | "text": { 37 | "type": "plain_text", 38 | "text": "🌞 There were no failures!" if not empty_file else "Something went wrong - please check GH action results.", 39 | "emoji": True 40 | } 41 | } 42 | 43 | message = "" 44 | payload = [ 45 | { 46 | "type": "header", 47 | "text": { 48 | "type": "plain_text", 49 | "text": "🤗 Results of the {} PEFT scheduled tests.".format(os.environ.get("TEST_TYPE", "")), 50 | } 51 | }, 52 | ] 53 | if total_num_failed > 0: 54 | for name, num_failed, failed_tests in group_info: 55 | if num_failed > 0: 56 | if num_failed == 1: 57 | message += f"*{name}: {num_failed} failed test*\n" 58 | else: 59 | message += f"*{name}: {num_failed} failed tests*\n" 60 | failed_table = [] 61 | for test in failed_tests: 62 | failed_table.append(test[0].split("::")) 63 | failed_table = tabulate(failed_table, headers=["Test Location", "Test Case", "Test Name"], showindex="always", tablefmt="grid", maxcolwidths=[12, 12, 12]) 64 | message += '\n```\n' +failed_table + '\n```' 65 | print(f'### {message}') 66 | else: 67 | payload.append(no_error_payload) 68 | 69 | 70 | if os.environ.get("TEST_TYPE", "") != "": 71 | from slack_sdk import WebClient 72 | 73 | if len(message) != 0: 74 | md_report = { 75 | "type": "section", 76 | "text": { 77 | "type": "mrkdwn", 78 | "text": message 79 | }, 80 | } 81 | payload.append(md_report) 82 | action_button = { 83 | "type": "section", 84 | "text": { 85 | "type": "mrkdwn", 86 | "text": "*For more details:*" 87 | }, 88 | "accessory": { 89 | "type": "button", 90 | "text": {"type": "plain_text", "text": "Check Action results", "emoji": True}, 91 | "url": f"https://github.com/huggingface/peft/actions/runs/{os.environ['GITHUB_RUN_ID']}", 92 | }, 93 | } 94 | payload.append(action_button) 95 | 96 | date_report = { 97 | "type": "context", 98 | "elements": [ 99 | { 100 | "type": "plain_text", 101 | "text": f"Nightly {os.environ.get('TEST_TYPE')} test results for {date.today()}", 102 | }, 103 | ], 104 | } 105 | payload.append(date_report) 106 | 107 | print(payload) 108 | 109 | client = WebClient(token=os.environ.get("SLACK_API_TOKEN")) 110 | client.chat_postMessage(channel="#peft-ci-daily", text=message, blocks=payload) 111 | -------------------------------------------------------------------------------- /peft_egg/scripts/stale.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team, the AllenNLP library authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Script to close stale issue. Taken in part from the AllenNLP repository. 16 | https://github.com/allenai/allennlp. 17 | """ 18 | from datetime import datetime as dt 19 | import os 20 | 21 | from github import Github 22 | 23 | 24 | LABELS_TO_EXEMPT = [ 25 | "good first issue", 26 | "good second issue", 27 | "good difficult issue", 28 | "feature request", 29 | "new model", 30 | "wip", 31 | "PRs welcome to address this", 32 | ] 33 | 34 | 35 | def main(): 36 | g = Github(os.environ["GITHUB_TOKEN"]) 37 | repo = g.get_repo("huggingface/peft") 38 | open_issues = repo.get_issues(state="open") 39 | 40 | for issue in open_issues: 41 | comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True) 42 | last_comment = comments[0] if len(comments) > 0 else None 43 | if ( 44 | last_comment is not None and last_comment.user.login == "github-actions[bot]" 45 | and (dt.utcnow() - issue.updated_at).days > 7 46 | and (dt.utcnow() - issue.created_at).days >= 30 47 | and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels()) 48 | ): 49 | issue.edit(state="closed") 50 | elif ( 51 | (dt.utcnow() - issue.updated_at).days > 23 52 | and (dt.utcnow() - issue.created_at).days >= 30 53 | and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels()) 54 | ): 55 | issue.create_comment( 56 | "This issue has been automatically marked as stale because it has not had " 57 | "recent activity. If you think this still needs to be addressed " 58 | "please comment on this thread.\n\n" 59 | ) 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /peft_egg/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from setuptools import find_packages, setup 16 | 17 | extras = {} 18 | extras["quality"] = ["black ~= 22.0", "ruff>=0.0.241", "urllib3<=2.0.0"] 19 | extras["docs_specific"] = ["hf-doc-builder"] 20 | extras["dev"] = extras["quality"] + extras["docs_specific"] 21 | extras["test"] = extras["dev"] + ["pytest", "pytest-xdist", "parameterized", "datasets"] 22 | 23 | setup( 24 | name="peft", 25 | version="0.4.0.dev0", 26 | description="Parameter-Efficient Fine-Tuning (PEFT)", 27 | license_files=["LICENSE"], 28 | long_description=open("README.md", "r", encoding="utf-8").read(), 29 | long_description_content_type="text/markdown", 30 | keywords="deep learning", 31 | license="Apache", 32 | author="The HuggingFace team", 33 | author_email="sourab@huggingface.co", 34 | url="https://github.com/huggingface/peft", 35 | package_dir={"": "src"}, 36 | packages=find_packages("src"), 37 | entry_points={}, 38 | python_requires=">=3.7.0", 39 | install_requires=[ 40 | "numpy>=1.17", 41 | "packaging>=20.0", 42 | "psutil", 43 | "pyyaml", 44 | "torch>=1.13.0", 45 | "transformers", 46 | "accelerate", 47 | "safetensors", 48 | ], 49 | extras_require=extras, 50 | classifiers=[ 51 | "Development Status :: 5 - Production/Stable", 52 | "Intended Audience :: Developers", 53 | "Intended Audience :: Education", 54 | "Intended Audience :: Science/Research", 55 | "License :: OSI Approved :: Apache Software License", 56 | "Operating System :: OS Independent", 57 | "Programming Language :: Python :: 3", 58 | "Programming Language :: Python :: 3.7", 59 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 60 | ], 61 | ) 62 | 63 | # Release checklist 64 | # 1. Change the version in __init__.py and setup.py. 65 | # 2. Commit these changes with the message: "Release: VERSION" 66 | # 3. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' " 67 | # Push the tag to git: git push --tags origin main 68 | # 4. Run the following commands in the top-level directory: 69 | # python setup.py bdist_wheel 70 | # python setup.py sdist 71 | # 5. Upload the package to the pypi test server first: 72 | # twine upload dist/* -r pypitest 73 | # twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ 74 | # 6. Check that you can install it in a virtualenv by running: 75 | # pip install -i https://testpypi.python.org/pypi peft 76 | # 7. Upload the final version to actual pypi: 77 | # twine upload dist/* -r pypi 78 | # 8. Add release notes to the tag in github once everything is looking hunky-dory. 79 | # 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to master 80 | -------------------------------------------------------------------------------- /peft_egg/src/peft/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | __version__ = "0.4.0.dev0" 21 | 22 | from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model 23 | from .peft_model import ( 24 | PeftModel, 25 | PeftModelForCausalLM, 26 | PeftModelForSeq2SeqLM, 27 | PeftModelForSequenceClassification, 28 | PeftModelForTokenClassification, 29 | PeftModelForQuestionAnswering, 30 | ) 31 | from .tuners import ( 32 | AdaptionPromptConfig, 33 | AdaptionPromptModel, 34 | LoraConfig, 35 | LoraModel, 36 | AdaLoraConfig, 37 | AdaLoraModel, 38 | PrefixEncoder, 39 | PrefixTuningConfig, 40 | PromptEmbedding, 41 | PromptEncoder, 42 | PromptEncoderConfig, 43 | PromptEncoderReparameterizationType, 44 | PromptTuningConfig, 45 | PromptTuningInit, 46 | MeloConfig, 47 | MeloModel 48 | ) 49 | from .utils import ( 50 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 51 | PeftConfig, 52 | PeftType, 53 | PromptLearningConfig, 54 | TaskType, 55 | bloom_model_postprocess_past_key_value, 56 | get_peft_model_state_dict, 57 | prepare_model_for_int8_training, 58 | prepare_model_for_kbit_training, 59 | set_peft_model_state_dict, 60 | shift_tokens_right, 61 | ) 62 | -------------------------------------------------------------------------------- /peft_egg/src/peft/import_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import importlib 16 | 17 | 18 | def is_bnb_available(): 19 | return importlib.util.find_spec("bitsandbytes") is not None 20 | 21 | 22 | def is_bnb_4bit_available(): 23 | if not is_bnb_available(): 24 | return False 25 | 26 | import bitsandbytes as bnb 27 | 28 | return hasattr(bnb.nn, "Linear4bit") 29 | -------------------------------------------------------------------------------- /peft_egg/src/peft/mapping.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .peft_model import ( 17 | PeftModel, 18 | PeftModelForCausalLM, 19 | PeftModelForQuestionAnswering, 20 | PeftModelForSeq2SeqLM, 21 | PeftModelForSequenceClassification, 22 | PeftModelForTokenClassification, 23 | ) 24 | from .tuners import ( 25 | AdaLoraConfig, 26 | AdaptionPromptConfig, 27 | LoraConfig, 28 | PrefixTuningConfig, 29 | PromptEncoderConfig, 30 | PromptTuningConfig, 31 | ) 32 | from .utils import PromptLearningConfig 33 | 34 | 35 | MODEL_TYPE_TO_PEFT_MODEL_MAPPING = { 36 | "SEQ_CLS": PeftModelForSequenceClassification, 37 | "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM, 38 | "CAUSAL_LM": PeftModelForCausalLM, 39 | "TOKEN_CLS": PeftModelForTokenClassification, 40 | "QUESTION_ANS": PeftModelForQuestionAnswering, 41 | } 42 | 43 | PEFT_TYPE_TO_CONFIG_MAPPING = { 44 | "ADAPTION_PROMPT": AdaptionPromptConfig, 45 | "PROMPT_TUNING": PromptTuningConfig, 46 | "PREFIX_TUNING": PrefixTuningConfig, 47 | "P_TUNING": PromptEncoderConfig, 48 | "LORA": LoraConfig, 49 | "ADALORA": AdaLoraConfig, 50 | } 51 | 52 | 53 | def get_peft_config(config_dict): 54 | """ 55 | Returns a Peft config object from a dictionary. 56 | 57 | Args: 58 | config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters. 59 | """ 60 | 61 | return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict) 62 | 63 | 64 | def _prepare_prompt_learning_config(peft_config, model_config): 65 | if peft_config.num_layers is None: 66 | if "num_hidden_layers" in model_config: 67 | num_layers = model_config["num_hidden_layers"] 68 | elif "num_layers" in model_config: 69 | num_layers = model_config["num_layers"] 70 | elif "n_layer" in model_config: 71 | num_layers = model_config["n_layer"] 72 | else: 73 | raise ValueError("Please specify `num_layers` in `peft_config`") 74 | peft_config.num_layers = num_layers 75 | 76 | if peft_config.token_dim is None: 77 | if "hidden_size" in model_config: 78 | token_dim = model_config["hidden_size"] 79 | elif "n_embd" in model_config: 80 | token_dim = model_config["n_embd"] 81 | elif "d_model" in model_config: 82 | token_dim = model_config["d_model"] 83 | else: 84 | raise ValueError("Please specify `token_dim` in `peft_config`") 85 | peft_config.token_dim = token_dim 86 | 87 | if peft_config.num_attention_heads is None: 88 | if "num_attention_heads" in model_config: 89 | num_attention_heads = model_config["num_attention_heads"] 90 | elif "n_head" in model_config: 91 | num_attention_heads = model_config["n_head"] 92 | elif "num_heads" in model_config: 93 | num_attention_heads = model_config["num_heads"] 94 | elif "encoder_attention_heads" in model_config: 95 | num_attention_heads = model_config["encoder_attention_heads"] 96 | else: 97 | raise ValueError("Please specify `num_attention_heads` in `peft_config`") 98 | peft_config.num_attention_heads = num_attention_heads 99 | 100 | if getattr(peft_config, "encoder_hidden_size", None) is None: 101 | setattr(peft_config, "encoder_hidden_size", peft_config.token_dim) 102 | 103 | return peft_config 104 | 105 | 106 | def get_peft_model(model, peft_config, adapter_name="default") -> PeftModel: 107 | """ 108 | Returns a Peft model object from a model and a config. 109 | 110 | Args: 111 | model ([`transformers.PreTrainedModel`]): Model to be wrapped. 112 | peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model. 113 | """ 114 | model_config = model.config.to_dict() if hasattr(model.config, "to_dict") else model.config 115 | peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) 116 | if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance( 117 | peft_config, PromptLearningConfig 118 | ): 119 | return PeftModel(model, peft_config, adapter_name=adapter_name) 120 | if isinstance(peft_config, PromptLearningConfig): 121 | peft_config = _prepare_prompt_learning_config(peft_config, model_config) 122 | return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name) 123 | -------------------------------------------------------------------------------- /peft_egg/src/peft/tuners/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel 21 | from .lora import LoraConfig, LoraModel 22 | from .adalora import AdaLoraConfig, AdaLoraModel 23 | from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType 24 | from .prefix_tuning import PrefixEncoder, PrefixTuningConfig 25 | from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit 26 | from .melo import MeloConfig, MeloModel -------------------------------------------------------------------------------- /peft_egg/src/peft/tuners/p_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import enum 17 | import warnings 18 | from dataclasses import dataclass, field 19 | from typing import Union 20 | 21 | import torch 22 | 23 | from ..utils import PeftType, PromptLearningConfig 24 | 25 | 26 | class PromptEncoderReparameterizationType(str, enum.Enum): 27 | MLP = "MLP" 28 | LSTM = "LSTM" 29 | 30 | 31 | @dataclass 32 | class PromptEncoderConfig(PromptLearningConfig): 33 | """ 34 | This is the configuration class to store the configuration of a [`PromptEncoder`]. 35 | 36 | Args: 37 | encoder_reparameterization_type (Union[[`PromptEncoderReparameterizationType`], `str`]): 38 | The type of reparameterization to use. 39 | encoder_hidden_size (`int`): The hidden size of the prompt encoder. 40 | encoder_num_layers (`int`): The number of layers of the prompt encoder. 41 | encoder_dropout (`float`): The dropout probability of the prompt encoder. 42 | """ 43 | 44 | encoder_reparameterization_type: Union[str, PromptEncoderReparameterizationType] = field( 45 | default=PromptEncoderReparameterizationType.MLP, 46 | metadata={"help": "How to reparameterize the prompt encoder"}, 47 | ) 48 | encoder_hidden_size: int = field( 49 | default=None, 50 | metadata={"help": "The hidden size of the prompt encoder"}, 51 | ) 52 | encoder_num_layers: int = field( 53 | default=2, 54 | metadata={"help": "The number of layers of the prompt encoder"}, 55 | ) 56 | encoder_dropout: float = field( 57 | default=0.0, 58 | metadata={"help": "The dropout of the prompt encoder"}, 59 | ) 60 | 61 | def __post_init__(self): 62 | self.peft_type = PeftType.P_TUNING 63 | 64 | 65 | # Based on https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/modules/common/prompt_encoder.py 66 | # with some refactor 67 | class PromptEncoder(torch.nn.Module): 68 | """ 69 | The prompt encoder network that is used to generate the virtual token embeddings for p-tuning. 70 | 71 | Args: 72 | config ([`PromptEncoderConfig`]): The configuration of the prompt encoder. 73 | 74 | Example: 75 | 76 | ```py 77 | >>> from peft import PromptEncoder, PromptEncoderConfig 78 | 79 | >>> config = PromptEncoderConfig( 80 | ... peft_type="P_TUNING", 81 | ... task_type="SEQ_2_SEQ_LM", 82 | ... num_virtual_tokens=20, 83 | ... token_dim=768, 84 | ... num_transformer_submodules=1, 85 | ... num_attention_heads=12, 86 | ... num_layers=12, 87 | ... encoder_reparameterization_type="MLP", 88 | ... encoder_hidden_size=768, 89 | ... ) 90 | 91 | >>> prompt_encoder = PromptEncoder(config) 92 | ``` 93 | 94 | **Attributes**: 95 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt encoder. 96 | - **mlp_head** (`torch.nn.Sequential`) -- The MLP head of the prompt encoder if `inference_mode=False`. 97 | - **lstm_head** (`torch.nn.LSTM`) -- The LSTM head of the prompt encoder if `inference_mode=False` and 98 | `encoder_reparameterization_type="LSTM"`. 99 | - **token_dim** (`int`) -- The hidden embedding dimension of the base transformer model. 100 | - **input_size** (`int`) -- The input size of the prompt encoder. 101 | - **output_size** (`int`) -- The output size of the prompt encoder. 102 | - **hidden_size** (`int`) -- The hidden size of the prompt encoder. 103 | - **total_virtual_tokens** (`int`): The total number of virtual tokens of the 104 | prompt encoder. 105 | - **encoder_type** (Union[[`PromptEncoderReparameterizationType`], `str`]): The encoder type of the prompt 106 | encoder. 107 | 108 | 109 | Input shape: (`batch_size`, `total_virtual_tokens`) 110 | 111 | Output shape: (`batch_size`, `total_virtual_tokens`, `token_dim`) 112 | """ 113 | 114 | def __init__(self, config): 115 | super().__init__() 116 | self.token_dim = config.token_dim 117 | self.input_size = self.token_dim 118 | self.output_size = self.token_dim 119 | self.hidden_size = config.encoder_hidden_size 120 | self.total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules 121 | self.encoder_type = config.encoder_reparameterization_type 122 | 123 | # embedding 124 | self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim) 125 | if not config.inference_mode: 126 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM: 127 | lstm_dropout = config.encoder_dropout 128 | num_layers = config.encoder_num_layers 129 | # LSTM 130 | self.lstm_head = torch.nn.LSTM( 131 | input_size=self.input_size, 132 | hidden_size=self.hidden_size, 133 | num_layers=num_layers, 134 | dropout=lstm_dropout, 135 | bidirectional=True, 136 | batch_first=True, 137 | ) 138 | 139 | self.mlp_head = torch.nn.Sequential( 140 | torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2), 141 | torch.nn.ReLU(), 142 | torch.nn.Linear(self.hidden_size * 2, self.output_size), 143 | ) 144 | 145 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP: 146 | warnings.warn( 147 | f"for {self.encoder_type}, the `encoder_num_layers` is ignored. Exactly 2 MLP layers are used." 148 | ) 149 | layers = [ 150 | torch.nn.Linear(self.input_size, self.hidden_size), 151 | torch.nn.ReLU(), 152 | torch.nn.Linear(self.hidden_size, self.hidden_size), 153 | torch.nn.ReLU(), 154 | torch.nn.Linear(self.hidden_size, self.output_size), 155 | ] 156 | self.mlp_head = torch.nn.Sequential(*layers) 157 | 158 | else: 159 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") 160 | 161 | def forward(self, indices): 162 | input_embeds = self.embedding(indices) 163 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM: 164 | output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]) 165 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP: 166 | output_embeds = self.mlp_head(input_embeds) 167 | else: 168 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") 169 | 170 | return output_embeds 171 | -------------------------------------------------------------------------------- /peft_egg/src/peft/tuners/prefix_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from dataclasses import dataclass, field 18 | 19 | import torch 20 | 21 | from ..utils import PeftType, PromptLearningConfig 22 | 23 | 24 | @dataclass 25 | class PrefixTuningConfig(PromptLearningConfig): 26 | """ 27 | This is the configuration class to store the configuration of a [`PrefixEncoder`]. 28 | 29 | Args: 30 | encoder_hidden_size (`int`): The hidden size of the prompt encoder. 31 | prefix_projection (`bool`): Whether to project the prefix embeddings. 32 | """ 33 | 34 | encoder_hidden_size: int = field( 35 | default=None, 36 | metadata={"help": "The hidden size of the encoder"}, 37 | ) 38 | prefix_projection: bool = field( 39 | default=False, 40 | metadata={"help": "Whether to project the prefix tokens"}, 41 | ) 42 | 43 | def __post_init__(self): 44 | self.peft_type = PeftType.PREFIX_TUNING 45 | 46 | 47 | # Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py 48 | # with some refactor 49 | class PrefixEncoder(torch.nn.Module): 50 | r""" 51 | The `torch.nn` model to encode the prefix. 52 | 53 | Args: 54 | config ([`PrefixTuningConfig`]): The configuration of the prefix encoder. 55 | 56 | Example: 57 | 58 | ```py 59 | >>> from peft import PrefixEncoder, PrefixTuningConfig 60 | 61 | >>> config = PrefixTuningConfig( 62 | ... peft_type="PREFIX_TUNING", 63 | ... task_type="SEQ_2_SEQ_LM", 64 | ... num_virtual_tokens=20, 65 | ... token_dim=768, 66 | ... num_transformer_submodules=1, 67 | ... num_attention_heads=12, 68 | ... num_layers=12, 69 | ... encoder_hidden_size=768, 70 | ... ) 71 | >>> prefix_encoder = PrefixEncoder(config) 72 | ``` 73 | 74 | **Attributes**: 75 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prefix encoder. 76 | - **transform** (`torch.nn.Sequential`) -- The two-layer MLP to transform the prefix embeddings if 77 | `prefix_projection` is `True`. 78 | - **prefix_projection** (`bool`) -- Whether to project the prefix embeddings. 79 | 80 | Input shape: (`batch_size`, `num_virtual_tokens`) 81 | 82 | Output shape: (`batch_size`, `num_virtual_tokens`, `2*layers*hidden`) 83 | """ 84 | 85 | def __init__(self, config): 86 | super().__init__() 87 | self.prefix_projection = config.prefix_projection 88 | token_dim = config.token_dim 89 | num_layers = config.num_layers 90 | encoder_hidden_size = config.encoder_hidden_size 91 | num_virtual_tokens = config.num_virtual_tokens 92 | if self.prefix_projection and not config.inference_mode: 93 | # Use a two-layer MLP to encode the prefix 94 | self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim) 95 | self.transform = torch.nn.Sequential( 96 | torch.nn.Linear(token_dim, encoder_hidden_size), 97 | torch.nn.Tanh(), 98 | torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim), 99 | ) 100 | else: 101 | self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim) 102 | 103 | def forward(self, prefix: torch.Tensor): 104 | if self.prefix_projection: 105 | prefix_tokens = self.embedding(prefix) 106 | past_key_values = self.transform(prefix_tokens) 107 | else: 108 | past_key_values = self.embedding(prefix) 109 | return past_key_values 110 | -------------------------------------------------------------------------------- /peft_egg/src/peft/tuners/prompt_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import enum 17 | import math 18 | from dataclasses import dataclass, field 19 | from typing import Optional, Union 20 | 21 | import torch 22 | 23 | from ..utils import PeftType, PromptLearningConfig 24 | 25 | 26 | class PromptTuningInit(str, enum.Enum): 27 | TEXT = "TEXT" 28 | RANDOM = "RANDOM" 29 | 30 | 31 | @dataclass 32 | class PromptTuningConfig(PromptLearningConfig): 33 | """ 34 | This is the configuration class to store the configuration of a [`PromptEmbedding`]. 35 | 36 | Args: 37 | prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): The initialization of the prompt embedding. 38 | prompt_tuning_init_text (`str`, *optional*): 39 | The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`. 40 | tokenizer_name_or_path (`str`, *optional*): 41 | The name or path of the tokenizer. Only used if `prompt_tuning_init` is `TEXT`. 42 | """ 43 | 44 | prompt_tuning_init: Union[PromptTuningInit, str] = field( 45 | default=PromptTuningInit.RANDOM, 46 | metadata={"help": "How to initialize the prompt tuning parameters"}, 47 | ) 48 | prompt_tuning_init_text: Optional[str] = field( 49 | default=None, 50 | metadata={ 51 | "help": "The text to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" 52 | }, 53 | ) 54 | tokenizer_name_or_path: Optional[str] = field( 55 | default=None, 56 | metadata={ 57 | "help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" 58 | }, 59 | ) 60 | 61 | def __post_init__(self): 62 | self.peft_type = PeftType.PROMPT_TUNING 63 | 64 | 65 | class PromptEmbedding(torch.nn.Module): 66 | """ 67 | The model to encode virtual tokens into prompt embeddings. 68 | 69 | Args: 70 | config ([`PromptTuningConfig`]): The configuration of the prompt embedding. 71 | word_embeddings (`torch.nn.Module`): The word embeddings of the base transformer model. 72 | 73 | **Attributes**: 74 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt embedding. 75 | 76 | Example: 77 | 78 | ```py 79 | >>> from peft import PromptEmbedding, PromptTuningConfig 80 | 81 | >>> config = PromptTuningConfig( 82 | ... peft_type="PROMPT_TUNING", 83 | ... task_type="SEQ_2_SEQ_LM", 84 | ... num_virtual_tokens=20, 85 | ... token_dim=768, 86 | ... num_transformer_submodules=1, 87 | ... num_attention_heads=12, 88 | ... num_layers=12, 89 | ... prompt_tuning_init="TEXT", 90 | ... prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral", 91 | ... tokenizer_name_or_path="t5-base", 92 | ... ) 93 | 94 | >>> # t5_model.shared is the word embeddings of the base model 95 | >>> prompt_embedding = PromptEmbedding(config, t5_model.shared) 96 | ``` 97 | 98 | Input Shape: (`batch_size`, `total_virtual_tokens`) 99 | 100 | Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`) 101 | """ 102 | 103 | def __init__(self, config, word_embeddings): 104 | super().__init__() 105 | 106 | total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules 107 | self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim) 108 | if config.prompt_tuning_init == PromptTuningInit.TEXT: 109 | from transformers import AutoTokenizer 110 | 111 | tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path) 112 | init_text = config.prompt_tuning_init_text 113 | init_token_ids = tokenizer(init_text)["input_ids"] 114 | # Trim or iterate until num_text_tokens matches total_virtual_tokens 115 | num_text_tokens = len(init_token_ids) 116 | if num_text_tokens > total_virtual_tokens: 117 | init_token_ids = init_token_ids[:total_virtual_tokens] 118 | elif num_text_tokens < total_virtual_tokens: 119 | num_reps = math.ceil(total_virtual_tokens / num_text_tokens) 120 | init_token_ids = init_token_ids * num_reps 121 | init_token_ids = init_token_ids[:total_virtual_tokens] 122 | 123 | word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone() 124 | word_embedding_weights = word_embedding_weights.to(torch.float32) 125 | self.embedding.weight = torch.nn.Parameter(word_embedding_weights) 126 | 127 | def forward(self, indices): 128 | # Just get embeddings 129 | prompt_embeddings = self.embedding(indices) 130 | return prompt_embeddings 131 | -------------------------------------------------------------------------------- /peft_egg/src/peft/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType 21 | from .other import ( 22 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 23 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, 24 | TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, 25 | COMMON_LAYERS_PATTERN, 26 | CONFIG_NAME, 27 | WEIGHTS_NAME, 28 | SAFETENSORS_WEIGHTS_NAME, 29 | _set_trainable, 30 | add_library_to_model_card, 31 | bloom_model_postprocess_past_key_value, 32 | prepare_model_for_int8_training, 33 | prepare_model_for_kbit_training, 34 | shift_tokens_right, 35 | transpose, 36 | _get_submodules, 37 | _set_adapter, 38 | _freeze_adapter, 39 | ModulesToSaveWrapper, 40 | ) 41 | from .hub_utils import hub_file_exists 42 | from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict 43 | -------------------------------------------------------------------------------- /peft_egg/src/peft/utils/hub_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from huggingface_hub import get_hf_file_metadata, hf_hub_url 17 | from huggingface_hub.utils import EntryNotFoundError 18 | 19 | 20 | def hub_file_exists(repo_id: str, filename: str, revision: str = None, repo_type: str = None) -> bool: 21 | r""" 22 | Checks if a file exists in a remote Hub repository. 23 | """ 24 | url = hf_hub_url(repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision) 25 | try: 26 | get_hf_file_metadata(url) 27 | return True 28 | except EntryNotFoundError: 29 | return False 30 | -------------------------------------------------------------------------------- /peft_egg/src/peft/utils/save_and_load.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .config import PeftType, PromptLearningConfig 17 | 18 | 19 | def get_peft_model_state_dict(model, state_dict=None, adapter_name="default"): 20 | """ 21 | Get the state dict of the Peft model. 22 | 23 | Args: 24 | model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP, 25 | the model should be the underlying model/unwrapped model (i.e. model.module). 26 | state_dict (`dict`, *optional*, defaults to `None`): 27 | The state dict of the model. If not provided, the state dict of the model 28 | will be used. 29 | """ 30 | config = model.peft_config[adapter_name] 31 | if state_dict is None: 32 | state_dict = model.state_dict() 33 | if config.peft_type in (PeftType.LORA, PeftType.ADALORA): 34 | # to_return = lora_state_dict(model, bias=model.peft_config.bias) 35 | # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` 36 | # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP 37 | bias = config.bias 38 | if bias == "none": 39 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k} 40 | elif bias == "all": 41 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k} 42 | elif bias == "lora_only": 43 | to_return = {} 44 | for k in state_dict: 45 | if "lora_" in k: 46 | to_return[k] = state_dict[k] 47 | bias_name = k.split("lora_")[0] + "bias" 48 | if bias_name in state_dict: 49 | to_return[bias_name] = state_dict[bias_name] 50 | else: 51 | raise NotImplementedError 52 | to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))} 53 | if config.peft_type == PeftType.ADALORA: 54 | rank_pattern = config.rank_pattern 55 | if rank_pattern is not None: 56 | rank_pattern = {k.replace(f".{adapter_name}", ""): v for k, v in rank_pattern.items()} 57 | config.rank_pattern = rank_pattern 58 | to_return = model.resize_state_dict_by_rank_pattern(rank_pattern, to_return, adapter_name) 59 | 60 | elif config.peft_type == PeftType.ADAPTION_PROMPT: 61 | to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")} 62 | elif isinstance(config, PromptLearningConfig): 63 | to_return = {} 64 | if config.inference_mode: 65 | prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight 66 | else: 67 | prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name) 68 | to_return["prompt_embeddings"] = prompt_embeddings 69 | else: 70 | raise NotImplementedError 71 | if model.modules_to_save is not None: 72 | for key, value in state_dict.items(): 73 | if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save): 74 | to_return[key.replace("modules_to_save.", "")] = value 75 | 76 | to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()} 77 | return to_return 78 | 79 | 80 | def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="default"): 81 | """ 82 | Set the state dict of the Peft model. 83 | 84 | Args: 85 | model ([`PeftModel`]): The Peft model. 86 | peft_model_state_dict (`dict`): The state dict of the Peft model. 87 | """ 88 | config = model.peft_config[adapter_name] 89 | state_dict = {} 90 | if model.modules_to_save is not None: 91 | for key, value in peft_model_state_dict.items(): 92 | if any(module_name in key for module_name in model.modules_to_save): 93 | for module_name in model.modules_to_save: 94 | if module_name in key: 95 | key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}") 96 | break 97 | state_dict[key] = value 98 | else: 99 | state_dict = peft_model_state_dict 100 | 101 | if config.peft_type in (PeftType.LORA, PeftType.ADALORA): 102 | peft_model_state_dict = {} 103 | for k, v in state_dict.items(): 104 | if "lora_" in k: 105 | suffix = k.split("lora_")[1] 106 | if "." in suffix: 107 | suffix_to_replace = ".".join(suffix.split(".")[1:]) 108 | k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}") 109 | else: 110 | k = f"{k}.{adapter_name}" 111 | peft_model_state_dict[k] = v 112 | else: 113 | peft_model_state_dict[k] = v 114 | if config.peft_type == PeftType.ADALORA: 115 | rank_pattern = config.rank_pattern 116 | if rank_pattern is not None: 117 | model.resize_modules_by_rank_pattern(rank_pattern, adapter_name) 118 | elif isinstance(config, PromptLearningConfig) or config.peft_type == PeftType.ADAPTION_PROMPT: 119 | peft_model_state_dict = state_dict 120 | else: 121 | raise NotImplementedError 122 | 123 | load_result = model.load_state_dict(peft_model_state_dict, strict=False) 124 | if isinstance(config, PromptLearningConfig): 125 | model.prompt_encoder[adapter_name].embedding.load_state_dict( 126 | {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True 127 | ) 128 | return load_result 129 | -------------------------------------------------------------------------------- /peft_egg/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/peft_egg/tests/__init__.py -------------------------------------------------------------------------------- /peft_egg/tests/test_config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import os 16 | import tempfile 17 | import unittest 18 | 19 | from peft import ( 20 | AdaptionPromptConfig, 21 | LoraConfig, 22 | PeftConfig, 23 | PrefixTuningConfig, 24 | PromptEncoderConfig, 25 | PromptTuningConfig, 26 | ) 27 | 28 | 29 | PEFT_MODELS_TO_TEST = [("lewtun/tiny-random-OPTForCausalLM-delta", "v1")] 30 | 31 | 32 | class PeftConfigTestMixin: 33 | all_config_classes = ( 34 | LoraConfig, 35 | PromptEncoderConfig, 36 | PrefixTuningConfig, 37 | PromptTuningConfig, 38 | AdaptionPromptConfig, 39 | ) 40 | 41 | 42 | class PeftConfigTester(unittest.TestCase, PeftConfigTestMixin): 43 | def test_methods(self): 44 | r""" 45 | Test if all configs have the expected methods. Here we test 46 | - to_dict 47 | - save_pretrained 48 | - from_pretrained 49 | - from_json_file 50 | """ 51 | # test if all configs have the expected methods 52 | for config_class in self.all_config_classes: 53 | config = config_class() 54 | self.assertTrue(hasattr(config, "to_dict")) 55 | self.assertTrue(hasattr(config, "save_pretrained")) 56 | self.assertTrue(hasattr(config, "from_pretrained")) 57 | self.assertTrue(hasattr(config, "from_json_file")) 58 | 59 | def test_task_type(self): 60 | for config_class in self.all_config_classes: 61 | # assert this will not fail 62 | _ = config_class(task_type="test") 63 | 64 | def test_from_pretrained(self): 65 | r""" 66 | Test if the config is correctly loaded using: 67 | - from_pretrained 68 | """ 69 | for config_class in self.all_config_classes: 70 | for model_name, revision in PEFT_MODELS_TO_TEST: 71 | # Test we can load config from delta 72 | _ = config_class.from_pretrained(model_name, revision=revision) 73 | 74 | def test_save_pretrained(self): 75 | r""" 76 | Test if the config is correctly saved and loaded using 77 | - save_pretrained 78 | """ 79 | for config_class in self.all_config_classes: 80 | config = config_class() 81 | with tempfile.TemporaryDirectory() as tmp_dirname: 82 | config.save_pretrained(tmp_dirname) 83 | 84 | config_from_pretrained = config_class.from_pretrained(tmp_dirname) 85 | self.assertEqual(config.to_dict(), config_from_pretrained.to_dict()) 86 | 87 | def test_from_json_file(self): 88 | for config_class in self.all_config_classes: 89 | config = config_class() 90 | with tempfile.TemporaryDirectory() as tmp_dirname: 91 | config.save_pretrained(tmp_dirname) 92 | 93 | config_from_json = config_class.from_json_file(os.path.join(tmp_dirname, "adapter_config.json")) 94 | self.assertEqual(config.to_dict(), config_from_json) 95 | 96 | def test_to_dict(self): 97 | r""" 98 | Test if the config can be correctly converted to a dict using: 99 | - to_dict 100 | - __dict__ 101 | """ 102 | for config_class in self.all_config_classes: 103 | config = config_class() 104 | self.assertEqual(config.to_dict(), config.__dict__) 105 | self.assertTrue(isinstance(config.to_dict(), dict)) 106 | 107 | def test_from_pretrained_cache_dir(self): 108 | r""" 109 | Test if the config is correctly loaded with extra kwargs 110 | """ 111 | with tempfile.TemporaryDirectory() as tmp_dirname: 112 | for config_class in self.all_config_classes: 113 | for model_name, revision in PEFT_MODELS_TO_TEST: 114 | # Test we can load config from delta 115 | _ = config_class.from_pretrained(model_name, revision=revision, cache_dir=tmp_dirname) 116 | 117 | def test_from_pretrained_cache_dir_remote(self): 118 | r""" 119 | Test if the config is correctly loaded with a checkpoint from the hub 120 | """ 121 | with tempfile.TemporaryDirectory() as tmp_dirname: 122 | _ = PeftConfig.from_pretrained("ybelkada/test-st-lora", cache_dir=tmp_dirname) 123 | self.assertTrue("models--ybelkada--test-st-lora" in os.listdir(tmp_dirname)) 124 | 125 | def test_set_attributes(self): 126 | # manually set attributes and check if they are correctly written 127 | for config_class in self.all_config_classes: 128 | config = config_class(peft_type="test") 129 | 130 | # save pretrained 131 | with tempfile.TemporaryDirectory() as tmp_dirname: 132 | config.save_pretrained(tmp_dirname) 133 | 134 | config_from_pretrained = config_class.from_pretrained(tmp_dirname) 135 | self.assertEqual(config.to_dict(), config_from_pretrained.to_dict()) 136 | -------------------------------------------------------------------------------- /peft_egg/tests/test_decoder_models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import unittest 16 | 17 | import torch 18 | from parameterized import parameterized 19 | from transformers import AutoModelForCausalLM 20 | 21 | from .testing_common import PeftCommonTester, PeftTestConfigManager 22 | 23 | 24 | PEFT_DECODER_MODELS_TO_TEST = [ 25 | "hf-internal-testing/tiny-random-OPTForCausalLM", 26 | "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", 27 | "hf-internal-testing/tiny-random-GPT2LMHeadModel", 28 | "hf-internal-testing/tiny-random-BloomForCausalLM", 29 | "hf-internal-testing/tiny-random-gpt_neo", 30 | "hf-internal-testing/tiny-random-GPTJForCausalLM", 31 | "hf-internal-testing/tiny-random-GPTBigCodeForCausalLM", 32 | "HuggingFaceM4/tiny-random-LlamaForCausalLM", 33 | ] 34 | 35 | FULL_GRID = { 36 | "model_ids": PEFT_DECODER_MODELS_TO_TEST, 37 | "task_type": "CAUSAL_LM", 38 | } 39 | 40 | 41 | def skip_non_pt_mqa(test_list): 42 | r""" 43 | Skip tests that are prefix tuning for MQA models (not supported yet) 44 | """ 45 | return [test for test in test_list if not ("prefix_tuning" in test[0] and "GPTBigCodeForCausalLM" in test[0])] 46 | 47 | 48 | class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester): 49 | r""" 50 | Test if the PeftModel behaves as expected. This includes: 51 | - test if the model has the expected methods 52 | 53 | We use parametrized.expand for debugging purposes to test each model individually. 54 | """ 55 | transformers_class = AutoModelForCausalLM 56 | 57 | def prepare_inputs_for_testing(self): 58 | input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device) 59 | attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) 60 | 61 | input_dict = { 62 | "input_ids": input_ids, 63 | "attention_mask": attention_mask, 64 | } 65 | 66 | return input_dict 67 | 68 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 69 | def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs): 70 | self._test_model_attr(model_id, config_cls, config_kwargs) 71 | 72 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 73 | def test_adapter_name(self, test_name, model_id, config_cls, config_kwargs): 74 | self._test_adapter_name(model_id, config_cls, config_kwargs) 75 | 76 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 77 | def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs): 78 | self._test_prepare_for_training(model_id, config_cls, config_kwargs) 79 | 80 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 81 | def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): 82 | self._test_save_pretrained(model_id, config_cls, config_kwargs) 83 | 84 | @parameterized.expand( 85 | PeftTestConfigManager.get_grid_parameters( 86 | { 87 | "model_ids": PEFT_DECODER_MODELS_TO_TEST, 88 | "lora_kwargs": {"init_lora_weights": [False]}, 89 | "task_type": "CAUSAL_LM", 90 | }, 91 | ) 92 | ) 93 | def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): 94 | self._test_merge_layers(model_id, config_cls, config_kwargs) 95 | 96 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_non_pt_mqa)) 97 | def test_generate(self, test_name, model_id, config_cls, config_kwargs): 98 | self._test_generate(model_id, config_cls, config_kwargs) 99 | 100 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_non_pt_mqa)) 101 | def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs): 102 | self._test_generate_half_prec(model_id, config_cls, config_kwargs) 103 | 104 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 105 | def test_training_decoders(self, test_name, model_id, config_cls, config_kwargs): 106 | self._test_training(model_id, config_cls, config_kwargs) 107 | 108 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 109 | def test_training_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs): 110 | self._test_training_layer_indexing(model_id, config_cls, config_kwargs) 111 | 112 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 113 | def test_training_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): 114 | self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) 115 | 116 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 117 | def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs): 118 | self._test_inference_safetensors(model_id, config_cls, config_kwargs) 119 | 120 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 121 | def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs): 122 | self._test_peft_model_device_map(model_id, config_cls, config_kwargs) 123 | -------------------------------------------------------------------------------- /peft_egg/tests/test_encoder_decoder_models.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import unittest 16 | 17 | import torch 18 | from parameterized import parameterized 19 | from transformers import AutoModelForSeq2SeqLM 20 | 21 | from .testing_common import PeftCommonTester, PeftTestConfigManager 22 | 23 | 24 | PEFT_ENCODER_DECODER_MODELS_TO_TEST = [ 25 | "ybelkada/tiny-random-T5ForConditionalGeneration-calibrated", 26 | "hf-internal-testing/tiny-random-BartForConditionalGeneration", 27 | ] 28 | 29 | FULL_GRID = {"model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST, "task_type": "SEQ_2_SEQ_LM"} 30 | 31 | 32 | def skip_non_lora_or_pt(test_list): 33 | r""" 34 | Skip tests that are not lora or prefix tuning 35 | """ 36 | return [test for test in test_list if ("lora" in test[0] or "prefix_tuning" in test[0])] 37 | 38 | 39 | class PeftEncoderDecoderModelTester(unittest.TestCase, PeftCommonTester): 40 | r""" 41 | Test if the PeftModel behaves as expected. This includes: 42 | - test if the model has the expected methods 43 | 44 | We use parametrized.expand for debugging purposes to test each model individually. 45 | """ 46 | transformers_class = AutoModelForSeq2SeqLM 47 | 48 | def prepare_inputs_for_testing(self): 49 | input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device) 50 | decoder_input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device) 51 | attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) 52 | 53 | input_dict = { 54 | "input_ids": input_ids, 55 | "decoder_input_ids": decoder_input_ids, 56 | "attention_mask": attention_mask, 57 | } 58 | 59 | return input_dict 60 | 61 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 62 | def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs): 63 | self._test_model_attr(model_id, config_cls, config_kwargs) 64 | 65 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 66 | def test_adapter_name(self, test_name, model_id, config_cls, config_kwargs): 67 | self._test_adapter_name(model_id, config_cls, config_kwargs) 68 | 69 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 70 | def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs): 71 | self._test_prepare_for_training(model_id, config_cls, config_kwargs) 72 | 73 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 74 | def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): 75 | self._test_save_pretrained(model_id, config_cls, config_kwargs) 76 | 77 | @parameterized.expand( 78 | PeftTestConfigManager.get_grid_parameters( 79 | { 80 | "model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST, 81 | "lora_kwargs": {"init_lora_weights": [False]}, 82 | "task_type": "SEQ_2_SEQ_LM", 83 | }, 84 | ) 85 | ) 86 | def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): 87 | self._test_merge_layers(model_id, config_cls, config_kwargs) 88 | 89 | # skip non lora models - generate does not work for prefix tuning, prompt tuning 90 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_non_lora_or_pt)) 91 | def test_generate(self, test_name, model_id, config_cls, config_kwargs): 92 | self._test_generate(model_id, config_cls, config_kwargs) 93 | 94 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 95 | def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs): 96 | self._test_generate_half_prec(model_id, config_cls, config_kwargs) 97 | 98 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 99 | def test_training_encoder_decoders(self, test_name, model_id, config_cls, config_kwargs): 100 | self._test_training(model_id, config_cls, config_kwargs) 101 | 102 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 103 | def test_training_encoder_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs): 104 | self._test_training_layer_indexing(model_id, config_cls, config_kwargs) 105 | 106 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 107 | def test_training_encoder_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): 108 | self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) 109 | 110 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 111 | def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs): 112 | self._test_inference_safetensors(model_id, config_cls, config_kwargs) 113 | 114 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) 115 | def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs): 116 | self._test_peft_model_device_map(model_id, config_cls, config_kwargs) 117 | -------------------------------------------------------------------------------- /peft_egg/tests/testing_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import unittest 16 | 17 | import torch 18 | 19 | 20 | def require_torch_gpu(test_case): 21 | """ 22 | Decorator marking a test that requires a GPU. Will be skipped when no GPU is available. 23 | """ 24 | if not torch.cuda.is_available(): 25 | return unittest.skip("test requires GPU")(test_case) 26 | else: 27 | return test_case 28 | 29 | 30 | def require_torch_multi_gpu(test_case): 31 | """ 32 | Decorator marking a test that requires multiple GPUs. Will be skipped when less than 2 GPUs are available. 33 | """ 34 | if not torch.cuda.is_available() or torch.cuda.device_count() < 2: 35 | return unittest.skip("test requires multiple GPUs")(test_case) 36 | else: 37 | return test_case 38 | 39 | 40 | def require_bitsandbytes(test_case): 41 | """ 42 | Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library is not installed. 43 | """ 44 | try: 45 | import bitsandbytes # noqa: F401 46 | except ImportError: 47 | return unittest.skip("test requires bitsandbytes")(test_case) 48 | else: 49 | return test_case 50 | -------------------------------------------------------------------------------- /plot/ablation/ablation_cluster_num.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator) 5 | 6 | colors = ['blue', 'red', 'orange','green'] 7 | base_dir = '../logs/VecDB' 8 | settings = ['T5Small_5000_e50', 'T5Small_5000_e75', 'T5Small_5000_e100'] 9 | CLUSTER = [] 10 | CONFLICTS = [] 11 | FORGET = [] 12 | 13 | r_list = [] 14 | 15 | t = list(range(0,5001,200)) 16 | 17 | for index, setting in enumerate(settings): 18 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f: 19 | loaded_dict = pickle.load(f) 20 | r_list.append(loaded_dict) 21 | 22 | 23 | for index, setting in enumerate(r_list): 24 | VecDB_log = setting['all_VecDB'] 25 | cluster = [0] 26 | for x in t[1:]: 27 | cluster.append(VecDB_log[x]['num_cluster']) 28 | CLUSTER.append(cluster) 29 | 30 | 31 | 32 | fig, axs = plt.subplots(1, 1, figsize=(4,4),dpi=800) 33 | # axs = fig.axes() 34 | 35 | axs.plot(t, CLUSTER[0], label = '$R$ = 50',linewidth=4,color=colors[0]) 36 | axs.plot(t, CLUSTER[1], label = '$R$ = 75',linewidth=4,color=colors[1]) 37 | axs.plot(t, CLUSTER[2], label = '$R$ = 100',linewidth=4,color=colors[2]) 38 | plt.axhline(y = 907, color = 'green', linestyle = '--',linewidth = 4,label = 'Answer Num.') 39 | axs.set_xlim(-100,5100) 40 | axs.set_ylim(0,4200) 41 | axs.xaxis.set_major_locator(MultipleLocator(1000)) 42 | axs.set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 43 | axs.yaxis.set_major_locator(MultipleLocator(1000)) 44 | axs.set_yticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 45 | axs.set_xlabel('Edits',fontweight='bold',family = 'serif',fontsize=18) 46 | axs.set_ylabel('Cluster Num',fontweight='bold',family = 'serif',fontsize=18) 47 | axs.grid(True) 48 | plt.xticks(fontsize=16) 49 | plt.yticks(fontsize=16) 50 | plt.legend(handlelength=2,fontsize=13) 51 | plt.tight_layout() 52 | plt.savefig("./ablation_res/T5Small_cluster") -------------------------------------------------------------------------------- /plot/ablation/ablation_conflicts_num.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator) 5 | 6 | colors = ['blue', 'red', 'orange','green'] 7 | base_dir = '../logs/VecDB' 8 | settings = ['T5Small_5000_e50', 'T5Small_5000_e75', 'T5Small_5000_e100'] 9 | CLUSTER = [] 10 | CONFLICTS = [] 11 | FORGET = [] 12 | 13 | r_list = [] 14 | 15 | t = list(range(0,5001,200)) 16 | 17 | for index, setting in enumerate(settings): 18 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f: 19 | loaded_dict = pickle.load(f) 20 | r_list.append(loaded_dict) 21 | 22 | 23 | for index, setting in enumerate(r_list): 24 | VecDB_log = setting['all_VecDB'] 25 | conflicts = [0] 26 | for x in t[1:]: 27 | conflicts.append(VecDB_log[x]['conflict_num']) 28 | CONFLICTS.append(conflicts) 29 | 30 | 31 | 32 | fig, axs = plt.subplots(1, 1, figsize=(4,4),dpi=800) 33 | # axs = fig.axes() 34 | 35 | axs.plot(t, CONFLICTS[0], label = '$R$ = 50',linewidth=4,color=colors[0]) 36 | axs.plot(t, CONFLICTS[1], label = '$R$ = 75',linewidth=4,color=colors[1]) 37 | axs.plot(t, CONFLICTS[2], label = '$R$ = 100',linewidth=4,color=colors[2]) 38 | 39 | axs.set_xlim(-100,5100) 40 | axs.set_ylim(0,800) 41 | axs.xaxis.set_major_locator(MultipleLocator(1000)) 42 | axs.set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 43 | axs.yaxis.set_major_locator(MultipleLocator(200)) 44 | axs.set_xlabel('Edits',fontweight='bold',family = 'serif',fontsize=18) 45 | axs.set_ylabel('Conflicts',fontweight='bold',family = 'serif',fontsize=18) 46 | axs.grid(True) 47 | plt.xticks(fontsize=16) 48 | plt.yticks(fontsize=16) 49 | plt.legend(handlelength=2,fontsize=13) 50 | plt.tight_layout() 51 | plt.savefig("./ablation_res/T5Small_conflicts") -------------------------------------------------------------------------------- /plot/ablation/ablation_forget.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator) 5 | 6 | colors = ['blue', 'red', 'orange','green'] 7 | base_dir = '../logs/VecDB' 8 | settings = ['T5Small_5000_e50', 'T5Small_5000_e75', 'T5Small_5000_e100'] 9 | CLUSTER = [] 10 | CONFLICTS = [] 11 | FORGET = [] 12 | 13 | r_list = [] 14 | 15 | t = list(range(0,5001,200)) 16 | 17 | for index, setting in enumerate(settings): 18 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f: 19 | loaded_dict = pickle.load(f) 20 | r_list.append(loaded_dict) 21 | 22 | 23 | for index, setting in enumerate(r_list): 24 | VecDB_log = setting['all_VecDB'] 25 | forget = [0] 26 | for x in t[1:]: 27 | forget.append(VecDB_log[x]['forget_keys']) 28 | FORGET.append(forget) 29 | 30 | 31 | 32 | fig, axs = plt.subplots(1, 1, figsize=(4,4),dpi=800) 33 | # axs = fig.axes() 34 | 35 | axs.plot(t, FORGET[0], label = '$R$ = 50',linewidth=4,color=colors[0]) 36 | axs.plot(t, FORGET[1], label = '$R$ = 75',linewidth=4,color=colors[1]) 37 | axs.plot(t, FORGET[2], label = '$R$ = 100',linewidth=4,color=colors[2]) 38 | 39 | axs.set_xlim(-100,5100) 40 | axs.set_ylim(0,500) 41 | axs.xaxis.set_major_locator(MultipleLocator(1000)) 42 | axs.set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 43 | axs.yaxis.set_major_locator(MultipleLocator(100)) 44 | axs.set_xlabel('Edits',fontweight='bold',family = 'serif',fontsize=18) 45 | axs.set_ylabel('F - Edits',fontweight='bold',family = 'serif',fontsize=18) 46 | axs.grid(True) 47 | plt.xticks(fontsize=16) 48 | plt.yticks(fontsize=16) 49 | plt.legend(handlelength=3, loc = 'upper left',fontsize=13) 50 | plt.tight_layout() 51 | plt.savefig("./ablation_res/T5Small_forget") -------------------------------------------------------------------------------- /plot/ablation/ablation_res/T5Large_zsre_block.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Large_zsre_block.jpg -------------------------------------------------------------------------------- /plot/ablation/ablation_res/T5Large_zsre_eps.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Large_zsre_eps.jpg -------------------------------------------------------------------------------- /plot/ablation/ablation_res/T5Small_cluster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_cluster.png -------------------------------------------------------------------------------- /plot/ablation/ablation_res/T5Small_conflicts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_conflicts.png -------------------------------------------------------------------------------- /plot/ablation/ablation_res/T5Small_forget.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_forget.png -------------------------------------------------------------------------------- /plot/ablation/ablation_res/T5Small_zsre_block.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_zsre_block.jpg -------------------------------------------------------------------------------- /plot/ablation/ablation_res/T5Small_zsre_block_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_zsre_block_2.jpg -------------------------------------------------------------------------------- /plot/ablation/ablation_res/T5Small_zsre_block_main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_zsre_block_main.jpg -------------------------------------------------------------------------------- /plot/ablation/ablation_res/T5Small_zsre_eps.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_zsre_eps.jpg -------------------------------------------------------------------------------- /plot/ablation/ablation_res/pca.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/pca.jpg -------------------------------------------------------------------------------- /plot/ablation/pca.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/pca.jpg -------------------------------------------------------------------------------- /plot/ablation/res_time.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | np.random.seed(19680801) 4 | 5 | pts = np.random.rand(30)*.2 6 | # Now let's make two outlier points which are far away from everything. 7 | pts[[3, 14]] += .8 8 | 9 | # If we were to simply plot pts, we'd lose most of the interesting 10 | # details due to the outliers. So let's 'break' or 'cut-out' the y-axis 11 | # into two portions - use the top (ax1) for the outliers, and the bottom 12 | # (ax2) for the details of the majority of our data 13 | fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True) 14 | fig.subplots_adjust(hspace=0.05) # adjust space between axes 15 | 16 | # plot the same data on both axes 17 | ax1.plot(pts) 18 | ax2.plot(pts) 19 | 20 | # zoom-in / limit the view to different portions of the data 21 | ax1.set_ylim(.78, 1.) # outliers only 22 | ax2.set_ylim(0, .22) # most of the data 23 | 24 | # hide the spines between ax and ax2 25 | ax1.spines.bottom.set_visible(False) 26 | ax2.spines.top.set_visible(False) 27 | ax1.xaxis.tick_top() 28 | ax1.tick_params(labeltop=False) # don't put tick labels at the top 29 | ax2.xaxis.tick_bottom() 30 | 31 | # Now, let's turn towards the cut-out slanted lines. 32 | # We create line objects in axes coordinates, in which (0,0), (0,1), 33 | # (1,0), and (1,1) are the four corners of the axes. 34 | # The slanted lines themselves are markers at those locations, such that the 35 | # lines keep their angle and position, independent of the axes size or scale 36 | # Finally, we need to disable clipping. 37 | 38 | d = .5 # proportion of vertical to horizontal extent of the slanted line 39 | kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, 40 | linestyle="none", color='k', mec='k', mew=1, clip_on=False) 41 | ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs) 42 | ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs) 43 | 44 | 45 | plt.show() -------------------------------------------------------------------------------- /plot/ablation/zsre_block_lineplot_T5Large.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator) 5 | 6 | colors = ['blue', 'red', 'green','orange'] 7 | 8 | original_his = 0.99 9 | original_up = 0.8111 10 | original_holdout = 0.3012 11 | original_time = 0 12 | 13 | base_dir = '../logs/paper' 14 | settings = ['t5large_zsre_r2e10b0', 't5large_zsre_r2e10b0', 't5large_zsre_r2e10b0'] 15 | 16 | r_list = [] 17 | 18 | for index, setting in enumerate(settings): 19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f: 20 | loaded_dict = pickle.load(f) 21 | r_list.append(loaded_dict) 22 | 23 | 24 | t = list(range(0,5001,200)) 25 | 26 | UP = [] 27 | HIS = [] 28 | HOLDOUT = [] 29 | TIME = [] 30 | for index, setting in enumerate(r_list): 31 | up_log = setting['all_UP'] 32 | up_f1 = [original_up] 33 | for x in t[1:]: 34 | up_f1.append(up_log[x]['UP_f1']) 35 | UP.append(up_f1) 36 | 37 | for index, setting in enumerate(r_list): 38 | his_log = setting['all_HIS'] 39 | his_f1 = [his_log[200]['HIS_f1']] 40 | for x in t[1:]: 41 | his_f1.append(his_log[x]['HIS_f1']) 42 | HIS.append(his_f1) 43 | 44 | for index, setting in enumerate(r_list): 45 | holdout_log = setting['all_HOLDOUT'] 46 | holdout_f1 = [original_holdout] 47 | for x in t[1:]: 48 | holdout_f1.append(holdout_log[x]['holdout_f1']) 49 | HOLDOUT.append(holdout_f1) 50 | 51 | for index, setting in enumerate(r_list): 52 | time_log = setting['all_edit_time'] 53 | time_list = [0] 54 | for x in t[1:]: 55 | time_list.append(time_log[x]/60) 56 | TIME.append(time_list) 57 | 58 | fig, axs = plt.subplots(2, 2, figsize=(5,4),dpi=800) 59 | for i in range(2): 60 | for x in axs[i]: 61 | x.set_box_aspect(0.85) 62 | 63 | 64 | 65 | axs[0,0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 66 | axs[0,0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 67 | axs[0,0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 68 | axs[0,0].set_xlim(-200,5200) 69 | axs[0,0].set_ylim(0.68,1.02) 70 | axs[0,0].xaxis.set_major_locator(MultipleLocator(1000)) 71 | axs[0,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 72 | axs[0,0].yaxis.set_major_locator(MultipleLocator(0.1)) 73 | axs[0,0].set_xlabel('Edits',fontweight='bold',family = 'serif') 74 | axs[0,0].set_ylabel('Edit Seccess',fontweight='bold',family = 'serif') 75 | axs[0,0].grid(True) 76 | 77 | axs[0,1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 78 | axs[0,1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 79 | axs[0,1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 80 | axs[0,1].set_xlim(-200,5200) 81 | axs[0,1].set_ylim(0.2,1.02) 82 | axs[0,1].xaxis.set_major_locator(MultipleLocator(1000)) 83 | axs[0,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 84 | axs[0,1].yaxis.set_major_locator(MultipleLocator(0.2)) 85 | axs[0,1].set_xlabel('Edits',fontweight='bold',family = 'serif') 86 | axs[0,1].set_ylabel('Holdout',fontweight='bold',family = 'serif') 87 | axs[0,1].grid(True) 88 | 89 | axs[1,0].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 90 | axs[1,0].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 91 | axs[1,0].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 92 | axs[1,0].set_xlim(-200,5200) 93 | axs[1,0].set_ylim(0,15) 94 | axs[1,0].xaxis.set_major_locator(MultipleLocator(1000)) 95 | axs[1,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 96 | axs[1,0].yaxis.set_major_locator(MultipleLocator(3)) 97 | axs[1,0].set_xlabel('Edits',fontweight='bold',family = 'serif') 98 | axs[1,0].set_ylabel('Time (mins)',fontweight='bold',family = 'serif') 99 | axs[1,0].grid(True) 100 | 101 | 102 | axs[1,1].plot(t, UP[0], label = '$VecDB$ = [0]',linewidth=1.8,color=colors[0]) 103 | axs[1,1].plot(t, UP[1], label = '$VecDB$ = [4]',linewidth=1.8,color=colors[1]) 104 | axs[1,1].plot(t, UP[2], label = '$vecDB$ = [10]',linewidth=1.8,color=colors[2]) 105 | axs[1,1].set_xlim(-200,5200) 106 | axs[1,1].set_ylim(0.5,0.9) 107 | axs[1,1].xaxis.set_major_locator(MultipleLocator(1000)) 108 | axs[1,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 109 | axs[1,1].yaxis.set_major_locator(MultipleLocator(0.1)) 110 | axs[1,1].set_xlabel('Edits',fontweight='bold',family = 'serif') 111 | axs[1,1].set_ylabel('Locality',fontweight='bold',family = 'serif') 112 | axs[1,1].grid(True) 113 | 114 | 115 | lines_labels = [axs[1,1].get_legend_handles_labels()] 116 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] 117 | fig.legend(lines, labels, loc='lower center', ncol=3) 118 | plt.tight_layout() 119 | plt.subplots_adjust(wspace=0) 120 | 121 | plt.subplots_adjust(bottom=0.2) 122 | 123 | 124 | plt.savefig('ablation_res/T5Large_zsre_block.jpg') 125 | 126 | plt.show() 127 | 128 | pass 129 | 130 | 131 | -------------------------------------------------------------------------------- /plot/ablation/zsre_block_lineplot_T5Small.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator) 5 | 6 | colors = ['blue', 'red', 'green','orange'] 7 | 8 | original_his = 0.99 9 | original_up = 0.7225 10 | original_holdout = 0.3012 11 | original_time = 0 12 | 13 | base_dir = '../logs/paper' 14 | settings = ['t5small_zsre_r2e75b0', 't5small_zsre_r2e75b2', 't5small_zsre_r2e75b4'] 15 | 16 | r_list = [] 17 | 18 | for index, setting in enumerate(settings): 19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f: 20 | loaded_dict = pickle.load(f) 21 | r_list.append(loaded_dict) 22 | 23 | 24 | t = list(range(0,5001,200)) 25 | 26 | UP = [] 27 | HIS = [] 28 | HOLDOUT = [] 29 | TIME = [] 30 | for index, setting in enumerate(r_list): 31 | up_log = setting['all_UP'] 32 | up_f1 = [original_up] 33 | for x in t[1:]: 34 | up_f1.append(up_log[x]['UP_f1']) 35 | UP.append(up_f1) 36 | 37 | for index, setting in enumerate(r_list): 38 | his_log = setting['all_HIS'] 39 | his_f1 = [his_log[200]['HIS_f1']] 40 | for x in t[1:]: 41 | his_f1.append(his_log[x]['HIS_f1']) 42 | HIS.append(his_f1) 43 | 44 | for index, setting in enumerate(r_list): 45 | holdout_log = setting['all_HOLDOUT'] 46 | holdout_f1 = [original_holdout] 47 | for x in t[1:]: 48 | holdout_f1.append(holdout_log[x]['holdout_f1']) 49 | HOLDOUT.append(holdout_f1) 50 | 51 | for index, setting in enumerate(r_list): 52 | time_log = setting['all_edit_time'] 53 | time_list = [0] 54 | for x in t[1:]: 55 | time_list.append(time_log[x]/60) 56 | TIME.append(time_list) 57 | 58 | fig, axs = plt.subplots(2, 2, figsize=(5,4),dpi=800) 59 | for i in range(2): 60 | for x in axs[i]: 61 | x.set_box_aspect(0.85) 62 | 63 | 64 | 65 | axs[0,0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 66 | axs[0,0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 67 | axs[0,0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 68 | axs[0,0].set_xlim(-200,5200) 69 | axs[0,0].set_ylim(0.68,1.02) 70 | axs[0,0].xaxis.set_major_locator(MultipleLocator(1000)) 71 | axs[0,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 72 | axs[0,0].yaxis.set_major_locator(MultipleLocator(0.1)) 73 | axs[0,0].set_xlabel('Edits',fontweight='bold',family = 'serif') 74 | axs[0,0].set_ylabel('Edit Seccess',fontweight='bold',family = 'serif') 75 | axs[0,0].grid(True) 76 | 77 | axs[0,1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 78 | axs[0,1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 79 | axs[0,1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 80 | axs[0,1].set_xlim(-200,5200) 81 | axs[0,1].set_ylim(0.2,1.02) 82 | axs[0,1].xaxis.set_major_locator(MultipleLocator(1000)) 83 | axs[0,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 84 | axs[0,1].yaxis.set_major_locator(MultipleLocator(0.2)) 85 | axs[0,1].set_xlabel('Edits',fontweight='bold',family = 'serif') 86 | axs[0,1].set_ylabel('Holdout',fontweight='bold',family = 'serif') 87 | axs[0,1].grid(True) 88 | 89 | axs[1,0].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 90 | axs[1,0].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 91 | axs[1,0].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 92 | axs[1,0].set_xlim(-200,5200) 93 | axs[1,0].set_ylim(0,3.2) 94 | axs[1,0].xaxis.set_major_locator(MultipleLocator(1000)) 95 | axs[1,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 96 | axs[1,0].yaxis.set_major_locator(MultipleLocator(1)) 97 | axs[1,0].set_xlabel('Edits',fontweight='bold',family = 'serif') 98 | axs[1,0].set_ylabel('Time (mins)',fontweight='bold',family = 'serif') 99 | axs[1,0].grid(True) 100 | 101 | 102 | axs[1,1].plot(t, UP[0], label = '$VecDB = [0]$',linewidth=1.8,color=colors[0]) 103 | axs[1,1].plot(t, UP[1], label = '$VecDB = [2]$',linewidth=1.8,color=colors[1]) 104 | axs[1,1].plot(t, UP[2], label = '$VecDB = [4]$',linewidth=1.8,color=colors[2]) 105 | axs[1,1].set_xlim(-200,5200) 106 | axs[1,1].set_ylim(0.3,0.9) 107 | axs[1,1].xaxis.set_major_locator(MultipleLocator(1000)) 108 | axs[1,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 109 | axs[1,1].yaxis.set_major_locator(MultipleLocator(0.1)) 110 | axs[1,1].set_xlabel('Edits',fontweight='bold',family = 'serif') 111 | axs[1,1].set_ylabel('Locality',fontweight='bold',family = 'serif') 112 | axs[1,1].grid(True) 113 | 114 | 115 | lines_labels = [axs[1,1].get_legend_handles_labels()] 116 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] 117 | fig.legend(lines, labels, loc='lower center', ncol=3) 118 | plt.tight_layout() 119 | plt.subplots_adjust(wspace=0) 120 | 121 | plt.subplots_adjust(bottom=0.2) 122 | 123 | 124 | plt.savefig('ablation_res/T5Small_zsre_block.jpg') 125 | 126 | plt.show() 127 | 128 | pass 129 | 130 | 131 | -------------------------------------------------------------------------------- /plot/ablation/zsre_block_lineplot_t5small_2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator) 5 | 6 | colors = ['blue', 'red', 'green','orange'] 7 | 8 | original_his = 0.99 9 | original_up = 0.7225 10 | original_holdout = 0.3012 11 | original_time = 0 12 | 13 | base_dir = '../logs/paper' 14 | settings = ['t5small_zsre_r2e75b0', 't5small_zsre_r2e75b2', 't5small_zsre_r2e75b4'] 15 | 16 | r_list = [] 17 | 18 | for index, setting in enumerate(settings): 19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f: 20 | loaded_dict = pickle.load(f) 21 | r_list.append(loaded_dict) 22 | 23 | 24 | t = list(range(0,5001,200)) 25 | 26 | UP = [] 27 | HIS = [] 28 | HOLDOUT = [] 29 | TIME = [] 30 | for index, setting in enumerate(r_list): 31 | up_log = setting['all_UP'] 32 | up_f1 = [original_up] 33 | for x in t[1:]: 34 | up_f1.append(up_log[x]['UP_f1']) 35 | UP.append(up_f1) 36 | 37 | for index, setting in enumerate(r_list): 38 | his_log = setting['all_HIS'] 39 | his_f1 = [his_log[200]['HIS_f1']] 40 | for x in t[1:]: 41 | his_f1.append(his_log[x]['HIS_f1']) 42 | HIS.append(his_f1) 43 | 44 | for index, setting in enumerate(r_list): 45 | holdout_log = setting['all_HOLDOUT'] 46 | holdout_f1 = [original_holdout] 47 | for x in t[1:]: 48 | holdout_f1.append(holdout_log[x]['holdout_f1']) 49 | HOLDOUT.append(holdout_f1) 50 | 51 | for index, setting in enumerate(r_list): 52 | time_log = setting['all_edit_time'] 53 | time_list = [0] 54 | for x in t[1:]: 55 | time_list.append(time_log[x]/60) 56 | TIME.append(time_list) 57 | 58 | fig, axs = plt.subplots(1, 2, figsize=(5,2.5),dpi=800) 59 | for i in range(2): 60 | for x in axs: 61 | x.set_box_aspect(0.85) 62 | 63 | 64 | 65 | axs[0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 66 | axs[0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 67 | axs[0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 68 | axs[0].set_xlim(-200,5200) 69 | axs[0].set_ylim(0.68,1.02) 70 | axs[0].xaxis.set_major_locator(MultipleLocator(1000)) 71 | axs[0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 72 | axs[0].yaxis.set_major_locator(MultipleLocator(0.1)) 73 | axs[0].set_xlabel('Edits',fontweight='bold',family = 'serif') 74 | axs[0].set_ylabel('Edit Success',fontweight='bold',family = 'serif') 75 | axs[0].grid(True) 76 | 77 | 78 | 79 | axs[1].plot(t, UP[0], label = '$VecDB = [0]$',linewidth=1.8,color=colors[0]) 80 | axs[1].plot(t, UP[1], label = '$VecDB = [2]$',linewidth=1.8,color=colors[1]) 81 | axs[1].plot(t, UP[2], label = '$VecDB = [4]$',linewidth=1.8,color=colors[2]) 82 | axs[1].set_xlim(-200,5200) 83 | axs[1].set_ylim(0.3,0.9) 84 | axs[1].xaxis.set_major_locator(MultipleLocator(1000)) 85 | axs[1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 86 | axs[1].yaxis.set_major_locator(MultipleLocator(0.1)) 87 | axs[1].set_xlabel('Edits',fontweight='bold',family = 'serif') 88 | axs[1].set_ylabel('Locality',fontweight='bold',family = 'serif') 89 | axs[1].grid(True) 90 | 91 | 92 | lines_labels = [axs[1].get_legend_handles_labels()] 93 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] 94 | fig.legend(lines, labels, loc='lower center', ncol=3) 95 | plt.tight_layout() 96 | # plt.subplots_adjust(wspace=0) 97 | 98 | plt.subplots_adjust(bottom=0.3) 99 | 100 | 101 | plt.savefig('ablation_res/T5Small_zsre_block_2.jpg') 102 | 103 | plt.show() 104 | 105 | pass 106 | 107 | 108 | -------------------------------------------------------------------------------- /plot/ablation/zsre_eps_lineplot_T5Large.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator) 5 | 6 | colors = ['blue', 'red', 'green','orange'] 7 | 8 | original_his = 0.99 9 | original_up = 0.8111 10 | original_holdout = 0.3012 11 | original_time = 0 12 | 13 | base_dir = '../logs/paper' 14 | settings = ['t5large_zsre_r2e10b4', 't5large_zsre_r2e15b4', 't5large_zsre_r2e20b4'] 15 | 16 | r_list = [] 17 | 18 | for index, setting in enumerate(settings): 19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f: 20 | loaded_dict = pickle.load(f) 21 | r_list.append(loaded_dict) 22 | 23 | 24 | t = list(range(0,5001,200)) 25 | 26 | UP = [] 27 | HIS = [] 28 | HOLDOUT = [] 29 | TIME = [] 30 | for index, setting in enumerate(r_list): 31 | up_log = setting['all_UP'] 32 | up_f1 = [original_up] 33 | for x in t[1:]: 34 | up_f1.append(up_log[x]['UP_f1']) 35 | UP.append(up_f1) 36 | 37 | for index, setting in enumerate(r_list): 38 | his_log = setting['all_HIS'] 39 | his_f1 = [his_log[200]['HIS_f1']] 40 | for x in t[1:]: 41 | his_f1.append(his_log[x]['HIS_f1']) 42 | HIS.append(his_f1) 43 | 44 | for index, setting in enumerate(r_list): 45 | holdout_log = setting['all_HOLDOUT'] 46 | holdout_f1 = [original_holdout] 47 | for x in t[1:]: 48 | holdout_f1.append(holdout_log[x]['holdout_f1']) 49 | HOLDOUT.append(holdout_f1) 50 | 51 | for index, setting in enumerate(r_list): 52 | time_log = setting['all_edit_time'] 53 | time_list = [0] 54 | for x in t[1:]: 55 | time_list.append(time_log[x]/60) 56 | TIME.append(time_list) 57 | 58 | fig, axs = plt.subplots(2, 2, figsize=(5,4),dpi=800) 59 | for i in range(2): 60 | for x in axs[i]: 61 | x.set_box_aspect(0.85) 62 | 63 | 64 | 65 | axs[0,0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 66 | axs[0,0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 67 | axs[0,0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 68 | axs[0,0].set_xlim(-200,5200) 69 | axs[0,0].set_ylim(0.68,1.02) 70 | axs[0,0].xaxis.set_major_locator(MultipleLocator(1000)) 71 | axs[0,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 72 | axs[0,0].yaxis.set_major_locator(MultipleLocator(0.1)) 73 | axs[0,0].set_xlabel('Edits',fontweight='bold',family = 'serif') 74 | axs[0,0].set_ylabel('Edit Seccess',fontweight='bold',family = 'serif') 75 | axs[0,0].grid(True) 76 | 77 | axs[0,1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 78 | axs[0,1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 79 | axs[0,1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 80 | axs[0,1].set_xlim(-200,5200) 81 | axs[0,1].set_ylim(0.2,1.02) 82 | axs[0,1].xaxis.set_major_locator(MultipleLocator(1000)) 83 | axs[0,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 84 | axs[0,1].yaxis.set_major_locator(MultipleLocator(0.2)) 85 | axs[0,1].set_xlabel('Edits',fontweight='bold',family = 'serif') 86 | axs[0,1].set_ylabel('Holdout',fontweight='bold',family = 'serif') 87 | axs[0,1].grid(True) 88 | 89 | axs[1,0].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 90 | axs[1,0].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 91 | axs[1,0].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 92 | axs[1,0].set_xlim(-200,5200) 93 | axs[1,0].set_ylim(0,15) 94 | axs[1,0].xaxis.set_major_locator(MultipleLocator(1000)) 95 | axs[1,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 96 | axs[1,0].yaxis.set_major_locator(MultipleLocator(3)) 97 | axs[1,0].set_xlabel('Edits',fontweight='bold',family = 'serif') 98 | axs[1,0].set_ylabel('Time (mins)',fontweight='bold',family = 'serif') 99 | axs[1,0].grid(True) 100 | 101 | 102 | axs[1,1].plot(t, UP[0], label = '$R$ = 1.0',linewidth=1.8,color=colors[0]) 103 | axs[1,1].plot(t, UP[1], label = '$R$ = 6.0',linewidth=1.8,color=colors[1]) 104 | axs[1,1].plot(t, UP[2], label = '$R$ = 20.0',linewidth=1.8,color=colors[2]) 105 | axs[1,1].set_xlim(-200,5200) 106 | axs[1,1].set_ylim(0.5,0.9) 107 | axs[1,1].xaxis.set_major_locator(MultipleLocator(1000)) 108 | axs[1,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 109 | axs[1,1].yaxis.set_major_locator(MultipleLocator(0.1)) 110 | axs[1,1].set_xlabel('Edits',fontweight='bold',family = 'serif') 111 | axs[1,1].set_ylabel('Locality',fontweight='bold',family = 'serif') 112 | axs[1,1].grid(True) 113 | 114 | 115 | lines_labels = [axs[1,1].get_legend_handles_labels()] 116 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] 117 | fig.legend(lines, labels, loc='lower center', ncol=3) 118 | plt.tight_layout() 119 | plt.subplots_adjust(wspace=0) 120 | 121 | plt.subplots_adjust(bottom=0.2) 122 | 123 | 124 | plt.savefig('ablation_res/T5Large_zsre_eps.jpg') 125 | 126 | plt.show() 127 | 128 | pass 129 | 130 | 131 | -------------------------------------------------------------------------------- /plot/ablation/zsre_eps_lineplot_T5Small.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator) 5 | 6 | colors = ['blue', 'red', 'green','orange'] 7 | 8 | original_his = 0.99 9 | original_up = 0.7225 10 | original_holdout = 0.3012 11 | original_time = 0 12 | 13 | base_dir = '../logs/paper' 14 | settings = ['t5small_zsre_r2e50b4', 't5small_zsre_r2e75b4', 't5small_zsre_r2e100b4'] 15 | 16 | r_list = [] 17 | 18 | for index, setting in enumerate(settings): 19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f: 20 | loaded_dict = pickle.load(f) 21 | r_list.append(loaded_dict) 22 | 23 | 24 | t = list(range(0,5001,200)) 25 | 26 | UP = [] 27 | HIS = [] 28 | HOLDOUT = [] 29 | TIME = [] 30 | for index, setting in enumerate(r_list): 31 | up_log = setting['all_UP'] 32 | up_f1 = [original_up] 33 | for x in t[1:]: 34 | up_f1.append(up_log[x]['UP_f1']) 35 | UP.append(up_f1) 36 | 37 | for index, setting in enumerate(r_list): 38 | his_log = setting['all_HIS'] 39 | his_f1 = [his_log[200]['HIS_f1']] 40 | for x in t[1:]: 41 | his_f1.append(his_log[x]['HIS_f1']) 42 | HIS.append(his_f1) 43 | 44 | for index, setting in enumerate(r_list): 45 | holdout_log = setting['all_HOLDOUT'] 46 | holdout_f1 = [original_holdout] 47 | for x in t[1:]: 48 | holdout_f1.append(holdout_log[x]['holdout_f1']) 49 | HOLDOUT.append(holdout_f1) 50 | 51 | for index, setting in enumerate(r_list): 52 | time_log = setting['all_edit_time'] 53 | time_list = [0] 54 | for x in t[1:]: 55 | time_list.append(time_log[x]/60) 56 | TIME.append(time_list) 57 | 58 | fig, axs = plt.subplots(2, 2, figsize=(5,4),dpi=800) 59 | for i in range(2): 60 | for x in axs[i]: 61 | x.set_box_aspect(0.85) 62 | 63 | 64 | 65 | axs[0,0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 66 | axs[0,0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 67 | axs[0,0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 68 | axs[0,0].set_xlim(-200,5200) 69 | axs[0,0].set_ylim(0.68,1.02) 70 | axs[0,0].xaxis.set_major_locator(MultipleLocator(1000)) 71 | axs[0,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 72 | axs[0,0].yaxis.set_major_locator(MultipleLocator(0.1)) 73 | axs[0,0].set_xlabel('Edits',fontweight='bold',family = 'serif') 74 | axs[0,0].set_ylabel('Edit Seccess',fontweight='bold',family = 'serif') 75 | axs[0,0].grid(True) 76 | 77 | axs[0,1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 78 | axs[0,1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 79 | axs[0,1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 80 | axs[0,1].set_xlim(-200,5200) 81 | axs[0,1].set_ylim(0.2,1.02) 82 | axs[0,1].xaxis.set_major_locator(MultipleLocator(1000)) 83 | axs[0,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 84 | axs[0,1].yaxis.set_major_locator(MultipleLocator(0.2)) 85 | axs[0,1].set_xlabel('Edits',fontweight='bold',family = 'serif') 86 | axs[0,1].set_ylabel('Holdout',fontweight='bold',family = 'serif') 87 | axs[0,1].grid(True) 88 | 89 | axs[1,0].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 90 | axs[1,0].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 91 | axs[1,0].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 92 | axs[1,0].set_xlim(-200,5200) 93 | axs[1,0].set_ylim(0,3.2) 94 | axs[1,0].xaxis.set_major_locator(MultipleLocator(1000)) 95 | axs[1,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 96 | axs[1,0].yaxis.set_major_locator(MultipleLocator(1)) 97 | axs[1,0].set_xlabel('Edits',fontweight='bold',family = 'serif') 98 | axs[1,0].set_ylabel('Time (mins)',fontweight='bold',family = 'serif') 99 | axs[1,0].grid(True) 100 | 101 | 102 | axs[1,1].plot(t, UP[0], label = '$R$ = 50',linewidth=1.8,color=colors[0]) 103 | axs[1,1].plot(t, UP[1], label = '$R$ = 75',linewidth=1.8,color=colors[1]) 104 | axs[1,1].plot(t, UP[2], label = '$R$ = 100',linewidth=1.8,color=colors[2]) 105 | axs[1,1].set_xlim(-200,5200) 106 | axs[1,1].set_ylim(0.5,0.9) 107 | axs[1,1].xaxis.set_major_locator(MultipleLocator(1000)) 108 | axs[1,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 109 | axs[1,1].yaxis.set_major_locator(MultipleLocator(0.1)) 110 | axs[1,1].set_xlabel('Edits',fontweight='bold',family = 'serif') 111 | axs[1,1].set_ylabel('Locality',fontweight='bold',family = 'serif') 112 | axs[1,1].grid(True) 113 | 114 | 115 | lines_labels = [axs[1,1].get_legend_handles_labels()] 116 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] 117 | fig.legend(lines, labels, loc='lower center', ncol=3) 118 | plt.tight_layout() 119 | plt.subplots_adjust(wspace=0) 120 | 121 | plt.subplots_adjust(bottom=0.2) 122 | 123 | 124 | plt.savefig('ablation_res/T5Small_zsre_eps.jpg') 125 | 126 | plt.show() 127 | 128 | pass 129 | 130 | 131 | -------------------------------------------------------------------------------- /plot/tsne.py: -------------------------------------------------------------------------------- 1 | import plotly.express as px 2 | from sklearn.datasets import make_classification 3 | 4 | X, y = make_classification( 5 | n_features=6, 6 | n_classes=3, 7 | n_samples=1500, 8 | n_informative=2, 9 | random_state=5, 10 | n_clusters_per_class=1, 11 | ) 12 | 13 | 14 | fig = px.scatter_3d(x=X[:, 0], y=X[:, 1], z=X[:, 2], color=y, opacity=0.8) 15 | fig.show() 16 | 17 | 18 | #-------------------PLOT PCA---------------------------------- 19 | from sklearn.decomposition import PCA 20 | pca = PCA(n_components=2) 21 | X_pca = pca.fit_transform(X) 22 | 23 | fig = px.scatter(x=X_pca[:, 0], y=X_pca[:, 1], color=y) 24 | fig.update_layout( 25 | title="PCA visualization of Custom Classification dataset", 26 | xaxis_title="First Principal Component", 27 | yaxis_title="Second Principal Component", 28 | ) 29 | fig.show() 30 | 31 | 32 | #-------------------PLOT tsne---------------------------------- 33 | from sklearn.manifold import TSNE 34 | 35 | tsne = TSNE(n_components=2, random_state=42) 36 | X_tsne = tsne.fit_transform(X) 37 | print(tsne.kl_divergence_) 38 | 39 | fig = px.scatter(x=X_tsne[:, 0], y=X_tsne[:, 1], color=y) 40 | fig.update_layout( 41 | title="t-SNE visualization of Custom Classification dataset", 42 | xaxis_title="First t-SNE", 43 | yaxis_title="Second t-SNE", 44 | ) 45 | fig.show() -------------------------------------------------------------------------------- /plot/tsne_zsre_T5Small.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import pandas as pd 4 | import plotly.express as px 5 | base_dir = 'logs/VecDB' 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.mplot3d import Axes3D 8 | from sklearn.decomposition import PCA 9 | from sklearn.manifold import TSNE 10 | import seaborn as sns 11 | import numpy as np 12 | import transformers 13 | with open(f'{base_dir}/T5Small_100_e75/log.pkl', 'rb') as f: 14 | loaded_dict = pickle.load(f) 15 | sns.set(style="darkgrid") 16 | 17 | tokenizer = transformers.AutoTokenizer.from_pretrained('google/t5-small-ssm-nq') 18 | vecdb = loaded_dict['all_VecDB'] 19 | import time 20 | clusters = vecdb.table 21 | X = [] 22 | Y = [] 23 | for index, cluster in enumerate(clusters): 24 | 25 | if len(cluster['points']) > 3 and cluster['radius']: 26 | key_label = cluster['key_label'] 27 | key_label = key_label.masked_fill(key_label == -100, tokenizer.pad_token_id) 28 | key_label = tokenizer.decode(key_label, skip_special_tokens=True) 29 | print(key_label,index) 30 | 31 | for point in cluster['points']: 32 | X.append(point.get_key()) 33 | Y.append(index) 34 | 35 | X = torch.stack(X,dim=0).cpu().numpy() 36 | y = torch.tensor(Y).cpu().numpy() 37 | 38 | # fig = px.scatter_3d(x=X[:, 0], y=X[:, 1], z=X[:, 2], color=Y, opacity=0.8) 39 | # fig.show() 40 | 41 | feat_cols = [ 'pixel'+str(i) for i in range(X.shape[1])] 42 | df = pd.DataFrame(X,columns=feat_cols) 43 | df['y'] = y 44 | df['label'] = df['y'].apply(lambda i: str(i)) 45 | 46 | X, y = None, None 47 | 48 | print('Size of the dataframe: {}'.format(df.shape)) 49 | 50 | pca = PCA(n_components=30) 51 | pca_result = pca.fit_transform(df[feat_cols].values) 52 | 53 | df['pca-one'] = pca_result[:,0] 54 | df['pca-two'] = pca_result[:,4] 55 | # df['pca-three'] = pca_result[:,2] 56 | 57 | print('Explained variation per principal component: {}'.format(pca.explained_variance_ratio_)) 58 | np.random.seed(0) 59 | 60 | rndperm = np.random.permutation(df.shape[0]) 61 | plt.figure(figsize=(5,5), dpi = 800) 62 | ax = sns.scatterplot( 63 | x="pca-one", y="pca-two", 64 | hue="y", 65 | palette=sns.color_palette("tab10"), 66 | legend=False, 67 | data=df, 68 | marker= 'o', 69 | s=120, 70 | ) 71 | ax.set(xlabel=None) 72 | ax.set(ylabel=None) 73 | plt.tight_layout() 74 | plt.tick_params(axis='both', which='major', labelsize=18) 75 | plt.savefig('./ablation/ablation_res/pca.jpg') 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /plot/zsre_eps_lineplot_T5Large.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator) 5 | 6 | colors = ['blue', 'red', 'green','orange'] 7 | 8 | original_his = 0.99 9 | original_up = 0.8111 10 | original_holdout = 0.3012 11 | original_time = 0 12 | 13 | base_dir = 'logs/5000' 14 | settings = ['T5Large_r2e1b4', 'T5Large_r2e6b4', 'T5Large_r2e20b4'] 15 | 16 | r_list = [] 17 | 18 | for index, setting in enumerate(settings): 19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f: 20 | loaded_dict = pickle.load(f) 21 | r_list.append(loaded_dict) 22 | 23 | 24 | t = list(range(0,5001,200)) 25 | 26 | UP = [] 27 | HIS = [] 28 | HOLDOUT = [] 29 | TIME = [] 30 | for index, setting in enumerate(r_list): 31 | up_log = setting['all_UP'] 32 | up_f1 = [original_up] 33 | for x in t[1:]: 34 | up_f1.append(up_log[x]['UP_f1']) 35 | UP.append(up_f1) 36 | 37 | for index, setting in enumerate(r_list): 38 | his_log = setting['all_HIS'] 39 | his_f1 = [his_log[200]['HIS_f1']] 40 | for x in t[1:]: 41 | his_f1.append(his_log[x]['HIS_f1']) 42 | HIS.append(his_f1) 43 | 44 | for index, setting in enumerate(r_list): 45 | holdout_log = setting['all_HOLDOUT'] 46 | holdout_f1 = [original_holdout] 47 | for x in t[1:]: 48 | holdout_f1.append(holdout_log[x]['holdout_f1']) 49 | HOLDOUT.append(holdout_f1) 50 | 51 | for index, setting in enumerate(r_list): 52 | time_log = setting['all_edit_time'] 53 | time_list = [0] 54 | for x in t[1:]: 55 | time_list.append(time_log[x]/60) 56 | TIME.append(time_list) 57 | 58 | fig, axs = plt.subplots(1, 4, figsize=(8,2),dpi=800) 59 | for i in range(4): 60 | axs[i].set_box_aspect(0.85) 61 | 62 | 63 | 64 | axs[0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 65 | axs[0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 66 | axs[0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 67 | axs[0].set_xlim(-200,5200) 68 | axs[0].set_ylim(0.68,1.02) 69 | axs[0].xaxis.set_major_locator(MultipleLocator(1000)) 70 | axs[0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 71 | axs[0].yaxis.set_major_locator(MultipleLocator(0.1)) 72 | axs[0].set_xlabel('Edits',fontweight='bold',family = 'serif') 73 | axs[0].set_ylabel('Edit Seccess',fontweight='bold',family = 'serif') 74 | axs[0].grid(True) 75 | 76 | axs[1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 77 | axs[1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 78 | axs[1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 79 | axs[1].set_xlim(-200,5200) 80 | axs[1].set_ylim(0.2,1.02) 81 | axs[1].xaxis.set_major_locator(MultipleLocator(1000)) 82 | axs[1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 83 | axs[1].yaxis.set_major_locator(MultipleLocator(0.2)) 84 | axs[1].set_xlabel('Edits',fontweight='bold',family = 'serif') 85 | axs[1].set_ylabel('Holdout',fontweight='bold',family = 'serif') 86 | axs[1].grid(True) 87 | 88 | axs[2].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 89 | axs[2].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 90 | axs[2].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 91 | axs[2].set_xlim(-200,5200) 92 | axs[2].set_ylim(0,15) 93 | axs[2].xaxis.set_major_locator(MultipleLocator(1000)) 94 | axs[2].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 95 | axs[2].yaxis.set_major_locator(MultipleLocator(3)) 96 | axs[2].set_xlabel('Edits',fontweight='bold',family = 'serif') 97 | axs[2].set_ylabel('Time (mins)',fontweight='bold',family = 'serif') 98 | axs[2].grid(True) 99 | 100 | 101 | axs[3].plot(t, UP[0], label = 'Radius = 1.0',linewidth=1.8,color=colors[0]) 102 | axs[3].plot(t, UP[1], label = 'Radius = 6.0',linewidth=1.8,color=colors[1]) 103 | axs[3].plot(t, UP[2], label = 'Radius = 20.0',linewidth=1.8,color=colors[2]) 104 | axs[3].set_xlim(-200,5200) 105 | axs[3].set_ylim(0.5,0.9) 106 | axs[3].xaxis.set_major_locator(MultipleLocator(1000)) 107 | axs[3].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 108 | axs[3].yaxis.set_major_locator(MultipleLocator(0.1)) 109 | axs[3].set_xlabel('Edits',fontweight='bold',family = 'serif') 110 | axs[3].set_ylabel('Locality',fontweight='bold',family = 'serif') 111 | axs[3].grid(True) 112 | 113 | 114 | lines_labels = [axs[3].get_legend_handles_labels()] 115 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] 116 | fig.legend(lines, labels, loc='lower center', ncol=4) 117 | plt.tight_layout() 118 | plt.subplots_adjust(wspace=0.5) 119 | plt.subplots_adjust(bottom=0.4) 120 | 121 | 122 | plt.savefig('plotres/T5Large_zsre_eps.jpg') 123 | 124 | plt.show() 125 | 126 | pass 127 | 128 | 129 | -------------------------------------------------------------------------------- /plot/zsre_rank_line_plot_generality.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator) 5 | 6 | colors = ['blue', 'red', 'orange','green'] 7 | 8 | original_his = 0.99 9 | original_up = 0.7225 10 | original_holdout = 0.3012 11 | original_time = 0 12 | 13 | base_dir = 'logs/5000' 14 | settings = ['T5Small_r1e3b4', 'T5Small_r2e3b4', 'T5Small_r4e3b4'] 15 | 16 | r_list_small = [] 17 | 18 | for index, setting in enumerate(settings): 19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f: 20 | loaded_dict = pickle.load(f) 21 | r_list_small.append(loaded_dict) 22 | 23 | 24 | settings = ['T5Large_r1e3b4', 'T5Large_r2e3b4', 'T5Large_r4e3b4'] 25 | r_list_large = [] 26 | 27 | for index, setting in enumerate(settings): 28 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f: 29 | loaded_dict = pickle.load(f) 30 | r_list_large.append(loaded_dict) 31 | 32 | 33 | t = list(range(0,5001,200)) 34 | 35 | 36 | HOLDOUT_SMALL = [] 37 | HOLDOUT_LARGE = [] 38 | 39 | 40 | 41 | for index, setting in enumerate(r_list_small): 42 | holdout_log = setting['all_HOLDOUT'] 43 | holdout_f1 = [original_holdout] 44 | for x in t[1:]: 45 | holdout_f1.append(holdout_log[x]['holdout_f1']) 46 | HOLDOUT_SMALL.append(holdout_f1) 47 | 48 | 49 | for index, setting in enumerate(r_list_large): 50 | holdout_log = setting['all_HOLDOUT'] 51 | holdout_f1 = [original_holdout] 52 | for x in t[1:]: 53 | holdout_f1.append(holdout_log[x]['holdout_f1']) 54 | HOLDOUT_LARGE.append(holdout_f1) 55 | 56 | 57 | fig, axs = plt.subplots(1, 2, figsize=(4,2),dpi=800) 58 | for i in range(2): 59 | axs[i].set_box_aspect(0.75) 60 | 61 | 62 | 63 | axs[0].plot(t, HOLDOUT_SMALL[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 64 | axs[0].plot(t, HOLDOUT_SMALL[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 65 | axs[0].plot(t, HOLDOUT_SMALL[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 66 | axs[0].set_xlim(-200,5200) 67 | axs[0].set_ylim(0.2,1.02) 68 | axs[0].xaxis.set_major_locator(MultipleLocator(1000)) 69 | axs[0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 70 | axs[0].yaxis.set_major_locator(MultipleLocator(0.2)) 71 | axs[0].set_xlabel('Edits',fontweight='bold',family = 'serif') 72 | axs[0].set_ylabel('Generality',fontweight='bold',family = 'serif') 73 | axs[0].set_title('T5Small',fontweight='bold',family = 'serif',size=11) 74 | axs[0].grid(True) 75 | 76 | 77 | axs[1].plot(t, HOLDOUT_LARGE[0], label = 'PR = 1',linewidth=1.8,color=colors[0]) 78 | axs[1].plot(t, HOLDOUT_LARGE[1], label = 'PR = 2',linewidth=1.8,color=colors[1]) 79 | axs[1].plot(t, HOLDOUT_LARGE[2], label = 'PR = 4',linewidth=1.8,color=colors[2]) 80 | axs[1].set_xlim(-200,5200) 81 | axs[1].set_ylim(0.2,1.02) 82 | axs[1].xaxis.set_major_locator(MultipleLocator(1000)) 83 | axs[1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 84 | axs[1].yaxis.set_major_locator(MultipleLocator(0.2)) 85 | axs[1].set_xlabel('Edits',fontweight='bold',family = 'serif') 86 | axs[1].set_ylabel('Generality',fontweight='bold',family = 'serif') 87 | axs[1].set_title('T5Large',fontweight='bold',family = 'serif',size=11) 88 | axs[1].grid(True) 89 | 90 | 91 | 92 | 93 | lines_labels = [axs[1].get_legend_handles_labels()] 94 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] 95 | fig.legend(lines, labels, loc='lower center', ncol=3) 96 | plt.tight_layout() 97 | plt.subplots_adjust(wspace=0.6) 98 | plt.subplots_adjust(wspace=0.6) 99 | plt.subplots_adjust(bottom=0.35) 100 | 101 | plt.savefig('plotres/T5Small_zsre_generality_42.jpg') 102 | 103 | plt.show() 104 | 105 | pass 106 | 107 | 108 | -------------------------------------------------------------------------------- /plot/zsre_rank_lineplot_T5Large.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator) 5 | 6 | colors = ['blue', 'red', 'orange','green'] 7 | 8 | original_his = 0.99 9 | original_up = 0.8111 10 | original_holdout = 0.3012 11 | original_time = 0 12 | 13 | base_dir = 'logs/5000' 14 | settings = ['T5Large_r1e3b4', 'T5Large_r2e3b4', 'T5Large_r4e3b4'] 15 | 16 | r_list = [] 17 | 18 | for index, setting in enumerate(settings): 19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f: 20 | loaded_dict = pickle.load(f) 21 | r_list.append(loaded_dict) 22 | 23 | 24 | t = list(range(0,5001,200)) 25 | 26 | UP = [] 27 | HIS = [] 28 | HOLDOUT = [] 29 | TIME = [] 30 | for index, setting in enumerate(r_list): 31 | up_log = setting['all_UP'] 32 | up_f1 = [original_up] 33 | for x in t[1:]: 34 | up_f1.append(up_log[x]['UP_f1']) 35 | UP.append(up_f1) 36 | 37 | for index, setting in enumerate(r_list): 38 | his_log = setting['all_HIS'] 39 | his_f1 = [his_log[200]['HIS_f1']] 40 | for x in t[1:]: 41 | his_f1.append(his_log[x]['HIS_f1']) 42 | HIS.append(his_f1) 43 | 44 | for index, setting in enumerate(r_list): 45 | holdout_log = setting['all_HOLDOUT'] 46 | holdout_f1 = [original_holdout] 47 | for x in t[1:]: 48 | holdout_f1.append(holdout_log[x]['holdout_f1']) 49 | HOLDOUT.append(holdout_f1) 50 | 51 | for index, setting in enumerate(r_list): 52 | time_log = setting['all_edit_time'] 53 | time_list = [0] 54 | for x in t[1:]: 55 | time_list.append(time_log[x]/60) 56 | TIME.append(time_list) 57 | 58 | fig, axs = plt.subplots(1, 4, figsize=(8,2),dpi=800) 59 | for i in range(4): 60 | axs[i].set_box_aspect(0.85) 61 | 62 | 63 | 64 | axs[0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 65 | axs[0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 66 | axs[0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 67 | axs[0].set_xlim(-200,5200) 68 | axs[0].set_ylim(0.78,1.02) 69 | axs[0].xaxis.set_major_locator(MultipleLocator(1000)) 70 | axs[0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 71 | axs[0].yaxis.set_major_locator(MultipleLocator(0.1)) 72 | axs[0].set_xlabel('Edits',fontweight='bold',family = 'serif') 73 | axs[0].set_ylabel('Edit Success',fontweight='bold',family = 'serif') 74 | axs[0].grid(True) 75 | 76 | axs[1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 77 | axs[1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 78 | axs[1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 79 | axs[1].set_xlim(-200,5200) 80 | axs[1].set_ylim(0.2,1.02) 81 | axs[1].xaxis.set_major_locator(MultipleLocator(1000)) 82 | axs[1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 83 | axs[1].yaxis.set_major_locator(MultipleLocator(0.2)) 84 | axs[1].set_xlabel('Edits',fontweight='bold',family = 'serif') 85 | axs[1].set_ylabel('Holdout',fontweight='bold',family = 'serif') 86 | axs[1].grid(True) 87 | 88 | axs[2].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 89 | axs[2].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 90 | axs[2].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 91 | axs[2].set_xlim(-200,5200) 92 | axs[2].set_ylim(0,15) 93 | axs[2].xaxis.set_major_locator(MultipleLocator(1000)) 94 | axs[2].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 95 | axs[2].yaxis.set_major_locator(MultipleLocator(3)) 96 | axs[2].set_xlabel('Edits',fontweight='bold',family = 'serif') 97 | axs[2].set_ylabel('Time (min)',fontweight='bold',family = 'serif') 98 | axs[2].grid(True) 99 | 100 | 101 | axs[3].plot(t, UP[0], label = 'Partial Rank = 1',linewidth=1.8,color=colors[0]) 102 | axs[3].plot(t, UP[1], label = 'Partial Rank = 2',linewidth=1.8,color=colors[1]) 103 | axs[3].plot(t, UP[2], label = 'Partial Rank = 4',linewidth=1.8,color=colors[2]) 104 | axs[3].set_xlim(-200,5200) 105 | axs[3].set_ylim(0.5,0.9) 106 | axs[3].xaxis.set_major_locator(MultipleLocator(1000)) 107 | axs[3].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 108 | axs[3].yaxis.set_major_locator(MultipleLocator(0.1)) 109 | axs[3].set_xlabel('Edits',fontweight='bold',family = 'serif') 110 | axs[3].set_ylabel('Locality',fontweight='bold',family = 'serif') 111 | axs[3].grid(True) 112 | 113 | 114 | lines_labels = [axs[3].get_legend_handles_labels()] 115 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] 116 | fig.legend(lines, labels, loc='lower center', ncol=4) 117 | plt.tight_layout() 118 | plt.subplots_adjust(wspace=0.5) 119 | plt.subplots_adjust(wspace=0.5) 120 | plt.subplots_adjust(bottom=0.4) 121 | 122 | plt.savefig('plotres/T5Large_zsre_rank.jpg') 123 | 124 | plt.show() 125 | 126 | pass 127 | 128 | 129 | -------------------------------------------------------------------------------- /plot/zsre_rank_lineplot_T5Small.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator) 5 | 6 | colors = ['blue', 'red', 'orange','green'] 7 | 8 | original_his = 0.99 9 | original_up = 0.7225 10 | original_holdout = 0.3012 11 | original_time = 0 12 | 13 | base_dir = 'logs/5000' 14 | settings = ['T5Small_r1e3b4', 'T5Small_r2e3b4', 'T5Small_r4e3b4'] 15 | 16 | r_list = [] 17 | 18 | for index, setting in enumerate(settings): 19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f: 20 | loaded_dict = pickle.load(f) 21 | r_list.append(loaded_dict) 22 | 23 | 24 | t = list(range(0,5001,200)) 25 | 26 | UP = [] 27 | HIS = [] 28 | HOLDOUT = [] 29 | TIME = [] 30 | for index, setting in enumerate(r_list): 31 | up_log = setting['all_UP'] 32 | up_f1 = [original_up] 33 | for x in t[1:]: 34 | up_f1.append(up_log[x]['UP_f1']) 35 | UP.append(up_f1) 36 | 37 | for index, setting in enumerate(r_list): 38 | his_log = setting['all_HIS'] 39 | his_f1 = [his_log[200]['HIS_f1']] 40 | for x in t[1:]: 41 | his_f1.append(his_log[x]['HIS_f1']) 42 | HIS.append(his_f1) 43 | 44 | for index, setting in enumerate(r_list): 45 | holdout_log = setting['all_HOLDOUT'] 46 | holdout_f1 = [original_holdout] 47 | for x in t[1:]: 48 | holdout_f1.append(holdout_log[x]['holdout_f1']) 49 | HOLDOUT.append(holdout_f1) 50 | 51 | for index, setting in enumerate(r_list): 52 | time_log = setting['all_edit_time'] 53 | time_list = [0] 54 | for x in t[1:]: 55 | time_list.append(time_log[x]/60) 56 | TIME.append(time_list) 57 | 58 | fig, axs = plt.subplots(1, 4, figsize=(8,2),dpi=800) 59 | for i in range(4): 60 | axs[i].set_box_aspect(0.85) 61 | 62 | 63 | 64 | axs[0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 65 | axs[0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 66 | axs[0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 67 | axs[0].set_xlim(-200,5200) 68 | axs[0].set_ylim(0.40,1.02) 69 | axs[0].xaxis.set_major_locator(MultipleLocator(1000)) 70 | axs[0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 71 | axs[0].yaxis.set_major_locator(MultipleLocator(0.2)) 72 | axs[0].set_xlabel('Edits',fontweight='bold',family = 'serif') 73 | axs[0].set_ylabel('Edit Success',fontweight='bold',family = 'serif') 74 | axs[0].grid(True) 75 | 76 | axs[1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 77 | axs[1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 78 | axs[1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 79 | axs[1].set_xlim(-200,5200) 80 | axs[1].set_ylim(0.2,1.02) 81 | axs[1].xaxis.set_major_locator(MultipleLocator(1000)) 82 | axs[1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 83 | axs[1].yaxis.set_major_locator(MultipleLocator(0.2)) 84 | axs[1].set_xlabel('Edits',fontweight='bold',family = 'serif') 85 | axs[1].set_ylabel('Holdout',fontweight='bold',family = 'serif') 86 | axs[1].grid(True) 87 | 88 | axs[2].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0]) 89 | axs[2].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1]) 90 | axs[2].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2]) 91 | axs[2].set_xlim(-200,5200) 92 | axs[2].set_ylim(0,3) 93 | axs[2].xaxis.set_major_locator(MultipleLocator(1000)) 94 | axs[2].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 95 | axs[2].yaxis.set_major_locator(MultipleLocator(1)) 96 | axs[2].set_xlabel('Edits',fontweight='bold',family = 'serif') 97 | axs[2].set_ylabel('Time (min)',fontweight='bold',family = 'serif') 98 | axs[2].grid(True) 99 | 100 | 101 | axs[3].plot(t, UP[0], label = 'Partial Rank = 1',linewidth=1.8,color=colors[0]) 102 | axs[3].plot(t, UP[1], label = 'Partial Rank = 2',linewidth=1.8,color=colors[1]) 103 | axs[3].plot(t, UP[2], label = 'Partial Rank = 4',linewidth=1.8,color=colors[2]) 104 | axs[3].set_xlim(-200,5200) 105 | axs[3].set_ylim(0.5,0.9) 106 | axs[3].xaxis.set_major_locator(MultipleLocator(1000)) 107 | axs[3].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k']) 108 | axs[3].yaxis.set_major_locator(MultipleLocator(0.1)) 109 | axs[3].set_xlabel('Edits',fontweight='bold',family = 'serif') 110 | axs[3].set_ylabel('Locality',fontweight='bold',family = 'serif') 111 | axs[3].grid(True) 112 | 113 | 114 | lines_labels = [axs[3].get_legend_handles_labels()] 115 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] 116 | fig.legend(lines, labels, loc='lower center', ncol=4) 117 | plt.tight_layout() 118 | plt.subplots_adjust(wspace=0.5) 119 | plt.subplots_adjust(wspace=0.5) 120 | plt.subplots_adjust(bottom=0.4) 121 | 122 | plt.savefig('plotres/T5Small_zsre_rank.jpg') 123 | 124 | plt.show() 125 | 126 | pass 127 | 128 | 129 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.19.0 2 | antlr4-python3-runtime==4.9.3 3 | appdirs==1.4.4 4 | attrs==23.1.0 5 | certifi @ file:///croot/certifi_1671487769961/work/certifi 6 | charset-normalizer==3.1.0 7 | click==8.1.3 8 | docker-pycreds==0.4.0 9 | filelock==3.12.0 10 | fsspec==2023.1.0 11 | gitdb==4.0.10 12 | GitPython==3.1.31 13 | huggingface-hub==0.15.1 14 | hydra-core==1.3.2 15 | idna==3.4 16 | importlib-metadata==6.6.0 17 | importlib-resources==5.12.0 18 | joblib==1.2.0 19 | jsonlines==3.1.0 20 | nltk==3.8.1 21 | numpy==1.21.6 22 | nvidia-cublas-cu11==11.10.3.66 23 | nvidia-cuda-nvrtc-cu11==11.7.99 24 | nvidia-cuda-runtime-cu11==11.7.99 25 | nvidia-cudnn-cu11==8.5.0.96 26 | omegaconf==2.3.0 27 | packaging==23.1 28 | pathtools==0.1.2 29 | Pillow==9.5.0 30 | protobuf==4.23.2 31 | psutil==5.9.5 32 | PyYAML==6.0 33 | regex==2023.6.3 34 | requests==2.31.0 35 | scikit-learn==1.0.2 36 | scipy==1.7.3 37 | sentence-transformers==2.2.2 38 | sentencepiece==0.1.99 39 | sentry-sdk==1.25.1 40 | setproctitle==1.3.2 41 | six==1.16.0 42 | smmap==5.0.0 43 | threadpoolctl==3.1.0 44 | tokenizers==0.13.3 45 | torch==1.13.1+cu116 46 | torchaudio==0.13.1+cu116 47 | torchvision==0.14.1+cu116 48 | tqdm==4.65.0 49 | transformers==4.29.2 50 | typing_extensions==4.6.3 51 | urllib3==2.0.2 52 | wandb==0.15.4 53 | zipp==3.15.0 54 | --------------------------------------------------------------------------------