├── inspect.py ├── lora_sweep.yaml ├── LICENSE ├── eval_apps └── eval.py ├── requirements.txt ├── README.md ├── sft.py ├── .gitignore ├── data_utils.py ├── generate.py └── utils.py /inspect.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | with open("responses_apps_t0.2.json") as f: 4 | responses = json.load(f) 5 | 6 | for response in responses[:10]: 7 | print(response[0]) 8 | print() 9 | print("-" * 30) 10 | print() 11 | -------------------------------------------------------------------------------- /lora_sweep.yaml: -------------------------------------------------------------------------------- 1 | program: run_sft.sh 2 | name: lora_sweep 3 | method: grid 4 | metric: 5 | goal: minimize 6 | name: eval_loss 7 | parameters: 8 | lora_r: 9 | values: [8, 16, 32, 64] 10 | lora_alpha: 11 | values: [16, 32, 64, 128] 12 | 13 | command: 14 | - ${env} 15 | - bash 16 | - ${program} 17 | - ${args} -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Martin Weyssow 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /eval_apps/eval.py: -------------------------------------------------------------------------------- 1 | import evaluate 2 | import json 3 | import os 4 | import sys 5 | 6 | input_fp = sys.argv[1] 7 | method = sys.argv[2] 8 | n = sys.argv[3] 9 | base_path = os.path.dirname(input_fp) 10 | 11 | with open(input_fp, "r") as f: 12 | responses = json.load(f) 13 | 14 | apps_metric = evaluate.load('codeparrot/apps_metric', keep_in_memory=True) 15 | 16 | interview_responses = responses[:3000] 17 | competition_responses = responses[3000:4000] 18 | intro_responses = responses[4000:5000] 19 | 20 | for difficulty_responses, difficulty in [(interview_responses, "interview"), 21 | (competition_responses, "competition"), 22 | (intro_responses, "introductory")]: 23 | print(f"Evaluating {difficulty} -- {input_fp}") 24 | results = apps_metric.compute(predictions=difficulty_responses, level=difficulty, debug=False, count_errors=True) 25 | print(results) 26 | 27 | output_fp = os.path.join(base_path, f"apps_metrics_{difficulty}_{method}_n{n}.json") 28 | # output_fp = os.path.join(base_path, f"apps_metrics_{difficulty}.json") 29 | with open(output_fp, "w") as fout: 30 | json.dump(results, fout) 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.23.0 2 | aiohttp==3.8.6 3 | aiosignal==1.3.1 4 | appdirs==1.4.4 5 | async-timeout==4.0.3 6 | attrs==23.1.0 7 | bitsandbytes==0.41.1 8 | certifi==2023.7.22 9 | charset-normalizer==3.3.0 10 | click==8.1.7 11 | datasets==2.14.5 12 | dill==0.3.7 13 | docker-pycreds==0.4.0 14 | evaluate==0.4.1 15 | filelock==3.12.4 16 | frozenlist==1.4.0 17 | fsspec==2023.6.0 18 | gitdb==4.0.11 19 | GitPython==3.1.40 20 | huggingface-hub==0.17.3 21 | idna==3.4 22 | Jinja2==3.1.2 23 | MarkupSafe==2.1.3 24 | mpmath==1.3.0 25 | multidict==6.0.4 26 | multiprocess==0.70.15 27 | networkx==3.1 28 | numpy==1.26.0 29 | nvidia-cublas-cu12==12.1.3.1 30 | nvidia-cuda-cupti-cu12==12.1.105 31 | nvidia-cuda-nvrtc-cu12==12.1.105 32 | nvidia-cuda-runtime-cu12==12.1.105 33 | nvidia-cudnn-cu12==8.9.2.26 34 | nvidia-cufft-cu12==11.0.2.54 35 | nvidia-curand-cu12==10.3.2.106 36 | nvidia-cusolver-cu12==11.4.5.107 37 | nvidia-cusparse-cu12==12.1.0.106 38 | nvidia-nccl-cu12==2.18.1 39 | nvidia-nvjitlink-cu12==12.2.140 40 | nvidia-nvtx-cu12==12.1.105 41 | packaging==23.2 42 | pandas==2.1.1 43 | pathtools==0.1.2 44 | peft==0.5.0 45 | Pillow==10.0.1 46 | protobuf==4.24.4 47 | psutil==5.9.5 48 | pyarrow==13.0.0 49 | python-dateutil==2.8.2 50 | pytz==2023.3.post1 51 | PyYAML==6.0.1 52 | regex==2023.10.3 53 | requests==2.31.0 54 | responses==0.18.0 55 | safetensors==0.4.0 56 | scipy==1.11.3 57 | sentry-sdk==1.32.0 58 | setproctitle==1.3.3 59 | six==1.16.0 60 | smmap==5.0.1 61 | sympy==1.12 62 | tiktoken==0.5.1 63 | tokenizers==0.14.1 64 | torch==2.1.0+cu121 65 | torchaudio==2.1.0 66 | torchvision==0.16.0 67 | tqdm==4.66.1 68 | transformers==4.34.0 69 | tree-sitter==0.20.2 70 | triton==2.1.0 71 | typing_extensions==4.8.0 72 | tzdata==2023.3 73 | urllib3==2.0.6 74 | wandb==0.15.12 75 | xxhash==3.4.1 76 | yarl==1.9.2 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Exploring Parameter-Efficient Fine-Tuning Techniques for Code Generation with Large Language Models 3 | Official replication package for our submission to TOSEM. 4 | 5 | In this readme, we provide detailed instructions on how to setup the repository and run the paper experiments. 6 | Our code can easily be adapted to further investigate the usage of parameter-efficient fine-tuning techniques for large language models for other generation tasks. 7 | 8 | ## Installation 9 | 1. Clone this repository using `git`. 10 | 2. Setup a `Python 3` virtual environment and install the requirements. 11 | ```sh 12 | python3 -m venv env 13 | source env/bin/activate 14 | pip install -r requirements.txt 15 | mkdir runs 16 | ``` 17 | We used Python 3.11.5 to run all our experiments, and a single NVIDIA RTX A5000. 18 | We used CUDA release 12.3, V12.3.52. Please, make sure the PyTorch version match your hardware requirements. 19 | 20 | ## Running the experiments 21 | 22 | ### Fine-tune an LLM using PEFT 23 | ```shell 24 | CUDA_VISIBLE_DEVICES=0 python main.py \ 25 | --model_name_or_path codellama/CodeLlama-7b-hf \ 26 | --dataset codealpaca \ 27 | --tuning_method lora \ 28 | --num_epochs 5 \ 29 | --batch_size 4 \ 30 | --gradient_accumulation_steps 2 \ 31 | --learning_rate 3e-4 \ 32 | --lora_r 8 \ 33 | --lora_alpha 16 \ 34 | --do_train \ 35 | --use_wandb 36 | ``` 37 | 38 | - You can also decide to not use WanDB by removing the `use_wandb` argument. 39 | - For `QLoRA`, set `--tuning_method qlora-8bit` or `--tuning_method qlora-4bit`. 40 | - For joint training, set `--dataset joint`. 41 | - The script automatically saves the best model checkpoint in the directory: `/runs/checkpoints/{dataset}/{model_name}_{tuning_method}/` 42 | In our example: `/runs/checkpoints/codealpaca/CodeLlama-7b-hf_lora` 43 | 44 | ### Evaluating fine-tuned LLMs 45 | ```shell 46 | CUDA_VISIBLE_DEVICES=0 python main.py \ 47 | --model_name_or_path codellama/CodeLlama-7b-hf \ 48 | --adapter_path runs/checkpoints/conala/CodeLlama-7b-hf_lora \ 49 | --tuning_method lora \ 50 | --dataset conala \ 51 | --tuning_method lora \ 52 | --do_test 53 | ``` 54 | - `--adapter_path` corresponds to the local path of the best model checkpoint. 55 | - The script saves files in the directory: `/runs/test_results/{model_name}_{tuning_method}`: 56 | - `output_{dataset}.jsonl`: top-10 predictions to compute EM@*k*. 57 | - `predictions_{dataset}.txt` and `references_{dataset}.txt`: top-1 predictions to compute CodeBLEU. 58 | 59 | ### Evaluating LLMs with ICL 60 | Use the `scripts/inference_icl_seeds.sh` bash script to replicate the results from the paper: 61 | ```shell 62 | bash scripts/inference_icl_seeds.sh codellama/CodeLlama-7b-hf conala 0 63 | ``` 64 | The script is going to run inference on the input model with the following parameters (that you can adjust to your needs): 65 | - `n_examples=(1 2 3)`: 1 to 3 few-shot examples 66 | - `seeds=(42 777 5432 55555 97)`: run inference using 5 seeds 67 | - Note that it results in running inference 15 times, which can be time consuming. 68 | 69 | ### Computing EM@k and CodeBLEU 70 | Use `compute_metrics.py`, which computes the EM@*k* and CodeBLEU on all the test results stored in the `runs/test_results` directory. 71 | -------------------------------------------------------------------------------- /sft.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | from rich.console import Console 5 | from rich.logging import RichHandler 6 | from tqdm.rich import tqdm 7 | from transformers import ( 8 | AutoTokenizer, 9 | DataCollatorForLanguageModeling 10 | ) 11 | from trl import ( 12 | SFTTrainer, 13 | SFTConfig, 14 | get_quantization_config, 15 | get_kbit_device_map, 16 | RichProgressCallback, 17 | DataCollatorForCompletionOnlyLM 18 | ) 19 | from trl.commands.cli_utils import init_zero_verbose, TrlParser 20 | 21 | from datasets import load_from_disk 22 | from utils import ( 23 | SFTScriptArguments, 24 | ModelConfig, 25 | get_peft_config 26 | ) 27 | 28 | init_zero_verbose() 29 | tqdm.pandas() 30 | logging.basicConfig(format="%(message)s", datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO) 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) 35 | args, training_args, model_config = parser.parse_args_and_config() 36 | 37 | training_args.disable_tqdm = True 38 | console = Console() 39 | 40 | tokenizer = AutoTokenizer.from_pretrained( 41 | model_config.model_name_or_path, 42 | trust_remote_code=model_config.trust_remote_code, 43 | use_fast=True 44 | ) 45 | if getattr(tokenizer, "pad_token", None) is None: 46 | tokenizer.pad_token = tokenizer.eos_token 47 | 48 | quantization_config = get_quantization_config(model_config) 49 | model_kwargs = dict( 50 | revision=model_config.model_revision, 51 | trust_remote_code=model_config.trust_remote_code, 52 | attn_implementation=model_config.attn_implementation, 53 | torch_dtype=model_config.torch_dtype, 54 | use_cache=False if training_args.gradient_checkpointing else True, 55 | device_map=get_kbit_device_map() if quantization_config is not None else None, 56 | quantization_config=quantization_config, 57 | ) 58 | training_args.model_init_kwargs = model_kwargs 59 | 60 | dataset = load_from_disk(args.dataset_name) 61 | train_dataset = dataset[args.dataset_train_split] 62 | eval_dataset = dataset[args.dataset_test_split] 63 | 64 | collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) 65 | if args.completion_only: 66 | # ensures the instruction is ignored during loss computation 67 | collator = DataCollatorForCompletionOnlyLM(args.response_template, tokenizer=tokenizer) 68 | 69 | trainer = SFTTrainer( 70 | model=model_config.model_name_or_path, 71 | args=training_args, 72 | train_dataset=train_dataset, 73 | eval_dataset=eval_dataset, 74 | tokenizer=tokenizer, 75 | data_collator=collator, 76 | peft_config=get_peft_config(model_config, tokenizer), 77 | callbacks=[RichProgressCallback()] 78 | ) 79 | 80 | """ 81 | train_dataloader = trainer.get_train_dataloader() 82 | 83 | for i, batch in enumerate(train_dataloader): 84 | input_ids = batch["input_ids"][3] 85 | labels = batch["labels"][3].cpu() 86 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 87 | print(tokenizer.decode(input_ids)) 88 | print("-" * 100) 89 | print(tokenizer.decode(labels)) 90 | break 91 | 92 | """ 93 | trainer.train() 94 | 95 | console.log(model_config) 96 | trainable_params, all_param = trainer.model.get_nb_trainable_parameters() 97 | console.log(f"trainable params: {trainable_params:,d} || " 98 | f"all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}") 99 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .idea 163 | datasets 164 | *.sh 165 | runs -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from datasets import load_dataset, DatasetDict 4 | 5 | from utils import make_chat_template_prompt, INSTRUCTION_PREFIX 6 | 7 | 8 | def transform_magicoder(output_dir="datasets"): 9 | dataset = load_dataset("ise-uiuc/Magicoder-Evol-Instruct-110K") 10 | 11 | def process_example(e): 12 | messages = [ 13 | {"role": "user", "content": e["instruction"]}, 14 | {"role": "assistant", "content": e["response"]} 15 | ] 16 | return {"messages": messages} 17 | 18 | train_set = dataset["train"].shuffle(42) 19 | validation_set = train_set.select(range(1000)) 20 | train_set = train_set.select(range(1000, len(train_set))) 21 | dataset = DatasetDict({ 22 | "train": train_set, 23 | "validation": validation_set, 24 | }) 25 | 26 | for split in dataset.keys(): 27 | dataset[split] = dataset[split].map(lambda e: process_example(e), num_proc=8) 28 | dataset.save_to_disk(f"{output_dir}/magicoder") 29 | 30 | 31 | def transform_magicoder_oss(output_dir="datasets"): 32 | dataset = load_dataset("ise-uiuc/Magicoder-OSS-Instruct-75K") 33 | 34 | def process_example(e): 35 | messages = [ 36 | {"role": "user", "content": e["problem"]}, 37 | {"role": "assistant", "content": e["solution"]} 38 | ] 39 | return {"messages": messages} 40 | 41 | train_set = dataset["train"].shuffle(42) 42 | validation_set = train_set.select(range(1000)) 43 | train_set = train_set.select(range(1000, len(train_set))) 44 | dataset = DatasetDict({ 45 | "train": train_set, 46 | "validation": validation_set, 47 | }) 48 | 49 | for split in dataset.keys(): 50 | dataset[split] = dataset[split].map(lambda e: process_example(e), num_proc=8) 51 | dataset.save_to_disk(f"{output_dir}/magicoder_oss") 52 | 53 | 54 | def transform_conala(output_dir="datasets"): 55 | dataset = load_dataset("neulab/docprompting-conala", trust_remote_code=True) 56 | instruction_prefix = INSTRUCTION_PREFIX["conala"] 57 | 58 | def process_example(e, split): 59 | user_content = e["nl"] 60 | assistant_content = None if split == "test" else e["cmd"] 61 | messages = make_chat_template_prompt(user_content, assistant_content, instruction_prefix) 62 | return {"messages": messages} 63 | 64 | for split in dataset.keys(): 65 | dataset[split] = dataset[split].map(lambda e: process_example(e, split), num_proc=8) 66 | dataset.save_to_disk(f"{output_dir}/conala") 67 | 68 | 69 | def transform_mbpp(output_dir="datasets"): 70 | dataset = load_dataset("google-research-datasets/mbpp", trust_remote_code=True) 71 | instruction_prefix = INSTRUCTION_PREFIX["mbpp"] 72 | 73 | def process_example(e, split): 74 | user_content = f"{e['text']} Your code should pass these tests:" 75 | for test in e["test_list"]: 76 | user_content += f"\n{test}" 77 | assistant_content = None if split == "test" else e["code"] 78 | messages = make_chat_template_prompt(user_content, assistant_content, instruction_prefix) 79 | return {"messages": messages} 80 | 81 | for split in dataset.keys(): 82 | dataset[split] = dataset[split].map(lambda e: process_example(e, split), num_proc=8) 83 | dataset.save_to_disk(f"{output_dir}/mbpp") 84 | 85 | 86 | def transform_apps(output_dir="datasets"): 87 | # this preprocessing follows the same format used in the original APPs paper: 88 | # https://github.com/hendrycks/apps/blob/main/train/dataset_apps/APPSBaseDataset.py 89 | # https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/example_script.py 90 | 91 | dataset = load_dataset("codeparrot/apps", trust_remote_code=True) 92 | instruction_prefix = INSTRUCTION_PREFIX["apps"] 93 | 94 | def process_example(e, split): 95 | starter_code = None if len(e["starter_code"]) == 0 else e["starter_code"] 96 | try: 97 | input_outpout = json.loads(e["input_output"]) 98 | fn_name = None if not input_outpout.get("fn_name") else input_outpout["fn_name"] 99 | except ValueError: 100 | fn_name = None 101 | try: 102 | solutions = json.loads(e["solutions"]) 103 | except ValueError: 104 | solutions = [""] 105 | 106 | user_content = e["question"] 107 | if starter_code: 108 | user_content += starter_code 109 | if fn_name: 110 | user_content += "\nUse Standard Input format\n" 111 | else: 112 | user_content += "\nUse Call-Based format\n" 113 | assistant_content = None if split == "test" else solutions[0] 114 | messages = make_chat_template_prompt(user_content, assistant_content, instruction_prefix) 115 | return {"messages": messages} 116 | 117 | # create validation set 118 | train_set = dataset["train"].shuffle(42) 119 | validation_set = train_set.select(range(500)) 120 | train_set = train_set.select(range(500, len(train_set))) 121 | dataset = DatasetDict({ 122 | "train": train_set, 123 | "validation": validation_set, 124 | "test": dataset["test"] 125 | }) 126 | 127 | for split in dataset.keys(): 128 | dataset[split] = dataset[split].map(lambda e: process_example(e, split), num_proc=8) 129 | dataset.save_to_disk(f"{output_dir}/apps") 130 | 131 | 132 | if __name__ == "__main__": 133 | transform_magicoder_oss() 134 | # transform_magicoder() 135 | # transform_conala() 136 | # transform_mbpp() 137 | # transform_apps() 138 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import anthropic 6 | import evaluate 7 | import torch 8 | from langchain.docstore.document import Document as LangchainDocument 9 | from langchain_community.vectorstores import FAISS 10 | from langchain_community.vectorstores.utils import DistanceStrategy 11 | from langchain_huggingface import HuggingFaceEmbeddings 12 | from openai import OpenAI 13 | from peft import PeftModelForCausalLM 14 | from ragatouille import RAGPretrainedModel 15 | from rich.progress import MofNCompleteColumn, BarColumn, Progress, TextColumn, TimeElapsedColumn 16 | from tqdm import tqdm 17 | from transformers import set_seed, AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList 18 | 19 | from datasets import load_from_disk 20 | from utils import track_gpu_usage, make_chat_template_prompt, encode_chat_template, INSTRUCTION_PREFIX 21 | 22 | 23 | def prepare_input(sample, knowledge_base_vectors, reranker, args): 24 | if args.use_rag: 25 | query = sample[args.instruction_field] 26 | retrieved_docs = knowledge_base_vectors.similarity_search(query=query, k=args.rag_n_retrieve) 27 | 28 | if args.use_reranking: 29 | docs = [doc.page_content for doc in retrieved_docs] 30 | docs_scores = reranker.rerank(query, docs, k=args.rag_n_retrieve) 31 | result_index = [doc["result_index"] for doc in docs_scores] 32 | retrieved_docs = [retrieved_docs[i] for i in result_index[:args.rag_n_final]] 33 | else: 34 | retrieved_docs = retrieved_docs[:args.rag_n_final] 35 | 36 | chat_docs = [] 37 | for doc in retrieved_docs: 38 | chat_docs += make_chat_template_prompt( 39 | doc.page_content, 40 | doc.metadata["code"], 41 | instruction_prefix=INSTRUCTION_PREFIX[args.dataset_name], 42 | ) 43 | return chat_docs + sample["messages"] 44 | return sample["messages"] 45 | 46 | 47 | class CustomStoppingCriteria(StoppingCriteria): 48 | def __init__(self, start, eos_tokens, tokenizer): 49 | self.start = start 50 | self.eos_tokens = eos_tokens 51 | self.tokenizer = tokenizer 52 | 53 | def __call__(self, input_ids, scores, **kwargs): 54 | tokens = self.tokenizer.decode(input_ids[0, self.start:]) 55 | return any([eos_token in tokens for eos_token in self.eos_tokens]) 56 | 57 | 58 | @track_gpu_usage 59 | def generate(args, dataset, model, tokenizer, knowledge_base_vectors=None, reranker=None): 60 | gen_kwargs = { 61 | "do_sample": args.do_sample, 62 | "temperature": args.temperature, 63 | "top_p": args.top_p, 64 | "top_k": args.top_k, 65 | } 66 | 67 | with (Progress( 68 | TextColumn(f"Generating responses •" + "[progress.percentage]{task.percentage:>3.0f}%"), 69 | BarColumn(), 70 | MofNCompleteColumn(), 71 | TextColumn("•"), 72 | TimeElapsedColumn(), 73 | ) as p): 74 | for sample in p.track(dataset): 75 | messages = prepare_input(sample, knowledge_base_vectors, reranker, args) 76 | if not args.api_model: 77 | inputs = encode_chat_template(messages, tokenizer) 78 | inputs = {k: v.to(model.device) for k, v in inputs.items()} 79 | outputs = model.generate( 80 | input_ids=inputs["input_ids"], 81 | attention_mask=inputs["attention_mask"], 82 | max_new_tokens=args.max_new_tokens, 83 | stopping_criteria=[CustomStoppingCriteria(inputs["input_ids"].shape[1], args.eos, tokenizer)], 84 | **gen_kwargs 85 | ) 86 | response_ids = outputs[0][inputs["input_ids"].shape[1]:] 87 | response = tokenizer.decode(response_ids, skip_special_tokens=True) 88 | response = response.split("```")[0].strip() 89 | else: 90 | if isinstance(model, OpenAI): 91 | response = model.chat.completions.create( 92 | model=args.model_name_or_path, 93 | messages=messages, 94 | max_tokens=args.max_new_tokens, 95 | temperature=args.temperature 96 | ) 97 | response = response.choices[0].message.content 98 | elif isinstance(model, anthropic.Anthropic): 99 | pass 100 | print(response) 101 | print("-" * 25) 102 | yield response 103 | 104 | 105 | def compute_metrics(args, responses, dataset): 106 | chrf = evaluate.load("chrf") 107 | em = evaluate.load("exact_match") 108 | 109 | references = dataset[args.reference_field] 110 | results_em = em.compute(predictions=responses, references=references) 111 | 112 | references_chrf = [[ref] for ref in references] 113 | results_chrf = chrf.compute(predictions=responses, references=references_chrf) 114 | results_chrf2 = chrf.compute(predictions=responses, references=references_chrf, word_order=2) 115 | 116 | print(f"EM: {results_em}") 117 | print(f"chrF: {results_chrf}") 118 | print(f"chrF++: {results_chrf2}") 119 | 120 | return { 121 | "em": results_em, 122 | "chrf": results_chrf, 123 | "chrf2": results_chrf2 124 | } 125 | 126 | 127 | def main(args): 128 | if not args.api_model: 129 | model = AutoModelForCausalLM.from_pretrained( 130 | args.model_name_or_path, 131 | torch_dtype=torch.bfloat16, 132 | trust_remote_code=True, 133 | device_map="auto" 134 | ) 135 | if args.peft_checkpoint_path is not None: 136 | model = PeftModelForCausalLM.from_pretrained(model, args.peft_checkpoint_path) 137 | args.model_name = args.model_name_or_path.split("/")[-1] 138 | 139 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) 140 | args.eos = [tokenizer.eos_token, "\n```\n"] 141 | else: 142 | if "deepseek" in args.model_name_or_path: 143 | model = OpenAI(api_key=os.getenv("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com") 144 | elif "claude" in args.model_name_or_path: 145 | model = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) 146 | else: 147 | model = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) 148 | tokenizer = None 149 | args.model_name = args.model_name_or_path 150 | 151 | dataset = load_from_disk(args.dataset_name_or_path)["test"] 152 | args.dataset_name = args.dataset_name_or_path.split("/")[-1] 153 | 154 | if args.dataset_name == "conala": 155 | args.max_new_tokens = 128 156 | args.instruction_field = "nl" 157 | args.reference_field = "cmd" 158 | elif args.dataset_name == "mbpp": 159 | args.max_new_tokens = 512 160 | args.instruction_field = "text" 161 | args.reference_field = "code" 162 | else: 163 | args.max_new_tokens = 1024 164 | args.instruction_field = "question" 165 | args.reference_field = "solutions" 166 | 167 | knowledge_base_vectors = None 168 | reranker = None 169 | if args.use_icl: 170 | examples = ( 171 | load_from_disk(args.dataset_name_or_path)["train"] 172 | .shuffle(args.icl_seed) 173 | .select(range(args.num_icl_examples)) 174 | ) 175 | chat_icl = [] 176 | for example in examples: 177 | if args.dataset_name == "apps": 178 | reference = json.loads(example[args.reference_field])[0] 179 | else: 180 | reference = example[args.reference_field] 181 | chat_exemple = make_chat_template_prompt( 182 | example[args.instruction_field], 183 | reference, 184 | instruction_prefix=INSTRUCTION_PREFIX[args.dataset_name], 185 | ) 186 | chat_icl += chat_exemple 187 | 188 | def add_icl_prompt(example): 189 | example["messages"] = chat_icl + example["messages"] 190 | return example 191 | 192 | dataset = dataset.map(add_icl_prompt, num_proc=16) 193 | elif args.use_rag: 194 | examples = load_from_disk(args.dataset_name_or_path)["train"] 195 | # @todo: special case for APPs 196 | knowledge_base = [ 197 | LangchainDocument( 198 | page_content=sample[args.instruction_field], 199 | metadata={"code": sample[args.reference_field]} 200 | ) for sample in tqdm(examples) 201 | ] 202 | 203 | embedding_model = HuggingFaceEmbeddings( 204 | model_name=args.rag_encoder_model, 205 | multi_process=False, 206 | model_kwargs={"device": "cuda"}, 207 | encode_kwargs={"normalize_embeddings": True}, 208 | ) 209 | 210 | if args.use_reranking: 211 | reranker = RAGPretrainedModel.from_pretrained(args.reranking_model) 212 | 213 | knowledge_base_vectors = FAISS.from_documents( 214 | knowledge_base, embedding_model, distance_strategy=DistanceStrategy.COSINE 215 | ) 216 | 217 | responses, init_gpu_memory, peak_gpu_memory, total_execution_time = ( 218 | generate(args, dataset, model, tokenizer, knowledge_base_vectors, reranker) 219 | ) 220 | 221 | metrics = { 222 | "init_gpu_memory": f"{init_gpu_memory} MB", 223 | "peak_gpu_memory": f"{peak_gpu_memory} MB", 224 | "total_execution_time": f"{total_execution_time} seconds" 225 | } 226 | 227 | if args.dataset_name == "conala": 228 | conala_metrics = compute_metrics(args, responses, dataset) 229 | metrics = {**metrics, **conala_metrics} 230 | 231 | output_dir = ( 232 | f"{args.peft_checkpoint_path}/results" if args.peft_checkpoint_path else f"runs/{args.model_name}/results" 233 | ) 234 | os.makedirs(output_dir, exist_ok=True) 235 | 236 | file_suffix = f"{args.dataset_name}_t{args.temperature}" 237 | if args.use_icl: 238 | file_suffix += f"_icl_n{args.num_icl_examples}_s{args.icl_seed}" 239 | elif args.use_rag: 240 | file_suffix += "_rag" 241 | if args.use_reranking: 242 | file_suffix += "_reranking" 243 | file_suffix += f"_n{args.rag_n_final}" 244 | 245 | with open(f"{output_dir}/metrics_{file_suffix}.jsonl", "w") as fout: 246 | json.dump(metrics, fout) 247 | 248 | data = [[response] for response in responses] 249 | with open(f"{output_dir}/responses_{file_suffix}.json", "w") as fout: 250 | json.dump(data, fout) 251 | 252 | 253 | if __name__ == "__main__": 254 | parser = argparse.ArgumentParser() 255 | parser.add_argument("--model_name_or_path", type=str, default=None) 256 | parser.add_argument("--peft_checkpoint_path", type=str, default=None) 257 | parser.add_argument("--dataset_name_or_path", type=str, default=None) 258 | parser.add_argument("--api_model", action="store_true", default=False) 259 | 260 | parser.add_argument("--do_sample", default=True, type=bool, help="do sampling in generation") 261 | parser.add_argument("--temperature", default=0.2, type=float, help="temperature for sampling") 262 | parser.add_argument("--top_p", default=0.95, type=float, help="top p for sampling") 263 | parser.add_argument("--top_k", default=0, type=float, help="top k for sampling") 264 | 265 | parser.add_argument("--use_icl", action="store_true", default=False) 266 | parser.add_argument("--icl_seed", type=int, default=42) 267 | parser.add_argument("--num_icl_examples", type=int, default=3) 268 | 269 | parser.add_argument("--use_rag", action="store_true", default=False) 270 | parser.add_argument("--rag_encoder_model", default="thenlper/gte-small", type=str) 271 | parser.add_argument("--rag_n_retrieve", default=1, type=int) 272 | parser.add_argument("--rag_n_final", default=1, type=int) 273 | parser.add_argument("--use_reranking", action="store_true", default=False) 274 | parser.add_argument("--reranking_model", default="colbert-ir/colbertv2.0", type=str) 275 | 276 | args = parser.parse_args() 277 | set_seed(42) 278 | main(args) 279 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import subprocess 4 | import time 5 | from dataclasses import field, dataclass 6 | from typing import Optional, List, Union, Any, Dict 7 | 8 | import torch 9 | from transformers import AutoTokenizer 10 | from trl import DataCollatorForCompletionOnlyLM 11 | from trl.core import flatten_dict 12 | 13 | from peft import ( 14 | LoraConfig, 15 | PeftConfig, 16 | PromptEncoderConfig, 17 | PromptTuningConfig, 18 | PromptTuningInit, 19 | PrefixTuningConfig 20 | ) 21 | 22 | 23 | LORA_TARGET_MODULES = { 24 | "Phi-3-mini-128k-instruct": ["o_proj", "qkv_proj"], 25 | "deepseek-coder-6.7b-instruct": ["q_proj", "v_proj", "o_proj", "k_proj"], 26 | "CodeQwen1.5-7B-Chat": ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], 27 | "Meta-Llama-3.1-8B-Instruct": ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], 28 | "Qwen2.5-Coder-7B-Instruct": ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], 29 | "Qwen2.5-Coder-1.5B": ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], 30 | "Qwen2.5-Coder-1.5B-Instruct": ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], 31 | } 32 | 33 | _MAGIC_SPLITTER_ = "-[[]]-this-is-really-our-highest-priority-[[]]-" 34 | 35 | INSTRUCTION_PREFIX = { 36 | "conala": ( 37 | "Provide a self-contained Python script that solves the following problem in a markdown code block. " 38 | "Your solution should most likely contain a single line of code, or only a few ones." 39 | ), 40 | "mbpp": ( 41 | "Provide a self-contained Python script that solves the following problem in a markdown code block. " 42 | "You are given example test cases from which you can infer the function signature." 43 | ), 44 | "apps": ( 45 | "Provide a self-contained Python script that solves the following problem in a markdown code block. " 46 | "Make sure the solution obeys the constraints and passes the example test cases." 47 | ) 48 | } 49 | 50 | 51 | def encode_chat_template(chat_template, tokenizer): 52 | prompt = tokenizer.apply_chat_template(chat_template, tokenize=False).split(_MAGIC_SPLITTER_)[0] 53 | return tokenizer(prompt, return_attention_mask=True, return_tensors="pt") 54 | 55 | 56 | def make_chat_template_prompt(instruction, response, instruction_prefix): 57 | # https://github.com/evalplus/evalplus/blob/master/evalplus/provider/utility.py#L25 58 | user_content = f"{instruction_prefix}\n```\n{instruction.strip()}\n```" 59 | if response is None: 60 | assistant_content = f"```python\n{_MAGIC_SPLITTER_}\n```" 61 | else: 62 | assistant_content = f"```python\n{response.strip()}\n```" 63 | 64 | return [ 65 | {"role": "user", "content": user_content}, 66 | {"role": "assistant", "content": assistant_content} 67 | ] 68 | 69 | 70 | class CustomDataCollatorForCompletionOnlyLM(DataCollatorForCompletionOnlyLM): 71 | def __init__(self, *args, **kwargs): 72 | super().__init__(*args, **kwargs) 73 | 74 | def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: 75 | batch = super().torch_call(examples) 76 | 77 | return batch 78 | 79 | 80 | def get_gpu_memory_usage(): 81 | try: 82 | visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) 83 | if visible_devices is None: 84 | # select all available GPUs 85 | gpu_ids = range(torch.cuda.device_count()) 86 | else: 87 | # select active GPUs for the current script 88 | gpu_ids = list(map(int, visible_devices.split(','))) 89 | 90 | result = subprocess.check_output( 91 | ['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'], 92 | encoding='utf-8' 93 | ) 94 | 95 | # sum GPU memory usage from active GPUs 96 | gpu_memory = [int(x) for x in result.strip().split('\n')] 97 | total_memory_used = sum(gpu_memory[gpu_id] for gpu_id in gpu_ids) 98 | 99 | return total_memory_used 100 | except Exception: 101 | return -1 102 | 103 | 104 | def track_gpu_usage(func): 105 | @functools.wraps(func) 106 | def wrapper_track_gpu_usage(*args, **kwargs): 107 | initial_gpu_memory = get_gpu_memory_usage() 108 | max_gpu_memory_usage = initial_gpu_memory 109 | start_time = time.time() 110 | result = [] 111 | try: 112 | gen = func(*args, **kwargs) 113 | for sample_index, output in enumerate(gen): 114 | current_gpu_memory = get_gpu_memory_usage() 115 | max_gpu_memory_usage = max(max_gpu_memory_usage, current_gpu_memory) 116 | result.append(output) 117 | except Exception as e: 118 | print(f"Error during execution: {e}") 119 | finally: 120 | end_time = time.time() 121 | 122 | return result, initial_gpu_memory, max_gpu_memory_usage, end_time - start_time 123 | 124 | return wrapper_track_gpu_usage 125 | 126 | 127 | @dataclass 128 | class SFTScriptArguments: 129 | dataset_name: str = field( 130 | default="timdettmers/openassistant-guanaco", 131 | metadata={"help": "the dataset name"}, 132 | ) 133 | dataset_train_split: str = field(default="train", metadata={"help": "The dataset split to train on"}) 134 | dataset_test_split: str = field(default="validation", metadata={"help": "The dataset split to evaluate on"}) 135 | config: str = field(default=None, metadata={"help": "Path to the optional config file"}) 136 | gradient_checkpointing_use_reentrant: bool = field( 137 | default=False, 138 | metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}, 139 | ) 140 | completion_only: bool = field( 141 | default=False, 142 | metadata={"help": "Whether to only consider the assistant's response in the loss calculation"} 143 | ) 144 | response_template: str = field( 145 | default=None, 146 | metadata={"help": "Response template when setting `completion_only`"} 147 | ) 148 | 149 | 150 | @dataclass 151 | class ModelConfig: 152 | """ 153 | Arguments which define the model and tokenizer to load. 154 | """ 155 | 156 | model_name_or_path: Optional[str] = field( 157 | default=None, 158 | metadata={"help": ("The model checkpoint for weights initialization.")}, 159 | ) 160 | model_revision: str = field( 161 | default="main", 162 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 163 | ) 164 | torch_dtype: Optional[str] = field( 165 | default=None, 166 | metadata={ 167 | "help": ( 168 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " 169 | "dtype will be automatically derived from the model's weights." 170 | ), 171 | "choices": ["auto", "bfloat16", "float16", "float32"], 172 | }, 173 | ) 174 | trust_remote_code: bool = field(default=True, metadata={"help": "Trust remote code when loading a model."}) 175 | attn_implementation: Optional[str] = field( 176 | default=None, 177 | metadata={ 178 | "help": ( 179 | "Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`" 180 | ) 181 | }, 182 | ) 183 | use_peft: bool = field( 184 | default=False, 185 | metadata={"help": ("Whether to use PEFT or not for training.")}, 186 | ) 187 | # LoRA methods 188 | use_lora: bool = field( 189 | default=False, 190 | metadata={"help": ("Whether to use LoRA.")}, 191 | ) 192 | lora_r: Optional[int] = field( 193 | default=16, 194 | metadata={"help": ("LoRA R value.")}, 195 | ) 196 | lora_alpha: Optional[int] = field( 197 | default=32, 198 | metadata={"help": ("LoRA alpha.")}, 199 | ) 200 | lora_dropout: Optional[float] = field( 201 | default=0.05, 202 | metadata={"help": ("LoRA dropout.")}, 203 | ) 204 | lora_target_modules: Optional[List[str]] = field( 205 | default=None, 206 | metadata={"help": ("LoRA target modules.")}, 207 | ) 208 | lora_modules_to_save: Optional[List[str]] = field( 209 | default=None, 210 | metadata={"help": ("Model layers to unfreeze & train")}, 211 | ) 212 | task_type: str = field( 213 | default="CAUSAL_LM", metadata={"help": "The task_type to pass for LoRA (use SEQ_CLS for reward modeling)"} 214 | ) 215 | # QLoRA 216 | load_in_8bit: bool = field( 217 | default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"} 218 | ) 219 | load_in_4bit: bool = field( 220 | default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"} 221 | ) 222 | bnb_4bit_quant_type: Optional[str] = field( 223 | default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"} 224 | ) 225 | use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) 226 | # Prompt methods 227 | use_p_tuning: bool = field( 228 | default=False, 229 | metadata={"help": ("Whether to use p-tuning.")}, 230 | ) 231 | use_prefix_tuning: bool = field( 232 | default=False, 233 | metadata={"help": ("Whether to use prefix tuning.")}, 234 | ) 235 | use_prompt_tuning: bool = field( 236 | default=False, 237 | metadata={"help": ("Whether to use prompt tuning.")}, 238 | ) 239 | num_virtual_tokens: int = field( 240 | default=20, 241 | metadata={"help": ("Number of virtual tokens for p-tuning or prefix tuning.")}, 242 | ) 243 | encoder_hidden_size: int = field( 244 | default=128, 245 | metadata={"help": ("Encoder hidden size for p-tuning.")}, 246 | ) 247 | active_gpu: int = field( 248 | default=-1, 249 | metadata={"help": ("The index of the active GPU used for training.")} 250 | ) 251 | 252 | def to_dict(self): 253 | output_dict = {} 254 | for key, value in self.__dict__.items(): 255 | output_dict[key] = value 256 | return flatten_dict(output_dict) 257 | 258 | def __post_init__(self): 259 | if self.load_in_8bit and self.load_in_4bit: 260 | raise ValueError("You can't use 8 bit and 4 bit precision at the same time") 261 | 262 | if isinstance(self.lora_target_modules, list) and len(self.lora_target_modules) == 1: 263 | self.lora_target_modules = self.lora_target_modules[0] 264 | 265 | 266 | def get_peft_config(model_config: ModelConfig, tokenizer: AutoTokenizer) -> "Optional[PeftConfig]": 267 | if model_config.use_peft is False: 268 | return None 269 | 270 | model_name = model_config.model_name_or_path.split("/")[-1] 271 | 272 | if model_config.use_lora: 273 | peft_config = LoraConfig( 274 | r=model_config.lora_r, 275 | lora_alpha=model_config.lora_alpha, 276 | lora_dropout=model_config.lora_dropout, 277 | bias="none", 278 | task_type=model_config.task_type, 279 | target_modules=LORA_TARGET_MODULES[model_name], 280 | modules_to_save=model_config.lora_modules_to_save, 281 | ) 282 | elif model_config.use_p_tuning: 283 | peft_config = PromptEncoderConfig( 284 | task_type=model_config.task_type, 285 | num_virtual_tokens=model_config.num_virtual_tokens, 286 | encoder_hidden_size=model_config.encoder_hidden_size, 287 | ) 288 | elif model_config.use_prompt_tuning: 289 | prompt_tuning_init_text = "Generate a Python code that solves the given problem.\n" 290 | peft_config = PromptTuningConfig( 291 | task_type=model_config.task_type, 292 | prompt_tuning_init=PromptTuningInit.TEXT, 293 | num_virtual_tokens=len(tokenizer(prompt_tuning_init_text)["input_ids"]), 294 | prompt_tuning_init_text=prompt_tuning_init_text, 295 | tokenizer_name_or_path=model_config.model_name_or_path, 296 | ) 297 | elif model_config.use_prefix_tuning: 298 | peft_config = PrefixTuningConfig( 299 | task_type=model_config.task_type, 300 | num_virtual_tokens=model_config.num_virtual_tokens 301 | ) 302 | else: 303 | peft_config = None 304 | 305 | return peft_config 306 | --------------------------------------------------------------------------------