├── .gitignore ├── README.md ├── assets ├── full-lora-pissa.png ├── llama3.png ├── loss_landscape.gif └── models.png ├── configs ├── ds_config_zero2_no_offload.json └── ds_config_zero3.json ├── loss_landscape ├── main.py └── utils.py ├── pissa-sdxl.ipynb ├── requirements.txt ├── scripts ├── conveision_llama2_7b │ ├── run_full_finetune.sh │ ├── run_loftq.sh │ ├── run_lora.sh │ ├── run_pissa.sh │ ├── run_qlora.sh │ └── run_qpissa.sh ├── metamath_llama2_7b │ ├── run_full_finetune.sh │ ├── run_loftq.sh │ ├── run_lora.sh │ ├── run_pissa.sh │ ├── run_qlora.sh │ └── run_qpissa.sh └── python_llama2_7b │ ├── run_full_finetune.sh │ ├── run_loftq.sh │ ├── run_lora.sh │ ├── run_pissa.sh │ ├── run_qlora.sh │ └── run_qpissa.sh ├── train.py └── utils ├── code_process.py ├── gen_vllm.py ├── init_clover.py ├── init_crossover.py ├── init_pissa.py ├── init_qpissa.py ├── merge_adapter.py ├── nf4_to_bf16.py └── test_acc.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | output* 3 | meta-llama 4 | pissa-dataset -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **P**r**i**ncipal **S**ingular values and **S**ingular vectors **A**daptation 2 | 3 | [YouTube](https://youtu.be/X37WFwJ3nT4) 4 | 5 | ## Introduction 6 | We introduce a parameter-efficient fine-tuning (PEFT) method, **P**r**i**ncipal **S**ingular values and **S**ingular vectors **A**daptation (PiSSA), which optimizes the essential singular values and vectors while freezing the "noisy" parts. In comparison, LoRA freezes the original matrix and updates the "noise". This distinction enables PiSSA to convergence much faster than LoRA and also achieve better performance in the end. On five common benchmarks, PiSSA outperforms LoRA on all of them using exactly the same setups except for a different initialization. On GSM8K, Mistral-7B fine-tuned with PiSSA achieves an accuracy of 72.86\%, outperforming LoRA's 67.7\% by 5.16\%. 7 | Due to the same architecture, PiSSA inherits many of LoRA's advantages, such as parameter efficiency and compatibility with quantization. 8 | Furthermore, PiSSA reduces the 4-bit quantization error in LLaMA 2-7B by 18.97\%, resulting in a substantial improvement in fine-tuning performance. On the GSM8K benchmark, PiSSA achieves an accuracy of 49.13\%, surpassing the performances of QLoRA at 39.8\% and LoftQ at 40.71\%. 9 | Leveraging a fast SVD technique, the initialization of PiSSA takes only a few seconds, inducing negligible cost of switching LoRA to PiSSA. 10 | 11 | ![PiSSA](./assets/full-lora-pissa.png) 12 | ![llama-3-8b](./assets/llama3.png) 13 | ![models](./assets/models.png) 14 | ![loss-landscape](./assets/loss_landscape.gif) 15 | ## News 16 | - [2025.01.09] Provide [Document](https://huggingface.co/datasets/fxmeng/pissa-dataset) and [中文文档](https://hf-mirror.com/datasets/fxmeng/pissa-dataset/blob/main/README_CN.md) to help you better use PiSSA for training and testing. 17 | - [2024.07.17] PiSSA now support Conv2d and Embedding, [here](pissa-sdxl.ipynb) is an example for using PiSSA on SDXL. 18 | - [2024.07.16] PiSSA now support deepspeed. 19 | - [2024.05.16] PiSSA has been merged into the [main branch of peft](https://github.com/huggingface/peft) as an optional initialization method for LoRA. 20 | 21 | ## Quick Start 22 | 23 | Install PiSSA via pip: 24 | ``` 25 | git clone https://github.com/GraphPKU/PiSSA.git 26 | cd PiSSA/ 27 | # export HF_ENDPOINT=https://hf-mirror.com 28 | pip install -U huggingface_hub 29 | huggingface-cli download --repo-type dataset --resume-download fxmeng/pissa-dataset --local-dir pissa-dataset 30 | conda create -n pissa python=3.10 31 | conda activate pissa 32 | conda install nvidia/label/cuda-12.1.0::cuda-toolkit 33 | conda install pytorch==2.4.0 torchvision=0.19.0 pytorch-cuda=12.1 -c pytorch -c nvidia 34 | pip install -r requirements.txt 35 | pip install flash-attn --no-build-isolation 36 | ``` 37 | 38 | ## Reproduce the Results 39 | All the datasets we used are publicly available at [Dataset](https://huggingface.co/datasets/fxmeng/pissa-dataset). 40 | 41 | The PiSSA-initialized models are shared on [Models](https://huggingface.co/collections/fxmeng/pissa-qwen2-666a55e58b6feadc1015aa75) for easy reuse. They retain the same input and output as the original models but are split into residual models and PiSSA adapters for more effective fine-tuning. 42 | 43 | | | PiSSA | QPiSSA | 44 | | --- | --- | --- | 45 | | LLaMA-2-7B | [r128](https://huggingface.co/collections/fxmeng/pissa-llama-2-7b-66377477f7acbb051bc5dc6c) | [r16,32,64,128](https://huggingface.co/collections/fxmeng/pissa-llama-2-7b-66377477f7acbb051bc5dc6c) | 46 | | LLaMA-3-8B | [r16,32,64,128](https://huggingface.co/collections/fxmeng/pissa-llama-3-8b-6637591fe4156d34a4191628) | [r64,128](https://huggingface.co/collections/fxmeng/pissa-llama-3-8b-6637591fe4156d34a4191628) | 47 | | LLaMA-3-8B-Instruct | [r16,32,64,128](https://huggingface.co/collections/fxmeng/pissa-llama-3-8b-instruct-663774dbd2174225c139a653) | -- | 48 | | LLaMA-3-70B | -- | [r64,128](https://huggingface.co/collections/fxmeng/pissa-llama-3-70b-66376205164dfca129a4caf1) | 49 | | LLaMA-3-70B-Instruct | -- | [r128](https://huggingface.co/collections/fxmeng/pissa-llama-3-70b-66376205164dfca129a4caf1) | 50 | | Qwen2-7B | [r128](https://huggingface.co/collections/fxmeng/pissa-qwen2-666a55e58b6feadc1015aa75) | [r128](https://huggingface.co/collections/fxmeng/pissa-qwen2-666a55e58b6feadc1015aa75) | 51 | | Qwen2-7B-Instruct | [r128](https://huggingface.co/collections/fxmeng/pissa-qwen2-666a55e58b6feadc1015aa75) | [r128](https://huggingface.co/collections/fxmeng/pissa-qwen2-666a55e58b6feadc1015aa75) | 52 | | Qwen2-72B | --| [r64,128](https://huggingface.co/collections/fxmeng/pissa-qwen2-666a55e58b6feadc1015aa75) | 53 | | Qwen2-72B-Instruct | --| [r64,128](https://huggingface.co/collections/fxmeng/pissa-qwen2-666a55e58b6feadc1015aa75) | 54 | 55 | ### Training 56 | Running the following script will automatically download the model, then start training: 57 | ``` 58 | sh scripts/*/run_full_finetune.sh 59 | sh scripts/*/lora.sh 60 | sh scripts/*/pissa.sh 61 | sh scripts/*/loftq.sh 62 | sh scripts/*/qlora.sh 63 | sh scripts/*/qpissa.sh 64 | ``` 65 | ### Evaluation 66 | To evaluate the performance of your fine-tuned model, please follow the instructions in [fxmeng/pissa-dataset](https://huggingface.co/datasets/fxmeng/pissa-dataset). 67 | 68 | ## Advanced Usage 69 | We recommend downloading decomposed models directly from the [Hugging Face Collections](https://huggingface.co/collections/fxmeng) instead of performing SVD every time. 70 | If the existing models do not meet your needs, apply PiSSA initialization to a pre-trained model and store the decomposed model locally: 71 | ```python 72 | import torch 73 | import os 74 | from peft import LoraConfig, get_peft_model 75 | from transformers import AutoTokenizer, AutoModelForCausalLM 76 | MODEL_ID = "meta-llama/Llama-2-7b-hf" 77 | model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto") 78 | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) 79 | tokenizer.pad_token_id = tokenizer.eos_token_id 80 | lora_config = LoraConfig( 81 | # init_lora_weights="pissa", # Configure the initialization method to "pissa", which may take several minutes to execute SVD on the pre-trained model. 82 | init_lora_weights="pissa_niter_4", # Initialize the PiSSA with fast SVD, which completes in just a few seconds. 83 | r=128, 84 | lora_alpha=128, 85 | lora_dropout=0, # Since the component of the PiSSA adapter are the principal singular values and vectors, dropout should be set to 0 to avoid random discarding. 86 | target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], 87 | task_type="CAUSAL_LM", 88 | ) 89 | peft_model = get_peft_model(model, lora_config) 90 | peft_model.print_trainable_parameters() 91 | OUTPUT_DIR="PiSSA-Llama-2-7b-hf-r128" 92 | # Save PiSSA modules: 93 | peft_model.peft_config["default"].init_lora_weights = True # Important 94 | peft_model.save_pretrained(os.path.join(OUTPUT_DIR, "pissa_init")) 95 | # Save residual model: 96 | peft_model = peft_model.unload() 97 | peft_model.save_pretrained(OUTPUT_DIR) 98 | # Save the tokenizer: 99 | tokenizer.save_pretrained(OUTPUT_DIR) 100 | ``` 101 | 102 | Load a pre-processed model and finetune it on IMDB dataset: 103 | 104 | ```python 105 | from trl import SFTTrainer 106 | from datasets import load_dataset 107 | from transformers import AutoTokenizer, AutoModelForCausalLM 108 | from peft import PeftModel 109 | MODEL_ID = "PiSSA-Llama-2-7b-hf-r128" 110 | residual_model = AutoModelForCausalLM.from_pretrained(MODEL_ID,device_map="auto") 111 | model = PeftModel.from_pretrained(residual_model, MODEL_ID, subfolder = "pissa_init", is_trainable=True) 112 | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) 113 | dataset = load_dataset("imdb", split="train[:1%]") # Only use 1% of the dataset 114 | trainer = SFTTrainer( 115 | model=peft_model, 116 | train_dataset=dataset, 117 | dataset_text_field="text", 118 | max_seq_length=128, 119 | tokenizer=tokenizer, 120 | ) 121 | trainer.train() 122 | peft_model.save_pretrained("pissa-llama-2-7b-ft") 123 | ``` 124 | 125 | ### Convert PiSSA to LoRA 126 | When using `peft_model.save_pretrained`, if `path_initial_model_for_weight_conversion=None`, the fine-tuned matrices $A$ and $B$ are saved and should be combined with the residual model. However, when specifying `path_initial_model_for_weight_conversion="pissa_init_dir"`, the saving function converts PiSSA to LoRA by $\Delta W = A B - A_0 B_0 = [A | A_0] [B | -B_0]^T=A^{'}B^{'}$. This conversion enables the loading of LoRA on top of a standard base model: 127 | 128 | ```python 129 | import torch 130 | from peft import PeftModel 131 | from transformers import AutoModelForCausalLM 132 | 133 | model = AutoModelForCausalLM.from_pretrained( 134 | "meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto" 135 | ) 136 | # No SVD is performed during this step, and the base model remains unaltered. 137 | peft_model = PeftModel.from_pretrained(model, "pissa-llama-2-7b-lora") 138 | ``` 139 | Utilizing the converted LoRA does not require modifying the parameters of the base model. When multiple converted LoRAs are needed simultaneously, each adapter operates independently without interference, allowing for the adapters to be freely deleted or added. 140 | 141 | 142 | 143 | ## Citation 144 | ``` 145 | @article{meng2024pissa, 146 | title={Pissa: Principal singular values and singular vectors adaptation of large language models}, 147 | author={Meng, Fanxu and Wang, Zhaohui and Zhang, Muhan}, 148 | journal={arXiv preprint arXiv:2404.02948}, 149 | year={2024} 150 | } 151 | ``` 152 | 153 | ## Star History 154 | 155 | [![Star History Chart](https://api.star-history.com/svg?repos=GraphPKU/PiSSA&type=Date)](https://star-history.com/#GraphPKU/PiSSA&Date) 156 | 157 | ## Follow-up Work 158 | **2024, May 27**, [LoRA-XS: Low-Rank Adaptation with Extremely Small Number of Parameters](https://arxiv.org/abs/2405.17604) performs basis adaption for principal singular values and singular vectors. 159 | **2024, May 30**, [SVFT: Parameter-Efficient Fine-Tuning with Singular Vectors](https://arxiv.org/abs/2405.19597) freeze the singular vectors while fintune the singular values in a sparse manner. 160 | **2024, Jun 3**, [OLoRA: Orthonormal Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2406.01775), leverages orthonormal matrix initialization through QR decomposition. 161 | **2024, Jun 7**, [CorDA: Context-Oriented Decomposition Adaptation of Large Language Models](https://arxiv.org/abs/2406.05223), leverages knowledge-preserved adaptation and the instruction-previewed adaptation through Context-oriented Decomposition. 162 | **2024, Jun 7**, [MiLoRA: Harnessing Minor Singular Components for Parameter-Efficient LLM Finetuning](https://arxiv.org/abs/2406.09044), Minor Singular Components Adaption. 163 | **2024, Jun 18**, [LaMDA: Large Model Fine-Tuning via Spectrally Decomposed Low-Dimensional Adaptation](https://arxiv.org/abs/2406.12832) performs basis adaption for principal singular values and singular vectors. 164 | **2024, Jul 6**, [LoRA-GA: Low-Rank Adaptation with Gradient Approximation](https://arxiv.org/abs/2407.05000v1) aligns the gradients of low-rank matrix product with those of full fine-tuning at the first step. 165 | **2024, Jul 25**, [LoRA-Pro: Are Low-Rank Adapters Properly Optimized?](https://arxiv.org/abs/2407.18242) strategically adjusts the gradients of adapters, enabling the low-rank gradients to more accurately approximate the full fine-tuning gradients. 166 | **2024, Oct 9**, [One Initialization to Rule them All: Fine-tuning via Explained Variance Adaptation](https://arxiv.org/abs/2410.07170) initialize adapter in a data-driven manner by computing singular value decomposition on minibatches of activation vectors. 167 | **2024, Nov 7**, [SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models](https://arxiv.org/abs/2411.05007) consolidate the outliers by shifting them from activations to weights, then employ a high-precision low-rank branch to take in the weight outliers with SVD. 168 | -------------------------------------------------------------------------------- /assets/full-lora-pissa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PiSSA/87f4db6fc75a58b92c803cd9880c642bf011e69a/assets/full-lora-pissa.png -------------------------------------------------------------------------------- /assets/llama3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PiSSA/87f4db6fc75a58b92c803cd9880c642bf011e69a/assets/llama3.png -------------------------------------------------------------------------------- /assets/loss_landscape.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PiSSA/87f4db6fc75a58b92c803cd9880c642bf011e69a/assets/loss_landscape.gif -------------------------------------------------------------------------------- /assets/models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GraphPKU/PiSSA/87f4db6fc75a58b92c803cd9880c642bf011e69a/assets/models.png -------------------------------------------------------------------------------- /configs/ds_config_zero2_no_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | 6 | "zero_optimization": { 7 | "stage": 2, 8 | "allgather_partitions": true, 9 | "allgather_bucket_size": 1e8, 10 | "overlap_comm": true, 11 | "reduce_scatter": true, 12 | "reduce_bucket_size": 1e8, 13 | "contiguous_gradients": true 14 | }, 15 | 16 | "gradient_accumulation_steps": "auto", 17 | "gradient_clipping": "auto", 18 | "steps_per_print": 2000, 19 | "train_batch_size": "auto", 20 | "train_micro_batch_size_per_gpu": "auto", 21 | "wall_clock_breakdown": false 22 | } 23 | -------------------------------------------------------------------------------- /configs/ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | 15 | "scheduler": { 16 | "type": "WarmupLR", 17 | "params": { 18 | "warmup_min_lr": "auto", 19 | "warmup_max_lr": "auto", 20 | "warmup_num_steps": "auto" 21 | } 22 | }, 23 | 24 | "zero_optimization": { 25 | "stage": 3, 26 | "offload_optimizer": { 27 | "device": "cpu", 28 | "pin_memory": true 29 | }, 30 | "offload_param": { 31 | "device": "cpu", 32 | "pin_memory": true 33 | }, 34 | "overlap_comm": true, 35 | "contiguous_gradients": true, 36 | "sub_group_size": 1e9, 37 | "reduce_bucket_size": "auto", 38 | "stage3_prefetch_bucket_size": "auto", 39 | "stage3_param_persistence_threshold": "auto", 40 | "stage3_max_live_parameters": 1e9, 41 | "stage3_max_reuse_distance": 1e9, 42 | "stage3_gather_16bit_weights_on_model_save": true 43 | }, 44 | 45 | "gradient_accumulation_steps": "auto", 46 | "gradient_clipping": "auto", 47 | "steps_per_print": 20, 48 | "train_batch_size": "auto", 49 | "train_micro_batch_size_per_gpu": "auto", 50 | "wall_clock_breakdown": false 51 | } -------------------------------------------------------------------------------- /loss_landscape/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from copy import deepcopy 4 | import pytorch_lightning as pl 5 | import torch 6 | from utils import Model, MNISTData, DimReduction, LossGrid, animate_contour 7 | logging.getLogger("pytorch_lightning").setLevel(logging.WARNING) 8 | parser = argparse.ArgumentParser(description="Finetune MLP on MNIST dataset with Full Finetune, LoRA, PiSSA.") 9 | parser.add_argument( 10 | "--pretrain_epochs", 11 | type=int, 12 | default=200, 13 | ) 14 | parser.add_argument( 15 | "--epochs", 16 | type=int, 17 | default=100, 18 | ) 19 | parser.add_argument( 20 | "--rank", 21 | type=int, 22 | default=8, 23 | ) 24 | parser.add_argument( 25 | "--lr", 26 | type=float, 27 | default=5e-4, 28 | ) 29 | parser.add_argument( 30 | "--input_dim", 31 | type=int, 32 | default=8, 33 | ) 34 | parser.add_argument( 35 | "--hidden_dim", 36 | type=int, 37 | default=128, 38 | ) 39 | parser.add_argument( 40 | "--odd_number", 41 | type=int, 42 | default=10000, 43 | ) 44 | parser.add_argument( 45 | "--even_number", 46 | type=int, 47 | default=1000, 48 | ) 49 | args = parser.parse_args() 50 | mnist = MNISTData(odd_number=args.odd_number, even_number=args.even_number, input_dim=args.input_dim) 51 | torch.manual_seed(0) 52 | pretrain_model = Model( 53 | input_dim=mnist.input_dim, 54 | num_classes=mnist.num_classes, 55 | learning_rate=args.lr, 56 | hidden_dim=args.hidden_dim, 57 | ) 58 | print(pretrain_model) 59 | print(f"Training for {args.epochs} epochs...") 60 | train_loader = mnist.odd_dataloader() 61 | trainer = pl.Trainer(enable_progress_bar=True, max_epochs=args.pretrain_epochs) 62 | trainer.fit(pretrain_model, train_loader) 63 | state_dict = pretrain_model.state_dict() 64 | torch.manual_seed(0) 65 | full_model = Model( 66 | input_dim=mnist.input_dim, 67 | num_classes=mnist.num_classes, 68 | learning_rate=args.lr, 69 | hidden_dim=args.hidden_dim, 70 | ) 71 | full_model.load_state_dict(deepcopy(state_dict)) 72 | print(full_model) 73 | print(f"Training for {args.epochs} epochs...") 74 | trainer = pl.Trainer(enable_progress_bar=True, max_epochs=args.epochs) 75 | trainer.fit(full_model, mnist.even_dataloader()) 76 | full_optim_path, full_loss_path = zip(*[(path["flat_w"], path["loss"])for path in full_model.optim_path]) 77 | print(f"Dimensionality reduction method specified: pca") 78 | dim_reduction = DimReduction(params_path=full_optim_path,) 79 | full_directions = dim_reduction.pca()["reduced_dirs"] 80 | full_path_2d = dim_reduction.reduce_to_custom_directions(full_directions)["path_2d"] 81 | full_loss_grid = LossGrid( 82 | optim_path=full_optim_path, 83 | model=full_model, 84 | data=mnist.even_dataset.tensors, 85 | path_2d=full_path_2d, 86 | directions=full_directions, 87 | ) 88 | torch.manual_seed(0) 89 | lora_model = Model( 90 | input_dim=mnist.input_dim, 91 | num_classes=mnist.num_classes, 92 | learning_rate=args.lr, 93 | lora_r=args.rank, 94 | hidden_dim=args.hidden_dim, 95 | ) 96 | lora_model.load_state_dict(deepcopy(state_dict)) 97 | lora_model.convert_to_lora_pissa(True) 98 | print(lora_model) 99 | for name, param in lora_model.named_parameters(): 100 | print(name,param.requires_grad) 101 | print(f"Training for {args.epochs} epochs...") 102 | trainer = pl.Trainer(enable_progress_bar=True, max_epochs=args.epochs) 103 | trainer.fit(lora_model, mnist.even_dataloader()) 104 | # Sample from full path 105 | lora_optim_path, lora_loss_path = zip(*[(path["flat_w"], path["loss"])for path in lora_model.optim_path]) 106 | print(f"Dimensionality reduction method specified: custom") 107 | dim_reduction = DimReduction( 108 | params_path=lora_optim_path, 109 | ) 110 | lora_path_2d = dim_reduction.reduce_to_custom_directions(full_directions)["path_2d"] 111 | torch.manual_seed(0) 112 | pissa_model = Model( 113 | input_dim=mnist.input_dim, 114 | num_classes=mnist.num_classes, 115 | learning_rate=args.lr, 116 | lora_r=args.rank, 117 | hidden_dim=args.hidden_dim, 118 | ) 119 | pissa_model.load_state_dict(deepcopy(state_dict)) 120 | pissa_model.convert_to_lora_pissa("pissa") 121 | print(pissa_model) 122 | for name, param in pissa_model.named_parameters(): 123 | print(name,param.requires_grad) 124 | print(f"Training for {args.epochs} epochs...") 125 | trainer = pl.Trainer(enable_progress_bar=True, max_epochs=args.epochs) 126 | trainer.fit(pissa_model, mnist.even_dataloader()) 127 | pissa_optim_path, pissa_loss_path = zip(*[(path["flat_w"], path["loss"])for path in pissa_model.optim_path]) 128 | print(f"Dimensionality reduction method specified: custom") 129 | dim_reduction = DimReduction( 130 | params_path=pissa_optim_path, 131 | ) 132 | pissa_path_2d = dim_reduction.reduce_to_custom_directions(full_directions)["path_2d"] 133 | animate_contour( 134 | full_param_steps=full_path_2d.tolist(), 135 | lora_param_steps=lora_path_2d.tolist(), 136 | pissa_param_steps=pissa_path_2d.tolist(), 137 | full_loss_steps=full_loss_path, 138 | lora_loss_steps=lora_loss_path, 139 | pissa_loss_steps=pissa_loss_path, 140 | loss_grid=full_loss_grid.loss_values_log_2d, 141 | coords=full_loss_grid.coords, 142 | true_optim_point=full_loss_grid.true_optim_point, 143 | filename="loss_landscape.gif", 144 | ) -------------------------------------------------------------------------------- /loss_landscape/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytorch_lightning as pl 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from torch.utils.data import DataLoader, TensorDataset 7 | from torchvision import transforms 8 | from torchvision.datasets import MNIST 9 | from sklearn.decomposition import PCA 10 | from torch import nn 11 | from peft.tuners.lora.layer import Linear 12 | from torch.optim import Adam 13 | from matplotlib.animation import FuncAnimation 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import torch 17 | from sklearn.decomposition import PCA 18 | from tqdm import tqdm 19 | RES = 50 20 | MARGIN = 0.1 21 | torch.manual_seed(0) 22 | class Model(pl.LightningModule): 23 | def __init__( 24 | self, input_dim, num_classes=5, lora_r=0, hidden_dim=128, optimizer="adam", learning_rate=0, gpus=1, 25 | ): 26 | super().__init__() 27 | self.learning_rate = learning_rate 28 | self.optimizer = optimizer 29 | self.gpus = gpus 30 | self.optim_path = [] 31 | self.training_step_outputs = [] 32 | self.lora_r = lora_r 33 | self.layers = nn.Sequential( 34 | nn.Linear(input_dim, hidden_dim, bias=False), 35 | nn.ReLU(), 36 | nn.Linear(hidden_dim, num_classes, bias=False) 37 | ) 38 | 39 | def forward(self, x_in, apply_softmax=False): 40 | y_pred = self.layers(x_in) 41 | if apply_softmax: 42 | y_pred = F.softmax(y_pred, dim=1) 43 | return y_pred 44 | 45 | def loss_fn(self, y_pred, y): 46 | return F.cross_entropy(y_pred, y) 47 | 48 | 49 | def training_step(self, batch, batch_idx): 50 | X, y = batch 51 | y_pred = self(X) 52 | # Get model weights flattened here to append to optim_path later 53 | flat_w = self.get_flat_params() 54 | loss = self.loss_fn(y_pred, y) 55 | self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) 56 | self.training_step_outputs.append({"loss": loss, "flat_w": flat_w}) 57 | return {"loss": loss, "flat_w": flat_w} 58 | def on_train_epoch_end(self): 59 | self.optim_path.append(self.training_step_outputs[-1]) 60 | 61 | def configure_optimizers(self): 62 | parameters = [param for param in self.parameters() if param.requires_grad] 63 | return Adam(parameters, self.learning_rate) 64 | def get_flat_params(self): 65 | """Get flattened and concatenated params of the model.""" 66 | if self.lora_r > 0: 67 | params = {} 68 | for name, module in self.named_modules(): 69 | if isinstance(module, Linear): 70 | base_layer = module.base_layer.weight.data 71 | lora_A = module.lora_A["default"].weight.data 72 | lora_B = module.lora_B["default"].weight.data 73 | params[name+".weight"] = base_layer + module.scaling['default'] * lora_B @ lora_A 74 | else: 75 | params = self._get_params() 76 | flat_params = torch.Tensor() 77 | if torch.cuda.is_available() and self.gpus > 0: 78 | flat_params = flat_params.cuda() 79 | for _, param in params.items(): 80 | flat_params = torch.cat((flat_params, torch.flatten(param))) 81 | return flat_params 82 | def init_from_flat_params(self, flat_params): 83 | """Set all model parameters from the flattened form.""" 84 | if not isinstance(flat_params, torch.Tensor): 85 | raise AttributeError( 86 | "Argument to init_from_flat_params() must be torch.Tensor" 87 | ) 88 | shapes = self._get_param_shapes() 89 | state_dict = self._unflatten_to_state_dict(flat_params, shapes) 90 | self.load_state_dict(state_dict, strict=True) 91 | def _get_param_shapes(self): 92 | shapes = [] 93 | for name, param in self.named_parameters(): 94 | shapes.append((name, param.shape, param.numel())) 95 | return shapes 96 | def _get_params(self): 97 | params = {} 98 | for name, param in self.named_parameters(): 99 | params[name] = param.data 100 | return params 101 | def _unflatten_to_state_dict(self, flat_w, shapes): 102 | state_dict = {} 103 | counter = 0 104 | for shape in shapes: 105 | name, tsize, tnum = shape 106 | param = flat_w[counter : counter + tnum].reshape(tsize) 107 | state_dict[name] = torch.nn.Parameter(param) 108 | counter += tnum 109 | assert counter == len(flat_w), "counter must reach the end of weight vector" 110 | return state_dict 111 | def convert_to_lora_pissa(self, init_lora_weights): 112 | def convert(model, init_lora_weights): 113 | for name, module in model.named_children(): 114 | if isinstance(module, torch.nn.Linear): 115 | setattr(model, name, Linear(module, adapter_name="default", r = self.lora_r, lora_alpha = self.lora_r, init_lora_weights = init_lora_weights)) 116 | else: 117 | convert(module, init_lora_weights) 118 | convert(self, init_lora_weights) 119 | for name, param in self.named_parameters(): 120 | if "lora_" not in name: 121 | param.requires_grad=False 122 | class MNISTData(pl.LightningDataModule): 123 | def __init__(self, odd_number=1000, even_number=1000, input_dim=8): 124 | super().__init__() 125 | transform = transforms.Compose([ 126 | transforms.ToTensor(), 127 | transforms.Normalize((0.1307,), (0.3081,)), 128 | ]) 129 | self.num_classes = 5 130 | mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform) 131 | even_mask = mnist_train.targets%2==0 132 | even_X = mnist_train.data[even_mask] 133 | odd_X = mnist_train.data[~even_mask] 134 | even_Y = mnist_train.targets[even_mask]//2 135 | odd_Y = mnist_train.targets[~even_mask]//2 136 | rand_odd = torch.randperm(len(odd_Y))[:odd_number] 137 | rand_even = torch.randperm(len(even_Y))[:even_number] 138 | odd_X = odd_X[rand_odd] 139 | odd_Y = odd_Y[rand_odd] 140 | even_X = even_X[rand_even] 141 | even_Y = even_Y[rand_even] 142 | self.input_dim = input_dim 143 | pca = PCA(n_components=input_dim) 144 | all_features = torch.cat([odd_X.view(odd_number, -1), even_X.view(even_number, -1)]).numpy() 145 | all_features = torch.from_numpy(pca.fit_transform(all_features)).to(torch.float32) 146 | odd_X = all_features[:len(odd_X)] 147 | even_X = all_features[len(odd_X):] 148 | 149 | self.odd_dataset = TensorDataset(odd_X, odd_Y) 150 | self.even_dataset = TensorDataset(even_X, even_Y) 151 | 152 | def odd_dataloader(self, num_workers=7): 153 | return DataLoader( 154 | self.odd_dataset, 155 | batch_size=self.odd_dataset.__len__(), 156 | num_workers=num_workers, 157 | persistent_workers=True, 158 | ) 159 | 160 | def even_dataloader(self, num_workers=7): 161 | return DataLoader( 162 | self.even_dataset, 163 | batch_size=self.even_dataset.__len__(), 164 | num_workers=num_workers, 165 | persistent_workers=True, 166 | ) 167 | class DimReduction: 168 | """The dimensionality reduction class.""" 169 | def __init__(self, params_path, seed=0): 170 | """Init a dimensionality reduction object. 171 | Args: 172 | params_path: list of full-dimensional flattened parameters from training. 173 | seed: seed for reproducible experiments. 174 | """ 175 | self.optim_path_matrix = self._transform(params_path) 176 | self.n_steps, self.n_dim = self.optim_path_matrix.shape 177 | self.seed = seed 178 | def pca(self): 179 | pca = PCA(n_components=2, random_state=self.seed) 180 | path_2d = pca.fit_transform(self.optim_path_matrix) 181 | reduced_dirs = pca.components_ 182 | assert path_2d.shape == (self.n_steps, 2) 183 | return {"reduced_dirs": reduced_dirs,} 184 | def reduce_to_custom_directions(self, custom_directions): 185 | """Project self.optim_path_matrix onto (u, v).""" 186 | path_projection = self.optim_path_matrix.dot(custom_directions.T) 187 | assert path_projection.shape == (self.n_steps, 2) 188 | return {"path_2d": path_projection,} 189 | def _transform(self, model_params): 190 | npvectors = [] 191 | for tensor in model_params: 192 | npvectors.append(np.array(tensor.cpu())) 193 | return np.vstack(npvectors) 194 | class LossGrid: 195 | """The loss grid class that holds the values of 2D slice from the loss landscape.""" 196 | def __init__( 197 | self, 198 | optim_path, 199 | model, 200 | data, 201 | path_2d, 202 | directions, 203 | res=RES, 204 | tqdm_disable=False, 205 | loss_values_2d=None, 206 | argmin=None, 207 | loss_min=None, 208 | ): 209 | self.dir0, self.dir1 = directions 210 | self.path_2d = path_2d 211 | self.optim_point = optim_path[-1] 212 | self.optim_point_2d = path_2d[-1] 213 | alpha = self._compute_stepsize(res) 214 | self.params_grid = self.build_params_grid(res, alpha) 215 | 216 | if loss_values_2d is not None and argmin is not None and loss_min is not None: 217 | self.loss_values_2d = loss_values_2d 218 | self.argmin = argmin 219 | self.loss_min = loss_min 220 | else: 221 | self.loss_values_2d, self.argmin, self.loss_min = self.compute_loss_2d( 222 | model, data, tqdm_disable=tqdm_disable 223 | ) 224 | 225 | self.loss_values_log_2d = np.log(self.loss_values_2d) 226 | self.coords = self._convert_coords(res, alpha) 227 | # True optim in loss grid 228 | self.true_optim_point = self.indices_to_coords(self.argmin, res, alpha) 229 | def build_params_grid(self, res, alpha): 230 | """ 231 | Produce the grid for the contour plot. 232 | Start from the optimal point, span directions of the pca result with 233 | stepsize alpha, resolution res. 234 | """ 235 | grid = [] 236 | for i in range(-res, res): 237 | row = [] 238 | for j in range(-res, res): 239 | w_new = ( 240 | self.optim_point.cpu() 241 | + i * alpha * self.dir0 242 | + j * alpha * self.dir1 243 | ) 244 | row.append(w_new) 245 | grid.append(row) 246 | assert (grid[res][res] == self.optim_point.cpu()).all() 247 | return grid 248 | def compute_loss_2d(self, model, data, tqdm_disable=False): 249 | """Compute loss values for each weight vector in grid for the model and data.""" 250 | X, y = data 251 | loss_2d = [] 252 | n = len(self.params_grid) 253 | m = len(self.params_grid[0]) 254 | loss_min = float("inf") 255 | argmin = () 256 | print("Generating loss values for the contour plot...") 257 | with tqdm(total=n * m, disable=tqdm_disable) as pbar: 258 | for i in range(n): 259 | loss_row = [] 260 | for j in range(m): 261 | w_ij = torch.Tensor(self.params_grid[i][j].float()) 262 | # Load flattened weight vector into model 263 | model.init_from_flat_params(w_ij) 264 | y_pred = model(X) 265 | loss_val = model.loss_fn(y_pred, y).item() 266 | if loss_val < loss_min: 267 | loss_min = loss_val 268 | argmin = (i, j) 269 | loss_row.append(loss_val) 270 | pbar.update(1) 271 | loss_2d.append(loss_row) 272 | # This transpose below is very important for a correct contour plot because 273 | # originally in loss_2d, dir1 (y) is row-direction, dir0 (x) is column 274 | loss_2darray = np.array(loss_2d).T 275 | print("\nLoss values generated.") 276 | return loss_2darray, argmin, loss_min 277 | def _convert_coord(self, i, ref_point_coord, alpha): 278 | """ 279 | Convert from integer index to the coordinate value. 280 | Given a reference point coordinate (1D), find the value i steps away with 281 | step size alpha. 282 | """ 283 | return i * alpha + ref_point_coord 284 | def _convert_coords(self, res, alpha): 285 | """ 286 | Convert the coordinates from (i, j) indices to (x, y) values. 287 | Remember that for PCA, the coordinates have unit vectors as the top 2 PCs. 288 | Original path_2d has PCA output, i.e. the 2D projections of each W step 289 | onto the 2D space spanned by the top 2 PCs. 290 | We need these steps in (i, j) terms with unit vectors 291 | reduced_w1 = (1, 0) and reduced_w2 = (0, 1) in the 2D space. 292 | We center the plot on optim_point_2d, i.e. 293 | let center_2d = optim_point_2d 294 | ``` 295 | i = (x - optim_point_2d[0]) / alpha 296 | j = (y - optim_point_2d[1]) / alpha 297 | i.e. 298 | x = i * alpha + optim_point_2d[0] 299 | y = j * alpha + optim_point_2d[1] 300 | ``` 301 | where (x, y) is the 2D points in path_2d from PCA. Again, the unit 302 | vectors are reduced_w1 and reduced_w2. 303 | Return the grid coordinates in terms of (x, y) for the loss values 304 | """ 305 | converted_coord_xs = [] 306 | converted_coord_ys = [] 307 | for i in range(-res, res): 308 | x = self._convert_coord(i, self.optim_point_2d[0], alpha) 309 | y = self._convert_coord(i, self.optim_point_2d[1], alpha) 310 | converted_coord_xs.append(x) 311 | converted_coord_ys.append(y) 312 | return np.array(converted_coord_xs), np.array(converted_coord_ys) 313 | def indices_to_coords(self, indices, res, alpha): 314 | """Convert the (i, j) indices to (x, y) coordinates. 315 | Args: 316 | indices: (i, j) indices to convert. 317 | res: Resolution. 318 | alpha: Step size. 319 | Returns: 320 | The (x, y) coordinates in the projected 2D space. 321 | """ 322 | grid_i, grid_j = indices 323 | i, j = grid_i - res, grid_j - res 324 | x = i * alpha + self.optim_point_2d[0] 325 | y = j * alpha + self.optim_point_2d[1] 326 | return x, y 327 | def _compute_stepsize(self, res): 328 | dist_2d = self.path_2d[-1] - self.path_2d[0] 329 | dist = (dist_2d[0] ** 2 + dist_2d[1] ** 2) ** 0.5 330 | return dist * (1 + MARGIN) / res 331 | def _animate_progress(current_frame, total_frames): 332 | print("\r" + f"Processing {current_frame+1}/{total_frames} frames...", end="") 333 | if current_frame + 1 == total_frames: 334 | print("\nConverting to gif, this may take a while...") 335 | def animate_contour( 336 | full_param_steps, 337 | lora_param_steps, 338 | pissa_param_steps, 339 | full_loss_steps, 340 | lora_loss_steps, 341 | pissa_loss_steps, 342 | loss_grid, 343 | coords, 344 | true_optim_point, 345 | giffps=15, 346 | figsize=(9, 6), 347 | filename="test.gif", 348 | ): 349 | n_frames = len(full_param_steps) 350 | print(f"\nTotal frames to process: {n_frames}, result frames per second: {giffps}") 351 | fig, ax = plt.subplots(figsize=figsize) 352 | coords_x, coords_y = coords 353 | from matplotlib.colors import LinearSegmentedColormap 354 | colors=["#FFD06F", "#FFE6B7","#AADCE0","#72BCD5", "#528FAD","#376795", "#1E466E"] 355 | custom_cmap = LinearSegmentedColormap.from_list("my_cmap", colors) 356 | ax.contourf(coords_x, coords_y, loss_grid, levels=35, alpha=0.9, cmap=custom_cmap) 357 | ax.plot(true_optim_point[0], true_optim_point[1], "bx", markersize=10, label="Target Local Minimum") 358 | plt.rcParams.update({'font.size': 14}) 359 | full_W0 = full_param_steps[0] 360 | lora_W0 = lora_param_steps[0] 361 | pissa_W0 = pissa_param_steps[0] 362 | full_w1s = [full_W0[0]] 363 | full_w2s = [full_W0[1]] 364 | lora_w1s = [lora_W0[0]] 365 | lora_w2s = [lora_W0[1]] 366 | pissa_w1s = [pissa_W0[0]] 367 | pissa_w2s = [pissa_W0[1]] 368 | (full_pathline,) = ax.plot(full_w1s, full_w2s, color="#E76254", lw=3, label="Full FT") 369 | (full_point,) = ax.plot(full_W0[0], full_W0[1], color="#E76254", marker='o') 370 | (lora_pathline,) = ax.plot(lora_w1s, lora_w2s, color="#528FAD", lw=3, label="LoRA") 371 | (lora_point,) = ax.plot(lora_W0[0], lora_W0[1], color="#528FAD", marker='o') 372 | (pissa_pathline,) = ax.plot(pissa_w1s, pissa_w2s, color="#F7AA58", lw=3, label="PiSSA") 373 | (pissa_point,) = ax.plot(pissa_W0[0], pissa_W0[1], color="#F7AA58", marker='o') 374 | 375 | def animate(i): 376 | full_W = full_param_steps[i] 377 | full_w1s.append(full_W[0]) 378 | full_w2s.append(full_W[1]) 379 | full_pathline.set_data([full_w1s, ], [full_w2s, ]) 380 | 381 | full_point.set_data([full_W[0], ], [full_W[1], ]) 382 | 383 | lora_W = lora_param_steps[i] 384 | lora_w1s.append(lora_W[0]) 385 | lora_w2s.append(lora_W[1]) 386 | lora_pathline.set_data([lora_w1s,], [lora_w2s, ]) 387 | lora_point.set_data([lora_W[0],], [lora_W[1], ]) 388 | 389 | pissa_W = pissa_param_steps[i] 390 | pissa_w1s.append(pissa_W[0]) 391 | pissa_w2s.append(pissa_W[1]) 392 | pissa_pathline.set_data([pissa_w1s, ], [ pissa_w2s, ]) 393 | pissa_point.set_data([pissa_W[0], ], [ pissa_W[1], ]) 394 | 395 | if i % 20 == 19: 396 | ax.plot(full_W[0], full_W[1], color="#E76254", marker='+', markersize=12) 397 | ax.plot(lora_W[0], lora_W[1], color="#528FAD", marker='+', markersize=12) 398 | ax.plot(pissa_W[0], pissa_W[1], color="#F7AA58", marker='+', markersize=12) 399 | 400 | full_pathline.set_label(f"Full FT Loss: {full_loss_steps[i]: .3f}") 401 | lora_pathline.set_label(f"LoRA Loss: {lora_loss_steps[i]: .3f}") 402 | pissa_pathline.set_label(f"PiSSA Loss: {pissa_loss_steps[i]: .3f}") 403 | plt.legend(loc="upper right") 404 | fig.savefig(filename.replace("gif","pdf")) 405 | global anim 406 | anim = FuncAnimation( 407 | fig, animate, frames=len(full_param_steps), interval=100, blit=False, repeat=False 408 | ) 409 | 410 | fig.tight_layout() 411 | print(f"Writing {filename}.") 412 | anim.save( 413 | f"./{filename}", 414 | writer="imagemagick", 415 | fps=giffps, 416 | progress_callback=_animate_progress, 417 | ) 418 | print(f"\n{filename} created successfully.") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.2.1 2 | transformers==4.45.1 3 | datasets==3.2.0 4 | peft==0.14.0 5 | deepspeed==0.15.4 6 | vllm==0.6.2 7 | tensorboardX 8 | tqdm 9 | attrdict 10 | human_eval 11 | evalplus 12 | fraction -------------------------------------------------------------------------------- /scripts/conveision_llama2_7b/run_full_finetune.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | OUTPUT_PATH="output/conversation-FullFT-Llama-2-7b" 3 | DATA_PATH="pissa-dataset" 4 | export HF_ENDPOINT=https://hf-mirror.com 5 | # huggingface-cli download --token hf_*** --resume-download $BASE_MODEL --local-dir $BASE_MODEL 6 | 7 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 8 | deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \ 9 | --deepspeed configs/ds_config_zero2_no_offload.json \ 10 | --model_name_or_path $BASE_MODEL \ 11 | --full_finetune True \ 12 | --bf16 \ 13 | --data_path $DATA_PATH \ 14 | --sub_task conversation:100000 \ 15 | --dataset_split "train"\ 16 | --dataset_field instruction output \ 17 | --output_dir $OUTPUT_PATH \ 18 | --num_train_epochs 1 \ 19 | --model_max_length 512 \ 20 | --per_device_train_batch_size 2 \ 21 | --gradient_accumulation_steps 8 \ 22 | --save_strategy "steps" \ 23 | --save_steps 1000 \ 24 | --save_total_limit 1 \ 25 | --learning_rate 2e-5 \ 26 | --weight_decay 0. \ 27 | --warmup_ratio 0.03 \ 28 | --logging_steps 1 \ 29 | --lr_scheduler_type "cosine" \ 30 | --report_to "tensorboard" \ 31 | -------------------------------------------------------------------------------- /scripts/conveision_llama2_7b/run_loftq.sh: -------------------------------------------------------------------------------- 1 | # LoftQ only provide model with rank=64, one can DIY a rank=128 version following: 2 | # https://github.com/yxli2123/LoftQ/tree/main 3 | RESIDUAL_MODEL="LoftQ/Meta-Llama-3-8B-4bit-64rank" 4 | OUTPUT_PATH="output/LoftQ-Llama-3-8B-4bit-64rank" 5 | DATA_PATH="meta-math/MetaMathQA" 6 | 7 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 8 | deepspeed --master_port=16971 --include=localhost:0 train.py \ 9 | --deepspeed configs/ds_config_zero2_no_offload.json \ 10 | --model_name_or_path $RESIDUAL_MODEL \ 11 | --full_finetune False \ 12 | --bf16 \ 13 | --bits 4 \ 14 | --use_lora True \ 15 | --adapter_name_or_path "loftq_init" \ 16 | --data_path $DATA_PATH \ 17 | --dataset_field query response \ 18 | --dataset_split "train[:100000]"\ 19 | --output_dir $OUTPUT_PATH \ 20 | --num_train_epochs 1 \ 21 | --model_max_length 512 \ 22 | --per_device_train_batch_size 1 \ 23 | --gradient_accumulation_steps 128 \ 24 | --save_strategy "steps" \ 25 | --save_steps 100 \ 26 | --save_total_limit 100 \ 27 | --learning_rate 2e-5 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.03 \ 30 | --logging_steps 1 \ 31 | --lr_scheduler_type "cosine" \ 32 | --report_to "tensorboard" \ 33 | -------------------------------------------------------------------------------- /scripts/conveision_llama2_7b/run_lora.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | OUTPUT_PATH="output/conversation-LoRA-Llama-2-7b-r128" 3 | DATA_PATH="pissa-dataset" 4 | 5 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 6 | deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \ 7 | --deepspeed configs/ds_config_zero2_no_offload.json \ 8 | --model_name_or_path $BASE_MODEL \ 9 | --full_finetune False \ 10 | --bf16 \ 11 | --init_weights True \ 12 | --target_modules "q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" \ 13 | --lora_rank 128 \ 14 | --lora_alpha 128 \ 15 | --lora_dropout 0 \ 16 | --data_path $DATA_PATH \ 17 | --sub_task conversation \ 18 | --dataset_split train \ 19 | --dataset_field instruction output \ 20 | --output_dir $OUTPUT_PATH \ 21 | --num_train_epochs 1 \ 22 | --model_max_length 512 \ 23 | --per_device_train_batch_size 4 \ 24 | --gradient_accumulation_steps 4 \ 25 | --save_strategy "steps" \ 26 | --save_steps 1000 \ 27 | --save_total_limit 1 \ 28 | --learning_rate 2e-5 \ 29 | --weight_decay 0. \ 30 | --warmup_ratio 0.03 \ 31 | --logging_steps 1 \ 32 | --lr_scheduler_type "cosine" \ 33 | --report_to "tensorboard" \ 34 | --merge True \ 35 | 36 | -------------------------------------------------------------------------------- /scripts/conveision_llama2_7b/run_pissa.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | RES_MODEL="output/PiSSA-Llama-2-7b-r128" 3 | OUTPUT_PATH="output/conversation-PiSSA-Llama-2-7b-r128" 4 | DATA_PATH="pissa-dataset" 5 | export HF_ENDPOINT=https://hf-mirror.com 6 | 7 | #huggingface-cli download --token hf_*** --resume-download $RES_MODEL --local-dir $RES_MODEL 8 | if [ -e $RES_MODEL ]; then 9 | echo "Use pre-initialized residual model." 10 | else 11 | echo "Perform PiSSA initialization by my self." 12 | conversation utils/init_pissa.py --base_model_path $BASE_MODEL --output_dir $RES_MODEL --init_weights pissa_niter_16 --lora_r 128 --lora_alpha 128 --lora_dropout 0 --target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj 13 | fi 14 | 15 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 16 | deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \ 17 | --deepspeed configs/ds_config_zero2_no_offload.json \ 18 | --model_name_or_path $RES_MODEL \ 19 | --full_finetune False \ 20 | --bf16 \ 21 | --adapter_name_or_path "pissa_init" \ 22 | --data_path $DATA_PATH \ 23 | --sub_task conversation \ 24 | --dataset_split train \ 25 | --dataset_field instruction output \ 26 | --output_dir $OUTPUT_PATH \ 27 | --num_train_epochs 1 \ 28 | --model_max_length 512 \ 29 | --per_device_train_batch_size 4 \ 30 | --gradient_accumulation_steps 4 \ 31 | --save_strategy "steps" \ 32 | --save_steps 1000 \ 33 | --save_total_limit 1 \ 34 | --learning_rate 2e-5 \ 35 | --weight_decay 0. \ 36 | --warmup_ratio 0.03 \ 37 | --logging_steps 1 \ 38 | --lr_scheduler_type "cosine" \ 39 | --report_to "tensorboard" \ 40 | --merge True \ 41 | -------------------------------------------------------------------------------- /scripts/conveision_llama2_7b/run_qlora.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | OUTPUT_PATH="output/conversation-QLoRA-Llama-2-7B-4bit-r128" 3 | DATA_PATH="pissa-dataset" 4 | 5 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 6 | deepspeed --master_port=16971 --include=localhost:0 train.py \ 7 | --deepspeed configs/ds_config_zero2_no_offload.json \ 8 | --model_name_or_path $BASE_MODEL \ 9 | --full_finetune False \ 10 | --bf16 \ 11 | --bits 4 \ 12 | --init_weights True \ 13 | --target_modules "q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" \ 14 | --lora_rank 128 \ 15 | --lora_alpha 128 \ 16 | --lora_dropout 0 \ 17 | --data_path $DATA_PATH \ 18 | --dataset_split "train"\ 19 | --sub_task conversation \ 20 | --dataset_field instruction output \ 21 | --output_dir $OUTPUT_PATH \ 22 | --num_train_epochs 1 \ 23 | --model_max_length 512 \ 24 | --per_device_train_batch_size 1 \ 25 | --gradient_accumulation_steps 128 \ 26 | --save_strategy "steps" \ 27 | --save_steps 100 \ 28 | --save_total_limit 100 \ 29 | --learning_rate 2e-5 \ 30 | --weight_decay 0. \ 31 | --warmup_ratio 0.03 \ 32 | --logging_steps 1 \ 33 | --lr_scheduler_type "cosine" \ 34 | --report_to "tensorboard" \ 35 | -------------------------------------------------------------------------------- /scripts/conveision_llama2_7b/run_qpissa.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | RES_MODEL="output/QPiSSA-Llama-2-7b-4bit-r128-5iter" 3 | OUTPUT_PATH="output/conversation-QPiSSA-Llama-2-7b-4bit-r128-5iter" 4 | DATA_PATH="pissa-dataset" 5 | 6 | if [ -e $RES_MODEL ]; then 7 | echo "Use pre-initialized residual model." 8 | else 9 | echo "Perform QPiSSA initialization by my self." 10 | conversation utils/init_qpissa.py --base_model_dir $BASE_MODEL --output_path $RES_MODEL --rank 128 --iter 5 --target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj 11 | fi 12 | 13 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 14 | deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \ 15 | --deepspeed configs/ds_config_zero2_no_offload.json \ 16 | --model_name_or_path $RES_MODEL \ 17 | --full_finetune False \ 18 | --bf16 \ 19 | --bits 4 \ 20 | --adapter_name_or_path "pissa_init" \ 21 | --data_path $DATA_PATH \ 22 | --sub_task conversation \ 23 | --dataset_split train \ 24 | --dataset_field instruction output \ 25 | --output_dir $OUTPUT_PATH \ 26 | --num_train_epochs 1 \ 27 | --model_max_length 512 \ 28 | --per_device_train_batch_size 4 \ 29 | --gradient_accumulation_steps 4 \ 30 | --save_strategy "steps" \ 31 | --save_steps 100 \ 32 | --save_total_limit 100 \ 33 | --learning_rate 2e-5 \ 34 | --weight_decay 0. \ 35 | --warmup_ratio 0.03 \ 36 | --logging_steps 1 \ 37 | --lr_scheduler_type "cosine" \ 38 | --report_to "tensorboard" \ 39 | -------------------------------------------------------------------------------- /scripts/metamath_llama2_7b/run_full_finetune.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | OUTPUT_PATH="output/metamath-FullFT-Llama-2-7b" 3 | DATA_PATH="pissa-dataset" 4 | export HF_ENDPOINT=https://hf-mirror.com 5 | # huggingface-cli download --token hf_*** --resume-download $BASE_MODEL --local-dir $BASE_MODEL 6 | 7 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 8 | deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \ 9 | --deepspeed configs/ds_config_zero2_no_offload.json \ 10 | --model_name_or_path $BASE_MODEL \ 11 | --full_finetune True \ 12 | --bf16 \ 13 | --data_path $DATA_PATH \ 14 | --sub_task metamath:100000 \ 15 | --dataset_split "train"\ 16 | --dataset_field instruction output \ 17 | --output_dir $OUTPUT_PATH \ 18 | --num_train_epochs 1 \ 19 | --model_max_length 512 \ 20 | --per_device_train_batch_size 2 \ 21 | --gradient_accumulation_steps 8 \ 22 | --save_strategy "steps" \ 23 | --save_steps 1000 \ 24 | --save_total_limit 1 \ 25 | --learning_rate 2e-5 \ 26 | --weight_decay 0. \ 27 | --warmup_ratio 0.03 \ 28 | --logging_steps 1 \ 29 | --lr_scheduler_type "cosine" \ 30 | --report_to "tensorboard" \ 31 | 32 | python utils/gen_vllm.py --model $OUTPUT_PATH --sub_task metamath --output_file $OUTPUT_PATH/metamath_response.jsonl 33 | python utils/test_acc.py --input_file $OUTPUT_PATH/metamath_response.jsonl 34 | -------------------------------------------------------------------------------- /scripts/metamath_llama2_7b/run_loftq.sh: -------------------------------------------------------------------------------- 1 | # LoftQ only provide model with rank=64, one can DIY a rank=128 version following: 2 | # https://github.com/yxli2123/LoftQ/tree/main 3 | RESIDUAL_MODEL="LoftQ/Meta-Llama-3-8B-4bit-64rank" 4 | OUTPUT_PATH="output/metamath-LoftQ-Llama-3-8B-4bit-64rank" 5 | DATA_PATH="pissa-dataset" 6 | 7 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 8 | deepspeed --master_port=16971 --include=localhost:0 train.py \ 9 | --deepspeed configs/ds_config_zero2_no_offload.json \ 10 | --model_name_or_path $RESIDUAL_MODEL \ 11 | --full_finetune False \ 12 | --bf16 \ 13 | --bits 4 \ 14 | --use_lora True \ 15 | --adapter_name_or_path "loftq_init" \ 16 | --data_path $DATA_PATH \ 17 | --sub_task metamath:100000 \ 18 | --dataset_split "train"\ 19 | --dataset_field instruction output \ 20 | --output_dir $OUTPUT_PATH \ 21 | --num_train_epochs 1 \ 22 | --model_max_length 512 \ 23 | --per_device_train_batch_size 1 \ 24 | --gradient_accumulation_steps 128 \ 25 | --save_strategy "steps" \ 26 | --save_steps 1000 \ 27 | --save_total_limit 1 \ 28 | --learning_rate 2e-5 \ 29 | --weight_decay 0. \ 30 | --warmup_ratio 0.03 \ 31 | --logging_steps 1 \ 32 | --lr_scheduler_type "cosine" \ 33 | --report_to "tensorboard" \ 34 | -------------------------------------------------------------------------------- /scripts/metamath_llama2_7b/run_lora.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | OUTPUT_PATH="output/metamath-LoRA-Llama-2-7b-r128" 3 | DATA_PATH="pissa-dataset" 4 | 5 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 6 | deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \ 7 | --deepspeed configs/ds_config_zero2_no_offload.json \ 8 | --model_name_or_path $BASE_MODEL \ 9 | --full_finetune False \ 10 | --bf16 \ 11 | --init_weights True \ 12 | --target_modules "q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" \ 13 | --lora_rank 128 \ 14 | --lora_alpha 128 \ 15 | --lora_dropout 0 \ 16 | --data_path $DATA_PATH \ 17 | --sub_task metamath:100000 \ 18 | --dataset_split train \ 19 | --dataset_field instruction output \ 20 | --output_dir $OUTPUT_PATH \ 21 | --num_train_epochs 1 \ 22 | --model_max_length 512 \ 23 | --per_device_train_batch_size 4 \ 24 | --gradient_accumulation_steps 4 \ 25 | --save_strategy "steps" \ 26 | --save_steps 1000 \ 27 | --save_total_limit 1 \ 28 | --learning_rate 2e-5 \ 29 | --weight_decay 0. \ 30 | --warmup_ratio 0.03 \ 31 | --logging_steps 1 \ 32 | --lr_scheduler_type "cosine" \ 33 | --report_to "tensorboard" \ 34 | --merge True \ 35 | 36 | python utils/gen_vllm.py --model $OUTPUT_PATH --sub_task metamath --output_file $OUTPUT_PATH/metamath_response.jsonl 37 | python utils/test_acc.py --input_file $OUTPUT_PATH/metamath_response.jsonl 38 | -------------------------------------------------------------------------------- /scripts/metamath_llama2_7b/run_pissa.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | RES_MODEL="output/PiSSA-Llama-2-7b-r128" 3 | OUTPUT_PATH="output/metamath-PiSSA-Llama-2-7b-r128" 4 | DATA_PATH="pissa-dataset" 5 | export HF_ENDPOINT=https://hf-mirror.com 6 | 7 | #huggingface-cli download --token hf_*** --resume-download $RES_MODEL --local-dir $RES_MODEL 8 | if [ -e $RES_MODEL ]; then 9 | echo "Use pre-initialized residual model." 10 | else 11 | echo "Perform PiSSA initialization by my self." 12 | python utils/init_pissa.py --base_model_path $BASE_MODEL --output_dir $RES_MODEL --init_weights pissa_niter_16 --lora_r 128 --lora_alpha 128 --lora_dropout 0 --target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj 13 | fi 14 | 15 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 16 | deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \ 17 | --deepspeed configs/ds_config_zero2_no_offload.json \ 18 | --model_name_or_path $RES_MODEL \ 19 | --full_finetune False \ 20 | --bf16 \ 21 | --adapter_name_or_path "pissa_init" \ 22 | --data_path $DATA_PATH \ 23 | --sub_task metamath:100000 \ 24 | --dataset_split train \ 25 | --dataset_field instruction output \ 26 | --output_dir $OUTPUT_PATH \ 27 | --num_train_epochs 1 \ 28 | --model_max_length 512 \ 29 | --per_device_train_batch_size 4 \ 30 | --gradient_accumulation_steps 4 \ 31 | --save_strategy "steps" \ 32 | --save_steps 1000 \ 33 | --save_total_limit 1 \ 34 | --learning_rate 2e-5 \ 35 | --weight_decay 0. \ 36 | --warmup_ratio 0.03 \ 37 | --logging_steps 1 \ 38 | --lr_scheduler_type "cosine" \ 39 | --report_to "tensorboard" \ 40 | --merge True \ 41 | 42 | python utils/gen_vllm.py --model $OUTPUT_PATH --sub_task metamath --output_file $OUTPUT_PATH/metamath_response.jsonl 43 | python utils/test_acc.py --input_file $OUTPUT_PATH/metamath_response.jsonl 44 | -------------------------------------------------------------------------------- /scripts/metamath_llama2_7b/run_qlora.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | OUTPUT_PATH="output/metamath-QLoRA-Llama-2-7B-4bit-r128" 3 | DATA_PATH="pissa-dataset" 4 | 5 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 6 | deepspeed --master_port=16971 --include=localhost:0 train.py \ 7 | --deepspeed configs/ds_config_zero2_no_offload.json \ 8 | --model_name_or_path $BASE_MODEL \ 9 | --full_finetune False \ 10 | --bf16 \ 11 | --bits 4 \ 12 | --init_weights True \ 13 | --target_modules "q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" \ 14 | --lora_rank 128 \ 15 | --lora_alpha 128 \ 16 | --lora_dropout 0 \ 17 | --data_path $DATA_PATH \ 18 | --dataset_split "train"\ 19 | --sub_task metamath:100000 \ 20 | --dataset_field instruction output \ 21 | --output_dir $OUTPUT_PATH \ 22 | --num_train_epochs 1 \ 23 | --model_max_length 512 \ 24 | --per_device_train_batch_size 1 \ 25 | --gradient_accumulation_steps 128 \ 26 | --save_strategy "steps" \ 27 | --save_steps 100 \ 28 | --save_total_limit 100 \ 29 | --learning_rate 2e-5 \ 30 | --weight_decay 0. \ 31 | --warmup_ratio 0.03 \ 32 | --logging_steps 1 \ 33 | --lr_scheduler_type "cosine" \ 34 | --report_to "tensorboard" \ 35 | 36 | python utils/merge_adapter.py --base_model $BASE_MODEL --adapter $OUTPUT_PATH/checkpoint-781/ --output_path $OUTPUT_PATH 37 | python utils/gen_vllm.py --model $OUTPUT_PATH --sub_task metamath --output_file $OUTPUT_PATH/metamath_response.jsonl 38 | python utils/test_acc.py --input_file $OUTPUT_PATH/metamath_response.jsonl 39 | -------------------------------------------------------------------------------- /scripts/metamath_llama2_7b/run_qpissa.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | RES_MODEL="output/QPiSSA-Llama-2-7b-4bit-r128-5iter" 3 | OUTPUT_PATH="output/metamath-QPiSSA-Llama-2-7b-4bit-r128-5iter" 4 | DATA_PATH="pissa-dataset" 5 | 6 | if [ -e $RES_MODEL ]; then 7 | echo "Use pre-initialized residual model." 8 | else 9 | echo "Perform QPiSSA initialization by my self." 10 | python utils/init_qpissa.py --base_model_dir $BASE_MODEL --output_path $RES_MODEL --rank 128 --iter 5 --target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj 11 | fi 12 | 13 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 14 | deepspeed --master_port=16971 --include=localhost:0 train.py \ 15 | --deepspeed configs/ds_config_zero2_no_offload.json \ 16 | --model_name_or_path $RES_MODEL \ 17 | --full_finetune False \ 18 | --bf16 \ 19 | --bits 4 \ 20 | --adapter_name_or_path "pissa_init" \ 21 | --data_path $DATA_PATH \ 22 | --sub_task metamath:100000 \ 23 | --dataset_split train \ 24 | --dataset_field instruction output \ 25 | --output_dir $OUTPUT_PATH \ 26 | --num_train_epochs 1 \ 27 | --model_max_length 512 \ 28 | --per_device_train_batch_size 1 \ 29 | --gradient_accumulation_steps 128 \ 30 | --save_strategy "steps" \ 31 | --save_steps 100 \ 32 | --save_total_limit 100 \ 33 | --learning_rate 2e-5 \ 34 | --weight_decay 0. \ 35 | --warmup_ratio 0.03 \ 36 | --logging_steps 1 \ 37 | --lr_scheduler_type "cosine" \ 38 | --report_to "tensorboard" \ 39 | 40 | python utils/merge_adapter.py --base_model $RES_MODEL --adapter $OUTPUT_PATH/checkpoint-781/ --output_path $OUTPUT_PATH 41 | python utils/gen_vllm.py --model $OUTPUT_PATH --sub_task metamath --output_file $OUTPUT_PATH/metamath_response.jsonl 42 | python utils/test_acc.py --input_file $OUTPUT_PATH/metamath_response.jsonl 43 | -------------------------------------------------------------------------------- /scripts/python_llama2_7b/run_full_finetune.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | OUTPUT_PATH="output/python-FullFT-Llama-2-7b" 3 | DATA_PATH="pissa-dataset" 4 | export HF_ENDPOINT=https://hf-mirror.com 5 | # huggingface-cli download --token hf_*** --resume-download $BASE_MODEL --local-dir $BASE_MODEL 6 | 7 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 8 | deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \ 9 | --deepspeed configs/ds_config_zero2_no_offload.json \ 10 | --model_name_or_path $BASE_MODEL \ 11 | --full_finetune True \ 12 | --bf16 \ 13 | --data_path $DATA_PATH \ 14 | --sub_task python \ 15 | --dataset_split "train"\ 16 | --dataset_field instruction output \ 17 | --output_dir $OUTPUT_PATH \ 18 | --num_train_epochs 1 \ 19 | --model_max_length 512 \ 20 | --per_device_train_batch_size 2 \ 21 | --gradient_accumulation_steps 8 \ 22 | --save_strategy "steps" \ 23 | --save_steps 1000 \ 24 | --save_total_limit 1 \ 25 | --learning_rate 2e-5 \ 26 | --weight_decay 0. \ 27 | --warmup_ratio 0.03 \ 28 | --logging_steps 1 \ 29 | --lr_scheduler_type "cosine" \ 30 | --report_to "tensorboard" \ 31 | 32 | python utils/gen_vllm.py --model $OUTPUT_PATH --sub_task python --output_file $OUTPUT_PATH/python_response.jsonl 33 | python utils/code_process.py --path $OUTPUT_PATH/python_response.jsonl 34 | evalplus.evaluate --dataset humaneval --samples $OUTPUT_PATH/humaneval.jsonl 35 | evalplus.evaluate --dataset mbpp --samples $OUTPUT_PATH/mbpp.jsonl -------------------------------------------------------------------------------- /scripts/python_llama2_7b/run_loftq.sh: -------------------------------------------------------------------------------- 1 | # LoftQ only provide model with rank=64, one can DIY a rank=128 version following: 2 | # https://github.com/yxli2123/LoftQ/tree/main 3 | RESIDUAL_MODEL="LoftQ/Meta-Llama-3-8B-4bit-64rank" 4 | OUTPUT_PATH="output/LoftQ-Llama-3-8B-4bit-64rank" 5 | DATA_PATH="meta-math/MetaMathQA" 6 | 7 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 8 | deepspeed --master_port=16971 --include=localhost:0 train.py \ 9 | --deepspeed configs/ds_config_zero2_no_offload.json \ 10 | --model_name_or_path $RESIDUAL_MODEL \ 11 | --full_finetune False \ 12 | --bf16 \ 13 | --bits 4 \ 14 | --use_lora True \ 15 | --adapter_name_or_path "loftq_init" \ 16 | --data_path $DATA_PATH \ 17 | --dataset_field query response \ 18 | --dataset_split "train[:100000]"\ 19 | --output_dir $OUTPUT_PATH \ 20 | --num_train_epochs 1 \ 21 | --model_max_length 512 \ 22 | --per_device_train_batch_size 1 \ 23 | --gradient_accumulation_steps 128 \ 24 | --save_strategy "steps" \ 25 | --save_steps 100 \ 26 | --save_total_limit 100 \ 27 | --learning_rate 2e-5 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.03 \ 30 | --logging_steps 1 \ 31 | --lr_scheduler_type "cosine" \ 32 | --report_to "tensorboard" \ 33 | -------------------------------------------------------------------------------- /scripts/python_llama2_7b/run_lora.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | OUTPUT_PATH="output/python-LoRA-Llama-2-7b-r128" 3 | DATA_PATH="pissa-dataset" 4 | 5 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 6 | deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \ 7 | --deepspeed configs/ds_config_zero2_no_offload.json \ 8 | --model_name_or_path $BASE_MODEL \ 9 | --full_finetune False \ 10 | --bf16 \ 11 | --init_weights True \ 12 | --target_modules "q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" \ 13 | --lora_rank 128 \ 14 | --lora_alpha 128 \ 15 | --lora_dropout 0 \ 16 | --data_path $DATA_PATH \ 17 | --sub_task python \ 18 | --dataset_split train \ 19 | --dataset_field instruction output \ 20 | --output_dir $OUTPUT_PATH \ 21 | --num_train_epochs 1 \ 22 | --model_max_length 512 \ 23 | --per_device_train_batch_size 4 \ 24 | --gradient_accumulation_steps 4 \ 25 | --save_strategy "steps" \ 26 | --save_steps 1000 \ 27 | --save_total_limit 1 \ 28 | --learning_rate 2e-5 \ 29 | --weight_decay 0. \ 30 | --warmup_ratio 0.03 \ 31 | --logging_steps 1 \ 32 | --lr_scheduler_type "cosine" \ 33 | --report_to "tensorboard" \ 34 | --merge True \ 35 | 36 | python utils/gen_vllm.py --model $OUTPUT_PATH --sub_task python --output_file $OUTPUT_PATH/python_response.jsonl 37 | python utils/test_acc.py --input_file $OUTPUT_PATH/python_response.jsonl 38 | -------------------------------------------------------------------------------- /scripts/python_llama2_7b/run_pissa.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | RES_MODEL="output/PiSSA-Llama-2-7b-r128" 3 | OUTPUT_PATH="output/python-PiSSA-Llama-2-7b-r128" 4 | DATA_PATH="pissa-dataset" 5 | export HF_ENDPOINT=https://hf-mirror.com 6 | 7 | #huggingface-cli download --token hf_*** --resume-download $RES_MODEL --local-dir $RES_MODEL 8 | if [ -e $RES_MODEL ]; then 9 | echo "Use pre-initialized residual model." 10 | else 11 | echo "Perform PiSSA initialization by my self." 12 | python utils/init_pissa.py --base_model_path $BASE_MODEL --output_dir $RES_MODEL --init_weights pissa_niter_16 --lora_r 128 --lora_alpha 128 --lora_dropout 0 --target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj 13 | fi 14 | 15 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 16 | deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \ 17 | --deepspeed configs/ds_config_zero2_no_offload.json \ 18 | --model_name_or_path $RES_MODEL \ 19 | --full_finetune False \ 20 | --bf16 \ 21 | --adapter_name_or_path "pissa_init" \ 22 | --data_path $DATA_PATH \ 23 | --sub_task python \ 24 | --dataset_split train \ 25 | --dataset_field instruction output \ 26 | --output_dir $OUTPUT_PATH \ 27 | --num_train_epochs 1 \ 28 | --model_max_length 512 \ 29 | --per_device_train_batch_size 4 \ 30 | --gradient_accumulation_steps 4 \ 31 | --save_strategy "steps" \ 32 | --save_steps 1000 \ 33 | --save_total_limit 1 \ 34 | --learning_rate 2e-5 \ 35 | --weight_decay 0. \ 36 | --warmup_ratio 0.03 \ 37 | --logging_steps 1 \ 38 | --lr_scheduler_type "cosine" \ 39 | --report_to "tensorboard" \ 40 | --merge True \ 41 | 42 | python utils/gen_vllm.py --model $OUTPUT_PATH --sub_task python --output_file $OUTPUT_PATH/python_response.jsonl 43 | python utils/code_process.py --path $OUTPUT_PATH/python_response.jsonl 44 | evalplus.evaluate --dataset humaneval --samples $OUTPUT_PATH/humaneval.jsonl 45 | evalplus.evaluate --dataset mbpp --samples $OUTPUT_PATH/mbpp.jsonl 46 | -------------------------------------------------------------------------------- /scripts/python_llama2_7b/run_qlora.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | OUTPUT_PATH="output/python-QLoRA-Llama-2-7B-4bit-r128" 3 | DATA_PATH="pissa-dataset" 4 | 5 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 6 | deepspeed --master_port=16971 --include=localhost:0 train.py \ 7 | --deepspeed configs/ds_config_zero2_no_offload.json \ 8 | --model_name_or_path $BASE_MODEL \ 9 | --full_finetune False \ 10 | --bf16 \ 11 | --bits 4 \ 12 | --init_weights True \ 13 | --target_modules "q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" \ 14 | --lora_rank 128 \ 15 | --lora_alpha 128 \ 16 | --lora_dropout 0 \ 17 | --data_path $DATA_PATH \ 18 | --dataset_split "train"\ 19 | --sub_task python \ 20 | --dataset_field instruction output \ 21 | --output_dir $OUTPUT_PATH \ 22 | --num_train_epochs 1 \ 23 | --model_max_length 512 \ 24 | --per_device_train_batch_size 1 \ 25 | --gradient_accumulation_steps 128 \ 26 | --save_strategy "steps" \ 27 | --save_steps 100 \ 28 | --save_total_limit 100 \ 29 | --learning_rate 2e-5 \ 30 | --weight_decay 0. \ 31 | --warmup_ratio 0.03 \ 32 | --logging_steps 1 \ 33 | --lr_scheduler_type "cosine" \ 34 | --report_to "tensorboard" \ 35 | 36 | python utils/merge_adapter.py --base_model $BASE_MODEL --adapter $OUTPUT_PATH/checkpoint-819/ --output_path $OUTPUT_PATH 37 | python utils/gen_vllm.py --model $OUTPUT_PATH --sub_task python --output_file $OUTPUT_PATH/python_response.jsonl 38 | python utils/code_process.py --path $OUTPUT_PATH/python_response.jsonl 39 | evalplus.evaluate --dataset humaneval --samples $OUTPUT_PATH/humaneval.jsonl 40 | evalplus.evaluate --dataset mbpp --samples $OUTPUT_PATH/mbpp.jsonl 41 | -------------------------------------------------------------------------------- /scripts/python_llama2_7b/run_qpissa.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 2 | RES_MODEL="output/QPiSSA-Llama-2-7b-4bit-r128-5iter" 3 | OUTPUT_PATH="output/python-QPiSSA-Llama-2-7b-4bit-r128-5iter" 4 | DATA_PATH="pissa-dataset" 5 | 6 | if [ -e $RES_MODEL ]; then 7 | echo "Use pre-initialized residual model." 8 | else 9 | echo "Perform QPiSSA initialization by my self." 10 | python utils/init_qpissa.py --base_model_dir $BASE_MODEL --output_path $RES_MODEL --rank 128 --iter 5 --target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj 11 | fi 12 | 13 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 14 | deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \ 15 | --deepspeed configs/ds_config_zero2_no_offload.json \ 16 | --model_name_or_path $RES_MODEL \ 17 | --full_finetune False \ 18 | --bf16 \ 19 | --bits 4 \ 20 | --adapter_name_or_path "qpissa_init" \ 21 | --data_path $DATA_PATH \ 22 | --sub_task python \ 23 | --dataset_split train \ 24 | --dataset_field instruction output \ 25 | --output_dir $OUTPUT_PATH \ 26 | --num_train_epochs 1 \ 27 | --model_max_length 512 \ 28 | --per_device_train_batch_size 4 \ 29 | --gradient_accumulation_steps 4 \ 30 | --save_strategy "steps" \ 31 | --save_steps 100 \ 32 | --save_total_limit 100 \ 33 | --learning_rate 2e-5 \ 34 | --weight_decay 0. \ 35 | --warmup_ratio 0.03 \ 36 | --logging_steps 1 \ 37 | --lr_scheduler_type "cosine" \ 38 | --report_to "tensorboard" \ 39 | 40 | python utils/merge_adapter.py --base_model $RES_MODEL --adapter $OUTPUT_PATH/checkpoint-819/ --output_path $OUTPUT_PATH 41 | python utils/gen_vllm.py --model $OUTPUT_PATH --sub_task python --output_file $OUTPUT_PATH/python_response.jsonl 42 | python utils/code_process.py --path $OUTPUT_PATH/python_response.jsonl 43 | evalplus.evaluate --dataset humaneval --samples $OUTPUT_PATH/humaneval.jsonl 44 | evalplus.evaluate --dataset mbpp --samples $OUTPUT_PATH/mbpp.jsonl -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from dataclasses import dataclass, field 4 | from typing import Optional, Dict, Sequence, List 5 | import logging 6 | import os 7 | 8 | import torch 9 | import torch.distributed 10 | import transformers 11 | from transformers import Trainer, BitsAndBytesConfig 12 | from datasets import load_dataset, concatenate_datasets 13 | import datasets 14 | import numpy as np 15 | from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel, LoraRuntimeConfig 16 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 17 | 18 | IGNORE_INDEX = -100 19 | logger = logging.getLogger(__name__) 20 | 21 | PROMPT = ( 22 | "Below is an instruction that describes a task. " 23 | "Write a response that appropriately completes the request.\n\n" 24 | "### Instruction:\n{instruction}\n\n### Response:" 25 | ) 26 | 27 | @dataclass 28 | class TrainingArguments(transformers.TrainingArguments): 29 | # Base model or residual model setting 30 | model_name_or_path: Optional[str] = field(default="meta-llama/Meta-Llama-3-8B") 31 | attn_implementation : Optional[str] = field(default="flash_attention_2") 32 | # Lora or PiSSA setting 33 | full_finetune : Optional[bool] = field(default=True) 34 | adapter_name_or_path: Optional[str] = field(default=None,metadata={"help": ("Pre-initialized PiSSA adapter path; when this is not None, the following arguments are ignored."),},) 35 | init_weights: bool | str = field(default=True,metadata={"help": ("True -> LoRA; `pissa` -> PiSSA; `pissa_niter_16` -> Fast SVD PiSSA"),},) 36 | use_dora : Optional[bool] = field(default=False) 37 | target_modules : Optional[str] = field(default="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj") 38 | lora_rank : Optional[int] = field(default=8) 39 | lora_alpha : Optional[float] = field(default=32.) 40 | lora_dropout : Optional[float] = field(default=0.,metadata={"help": ("Must be set to 0 when using PiSSA."),},) 41 | # Quantization setting 42 | bits: int = field(default=16,metadata={"help": "How many bits to use."}) 43 | double_quant: bool = field(default=True,metadata={"help": "Compress the quantization statistics through double quantization."}) 44 | quant_type: str = field(default="nf4",metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}) 45 | # DataArguments: 46 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 47 | sub_task: List[str] = field(default=None) 48 | dataset_split: str = field(default="train", metadata={"help": "(`['train', 'test', 'eval']`):"}) 49 | dataset_field: List[str] = field(default=None, metadata={"help": "Fields of dataset input and output."}) 50 | shuffle_dataset : Optional[bool] = field(default=False) 51 | # TrainingArguments 52 | optim: str = field(default="adamw_torch") 53 | model_max_length: int = field(default=512,metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},) 54 | merge : Optional[bool] = field(default=False,metadata={"help": "Merge the PiSSA adapter to the residual model or LoRA to the base model"},) 55 | 56 | class SavePeftModelCallback(transformers.TrainerCallback): 57 | def save_model(self, args, state, kwargs): 58 | logger.info('Saving PEFT checkpoint...') 59 | if state.best_model_checkpoint is not None: 60 | checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model") 61 | else: 62 | checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") 63 | 64 | peft_model_path = os.path.join(checkpoint_folder, "adapter_model") 65 | kwargs["model"].save_pretrained(peft_model_path) 66 | kwargs["tokenizer"].save_pretrained(peft_model_path) 67 | 68 | def on_save(self, args, state, control, **kwargs): 69 | self.save_model(args, state, kwargs) 70 | return control 71 | 72 | def on_train_end(self, args, state, control, **kwargs): 73 | def touch(fname, times=None): 74 | with open(fname, 'a'): 75 | os.utime(fname, times) 76 | touch(os.path.join(args.output_dir, 'completed')) 77 | self.save_model(args, state, kwargs) 78 | 79 | def get_last_checkpoint(checkpoint_dir): 80 | if os.path.isdir(checkpoint_dir): 81 | is_completed = os.path.exists(os.path.join(checkpoint_dir, 'completed')) 82 | if is_completed: return None # already finished 83 | max_step = 0 84 | for filename in os.listdir(checkpoint_dir): 85 | if os.path.isdir(os.path.join(checkpoint_dir, filename)) and filename.startswith(PREFIX_CHECKPOINT_DIR): 86 | max_step = max(max_step, int(filename.replace(PREFIX_CHECKPOINT_DIR + '-', ''))) 87 | if max_step == 0: return None 88 | latest_ckpt_dir = os.path.join(checkpoint_dir, f'{PREFIX_CHECKPOINT_DIR}-{max_step}') 89 | logger.info(f"Found a previous checkpoint at: {checkpoint_dir}") 90 | return latest_ckpt_dir 91 | return None # first training 92 | 93 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 94 | """Collects the state dict and dump to disk.""" 95 | state_dict = trainer.model.state_dict() 96 | if trainer.args.should_save: 97 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 98 | del state_dict 99 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 100 | 101 | 102 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 103 | """Tokenize a list of strings.""" 104 | tokenized_list = [tokenizer(text, max_length=tokenizer.model_max_length,truncation=True,)for text in strings] 105 | input_ids = labels = [np.array(tokenized.input_ids) for tokenized in tokenized_list] 106 | input_ids_lens = labels_lens = [len(tokenized.input_ids) for tokenized in tokenized_list] 107 | 108 | return dict( 109 | input_ids=input_ids, 110 | labels=labels, 111 | input_ids_lens=input_ids_lens, 112 | labels_lens=labels_lens, 113 | ) 114 | 115 | 116 | def preprocess( 117 | sources: Sequence[str], 118 | targets: Sequence[str], 119 | tokenizer: transformers.PreTrainedTokenizer, 120 | ) -> Dict: 121 | """Preprocess the data by tokenizing.""" 122 | examples = [s + t for s, t in zip(sources, targets)] 123 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 124 | input_ids = examples_tokenized["input_ids"] 125 | labels = copy.deepcopy(input_ids) 126 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 127 | label[:source_len] = IGNORE_INDEX 128 | return dict(input_ids=input_ids, labels=labels) 129 | 130 | @dataclass 131 | class DataCollatorForSupervisedDataset(object): 132 | """Collate examples for supervised fine-tuning.""" 133 | tokenizer: transformers.PreTrainedTokenizer 134 | 135 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 136 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 137 | input_ids = [torch.tensor(x) for x in input_ids] 138 | input_ids = torch.nn.utils.rnn.pad_sequence( 139 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 140 | ) 141 | labels = [torch.tensor(x) for x in labels] 142 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 143 | 144 | return dict( 145 | input_ids=input_ids, 146 | labels=labels, 147 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 148 | ) 149 | 150 | def train_tokenize_function(examples, tokenizer, query, response): 151 | sources = [PROMPT.format_map(dict(instruction=instruction)) for instruction in examples[query]] 152 | targets = [f"{output}\n{tokenizer.eos_token}" for output in examples[response]] 153 | data_dict = preprocess(sources, targets, tokenizer) 154 | return data_dict 155 | 156 | def build_model(script_args, checkpoint_dir): 157 | if script_args.full_finetune: 158 | assert script_args.bits in [16, 32] 159 | compute_dtype = (torch.bfloat16 if script_args.bf16 else torch.float32) 160 | model = transformers.AutoModelForCausalLM.from_pretrained( 161 | script_args.model_name_or_path, 162 | quantization_config=BitsAndBytesConfig( 163 | load_in_4bit=script_args.bits == 4, 164 | load_in_8bit=script_args.bits == 8, 165 | llm_int8_threshold=6.0, 166 | llm_int8_has_fp16_weight=False, 167 | bnb_4bit_compute_dtype=compute_dtype, 168 | bnb_4bit_use_double_quant=script_args.double_quant, 169 | bnb_4bit_quant_type=script_args.quant_type, 170 | ) if script_args.bits in [4, 8] else None, 171 | torch_dtype=compute_dtype, 172 | trust_remote_code=True, 173 | ) 174 | setattr(model, 'model_parallel', True) 175 | setattr(model, 'is_parallelizable', True) 176 | # Tokenizer 177 | 178 | if not script_args.full_finetune: 179 | if script_args.bits < 16: 180 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=script_args.gradient_checkpointing) 181 | 182 | if checkpoint_dir is not None: 183 | logger.info(f"Loading adapters from {checkpoint_dir}.") 184 | # os.path.join(checkpoint_dir, 'adapter_model') 185 | model = PeftModel.from_pretrained(model, checkpoint_dir, is_trainable=True) 186 | elif script_args.adapter_name_or_path is not None: 187 | logger.info(f"Initilize LoRA/PiSSA/CLOVER adapters from {script_args.model_name_or_path}/{script_args.adapter_name_or_path}.") 188 | model = PeftModel.from_pretrained(model, script_args.model_name_or_path, subfolder = script_args.adapter_name_or_path, is_trainable=True) 189 | else: 190 | logger.info(f'Init LoRA/PiSSA modules...') 191 | peft_config = LoraConfig( 192 | use_dora=script_args.use_dora, 193 | runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=script_args.use_dora), 194 | task_type=TaskType.CAUSAL_LM, 195 | target_modules=script_args.target_modules.split(','), 196 | inference_mode=False, 197 | r=script_args.lora_rank, 198 | lora_alpha=script_args.lora_alpha, 199 | lora_dropout=script_args.lora_dropout, 200 | init_lora_weights=script_args.init_weights, 201 | ) 202 | model = get_peft_model(model, peft_config) 203 | 204 | for name, module in model.named_modules(): 205 | if 'norm' in name or 'gate' in name: 206 | module = module.to(torch.float32) 207 | return model 208 | 209 | def train(): 210 | parser = transformers.HfArgumentParser(TrainingArguments) 211 | script_args = parser.parse_args_into_dataclasses()[0] 212 | log_level = script_args.get_process_log_level() 213 | logger.setLevel(log_level) 214 | datasets.utils.logging.set_verbosity(log_level) 215 | transformers.utils.logging.set_verbosity(log_level) 216 | transformers.utils.logging.enable_default_handler() 217 | transformers.utils.logging.enable_explicit_format() 218 | 219 | if script_args.local_rank == 0: 220 | logger.info('='*100) 221 | logger.info(script_args) 222 | 223 | tokenizer = transformers.AutoTokenizer.from_pretrained( 224 | script_args.model_name_or_path, 225 | model_max_length=script_args.model_max_length, 226 | padding_side="right", 227 | use_fast=True, 228 | trust_remote_code=True 229 | ) 230 | if tokenizer.pad_token is None: 231 | tokenizer.pad_token = tokenizer.eos_token 232 | 233 | if script_args.local_rank == 0: 234 | logger.info("Load tokenizer from {} over.".format(script_args.model_name_or_path)) 235 | 236 | resume_from_checkpoint_dir = get_last_checkpoint(script_args.output_dir) 237 | model = build_model(script_args, resume_from_checkpoint_dir) 238 | 239 | all_training_dataset = [] 240 | for task in script_args.sub_task: 241 | if ":" in task: # e.g. math:500, gsm8k:100 242 | cur_task, num_split = task.split(":") 243 | cur_split = f"{script_args.dataset_split}[:{num_split}]" 244 | else: 245 | cur_task, cur_split = task, script_args.dataset_split 246 | 247 | ds = load_dataset(script_args.data_path, data_dir=cur_task, split=cur_split) 248 | if script_args.local_rank == 0: 249 | print(f"{script_args.data_path}/{cur_task}/{cur_split}/{ds.num_rows}") 250 | for k,v in ds[0].items(): 251 | print("-"*100) 252 | print(k,end=':\t') 253 | print(v) 254 | print("+"*100) 255 | all_training_dataset.append(ds) 256 | 257 | raw_train_datasets = concatenate_datasets(all_training_dataset) 258 | if script_args.shuffle_dataset: 259 | if script_args.local_rank == 0: 260 | print(f"Shuffle dataset with seed={script_args.seed}") 261 | raw_train_datasets = raw_train_datasets.shuffle(seed=script_args.seed) 262 | 263 | if script_args.local_rank > 0: 264 | torch.distributed.barrier() 265 | 266 | train_dataset = raw_train_datasets.map( 267 | train_tokenize_function, 268 | batched=True, 269 | batch_size=3000, 270 | num_proc=32, 271 | remove_columns=raw_train_datasets.column_names, 272 | load_from_cache_file=True, 273 | desc="Running tokenizer on train dataset", 274 | fn_kwargs={"tokenizer": tokenizer, "query": script_args.dataset_field[0], "response": script_args.dataset_field[1]} 275 | ) 276 | 277 | 278 | if script_args.local_rank == 0: 279 | torch.distributed.barrier() 280 | print(model) 281 | logger.info("Training dataset samples:", len(train_dataset)) 282 | for index in random.sample(range(len(train_dataset)), 3): 283 | logger.info(f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['labels']}.") 284 | logger.info(f"Sample {index} of the training set: {tokenizer.decode(list(train_dataset[index]['input_ids']))}.") 285 | 286 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 287 | data_module = dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) 288 | 289 | trainer = Trainer(model=model, tokenizer=tokenizer, args=script_args, **data_module) 290 | if not script_args.full_finetune: 291 | trainer.add_callback(SavePeftModelCallback) 292 | trainer.train(resume_from_checkpoint = resume_from_checkpoint_dir) 293 | trainer.save_state() 294 | if not script_args.full_finetune and script_args.merge: 295 | model = model.merge_and_unload() 296 | model.save_pretrained(script_args.output_dir) 297 | tokenizer.save_pretrained(script_args.output_dir) 298 | if script_args.full_finetune: 299 | safe_save_model_for_hf_trainer(trainer=trainer, output_dir=script_args.output_dir) 300 | 301 | 302 | if __name__ == "__main__": 303 | train() 304 | -------------------------------------------------------------------------------- /utils/code_process.py: -------------------------------------------------------------------------------- 1 | from human_eval.data import write_jsonl, stream_jsonl 2 | import glob 3 | from tqdm import tqdm 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | 8 | # Inputs 9 | parser.add_argument( 10 | '--path', 11 | type=str, 12 | help="") 13 | parser.add_argument( 14 | '--out_path', 15 | type=str, 16 | help="") 17 | 18 | args = parser.parse_args() 19 | humaneval_output = [] 20 | mbpp_output = [] 21 | for code in stream_jsonl(args.path): 22 | if code['type'] not in ['humaneval', 'mbpp']: 23 | continue 24 | task_id = code['answer'] 25 | code['task_id'] = str(task_id) 26 | completion = code['output'].replace("\r", "") 27 | if '```python' in completion: 28 | def_line = completion.index('```python') 29 | completion = completion[def_line:].strip() 30 | completion = completion.replace('```python', '') 31 | try: 32 | next_line = completion.index('\n```') 33 | completion = completion[:next_line].strip() 34 | except: 35 | pass 36 | 37 | if "__name__ == \"__main__\"" in completion: 38 | next_line = completion.index('if __name__ == "__main__":') 39 | completion = completion[:next_line].strip() 40 | 41 | if "# Example usage" in completion: 42 | next_line = completion.index('# Example usage') 43 | completion = completion[:next_line].strip() 44 | 45 | if "assert" in completion: 46 | next_line = completion.index('assert') 47 | completion = completion[:next_line].strip() 48 | 49 | code['completion'] = completion 50 | 51 | if code['type'] == 'humaneval': 52 | humaneval_output.append(code) 53 | else: 54 | mbpp_output.append(code) 55 | 56 | import os 57 | humaneval_outpath = os.path.join(os.path.dirname(args.path), "humaneval.jsonl") 58 | mbpp_outpath = os.path.join(os.path.dirname(args.path), "mbpp.jsonl") 59 | 60 | print("save to {}".format(humaneval_outpath)) 61 | print("save to {}".format(mbpp_outpath)) 62 | write_jsonl(humaneval_outpath, humaneval_output) 63 | write_jsonl(mbpp_outpath, mbpp_output) -------------------------------------------------------------------------------- /utils/gen_vllm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import sys 4 | import os 5 | import json 6 | from vllm import LLM, SamplingParams 7 | from datasets import load_dataset, concatenate_datasets 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--model', type=str, help="") 11 | parser.add_argument("--data_path", type=str, default="pissa-dataset") 12 | parser.add_argument('--sub_task', nargs='+', help='') 13 | parser.add_argument('--dataset_split', type=str, default="test", help='') 14 | parser.add_argument('--output_file', type=str, default="model_response.jsonl", help="") 15 | parser.add_argument("--batch_size", type=int, default=400, help="") 16 | parser.add_argument('--temperature', type=float, default=0.0, help="") 17 | parser.add_argument('--top_p', type=float, default=1, help="") 18 | parser.add_argument('--max_tokens', type=int, default=1024, help="") 19 | args = parser.parse_args() 20 | 21 | stop_tokens = [] 22 | sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens, stop=stop_tokens) 23 | llm = LLM(model=args.model, tensor_parallel_size=torch.cuda.device_count()) 24 | 25 | def batch_data(data_list, batch_size=1): 26 | n = len(data_list) // batch_size 27 | batch_data = [] 28 | for i in range(n-1): 29 | start = i * batch_size 30 | end = (i+1)*batch_size 31 | batch_data.append(data_list[start:end]) 32 | 33 | last_start = (n-1) * batch_size 34 | last_end = sys.maxsize 35 | batch_data.append(data_list[last_start:last_end]) 36 | return batch_data 37 | 38 | if args.sub_task is None: 39 | dataset = load_dataset(args.data_path, split=args.dataset_split) 40 | else: 41 | all_test_dataset = [] 42 | for task in args.sub_task: 43 | ds = load_dataset(args.data_path, data_dir=task, split=args.dataset_split) 44 | print(f"{args.data_path}/{task}/{args.dataset_split}") 45 | for k,v in ds[0].items(): 46 | print("-"*100) 47 | print(k,end=':\t') 48 | print(v) 49 | print("+"*100) 50 | all_test_dataset.append(ds) 51 | 52 | dataset = concatenate_datasets(all_test_dataset) 53 | 54 | batch_dataset_query = batch_data(dataset["instruction"], batch_size=args.batch_size) 55 | batch_dataset_answer = batch_data(dataset["output"], batch_size=args.batch_size) 56 | batch_dataset_task = batch_data(dataset["type"], batch_size=args.batch_size) 57 | 58 | for idx, (batch_query, batch_answer, batch_task) in enumerate(zip(batch_dataset_query, batch_dataset_answer,batch_dataset_task)): 59 | with torch.no_grad(): 60 | completions = llm.generate(batch_query, sampling_params) 61 | for query, completion, answer, task in zip(batch_query, completions, batch_answer, batch_task): 62 | with open(args.output_file, 'a') as f: 63 | json.dump({'type': task, 'query': query, 'output': completion.outputs[0].text, 'answer': answer}, f) 64 | f.write('\n') 65 | -------------------------------------------------------------------------------- /utils/init_clover.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present the HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import os 17 | from peft import CloverConfig, get_peft_model 18 | from transformers import AutoTokenizer, AutoModelForCausalLM 19 | import argparse 20 | 21 | parser = argparse.ArgumentParser(description="Merge Adapter to Base Model") 22 | parser.add_argument("--base_model_path", type=str, help="The name or path of the fp32/16 base model.") 23 | parser.add_argument("--output_dir", type=str, default="clover_model") 24 | parser.add_argument("--bits", type=str, default="fp32", choices=["bf16", "fp16", "fp32"]) 25 | parser.add_argument("--init_weights", type=str, default="svd", help="(`['eye', 'svd']`)") 26 | parser.add_argument('--target_modules', nargs='+', help='', required=True) 27 | parser.add_argument("--head_dim", type=int) 28 | parser.add_argument("--num_head", type=int) 29 | args = parser.parse_args() 30 | print(args) 31 | 32 | model = AutoModelForCausalLM.from_pretrained( 33 | args.base_model_path, 34 | torch_dtype=( 35 | torch.float16 36 | if args.bits == "fp16" 37 | else (torch.bfloat16 if args.bits == "bf16" else torch.float32) 38 | ), 39 | device_map="auto", 40 | ) 41 | tokenizer = AutoTokenizer.from_pretrained(args.base_model_path) 42 | clover_config = CloverConfig( 43 | init_clover_weights=args.init_weights, 44 | target_modules=args.target_modules, 45 | head_dim=args.head_dim, 46 | num_head=args.num_head, 47 | task_type="CAUSAL_LM", 48 | ) 49 | peft_model = get_peft_model(model, clover_config) 50 | print(peft_model.get_nb_trainable_parameters()) 51 | 52 | # Save CLOVER modules: 53 | peft_model.peft_config["default"].init_clover_weights = "eye" 54 | peft_model.save_pretrained(os.path.join(args.output_dir, "clover_init")) 55 | # Save residual model: 56 | peft_model = peft_model.unload() 57 | peft_model.save_pretrained(args.output_dir) 58 | # Save the tokenizer: 59 | tokenizer.save_pretrained(args.output_dir) -------------------------------------------------------------------------------- /utils/init_crossover.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present the HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import os 17 | from peft import CrossoverConfig, get_peft_model 18 | from transformers import AutoTokenizer, AutoModelForCausalLM 19 | import argparse 20 | 21 | parser = argparse.ArgumentParser(description="Merge Adapter to Base Model") 22 | parser.add_argument("--base_model_path", type=str, help="The name or path of the fp32/16 base model.") 23 | parser.add_argument("--output_dir", type=str, default="clover_model") 24 | parser.add_argument("--bits", type=str, default="fp32", choices=["bf16", "fp16", "fp32"]) 25 | parser.add_argument("--init_weights", type=str, default="kaiming", help="(`['kaiming', 'gaussian', 'orthogonal']`)") 26 | parser.add_argument('--target_modules', nargs='+', help='', required=True) 27 | parser.add_argument("--block_size", type=int) 28 | parser.add_argument("--alpha", type=int) 29 | parser.add_argument("--dropout", type=float) 30 | args = parser.parse_args() 31 | print(args) 32 | 33 | model = AutoModelForCausalLM.from_pretrained( 34 | args.base_model_path, 35 | torch_dtype=( 36 | torch.float16 37 | if args.bits == "fp16" 38 | else (torch.bfloat16 if args.bits == "bf16" else torch.float32) 39 | ), 40 | device_map="auto", 41 | ) 42 | tokenizer = AutoTokenizer.from_pretrained(args.base_model_path) 43 | crossover_config = CrossoverConfig( 44 | init_crossover_weights=args.init_weights, 45 | target_modules=args.target_modules, 46 | block_size=args.block_size, 47 | alpha=args.alpha, 48 | dropout=args.dropout, 49 | task_type="CAUSAL_LM", 50 | ) 51 | peft_model = get_peft_model(model, crossover_config) 52 | print(peft_model.get_nb_trainable_parameters()) 53 | 54 | # Save crossover modules: 55 | peft_model.peft_config["default"].init_crossover_weights = 'kaiming' 56 | peft_model.save_pretrained(os.path.join(args.output_dir, "crossover_init")) 57 | # Save residual model: 58 | peft_model = peft_model.unload() 59 | peft_model.save_pretrained(args.output_dir) 60 | # Save the tokenizer: 61 | tokenizer.save_pretrained(args.output_dir) -------------------------------------------------------------------------------- /utils/init_pissa.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present the HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import os 17 | from peft import LoraConfig, get_peft_model 18 | from transformers import AutoTokenizer, AutoModelForCausalLM 19 | import argparse 20 | 21 | parser = argparse.ArgumentParser(description="Separate the principal singular value and singular vectors from base model") 22 | parser.add_argument("--base_model_path", type=str, required=True, help="The name or path of the base model.") 23 | parser.add_argument("--output_dir", type=str, required=True) 24 | parser.add_argument("--bits", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) 25 | parser.add_argument("--init_weights", type=str, default="pissa", help="(`['pissa', 'pissa_niter_[number of iters]']`)") 26 | parser.add_argument("--lora_r", type=int, default=128) 27 | parser.add_argument("--lora_alpha", type=int, default=128) 28 | parser.add_argument("--lora_dropout", type=float, default=0) 29 | parser.add_argument('--target_modules', nargs='+', help='', required=True) 30 | script_args = parser.parse_args() 31 | print(script_args) 32 | 33 | model = AutoModelForCausalLM.from_pretrained( 34 | script_args.base_model_path, 35 | torch_dtype=( 36 | torch.float16 37 | if script_args.bits == "fp16" 38 | else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32) 39 | ), 40 | device_map="auto", 41 | ) 42 | tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_path) 43 | tokenizer.pad_token_id = tokenizer.eos_token_id 44 | lora_config = LoraConfig( 45 | r=script_args.lora_r, 46 | lora_alpha=script_args.lora_alpha, 47 | init_lora_weights=True if script_args.init_weights=="True" else script_args.init_weights, 48 | lora_dropout=script_args.lora_dropout, 49 | target_modules=script_args.target_modules, 50 | ) 51 | peft_model = get_peft_model(model, lora_config) 52 | 53 | # Save PiSSA modules: 54 | peft_model.peft_config["default"].init_lora_weights = True 55 | peft_model.save_pretrained(os.path.join(script_args.output_dir, "pissa_init")) 56 | # Save residual model: 57 | peft_model = peft_model.unload() 58 | peft_model.save_pretrained(script_args.output_dir) 59 | # Save the tokenizer: 60 | tokenizer.save_pretrained(script_args.output_dir) -------------------------------------------------------------------------------- /utils/init_qpissa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from peft import get_peft_model, LoraConfig 5 | import bitsandbytes as bnb 6 | from tqdm import tqdm 7 | 8 | # python utils/init_qpissa.py --base_model_dir meta-llama/Llama-2-7b-hf/ --output_path llama-2-7b-pissa-4bit-r128-iter5 --iter 5 9 | 10 | parser = argparse.ArgumentParser(description="Initializing QPiSSA.") 11 | parser.add_argument("--base_model_dir", type=str, required=True) 12 | parser.add_argument("--output_path", type=str, required=True) 13 | parser.add_argument("--rank", type=int, default=128) 14 | parser.add_argument("--iter", type=int, default=1) 15 | parser.add_argument("--device", type=str, default="cuda") 16 | parser.add_argument('--target_modules', nargs='+', help='', required=True) 17 | args = parser.parse_args() 18 | 19 | def quantize_and_dequantized(weight): 20 | device = weight.device 21 | weight_nf4 = bnb.nn.Params4bit(weight.to("cpu"), requires_grad=False, compress_statistics=False, quant_type="nf4") 22 | weight_nf4 = weight_nf4.to(device) 23 | weight_dequantized = bnb.functional.dequantize_4bit( 24 | weight_nf4.data, weight_nf4.quant_state 25 | ).to(torch.float32) 26 | return weight_nf4, weight_dequantized 27 | 28 | @torch.no_grad() 29 | def pissa_quant(weight, r=64, niter=5): 30 | res = weight.to(torch.float32) 31 | for i in range(niter): 32 | U, S, Vh = torch.linalg.svd(res, full_matrices=False) 33 | L = U @ (torch.sqrt(torch.diag(S)[:, :r])) 34 | R = torch.sqrt(torch.diag(S)[:r, :]) @ Vh 35 | res = weight - L @ R 36 | weight_nf4, weight_dequantized = quantize_and_dequantized(res) 37 | res = weight - weight_dequantized 38 | 39 | return weight_nf4, weight_dequantized, R, L 40 | 41 | 42 | base_model = AutoModelForCausalLM.from_pretrained(args.base_model_dir, device_map=args.device) 43 | tokenizer = AutoTokenizer.from_pretrained(args.base_model_dir) 44 | lora_config = LoraConfig( 45 | r=args.rank, 46 | lora_alpha=args.rank, 47 | target_modules=args.target_modules, 48 | task_type="CAUSAL_LM", 49 | ) 50 | peft_model = get_peft_model(base_model, peft_config=lora_config) 51 | state_dict = {} 52 | for key, value in tqdm(peft_model.state_dict().items()): 53 | if "base_layer" in key: 54 | base_layer_in_4bits, base_layer, lora_A, lora_B = pissa_quant(value, args.rank, args.iter) 55 | state_dict[key] = base_layer.to('cpu') 56 | state_dict[key.replace("base_layer", "lora_A.default")] = lora_A.to('cpu') 57 | state_dict[key.replace("base_layer", "lora_B.default")] = lora_B.to('cpu') 58 | 59 | print(peft_model.load_state_dict(state_dict, strict=False)) 60 | peft_model.save_pretrained(f"{args.output_path}/pissa_init") 61 | peft_model = peft_model.unload() 62 | peft_model.save_pretrained(args.output_path) 63 | tokenizer.save_pretrained(args.output_path) -------------------------------------------------------------------------------- /utils/merge_adapter.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | from peft import PeftModel, PeftConfig 3 | import argparse 4 | import torch 5 | 6 | parser = argparse.ArgumentParser(description='Merge Adapter to Base Model') 7 | parser.add_argument('--base_model', type=str) 8 | parser.add_argument('--adapter', type=str) 9 | parser.add_argument('--output_path', type=str) 10 | args = parser.parse_args() 11 | 12 | model = AutoModelForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.bfloat16, trust_remote_code=True) 13 | tokenizer = AutoTokenizer.from_pretrained(args.base_model) 14 | model = PeftModel.from_pretrained(model, args.adapter) 15 | model = model.merge_and_unload() 16 | model.save_pretrained(args.output_path) 17 | tokenizer.save_pretrained(args.output_path) -------------------------------------------------------------------------------- /utils/nf4_to_bf16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import bitsandbytes as bnb 4 | from peft import PeftModel 5 | import torch 6 | from transformers import AutoModelForCausalLM 7 | 8 | parser = argparse.ArgumentParser( 9 | description="Calculate the quantization error of NF4 model." 10 | ) 11 | parser.add_argument( 12 | "--base_model_path", 13 | type=str, 14 | required=True, 15 | ) 16 | parser.add_argument( 17 | "--quant_model_path", 18 | type=str, 19 | required=True, 20 | ) 21 | parser.add_argument( 22 | "--output_path", 23 | type=str, 24 | required=True, 25 | ) 26 | parser.add_argument( 27 | "--device", 28 | type=str, 29 | default="cuda", 30 | ) 31 | args = parser.parse_args() 32 | 33 | 34 | residual_model = AutoModelForCausalLM.from_pretrained( 35 | args.base_model_path, torch_dtype=torch.bfloat16, device_map=args.device 36 | ) 37 | quant_model = AutoModelForCausalLM.from_pretrained( 38 | args.quant_model_path, low_cpu_mem_usage=True 39 | ) 40 | 41 | with torch.no_grad(): 42 | for name, param in quant_model.named_parameters(): 43 | if "_proj" in name: 44 | W = residual_model.get_parameter(name) 45 | W.data = bnb.functional.dequantize_4bit(param.data, param.quant_state).to(torch.bfloat16).cpu() 46 | 47 | 48 | residual_model.save_pretrained(args.output_path) -------------------------------------------------------------------------------- /utils/test_acc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import re 4 | from fraction import Fraction 5 | from collections import defaultdict 6 | 7 | def remove_right_units(string): 8 | # "\\text{ " only ever occurs (at least in the val set) when describing units 9 | if "\\text{ " in string: 10 | splits = string.split("\\text{ ") 11 | assert len(splits) == 2 12 | return splits[0] 13 | else: 14 | return string 15 | 16 | def fix_sqrt(string): 17 | if "\\sqrt" not in string: 18 | return string 19 | splits = string.split("\\sqrt") 20 | new_string = splits[0] 21 | for split in splits[1:]: 22 | if split[0] != "{": 23 | a = split[0] 24 | new_substr = "\\sqrt{" + a + "}" + split[1:] 25 | else: 26 | new_substr = "\\sqrt" + split 27 | new_string += new_substr 28 | return new_string 29 | 30 | def fix_fracs(string): 31 | substrs = string.split("\\frac") 32 | new_str = substrs[0] 33 | if len(substrs) > 1: 34 | substrs = substrs[1:] 35 | for substr in substrs: 36 | new_str += "\\frac" 37 | if substr[0] == "{": 38 | new_str += substr 39 | else: 40 | try: 41 | assert len(substr) >= 2 42 | except AssertionError: 43 | return string 44 | a = substr[0] 45 | b = substr[1] 46 | if b != "{": 47 | if len(substr) > 2: 48 | post_substr = substr[2:] 49 | new_str += "{" + a + "}{" + b + "}" + post_substr 50 | else: 51 | new_str += "{" + a + "}{" + b + "}" 52 | else: 53 | if len(substr) > 2: 54 | post_substr = substr[2:] 55 | new_str += "{" + a + "}" + b + post_substr 56 | else: 57 | new_str += "{" + a + "}" + b 58 | string = new_str 59 | return string 60 | 61 | def fix_a_slash_b(string): 62 | if len(string.split("/")) != 2: 63 | return string 64 | a = string.split("/")[0] 65 | b = string.split("/")[1] 66 | try: 67 | a = int(a) 68 | b = int(b) 69 | assert string == "{}/{}".format(a, b) 70 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 71 | return new_string 72 | except AssertionError: 73 | return string 74 | 75 | def strip_string(string): 76 | # linebreaks 77 | string = string.replace("\n", "") 78 | 79 | # remove inverse spaces 80 | string = string.replace("\\!", "") 81 | 82 | # replace \\ with \ 83 | string = string.replace("\\\\", "\\") 84 | 85 | # replace tfrac and dfrac with frac 86 | string = string.replace("tfrac", "frac") 87 | string = string.replace("dfrac", "frac") 88 | 89 | # remove \left and \right 90 | string = string.replace("\\left", "") 91 | string = string.replace("\\right", "") 92 | 93 | # Remove circ (degrees) 94 | string = string.replace("^{\\circ}", "") 95 | string = string.replace("^\\circ", "") 96 | 97 | # remove dollar signs 98 | string = string.replace("\\$", "") 99 | 100 | # remove units (on the right) 101 | string = remove_right_units(string) 102 | 103 | # remove percentage 104 | string = string.replace("\\%", "") 105 | string = string.replace("\%", "") # noqa: W605 106 | 107 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 108 | string = string.replace(" .", " 0.") 109 | string = string.replace("{.", "{0.") 110 | # if empty, return empty string 111 | if len(string) == 0: 112 | return string 113 | if string[0] == ".": 114 | string = "0" + string 115 | 116 | # to consider: get rid of e.g. "k = " or "q = " at beginning 117 | if len(string.split("=")) == 2: 118 | if len(string.split("=")[0]) <= 2: 119 | string = string.split("=")[1] 120 | 121 | # fix sqrt3 --> sqrt{3} 122 | string = fix_sqrt(string) 123 | 124 | # remove spaces 125 | string = string.replace(" ", "") 126 | 127 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 128 | string = fix_fracs(string) 129 | 130 | # manually change 0.5 --> \frac{1}{2} 131 | if string == "0.5": 132 | string = "\\frac{1}{2}" 133 | 134 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 135 | string = fix_a_slash_b(string) 136 | 137 | return string 138 | 139 | def is_equiv(str1, str2, verbose=False): 140 | if str1 is None and str2 is None: 141 | print("WARNING: Both None") 142 | return True 143 | if str1 is None or str2 is None: 144 | return False 145 | 146 | try: 147 | ss1 = strip_string(str1) 148 | ss2 = strip_string(str2) 149 | #pdb.set_trace() 150 | if verbose: 151 | print(ss1, ss2) 152 | return ss1 == ss2 153 | except Exception: 154 | return str1 == str2 155 | 156 | def process_math_results(completion, answer): 157 | split_ans = completion.split('The answer is: ') 158 | if len(split_ans) > 1: 159 | ans = split_ans[-1] 160 | extract_ans_temp = ans.split('.\n')[0] 161 | extract_ans_temp = extract_ans_temp.strip() 162 | if len(extract_ans_temp)>0 and extract_ans_temp[-1] == '.': 163 | extract_ans = extract_ans_temp[0:-1] 164 | else: 165 | extract_ans = extract_ans_temp 166 | extract_ans = extract_ans.strip() 167 | if is_equiv(extract_ans, answer): 168 | return True 169 | else: 170 | return False 171 | else: 172 | return False 173 | 174 | def is_number(s): 175 | try: 176 | float(s) 177 | return True 178 | except ValueError: 179 | pass 180 | try: 181 | import unicodedata 182 | unicodedata.numeric(s) 183 | return True 184 | except (TypeError, ValueError): 185 | pass 186 | return False 187 | 188 | def extract_answer_number(completion): 189 | text = completion.split('The answer is: ') 190 | if len(text) > 1: 191 | extract_ans = text[-1].strip() 192 | match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans) 193 | if match: 194 | if '/' in match.group(): 195 | denominator = match.group().split('/')[1] 196 | numerator = match.group().split('/')[0] 197 | if is_number(denominator) == True and is_number(numerator) == True: 198 | if denominator == '0': 199 | return round(float(numerator.replace(',', ''))) 200 | else: 201 | frac = Fraction(match.group().replace(',', '')) 202 | num_numerator = frac.numerator 203 | num_denominator = frac.denominator 204 | return round(float(num_numerator / num_denominator)) 205 | else: 206 | return None 207 | else: 208 | if float(match.group().replace(',', '')) == float('inf'): 209 | return None 210 | return round(float(match.group().replace(',', ''))) 211 | else: 212 | return None 213 | else: 214 | return None 215 | 216 | def extract_commonsense_answer(dataset, sentence: str) -> float: 217 | if dataset == 'boolq': 218 | sentence_ = sentence.strip() 219 | pred_answers = re.findall(r'true|false', sentence_) 220 | if not pred_answers: 221 | return "" 222 | return pred_answers[0] 223 | elif dataset == 'piqa': 224 | sentence_ = sentence.strip() 225 | pred_answers = re.findall(r'solution1|solution2', sentence_) 226 | if not pred_answers: 227 | return "" 228 | return pred_answers[0] 229 | elif dataset in ['siqa', 'arc_challenge', 'arc_easy', 'openbookqa']: 230 | sentence_ = sentence.strip() 231 | pred_answers = re.findall(r'answer1|answer2|answer3|answer4|answer5', sentence_) 232 | if not pred_answers: 233 | return "" 234 | return pred_answers[0] 235 | elif dataset == 'hellaswag': 236 | sentence_ = sentence.strip() 237 | pred_answers = re.findall(r'ending1|ending2|ending3|ending4', sentence_) 238 | if not pred_answers: 239 | return "" 240 | return pred_answers[0] 241 | elif dataset == 'winogrande': 242 | sentence_ = sentence.strip() 243 | pred_answers = re.findall(r'option1|option2', sentence_) 244 | if not pred_answers: 245 | return "" 246 | return pred_answers[0] 247 | 248 | parser = argparse.ArgumentParser() 249 | parser.add_argument('--input_file', type=str, help="") 250 | args = parser.parse_args() 251 | 252 | results = defaultdict(list) 253 | with open(args.input_file, 'r') as f: 254 | for line in f.readlines(): 255 | data = json.loads(line) 256 | if data['type'] == 'gsm8k': 257 | y_pred = extract_answer_number(data['output']) 258 | if y_pred != None: 259 | results[data['type']].append(float(y_pred) == float(data["answer"])) 260 | else: 261 | results[data['type']].append(False) 262 | elif data['type'] == 'math': 263 | res = process_math_results(data['output'], data['answer']) 264 | results[data['type']].append(res) 265 | elif data['type'] in ['boolq', 'piqa', 'siqa', 'arc_challenge', 'arc_easy', 'openbookqa', 'hellaswag', 'winogrande']: 266 | y_pred = extract_commonsense_answer(data['type'] ,data['output']) 267 | if y_pred != None: 268 | results[data['type']].append(y_pred == data["answer"]) 269 | else: 270 | results[data['type']].append(False) 271 | 272 | for key, value in results.items(): 273 | acc = sum(value) / len(value) 274 | print(f'{key} length====', len(value), f', {key} acc====', acc) --------------------------------------------------------------------------------