├── .gitignore
├── README.md
├── figures
├── bert.png
├── main_00.png
└── table.png
├── melo
├── algs
│ └── lora.py
├── config
│ ├── alg
│ │ └── lora.yaml
│ ├── config.yaml
│ ├── experiment
│ │ ├── hallucination.yaml
│ │ ├── qa.yaml
│ │ └── scotus.yaml
│ └── model
│ │ ├── gpt2xl.yaml
│ │ ├── scotus-bert.yaml
│ │ ├── t5large.yaml
│ │ └── t5small.yaml
├── dataset.py
├── grammar.py
├── hooks.py
├── main.sh
├── metrics.py
├── models.py
├── run.py
├── trainer.py
├── utils.py
└── wiki_bio_concepts.txt
├── peft_egg
├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── docker
│ ├── peft-cpu
│ │ └── Dockerfile
│ └── peft-gpu
│ │ └── Dockerfile
├── docs
│ ├── Makefile
│ ├── README.md
│ └── source
│ │ ├── _config.py
│ │ ├── _toctree.yml
│ │ ├── accelerate
│ │ ├── deepspeed-zero3-offload.mdx
│ │ └── fsdp.mdx
│ │ ├── conceptual_guides
│ │ ├── lora.mdx
│ │ └── prompting.mdx
│ │ ├── index.mdx
│ │ ├── install.mdx
│ │ ├── package_reference
│ │ ├── config.mdx
│ │ ├── peft_model.mdx
│ │ └── tuners.mdx
│ │ ├── quicktour.mdx
│ │ └── task_guides
│ │ ├── clm-prompt-tuning.mdx
│ │ ├── dreambooth_lora.mdx
│ │ ├── image_classification_lora.mdx
│ │ ├── int8-asr.mdx
│ │ ├── ptuning-seq-classification.mdx
│ │ ├── semantic_segmentation_lora.mdx
│ │ ├── seq2seq-prefix-tuning.mdx
│ │ └── token-classification-lora.mdx
├── grammar.py
├── pyproject.toml
├── scripts
│ ├── log_reports.py
│ └── stale.py
├── setup.py
├── src
│ └── peft
│ │ ├── __init__.py
│ │ ├── import_utils.py
│ │ ├── mapping.py
│ │ ├── peft_model.py
│ │ ├── tuners
│ │ ├── __init__.py
│ │ ├── adalora.py
│ │ ├── adaption_prompt.py
│ │ ├── lora.py
│ │ ├── melo.py
│ │ ├── melo_backup.py
│ │ ├── p_tuning.py
│ │ ├── prefix_tuning.py
│ │ └── prompt_tuning.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── hub_utils.py
│ │ ├── other.py
│ │ └── save_and_load.py
└── tests
│ ├── __init__.py
│ ├── test_adaption_prompt.py
│ ├── test_common_gpu.py
│ ├── test_config.py
│ ├── test_decoder_models.py
│ ├── test_encoder_decoder_models.py
│ ├── test_gpu_examples.py
│ ├── testing_common.py
│ └── testing_utils.py
├── plot
├── ablation
│ ├── ablation_cluster_num.py
│ ├── ablation_conflicts_num.py
│ ├── ablation_forget.py
│ ├── ablation_res
│ │ ├── T5Large_zsre_block.jpg
│ │ ├── T5Large_zsre_eps.jpg
│ │ ├── T5Small_cluster.png
│ │ ├── T5Small_conflicts.png
│ │ ├── T5Small_forget.png
│ │ ├── T5Small_zsre_block.jpg
│ │ ├── T5Small_zsre_block_2.jpg
│ │ ├── T5Small_zsre_block_main.jpg
│ │ ├── T5Small_zsre_eps.jpg
│ │ └── pca.jpg
│ ├── pca.jpg
│ ├── res_time.py
│ ├── zsre_block_lineplot_T5Large.py
│ ├── zsre_block_lineplot_T5Small.py
│ ├── zsre_block_lineplot_t5small_2.py
│ ├── zsre_eps_lineplot_T5Large.py
│ └── zsre_eps_lineplot_T5Small.py
├── tsne.py
├── tsne_zsre_T5Small.py
├── zsre_eps_lineplot_T5Large.py
├── zsre_rank_line_plot_generality.py
├── zsre_rank_lineplot_T5Large.py
└── zsre_rank_lineplot_T5Small.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | .git
2 | .idea
3 | data
4 | scr
5 | __pycache__
6 | scr
7 | outputs
8 | checkpoint
9 | .json
10 | plotres
11 | logs
12 | mnist
13 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # MELO: Enhancing Model Editing with Neuron-Indexd Dynamic LoRA
3 | This repo contains the source code of our proposed MELO, a plug-in model editing method, which routes models' behavoir by dynamically indexing LoRA blocks according to a inner vector databse. Seamlessly integrated in [PEFT](https://github.com/huggingface/peft), MELO supports multiple LLMs such as BERT, T5 and GPT.
4 |
5 |
6 | ## Updates
7 | - **2024/03/10:** Add some Important Tips for deployment 🪂
8 | - **2023/12/19:** Repo has been transferred to [ECNU-ICALK/MELO](https://github.com/ECNU-ICALK/MELO) (Organization Account) 🔔
9 | - **2023/12/09:** Our work has been accepted by AAAI 2024 :fire::fire:
10 | - **2023/7/16:** Experiments with multiple LLMs on different editing tasks. :art:
11 | - **2023/6/24:** Inner vector databse that builds accurate editing scope. :confetti_ball:
12 | - **2023/6/08:** Support dynamic LoRA block Loding. :star:
13 |
14 |
15 | ## Table of Contents
16 | - [Reference](#reference)
17 | - [Introduction](#introduction)
18 | - [Experiments](#experiments)
19 | - [Prepare Environments](#prepare-environments)
20 | - [Prepare Datasets](#prepare-datasets)
21 | - [Quick Start](#quick-start)
22 | - [Important Tips](#important-tips)
23 | - [Acknowledgments](#acknowledgments)
24 | ## Reference
25 | We would appreciate if you could refer to our work as one of your baselines!
26 | ```
27 | @article{yu2023melo,
28 | title={MELO: Enhancing Model Editing with Neuron-Indexed Dynamic LoRA},
29 | author={Yu, Lang and Chen, Qin and Zhou, Jie and He, Liang},
30 | journal={arXiv preprint arXiv:2312.11795},
31 | year={2023}
32 | }
33 | ```
34 | ## Introduction
35 | Due to the limitation of catastrophic forgetting and the lack of locality, few studies explore recent advanced Low-rank Adapter (LoRA) techniques for continual model editing. To overcome these limitations and take advantage of LoRA's resource efficiency, we propose MELO, a plug-in model editing method implemented with dynamic LoRA, which routes the behavior of language models by dynamically indexing LoRA blocks according to an inner vector database. MELO considers all editing properties and can be easily integrated into multiple LLMs such as BERT, T5 and GPT. Experimental results show that our proposed MELO achieves state-of-the-art editing performance on three sequential editing tasks (document classification, question answering and hallucination correction), while requires the least trainable parameters and computational cost.
36 | 
37 |
38 | ## Experiments
39 | Comparison of MELO to prior editing methods on sequential editing tasks. Note that MELO edits all language models with a single RTX 3090 GPU.
40 | 
41 |
42 | ## Prepare Environments
43 | Required CUDA environment and library dependencies are listed in:
44 | ```
45 | requirements.txt
46 | ```
47 | Then you should install our modified PEFT:
48 |
🤗 PEFT-MELO
49 |
50 | ```
51 | cd peft_egg
52 | pip install -e .
53 | ```
54 | Detailed implementation of MELO is in `./peft_egg/src/tuners/melo.py`
55 | ## Prepare Datasets
56 | The zsRE experiments use data linked by the [MEND](https://github.com/eric-mitchell/mend) repository. Download the data for NQ and zsRE from their Google Drive link and unzip each sub-directory into ./melo/data. SCOTUS and Hallucination data are loaded through huggingface.
57 |
58 | ## Quick Start
59 | The location of inner vector database and dynamic LoRA target modules can be modified in `./melo/model/config`
60 |
61 | ### Editing GPT2-XL on Hallucination with MELO
62 | ```
63 | cd melo
64 | python run.py +alg=lora +experiment=hallucination +model=gpt2xl
65 | ```
66 |
67 |
68 | ### Editing BERT on SCOTUS with MELO
69 | ```
70 | cd melo
71 | python run.py +alg=lora +experiment=scotus +model=scotus-bert
72 | ```
73 |
74 | ### Editing T5 on zsRE with MELO
75 | ```
76 | cd melo
77 | python run.py +alg=lora +experiment=qa +model=t5small
78 | ```
79 | ## Important Tips
80 | * [Datasets](https://drive.google.com/file/d/1HDqh4ofYF7B-YkcU3CNlZMYAOJO0XxwX/view?usp=drive_link) for MELO's experiments can be downloaded through GoogleDrive now. Please extract the files and place them under `melo\data`.
81 |
82 | * The GPT2-XL model we use is fine-tuned in line with the work [GRACE](https://github.com/Thartvigsen/GRACE/blob/728a52ebcd328ddca0bb1ec975e79625eabfab2a/grace/main.py#L83). Please download the checkpoint with the [Google Drive](https://drive.google.com/drive/folders/1j_DvcUY8goksQVOBt4XqBe7z8fS-0zvI?usp=sharing) link, and place the files under `melo/scr/models--gpt2-xl`
83 |
84 |
85 | * Some [logs](https://drive.google.com/drive/folders/1UhlY1W8MUmvsxqIXlRFBfxuTXEQG8FJP?usp=sharing) recording the correct training and inference processes are released for checking hyper-parameters.
86 |
87 | * The settings of [torch.optim.lr_scheduler](https://github.com/BruthYU/MELO/blob/51c8322cc06faa2b7665c2d90236f1bd1b8d9575/melo/algs/lora.py#L135) vary on different tasks:
88 | ```
89 | # T5-Small and T5-Large
90 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=20,gamma=0.5)
91 | # SCOTUS-BERT and GPT2-XL
92 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=30,gamma=0.5)
93 | ```
94 |
95 |
96 |
97 |
98 | ## Acknowledgments
99 | We would like to thank the following individuals and organizations for their contributions to this project:
100 | ```
101 | Huggingface: for their support of the PEFT community and their development of the PEFT framework (https://github.com/huggingface/peft)
102 |
103 | GRACE: for the development of the open-source library GRACE which inspired our work (https://github.com/Thartvigsen/GRACE)
104 | ```
105 |
--------------------------------------------------------------------------------
/figures/bert.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/figures/bert.png
--------------------------------------------------------------------------------
/figures/main_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/figures/main_00.png
--------------------------------------------------------------------------------
/figures/table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/figures/table.png
--------------------------------------------------------------------------------
/melo/algs/lora.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from omegaconf import OmegaConf
3 | import torch
4 | import copy
5 | import transformers
6 | import logging
7 | import os
8 |
9 | from torch.nn import Parameter
10 |
11 | from utils import *
12 |
13 | from peft import (
14 | PeftModel,
15 | prepare_model_for_int8_training,
16 | MeloConfig,
17 | get_peft_model,
18 | get_peft_model_state_dict,
19 | )
20 | from peft.tuners.melo import LoraLayer, GraceLayer
21 | from hooks import lora_backward_hook
22 | # from models import BertClassifier
23 | LOG = logging.getLogger(__name__)
24 | def translate_tokens(tokens, from_tok, to_tok):
25 | tokens = tokens.masked_fill(tokens == -100, from_tok.pad_token_id)
26 | text = from_tok.batch_decode(tokens, skip_special_tokens=True)
27 | return to_tok(text, return_tensors="pt")["input_ids"].to(tokens.device)
28 |
29 | class LORA(torch.nn.Module):
30 | def __init__(self, model, config, model_tok,scale=None):
31 | super(LORA, self).__init__()
32 | self.config = config
33 |
34 | '''Apply_lora
35 | '''
36 | r_num = config.grace.num_block * config.grace.num_rank_per_block
37 | self.lora_config = MeloConfig(
38 | r = r_num,
39 | lora_alpha = r_num,
40 | target_modules= list(config.model.target_modules),
41 | lora_dropout = config.lora.lora_dropout,
42 | task_type = config.lora_task_type,
43 | fan_in_fan_out= config.model.fan_in_fan_out,
44 | grace_layer = config.model.grace_layer,
45 | grace_config= OmegaConf.to_object(config.grace)
46 | )
47 | self.log_dict = {}
48 |
49 | '''Load
50 | '''
51 | # self.original_model = model
52 | # self.model = model
53 |
54 | if not config.check_dir:
55 | self.model = get_peft_model(model, self.lora_config)
56 | else:
57 | save_path = os.path.join(config.base_dir, "checkpoint", config.check_dir)
58 | self.load_from_checkpoint(save_path)
59 |
60 | self.lora_list = self.named_lora_modules()
61 | self.grace_layer = self.named_grace_layer()
62 | # self.register_lora_backward_hooks(lora_backward_hook)
63 |
64 | '''Load Tokenizer
65 | '''
66 | self.model_tok = model_tok
67 | self.classifier_tok = transformers.AutoTokenizer.from_pretrained(config.lora.cls_name)
68 |
69 | '''Parameters to be optimized
70 | '''
71 | self.opt_params = self.optim_parameters()
72 | pass
73 |
74 |
75 | def optim_parameters(self):
76 | for name, param in self.model.named_parameters():
77 | if param.requires_grad==True and 'lora' not in name:
78 | param.requires_grad = False
79 | lora_params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
80 | return lora_params
81 |
82 |
83 |
84 |
85 | #TODO
86 | def load_from_checkpoint(self, save_path):
87 | print(save_path)
88 |
89 |
90 | def save_classifier_weights(self,cls_dir):
91 | if not os.path.exists(cls_dir):
92 | os.makedirs(cls_dir,exist_ok=True)
93 | torch.save(self.classifier.state_dict(),f"{cls_dir}/classifier.pt")
94 | def save_lora_weights(self,lora_dir):
95 | self.model.save_pretrained(lora_dir+"/lora_checkpoint")
96 |
97 |
98 | def reset_lora(self):
99 | for key in self.lora_list:
100 | self.model.get_submodule(key).reset_lora_parameters('default')
101 |
102 | def named_lora_modules(self):
103 | module_list = [key for key,_ in self.model.named_modules()]
104 | lora_list = []
105 | for key in module_list:
106 | if isinstance(self.model.get_submodule(key),LoraLayer):
107 | lora_list.append(key)
108 | return lora_list
109 |
110 | def named_grace_layer(self) -> str:
111 | module_list = [key for key, _ in self.model.named_modules()]
112 | grace_list = []
113 | for key in module_list:
114 | if isinstance(self.model.get_submodule(key), GraceLayer):
115 | grace_list.append(key)
116 | assert len(grace_list) == 1, "At Most One Grace Layer"
117 | return grace_list[0]
118 |
119 | def register_lora_backward_hooks(self,backward_hook_fn):
120 | for key in self.lora_list:
121 | self.model.get_submodule(key).register_backward_hook(backward_hook_fn)
122 |
123 |
124 | def disable_melo(self):
125 | self.model.base_model.disable_adapter_layers()
126 | self.model.base_model.disable_grace_layer()
127 |
128 | def enable_melo(self):
129 | self.model.base_model.enable_adapter_layers()
130 | self.model.base_model.enable_grace_layer()
131 |
132 |
133 | def edit(self, tokens):
134 | optimizer = torch.optim.Adam(self.optim_parameters(), self.config.grace.edit_lr)
135 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=20,gamma=0.5)
136 | # --- pass edit label, training mode, and key_id into GRACE ---
137 | setattr(self.model.get_submodule(self.grace_layer), "training", True)
138 | setattr(self.model.get_submodule(self.grace_layer), "edit_label", tokens["labels"])
139 |
140 | self.losses = []
141 | for i in range(self.config.grace.num_iter):
142 | # --- insert iteration into each layer (only initiate keys on first iteration) ---
143 | setattr(self.model.get_submodule(self.grace_layer), "batch_iter", i)
144 |
145 | # --- pass tokens through model (including through the GRACE layer) ---
146 | outputs = self.model.model(**tokens)
147 | loss = outputs.loss
148 | loss.backward()
149 | optimizer.step()
150 | optimizer.zero_grad()
151 | scheduler.step()
152 | self.losses.append(loss.detach().cpu().numpy())
153 | LOG.info(f'batch loss in iter {i}: {loss.detach().cpu().numpy()}')
154 | self.loss = loss # Log final loss
155 |
156 | setattr(self.model.get_submodule(self.grace_layer), "training", False)
157 |
158 |
159 |
160 |
161 | def generate(self, *args, **kwargs):
162 | return self.model.model.generate(*args, **kwargs)
163 |
164 | def get_VecDB_info(self):
165 | VecDB_logdict = {}
166 | VecDB_logdict["num_cluster"] = len(getattr(self.model.get_submodule(self.grace_layer), "VecDB"))
167 | VecDB_logdict["conflict_num"] = getattr(self.model.get_submodule(self.grace_layer), "VecDB").conflict_num
168 | VecDB_logdict["forget_keys"] = len(getattr(self.model.get_submodule(self.grace_layer), "VecDB").forget_keys)
169 | return VecDB_logdict
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 | if __name__ == '__main__':
181 | pass
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
--------------------------------------------------------------------------------
/melo/config/alg/lora.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | alg: lora
4 | lr: 1e-4
5 | train_base: False
6 | lr_lr: 1e-4
7 | lora:
8 | cls_name: distilbert-base-cased
9 | cls_class: AutoModel
10 | supervised: true
11 | cos: false
12 | freeze: null
13 | square: true
14 | bound_embeds: false
15 | use_all_negatives: false
16 | freeze_lora: false
17 | dist_heads: 1
18 | cross_attend: false
19 | soft_weighting: false
20 | checkpoint_grad: false
21 | lora_r: 64
22 | lora_alpha: 64
23 | lora_dropout: 0.0
24 |
--------------------------------------------------------------------------------
/melo/config/config.yaml:
--------------------------------------------------------------------------------
1 | alg: enn
2 | seed: 0
3 | debug: False
4 | model_save_pt: 5000
5 | edit_bs: 1
6 | silent: False
7 | max_iters: 200100
8 | log_interval: 100
9 | val_interval: 5000
10 | batch_size: 4
11 | val_batch_size: 4
12 | accumulate_bs: 10
13 | cedit: 0.2
14 | cloc: 1.0
15 | cbase: 1.0
16 | val_steps: 500
17 | device: cuda
18 | base_loss: distill
19 | oracle: False
20 | train: True
21 | train_base: True
22 | opt: Adam
23 | single_batch: False
24 | archive: null
25 | grad_clip: 100.
26 | ref: null
27 | early_stop_patience: 40000
28 | early_stop_key: "mixture/acc_val"
29 | dropout: 0.0
30 | tokenizer: null
31 | results_dir: null
32 | no_grad_layers: null
33 | eval_only: False
34 | half: False
35 | save: False
36 | log_errors: False
37 | unlikelihood: True
38 | check_dir: null
39 | batch_round: 10
40 | re_init_model: False
41 | max_n_edits: 5000
42 |
43 | model:
44 | pt: null
45 |
46 | data:
47 | path: null
48 | rephrase: true
49 | zsre_nq: false
50 | zsre_impl: false
51 | zsre_impl_path: ${hydra:runtime.cwd}/data/zsre/impl_{}.json
52 | zsre_yn: false
53 | zsre_yn_path: ${hydra:runtime.cwd}/data/zsre/zsre_yn_{}.txt
54 | zsre_eval_idxs: null
55 | zsre_path: ${hydra:runtime.cwd}/data/zsre/structured_zeroshot-{}-new_annotated_final.jsonl
56 | nq_path: ${hydra:runtime.cwd}/data/nq
57 | wiki_webtext: true
58 | n_edits: 1
59 | hard_neg: false
60 | hard_neg_neighbors: 100
61 | hard_neg_exclude: 25
62 | hard_neg_temp: 0.1
63 | hard_neg_prob: 0.5
64 | flip_inner_outer: false
65 | sent_eval_sample: false
66 | n_outer_max: null
67 |
68 | eval:
69 | verbose: True
70 | log_interval: 100
71 | final_eval: True
72 |
73 | hydra:
74 | run:
75 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f${uuid:}}
76 | sweep:
77 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f}
78 | subdir: ${hydra.job.num}
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/melo/config/experiment/hallucination.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | task: hallucination
3 | lora_task_type: CAUSAL_LM
4 | grace:
5 | _name: grace
6 | num_iter: 50
7 | init_radius: 0.5
8 | dist_fn: euc # euc, mmd, cos
9 | val_init: cold # cold, warm
10 | val_train: sgd # sgd, pert
11 | val_reg: None # early
12 | reg: early_stop # early_stop
13 | replacement: replace_prompt # replace_last, replace_all, replace_prompt
14 | expand_mode: moving_avg # , moving_avg, decay
15 | num_pert: 8 # only matters when using perturbation training
16 | key_id: -1
17 | num_edit_per_block: 4
18 | num_block: 350
19 | num_rank_per_block: 2
20 | metric_period: 400
21 | edit_lr: 1e-3
--------------------------------------------------------------------------------
/melo/config/experiment/qa.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | task: qa
3 | lora_task_type: SEQ_2_SEQ_LM
4 | grace:
5 | _name: grace
6 | num_iter: 50
7 | init_radius: 75
8 | dist_fn: euc # euc, mmd, cos
9 | val_init: cold # cold, warm
10 | val_train: sgd # sgd, pert
11 | val_reg: None # early
12 | reg: early_stop # early_stop
13 | replacement: replace_prompt # replace_last, replace_all, replace_prompt
14 | expand_mode: moving_avg # , moving_avg, decay
15 | num_pert: 8 # only matters when using perturbation training
16 | key_id: -1
17 | num_edit_per_block: 100
18 | num_block: 20
19 | num_rank_per_block: 2
20 | metric_period: 200
21 | edit_lr: 1e-2
22 |
--------------------------------------------------------------------------------
/melo/config/experiment/scotus.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | task: scotus
3 | lora_task_type: SEQ_CLS
4 | grace:
5 | _name: grace
6 | num_iter: 50
7 | init_radius: 0.1
8 | dist_fn: euc # euc, mmd, cos
9 | val_init: cold # cold, warm
10 | val_train: sgd # sgd, pert
11 | val_reg: None # early
12 | reg: early_stop # early_stop
13 | replacement: replace_prompt # replace_last, replace_all, replace_prompt
14 | expand_mode: moving_avg # , moving_avg, decay
15 | num_pert: 8 # only matters when using perturbation training
16 | key_id: -1
17 | num_edit_per_block: 50
18 | num_block: 100
19 | num_rank_per_block: 4
20 | metric_period: 200
21 | edit_lr: 1e-2
--------------------------------------------------------------------------------
/melo/config/model/gpt2xl.yaml:
--------------------------------------------------------------------------------
1 | name: gpt2-xl
2 | class_name: GPT2LMHeadModel
3 | tokenizer_class: GPT2TokenizerFast
4 | tokenizer_name: gpt2-xl
5 |
6 | fan_in_fan_out: True
7 | target_modules:
8 | - transformer.h.36.mlp.c_fc
9 | - transformer.h.37.mlp.c_fc
10 |
11 |
12 | #pt: null
13 | pt: /home/yu/ECNU/MELO/melo/checkpoint # set this to 'hallucination' inside your checkpoint directory
14 | #model.base_model.h[35]
15 |
16 |
17 |
18 | grace_layer: transformer.h.35.mlp.c_fc
--------------------------------------------------------------------------------
/melo/config/model/scotus-bert.yaml:
--------------------------------------------------------------------------------
1 | name: tomh/scotus-bert
2 | class_name: AutoModelForSequenceClassification
3 | tokenizer_class: AutoTokenizer
4 | tokenizer_name: bert-base-cased
5 | fan_in_fan_out: False
6 | target_modules:
7 | - bert.encoder.layer.9.output.dense
8 | - bert.encoder.layer.10.output.dense
9 | - bert.encoder.layer.11.output.dense
10 | grace_layer: bert.encoder.layer.4.output.dense
11 |
12 | pt: null
--------------------------------------------------------------------------------
/melo/config/model/t5large.yaml:
--------------------------------------------------------------------------------
1 | name: google/t5-large-ssm-nq
2 | class_name: AutoModelForSeq2SeqLM
3 | tokenizer_class: AutoTokenizer
4 | tokenizer_name: google/t5-large-ssm-nq
5 | fan_in_fan_out: False
6 | inner_params:
7 | - encoder.block.22.layer.1.DenseReluDense.wi.weight
8 | - encoder.block.22.layer.1.DenseReluDense.wo.weight
9 | - encoder.block.23.layer.1.DenseReluDense.wi.weight
10 | - encoder.block.23.layer.1.DenseReluDense.wo.weight
11 | - decoder.block.22.layer.2.DenseReluDense.wi.weight
12 | - decoder.block.22.layer.2.DenseReluDense.wo.weight
13 | - decoder.block.23.layer.2.DenseReluDense.wi.weight
14 | - decoder.block.23.layer.2.DenseReluDense.wo.weight
15 |
16 |
17 | target_modules:
18 | - encoder.block.22.layer.1.DenseReluDense.wi
19 | - encoder.block.22.layer.1.DenseReluDense.wo
20 | - encoder.block.23.layer.1.DenseReluDense.wi
21 | - encoder.block.23.layer.1.DenseReluDense.wo
22 | - decoder.block.22.layer.2.DenseReluDense.wi
23 | - decoder.block.22.layer.2.DenseReluDense.wo
24 | - decoder.block.23.layer.2.DenseReluDense.wi
25 | - decoder.block.23.layer.2.DenseReluDense.wo
26 |
27 | grace_layer: encoder.block.12.layer.1.DenseReluDense.wo
28 | #alg.model.base_model.model.encoder.block[4].layer[1].DenseReluDense.wo
29 | #self.model.model.encoder.block[22].layer[1].DenseReluDense.wo.lora_B['default'][4:8]
--------------------------------------------------------------------------------
/melo/config/model/t5small.yaml:
--------------------------------------------------------------------------------
1 | name: google/t5-small-ssm-nq
2 | class_name: AutoModelForSeq2SeqLM
3 | tokenizer_class: AutoTokenizer
4 | tokenizer_name: google/t5-small-ssm-nq
5 | fan_in_fan_out: False
6 | target_modules:
7 | - encoder.block.5.layer.1.DenseReluDense.wi
8 | - encoder.block.5.layer.1.DenseReluDense.wo
9 | - decoder.block.5.layer.2.DenseReluDense.wi
10 | - decoder.block.5.layer.2.DenseReluDense.wo
11 | - encoder.block.6.layer.1.DenseReluDense.wi
12 | - encoder.block.6.layer.1.DenseReluDense.wo
13 | - decoder.block.6.layer.2.DenseReluDense.wi
14 | - decoder.block.6.layer.2.DenseReluDense.wo
15 |
16 | pt: null
17 |
18 | grace_layer: encoder.block.4.layer.1.DenseReluDense.wo
--------------------------------------------------------------------------------
/melo/grammar.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
--------------------------------------------------------------------------------
/melo/hooks.py:
--------------------------------------------------------------------------------
1 | def lora_backward_hook(lora_module, grad_in, grad_out):
2 | print("Module_name: ", lora_module)
3 | print("grad_in[0]_shape: ",grad_in[0].shape)
4 | print("grad_in[0]: ", grad_in[0])
5 | print("grad_out[0]_shape: ", grad_out[0].shape)
6 | print("grad_out[0]: ", grad_out[0])
7 | print("-------------------------------")
--------------------------------------------------------------------------------
/melo/main.sh:
--------------------------------------------------------------------------------
1 | +alg=lora +experiment=qa +model=t5small
2 | +alg=lora +experiment=hallucination +model=gpt2xl
3 | +alg=lora +experiment=scotus +model=scotus-bert
4 |
5 |
6 | CUDA_VISIBLE_DEVICES=3 python few_shot_run.py +alg=lora +experiment=fnli +model=bert-base batch_size=10 val_batch_size=10 lora.cross_attend=True
7 |
8 |
9 |
10 | CUDA_VISIBLE_DEVICES=3 python few_shot_run.py +alg=lora +experiment=qa +model=t5large batch_size=10 val_batch_size=10 data.zsre_impl=false data.zsre_yn=false data.hard_neg=false
--------------------------------------------------------------------------------
/melo/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils import *
3 | import logging
4 | LOG = logging.getLogger(__name__)
5 |
6 |
7 | # DEPRECATED
8 | def sent_success(pre_edit_probs, post_edit_probs, pos_mask, eps=torch.finfo(torch.float32).eps, batch_size=20):
9 | assert False, "No longer used"
10 | # content_score = post_edit_probs[pos_mask].prod() ** (1/pos_mask.sum()) / (pre_edit_probs[pos_mask]. + eps)
11 | post_pos_avg = post_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum())
12 | pre_pos_avg = pre_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum())
13 | content_score = post_pos_avg / (pre_pos_avg + eps)
14 | z_content = min(1., content_score)
15 |
16 | # compute z_sent through a weighting objective
17 | # normalized_probs = post_edit_probs / (post_edit_probs.sum() + eps)
18 | # balancing_factor = 0.5 * ((~pos_mask).float().sum() / pos_mask.float().sum() + 1)
19 | # z_sent_weight = balancing_factor * normalized_probs.dot(pos_mask.float())
20 | post_neg_avg = post_edit_probs[~pos_mask].prod() ** (1 / (~pos_mask).sum())
21 | neg_over_pos = post_neg_avg / (eps + post_pos_avg)
22 | z_sent_weight = 1 / (1 + neg_over_pos)
23 |
24 | # compute z_sent through a ranking objective
25 | batch_mask = pos_mask.view(-1, batch_size).long()
26 | sort_idxs = post_edit_probs.view(-1, batch_size).sort(-1, descending=True).indices
27 | ranked_mask = batch_mask.gather(1, sort_idxs)
28 | true_mask = batch_mask.sort(-1, descending=True).values
29 | z_sent_rank = (ranked_mask == true_mask).float().mean()
30 |
31 | # compute the final success scores
32 | weight_success = (z_content * z_sent_weight) ** 0.5
33 | rank_success = (z_content * z_sent_rank) ** 0.5
34 |
35 | correct_probs = post_edit_probs[pos_mask].mean()
36 | wrong_probs = post_edit_probs[~pos_mask].mean()
37 |
38 | return {
39 | "acc_weight": weight_success,
40 | "acc_rank": rank_success,
41 | "rank_score": z_sent_rank,
42 | "weight_score": z_sent_weight,
43 | "content_score": content_score,
44 | "post_edit_probs": post_edit_probs.sum(),
45 | "pre_edit_probs": pre_edit_probs.sum(),
46 | "correct_probs": correct_probs,
47 | "wrong_probs": wrong_probs
48 | }
49 |
50 |
51 |
52 | # For zsRE and F-NLI
53 | def retain_rate(pre_logits, post_logits, mask=None):
54 | if pre_logits.shape[-1] == 1:
55 | pre_logits = pre_logits.squeeze(-1)
56 | if post_logits.shape[-1] == 1:
57 | post_logits = post_logits.squeeze(-1)
58 |
59 | assert pre_logits.shape == post_logits.shape
60 | assert pre_logits.shape[0] == mask.shape[0]
61 |
62 | if pre_logits.dim() == 1:
63 | # binary classification
64 | pre_preds = pre_logits > 0
65 | post_preds = post_logits > 0
66 | retain = (pre_preds == post_preds).float().mean()
67 | elif pre_logits.dim() == 3:
68 | # sequence modeling
69 | pre_preds = pre_logits.argmax(-1)
70 | post_preds = post_logits.argmax(-1)
71 | match = (pre_preds == post_preds) * mask
72 | retain = (match.sum(-1) == mask.sum(-1)).float().mean()
73 | else:
74 | raise NotImplementedError
75 |
76 | return retain.item()
77 |
78 |
79 | def is_acc_error(model, tokens):
80 | # Check whether or not the model's prediction for a batch element is correct
81 | labels = tokens["labels"]
82 | logits = model(**tokens).logits
83 | probs = torch.softmax(logits, -1).squeeze()
84 | argmaxs = torch.argmax(probs, dim=-1).squeeze()
85 | return labels != argmaxs
86 |
87 |
88 | def Accuracy(alg, tokens):
89 | labels = tokens["labels"]
90 | new_tokens = {f"{k}": v for k, v in tokens.items() if k != "labels"}
91 | logits = alg.model(**new_tokens).logits
92 | probs = torch.softmax(logits, -1).squeeze()
93 | argmaxs = torch.argmax(probs, dim=-1).squeeze()
94 | return (labels == argmaxs).float().mean()
95 |
96 |
97 | def is_qa_error(model, tokens):
98 | preds = model.generate(tokens["input_ids"], max_length=20).squeeze() # Run model to get its predictions
99 | labels = tokens["labels"] # [tokens["labels"] != -100]
100 |
101 | if (len(preds) != len(labels)) or ((preds == labels).sum() != len(preds)):
102 | return True
103 | else:
104 | return False
105 |
106 |
107 | def PPL(alg, batch):
108 | input_ids = batch["input_ids"][:, :1024] # .to(device)
109 | if "labels" not in batch:
110 | batch["labels"] = batch["input_ids"][:, :1024].clone()
111 | else:
112 | batch["labels"] = batch["labels"][:, :1024].clone()
113 |
114 | with torch.no_grad():
115 | #outputs = alg.model.model(input_ids=input_ids, labels=target_ids)
116 | outputs = alg.model(**batch)
117 | nll = outputs.loss
118 |
119 | ppl = torch.exp(nll) # .clip(0, 100)
120 | return ppl
121 |
122 |
123 |
124 | def F1_ACC(alg, batch):
125 | try:
126 | preds = alg.generate(batch["input_ids"], max_length=20).squeeze()
127 | f1 = F1(preds, batch, alg.model_tok)
128 | acc = ACC(preds, batch, alg.model_tok)
129 | return f1, acc
130 | except Exception as e:
131 | raise e
132 |
133 | def F1(preds, batch, tok):
134 | try:
135 | f1_list = []
136 | for p, g in zip(preds,batch["labels"]):
137 | p = p[p != tok.pad_token_id].cpu().squeeze()
138 | g = g[g != -100].cpu().squeeze() # -100 might be nonsense
139 | num_same = len(np.intersect1d(p, g))
140 | len_pred = len(p)
141 | len_gold = len(g)
142 | precision = num_same / len_pred
143 | recall = 1.0 * num_same / len_gold
144 | f1 = (2 * precision * recall) / (precision + recall)
145 | f1_list.append(f1)
146 | except:
147 | return 0.
148 |
149 | return sum(f1_list) / len(f1_list)
150 |
151 |
152 | def ACC(preds, batch, tok):
153 | decode_preds = tok.batch_decode(preds,skip_special_tokens=True)
154 | gold_labels = batch['labels']
155 | gold_labels = gold_labels.masked_fill(gold_labels == -100,tok.pad_token_id)
156 | decode_labels = tok.batch_decode(gold_labels,skip_special_tokens=True)
157 | assert len(decode_labels) == len(decode_preds), "Lengths of decode_preds and decode_labels should be the same"
158 | count = 0.
159 | for pred,label in zip(decode_preds, decode_labels):
160 | if pred == label:
161 | count = count + 1
162 | return count/len(decode_preds)
163 |
--------------------------------------------------------------------------------
/melo/models.py:
--------------------------------------------------------------------------------
1 | import transformers
2 | import torch
3 | import torch.nn as nn
4 | import re
5 | import logging
6 | # from torch.nn import FixableDropout
7 | from utils import scr
8 |
9 |
10 | LOG = logging.getLogger(__name__)
11 |
12 |
13 | class CastModule(nn.Module):
14 | def __init__(self, module: nn.Module, in_cast: torch.dtype = torch.float32, out_cast: torch.dtype = None):
15 | super().__init__()
16 |
17 | self.underlying = module
18 | self.in_cast = in_cast
19 | self.out_cast = out_cast
20 |
21 | def cast(self, obj, dtype):
22 | if dtype is None:
23 | return obj
24 |
25 | if isinstance(obj, torch.Tensor):
26 | return obj.to(dtype)
27 | else:
28 | return obj
29 |
30 | def forward(self, *args, **kwargs):
31 | args = tuple(self.cast(a, self.in_cast) for a in args)
32 | kwargs = {k: self.cast(v, self.in_cast) for k, v in kwargs.items()}
33 | outputs = self.underlying(*args, **kwargs)
34 | if isinstance(outputs, torch.Tensor):
35 | outputs = self.cast(outputs, self.out_cast)
36 | elif isinstance(outputs, tuple):
37 | outputs = tuple(self.cast(o, self.out_cast) for o in outputs)
38 | else:
39 | raise RuntimeError(f"Not sure how to cast type {type(outputs)}")
40 | return outputs
41 |
42 | def extra_repr(self):
43 | return f"in_cast: {self.in_cast}\nout_cast: {self.out_cast}"
44 |
45 |
46 | class BertClassifier(torch.nn.Module):
47 | def __init__(self, model_name, hidden_dim=768):
48 | super().__init__()
49 | if model_name.startswith("bert"):
50 | LOG.info(f"Loading model class {model_name}, cache dir {scr()}")
51 | self.model = transformers.BertModel.from_pretrained(model_name, cache_dir=scr())
52 | else:
53 | self.model = transformers.AutoModel.from_pretrained(model_name, cache_dir=scr())
54 | self.classifier = torch.nn.Linear(hidden_dim, 1)
55 |
56 | @property
57 | def config(self):
58 | return self.model.config
59 |
60 | def forward(self, *args, **kwargs):
61 | filtered_kwargs = {k: v for k, v in kwargs.items() if k != "labels"}
62 | model_output = self.model(*args, **filtered_kwargs)
63 | if "pooler_output" in model_output.keys():
64 | pred = self.classifier(model_output.pooler_output)
65 | else:
66 | pred = self.classifier(model_output.last_hidden_state[:, 0])
67 |
68 | if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]:
69 | last_hidden_state = model_output.last_hidden_state
70 | return pred, last_hidden_state
71 | else:
72 | return pred
73 |
74 |
75 |
76 | def get_hf_model(config):
77 | ModelClass = getattr(transformers, config.model.class_name)
78 | LOG.info(f"Loading model class {ModelClass} with name {config.model.name} from cache dir {scr()}")
79 | if config.model.pt is None:
80 | model = ModelClass.from_pretrained(config.model.name, cache_dir=scr())
81 | elif config.re_init_model:
82 | print("Downloading untrained model.")
83 | model = ModelClass.from_pretrained(config.model.name)
84 | else:
85 | try:
86 | # try to load specified model from local dir
87 | model = ModelClass.from_pretrained(config.model.pt)
88 | print(f"Loaded model: {config.model.pt}")
89 | except:
90 | print("Couldn't load model: {config.model.pt}. Downloading new model.")
91 | model = ModelClass.from_pretrained(config.model.name, cache_dir=scr())
92 |
93 |
94 | if config.dropout is not None:
95 | n_reset = 0
96 | for m in model.modules():
97 | if isinstance(m, nn.Dropout):
98 | m.p = config.dropout
99 | n_reset += 1
100 |
101 | if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout
102 | if isinstance(m.dropout, float):
103 | m.dropout = config.dropout
104 | n_reset += 1
105 |
106 | if hasattr(m, "activation_dropout"): # Requires for BART, which uses F.dropout
107 | if isinstance(m.activation_dropout, float):
108 | m.activation_dropout = config.dropout
109 | n_reset += 1
110 |
111 | LOG.info(f"Set {n_reset} dropout modules to p={config.dropout}")
112 | return model
113 |
114 |
115 |
116 | def get_tokenizer(config):
117 | tok_name = config.model.tokenizer_name if config.model.tokenizer_name is not None else config.model.name
118 | tokenizer = getattr(transformers, config.model.tokenizer_class).from_pretrained(tok_name, cache_dir=scr())
119 | if not tokenizer.pad_token:
120 | tokenizer.pad_token = tokenizer.eos_token
121 | return tokenizer
122 |
123 | def get_processor(config):
124 | processor_name = config.model.processor_name if config.model.processor_name is not None else config.model.name
125 | processor = getattr(transformers, config.model.processor_class).from_pretrained(processor_name, cache_dir = scr())
126 | return processor
127 |
128 |
129 | if __name__ == '__main__':
130 | m = BertClassifier("bert-base-uncased")
131 | m(torch.arange(5)[None, :])
132 | import pdb; pdb.set_trace()
133 |
--------------------------------------------------------------------------------
/melo/run.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | import importlib
4 | import logging
5 | from time import time
6 | import hydra
7 | from omegaconf import OmegaConf,open_dict
8 | import numpy as np
9 | import torch
10 | from utils import *
11 | from torch.utils.data import DataLoader
12 | from tqdm import tqdm
13 | import pickle
14 | import models
15 | from trainer import zsre_trainer, hallucination_trainer, scotus_trainer
16 | #TODO
17 | #SUPPORT MELO CONV1D fan_in_fan_out
18 |
19 | # os.environ['http_proxy'] = '127.0.0.1:7890'
20 | # os.environ['https_proxy'] = '127.0.0.1:7890'
21 | OmegaConf.register_new_resolver("uuid", lambda: uuid())
22 | LOG = logging.getLogger(__name__)
23 | @hydra.main(config_path='config', config_name='config')
24 | def run(config):
25 | grace_config_keys = ['edit_lr','init_radius','expand_mode','key_id','num_edit_per_block','num_block','num_rank_per_block']
26 | model_config_keys = ['target_modules','grace_layer']
27 | GRACE_CONFIG = dict(config.grace)
28 | MODEL_CONFIG = dict(config.model)
29 |
30 | for k in grace_config_keys:
31 | LOG.info(f'[-GRACE CONFIG-] {k}: {GRACE_CONFIG[k]}')
32 | for k in model_config_keys:
33 | LOG.info(f'[-MODEL CONFIG-] {k}: {MODEL_CONFIG[k]}')
34 |
35 | base_dir = hydra.utils.get_original_cwd()
36 | with open_dict(config):
37 | config.base_dir = base_dir
38 |
39 | random.seed(config.seed)
40 | np.random.seed(config.seed)
41 | torch.manual_seed(config.seed)
42 |
43 | if config.task == "qa" or config.task == "zsre":
44 | model = models.get_hf_model(config)
45 | elif config.task == "hallucination":
46 | model = models.get_hf_model(config)
47 | elif config.task == "scotus":
48 | model = models.get_hf_model(config)
49 | else:
50 | print(f"{config.task} task not found")
51 |
52 |
53 |
54 | model.to(config.device)
55 | tokenizer = models.get_tokenizer(config)
56 |
57 |
58 | '''
59 | Load Dataset
60 | '''
61 | if config.task == "qa" or config.task == "zsre":
62 | from dataset import NQ, zsRE, zsRE_balanced
63 | from metrics import F1_ACC, is_qa_error
64 | upstream = NQ()
65 | edits = zsRE_balanced(split="edit", n_edits=1000)
66 | edit_holdouts = zsRE_balanced(split="holdout", n_edits=1000)
67 |
68 | '''Get Loaders
69 | '''
70 | batch_size = config.grace.num_edit_per_block
71 | edit_loader = DataLoader(edits, batch_size=batch_size, shuffle=True)
72 | edit_holdout_loader = DataLoader(edit_holdouts, batch_size=batch_size, shuffle=False)
73 | upstream_loader = DataLoader(upstream, batch_size=batch_size, shuffle=False)
74 | hold_out = 0
75 | '''Define Metrics
76 | '''
77 | metric = F1_ACC # Measure QA F1
78 | is_error = is_qa_error
79 | tokenize = tokenize_qa
80 |
81 | elif config.task == "hallucination":
82 | from dataset import Hallucination, WebText10k
83 | from metrics import PPL
84 | upstream = WebText10k()
85 | edits = Hallucination(split="edit")
86 | accurate_dataset = Hallucination(split="accurate")
87 |
88 | '''Get Loaders
89 | '''
90 | batch_size = config.grace.num_edit_per_block
91 | edit_loader = DataLoader(edits, batch_size=batch_size, shuffle=False)
92 | accurate_loader = DataLoader(accurate_dataset, batch_size=batch_size, shuffle=False)
93 | upstream_loader = DataLoader(upstream, batch_size=batch_size, shuffle=False)
94 | '''Define Metrics
95 | '''
96 | metric = PPL # Measure QA F1
97 | tokenize = tokenize_gpt
98 | elif config.task == 'scotus':
99 | from dataset import SCOTUS
100 | from metrics import Accuracy
101 | upstream = SCOTUS("train")
102 | edits = SCOTUS("edit")
103 |
104 | '''Get Loaders
105 | '''
106 | batch_size = config.grace.num_edit_per_block
107 | edit_loader = DataLoader(edits, batch_size=batch_size, shuffle=False)
108 | upstream_loader = DataLoader(upstream, batch_size=batch_size, shuffle=False)
109 | '''Define Metrics
110 | '''
111 | metric = Accuracy
112 | tokenize = tokenize_clf
113 | else:
114 | print(f"{config.task} task not found")
115 |
116 | alg_module = importlib.import_module(f'algs.{config.alg}')
117 | AlgClass = getattr(alg_module,config.alg.upper())
118 | alg = AlgClass(model,config,tokenizer)
119 | alg.to(config.device)
120 |
121 | # Trainer
122 | if config.task == "qa" or config.task == "zsre":
123 | trainer = zsre_trainer(config,alg,tokenize,metric,edit_loader,upstream_loader,edit_holdout_loader)
124 | elif config.task == "hallucination":
125 | trainer = hallucination_trainer(config,alg,tokenize,metric,edit_loader,upstream_loader,accurate_loader)
126 | elif config.task == "scotus":
127 | trainer = scotus_trainer(config,alg,tokenize,metric,edit_loader,upstream_loader)
128 |
129 | # trainer.pre_editing_analyse()
130 | torch.cuda.empty_cache()
131 | trainer.run_edit()
132 |
133 |
134 | if __name__ == '__main__':
135 | run()
--------------------------------------------------------------------------------
/melo/utils.py:
--------------------------------------------------------------------------------
1 | import transformers
2 | import torch
3 | import os
4 | import numpy as np
5 | import datetime
6 | import struct
7 | from torch.nn.utils.rnn import pad_sequence
8 | import torch.nn.functional as F
9 | import wandb
10 | import hydra
11 |
12 | def get_inner_params(named_parameters, inner_names):
13 | param_dict = dict(named_parameters)
14 | return [(n, param_dict[n]) for n in inner_names]
15 |
16 |
17 | def param_subset(named_parameters, inner_names):
18 | param_dict = dict(named_parameters)
19 | return [param_dict[n] for n in inner_names]
20 |
21 |
22 | def parent_module(model, pname):
23 | components = pname.split('.')
24 | parent = model
25 |
26 | for component in components[:-1]:
27 | if hasattr(parent, component):
28 | parent = getattr(parent, component)
29 | elif component.isdigit():
30 | parent = parent[int(component)]
31 | else:
32 | raise RuntimeError(f"Couldn't find child module {component}")
33 |
34 | if not hasattr(parent, components[-1]):
35 | raise RuntimeError(f"Couldn't find child module {components[-1]}")
36 |
37 | return parent
38 |
39 |
40 | def uuid(digits=4):
41 | if not hasattr(uuid, "uuid_value"):
42 | uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10 ** digits)
43 |
44 | return uuid.uuid_value
45 |
46 | def scr():
47 | base_dir = hydra.utils.get_original_cwd()
48 | if os.path.exists(os.path.join(base_dir,"scr-ssd")):
49 | scr_dir = os.path.join(base_dir,"scr-ssd")
50 | else:
51 | scr_dir = os.path.join(base_dir,"scr")
52 |
53 | if not os.path.exists(scr_dir):
54 | os.makedirs(scr_dir)
55 |
56 | return scr_dir
57 | def ckpt_dir():
58 | """returns the directory in which to store model checkpoints"""
59 | path = "./ckpts/"
60 | if not os.path.exists(path):
61 | os.makedirs(path)
62 | return path
63 |
64 |
65 | def brackets_to_periods(name):
66 | return name.replace("[", ".").replace("]", "")
67 |
68 |
69 | def get_params(model):
70 | return model.state_dict()
71 |
72 |
73 | def get_shape(p, model):
74 | # We need to flip the shapes since OpenAI gpt2 uses convs instead of linear
75 | return p.shape if isinstance(model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0])
76 |
77 |
78 | def get_logits(x):
79 | return x.logits if hasattr(x, "logits") else x
80 |
81 |
82 | def tokenize_gpt(batch, tokenizer, device, test=False):
83 | prompt, label = batch["text"], batch["labels"]
84 | mask_token = -100 # ignore_index of CrossEntropyLoss
85 | if test or not label:
86 | tokens = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True)
87 | tokens["labels"] = tokens["input_ids"].clone()
88 | tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token
89 |
90 | else:
91 | full_prompt = [f"{p} {l} <|endoftext|>" for p, l in zip(prompt, label)]
92 | # full_prompt = [f"{p} {l} {tokenizer.eos_token}" for p, l in zip(prompt, label)]
93 | prompt_ids = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True)["input_ids"]
94 | num_prompt_toks = [int((i != tokenizer.pad_token_id).sum()) for i in prompt_ids]
95 | tokens = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
96 | tokens["labels"] = tokens["input_ids"].clone()
97 | for i in range(len(prompt)):
98 | tokens["labels"][i][:num_prompt_toks[i]] = mask_token
99 |
100 | tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token # What is this doing?
101 |
102 | tokens = {f"{k1}": v1.to(device) for k1, v1 in tokens.items()}
103 | return tokens
104 |
105 |
106 | def tokenize_qa(batch, tokenizer, device, **kwargs):
107 | input_sequences, output_sequences = batch["text"], batch["labels"]
108 |
109 | input_encoding = tokenizer(
110 | list(input_sequences),
111 | padding="longest",
112 | max_length=20,
113 | truncation=True,
114 | return_tensors="pt",
115 | )
116 |
117 | input_ids, attention_mask = input_encoding.input_ids, input_encoding.attention_mask
118 |
119 | target_encoding = tokenizer(
120 | list(output_sequences),
121 | padding="longest",
122 | max_length=20,
123 | truncation=True,
124 | return_tensors="pt",
125 | )
126 |
127 | labels = target_encoding.input_ids
128 | labels[labels == tokenizer.pad_token_id] = -100
129 |
130 | tokens = {
131 | "input_ids": input_ids,
132 | "attention_mask": attention_mask,
133 | "labels": labels
134 | }
135 |
136 | tokens = {f"{k1}": v1.to(device) for k1, v1 in tokens.items()}
137 | return tokens
138 |
139 |
140 | def tokenize_clf(batch, tokenizer, device, **kwargs):
141 | input_sequences, labels = batch["text"], batch["labels"]
142 |
143 | tokens = tokenizer(input_sequences, truncation=True, padding="max_length", return_tensors="pt")
144 | tokens["labels"] = labels
145 | tokens = {f"{k1}": v1.to(device) for k1, v1 in tokens.items()}
146 | return tokens
147 |
148 |
149 |
150 |
151 |
152 |
153 |
--------------------------------------------------------------------------------
/melo/wiki_bio_concepts.txt:
--------------------------------------------------------------------------------
1 | john russell reynolds
2 | matthew aylmer , 1st baron aylmer
3 | rick mahler
4 | james blair south carolina
5 | tim finchem
6 | akila dananjaya
7 | derek king australian footballer
8 | wilhelm windelband
9 | freddie frith
10 | marshall manesh
11 | eleanor arnason
12 | carter harrison , sr. .
13 | winnebago deal
14 | noel hogan
15 | dawn landes
16 | bill quinn
17 | carol huston
18 | gia carangi
19 | nigel milsom
20 | rod morgenstein
21 | terry alderman
22 | frank a. mclain
23 | rich williams
24 | torry castellano
25 | albert i , margrave of meissen
26 | sirið stenberg
27 | thomas harriot
28 | tadeusz szeligowski
29 | gordon strachan
30 | steven threet
31 | archie baird
32 | peter breen politician
33 | adja yunkers
34 | the blood divine
35 | king zhuang of chu
36 | william j. flanagan , jr.
37 | k. s. manilal
38 | jeannine riley
39 | seyi shay
40 | hilda kuper
41 | stuart scott
42 | mark fite
43 | philippe dodard
44 | rudy fernandez labor leader
45 | mackenzie caquatto
46 | twila shively
47 | lionel aldridge
48 | irena sendler
49 | ronnie barker
50 | honoré iii , prince of monaco
51 | emily gielnik
52 | choi jae-bong
53 | tom izzo
54 | tommy nutter
55 | jearl walker
56 | steve ridzik
57 | achille-ferdinand carrier
58 | tera van beilen
59 | harry kennedy
60 | david kappos
61 | pattern is movement
62 | kévin gameiro
63 | lee hsien loong
64 | lucien turcotte pacaud
65 | makiko esumi
66 | kate deines
67 | c. v. ananda bose
68 | anthony dimond
69 | honoré iv , prince of monaco
70 | tristan rogers
71 | john burnham cricketer
72 | nate saint
73 | thutmose iii
74 | john loder sound engineer
75 | a. p. j. abdul kalam
76 | john reed , jr. .
77 | paul elliott politician
78 | moisés kaufman
79 | robert holgate
80 | duncan mackay footballer
81 | saul david
82 | tomasz lis
83 | véra korène
84 | nodar kumaritashvili
85 | leana de bruin
86 | alfred fischer ss officer
87 | kermit davis
88 | daniel ménard
89 | modibo adama
90 | bert deacon
91 | mushahid hussain syed
92 | kia joorabchian
93 | vitaliano brancati
94 | emperor wenxuan of northern qi
95 | johan christian dahl
96 | steve cooper footballer , born 1964
97 | ernest miller cinematographer
98 | david king australian rules footballer
99 | danny smith coach
100 | hope cooke
101 | tathagata satpathy
102 | michel mathieu canadian politician
103 | mario monti
104 | pino palladino
105 | tony la russa
106 | murray g. ross
107 | malcolm brogdon
108 | john les
109 | evan rachel wood
110 | frank abagnale
111 | reezal merican naina merican
112 | dan stearns
113 | lindsay crouse
114 | clay timpner
115 | yaakov israel ifargan
116 | ha jung-woo
117 | charles lee basketball
118 | stereophonics
119 | don r. swanson
120 | roy beggs , jr. .
121 | adiele afigbo
122 | brian petrovek
123 | john cushnahan
124 | ron meagher
125 | george milne cricketer
126 | bill tobin american football
127 | william luther pierce
128 | martina sorbara
129 | tom wise
130 | frederick thomas brentnall
131 | bill brown goalkeeper
132 | eden natan-zada
133 | richard carpenter screenwriter
134 | joe brown utility player
135 | wayne allyn root
136 | assassination of robert f. kennedy
137 | paul caddis
138 | paul taylor winger
139 | linda hunt
140 | jerry leger
141 | 3rd dalai lama
142 | james clarke vc
143 | jack straw
144 | syd rapson
145 | billy barnie
146 | catherine johnson playwright
147 | sara montiel
148 | lucy akhurst
149 | william allan neilson
150 | elisha brown
151 | joe walsh rugby league
152 | josiah mason
153 | balbir singh kullar
154 | george bovell
155 | fei-ping hsu
156 | anne de gaulle
157 | rusty stevens
158 | john cameron alberta politician
159 | carole gist
160 | david collings
161 | matt striebel
162 | bob miller american football
163 | bryan mcclendon
164 | royce campbell
165 | carlos arniches
166 | geoff griffin
167 | frankie lymon
168 | raymond harry brown
169 | george roll
170 | ayn rand
171 | richard allen epstein
172 | tom butler actor
173 | kenan hasagić
174 | gordon hogg
175 | vagos motorcycle club
176 | katie ledecky
177 | michael savage
178 | john howe illustrator
179 | alana davis
180 | arthur sewall
181 | stan heal
182 | ithamara koorax
183 | thomas wolfe
184 | john russell vc
185 | cicero hunt lewis
186 | philip of france 1116 -- 1131
187 | brian hughes musician
188 | rickey paulding
189 | charles melville hays
190 | lee naylor footballer
191 | bane band
192 | adam collis
193 | alan dinehart
194 | sylvain barrier
195 | kirill karabits
196 | b. k. anand
197 | robert emmett keane
198 | charlotte rae
199 | riccardo tisci
200 | lester germer
201 | laurent koscielny
202 | bridget moynahan
203 | george hubbard clapp
204 | merle oberon
205 | mayhew foster
206 | hephaestion
207 | thomas biagi
208 | susan pedersen swimmer
209 | tetsuzō iwamoto
210 | donald alexander mackinnon
211 | joe holland basketball
212 | casey serin
213 | jean hugo
214 | heinz christian pander
215 | arthur tedder , 1st baron tedder
216 | cindy kleine
217 | willie naulls
218 | john holman chemist
219 | paul y. r. waddington
220 | andy hurley
221 | dara torres
222 | john hughes archbishop of new york
223 | millicent shelton
224 | whitey kurowski
225 | nofx
226 | hisashi iwakuma
227 | virginia bottomley
228 | john liscio
229 | john vallely
230 | johannes andreas august grabau
231 | guðlaugur Þór Þórðarson
232 | laurier lévesque
233 | micky moody
234 | gündüz kılıç
235 | michael replogle
236 | billy burke golfer
237 | ted childs
238 | edward synge archbishop of tuam
--------------------------------------------------------------------------------
/peft_egg/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # VSCode
132 | .vscode
133 |
134 | # IntelliJ
135 | .idea
136 |
137 | # Mac .DS_Store
138 | .DS_Store
139 |
140 | # More test things
141 | wandb
--------------------------------------------------------------------------------
/peft_egg/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: quality style test docs
2 |
3 | check_dirs := src tests examples docs
4 |
5 | # Check that source code meets quality standards
6 |
7 | # this target runs checks on all files
8 | quality:
9 | black --check $(check_dirs)
10 | ruff $(check_dirs)
11 | doc-builder style src/peft tests docs/source --max_len 119 --check_only
12 |
13 | # Format source code automatically and check is there are any problems left that need manual fixing
14 | style:
15 | black $(check_dirs)
16 | ruff $(check_dirs) --fix
17 | doc-builder style src/peft tests docs/source --max_len 119
18 |
19 | test:
20 | python -m pytest -n 3 tests/ $(if $(IS_GITHUB_CI),--report-log "ci_tests.log",)
21 |
22 | tests_examples_multi_gpu:
23 | python -m pytest -m multi_gpu_tests tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "multi_gpu_examples.log",)
24 |
25 | tests_examples_single_gpu:
26 | python -m pytest -m single_gpu_tests tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "single_gpu_examples.log",)
27 |
28 | tests_core_multi_gpu:
29 | python -m pytest -m multi_gpu_tests tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_multi_gpu.log",)
30 |
31 | tests_core_single_gpu:
32 | python -m pytest -m single_gpu_tests tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_single_gpu.log",)
33 |
34 | tests_common_gpu:
35 | python -m pytest tests/test_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_decoder.log",)
36 | python -m pytest tests/test_encoder_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_encoder_decoder.log",)
37 |
--------------------------------------------------------------------------------
/peft_egg/docker/peft-cpu/Dockerfile:
--------------------------------------------------------------------------------
1 | # Builds GPU docker image of PyTorch
2 | # Uses multi-staged approach to reduce size
3 | # Stage 1
4 | # Use base conda image to reduce time
5 | FROM continuumio/miniconda3:latest AS compile-image
6 | # Specify py version
7 | ENV PYTHON_VERSION=3.8
8 | # Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
9 | RUN apt-get update && \
10 | apt-get install -y curl git wget software-properties-common git-lfs && \
11 | apt-get clean && \
12 | rm -rf /var/lib/apt/lists*
13 |
14 | # Install audio-related libraries
15 | RUN apt-get update && \
16 | apt install -y ffmpeg
17 |
18 | RUN apt install -y libsndfile1-dev
19 | RUN git lfs install
20 |
21 | # Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
22 | RUN conda create --name peft python=${PYTHON_VERSION} ipython jupyter pip
23 | RUN python3 -m pip install --no-cache-dir --upgrade pip
24 |
25 | # Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
26 | # We don't install pytorch here yet since CUDA isn't available
27 | # instead we use the direct torch wheel
28 | ENV PATH /opt/conda/envs/peft/bin:$PATH
29 | # Activate our bash shell
30 | RUN chsh -s /bin/bash
31 | SHELL ["/bin/bash", "-c"]
32 | # Activate the conda env and install transformers + accelerate from source
33 | RUN source activate peft && \
34 | python3 -m pip install --no-cache-dir \
35 | librosa \
36 | "soundfile>=0.12.1" \
37 | scipy \
38 | git+https://github.com/huggingface/transformers \
39 | git+https://github.com/huggingface/accelerate \
40 | peft[test]@git+https://github.com/huggingface/peft
41 |
42 | # Install apt libs
43 | RUN apt-get update && \
44 | apt-get install -y curl git wget && \
45 | apt-get clean && \
46 | rm -rf /var/lib/apt/lists*
47 |
48 | RUN echo "source activate peft" >> ~/.profile
49 |
50 | # Activate the virtualenv
51 | CMD ["/bin/bash"]
--------------------------------------------------------------------------------
/peft_egg/docker/peft-gpu/Dockerfile:
--------------------------------------------------------------------------------
1 | # Builds GPU docker image of PyTorch
2 | # Uses multi-staged approach to reduce size
3 | # Stage 1
4 | # Use base conda image to reduce time
5 | FROM continuumio/miniconda3:latest AS compile-image
6 | # Specify py version
7 | ENV PYTHON_VERSION=3.8
8 | # Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
9 | RUN apt-get update && \
10 | apt-get install -y curl git wget software-properties-common git-lfs && \
11 | apt-get clean && \
12 | rm -rf /var/lib/apt/lists*
13 |
14 | # Install audio-related libraries
15 | RUN apt-get update && \
16 | apt install -y ffmpeg
17 |
18 | RUN apt install -y libsndfile1-dev
19 | RUN git lfs install
20 |
21 | # Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
22 | RUN conda create --name peft python=${PYTHON_VERSION} ipython jupyter pip
23 | RUN python3 -m pip install --no-cache-dir --upgrade pip
24 |
25 | # Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
26 | # We don't install pytorch here yet since CUDA isn't available
27 | # instead we use the direct torch wheel
28 | ENV PATH /opt/conda/envs/peft/bin:$PATH
29 | # Activate our bash shell
30 | RUN chsh -s /bin/bash
31 | SHELL ["/bin/bash", "-c"]
32 | # Activate the conda env and install transformers + accelerate from source
33 | RUN source activate peft && \
34 | python3 -m pip install --no-cache-dir \
35 | librosa \
36 | "soundfile>=0.12.1" \
37 | scipy \
38 | git+https://github.com/huggingface/transformers \
39 | git+https://github.com/huggingface/accelerate \
40 | peft[test]@git+https://github.com/huggingface/peft
41 |
42 | RUN python3 -m pip install --no-cache-dir bitsandbytes
43 |
44 | # Stage 2
45 | FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 AS build-image
46 | COPY --from=compile-image /opt/conda /opt/conda
47 | ENV PATH /opt/conda/bin:$PATH
48 |
49 | # Install apt libs
50 | RUN apt-get update && \
51 | apt-get install -y curl git wget && \
52 | apt-get clean && \
53 | rm -rf /var/lib/apt/lists*
54 |
55 | RUN echo "source activate peft" >> ~/.profile
56 |
57 | # Activate the virtualenv
58 | CMD ["/bin/bash"]
--------------------------------------------------------------------------------
/peft_egg/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SOURCEDIR = source
8 | BUILDDIR = _build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/peft_egg/docs/source/_config.py:
--------------------------------------------------------------------------------
1 | # docstyle-ignore
2 | INSTALL_CONTENT = """
3 | # PEFT installation
4 | ! pip install peft accelerate transformers
5 | # To install from source instead of the last release, comment the command above and uncomment the following one.
6 | # ! pip install git+https://github.com/huggingface/peft.git
7 | """
8 |
--------------------------------------------------------------------------------
/peft_egg/docs/source/_toctree.yml:
--------------------------------------------------------------------------------
1 | - title: Get started
2 | sections:
3 | - local: index
4 | title: 🤗 PEFT
5 | - local: quicktour
6 | title: Quicktour
7 | - local: install
8 | title: Installation
9 |
10 | - title: Task guides
11 | sections:
12 | - local: task_guides/image_classification_lora
13 | title: Image classification using LoRA
14 | - local: task_guides/seq2seq-prefix-tuning
15 | title: Prefix tuning for conditional generation
16 | - local: task_guides/clm-prompt-tuning
17 | title: Prompt tuning for causal language modeling
18 | - local: task_guides/semantic_segmentation_lora
19 | title: Semantic segmentation using LoRA
20 | - local: task_guides/ptuning-seq-classification
21 | title: P-tuning for sequence classification
22 | - local: task_guides/dreambooth_lora
23 | title: Dreambooth fine-tuning with LoRA
24 | - local: task_guides/token-classification-lora
25 | title: LoRA for token classification
26 | - local: task_guides/int8-asr
27 | title: int8 training for automatic speech recognition
28 |
29 | - title: 🤗 Accelerate integrations
30 | sections:
31 | - local: accelerate/deepspeed-zero3-offload
32 | title: DeepSpeed
33 | - local: accelerate/fsdp
34 | title: Fully Sharded Data Parallel
35 |
36 | - title: Conceptual guides
37 | sections:
38 | - local: conceptual_guides/lora
39 | title: LoRA
40 | - local: conceptual_guides/prompting
41 | title: Prompting
42 |
43 | - title: Reference
44 | sections:
45 | - local: package_reference/peft_model
46 | title: PEFT model
47 | - local: package_reference/config
48 | title: Configuration
49 | - local: package_reference/tuners
50 | title: Tuners
--------------------------------------------------------------------------------
/peft_egg/docs/source/accelerate/fsdp.mdx:
--------------------------------------------------------------------------------
1 | # Fully Sharded Data Parallel
2 |
3 | [Fully sharded data parallel](https://pytorch.org/docs/stable/fsdp.html) (FSDP) is developed for distributed training of large pretrained models up to 1T parameters. FSDP achieves this by sharding the model parameters, gradients, and optimizer states across data parallel processes and it can also offload sharded model parameters to a CPU. The memory efficiency afforded by FSDP allows you to scale training to larger batch or model sizes.
4 |
5 |
6 |
7 | Currently, FSDP does not confer any reduction in GPU memory usage and FSDP with CPU offload actually consumes 1.65x more GPU memory during training. You can track this PyTorch [issue](https://github.com/pytorch/pytorch/issues/91165) for any updates.
8 |
9 |
10 |
11 | FSDP is supported in 🤗 Accelerate, and you can use it with 🤗 PEFT. This guide will help you learn how to use our FSDP [training script](https://github.com/huggingface/peft/blob/main/examples/conditional_generation/peft_lora_seq2seq_accelerate_fsdp.py). You'll configure the script to train a large model for conditional generation.
12 |
13 | ## Configuration
14 |
15 | Begin by running the following command to [create a FSDP configuration file](https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp) with 🤗 Accelerate. Use the `--config_file` flag to save the configuration file to a specific location, otherwise it is saved as a `default_config.yaml` file in the 🤗 Accelerate cache.
16 |
17 | The configuration file is used to set the default options when you launch the training script.
18 |
19 | ```bash
20 | accelerate config --config_file fsdp_config.yaml
21 | ```
22 |
23 | You'll be asked a few questions about your setup, and configure the following arguments. For this example, make sure you fully shard the model parameters, gradients, optimizer states, leverage the CPU for offloading, and wrap model layers based on the Transformer layer class name.
24 |
25 | ```bash
26 | `Sharding Strategy`: [1] FULL_SHARD (shards optimizer states, gradients and parameters), [2] SHARD_GRAD_OP (shards optimizer states and gradients), [3] NO_SHARD
27 | `Offload Params`: Decides Whether to offload parameters and gradients to CPU
28 | `Auto Wrap Policy`: [1] TRANSFORMER_BASED_WRAP, [2] SIZE_BASED_WRAP, [3] NO_WRAP
29 | `Transformer Layer Class to Wrap`: When using `TRANSFORMER_BASED_WRAP`, user specifies comma-separated string of transformer layer class names (case-sensitive) to wrap ,e.g,
30 | `BertLayer`, `GPTJBlock`, `T5Block`, `BertLayer,BertEmbeddings,BertSelfOutput`...
31 | `Min Num Params`: minimum number of parameters when using `SIZE_BASED_WRAP`
32 | `Backward Prefetch`: [1] BACKWARD_PRE, [2] BACKWARD_POST, [3] NO_PREFETCH
33 | `State Dict Type`: [1] FULL_STATE_DICT, [2] LOCAL_STATE_DICT, [3] SHARDED_STATE_DICT
34 | ```
35 |
36 | For example, your FSDP configuration file may look like the following:
37 |
38 | ```yaml
39 | command_file: null
40 | commands: null
41 | compute_environment: LOCAL_MACHINE
42 | deepspeed_config: {}
43 | distributed_type: FSDP
44 | downcast_bf16: 'no'
45 | dynamo_backend: 'NO'
46 | fsdp_config:
47 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
48 | fsdp_backward_prefetch_policy: BACKWARD_PRE
49 | fsdp_offload_params: true
50 | fsdp_sharding_strategy: 1
51 | fsdp_state_dict_type: FULL_STATE_DICT
52 | fsdp_transformer_layer_cls_to_wrap: T5Block
53 | gpu_ids: null
54 | machine_rank: 0
55 | main_process_ip: null
56 | main_process_port: null
57 | main_training_function: main
58 | megatron_lm_config: {}
59 | mixed_precision: 'no'
60 | num_machines: 1
61 | num_processes: 2
62 | rdzv_backend: static
63 | same_network: true
64 | tpu_name: null
65 | tpu_zone: null
66 | use_cpu: false
67 | ```
68 |
69 | ## The important parts
70 |
71 | Let's dig a bit deeper into the training script to understand how it works.
72 |
73 | The [`main()`](https://github.com/huggingface/peft/blob/2822398fbe896f25d4dac5e468624dc5fd65a51b/examples/conditional_generation/peft_lora_seq2seq_accelerate_fsdp.py#L14) function begins with initializing an [`~accelerate.Accelerator`] class which handles everything for distributed training, such as automatically detecting your training environment.
74 |
75 |
76 |
77 | 💡 Feel free to change the model and dataset inside the `main` function. If your dataset format is different from the one in the script, you may also need to write your own preprocessing function.
78 |
79 |
80 |
81 | The script also creates a configuration corresponding to the 🤗 PEFT method you're using. For LoRA, you'll use [`LoraConfig`] to specify the task type, and several other important parameters such as the dimension of the low-rank matrices, the matrices scaling factor, and the dropout probability of the LoRA layers. If you want to use a different 🤗 PEFT method, replace `LoraConfig` with the appropriate [class](../package_reference/tuners).
82 |
83 | Next, the script wraps the base model and `peft_config` with the [`get_peft_model`] function to create a [`PeftModel`].
84 |
85 | ```diff
86 | def main():
87 | + accelerator = Accelerator()
88 | model_name_or_path = "t5-base"
89 | base_path = "temp/data/FinancialPhraseBank-v1.0"
90 | + peft_config = LoraConfig(
91 | task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
92 | )
93 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
94 | + model = get_peft_model(model, peft_config)
95 | ```
96 |
97 | Throughout the script, you'll see the [`~accelerate.Accelerator.main_process_first`] and [`~accelerate.Accelerator.wait_for_everyone`] functions which help control and synchronize when processes are executed.
98 |
99 | After your dataset is prepared, and all the necessary training components are loaded, the script checks if you're using the `fsdp_plugin`. PyTorch offers two ways for wrapping model layers in FSDP, automatically or manually. The simplest method is to allow FSDP to automatically recursively wrap model layers without changing any other code. You can choose to wrap the model layers based on the layer name or on the size (number of parameters). In the FSDP configuration file, it uses the `TRANSFORMER_BASED_WRAP` option to wrap the [`T5Block`] layer.
100 |
101 | ```py
102 | if getattr(accelerator.state, "fsdp_plugin", None) is not None:
103 | accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model)
104 | ```
105 |
106 | Next, use 🤗 Accelerate's [`~accelerate.Accelerator.prepare`] function to prepare the model, datasets, optimizer, and scheduler for training.
107 |
108 | ```py
109 | model, train_dataloader, eval_dataloader, optimizer, lr_scheduler = accelerator.prepare(
110 | model, train_dataloader, eval_dataloader, optimizer, lr_scheduler
111 | )
112 | ```
113 |
114 | From here, the remainder of the script handles the training loop, evaluation, and sharing your model to the Hub.
115 |
116 | ## Train
117 |
118 | Run the following command to launch the training script. Earlier, you saved the configuration file to `fsdp_config.yaml`, so you'll need to pass the path to the launcher with the `--config_file` argument like this:
119 |
120 | ```bash
121 | accelerate launch --config_file fsdp_config.yaml examples/peft_lora_seq2seq_accelerate_fsdp.py
122 | ```
123 |
124 | Once training is complete, the script returns the accuracy and compares the predictions to the labels.
--------------------------------------------------------------------------------
/peft_egg/docs/source/conceptual_guides/lora.mdx:
--------------------------------------------------------------------------------
1 |
12 |
13 | # LoRA
14 |
15 | This conceptual guide gives a brief overview of [LoRA](https://arxiv.org/abs/2106.09685), a technique that accelerates
16 | the fine-tuning of large models while consuming less memory.
17 |
18 | To make fine-tuning more efficient, LoRA's approach is to represent the weight updates with two smaller
19 | matrices (called **update matrices**) through low-rank decomposition. These new matrices can be trained to adapt to the
20 | new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn't receive
21 | any further adjustments. To produce the final results, both the original and the adapted weights are combined.
22 |
23 | This approach has a number of advantages:
24 |
25 | * LoRA makes fine-tuning more efficient by drastically reducing the number of trainable parameters.
26 | * The original pre-trained weights are kept frozen, which means you can have multiple lightweight and portable LoRA models for various downstream tasks built on top of them.
27 | * LoRA is orthogonal to many other parameter-efficient methods and can be combined with many of them.
28 | * Performance of models fine-tuned using LoRA is comparable to the performance of fully fine-tuned models.
29 | * LoRA does not add any inference latency because adapter weights can be merged with the base model.
30 |
31 | In principle, LoRA can be applied to any subset of weight matrices in a neural network to reduce the number of trainable
32 | parameters. However, for simplicity and further parameter efficiency, in Transformer models LoRA is typically applied to
33 | attention blocks only. The resulting number of trainable parameters in a LoRA model depends on the size of the low-rank
34 | update matrices, which is determined mainly by the rank `r` and the shape of the original weight matrix.
35 |
36 | ## Common LoRA parameters in PEFT
37 |
38 | As with other methods supported by PEFT, to fine-tune a model using LoRA, you need to:
39 |
40 | 1. Instantiate a base model.
41 | 2. Create a configuration (`LoraConfig`) where you define LoRA-specific parameters.
42 | 3. Wrap the base model with `get_peft_model()` to get a trainable `PeftModel`.
43 | 4. Train the `PeftModel` as you normally would train the base model.
44 |
45 | `LoraConfig` allows you to control how LoRA is applied to the base model through the following parameters:
46 |
47 | - `r`: the rank of the update matrices, expressed in `int`. Lower rank results in smaller update matrices with fewer trainable parameters.
48 | - `target_modules`: The modules (for example, attention blocks) to apply the LoRA update matrices.
49 | - `alpha`: LoRA scaling factor.
50 | - `bias`: Specifies if the `bias` parameters should be trained. Can be `'none'`, `'all'` or `'lora_only'`.
51 | - `modules_to_save`: List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task.
52 | - `layers_to_transform`: List of layers to be transformed by LoRA. If not specified, all layers in `target_modules` are transformed.
53 | - `layers_pattern`: Pattern to match layer names in `target_modules`, if `layers_to_transform` is specified. By default `PeftModel` will look at common layer pattern (`layers`, `h`, `blocks`, etc.), use it for exotic and custom models.
54 |
55 | ## LoRA examples
56 |
57 | For an example of LoRA method application to various downstream tasks, please refer to the following guides:
58 |
59 | * [Image classification using LoRA](../task_guides/image_classification_lora)
60 | * [Semantic segmentation](../task_guides/semantic_segmentation_lora)
61 |
62 | While the original paper focuses on language models, the technique can be applied to any dense layers in deep learning
63 | models. As such, you can leverage this technique with diffusion models. See [Dreambooth fine-tuning with LoRA](../task_guides/task_guides/dreambooth_lora) task guide for an example.
64 |
--------------------------------------------------------------------------------
/peft_egg/docs/source/conceptual_guides/prompting.mdx:
--------------------------------------------------------------------------------
1 | # Prompting
2 |
3 | Training large pretrained language models is very time-consuming and compute-intensive. As they continue to grow in size, there is increasing interest in more efficient training methods such as *prompting*. Prompting primes a frozen pretrained model for a specific downstream task by including a text prompt that describes the task or even demonstrates an example of the task. With prompting, you can avoid fully training a separate model for each downstream task, and use the same frozen pretrained model instead. This is a lot easier because you can use the same model for several different tasks, and it is significantly more efficient to train and store a smaller set of prompt parameters than to train all the model's parameters.
4 |
5 | There are two categories of prompting methods:
6 |
7 | - hard prompts are manually handcrafted text prompts with discrete input tokens; the downside is that it requires a lot of effort to create a good prompt
8 | - soft prompts are learnable tensors concatenated with the input embeddings that can be optimized to a dataset; the downside is that they aren't human readable because you aren't matching these "virtual tokens" to the embeddings of a real word
9 |
10 | This conceptual guide provides a brief overview of the soft prompt methods included in 🤗 PEFT: prompt tuning, prefix tuning, and P-tuning.
11 |
12 | ## Prompt tuning
13 |
14 |
15 |

16 |
17 | Only train and store a significantly smaller set of task-specific prompt parameters (image source).
18 |
19 | Prompt tuning was developed for text classification tasks on T5 models, and all downstream tasks are cast as a text generation task. For example, sequence classification usually assigns a single class label to a sequence of text. By casting it as a text generation task, the tokens that make up the class label are *generated*. Prompts are added to the input as a series of tokens. Typically, the model parameters are fixed which means the prompt tokens are also fixed by the model parameters.
20 |
21 | The key idea behind prompt tuning is that prompt tokens have their own parameters that are updated independently. This means you can keep the pretrained model's parameters frozen, and only update the gradients of the prompt token embeddings. The results are comparable to the traditional method of training the entire model, and prompt tuning performance scales as model size increases.
22 |
23 | Take a look at [Prompt tuning for causal language modeling](../task_guides/clm-prompt-tuning) for a step-by-step guide on how to train a model with prompt tuning.
24 |
25 | ## Prefix tuning
26 |
27 |
28 |

29 |
30 | Optimize the prefix parameters for each task (image source).
31 |
32 | Prefix tuning was designed for natural language generation (NLG) tasks on GPT models. It is very similar to prompt tuning; prefix tuning also prepends a sequence of task-specific vectors to the input that can be trained and updated while keeping the rest of the pretrained model's parameters frozen.
33 |
34 | The main difference is that the prefix parameters are inserted in **all** of the model layers, whereas prompt tuning only adds the prompt parameters to the model input embeddings. The prefix parameters are also optimized by a separate feed-forward network (FFN) instead of training directly on the soft prompts because it causes instability and hurts performance. The FFN is discarded after updating the soft prompts.
35 |
36 | As a result, the authors found that prefix tuning demonstrates comparable performance to fully finetuning a model, despite having 1000x fewer parameters, and it performs even better in low-data settings.
37 |
38 | Take a look at [Prefix tuning for conditional generation](../task_guides/seq2seq-prefix-tuning) for a step-by-step guide on how to train a model with prefix tuning.
39 |
40 | ## P-tuning
41 |
42 |
43 |

44 |
45 | Prompt tokens can be inserted anywhere in the input sequence, and they are optimized by a prompt encoder (image source).
46 |
47 | P-tuning is designed for natural language understanding (NLU) tasks and all language models.
48 | It is another variation of a soft prompt method; P-tuning also adds a trainable embedding tensor that can be optimized to find better prompts, and it uses a prompt encoder (a bidirectional long-short term memory network or LSTM) to optimize the prompt parameters. Unlike prefix tuning though:
49 |
50 | - the prompt tokens can be inserted anywhere in the input sequence, and it isn't restricted to only the beginning
51 | - the prompt tokens are only added to the input instead of adding them to every layer of the model
52 | - introducing *anchor* tokens can improve performance because they indicate characteristics of a component in the input sequence
53 |
54 | The results suggest that P-tuning is more efficient than manually crafting prompts, and it enables GPT-like models to compete with BERT-like models on NLU tasks.
55 |
56 | Take a look at [P-tuning for sequence classification](../task_guides/ptuning-seq-classification) for a step-by-step guide on how to train a model with P-tuning.
--------------------------------------------------------------------------------
/peft_egg/docs/source/index.mdx:
--------------------------------------------------------------------------------
1 |
12 |
13 | # PEFT
14 |
15 | 🤗 PEFT, or Parameter-Efficient Fine-Tuning (PEFT), is a library for efficiently adapting pre-trained language models (PLMs) to various downstream applications without fine-tuning all the model's parameters.
16 | PEFT methods only fine-tune a small number of (extra) model parameters, significantly decreasing computational and storage costs because fine-tuning large-scale PLMs is prohibitively costly.
17 | Recent state-of-the-art PEFT techniques achieve performance comparable to that of full fine-tuning.
18 |
19 | PEFT is seamlessly integrated with 🤗 Accelerate for large-scale models leveraging DeepSpeed and [Big Model Inference](https://huggingface.co/docs/accelerate/usage_guides/big_modeling).
20 |
21 |
41 |
42 | ## Supported methods
43 |
44 | 1. LoRA: [LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/pdf/2106.09685.pdf)
45 | 2. Prefix Tuning: [Prefix-Tuning: Optimizing Continuous Prompts for Generation](https://aclanthology.org/2021.acl-long.353/), [P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/pdf/2110.07602.pdf)
46 | 3. P-Tuning: [GPT Understands, Too](https://arxiv.org/pdf/2103.10385.pdf)
47 | 4. Prompt Tuning: [The Power of Scale for Parameter-Efficient Prompt Tuning](https://arxiv.org/pdf/2104.08691.pdf)
48 | 5. AdaLoRA: [Adaptive Budget Allocation for Parameter-Efficient Fine-Tuning](https://arxiv.org/abs/2303.10512)
49 | 6. [LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention](https://github.com/ZrrSkywalker/LLaMA-Adapter)
50 |
51 | ## Supported models
52 |
53 | The tables provided below list the PEFT methods and models supported for each task. To apply a particular PEFT method for
54 | a task, please refer to the corresponding Task guides.
55 |
56 | ### Causal Language Modeling
57 |
58 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning |
59 | |--------------| ---- | ---- | ---- | ---- |
60 | | GPT-2 | ✅ | ✅ | ✅ | ✅ |
61 | | Bloom | ✅ | ✅ | ✅ | ✅ |
62 | | OPT | ✅ | ✅ | ✅ | ✅ |
63 | | GPT-Neo | ✅ | ✅ | ✅ | ✅ |
64 | | GPT-J | ✅ | ✅ | ✅ | ✅ |
65 | | GPT-NeoX-20B | ✅ | ✅ | ✅ | ✅ |
66 | | LLaMA | ✅ | ✅ | ✅ | ✅ |
67 | | ChatGLM | ✅ | ✅ | ✅ | ✅ |
68 |
69 | ### Conditional Generation
70 |
71 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning |
72 | | --------- | ---- | ---- | ---- | ---- |
73 | | T5 | ✅ | ✅ | ✅ | ✅ |
74 | | BART | ✅ | ✅ | ✅ | ✅ |
75 |
76 | ### Sequence Classification
77 |
78 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning |
79 | | --------- | ---- | ---- | ---- | ---- |
80 | | BERT | ✅ | ✅ | ✅ | ✅ |
81 | | RoBERTa | ✅ | ✅ | ✅ | ✅ |
82 | | GPT-2 | ✅ | ✅ | ✅ | ✅ |
83 | | Bloom | ✅ | ✅ | ✅ | ✅ |
84 | | OPT | ✅ | ✅ | ✅ | ✅ |
85 | | GPT-Neo | ✅ | ✅ | ✅ | ✅ |
86 | | GPT-J | ✅ | ✅ | ✅ | ✅ |
87 | | Deberta | ✅ | | ✅ | ✅ |
88 | | Deberta-v2 | ✅ | | ✅ | ✅ |
89 |
90 | ### Token Classification
91 |
92 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning |
93 | | --------- | ---- | ---- | ---- | ---- |
94 | | BERT | ✅ | ✅ | | |
95 | | RoBERTa | ✅ | ✅ | | |
96 | | GPT-2 | ✅ | ✅ | | |
97 | | Bloom | ✅ | ✅ | | |
98 | | OPT | ✅ | ✅ | | |
99 | | GPT-Neo | ✅ | ✅ | | |
100 | | GPT-J | ✅ | ✅ | | |
101 | | Deberta | ✅ | | | |
102 | | Deberta-v2 | ✅ | | | |
103 |
104 | ### Text-to-Image Generation
105 |
106 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning |
107 | | --------- | ---- | ---- | ---- | ---- |
108 | | Stable Diffusion | ✅ | | | |
109 |
110 |
111 | ### Image Classification
112 |
113 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning |
114 | | --------- | ---- | ---- | ---- | ---- |
115 | | ViT | ✅ | | | |
116 | | Swin | ✅ | | | |
117 |
118 | ### Image to text (Multi-modal models)
119 |
120 | We have tested LoRA for [ViT](https://huggingface.co/docs/transformers/model_doc/vit) and [Swin](https://huggingface.co/docs/transformers/model_doc/swin) for fine-tuning on image classification.
121 | However, it should be possible to use LoRA for any [ViT-based model](https://huggingface.co/models?pipeline_tag=image-classification&sort=downloads&search=vit) from 🤗 Transformers.
122 | Check out the [Image classification](/task_guides/image_classification_lora) task guide to learn more. If you run into problems, please open an issue.
123 |
124 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning |
125 | | --------- | ---- | ---- | ---- | ---- |
126 | | Blip-2 | ✅ | | | |
127 |
128 |
129 | ### Semantic Segmentation
130 |
131 | As with image-to-text models, you should be able to apply LoRA to any of the [segmentation models](https://huggingface.co/models?pipeline_tag=image-segmentation&sort=downloads).
132 | It's worth noting that we haven't tested this with every architecture yet. Therefore, if you come across any issues, kindly create an issue report.
133 |
134 | | Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning |
135 | | --------- | ---- | ---- | ---- | ---- |
136 | | SegFormer | ✅ | | | |
137 |
138 |
--------------------------------------------------------------------------------
/peft_egg/docs/source/install.mdx:
--------------------------------------------------------------------------------
1 |
12 |
13 | # Installation
14 |
15 | Before you start, you will need to setup your environment, install the appropriate packages, and configure 🤗 PEFT. 🤗 PEFT is tested on **Python 3.7+**.
16 |
17 | 🤗 PEFT is available on pypi, as well as GitHub:
18 |
19 | ## pip
20 |
21 | To install 🤗 PEFT from pypi:
22 |
23 | ```bash
24 | pip install peft
25 | ```
26 |
27 | ## Source
28 |
29 | New features that haven't been released yet are added every day, which also means there may be some bugs. To try them out, install from the GitHub repository:
30 |
31 | ```bash
32 | pip install git+https://github.com/huggingface/peft
33 | ```
34 |
35 | If you're working on contributing to the library or wish to play with the source code and see live
36 | results as you run the code, an editable version can be installed from a locally-cloned version of the
37 | repository:
38 |
39 | ```bash
40 | git clone https://github.com/huggingface/peft
41 | cd peft
42 | pip install -e .
43 | ```
44 |
--------------------------------------------------------------------------------
/peft_egg/docs/source/package_reference/config.mdx:
--------------------------------------------------------------------------------
1 | # Configuration
2 |
3 | The configuration classes stores the configuration of a [`PeftModel`], PEFT adapter models, and the configurations of [`PrefixTuning`], [`PromptTuning`], and [`PromptEncoder`]. They contain methods for saving and loading model configurations from the Hub, specifying the PEFT method to use, type of task to perform, and model configurations like number of layers and number of attention heads.
4 |
5 | ## PeftConfigMixin
6 |
7 | [[autodoc]] utils.config.PeftConfigMixin
8 | - all
9 |
10 | ## PeftConfig
11 |
12 | [[autodoc]] PeftConfig
13 | - all
14 |
15 | ## PromptLearningConfig
16 |
17 | [[autodoc]] PromptLearningConfig
18 | - all
19 |
--------------------------------------------------------------------------------
/peft_egg/docs/source/package_reference/peft_model.mdx:
--------------------------------------------------------------------------------
1 | # Models
2 |
3 | [`PeftModel`] is the base model class for specifying the base Transformer model and configuration to apply a PEFT method to. The base `PeftModel` contains methods for loading and saving models from the Hub, and supports the [`PromptEncoder`] for prompt learning.
4 |
5 | ## PeftModel
6 |
7 | [[autodoc]] PeftModel
8 | - all
9 |
10 | ## PeftModelForSequenceClassification
11 |
12 | A `PeftModel` for sequence classification tasks.
13 |
14 | [[autodoc]] PeftModelForSequenceClassification
15 | - all
16 |
17 | ## PeftModelForTokenClassification
18 |
19 | A `PeftModel` for token classification tasks.
20 |
21 | [[autodoc]] PeftModelForTokenClassification
22 | - all
23 |
24 | ## PeftModelForCausalLM
25 |
26 | A `PeftModel` for causal language modeling.
27 |
28 | [[autodoc]] PeftModelForCausalLM
29 | - all
30 |
31 | ## PeftModelForSeq2SeqLM
32 |
33 | A `PeftModel` for sequence-to-sequence language modeling.
34 |
35 | [[autodoc]] PeftModelForSeq2SeqLM
36 | - all
37 |
--------------------------------------------------------------------------------
/peft_egg/docs/source/package_reference/tuners.mdx:
--------------------------------------------------------------------------------
1 | # Tuners
2 |
3 | Each tuner (or PEFT method) has a configuration and model.
4 |
5 | ## LoRA
6 |
7 | For finetuning a model with LoRA.
8 |
9 | [[autodoc]] LoraConfig
10 |
11 | [[autodoc]] LoraModel
12 |
13 | [[autodoc]] tuners.lora.LoraLayer
14 |
15 | [[autodoc]] tuners.lora.Linear
16 |
17 | ## P-tuning
18 |
19 | [[autodoc]] tuners.p_tuning.PromptEncoderConfig
20 |
21 | [[autodoc]] tuners.p_tuning.PromptEncoder
22 |
23 | ## Prefix tuning
24 |
25 | [[autodoc]] tuners.prefix_tuning.PrefixTuningConfig
26 |
27 | [[autodoc]] tuners.prefix_tuning.PrefixEncoder
28 |
29 | ## Prompt tuning
30 |
31 | [[autodoc]] tuners.prompt_tuning.PromptTuningConfig
32 |
33 | [[autodoc]] tuners.prompt_tuning.PromptEmbedding
--------------------------------------------------------------------------------
/peft_egg/docs/source/quicktour.mdx:
--------------------------------------------------------------------------------
1 |
12 |
13 | # Quicktour
14 |
15 | 🤗 PEFT contains parameter-efficient finetuning methods for training large pretrained models. The traditional paradigm is to finetune all of a model's parameters for each downstream task, but this is becoming exceedingly costly and impractical because of the enormous number of parameters in models today. Instead, it is more efficient to train a smaller number of prompt parameters or use a reparametrization method like low-rank adaptation (LoRA) to reduce the number of trainable parameters.
16 |
17 | This quicktour will show you 🤗 PEFT's main features and help you train large pretrained models that would typically be inaccessible on consumer devices. You'll see how to train the 1.2B parameter [`bigscience/mt0-large`](https://huggingface.co/bigscience/mt0-large) model with LoRA to generate a classification label and use it for inference.
18 |
19 | ## PeftConfig
20 |
21 | Each 🤗 PEFT method is defined by a [`PeftConfig`] class that stores all the important parameters for building a [`PeftModel`].
22 |
23 | Because you're going to use LoRA, you'll need to load and create a [`LoraConfig`] class. Within `LoraConfig`, specify the following parameters:
24 |
25 | - the `task_type`, or sequence-to-sequence language modeling in this case
26 | - `inference_mode`, whether you're using the model for inference or not
27 | - `r`, the dimension of the low-rank matrices
28 | - `lora_alpha`, the scaling factor for the low-rank matrices
29 | - `lora_dropout`, the dropout probability of the LoRA layers
30 |
31 | ```python
32 | from peft import LoraConfig, TaskType
33 |
34 | peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
35 | ```
36 |
37 |
38 |
39 | 💡 See the [`LoraConfig`] reference for more details about other parameters you can adjust.
40 |
41 |
42 |
43 | ## PeftModel
44 |
45 | A [`PeftModel`] is created by the [`get_peft_model`] function. It takes a base model - which you can load from the 🤗 Transformers library - and the [`PeftConfig`] containing the instructions for how to configure a model for a specific 🤗 PEFT method.
46 |
47 | Start by loading the base model you want to finetune.
48 |
49 | ```python
50 | from transformers import AutoModelForSeq2SeqLM
51 |
52 | model_name_or_path = "bigscience/mt0-large"
53 | tokenizer_name_or_path = "bigscience/mt0-large"
54 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
55 | ```
56 |
57 | Wrap your base model and `peft_config` with the `get_peft_model` function to create a [`PeftModel`]. To get a sense of the number of trainable parameters in your model, use the [`print_trainable_parameters`] method. In this case, you're only training 0.19% of the model's parameters! 🤏
58 |
59 | ```python
60 | from peft import get_peft_model
61 |
62 | model = get_peft_model(model, peft_config)
63 | model.print_trainable_parameters()
64 | "output: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282"
65 | ```
66 |
67 | That is it 🎉! Now you can train the model using the 🤗 Transformers [`~transformers.Trainer`], 🤗 Accelerate, or any custom PyTorch training loop.
68 |
69 | ## Save and load a model
70 |
71 | After your model is finished training, you can save your model to a directory using the [`~transformers.PreTrainedModel.save_pretrained`] function. You can also save your model to the Hub (make sure you log in to your Hugging Face account first) with the [`~transformers.PreTrainedModel.push_to_hub`] function.
72 |
73 | ```python
74 | model.save_pretrained("output_dir")
75 |
76 | # if pushing to Hub
77 | from huggingface_hub import notebook_login
78 |
79 | notebook_login()
80 | model.push_to_hub("my_awesome_peft_model")
81 | ```
82 |
83 | This only saves the incremental 🤗 PEFT weights that were trained, meaning it is super efficient to store, transfer, and load. For example, this [`bigscience/T0_3B`](https://huggingface.co/smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM) model trained with LoRA on the [`twitter_complaints`](https://huggingface.co/datasets/ought/raft/viewer/twitter_complaints/train) subset of the RAFT [dataset](https://huggingface.co/datasets/ought/raft) only contains two files: `adapter_config.json` and `adapter_model.bin`. The latter file is just 19MB!
84 |
85 | Easily load your model for inference using the [`~transformers.PreTrainedModel.from_pretrained`] function:
86 |
87 | ```diff
88 | from transformers import AutoModelForSeq2SeqLM
89 | + from peft import PeftModel, PeftConfig
90 |
91 | + peft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"
92 | + config = PeftConfig.from_pretrained(peft_model_id)
93 | model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
94 | + model = PeftModel.from_pretrained(model, peft_model_id)
95 | tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
96 |
97 | model = model.to(device)
98 | model.eval()
99 | inputs = tokenizer("Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label :", return_tensors="pt")
100 |
101 | with torch.no_grad():
102 | outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=10)
103 | print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0])
104 | 'complaint'
105 | ```
106 |
107 | ## Next steps
108 |
109 | Now that you've seen how to train a model with one of the 🤗 PEFT methods, we encourage you to try out some of the other methods like prompt tuning. The steps are very similar to the ones shown in this quickstart; prepare a [`PeftConfig`] for a 🤗 PEFT method, and use the `get_peft_model` to create a [`PeftModel`] from the configuration and base model. Then you can train it however you like!
110 |
111 | Feel free to also take a look at the task guides if you're interested in training a model with a 🤗 PEFT method for a specific task such as semantic segmentation, multilingual automatic speech recognition, DreamBooth, and token classification.
--------------------------------------------------------------------------------
/peft_egg/grammar.py:
--------------------------------------------------------------------------------
1 | # import torch
2 | #
3 | # lora_A = torch.randn((6,1))
4 | # tensor_list = [torch.tensor(float(i)) for i in range(1,5)]
5 | # print(f'lora_A: {lora_A}')
6 | # print(f'tensor_list: {tensor_list}')
7 | #
8 | # lora_A.requires_grad = True
9 | # for x in tensor_list:
10 | # x.requires_grad = True
11 | #
12 | # c = []
13 | # for x in tensor_list:
14 | # c.append(lora_A[1] * x)
15 | #
16 | # d = torch.stack(c,0)
17 | # print(f'stacked d: {d}')
18 | #
19 | # d.sum().backward()
20 | #
21 | # print(lora_A.grad)
22 |
23 |
24 | import torch
25 | import torch.nn as nn
26 | class Conv1D(nn.Module):
27 | """
28 | 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
29 |
30 | Basically works like a linear layer but the weights are transposed.
31 |
32 | Args:
33 | nf (`int`): The number of output features.
34 | nx (`int`): The number of input features.
35 | """
36 |
37 | def __init__(self, nf, nx):
38 | super().__init__()
39 | self.nf = nf
40 | self.weight = nn.Parameter(torch.empty(nx, nf))
41 | self.bias = nn.Parameter(torch.zeros(nf))
42 | nn.init.normal_(self.weight, std=0.02)
43 |
44 | def forward(self, x):
45 | size_out = x.size()[:-1] + (self.nf,)
46 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
47 | x = x.view(size_out)
48 | return x
49 | a = Conv1D(500,300)
50 |
51 | print(a.weight.shape)
--------------------------------------------------------------------------------
/peft_egg/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 119
3 | target-version = ['py36']
4 |
5 | [tool.ruff]
6 | ignore = ["C901", "E501", "E741", "W605"]
7 | select = ["C", "E", "F", "I", "W"]
8 | line-length = 119
9 |
10 | [tool.ruff.isort]
11 | lines-after-imports = 2
12 | known-first-party = ["peft"]
13 |
14 | [isort]
15 | default_section = "FIRSTPARTY"
16 | known_first_party = "peft"
17 | known_third_party = [
18 | "numpy",
19 | "torch",
20 | "accelerate",
21 | "transformers",
22 | ]
23 | line_length = 119
24 | lines_after_imports = 2
25 | multi_line_output = 3
26 | include_trailing_comma = true
27 | force_grid_wrap = 0
28 | use_parentheses = true
29 | ensure_newline_before_comments = true
30 |
31 | [tool.pytest]
32 | doctest_optionflags = [
33 | "NORMALIZE_WHITESPACE",
34 | "ELLIPSIS",
35 | "NUMBER",
36 | ]
--------------------------------------------------------------------------------
/peft_egg/scripts/log_reports.py:
--------------------------------------------------------------------------------
1 | import json, os
2 | from pathlib import Path
3 | from datetime import date
4 | from tabulate import tabulate
5 |
6 | failed = []
7 | passed = []
8 |
9 | group_info = []
10 |
11 | total_num_failed = 0
12 | empty_file = False or len(list(Path().glob("*.log"))) == 0
13 | for log in Path().glob("*.log"):
14 | section_num_failed = 0
15 | with open(log, "r") as f:
16 | nb_lines = sum(1 for _ in f)
17 | for i, line in f:
18 | line = json.loads(line)
19 | if line.get("nodeid", "") != "":
20 | test = line["nodeid"]
21 | if line.get("duration", None) is not None:
22 | duration = f'{line["duration"]:.4f}'
23 | if line.get("outcome", "") == "failed":
24 | section_num_failed += 1
25 | failed.append([test, duration, log.name.split('_')[0]])
26 | total_num_failed += 1
27 | else:
28 | passed.append([test, duration, log.name.split('_')[0]])
29 | if nb_lines == 0:
30 | empty_file = True
31 | group_info.append([str(log), section_num_failed, failed])
32 | os.remove(log)
33 | failed = []
34 | no_error_payload = {
35 | "type": "section",
36 | "text": {
37 | "type": "plain_text",
38 | "text": "🌞 There were no failures!" if not empty_file else "Something went wrong - please check GH action results.",
39 | "emoji": True
40 | }
41 | }
42 |
43 | message = ""
44 | payload = [
45 | {
46 | "type": "header",
47 | "text": {
48 | "type": "plain_text",
49 | "text": "🤗 Results of the {} PEFT scheduled tests.".format(os.environ.get("TEST_TYPE", "")),
50 | }
51 | },
52 | ]
53 | if total_num_failed > 0:
54 | for name, num_failed, failed_tests in group_info:
55 | if num_failed > 0:
56 | if num_failed == 1:
57 | message += f"*{name}: {num_failed} failed test*\n"
58 | else:
59 | message += f"*{name}: {num_failed} failed tests*\n"
60 | failed_table = []
61 | for test in failed_tests:
62 | failed_table.append(test[0].split("::"))
63 | failed_table = tabulate(failed_table, headers=["Test Location", "Test Case", "Test Name"], showindex="always", tablefmt="grid", maxcolwidths=[12, 12, 12])
64 | message += '\n```\n' +failed_table + '\n```'
65 | print(f'### {message}')
66 | else:
67 | payload.append(no_error_payload)
68 |
69 |
70 | if os.environ.get("TEST_TYPE", "") != "":
71 | from slack_sdk import WebClient
72 |
73 | if len(message) != 0:
74 | md_report = {
75 | "type": "section",
76 | "text": {
77 | "type": "mrkdwn",
78 | "text": message
79 | },
80 | }
81 | payload.append(md_report)
82 | action_button = {
83 | "type": "section",
84 | "text": {
85 | "type": "mrkdwn",
86 | "text": "*For more details:*"
87 | },
88 | "accessory": {
89 | "type": "button",
90 | "text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
91 | "url": f"https://github.com/huggingface/peft/actions/runs/{os.environ['GITHUB_RUN_ID']}",
92 | },
93 | }
94 | payload.append(action_button)
95 |
96 | date_report = {
97 | "type": "context",
98 | "elements": [
99 | {
100 | "type": "plain_text",
101 | "text": f"Nightly {os.environ.get('TEST_TYPE')} test results for {date.today()}",
102 | },
103 | ],
104 | }
105 | payload.append(date_report)
106 |
107 | print(payload)
108 |
109 | client = WebClient(token=os.environ.get("SLACK_API_TOKEN"))
110 | client.chat_postMessage(channel="#peft-ci-daily", text=message, blocks=payload)
111 |
--------------------------------------------------------------------------------
/peft_egg/scripts/stale.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team, the AllenNLP library authors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Script to close stale issue. Taken in part from the AllenNLP repository.
16 | https://github.com/allenai/allennlp.
17 | """
18 | from datetime import datetime as dt
19 | import os
20 |
21 | from github import Github
22 |
23 |
24 | LABELS_TO_EXEMPT = [
25 | "good first issue",
26 | "good second issue",
27 | "good difficult issue",
28 | "feature request",
29 | "new model",
30 | "wip",
31 | "PRs welcome to address this",
32 | ]
33 |
34 |
35 | def main():
36 | g = Github(os.environ["GITHUB_TOKEN"])
37 | repo = g.get_repo("huggingface/peft")
38 | open_issues = repo.get_issues(state="open")
39 |
40 | for issue in open_issues:
41 | comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True)
42 | last_comment = comments[0] if len(comments) > 0 else None
43 | if (
44 | last_comment is not None and last_comment.user.login == "github-actions[bot]"
45 | and (dt.utcnow() - issue.updated_at).days > 7
46 | and (dt.utcnow() - issue.created_at).days >= 30
47 | and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels())
48 | ):
49 | issue.edit(state="closed")
50 | elif (
51 | (dt.utcnow() - issue.updated_at).days > 23
52 | and (dt.utcnow() - issue.created_at).days >= 30
53 | and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels())
54 | ):
55 | issue.create_comment(
56 | "This issue has been automatically marked as stale because it has not had "
57 | "recent activity. If you think this still needs to be addressed "
58 | "please comment on this thread.\n\n"
59 | )
60 |
61 |
62 | if __name__ == "__main__":
63 | main()
64 |
--------------------------------------------------------------------------------
/peft_egg/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from setuptools import find_packages, setup
16 |
17 | extras = {}
18 | extras["quality"] = ["black ~= 22.0", "ruff>=0.0.241", "urllib3<=2.0.0"]
19 | extras["docs_specific"] = ["hf-doc-builder"]
20 | extras["dev"] = extras["quality"] + extras["docs_specific"]
21 | extras["test"] = extras["dev"] + ["pytest", "pytest-xdist", "parameterized", "datasets"]
22 |
23 | setup(
24 | name="peft",
25 | version="0.4.0.dev0",
26 | description="Parameter-Efficient Fine-Tuning (PEFT)",
27 | license_files=["LICENSE"],
28 | long_description=open("README.md", "r", encoding="utf-8").read(),
29 | long_description_content_type="text/markdown",
30 | keywords="deep learning",
31 | license="Apache",
32 | author="The HuggingFace team",
33 | author_email="sourab@huggingface.co",
34 | url="https://github.com/huggingface/peft",
35 | package_dir={"": "src"},
36 | packages=find_packages("src"),
37 | entry_points={},
38 | python_requires=">=3.7.0",
39 | install_requires=[
40 | "numpy>=1.17",
41 | "packaging>=20.0",
42 | "psutil",
43 | "pyyaml",
44 | "torch>=1.13.0",
45 | "transformers",
46 | "accelerate",
47 | "safetensors",
48 | ],
49 | extras_require=extras,
50 | classifiers=[
51 | "Development Status :: 5 - Production/Stable",
52 | "Intended Audience :: Developers",
53 | "Intended Audience :: Education",
54 | "Intended Audience :: Science/Research",
55 | "License :: OSI Approved :: Apache Software License",
56 | "Operating System :: OS Independent",
57 | "Programming Language :: Python :: 3",
58 | "Programming Language :: Python :: 3.7",
59 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
60 | ],
61 | )
62 |
63 | # Release checklist
64 | # 1. Change the version in __init__.py and setup.py.
65 | # 2. Commit these changes with the message: "Release: VERSION"
66 | # 3. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' "
67 | # Push the tag to git: git push --tags origin main
68 | # 4. Run the following commands in the top-level directory:
69 | # python setup.py bdist_wheel
70 | # python setup.py sdist
71 | # 5. Upload the package to the pypi test server first:
72 | # twine upload dist/* -r pypitest
73 | # twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
74 | # 6. Check that you can install it in a virtualenv by running:
75 | # pip install -i https://testpypi.python.org/pypi peft
76 | # 7. Upload the final version to actual pypi:
77 | # twine upload dist/* -r pypi
78 | # 8. Add release notes to the tag in github once everything is looking hunky-dory.
79 | # 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to master
80 |
--------------------------------------------------------------------------------
/peft_egg/src/peft/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all.
4 |
5 | # coding=utf-8
6 | # Copyright 2023-present the HuggingFace Inc. team.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | __version__ = "0.4.0.dev0"
21 |
22 | from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model
23 | from .peft_model import (
24 | PeftModel,
25 | PeftModelForCausalLM,
26 | PeftModelForSeq2SeqLM,
27 | PeftModelForSequenceClassification,
28 | PeftModelForTokenClassification,
29 | PeftModelForQuestionAnswering,
30 | )
31 | from .tuners import (
32 | AdaptionPromptConfig,
33 | AdaptionPromptModel,
34 | LoraConfig,
35 | LoraModel,
36 | AdaLoraConfig,
37 | AdaLoraModel,
38 | PrefixEncoder,
39 | PrefixTuningConfig,
40 | PromptEmbedding,
41 | PromptEncoder,
42 | PromptEncoderConfig,
43 | PromptEncoderReparameterizationType,
44 | PromptTuningConfig,
45 | PromptTuningInit,
46 | MeloConfig,
47 | MeloModel
48 | )
49 | from .utils import (
50 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
51 | PeftConfig,
52 | PeftType,
53 | PromptLearningConfig,
54 | TaskType,
55 | bloom_model_postprocess_past_key_value,
56 | get_peft_model_state_dict,
57 | prepare_model_for_int8_training,
58 | prepare_model_for_kbit_training,
59 | set_peft_model_state_dict,
60 | shift_tokens_right,
61 | )
62 |
--------------------------------------------------------------------------------
/peft_egg/src/peft/import_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import importlib
16 |
17 |
18 | def is_bnb_available():
19 | return importlib.util.find_spec("bitsandbytes") is not None
20 |
21 |
22 | def is_bnb_4bit_available():
23 | if not is_bnb_available():
24 | return False
25 |
26 | import bitsandbytes as bnb
27 |
28 | return hasattr(bnb.nn, "Linear4bit")
29 |
--------------------------------------------------------------------------------
/peft_egg/src/peft/mapping.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from .peft_model import (
17 | PeftModel,
18 | PeftModelForCausalLM,
19 | PeftModelForQuestionAnswering,
20 | PeftModelForSeq2SeqLM,
21 | PeftModelForSequenceClassification,
22 | PeftModelForTokenClassification,
23 | )
24 | from .tuners import (
25 | AdaLoraConfig,
26 | AdaptionPromptConfig,
27 | LoraConfig,
28 | PrefixTuningConfig,
29 | PromptEncoderConfig,
30 | PromptTuningConfig,
31 | )
32 | from .utils import PromptLearningConfig
33 |
34 |
35 | MODEL_TYPE_TO_PEFT_MODEL_MAPPING = {
36 | "SEQ_CLS": PeftModelForSequenceClassification,
37 | "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
38 | "CAUSAL_LM": PeftModelForCausalLM,
39 | "TOKEN_CLS": PeftModelForTokenClassification,
40 | "QUESTION_ANS": PeftModelForQuestionAnswering,
41 | }
42 |
43 | PEFT_TYPE_TO_CONFIG_MAPPING = {
44 | "ADAPTION_PROMPT": AdaptionPromptConfig,
45 | "PROMPT_TUNING": PromptTuningConfig,
46 | "PREFIX_TUNING": PrefixTuningConfig,
47 | "P_TUNING": PromptEncoderConfig,
48 | "LORA": LoraConfig,
49 | "ADALORA": AdaLoraConfig,
50 | }
51 |
52 |
53 | def get_peft_config(config_dict):
54 | """
55 | Returns a Peft config object from a dictionary.
56 |
57 | Args:
58 | config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters.
59 | """
60 |
61 | return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)
62 |
63 |
64 | def _prepare_prompt_learning_config(peft_config, model_config):
65 | if peft_config.num_layers is None:
66 | if "num_hidden_layers" in model_config:
67 | num_layers = model_config["num_hidden_layers"]
68 | elif "num_layers" in model_config:
69 | num_layers = model_config["num_layers"]
70 | elif "n_layer" in model_config:
71 | num_layers = model_config["n_layer"]
72 | else:
73 | raise ValueError("Please specify `num_layers` in `peft_config`")
74 | peft_config.num_layers = num_layers
75 |
76 | if peft_config.token_dim is None:
77 | if "hidden_size" in model_config:
78 | token_dim = model_config["hidden_size"]
79 | elif "n_embd" in model_config:
80 | token_dim = model_config["n_embd"]
81 | elif "d_model" in model_config:
82 | token_dim = model_config["d_model"]
83 | else:
84 | raise ValueError("Please specify `token_dim` in `peft_config`")
85 | peft_config.token_dim = token_dim
86 |
87 | if peft_config.num_attention_heads is None:
88 | if "num_attention_heads" in model_config:
89 | num_attention_heads = model_config["num_attention_heads"]
90 | elif "n_head" in model_config:
91 | num_attention_heads = model_config["n_head"]
92 | elif "num_heads" in model_config:
93 | num_attention_heads = model_config["num_heads"]
94 | elif "encoder_attention_heads" in model_config:
95 | num_attention_heads = model_config["encoder_attention_heads"]
96 | else:
97 | raise ValueError("Please specify `num_attention_heads` in `peft_config`")
98 | peft_config.num_attention_heads = num_attention_heads
99 |
100 | if getattr(peft_config, "encoder_hidden_size", None) is None:
101 | setattr(peft_config, "encoder_hidden_size", peft_config.token_dim)
102 |
103 | return peft_config
104 |
105 |
106 | def get_peft_model(model, peft_config, adapter_name="default") -> PeftModel:
107 | """
108 | Returns a Peft model object from a model and a config.
109 |
110 | Args:
111 | model ([`transformers.PreTrainedModel`]): Model to be wrapped.
112 | peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
113 | """
114 | model_config = model.config.to_dict() if hasattr(model.config, "to_dict") else model.config
115 | peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
116 | if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance(
117 | peft_config, PromptLearningConfig
118 | ):
119 | return PeftModel(model, peft_config, adapter_name=adapter_name)
120 | if isinstance(peft_config, PromptLearningConfig):
121 | peft_config = _prepare_prompt_learning_config(peft_config, model_config)
122 | return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)
123 |
--------------------------------------------------------------------------------
/peft_egg/src/peft/tuners/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all
4 |
5 | # coding=utf-8
6 | # Copyright 2023-present the HuggingFace Inc. team.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel
21 | from .lora import LoraConfig, LoraModel
22 | from .adalora import AdaLoraConfig, AdaLoraModel
23 | from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType
24 | from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
25 | from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
26 | from .melo import MeloConfig, MeloModel
--------------------------------------------------------------------------------
/peft_egg/src/peft/tuners/p_tuning.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import enum
17 | import warnings
18 | from dataclasses import dataclass, field
19 | from typing import Union
20 |
21 | import torch
22 |
23 | from ..utils import PeftType, PromptLearningConfig
24 |
25 |
26 | class PromptEncoderReparameterizationType(str, enum.Enum):
27 | MLP = "MLP"
28 | LSTM = "LSTM"
29 |
30 |
31 | @dataclass
32 | class PromptEncoderConfig(PromptLearningConfig):
33 | """
34 | This is the configuration class to store the configuration of a [`PromptEncoder`].
35 |
36 | Args:
37 | encoder_reparameterization_type (Union[[`PromptEncoderReparameterizationType`], `str`]):
38 | The type of reparameterization to use.
39 | encoder_hidden_size (`int`): The hidden size of the prompt encoder.
40 | encoder_num_layers (`int`): The number of layers of the prompt encoder.
41 | encoder_dropout (`float`): The dropout probability of the prompt encoder.
42 | """
43 |
44 | encoder_reparameterization_type: Union[str, PromptEncoderReparameterizationType] = field(
45 | default=PromptEncoderReparameterizationType.MLP,
46 | metadata={"help": "How to reparameterize the prompt encoder"},
47 | )
48 | encoder_hidden_size: int = field(
49 | default=None,
50 | metadata={"help": "The hidden size of the prompt encoder"},
51 | )
52 | encoder_num_layers: int = field(
53 | default=2,
54 | metadata={"help": "The number of layers of the prompt encoder"},
55 | )
56 | encoder_dropout: float = field(
57 | default=0.0,
58 | metadata={"help": "The dropout of the prompt encoder"},
59 | )
60 |
61 | def __post_init__(self):
62 | self.peft_type = PeftType.P_TUNING
63 |
64 |
65 | # Based on https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/modules/common/prompt_encoder.py
66 | # with some refactor
67 | class PromptEncoder(torch.nn.Module):
68 | """
69 | The prompt encoder network that is used to generate the virtual token embeddings for p-tuning.
70 |
71 | Args:
72 | config ([`PromptEncoderConfig`]): The configuration of the prompt encoder.
73 |
74 | Example:
75 |
76 | ```py
77 | >>> from peft import PromptEncoder, PromptEncoderConfig
78 |
79 | >>> config = PromptEncoderConfig(
80 | ... peft_type="P_TUNING",
81 | ... task_type="SEQ_2_SEQ_LM",
82 | ... num_virtual_tokens=20,
83 | ... token_dim=768,
84 | ... num_transformer_submodules=1,
85 | ... num_attention_heads=12,
86 | ... num_layers=12,
87 | ... encoder_reparameterization_type="MLP",
88 | ... encoder_hidden_size=768,
89 | ... )
90 |
91 | >>> prompt_encoder = PromptEncoder(config)
92 | ```
93 |
94 | **Attributes**:
95 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt encoder.
96 | - **mlp_head** (`torch.nn.Sequential`) -- The MLP head of the prompt encoder if `inference_mode=False`.
97 | - **lstm_head** (`torch.nn.LSTM`) -- The LSTM head of the prompt encoder if `inference_mode=False` and
98 | `encoder_reparameterization_type="LSTM"`.
99 | - **token_dim** (`int`) -- The hidden embedding dimension of the base transformer model.
100 | - **input_size** (`int`) -- The input size of the prompt encoder.
101 | - **output_size** (`int`) -- The output size of the prompt encoder.
102 | - **hidden_size** (`int`) -- The hidden size of the prompt encoder.
103 | - **total_virtual_tokens** (`int`): The total number of virtual tokens of the
104 | prompt encoder.
105 | - **encoder_type** (Union[[`PromptEncoderReparameterizationType`], `str`]): The encoder type of the prompt
106 | encoder.
107 |
108 |
109 | Input shape: (`batch_size`, `total_virtual_tokens`)
110 |
111 | Output shape: (`batch_size`, `total_virtual_tokens`, `token_dim`)
112 | """
113 |
114 | def __init__(self, config):
115 | super().__init__()
116 | self.token_dim = config.token_dim
117 | self.input_size = self.token_dim
118 | self.output_size = self.token_dim
119 | self.hidden_size = config.encoder_hidden_size
120 | self.total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
121 | self.encoder_type = config.encoder_reparameterization_type
122 |
123 | # embedding
124 | self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim)
125 | if not config.inference_mode:
126 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
127 | lstm_dropout = config.encoder_dropout
128 | num_layers = config.encoder_num_layers
129 | # LSTM
130 | self.lstm_head = torch.nn.LSTM(
131 | input_size=self.input_size,
132 | hidden_size=self.hidden_size,
133 | num_layers=num_layers,
134 | dropout=lstm_dropout,
135 | bidirectional=True,
136 | batch_first=True,
137 | )
138 |
139 | self.mlp_head = torch.nn.Sequential(
140 | torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2),
141 | torch.nn.ReLU(),
142 | torch.nn.Linear(self.hidden_size * 2, self.output_size),
143 | )
144 |
145 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
146 | warnings.warn(
147 | f"for {self.encoder_type}, the `encoder_num_layers` is ignored. Exactly 2 MLP layers are used."
148 | )
149 | layers = [
150 | torch.nn.Linear(self.input_size, self.hidden_size),
151 | torch.nn.ReLU(),
152 | torch.nn.Linear(self.hidden_size, self.hidden_size),
153 | torch.nn.ReLU(),
154 | torch.nn.Linear(self.hidden_size, self.output_size),
155 | ]
156 | self.mlp_head = torch.nn.Sequential(*layers)
157 |
158 | else:
159 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")
160 |
161 | def forward(self, indices):
162 | input_embeds = self.embedding(indices)
163 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
164 | output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0])
165 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
166 | output_embeds = self.mlp_head(input_embeds)
167 | else:
168 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")
169 |
170 | return output_embeds
171 |
--------------------------------------------------------------------------------
/peft_egg/src/peft/tuners/prefix_tuning.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | from dataclasses import dataclass, field
18 |
19 | import torch
20 |
21 | from ..utils import PeftType, PromptLearningConfig
22 |
23 |
24 | @dataclass
25 | class PrefixTuningConfig(PromptLearningConfig):
26 | """
27 | This is the configuration class to store the configuration of a [`PrefixEncoder`].
28 |
29 | Args:
30 | encoder_hidden_size (`int`): The hidden size of the prompt encoder.
31 | prefix_projection (`bool`): Whether to project the prefix embeddings.
32 | """
33 |
34 | encoder_hidden_size: int = field(
35 | default=None,
36 | metadata={"help": "The hidden size of the encoder"},
37 | )
38 | prefix_projection: bool = field(
39 | default=False,
40 | metadata={"help": "Whether to project the prefix tokens"},
41 | )
42 |
43 | def __post_init__(self):
44 | self.peft_type = PeftType.PREFIX_TUNING
45 |
46 |
47 | # Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py
48 | # with some refactor
49 | class PrefixEncoder(torch.nn.Module):
50 | r"""
51 | The `torch.nn` model to encode the prefix.
52 |
53 | Args:
54 | config ([`PrefixTuningConfig`]): The configuration of the prefix encoder.
55 |
56 | Example:
57 |
58 | ```py
59 | >>> from peft import PrefixEncoder, PrefixTuningConfig
60 |
61 | >>> config = PrefixTuningConfig(
62 | ... peft_type="PREFIX_TUNING",
63 | ... task_type="SEQ_2_SEQ_LM",
64 | ... num_virtual_tokens=20,
65 | ... token_dim=768,
66 | ... num_transformer_submodules=1,
67 | ... num_attention_heads=12,
68 | ... num_layers=12,
69 | ... encoder_hidden_size=768,
70 | ... )
71 | >>> prefix_encoder = PrefixEncoder(config)
72 | ```
73 |
74 | **Attributes**:
75 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prefix encoder.
76 | - **transform** (`torch.nn.Sequential`) -- The two-layer MLP to transform the prefix embeddings if
77 | `prefix_projection` is `True`.
78 | - **prefix_projection** (`bool`) -- Whether to project the prefix embeddings.
79 |
80 | Input shape: (`batch_size`, `num_virtual_tokens`)
81 |
82 | Output shape: (`batch_size`, `num_virtual_tokens`, `2*layers*hidden`)
83 | """
84 |
85 | def __init__(self, config):
86 | super().__init__()
87 | self.prefix_projection = config.prefix_projection
88 | token_dim = config.token_dim
89 | num_layers = config.num_layers
90 | encoder_hidden_size = config.encoder_hidden_size
91 | num_virtual_tokens = config.num_virtual_tokens
92 | if self.prefix_projection and not config.inference_mode:
93 | # Use a two-layer MLP to encode the prefix
94 | self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
95 | self.transform = torch.nn.Sequential(
96 | torch.nn.Linear(token_dim, encoder_hidden_size),
97 | torch.nn.Tanh(),
98 | torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
99 | )
100 | else:
101 | self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)
102 |
103 | def forward(self, prefix: torch.Tensor):
104 | if self.prefix_projection:
105 | prefix_tokens = self.embedding(prefix)
106 | past_key_values = self.transform(prefix_tokens)
107 | else:
108 | past_key_values = self.embedding(prefix)
109 | return past_key_values
110 |
--------------------------------------------------------------------------------
/peft_egg/src/peft/tuners/prompt_tuning.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import enum
17 | import math
18 | from dataclasses import dataclass, field
19 | from typing import Optional, Union
20 |
21 | import torch
22 |
23 | from ..utils import PeftType, PromptLearningConfig
24 |
25 |
26 | class PromptTuningInit(str, enum.Enum):
27 | TEXT = "TEXT"
28 | RANDOM = "RANDOM"
29 |
30 |
31 | @dataclass
32 | class PromptTuningConfig(PromptLearningConfig):
33 | """
34 | This is the configuration class to store the configuration of a [`PromptEmbedding`].
35 |
36 | Args:
37 | prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): The initialization of the prompt embedding.
38 | prompt_tuning_init_text (`str`, *optional*):
39 | The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`.
40 | tokenizer_name_or_path (`str`, *optional*):
41 | The name or path of the tokenizer. Only used if `prompt_tuning_init` is `TEXT`.
42 | """
43 |
44 | prompt_tuning_init: Union[PromptTuningInit, str] = field(
45 | default=PromptTuningInit.RANDOM,
46 | metadata={"help": "How to initialize the prompt tuning parameters"},
47 | )
48 | prompt_tuning_init_text: Optional[str] = field(
49 | default=None,
50 | metadata={
51 | "help": "The text to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`"
52 | },
53 | )
54 | tokenizer_name_or_path: Optional[str] = field(
55 | default=None,
56 | metadata={
57 | "help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`"
58 | },
59 | )
60 |
61 | def __post_init__(self):
62 | self.peft_type = PeftType.PROMPT_TUNING
63 |
64 |
65 | class PromptEmbedding(torch.nn.Module):
66 | """
67 | The model to encode virtual tokens into prompt embeddings.
68 |
69 | Args:
70 | config ([`PromptTuningConfig`]): The configuration of the prompt embedding.
71 | word_embeddings (`torch.nn.Module`): The word embeddings of the base transformer model.
72 |
73 | **Attributes**:
74 | - **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt embedding.
75 |
76 | Example:
77 |
78 | ```py
79 | >>> from peft import PromptEmbedding, PromptTuningConfig
80 |
81 | >>> config = PromptTuningConfig(
82 | ... peft_type="PROMPT_TUNING",
83 | ... task_type="SEQ_2_SEQ_LM",
84 | ... num_virtual_tokens=20,
85 | ... token_dim=768,
86 | ... num_transformer_submodules=1,
87 | ... num_attention_heads=12,
88 | ... num_layers=12,
89 | ... prompt_tuning_init="TEXT",
90 | ... prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral",
91 | ... tokenizer_name_or_path="t5-base",
92 | ... )
93 |
94 | >>> # t5_model.shared is the word embeddings of the base model
95 | >>> prompt_embedding = PromptEmbedding(config, t5_model.shared)
96 | ```
97 |
98 | Input Shape: (`batch_size`, `total_virtual_tokens`)
99 |
100 | Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`)
101 | """
102 |
103 | def __init__(self, config, word_embeddings):
104 | super().__init__()
105 |
106 | total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
107 | self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
108 | if config.prompt_tuning_init == PromptTuningInit.TEXT:
109 | from transformers import AutoTokenizer
110 |
111 | tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
112 | init_text = config.prompt_tuning_init_text
113 | init_token_ids = tokenizer(init_text)["input_ids"]
114 | # Trim or iterate until num_text_tokens matches total_virtual_tokens
115 | num_text_tokens = len(init_token_ids)
116 | if num_text_tokens > total_virtual_tokens:
117 | init_token_ids = init_token_ids[:total_virtual_tokens]
118 | elif num_text_tokens < total_virtual_tokens:
119 | num_reps = math.ceil(total_virtual_tokens / num_text_tokens)
120 | init_token_ids = init_token_ids * num_reps
121 | init_token_ids = init_token_ids[:total_virtual_tokens]
122 |
123 | word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()
124 | word_embedding_weights = word_embedding_weights.to(torch.float32)
125 | self.embedding.weight = torch.nn.Parameter(word_embedding_weights)
126 |
127 | def forward(self, indices):
128 | # Just get embeddings
129 | prompt_embeddings = self.embedding(indices)
130 | return prompt_embeddings
131 |
--------------------------------------------------------------------------------
/peft_egg/src/peft/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all
4 |
5 | # coding=utf-8
6 | # Copyright 2023-present the HuggingFace Inc. team.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType
21 | from .other import (
22 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
23 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
24 | TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
25 | COMMON_LAYERS_PATTERN,
26 | CONFIG_NAME,
27 | WEIGHTS_NAME,
28 | SAFETENSORS_WEIGHTS_NAME,
29 | _set_trainable,
30 | add_library_to_model_card,
31 | bloom_model_postprocess_past_key_value,
32 | prepare_model_for_int8_training,
33 | prepare_model_for_kbit_training,
34 | shift_tokens_right,
35 | transpose,
36 | _get_submodules,
37 | _set_adapter,
38 | _freeze_adapter,
39 | ModulesToSaveWrapper,
40 | )
41 | from .hub_utils import hub_file_exists
42 | from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict
43 |
--------------------------------------------------------------------------------
/peft_egg/src/peft/utils/hub_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from huggingface_hub import get_hf_file_metadata, hf_hub_url
17 | from huggingface_hub.utils import EntryNotFoundError
18 |
19 |
20 | def hub_file_exists(repo_id: str, filename: str, revision: str = None, repo_type: str = None) -> bool:
21 | r"""
22 | Checks if a file exists in a remote Hub repository.
23 | """
24 | url = hf_hub_url(repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision)
25 | try:
26 | get_hf_file_metadata(url)
27 | return True
28 | except EntryNotFoundError:
29 | return False
30 |
--------------------------------------------------------------------------------
/peft_egg/src/peft/utils/save_and_load.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from .config import PeftType, PromptLearningConfig
17 |
18 |
19 | def get_peft_model_state_dict(model, state_dict=None, adapter_name="default"):
20 | """
21 | Get the state dict of the Peft model.
22 |
23 | Args:
24 | model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
25 | the model should be the underlying model/unwrapped model (i.e. model.module).
26 | state_dict (`dict`, *optional*, defaults to `None`):
27 | The state dict of the model. If not provided, the state dict of the model
28 | will be used.
29 | """
30 | config = model.peft_config[adapter_name]
31 | if state_dict is None:
32 | state_dict = model.state_dict()
33 | if config.peft_type in (PeftType.LORA, PeftType.ADALORA):
34 | # to_return = lora_state_dict(model, bias=model.peft_config.bias)
35 | # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
36 | # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
37 | bias = config.bias
38 | if bias == "none":
39 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
40 | elif bias == "all":
41 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
42 | elif bias == "lora_only":
43 | to_return = {}
44 | for k in state_dict:
45 | if "lora_" in k:
46 | to_return[k] = state_dict[k]
47 | bias_name = k.split("lora_")[0] + "bias"
48 | if bias_name in state_dict:
49 | to_return[bias_name] = state_dict[bias_name]
50 | else:
51 | raise NotImplementedError
52 | to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))}
53 | if config.peft_type == PeftType.ADALORA:
54 | rank_pattern = config.rank_pattern
55 | if rank_pattern is not None:
56 | rank_pattern = {k.replace(f".{adapter_name}", ""): v for k, v in rank_pattern.items()}
57 | config.rank_pattern = rank_pattern
58 | to_return = model.resize_state_dict_by_rank_pattern(rank_pattern, to_return, adapter_name)
59 |
60 | elif config.peft_type == PeftType.ADAPTION_PROMPT:
61 | to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")}
62 | elif isinstance(config, PromptLearningConfig):
63 | to_return = {}
64 | if config.inference_mode:
65 | prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight
66 | else:
67 | prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name)
68 | to_return["prompt_embeddings"] = prompt_embeddings
69 | else:
70 | raise NotImplementedError
71 | if model.modules_to_save is not None:
72 | for key, value in state_dict.items():
73 | if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save):
74 | to_return[key.replace("modules_to_save.", "")] = value
75 |
76 | to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()}
77 | return to_return
78 |
79 |
80 | def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="default"):
81 | """
82 | Set the state dict of the Peft model.
83 |
84 | Args:
85 | model ([`PeftModel`]): The Peft model.
86 | peft_model_state_dict (`dict`): The state dict of the Peft model.
87 | """
88 | config = model.peft_config[adapter_name]
89 | state_dict = {}
90 | if model.modules_to_save is not None:
91 | for key, value in peft_model_state_dict.items():
92 | if any(module_name in key for module_name in model.modules_to_save):
93 | for module_name in model.modules_to_save:
94 | if module_name in key:
95 | key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}")
96 | break
97 | state_dict[key] = value
98 | else:
99 | state_dict = peft_model_state_dict
100 |
101 | if config.peft_type in (PeftType.LORA, PeftType.ADALORA):
102 | peft_model_state_dict = {}
103 | for k, v in state_dict.items():
104 | if "lora_" in k:
105 | suffix = k.split("lora_")[1]
106 | if "." in suffix:
107 | suffix_to_replace = ".".join(suffix.split(".")[1:])
108 | k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
109 | else:
110 | k = f"{k}.{adapter_name}"
111 | peft_model_state_dict[k] = v
112 | else:
113 | peft_model_state_dict[k] = v
114 | if config.peft_type == PeftType.ADALORA:
115 | rank_pattern = config.rank_pattern
116 | if rank_pattern is not None:
117 | model.resize_modules_by_rank_pattern(rank_pattern, adapter_name)
118 | elif isinstance(config, PromptLearningConfig) or config.peft_type == PeftType.ADAPTION_PROMPT:
119 | peft_model_state_dict = state_dict
120 | else:
121 | raise NotImplementedError
122 |
123 | load_result = model.load_state_dict(peft_model_state_dict, strict=False)
124 | if isinstance(config, PromptLearningConfig):
125 | model.prompt_encoder[adapter_name].embedding.load_state_dict(
126 | {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True
127 | )
128 | return load_result
129 |
--------------------------------------------------------------------------------
/peft_egg/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/peft_egg/tests/__init__.py
--------------------------------------------------------------------------------
/peft_egg/tests/test_config.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import os
16 | import tempfile
17 | import unittest
18 |
19 | from peft import (
20 | AdaptionPromptConfig,
21 | LoraConfig,
22 | PeftConfig,
23 | PrefixTuningConfig,
24 | PromptEncoderConfig,
25 | PromptTuningConfig,
26 | )
27 |
28 |
29 | PEFT_MODELS_TO_TEST = [("lewtun/tiny-random-OPTForCausalLM-delta", "v1")]
30 |
31 |
32 | class PeftConfigTestMixin:
33 | all_config_classes = (
34 | LoraConfig,
35 | PromptEncoderConfig,
36 | PrefixTuningConfig,
37 | PromptTuningConfig,
38 | AdaptionPromptConfig,
39 | )
40 |
41 |
42 | class PeftConfigTester(unittest.TestCase, PeftConfigTestMixin):
43 | def test_methods(self):
44 | r"""
45 | Test if all configs have the expected methods. Here we test
46 | - to_dict
47 | - save_pretrained
48 | - from_pretrained
49 | - from_json_file
50 | """
51 | # test if all configs have the expected methods
52 | for config_class in self.all_config_classes:
53 | config = config_class()
54 | self.assertTrue(hasattr(config, "to_dict"))
55 | self.assertTrue(hasattr(config, "save_pretrained"))
56 | self.assertTrue(hasattr(config, "from_pretrained"))
57 | self.assertTrue(hasattr(config, "from_json_file"))
58 |
59 | def test_task_type(self):
60 | for config_class in self.all_config_classes:
61 | # assert this will not fail
62 | _ = config_class(task_type="test")
63 |
64 | def test_from_pretrained(self):
65 | r"""
66 | Test if the config is correctly loaded using:
67 | - from_pretrained
68 | """
69 | for config_class in self.all_config_classes:
70 | for model_name, revision in PEFT_MODELS_TO_TEST:
71 | # Test we can load config from delta
72 | _ = config_class.from_pretrained(model_name, revision=revision)
73 |
74 | def test_save_pretrained(self):
75 | r"""
76 | Test if the config is correctly saved and loaded using
77 | - save_pretrained
78 | """
79 | for config_class in self.all_config_classes:
80 | config = config_class()
81 | with tempfile.TemporaryDirectory() as tmp_dirname:
82 | config.save_pretrained(tmp_dirname)
83 |
84 | config_from_pretrained = config_class.from_pretrained(tmp_dirname)
85 | self.assertEqual(config.to_dict(), config_from_pretrained.to_dict())
86 |
87 | def test_from_json_file(self):
88 | for config_class in self.all_config_classes:
89 | config = config_class()
90 | with tempfile.TemporaryDirectory() as tmp_dirname:
91 | config.save_pretrained(tmp_dirname)
92 |
93 | config_from_json = config_class.from_json_file(os.path.join(tmp_dirname, "adapter_config.json"))
94 | self.assertEqual(config.to_dict(), config_from_json)
95 |
96 | def test_to_dict(self):
97 | r"""
98 | Test if the config can be correctly converted to a dict using:
99 | - to_dict
100 | - __dict__
101 | """
102 | for config_class in self.all_config_classes:
103 | config = config_class()
104 | self.assertEqual(config.to_dict(), config.__dict__)
105 | self.assertTrue(isinstance(config.to_dict(), dict))
106 |
107 | def test_from_pretrained_cache_dir(self):
108 | r"""
109 | Test if the config is correctly loaded with extra kwargs
110 | """
111 | with tempfile.TemporaryDirectory() as tmp_dirname:
112 | for config_class in self.all_config_classes:
113 | for model_name, revision in PEFT_MODELS_TO_TEST:
114 | # Test we can load config from delta
115 | _ = config_class.from_pretrained(model_name, revision=revision, cache_dir=tmp_dirname)
116 |
117 | def test_from_pretrained_cache_dir_remote(self):
118 | r"""
119 | Test if the config is correctly loaded with a checkpoint from the hub
120 | """
121 | with tempfile.TemporaryDirectory() as tmp_dirname:
122 | _ = PeftConfig.from_pretrained("ybelkada/test-st-lora", cache_dir=tmp_dirname)
123 | self.assertTrue("models--ybelkada--test-st-lora" in os.listdir(tmp_dirname))
124 |
125 | def test_set_attributes(self):
126 | # manually set attributes and check if they are correctly written
127 | for config_class in self.all_config_classes:
128 | config = config_class(peft_type="test")
129 |
130 | # save pretrained
131 | with tempfile.TemporaryDirectory() as tmp_dirname:
132 | config.save_pretrained(tmp_dirname)
133 |
134 | config_from_pretrained = config_class.from_pretrained(tmp_dirname)
135 | self.assertEqual(config.to_dict(), config_from_pretrained.to_dict())
136 |
--------------------------------------------------------------------------------
/peft_egg/tests/test_decoder_models.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import unittest
16 |
17 | import torch
18 | from parameterized import parameterized
19 | from transformers import AutoModelForCausalLM
20 |
21 | from .testing_common import PeftCommonTester, PeftTestConfigManager
22 |
23 |
24 | PEFT_DECODER_MODELS_TO_TEST = [
25 | "hf-internal-testing/tiny-random-OPTForCausalLM",
26 | "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
27 | "hf-internal-testing/tiny-random-GPT2LMHeadModel",
28 | "hf-internal-testing/tiny-random-BloomForCausalLM",
29 | "hf-internal-testing/tiny-random-gpt_neo",
30 | "hf-internal-testing/tiny-random-GPTJForCausalLM",
31 | "hf-internal-testing/tiny-random-GPTBigCodeForCausalLM",
32 | "HuggingFaceM4/tiny-random-LlamaForCausalLM",
33 | ]
34 |
35 | FULL_GRID = {
36 | "model_ids": PEFT_DECODER_MODELS_TO_TEST,
37 | "task_type": "CAUSAL_LM",
38 | }
39 |
40 |
41 | def skip_non_pt_mqa(test_list):
42 | r"""
43 | Skip tests that are prefix tuning for MQA models (not supported yet)
44 | """
45 | return [test for test in test_list if not ("prefix_tuning" in test[0] and "GPTBigCodeForCausalLM" in test[0])]
46 |
47 |
48 | class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester):
49 | r"""
50 | Test if the PeftModel behaves as expected. This includes:
51 | - test if the model has the expected methods
52 |
53 | We use parametrized.expand for debugging purposes to test each model individually.
54 | """
55 | transformers_class = AutoModelForCausalLM
56 |
57 | def prepare_inputs_for_testing(self):
58 | input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device)
59 | attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
60 |
61 | input_dict = {
62 | "input_ids": input_ids,
63 | "attention_mask": attention_mask,
64 | }
65 |
66 | return input_dict
67 |
68 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
69 | def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs):
70 | self._test_model_attr(model_id, config_cls, config_kwargs)
71 |
72 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
73 | def test_adapter_name(self, test_name, model_id, config_cls, config_kwargs):
74 | self._test_adapter_name(model_id, config_cls, config_kwargs)
75 |
76 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
77 | def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs):
78 | self._test_prepare_for_training(model_id, config_cls, config_kwargs)
79 |
80 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
81 | def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs):
82 | self._test_save_pretrained(model_id, config_cls, config_kwargs)
83 |
84 | @parameterized.expand(
85 | PeftTestConfigManager.get_grid_parameters(
86 | {
87 | "model_ids": PEFT_DECODER_MODELS_TO_TEST,
88 | "lora_kwargs": {"init_lora_weights": [False]},
89 | "task_type": "CAUSAL_LM",
90 | },
91 | )
92 | )
93 | def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
94 | self._test_merge_layers(model_id, config_cls, config_kwargs)
95 |
96 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_non_pt_mqa))
97 | def test_generate(self, test_name, model_id, config_cls, config_kwargs):
98 | self._test_generate(model_id, config_cls, config_kwargs)
99 |
100 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_non_pt_mqa))
101 | def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs):
102 | self._test_generate_half_prec(model_id, config_cls, config_kwargs)
103 |
104 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
105 | def test_training_decoders(self, test_name, model_id, config_cls, config_kwargs):
106 | self._test_training(model_id, config_cls, config_kwargs)
107 |
108 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
109 | def test_training_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs):
110 | self._test_training_layer_indexing(model_id, config_cls, config_kwargs)
111 |
112 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
113 | def test_training_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs):
114 | self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
115 |
116 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
117 | def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs):
118 | self._test_inference_safetensors(model_id, config_cls, config_kwargs)
119 |
120 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
121 | def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs):
122 | self._test_peft_model_device_map(model_id, config_cls, config_kwargs)
123 |
--------------------------------------------------------------------------------
/peft_egg/tests/test_encoder_decoder_models.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import unittest
16 |
17 | import torch
18 | from parameterized import parameterized
19 | from transformers import AutoModelForSeq2SeqLM
20 |
21 | from .testing_common import PeftCommonTester, PeftTestConfigManager
22 |
23 |
24 | PEFT_ENCODER_DECODER_MODELS_TO_TEST = [
25 | "ybelkada/tiny-random-T5ForConditionalGeneration-calibrated",
26 | "hf-internal-testing/tiny-random-BartForConditionalGeneration",
27 | ]
28 |
29 | FULL_GRID = {"model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST, "task_type": "SEQ_2_SEQ_LM"}
30 |
31 |
32 | def skip_non_lora_or_pt(test_list):
33 | r"""
34 | Skip tests that are not lora or prefix tuning
35 | """
36 | return [test for test in test_list if ("lora" in test[0] or "prefix_tuning" in test[0])]
37 |
38 |
39 | class PeftEncoderDecoderModelTester(unittest.TestCase, PeftCommonTester):
40 | r"""
41 | Test if the PeftModel behaves as expected. This includes:
42 | - test if the model has the expected methods
43 |
44 | We use parametrized.expand for debugging purposes to test each model individually.
45 | """
46 | transformers_class = AutoModelForSeq2SeqLM
47 |
48 | def prepare_inputs_for_testing(self):
49 | input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device)
50 | decoder_input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device)
51 | attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
52 |
53 | input_dict = {
54 | "input_ids": input_ids,
55 | "decoder_input_ids": decoder_input_ids,
56 | "attention_mask": attention_mask,
57 | }
58 |
59 | return input_dict
60 |
61 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
62 | def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs):
63 | self._test_model_attr(model_id, config_cls, config_kwargs)
64 |
65 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
66 | def test_adapter_name(self, test_name, model_id, config_cls, config_kwargs):
67 | self._test_adapter_name(model_id, config_cls, config_kwargs)
68 |
69 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
70 | def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs):
71 | self._test_prepare_for_training(model_id, config_cls, config_kwargs)
72 |
73 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
74 | def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs):
75 | self._test_save_pretrained(model_id, config_cls, config_kwargs)
76 |
77 | @parameterized.expand(
78 | PeftTestConfigManager.get_grid_parameters(
79 | {
80 | "model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST,
81 | "lora_kwargs": {"init_lora_weights": [False]},
82 | "task_type": "SEQ_2_SEQ_LM",
83 | },
84 | )
85 | )
86 | def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
87 | self._test_merge_layers(model_id, config_cls, config_kwargs)
88 |
89 | # skip non lora models - generate does not work for prefix tuning, prompt tuning
90 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_non_lora_or_pt))
91 | def test_generate(self, test_name, model_id, config_cls, config_kwargs):
92 | self._test_generate(model_id, config_cls, config_kwargs)
93 |
94 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
95 | def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs):
96 | self._test_generate_half_prec(model_id, config_cls, config_kwargs)
97 |
98 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
99 | def test_training_encoder_decoders(self, test_name, model_id, config_cls, config_kwargs):
100 | self._test_training(model_id, config_cls, config_kwargs)
101 |
102 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
103 | def test_training_encoder_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs):
104 | self._test_training_layer_indexing(model_id, config_cls, config_kwargs)
105 |
106 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
107 | def test_training_encoder_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs):
108 | self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs)
109 |
110 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
111 | def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs):
112 | self._test_inference_safetensors(model_id, config_cls, config_kwargs)
113 |
114 | @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
115 | def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs):
116 | self._test_peft_model_device_map(model_id, config_cls, config_kwargs)
117 |
--------------------------------------------------------------------------------
/peft_egg/tests/testing_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import unittest
16 |
17 | import torch
18 |
19 |
20 | def require_torch_gpu(test_case):
21 | """
22 | Decorator marking a test that requires a GPU. Will be skipped when no GPU is available.
23 | """
24 | if not torch.cuda.is_available():
25 | return unittest.skip("test requires GPU")(test_case)
26 | else:
27 | return test_case
28 |
29 |
30 | def require_torch_multi_gpu(test_case):
31 | """
32 | Decorator marking a test that requires multiple GPUs. Will be skipped when less than 2 GPUs are available.
33 | """
34 | if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
35 | return unittest.skip("test requires multiple GPUs")(test_case)
36 | else:
37 | return test_case
38 |
39 |
40 | def require_bitsandbytes(test_case):
41 | """
42 | Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library is not installed.
43 | """
44 | try:
45 | import bitsandbytes # noqa: F401
46 | except ImportError:
47 | return unittest.skip("test requires bitsandbytes")(test_case)
48 | else:
49 | return test_case
50 |
--------------------------------------------------------------------------------
/plot/ablation/ablation_cluster_num.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import matplotlib.pyplot as plt
4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
5 |
6 | colors = ['blue', 'red', 'orange','green']
7 | base_dir = '../logs/VecDB'
8 | settings = ['T5Small_5000_e50', 'T5Small_5000_e75', 'T5Small_5000_e100']
9 | CLUSTER = []
10 | CONFLICTS = []
11 | FORGET = []
12 |
13 | r_list = []
14 |
15 | t = list(range(0,5001,200))
16 |
17 | for index, setting in enumerate(settings):
18 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f:
19 | loaded_dict = pickle.load(f)
20 | r_list.append(loaded_dict)
21 |
22 |
23 | for index, setting in enumerate(r_list):
24 | VecDB_log = setting['all_VecDB']
25 | cluster = [0]
26 | for x in t[1:]:
27 | cluster.append(VecDB_log[x]['num_cluster'])
28 | CLUSTER.append(cluster)
29 |
30 |
31 |
32 | fig, axs = plt.subplots(1, 1, figsize=(4,4),dpi=800)
33 | # axs = fig.axes()
34 |
35 | axs.plot(t, CLUSTER[0], label = '$R$ = 50',linewidth=4,color=colors[0])
36 | axs.plot(t, CLUSTER[1], label = '$R$ = 75',linewidth=4,color=colors[1])
37 | axs.plot(t, CLUSTER[2], label = '$R$ = 100',linewidth=4,color=colors[2])
38 | plt.axhline(y = 907, color = 'green', linestyle = '--',linewidth = 4,label = 'Answer Num.')
39 | axs.set_xlim(-100,5100)
40 | axs.set_ylim(0,4200)
41 | axs.xaxis.set_major_locator(MultipleLocator(1000))
42 | axs.set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
43 | axs.yaxis.set_major_locator(MultipleLocator(1000))
44 | axs.set_yticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
45 | axs.set_xlabel('Edits',fontweight='bold',family = 'serif',fontsize=18)
46 | axs.set_ylabel('Cluster Num',fontweight='bold',family = 'serif',fontsize=18)
47 | axs.grid(True)
48 | plt.xticks(fontsize=16)
49 | plt.yticks(fontsize=16)
50 | plt.legend(handlelength=2,fontsize=13)
51 | plt.tight_layout()
52 | plt.savefig("./ablation_res/T5Small_cluster")
--------------------------------------------------------------------------------
/plot/ablation/ablation_conflicts_num.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import matplotlib.pyplot as plt
4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
5 |
6 | colors = ['blue', 'red', 'orange','green']
7 | base_dir = '../logs/VecDB'
8 | settings = ['T5Small_5000_e50', 'T5Small_5000_e75', 'T5Small_5000_e100']
9 | CLUSTER = []
10 | CONFLICTS = []
11 | FORGET = []
12 |
13 | r_list = []
14 |
15 | t = list(range(0,5001,200))
16 |
17 | for index, setting in enumerate(settings):
18 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f:
19 | loaded_dict = pickle.load(f)
20 | r_list.append(loaded_dict)
21 |
22 |
23 | for index, setting in enumerate(r_list):
24 | VecDB_log = setting['all_VecDB']
25 | conflicts = [0]
26 | for x in t[1:]:
27 | conflicts.append(VecDB_log[x]['conflict_num'])
28 | CONFLICTS.append(conflicts)
29 |
30 |
31 |
32 | fig, axs = plt.subplots(1, 1, figsize=(4,4),dpi=800)
33 | # axs = fig.axes()
34 |
35 | axs.plot(t, CONFLICTS[0], label = '$R$ = 50',linewidth=4,color=colors[0])
36 | axs.plot(t, CONFLICTS[1], label = '$R$ = 75',linewidth=4,color=colors[1])
37 | axs.plot(t, CONFLICTS[2], label = '$R$ = 100',linewidth=4,color=colors[2])
38 |
39 | axs.set_xlim(-100,5100)
40 | axs.set_ylim(0,800)
41 | axs.xaxis.set_major_locator(MultipleLocator(1000))
42 | axs.set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
43 | axs.yaxis.set_major_locator(MultipleLocator(200))
44 | axs.set_xlabel('Edits',fontweight='bold',family = 'serif',fontsize=18)
45 | axs.set_ylabel('Conflicts',fontweight='bold',family = 'serif',fontsize=18)
46 | axs.grid(True)
47 | plt.xticks(fontsize=16)
48 | plt.yticks(fontsize=16)
49 | plt.legend(handlelength=2,fontsize=13)
50 | plt.tight_layout()
51 | plt.savefig("./ablation_res/T5Small_conflicts")
--------------------------------------------------------------------------------
/plot/ablation/ablation_forget.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import matplotlib.pyplot as plt
4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
5 |
6 | colors = ['blue', 'red', 'orange','green']
7 | base_dir = '../logs/VecDB'
8 | settings = ['T5Small_5000_e50', 'T5Small_5000_e75', 'T5Small_5000_e100']
9 | CLUSTER = []
10 | CONFLICTS = []
11 | FORGET = []
12 |
13 | r_list = []
14 |
15 | t = list(range(0,5001,200))
16 |
17 | for index, setting in enumerate(settings):
18 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f:
19 | loaded_dict = pickle.load(f)
20 | r_list.append(loaded_dict)
21 |
22 |
23 | for index, setting in enumerate(r_list):
24 | VecDB_log = setting['all_VecDB']
25 | forget = [0]
26 | for x in t[1:]:
27 | forget.append(VecDB_log[x]['forget_keys'])
28 | FORGET.append(forget)
29 |
30 |
31 |
32 | fig, axs = plt.subplots(1, 1, figsize=(4,4),dpi=800)
33 | # axs = fig.axes()
34 |
35 | axs.plot(t, FORGET[0], label = '$R$ = 50',linewidth=4,color=colors[0])
36 | axs.plot(t, FORGET[1], label = '$R$ = 75',linewidth=4,color=colors[1])
37 | axs.plot(t, FORGET[2], label = '$R$ = 100',linewidth=4,color=colors[2])
38 |
39 | axs.set_xlim(-100,5100)
40 | axs.set_ylim(0,500)
41 | axs.xaxis.set_major_locator(MultipleLocator(1000))
42 | axs.set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
43 | axs.yaxis.set_major_locator(MultipleLocator(100))
44 | axs.set_xlabel('Edits',fontweight='bold',family = 'serif',fontsize=18)
45 | axs.set_ylabel('F - Edits',fontweight='bold',family = 'serif',fontsize=18)
46 | axs.grid(True)
47 | plt.xticks(fontsize=16)
48 | plt.yticks(fontsize=16)
49 | plt.legend(handlelength=3, loc = 'upper left',fontsize=13)
50 | plt.tight_layout()
51 | plt.savefig("./ablation_res/T5Small_forget")
--------------------------------------------------------------------------------
/plot/ablation/ablation_res/T5Large_zsre_block.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Large_zsre_block.jpg
--------------------------------------------------------------------------------
/plot/ablation/ablation_res/T5Large_zsre_eps.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Large_zsre_eps.jpg
--------------------------------------------------------------------------------
/plot/ablation/ablation_res/T5Small_cluster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_cluster.png
--------------------------------------------------------------------------------
/plot/ablation/ablation_res/T5Small_conflicts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_conflicts.png
--------------------------------------------------------------------------------
/plot/ablation/ablation_res/T5Small_forget.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_forget.png
--------------------------------------------------------------------------------
/plot/ablation/ablation_res/T5Small_zsre_block.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_zsre_block.jpg
--------------------------------------------------------------------------------
/plot/ablation/ablation_res/T5Small_zsre_block_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_zsre_block_2.jpg
--------------------------------------------------------------------------------
/plot/ablation/ablation_res/T5Small_zsre_block_main.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_zsre_block_main.jpg
--------------------------------------------------------------------------------
/plot/ablation/ablation_res/T5Small_zsre_eps.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/T5Small_zsre_eps.jpg
--------------------------------------------------------------------------------
/plot/ablation/ablation_res/pca.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/ablation_res/pca.jpg
--------------------------------------------------------------------------------
/plot/ablation/pca.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ECNU-ICALK/MELO/ed097eb874aad866eefca881a97dba988247ab9c/plot/ablation/pca.jpg
--------------------------------------------------------------------------------
/plot/ablation/res_time.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | np.random.seed(19680801)
4 |
5 | pts = np.random.rand(30)*.2
6 | # Now let's make two outlier points which are far away from everything.
7 | pts[[3, 14]] += .8
8 |
9 | # If we were to simply plot pts, we'd lose most of the interesting
10 | # details due to the outliers. So let's 'break' or 'cut-out' the y-axis
11 | # into two portions - use the top (ax1) for the outliers, and the bottom
12 | # (ax2) for the details of the majority of our data
13 | fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)
14 | fig.subplots_adjust(hspace=0.05) # adjust space between axes
15 |
16 | # plot the same data on both axes
17 | ax1.plot(pts)
18 | ax2.plot(pts)
19 |
20 | # zoom-in / limit the view to different portions of the data
21 | ax1.set_ylim(.78, 1.) # outliers only
22 | ax2.set_ylim(0, .22) # most of the data
23 |
24 | # hide the spines between ax and ax2
25 | ax1.spines.bottom.set_visible(False)
26 | ax2.spines.top.set_visible(False)
27 | ax1.xaxis.tick_top()
28 | ax1.tick_params(labeltop=False) # don't put tick labels at the top
29 | ax2.xaxis.tick_bottom()
30 |
31 | # Now, let's turn towards the cut-out slanted lines.
32 | # We create line objects in axes coordinates, in which (0,0), (0,1),
33 | # (1,0), and (1,1) are the four corners of the axes.
34 | # The slanted lines themselves are markers at those locations, such that the
35 | # lines keep their angle and position, independent of the axes size or scale
36 | # Finally, we need to disable clipping.
37 |
38 | d = .5 # proportion of vertical to horizontal extent of the slanted line
39 | kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
40 | linestyle="none", color='k', mec='k', mew=1, clip_on=False)
41 | ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
42 | ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)
43 |
44 |
45 | plt.show()
--------------------------------------------------------------------------------
/plot/ablation/zsre_block_lineplot_T5Large.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import matplotlib.pyplot as plt
4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
5 |
6 | colors = ['blue', 'red', 'green','orange']
7 |
8 | original_his = 0.99
9 | original_up = 0.8111
10 | original_holdout = 0.3012
11 | original_time = 0
12 |
13 | base_dir = '../logs/paper'
14 | settings = ['t5large_zsre_r2e10b0', 't5large_zsre_r2e10b0', 't5large_zsre_r2e10b0']
15 |
16 | r_list = []
17 |
18 | for index, setting in enumerate(settings):
19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f:
20 | loaded_dict = pickle.load(f)
21 | r_list.append(loaded_dict)
22 |
23 |
24 | t = list(range(0,5001,200))
25 |
26 | UP = []
27 | HIS = []
28 | HOLDOUT = []
29 | TIME = []
30 | for index, setting in enumerate(r_list):
31 | up_log = setting['all_UP']
32 | up_f1 = [original_up]
33 | for x in t[1:]:
34 | up_f1.append(up_log[x]['UP_f1'])
35 | UP.append(up_f1)
36 |
37 | for index, setting in enumerate(r_list):
38 | his_log = setting['all_HIS']
39 | his_f1 = [his_log[200]['HIS_f1']]
40 | for x in t[1:]:
41 | his_f1.append(his_log[x]['HIS_f1'])
42 | HIS.append(his_f1)
43 |
44 | for index, setting in enumerate(r_list):
45 | holdout_log = setting['all_HOLDOUT']
46 | holdout_f1 = [original_holdout]
47 | for x in t[1:]:
48 | holdout_f1.append(holdout_log[x]['holdout_f1'])
49 | HOLDOUT.append(holdout_f1)
50 |
51 | for index, setting in enumerate(r_list):
52 | time_log = setting['all_edit_time']
53 | time_list = [0]
54 | for x in t[1:]:
55 | time_list.append(time_log[x]/60)
56 | TIME.append(time_list)
57 |
58 | fig, axs = plt.subplots(2, 2, figsize=(5,4),dpi=800)
59 | for i in range(2):
60 | for x in axs[i]:
61 | x.set_box_aspect(0.85)
62 |
63 |
64 |
65 | axs[0,0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
66 | axs[0,0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
67 | axs[0,0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
68 | axs[0,0].set_xlim(-200,5200)
69 | axs[0,0].set_ylim(0.68,1.02)
70 | axs[0,0].xaxis.set_major_locator(MultipleLocator(1000))
71 | axs[0,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
72 | axs[0,0].yaxis.set_major_locator(MultipleLocator(0.1))
73 | axs[0,0].set_xlabel('Edits',fontweight='bold',family = 'serif')
74 | axs[0,0].set_ylabel('Edit Seccess',fontweight='bold',family = 'serif')
75 | axs[0,0].grid(True)
76 |
77 | axs[0,1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
78 | axs[0,1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
79 | axs[0,1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
80 | axs[0,1].set_xlim(-200,5200)
81 | axs[0,1].set_ylim(0.2,1.02)
82 | axs[0,1].xaxis.set_major_locator(MultipleLocator(1000))
83 | axs[0,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
84 | axs[0,1].yaxis.set_major_locator(MultipleLocator(0.2))
85 | axs[0,1].set_xlabel('Edits',fontweight='bold',family = 'serif')
86 | axs[0,1].set_ylabel('Holdout',fontweight='bold',family = 'serif')
87 | axs[0,1].grid(True)
88 |
89 | axs[1,0].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
90 | axs[1,0].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
91 | axs[1,0].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
92 | axs[1,0].set_xlim(-200,5200)
93 | axs[1,0].set_ylim(0,15)
94 | axs[1,0].xaxis.set_major_locator(MultipleLocator(1000))
95 | axs[1,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
96 | axs[1,0].yaxis.set_major_locator(MultipleLocator(3))
97 | axs[1,0].set_xlabel('Edits',fontweight='bold',family = 'serif')
98 | axs[1,0].set_ylabel('Time (mins)',fontweight='bold',family = 'serif')
99 | axs[1,0].grid(True)
100 |
101 |
102 | axs[1,1].plot(t, UP[0], label = '$VecDB$ = [0]',linewidth=1.8,color=colors[0])
103 | axs[1,1].plot(t, UP[1], label = '$VecDB$ = [4]',linewidth=1.8,color=colors[1])
104 | axs[1,1].plot(t, UP[2], label = '$vecDB$ = [10]',linewidth=1.8,color=colors[2])
105 | axs[1,1].set_xlim(-200,5200)
106 | axs[1,1].set_ylim(0.5,0.9)
107 | axs[1,1].xaxis.set_major_locator(MultipleLocator(1000))
108 | axs[1,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
109 | axs[1,1].yaxis.set_major_locator(MultipleLocator(0.1))
110 | axs[1,1].set_xlabel('Edits',fontweight='bold',family = 'serif')
111 | axs[1,1].set_ylabel('Locality',fontweight='bold',family = 'serif')
112 | axs[1,1].grid(True)
113 |
114 |
115 | lines_labels = [axs[1,1].get_legend_handles_labels()]
116 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
117 | fig.legend(lines, labels, loc='lower center', ncol=3)
118 | plt.tight_layout()
119 | plt.subplots_adjust(wspace=0)
120 |
121 | plt.subplots_adjust(bottom=0.2)
122 |
123 |
124 | plt.savefig('ablation_res/T5Large_zsre_block.jpg')
125 |
126 | plt.show()
127 |
128 | pass
129 |
130 |
131 |
--------------------------------------------------------------------------------
/plot/ablation/zsre_block_lineplot_T5Small.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import matplotlib.pyplot as plt
4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
5 |
6 | colors = ['blue', 'red', 'green','orange']
7 |
8 | original_his = 0.99
9 | original_up = 0.7225
10 | original_holdout = 0.3012
11 | original_time = 0
12 |
13 | base_dir = '../logs/paper'
14 | settings = ['t5small_zsre_r2e75b0', 't5small_zsre_r2e75b2', 't5small_zsre_r2e75b4']
15 |
16 | r_list = []
17 |
18 | for index, setting in enumerate(settings):
19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f:
20 | loaded_dict = pickle.load(f)
21 | r_list.append(loaded_dict)
22 |
23 |
24 | t = list(range(0,5001,200))
25 |
26 | UP = []
27 | HIS = []
28 | HOLDOUT = []
29 | TIME = []
30 | for index, setting in enumerate(r_list):
31 | up_log = setting['all_UP']
32 | up_f1 = [original_up]
33 | for x in t[1:]:
34 | up_f1.append(up_log[x]['UP_f1'])
35 | UP.append(up_f1)
36 |
37 | for index, setting in enumerate(r_list):
38 | his_log = setting['all_HIS']
39 | his_f1 = [his_log[200]['HIS_f1']]
40 | for x in t[1:]:
41 | his_f1.append(his_log[x]['HIS_f1'])
42 | HIS.append(his_f1)
43 |
44 | for index, setting in enumerate(r_list):
45 | holdout_log = setting['all_HOLDOUT']
46 | holdout_f1 = [original_holdout]
47 | for x in t[1:]:
48 | holdout_f1.append(holdout_log[x]['holdout_f1'])
49 | HOLDOUT.append(holdout_f1)
50 |
51 | for index, setting in enumerate(r_list):
52 | time_log = setting['all_edit_time']
53 | time_list = [0]
54 | for x in t[1:]:
55 | time_list.append(time_log[x]/60)
56 | TIME.append(time_list)
57 |
58 | fig, axs = plt.subplots(2, 2, figsize=(5,4),dpi=800)
59 | for i in range(2):
60 | for x in axs[i]:
61 | x.set_box_aspect(0.85)
62 |
63 |
64 |
65 | axs[0,0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
66 | axs[0,0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
67 | axs[0,0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
68 | axs[0,0].set_xlim(-200,5200)
69 | axs[0,0].set_ylim(0.68,1.02)
70 | axs[0,0].xaxis.set_major_locator(MultipleLocator(1000))
71 | axs[0,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
72 | axs[0,0].yaxis.set_major_locator(MultipleLocator(0.1))
73 | axs[0,0].set_xlabel('Edits',fontweight='bold',family = 'serif')
74 | axs[0,0].set_ylabel('Edit Seccess',fontweight='bold',family = 'serif')
75 | axs[0,0].grid(True)
76 |
77 | axs[0,1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
78 | axs[0,1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
79 | axs[0,1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
80 | axs[0,1].set_xlim(-200,5200)
81 | axs[0,1].set_ylim(0.2,1.02)
82 | axs[0,1].xaxis.set_major_locator(MultipleLocator(1000))
83 | axs[0,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
84 | axs[0,1].yaxis.set_major_locator(MultipleLocator(0.2))
85 | axs[0,1].set_xlabel('Edits',fontweight='bold',family = 'serif')
86 | axs[0,1].set_ylabel('Holdout',fontweight='bold',family = 'serif')
87 | axs[0,1].grid(True)
88 |
89 | axs[1,0].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
90 | axs[1,0].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
91 | axs[1,0].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
92 | axs[1,0].set_xlim(-200,5200)
93 | axs[1,0].set_ylim(0,3.2)
94 | axs[1,0].xaxis.set_major_locator(MultipleLocator(1000))
95 | axs[1,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
96 | axs[1,0].yaxis.set_major_locator(MultipleLocator(1))
97 | axs[1,0].set_xlabel('Edits',fontweight='bold',family = 'serif')
98 | axs[1,0].set_ylabel('Time (mins)',fontweight='bold',family = 'serif')
99 | axs[1,0].grid(True)
100 |
101 |
102 | axs[1,1].plot(t, UP[0], label = '$VecDB = [0]$',linewidth=1.8,color=colors[0])
103 | axs[1,1].plot(t, UP[1], label = '$VecDB = [2]$',linewidth=1.8,color=colors[1])
104 | axs[1,1].plot(t, UP[2], label = '$VecDB = [4]$',linewidth=1.8,color=colors[2])
105 | axs[1,1].set_xlim(-200,5200)
106 | axs[1,1].set_ylim(0.3,0.9)
107 | axs[1,1].xaxis.set_major_locator(MultipleLocator(1000))
108 | axs[1,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
109 | axs[1,1].yaxis.set_major_locator(MultipleLocator(0.1))
110 | axs[1,1].set_xlabel('Edits',fontweight='bold',family = 'serif')
111 | axs[1,1].set_ylabel('Locality',fontweight='bold',family = 'serif')
112 | axs[1,1].grid(True)
113 |
114 |
115 | lines_labels = [axs[1,1].get_legend_handles_labels()]
116 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
117 | fig.legend(lines, labels, loc='lower center', ncol=3)
118 | plt.tight_layout()
119 | plt.subplots_adjust(wspace=0)
120 |
121 | plt.subplots_adjust(bottom=0.2)
122 |
123 |
124 | plt.savefig('ablation_res/T5Small_zsre_block.jpg')
125 |
126 | plt.show()
127 |
128 | pass
129 |
130 |
131 |
--------------------------------------------------------------------------------
/plot/ablation/zsre_block_lineplot_t5small_2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import matplotlib.pyplot as plt
4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
5 |
6 | colors = ['blue', 'red', 'green','orange']
7 |
8 | original_his = 0.99
9 | original_up = 0.7225
10 | original_holdout = 0.3012
11 | original_time = 0
12 |
13 | base_dir = '../logs/paper'
14 | settings = ['t5small_zsre_r2e75b0', 't5small_zsre_r2e75b2', 't5small_zsre_r2e75b4']
15 |
16 | r_list = []
17 |
18 | for index, setting in enumerate(settings):
19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f:
20 | loaded_dict = pickle.load(f)
21 | r_list.append(loaded_dict)
22 |
23 |
24 | t = list(range(0,5001,200))
25 |
26 | UP = []
27 | HIS = []
28 | HOLDOUT = []
29 | TIME = []
30 | for index, setting in enumerate(r_list):
31 | up_log = setting['all_UP']
32 | up_f1 = [original_up]
33 | for x in t[1:]:
34 | up_f1.append(up_log[x]['UP_f1'])
35 | UP.append(up_f1)
36 |
37 | for index, setting in enumerate(r_list):
38 | his_log = setting['all_HIS']
39 | his_f1 = [his_log[200]['HIS_f1']]
40 | for x in t[1:]:
41 | his_f1.append(his_log[x]['HIS_f1'])
42 | HIS.append(his_f1)
43 |
44 | for index, setting in enumerate(r_list):
45 | holdout_log = setting['all_HOLDOUT']
46 | holdout_f1 = [original_holdout]
47 | for x in t[1:]:
48 | holdout_f1.append(holdout_log[x]['holdout_f1'])
49 | HOLDOUT.append(holdout_f1)
50 |
51 | for index, setting in enumerate(r_list):
52 | time_log = setting['all_edit_time']
53 | time_list = [0]
54 | for x in t[1:]:
55 | time_list.append(time_log[x]/60)
56 | TIME.append(time_list)
57 |
58 | fig, axs = plt.subplots(1, 2, figsize=(5,2.5),dpi=800)
59 | for i in range(2):
60 | for x in axs:
61 | x.set_box_aspect(0.85)
62 |
63 |
64 |
65 | axs[0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
66 | axs[0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
67 | axs[0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
68 | axs[0].set_xlim(-200,5200)
69 | axs[0].set_ylim(0.68,1.02)
70 | axs[0].xaxis.set_major_locator(MultipleLocator(1000))
71 | axs[0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
72 | axs[0].yaxis.set_major_locator(MultipleLocator(0.1))
73 | axs[0].set_xlabel('Edits',fontweight='bold',family = 'serif')
74 | axs[0].set_ylabel('Edit Success',fontweight='bold',family = 'serif')
75 | axs[0].grid(True)
76 |
77 |
78 |
79 | axs[1].plot(t, UP[0], label = '$VecDB = [0]$',linewidth=1.8,color=colors[0])
80 | axs[1].plot(t, UP[1], label = '$VecDB = [2]$',linewidth=1.8,color=colors[1])
81 | axs[1].plot(t, UP[2], label = '$VecDB = [4]$',linewidth=1.8,color=colors[2])
82 | axs[1].set_xlim(-200,5200)
83 | axs[1].set_ylim(0.3,0.9)
84 | axs[1].xaxis.set_major_locator(MultipleLocator(1000))
85 | axs[1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
86 | axs[1].yaxis.set_major_locator(MultipleLocator(0.1))
87 | axs[1].set_xlabel('Edits',fontweight='bold',family = 'serif')
88 | axs[1].set_ylabel('Locality',fontweight='bold',family = 'serif')
89 | axs[1].grid(True)
90 |
91 |
92 | lines_labels = [axs[1].get_legend_handles_labels()]
93 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
94 | fig.legend(lines, labels, loc='lower center', ncol=3)
95 | plt.tight_layout()
96 | # plt.subplots_adjust(wspace=0)
97 |
98 | plt.subplots_adjust(bottom=0.3)
99 |
100 |
101 | plt.savefig('ablation_res/T5Small_zsre_block_2.jpg')
102 |
103 | plt.show()
104 |
105 | pass
106 |
107 |
108 |
--------------------------------------------------------------------------------
/plot/ablation/zsre_eps_lineplot_T5Large.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import matplotlib.pyplot as plt
4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
5 |
6 | colors = ['blue', 'red', 'green','orange']
7 |
8 | original_his = 0.99
9 | original_up = 0.8111
10 | original_holdout = 0.3012
11 | original_time = 0
12 |
13 | base_dir = '../logs/paper'
14 | settings = ['t5large_zsre_r2e10b4', 't5large_zsre_r2e15b4', 't5large_zsre_r2e20b4']
15 |
16 | r_list = []
17 |
18 | for index, setting in enumerate(settings):
19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f:
20 | loaded_dict = pickle.load(f)
21 | r_list.append(loaded_dict)
22 |
23 |
24 | t = list(range(0,5001,200))
25 |
26 | UP = []
27 | HIS = []
28 | HOLDOUT = []
29 | TIME = []
30 | for index, setting in enumerate(r_list):
31 | up_log = setting['all_UP']
32 | up_f1 = [original_up]
33 | for x in t[1:]:
34 | up_f1.append(up_log[x]['UP_f1'])
35 | UP.append(up_f1)
36 |
37 | for index, setting in enumerate(r_list):
38 | his_log = setting['all_HIS']
39 | his_f1 = [his_log[200]['HIS_f1']]
40 | for x in t[1:]:
41 | his_f1.append(his_log[x]['HIS_f1'])
42 | HIS.append(his_f1)
43 |
44 | for index, setting in enumerate(r_list):
45 | holdout_log = setting['all_HOLDOUT']
46 | holdout_f1 = [original_holdout]
47 | for x in t[1:]:
48 | holdout_f1.append(holdout_log[x]['holdout_f1'])
49 | HOLDOUT.append(holdout_f1)
50 |
51 | for index, setting in enumerate(r_list):
52 | time_log = setting['all_edit_time']
53 | time_list = [0]
54 | for x in t[1:]:
55 | time_list.append(time_log[x]/60)
56 | TIME.append(time_list)
57 |
58 | fig, axs = plt.subplots(2, 2, figsize=(5,4),dpi=800)
59 | for i in range(2):
60 | for x in axs[i]:
61 | x.set_box_aspect(0.85)
62 |
63 |
64 |
65 | axs[0,0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
66 | axs[0,0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
67 | axs[0,0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
68 | axs[0,0].set_xlim(-200,5200)
69 | axs[0,0].set_ylim(0.68,1.02)
70 | axs[0,0].xaxis.set_major_locator(MultipleLocator(1000))
71 | axs[0,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
72 | axs[0,0].yaxis.set_major_locator(MultipleLocator(0.1))
73 | axs[0,0].set_xlabel('Edits',fontweight='bold',family = 'serif')
74 | axs[0,0].set_ylabel('Edit Seccess',fontweight='bold',family = 'serif')
75 | axs[0,0].grid(True)
76 |
77 | axs[0,1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
78 | axs[0,1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
79 | axs[0,1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
80 | axs[0,1].set_xlim(-200,5200)
81 | axs[0,1].set_ylim(0.2,1.02)
82 | axs[0,1].xaxis.set_major_locator(MultipleLocator(1000))
83 | axs[0,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
84 | axs[0,1].yaxis.set_major_locator(MultipleLocator(0.2))
85 | axs[0,1].set_xlabel('Edits',fontweight='bold',family = 'serif')
86 | axs[0,1].set_ylabel('Holdout',fontweight='bold',family = 'serif')
87 | axs[0,1].grid(True)
88 |
89 | axs[1,0].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
90 | axs[1,0].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
91 | axs[1,0].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
92 | axs[1,0].set_xlim(-200,5200)
93 | axs[1,0].set_ylim(0,15)
94 | axs[1,0].xaxis.set_major_locator(MultipleLocator(1000))
95 | axs[1,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
96 | axs[1,0].yaxis.set_major_locator(MultipleLocator(3))
97 | axs[1,0].set_xlabel('Edits',fontweight='bold',family = 'serif')
98 | axs[1,0].set_ylabel('Time (mins)',fontweight='bold',family = 'serif')
99 | axs[1,0].grid(True)
100 |
101 |
102 | axs[1,1].plot(t, UP[0], label = '$R$ = 1.0',linewidth=1.8,color=colors[0])
103 | axs[1,1].plot(t, UP[1], label = '$R$ = 6.0',linewidth=1.8,color=colors[1])
104 | axs[1,1].plot(t, UP[2], label = '$R$ = 20.0',linewidth=1.8,color=colors[2])
105 | axs[1,1].set_xlim(-200,5200)
106 | axs[1,1].set_ylim(0.5,0.9)
107 | axs[1,1].xaxis.set_major_locator(MultipleLocator(1000))
108 | axs[1,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
109 | axs[1,1].yaxis.set_major_locator(MultipleLocator(0.1))
110 | axs[1,1].set_xlabel('Edits',fontweight='bold',family = 'serif')
111 | axs[1,1].set_ylabel('Locality',fontweight='bold',family = 'serif')
112 | axs[1,1].grid(True)
113 |
114 |
115 | lines_labels = [axs[1,1].get_legend_handles_labels()]
116 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
117 | fig.legend(lines, labels, loc='lower center', ncol=3)
118 | plt.tight_layout()
119 | plt.subplots_adjust(wspace=0)
120 |
121 | plt.subplots_adjust(bottom=0.2)
122 |
123 |
124 | plt.savefig('ablation_res/T5Large_zsre_eps.jpg')
125 |
126 | plt.show()
127 |
128 | pass
129 |
130 |
131 |
--------------------------------------------------------------------------------
/plot/ablation/zsre_eps_lineplot_T5Small.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import matplotlib.pyplot as plt
4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
5 |
6 | colors = ['blue', 'red', 'green','orange']
7 |
8 | original_his = 0.99
9 | original_up = 0.7225
10 | original_holdout = 0.3012
11 | original_time = 0
12 |
13 | base_dir = '../logs/paper'
14 | settings = ['t5small_zsre_r2e50b4', 't5small_zsre_r2e75b4', 't5small_zsre_r2e100b4']
15 |
16 | r_list = []
17 |
18 | for index, setting in enumerate(settings):
19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f:
20 | loaded_dict = pickle.load(f)
21 | r_list.append(loaded_dict)
22 |
23 |
24 | t = list(range(0,5001,200))
25 |
26 | UP = []
27 | HIS = []
28 | HOLDOUT = []
29 | TIME = []
30 | for index, setting in enumerate(r_list):
31 | up_log = setting['all_UP']
32 | up_f1 = [original_up]
33 | for x in t[1:]:
34 | up_f1.append(up_log[x]['UP_f1'])
35 | UP.append(up_f1)
36 |
37 | for index, setting in enumerate(r_list):
38 | his_log = setting['all_HIS']
39 | his_f1 = [his_log[200]['HIS_f1']]
40 | for x in t[1:]:
41 | his_f1.append(his_log[x]['HIS_f1'])
42 | HIS.append(his_f1)
43 |
44 | for index, setting in enumerate(r_list):
45 | holdout_log = setting['all_HOLDOUT']
46 | holdout_f1 = [original_holdout]
47 | for x in t[1:]:
48 | holdout_f1.append(holdout_log[x]['holdout_f1'])
49 | HOLDOUT.append(holdout_f1)
50 |
51 | for index, setting in enumerate(r_list):
52 | time_log = setting['all_edit_time']
53 | time_list = [0]
54 | for x in t[1:]:
55 | time_list.append(time_log[x]/60)
56 | TIME.append(time_list)
57 |
58 | fig, axs = plt.subplots(2, 2, figsize=(5,4),dpi=800)
59 | for i in range(2):
60 | for x in axs[i]:
61 | x.set_box_aspect(0.85)
62 |
63 |
64 |
65 | axs[0,0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
66 | axs[0,0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
67 | axs[0,0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
68 | axs[0,0].set_xlim(-200,5200)
69 | axs[0,0].set_ylim(0.68,1.02)
70 | axs[0,0].xaxis.set_major_locator(MultipleLocator(1000))
71 | axs[0,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
72 | axs[0,0].yaxis.set_major_locator(MultipleLocator(0.1))
73 | axs[0,0].set_xlabel('Edits',fontweight='bold',family = 'serif')
74 | axs[0,0].set_ylabel('Edit Seccess',fontweight='bold',family = 'serif')
75 | axs[0,0].grid(True)
76 |
77 | axs[0,1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
78 | axs[0,1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
79 | axs[0,1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
80 | axs[0,1].set_xlim(-200,5200)
81 | axs[0,1].set_ylim(0.2,1.02)
82 | axs[0,1].xaxis.set_major_locator(MultipleLocator(1000))
83 | axs[0,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
84 | axs[0,1].yaxis.set_major_locator(MultipleLocator(0.2))
85 | axs[0,1].set_xlabel('Edits',fontweight='bold',family = 'serif')
86 | axs[0,1].set_ylabel('Holdout',fontweight='bold',family = 'serif')
87 | axs[0,1].grid(True)
88 |
89 | axs[1,0].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
90 | axs[1,0].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
91 | axs[1,0].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
92 | axs[1,0].set_xlim(-200,5200)
93 | axs[1,0].set_ylim(0,3.2)
94 | axs[1,0].xaxis.set_major_locator(MultipleLocator(1000))
95 | axs[1,0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
96 | axs[1,0].yaxis.set_major_locator(MultipleLocator(1))
97 | axs[1,0].set_xlabel('Edits',fontweight='bold',family = 'serif')
98 | axs[1,0].set_ylabel('Time (mins)',fontweight='bold',family = 'serif')
99 | axs[1,0].grid(True)
100 |
101 |
102 | axs[1,1].plot(t, UP[0], label = '$R$ = 50',linewidth=1.8,color=colors[0])
103 | axs[1,1].plot(t, UP[1], label = '$R$ = 75',linewidth=1.8,color=colors[1])
104 | axs[1,1].plot(t, UP[2], label = '$R$ = 100',linewidth=1.8,color=colors[2])
105 | axs[1,1].set_xlim(-200,5200)
106 | axs[1,1].set_ylim(0.5,0.9)
107 | axs[1,1].xaxis.set_major_locator(MultipleLocator(1000))
108 | axs[1,1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
109 | axs[1,1].yaxis.set_major_locator(MultipleLocator(0.1))
110 | axs[1,1].set_xlabel('Edits',fontweight='bold',family = 'serif')
111 | axs[1,1].set_ylabel('Locality',fontweight='bold',family = 'serif')
112 | axs[1,1].grid(True)
113 |
114 |
115 | lines_labels = [axs[1,1].get_legend_handles_labels()]
116 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
117 | fig.legend(lines, labels, loc='lower center', ncol=3)
118 | plt.tight_layout()
119 | plt.subplots_adjust(wspace=0)
120 |
121 | plt.subplots_adjust(bottom=0.2)
122 |
123 |
124 | plt.savefig('ablation_res/T5Small_zsre_eps.jpg')
125 |
126 | plt.show()
127 |
128 | pass
129 |
130 |
131 |
--------------------------------------------------------------------------------
/plot/tsne.py:
--------------------------------------------------------------------------------
1 | import plotly.express as px
2 | from sklearn.datasets import make_classification
3 |
4 | X, y = make_classification(
5 | n_features=6,
6 | n_classes=3,
7 | n_samples=1500,
8 | n_informative=2,
9 | random_state=5,
10 | n_clusters_per_class=1,
11 | )
12 |
13 |
14 | fig = px.scatter_3d(x=X[:, 0], y=X[:, 1], z=X[:, 2], color=y, opacity=0.8)
15 | fig.show()
16 |
17 |
18 | #-------------------PLOT PCA----------------------------------
19 | from sklearn.decomposition import PCA
20 | pca = PCA(n_components=2)
21 | X_pca = pca.fit_transform(X)
22 |
23 | fig = px.scatter(x=X_pca[:, 0], y=X_pca[:, 1], color=y)
24 | fig.update_layout(
25 | title="PCA visualization of Custom Classification dataset",
26 | xaxis_title="First Principal Component",
27 | yaxis_title="Second Principal Component",
28 | )
29 | fig.show()
30 |
31 |
32 | #-------------------PLOT tsne----------------------------------
33 | from sklearn.manifold import TSNE
34 |
35 | tsne = TSNE(n_components=2, random_state=42)
36 | X_tsne = tsne.fit_transform(X)
37 | print(tsne.kl_divergence_)
38 |
39 | fig = px.scatter(x=X_tsne[:, 0], y=X_tsne[:, 1], color=y)
40 | fig.update_layout(
41 | title="t-SNE visualization of Custom Classification dataset",
42 | xaxis_title="First t-SNE",
43 | yaxis_title="Second t-SNE",
44 | )
45 | fig.show()
--------------------------------------------------------------------------------
/plot/tsne_zsre_T5Small.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import torch
3 | import pandas as pd
4 | import plotly.express as px
5 | base_dir = 'logs/VecDB'
6 | import matplotlib.pyplot as plt
7 | from mpl_toolkits.mplot3d import Axes3D
8 | from sklearn.decomposition import PCA
9 | from sklearn.manifold import TSNE
10 | import seaborn as sns
11 | import numpy as np
12 | import transformers
13 | with open(f'{base_dir}/T5Small_100_e75/log.pkl', 'rb') as f:
14 | loaded_dict = pickle.load(f)
15 | sns.set(style="darkgrid")
16 |
17 | tokenizer = transformers.AutoTokenizer.from_pretrained('google/t5-small-ssm-nq')
18 | vecdb = loaded_dict['all_VecDB']
19 | import time
20 | clusters = vecdb.table
21 | X = []
22 | Y = []
23 | for index, cluster in enumerate(clusters):
24 |
25 | if len(cluster['points']) > 3 and cluster['radius']:
26 | key_label = cluster['key_label']
27 | key_label = key_label.masked_fill(key_label == -100, tokenizer.pad_token_id)
28 | key_label = tokenizer.decode(key_label, skip_special_tokens=True)
29 | print(key_label,index)
30 |
31 | for point in cluster['points']:
32 | X.append(point.get_key())
33 | Y.append(index)
34 |
35 | X = torch.stack(X,dim=0).cpu().numpy()
36 | y = torch.tensor(Y).cpu().numpy()
37 |
38 | # fig = px.scatter_3d(x=X[:, 0], y=X[:, 1], z=X[:, 2], color=Y, opacity=0.8)
39 | # fig.show()
40 |
41 | feat_cols = [ 'pixel'+str(i) for i in range(X.shape[1])]
42 | df = pd.DataFrame(X,columns=feat_cols)
43 | df['y'] = y
44 | df['label'] = df['y'].apply(lambda i: str(i))
45 |
46 | X, y = None, None
47 |
48 | print('Size of the dataframe: {}'.format(df.shape))
49 |
50 | pca = PCA(n_components=30)
51 | pca_result = pca.fit_transform(df[feat_cols].values)
52 |
53 | df['pca-one'] = pca_result[:,0]
54 | df['pca-two'] = pca_result[:,4]
55 | # df['pca-three'] = pca_result[:,2]
56 |
57 | print('Explained variation per principal component: {}'.format(pca.explained_variance_ratio_))
58 | np.random.seed(0)
59 |
60 | rndperm = np.random.permutation(df.shape[0])
61 | plt.figure(figsize=(5,5), dpi = 800)
62 | ax = sns.scatterplot(
63 | x="pca-one", y="pca-two",
64 | hue="y",
65 | palette=sns.color_palette("tab10"),
66 | legend=False,
67 | data=df,
68 | marker= 'o',
69 | s=120,
70 | )
71 | ax.set(xlabel=None)
72 | ax.set(ylabel=None)
73 | plt.tight_layout()
74 | plt.tick_params(axis='both', which='major', labelsize=18)
75 | plt.savefig('./ablation/ablation_res/pca.jpg')
76 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/plot/zsre_eps_lineplot_T5Large.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import matplotlib.pyplot as plt
4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
5 |
6 | colors = ['blue', 'red', 'green','orange']
7 |
8 | original_his = 0.99
9 | original_up = 0.8111
10 | original_holdout = 0.3012
11 | original_time = 0
12 |
13 | base_dir = 'logs/5000'
14 | settings = ['T5Large_r2e1b4', 'T5Large_r2e6b4', 'T5Large_r2e20b4']
15 |
16 | r_list = []
17 |
18 | for index, setting in enumerate(settings):
19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f:
20 | loaded_dict = pickle.load(f)
21 | r_list.append(loaded_dict)
22 |
23 |
24 | t = list(range(0,5001,200))
25 |
26 | UP = []
27 | HIS = []
28 | HOLDOUT = []
29 | TIME = []
30 | for index, setting in enumerate(r_list):
31 | up_log = setting['all_UP']
32 | up_f1 = [original_up]
33 | for x in t[1:]:
34 | up_f1.append(up_log[x]['UP_f1'])
35 | UP.append(up_f1)
36 |
37 | for index, setting in enumerate(r_list):
38 | his_log = setting['all_HIS']
39 | his_f1 = [his_log[200]['HIS_f1']]
40 | for x in t[1:]:
41 | his_f1.append(his_log[x]['HIS_f1'])
42 | HIS.append(his_f1)
43 |
44 | for index, setting in enumerate(r_list):
45 | holdout_log = setting['all_HOLDOUT']
46 | holdout_f1 = [original_holdout]
47 | for x in t[1:]:
48 | holdout_f1.append(holdout_log[x]['holdout_f1'])
49 | HOLDOUT.append(holdout_f1)
50 |
51 | for index, setting in enumerate(r_list):
52 | time_log = setting['all_edit_time']
53 | time_list = [0]
54 | for x in t[1:]:
55 | time_list.append(time_log[x]/60)
56 | TIME.append(time_list)
57 |
58 | fig, axs = plt.subplots(1, 4, figsize=(8,2),dpi=800)
59 | for i in range(4):
60 | axs[i].set_box_aspect(0.85)
61 |
62 |
63 |
64 | axs[0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
65 | axs[0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
66 | axs[0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
67 | axs[0].set_xlim(-200,5200)
68 | axs[0].set_ylim(0.68,1.02)
69 | axs[0].xaxis.set_major_locator(MultipleLocator(1000))
70 | axs[0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
71 | axs[0].yaxis.set_major_locator(MultipleLocator(0.1))
72 | axs[0].set_xlabel('Edits',fontweight='bold',family = 'serif')
73 | axs[0].set_ylabel('Edit Seccess',fontweight='bold',family = 'serif')
74 | axs[0].grid(True)
75 |
76 | axs[1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
77 | axs[1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
78 | axs[1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
79 | axs[1].set_xlim(-200,5200)
80 | axs[1].set_ylim(0.2,1.02)
81 | axs[1].xaxis.set_major_locator(MultipleLocator(1000))
82 | axs[1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
83 | axs[1].yaxis.set_major_locator(MultipleLocator(0.2))
84 | axs[1].set_xlabel('Edits',fontweight='bold',family = 'serif')
85 | axs[1].set_ylabel('Holdout',fontweight='bold',family = 'serif')
86 | axs[1].grid(True)
87 |
88 | axs[2].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
89 | axs[2].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
90 | axs[2].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
91 | axs[2].set_xlim(-200,5200)
92 | axs[2].set_ylim(0,15)
93 | axs[2].xaxis.set_major_locator(MultipleLocator(1000))
94 | axs[2].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
95 | axs[2].yaxis.set_major_locator(MultipleLocator(3))
96 | axs[2].set_xlabel('Edits',fontweight='bold',family = 'serif')
97 | axs[2].set_ylabel('Time (mins)',fontweight='bold',family = 'serif')
98 | axs[2].grid(True)
99 |
100 |
101 | axs[3].plot(t, UP[0], label = 'Radius = 1.0',linewidth=1.8,color=colors[0])
102 | axs[3].plot(t, UP[1], label = 'Radius = 6.0',linewidth=1.8,color=colors[1])
103 | axs[3].plot(t, UP[2], label = 'Radius = 20.0',linewidth=1.8,color=colors[2])
104 | axs[3].set_xlim(-200,5200)
105 | axs[3].set_ylim(0.5,0.9)
106 | axs[3].xaxis.set_major_locator(MultipleLocator(1000))
107 | axs[3].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
108 | axs[3].yaxis.set_major_locator(MultipleLocator(0.1))
109 | axs[3].set_xlabel('Edits',fontweight='bold',family = 'serif')
110 | axs[3].set_ylabel('Locality',fontweight='bold',family = 'serif')
111 | axs[3].grid(True)
112 |
113 |
114 | lines_labels = [axs[3].get_legend_handles_labels()]
115 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
116 | fig.legend(lines, labels, loc='lower center', ncol=4)
117 | plt.tight_layout()
118 | plt.subplots_adjust(wspace=0.5)
119 | plt.subplots_adjust(bottom=0.4)
120 |
121 |
122 | plt.savefig('plotres/T5Large_zsre_eps.jpg')
123 |
124 | plt.show()
125 |
126 | pass
127 |
128 |
129 |
--------------------------------------------------------------------------------
/plot/zsre_rank_line_plot_generality.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import matplotlib.pyplot as plt
4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
5 |
6 | colors = ['blue', 'red', 'orange','green']
7 |
8 | original_his = 0.99
9 | original_up = 0.7225
10 | original_holdout = 0.3012
11 | original_time = 0
12 |
13 | base_dir = 'logs/5000'
14 | settings = ['T5Small_r1e3b4', 'T5Small_r2e3b4', 'T5Small_r4e3b4']
15 |
16 | r_list_small = []
17 |
18 | for index, setting in enumerate(settings):
19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f:
20 | loaded_dict = pickle.load(f)
21 | r_list_small.append(loaded_dict)
22 |
23 |
24 | settings = ['T5Large_r1e3b4', 'T5Large_r2e3b4', 'T5Large_r4e3b4']
25 | r_list_large = []
26 |
27 | for index, setting in enumerate(settings):
28 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f:
29 | loaded_dict = pickle.load(f)
30 | r_list_large.append(loaded_dict)
31 |
32 |
33 | t = list(range(0,5001,200))
34 |
35 |
36 | HOLDOUT_SMALL = []
37 | HOLDOUT_LARGE = []
38 |
39 |
40 |
41 | for index, setting in enumerate(r_list_small):
42 | holdout_log = setting['all_HOLDOUT']
43 | holdout_f1 = [original_holdout]
44 | for x in t[1:]:
45 | holdout_f1.append(holdout_log[x]['holdout_f1'])
46 | HOLDOUT_SMALL.append(holdout_f1)
47 |
48 |
49 | for index, setting in enumerate(r_list_large):
50 | holdout_log = setting['all_HOLDOUT']
51 | holdout_f1 = [original_holdout]
52 | for x in t[1:]:
53 | holdout_f1.append(holdout_log[x]['holdout_f1'])
54 | HOLDOUT_LARGE.append(holdout_f1)
55 |
56 |
57 | fig, axs = plt.subplots(1, 2, figsize=(4,2),dpi=800)
58 | for i in range(2):
59 | axs[i].set_box_aspect(0.75)
60 |
61 |
62 |
63 | axs[0].plot(t, HOLDOUT_SMALL[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
64 | axs[0].plot(t, HOLDOUT_SMALL[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
65 | axs[0].plot(t, HOLDOUT_SMALL[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
66 | axs[0].set_xlim(-200,5200)
67 | axs[0].set_ylim(0.2,1.02)
68 | axs[0].xaxis.set_major_locator(MultipleLocator(1000))
69 | axs[0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
70 | axs[0].yaxis.set_major_locator(MultipleLocator(0.2))
71 | axs[0].set_xlabel('Edits',fontweight='bold',family = 'serif')
72 | axs[0].set_ylabel('Generality',fontweight='bold',family = 'serif')
73 | axs[0].set_title('T5Small',fontweight='bold',family = 'serif',size=11)
74 | axs[0].grid(True)
75 |
76 |
77 | axs[1].plot(t, HOLDOUT_LARGE[0], label = 'PR = 1',linewidth=1.8,color=colors[0])
78 | axs[1].plot(t, HOLDOUT_LARGE[1], label = 'PR = 2',linewidth=1.8,color=colors[1])
79 | axs[1].plot(t, HOLDOUT_LARGE[2], label = 'PR = 4',linewidth=1.8,color=colors[2])
80 | axs[1].set_xlim(-200,5200)
81 | axs[1].set_ylim(0.2,1.02)
82 | axs[1].xaxis.set_major_locator(MultipleLocator(1000))
83 | axs[1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
84 | axs[1].yaxis.set_major_locator(MultipleLocator(0.2))
85 | axs[1].set_xlabel('Edits',fontweight='bold',family = 'serif')
86 | axs[1].set_ylabel('Generality',fontweight='bold',family = 'serif')
87 | axs[1].set_title('T5Large',fontweight='bold',family = 'serif',size=11)
88 | axs[1].grid(True)
89 |
90 |
91 |
92 |
93 | lines_labels = [axs[1].get_legend_handles_labels()]
94 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
95 | fig.legend(lines, labels, loc='lower center', ncol=3)
96 | plt.tight_layout()
97 | plt.subplots_adjust(wspace=0.6)
98 | plt.subplots_adjust(wspace=0.6)
99 | plt.subplots_adjust(bottom=0.35)
100 |
101 | plt.savefig('plotres/T5Small_zsre_generality_42.jpg')
102 |
103 | plt.show()
104 |
105 | pass
106 |
107 |
108 |
--------------------------------------------------------------------------------
/plot/zsre_rank_lineplot_T5Large.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import matplotlib.pyplot as plt
4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
5 |
6 | colors = ['blue', 'red', 'orange','green']
7 |
8 | original_his = 0.99
9 | original_up = 0.8111
10 | original_holdout = 0.3012
11 | original_time = 0
12 |
13 | base_dir = 'logs/5000'
14 | settings = ['T5Large_r1e3b4', 'T5Large_r2e3b4', 'T5Large_r4e3b4']
15 |
16 | r_list = []
17 |
18 | for index, setting in enumerate(settings):
19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f:
20 | loaded_dict = pickle.load(f)
21 | r_list.append(loaded_dict)
22 |
23 |
24 | t = list(range(0,5001,200))
25 |
26 | UP = []
27 | HIS = []
28 | HOLDOUT = []
29 | TIME = []
30 | for index, setting in enumerate(r_list):
31 | up_log = setting['all_UP']
32 | up_f1 = [original_up]
33 | for x in t[1:]:
34 | up_f1.append(up_log[x]['UP_f1'])
35 | UP.append(up_f1)
36 |
37 | for index, setting in enumerate(r_list):
38 | his_log = setting['all_HIS']
39 | his_f1 = [his_log[200]['HIS_f1']]
40 | for x in t[1:]:
41 | his_f1.append(his_log[x]['HIS_f1'])
42 | HIS.append(his_f1)
43 |
44 | for index, setting in enumerate(r_list):
45 | holdout_log = setting['all_HOLDOUT']
46 | holdout_f1 = [original_holdout]
47 | for x in t[1:]:
48 | holdout_f1.append(holdout_log[x]['holdout_f1'])
49 | HOLDOUT.append(holdout_f1)
50 |
51 | for index, setting in enumerate(r_list):
52 | time_log = setting['all_edit_time']
53 | time_list = [0]
54 | for x in t[1:]:
55 | time_list.append(time_log[x]/60)
56 | TIME.append(time_list)
57 |
58 | fig, axs = plt.subplots(1, 4, figsize=(8,2),dpi=800)
59 | for i in range(4):
60 | axs[i].set_box_aspect(0.85)
61 |
62 |
63 |
64 | axs[0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
65 | axs[0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
66 | axs[0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
67 | axs[0].set_xlim(-200,5200)
68 | axs[0].set_ylim(0.78,1.02)
69 | axs[0].xaxis.set_major_locator(MultipleLocator(1000))
70 | axs[0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
71 | axs[0].yaxis.set_major_locator(MultipleLocator(0.1))
72 | axs[0].set_xlabel('Edits',fontweight='bold',family = 'serif')
73 | axs[0].set_ylabel('Edit Success',fontweight='bold',family = 'serif')
74 | axs[0].grid(True)
75 |
76 | axs[1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
77 | axs[1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
78 | axs[1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
79 | axs[1].set_xlim(-200,5200)
80 | axs[1].set_ylim(0.2,1.02)
81 | axs[1].xaxis.set_major_locator(MultipleLocator(1000))
82 | axs[1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
83 | axs[1].yaxis.set_major_locator(MultipleLocator(0.2))
84 | axs[1].set_xlabel('Edits',fontweight='bold',family = 'serif')
85 | axs[1].set_ylabel('Holdout',fontweight='bold',family = 'serif')
86 | axs[1].grid(True)
87 |
88 | axs[2].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
89 | axs[2].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
90 | axs[2].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
91 | axs[2].set_xlim(-200,5200)
92 | axs[2].set_ylim(0,15)
93 | axs[2].xaxis.set_major_locator(MultipleLocator(1000))
94 | axs[2].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
95 | axs[2].yaxis.set_major_locator(MultipleLocator(3))
96 | axs[2].set_xlabel('Edits',fontweight='bold',family = 'serif')
97 | axs[2].set_ylabel('Time (min)',fontweight='bold',family = 'serif')
98 | axs[2].grid(True)
99 |
100 |
101 | axs[3].plot(t, UP[0], label = 'Partial Rank = 1',linewidth=1.8,color=colors[0])
102 | axs[3].plot(t, UP[1], label = 'Partial Rank = 2',linewidth=1.8,color=colors[1])
103 | axs[3].plot(t, UP[2], label = 'Partial Rank = 4',linewidth=1.8,color=colors[2])
104 | axs[3].set_xlim(-200,5200)
105 | axs[3].set_ylim(0.5,0.9)
106 | axs[3].xaxis.set_major_locator(MultipleLocator(1000))
107 | axs[3].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
108 | axs[3].yaxis.set_major_locator(MultipleLocator(0.1))
109 | axs[3].set_xlabel('Edits',fontweight='bold',family = 'serif')
110 | axs[3].set_ylabel('Locality',fontweight='bold',family = 'serif')
111 | axs[3].grid(True)
112 |
113 |
114 | lines_labels = [axs[3].get_legend_handles_labels()]
115 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
116 | fig.legend(lines, labels, loc='lower center', ncol=4)
117 | plt.tight_layout()
118 | plt.subplots_adjust(wspace=0.5)
119 | plt.subplots_adjust(wspace=0.5)
120 | plt.subplots_adjust(bottom=0.4)
121 |
122 | plt.savefig('plotres/T5Large_zsre_rank.jpg')
123 |
124 | plt.show()
125 |
126 | pass
127 |
128 |
129 |
--------------------------------------------------------------------------------
/plot/zsre_rank_lineplot_T5Small.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import matplotlib.pyplot as plt
4 | from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
5 |
6 | colors = ['blue', 'red', 'orange','green']
7 |
8 | original_his = 0.99
9 | original_up = 0.7225
10 | original_holdout = 0.3012
11 | original_time = 0
12 |
13 | base_dir = 'logs/5000'
14 | settings = ['T5Small_r1e3b4', 'T5Small_r2e3b4', 'T5Small_r4e3b4']
15 |
16 | r_list = []
17 |
18 | for index, setting in enumerate(settings):
19 | with open(f'{base_dir}/{setting}/log.pkl', 'rb') as f:
20 | loaded_dict = pickle.load(f)
21 | r_list.append(loaded_dict)
22 |
23 |
24 | t = list(range(0,5001,200))
25 |
26 | UP = []
27 | HIS = []
28 | HOLDOUT = []
29 | TIME = []
30 | for index, setting in enumerate(r_list):
31 | up_log = setting['all_UP']
32 | up_f1 = [original_up]
33 | for x in t[1:]:
34 | up_f1.append(up_log[x]['UP_f1'])
35 | UP.append(up_f1)
36 |
37 | for index, setting in enumerate(r_list):
38 | his_log = setting['all_HIS']
39 | his_f1 = [his_log[200]['HIS_f1']]
40 | for x in t[1:]:
41 | his_f1.append(his_log[x]['HIS_f1'])
42 | HIS.append(his_f1)
43 |
44 | for index, setting in enumerate(r_list):
45 | holdout_log = setting['all_HOLDOUT']
46 | holdout_f1 = [original_holdout]
47 | for x in t[1:]:
48 | holdout_f1.append(holdout_log[x]['holdout_f1'])
49 | HOLDOUT.append(holdout_f1)
50 |
51 | for index, setting in enumerate(r_list):
52 | time_log = setting['all_edit_time']
53 | time_list = [0]
54 | for x in t[1:]:
55 | time_list.append(time_log[x]/60)
56 | TIME.append(time_list)
57 |
58 | fig, axs = plt.subplots(1, 4, figsize=(8,2),dpi=800)
59 | for i in range(4):
60 | axs[i].set_box_aspect(0.85)
61 |
62 |
63 |
64 | axs[0].plot(t, HIS[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
65 | axs[0].plot(t, HIS[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
66 | axs[0].plot(t, HIS[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
67 | axs[0].set_xlim(-200,5200)
68 | axs[0].set_ylim(0.40,1.02)
69 | axs[0].xaxis.set_major_locator(MultipleLocator(1000))
70 | axs[0].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
71 | axs[0].yaxis.set_major_locator(MultipleLocator(0.2))
72 | axs[0].set_xlabel('Edits',fontweight='bold',family = 'serif')
73 | axs[0].set_ylabel('Edit Success',fontweight='bold',family = 'serif')
74 | axs[0].grid(True)
75 |
76 | axs[1].plot(t, HOLDOUT[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
77 | axs[1].plot(t, HOLDOUT[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
78 | axs[1].plot(t, HOLDOUT[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
79 | axs[1].set_xlim(-200,5200)
80 | axs[1].set_ylim(0.2,1.02)
81 | axs[1].xaxis.set_major_locator(MultipleLocator(1000))
82 | axs[1].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
83 | axs[1].yaxis.set_major_locator(MultipleLocator(0.2))
84 | axs[1].set_xlabel('Edits',fontweight='bold',family = 'serif')
85 | axs[1].set_ylabel('Holdout',fontweight='bold',family = 'serif')
86 | axs[1].grid(True)
87 |
88 | axs[2].plot(t, TIME[0], label = '1 rank(s)/block',linewidth=1.8,color=colors[0])
89 | axs[2].plot(t, TIME[1], label = '2 rank(s)/block',linewidth=1.8,color=colors[1])
90 | axs[2].plot(t, TIME[2], label = '4 rank(s)/block',linewidth=1.8,color=colors[2])
91 | axs[2].set_xlim(-200,5200)
92 | axs[2].set_ylim(0,3)
93 | axs[2].xaxis.set_major_locator(MultipleLocator(1000))
94 | axs[2].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
95 | axs[2].yaxis.set_major_locator(MultipleLocator(1))
96 | axs[2].set_xlabel('Edits',fontweight='bold',family = 'serif')
97 | axs[2].set_ylabel('Time (min)',fontweight='bold',family = 'serif')
98 | axs[2].grid(True)
99 |
100 |
101 | axs[3].plot(t, UP[0], label = 'Partial Rank = 1',linewidth=1.8,color=colors[0])
102 | axs[3].plot(t, UP[1], label = 'Partial Rank = 2',linewidth=1.8,color=colors[1])
103 | axs[3].plot(t, UP[2], label = 'Partial Rank = 4',linewidth=1.8,color=colors[2])
104 | axs[3].set_xlim(-200,5200)
105 | axs[3].set_ylim(0.5,0.9)
106 | axs[3].xaxis.set_major_locator(MultipleLocator(1000))
107 | axs[3].set_xticklabels(['0', '0','1k', '2k', '3k','4k','5k'])
108 | axs[3].yaxis.set_major_locator(MultipleLocator(0.1))
109 | axs[3].set_xlabel('Edits',fontweight='bold',family = 'serif')
110 | axs[3].set_ylabel('Locality',fontweight='bold',family = 'serif')
111 | axs[3].grid(True)
112 |
113 |
114 | lines_labels = [axs[3].get_legend_handles_labels()]
115 | lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
116 | fig.legend(lines, labels, loc='lower center', ncol=4)
117 | plt.tight_layout()
118 | plt.subplots_adjust(wspace=0.5)
119 | plt.subplots_adjust(wspace=0.5)
120 | plt.subplots_adjust(bottom=0.4)
121 |
122 | plt.savefig('plotres/T5Small_zsre_rank.jpg')
123 |
124 | plt.show()
125 |
126 | pass
127 |
128 |
129 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.19.0
2 | antlr4-python3-runtime==4.9.3
3 | appdirs==1.4.4
4 | attrs==23.1.0
5 | certifi @ file:///croot/certifi_1671487769961/work/certifi
6 | charset-normalizer==3.1.0
7 | click==8.1.3
8 | docker-pycreds==0.4.0
9 | filelock==3.12.0
10 | fsspec==2023.1.0
11 | gitdb==4.0.10
12 | GitPython==3.1.31
13 | huggingface-hub==0.15.1
14 | hydra-core==1.3.2
15 | idna==3.4
16 | importlib-metadata==6.6.0
17 | importlib-resources==5.12.0
18 | joblib==1.2.0
19 | jsonlines==3.1.0
20 | nltk==3.8.1
21 | numpy==1.21.6
22 | nvidia-cublas-cu11==11.10.3.66
23 | nvidia-cuda-nvrtc-cu11==11.7.99
24 | nvidia-cuda-runtime-cu11==11.7.99
25 | nvidia-cudnn-cu11==8.5.0.96
26 | omegaconf==2.3.0
27 | packaging==23.1
28 | pathtools==0.1.2
29 | Pillow==9.5.0
30 | protobuf==4.23.2
31 | psutil==5.9.5
32 | PyYAML==6.0
33 | regex==2023.6.3
34 | requests==2.31.0
35 | scikit-learn==1.0.2
36 | scipy==1.7.3
37 | sentence-transformers==2.2.2
38 | sentencepiece==0.1.99
39 | sentry-sdk==1.25.1
40 | setproctitle==1.3.2
41 | six==1.16.0
42 | smmap==5.0.0
43 | threadpoolctl==3.1.0
44 | tokenizers==0.13.3
45 | torch==1.13.1+cu116
46 | torchaudio==0.13.1+cu116
47 | torchvision==0.14.1+cu116
48 | tqdm==4.65.0
49 | transformers==4.29.2
50 | typing_extensions==4.6.3
51 | urllib3==2.0.2
52 | wandb==0.15.4
53 | zipp==3.15.0
54 |
--------------------------------------------------------------------------------