├── HMT-SiLLM.py ├── HMT-SiLLM.sh ├── HMT_Policy ├── L11_K8.json ├── L2_K4.json ├── L3_K6.json ├── L5_K6.json ├── L7_K6.json └── L9_K8.json ├── Model_Framework(2).pdf ├── README.md ├── SFT.sh ├── SFT_data └── DeEn_data.json ├── Wait-k-SiLLM.py ├── Wait-k-SiLLM.sh ├── finetune.py ├── model.PNG ├── requirements.txt ├── templates ├── README.md ├── Text_translation.json ├── alpaca.json ├── alpaca_legacy.json ├── alpaca_short.json └── vigogne.json ├── test.json └── utils ├── README.md ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc ├── callbacks.cpython-38.pyc └── prompter.cpython-38.pyc ├── callbacks.py └── prompter.py /HMT-SiLLM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pdb 4 | import fire 5 | import torch 6 | import transformers 7 | from peft import PeftModel 8 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, AutoTokenizer 9 | from datasets import load_dataset 10 | from utils.callbacks import Iteratorize, Stream 11 | from utils.prompter import Prompter 12 | import json 13 | import time 14 | if torch.cuda.is_available(): 15 | device = "cuda" 16 | else: 17 | device = "cpu" 18 | 19 | try: 20 | if torch.backends.mps.is_available(): 21 | device = "mps" 22 | except: # noqa: E722 23 | pass 24 | 25 | 26 | def main( 27 | load_8bit: bool = False, 28 | base_model: str = "", 29 | lora_weights: str = "tloen/alpaca-lora-7b", 30 | prompt_template: str = "", # The prompt template to use, will default to alpaca. 31 | data_path: str = "", 32 | output_translation_path: str="", 33 | Bottom: int=1, 34 | Top: int=3, 35 | ): 36 | base_model = base_model or os.environ.get("BASE_MODEL", "") 37 | assert ( 38 | base_model 39 | ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" 40 | 41 | prompter = Prompter(prompt_template) 42 | tokenizer = AutoTokenizer.from_pretrained(base_model) 43 | if device == "cuda": 44 | model = LlamaForCausalLM.from_pretrained( 45 | base_model, 46 | load_in_8bit=load_8bit, 47 | torch_dtype=torch.float16, 48 | device_map="auto", 49 | ) 50 | 51 | model = PeftModel.from_pretrained( 52 | model, 53 | lora_weights, 54 | torch_dtype=torch.float16, 55 | ) 56 | 57 | elif device == "mps": 58 | model = LlamaForCausalLM.from_pretrained( 59 | base_model, 60 | device_map={"": device}, 61 | torch_dtype=torch.float16, 62 | ) 63 | model = PeftModel.from_pretrained( 64 | model, 65 | lora_weights, 66 | device_map={"": device}, 67 | torch_dtype=torch.float16, 68 | ) 69 | else: 70 | model = LlamaForCausalLM.from_pretrained( 71 | base_model, device_map={"": device}, low_cpu_mem_usage=True 72 | ) 73 | model = PeftModel.from_pretrained( 74 | model, 75 | lora_weights, 76 | device_map={"": device}, 77 | ) 78 | 79 | # unwind broken decapoda-research config 80 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk 81 | model.config.bos_token_id = 1 82 | model.config.eos_token_id = 2 83 | 84 | if not load_8bit: 85 | model.half() # seems to fix bugs for some users. 86 | 87 | model.eval() 88 | if torch.__version__ >= "2" and sys.platform != "win32": 89 | model = torch.compile(model) 90 | 91 | def evaluate( 92 | instruction, 93 | input=None, 94 | output=None, 95 | suppress_tokens=None, 96 | temperature=0.1, 97 | top_p=0.75, 98 | top_k=40, 99 | num_beams=4, 100 | max_new_tokens=128, 101 | stream_output=False, 102 | **kwargs, 103 | ): 104 | prompt = prompter.generate_prompt(instruction, input, output) 105 | inputs = tokenizer(prompt, return_tensors="pt") 106 | input_ids = inputs["input_ids"].to(device) 107 | generation_config = GenerationConfig( 108 | num_beams=num_beams, 109 | suppress_tokens=suppress_tokens, 110 | **kwargs, 111 | ) 112 | 113 | # Without streaming 114 | with torch.no_grad(): 115 | generation_output = model.generate( 116 | input_ids=input_ids, 117 | generation_config=generation_config, 118 | return_dict_in_generate=True, 119 | output_scores=True, 120 | max_new_tokens=max_new_tokens, 121 | ) 122 | s = generation_output.sequences[0] 123 | output = tokenizer.decode(s) 124 | return prompter.get_response(output), s.size(-1) - input_ids.size(-1) 125 | 126 | def HMT_policy( 127 | instruction, 128 | input=None, 129 | policy=[], 130 | Lower=1, 131 | Upper=3, 132 | num_beams=1, 133 | max_new_tokens=256 134 | ): 135 | cur_target_str = "" 136 | tokenized_input = input 137 | i = 0 138 | src_len = len(input.split()) 139 | tmp_max_new_tokens = 1 140 | rw_seq = [] 141 | first_time = True 142 | 143 | tran_tgt_seqLen = len(policy) 144 | supress_tokens = [2] 145 | total_tokens = 0 146 | for i in range(tran_tgt_seqLen): 147 | limited_policy = policy[i] 148 | if policy[i] < Lower+i: 149 | limited_policy = Lower+i 150 | elif policy[i] > Upper+i: 151 | limited_policy = Upper+i 152 | limited_policy = min(limited_policy, src_len) 153 | cut_input = ' '.join(input.split()[:limited_policy]) 154 | tmp_max_new_tokens = 3 155 | if i >= (tran_tgt_seqLen - 1): 156 | tmp_max_new_tokens = max_new_tokens 157 | supress_tokens = None 158 | cur_target_str, tmp_size = evaluate(instruction, cut_input, output=cur_target_str, suppress_tokens=None, num_beams=num_beams, max_new_tokens=tmp_max_new_tokens) 159 | total_tokens += tmp_size 160 | if i < (tran_tgt_seqLen - 1): 161 | cur_target_str = ' '.join(cur_target_str.split()[:i+1]) 162 | rw_seq.append(limited_policy) 163 | if cur_target_str.find('') != -1: 164 | break 165 | else: 166 | tmp_size = len(cur_target_str.split()) - i 167 | rw_seq = rw_seq + [src_len] * tmp_size 168 | 169 | rw_seq.append(src_len) 170 | return rw_seq, cur_target_str, total_tokens 171 | 172 | data = load_dataset("json", data_files=data_path) 173 | test_data = data["train"] 174 | output_text = [] 175 | j = 1 176 | total_generate_tokens = 0 177 | total_generate_words = 0 178 | start_time = time.time() 179 | for item_data in test_data: 180 | print('sample' + str(j)) 181 | j += 1 182 | tmp_result = HMT_policy(item_data["instruction"], item_data["input"], item_data['policy'], Bottom, Top, num_beams=1, max_new_tokens=1024) 183 | total_generate_tokens += tmp_result[2] 184 | total_generate_words += len(tmp_result[1].split(' ')) 185 | index = tmp_result[1].find('\n') 186 | tmp_str = tmp_result[1] 187 | if index!=-1: 188 | tmp_str = tmp_result[1][:index] 189 | output_text.append({'rw': tmp_result[0], 'translation': tmp_str}) 190 | end_time = time.time() 191 | with open(output_translation_path, "w", encoding='utf-8') as fp: 192 | json.dump(output_text, fp, indent=4, ensure_ascii=False) 193 | 194 | print('Total time: '+str(end_time-start_time) + 'Total_words: '+str(total_generate_words)) 195 | if __name__ == "__main__": 196 | fire.Fire(main) 197 | 198 | -------------------------------------------------------------------------------- /HMT-SiLLM.sh: -------------------------------------------------------------------------------- 1 | Base_Model=/path/base_model 2 | LoRA_Weithts=/path/LoRA_weights 3 | Output_Translation=/path/output 4 | Test_Data=./HMT_Policy/L2_K4.json 5 | Bottom=1 6 | Top=3 7 | 8 | python HMT-SiLLM.py \ 9 | --base_model ${Base_Model} \ 10 | --lora_weights ${LoRA_Weithts} \ 11 | --prompt_template 'Text_translation' \ 12 | --Bottom ${Bottom} \ 13 | --Top ${Top} \ 14 | --data_path ${Test_Data} \ 15 | --output_translation_path ${Output_Translation} 16 | -------------------------------------------------------------------------------- /Model_Framework(2).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/SiLLM/2c952ec5dc6e78bf6ba2481f4496b522c39c52c8/Model_Framework(2).pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SiLLM 2 | 3 | Source code for our paper "SiLLM: Large Language Models for Simultaneous Machine Translation". 4 | 5 |
6 | 替代文本 7 |
8 | 9 | The framework of SiLLM incorporates the LLM to achieve the Simultaneous Machine Translation. It generates the translations under the guidance of the policy decided by the conventional Simultaneous Machine Translation Model. 10 | 11 | Our method is implemented based on the open-source toolkit [Alpaca-LoRA](https://github.com/tloen/alpaca-lora). 12 | 13 | ## Requirements and Installation 14 | 15 | * Python version = 3.8 16 | 17 | * PyTorch version = 2.2 18 | 19 | * Install our library: 20 | 21 | ``` 22 | git clone https://github.com/ictnlp/SiLLM.git 23 | cd SiLLM 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | ## Quick Start 28 | 29 | ### Fine-tune 30 | 31 | We sample 100k data for fine-tuning LLM from WMT15 German-English (download [here](https://www.statmt.org/wmt15)) and MuST-C English-German (download [here](https://mt.fbk.eu/must-c/)), respectively. In the given example, we sample only 50k of data to provide the data format. 32 | 33 | 34 | We perform SFT for WMT15 German-English dataset using the script: 35 | ``` 36 | bash finetune.sh 37 | ``` 38 | 39 | ### Wait-k-SiLLM 40 | We can execute the Wait-k policy with LLM by running the following script: 41 | ``` 42 | bash Wait-k-SiLLM.sh 43 | ``` 44 | 45 | 46 | ### HMT-SiLLM 47 | We can execute the HMT policy with LLM and get the outputs by running the following script: 48 | ``` 49 | bash HMT-SiLLM.sh 50 | ``` 51 | 52 | 53 | ## Citation 54 | ``` 55 | @misc{guo2024sillm, 56 | title={SiLLM: Large Language Models for Simultaneous Machine Translation}, 57 | author={Shoutao Guo and Shaolei Zhang and Zhengrui Ma and Min Zhang and Yang Feng}, 58 | year={2024}, 59 | eprint={2402.13036}, 60 | archivePrefix={arXiv}, 61 | primaryClass={cs.CL} 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /SFT.sh: -------------------------------------------------------------------------------- 1 | Base_Model=/path/base_model 2 | LoRA_Weithts=/path/LoRA_weights 3 | Data_File=./SFT_Data 4 | 5 | python finetune.py \ 6 | --base_model ${Base_Model} \ 7 | --data_path ${Data_File} \ 8 | --output_dir ${LoRA_Weithts} \ 9 | --batch_size 128 \ 10 | --micro_batch_size 4 \ 11 | --num_epochs 10 \ 12 | --learning_rate 1e-4 \ 13 | --cutoff_len 1024 \ 14 | --val_set_size 2000 \ 15 | --lora_r 8 \ 16 | --cutoff_len 1024 \ 17 | --lora_alpha 16 \ 18 | --lora_dropout 0.05 \ 19 | --lora_target_modules '[q_proj,k_proj,v_proj,o_proj]' \ 20 | --train_on_inputs \ 21 | --group_by_length \ 22 | --train_on_inputs False \ 23 | --prompt_template_name 'Text_translation' 24 | -------------------------------------------------------------------------------- /Wait-k-SiLLM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pdb 4 | import fire 5 | import torch 6 | import transformers 7 | from peft import PeftModel 8 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, AutoTokenizer 9 | from datasets import load_dataset 10 | from utils.callbacks import Iteratorize, Stream 11 | from utils.prompter import Prompter 12 | import json 13 | 14 | if torch.cuda.is_available(): 15 | device = "cuda" 16 | else: 17 | device = "cpu" 18 | 19 | try: 20 | if torch.backends.mps.is_available(): 21 | device = "mps" 22 | except: # noqa: E722 23 | pass 24 | 25 | 26 | def main( 27 | load_8bit: bool = False, 28 | base_model: str = "", 29 | lora_weights: str = "tloen/alpaca-lora-7b", 30 | prompt_template: str = "", # The prompt template to use, will default to alpaca. 31 | data_path: str = "", 32 | output_translation_path: str="", 33 | waitk: int=1, 34 | ): 35 | base_model = base_model or os.environ.get("BASE_MODEL", "") 36 | assert ( 37 | base_model 38 | ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" 39 | 40 | prompter = Prompter(prompt_template) 41 | tokenizer = AutoTokenizer.from_pretrained(base_model) 42 | if device == "cuda": 43 | model = LlamaForCausalLM.from_pretrained( 44 | base_model, 45 | load_in_8bit=load_8bit, 46 | torch_dtype=torch.float16, 47 | device_map="auto", 48 | ) 49 | model = PeftModel.from_pretrained( 50 | model, 51 | lora_weights, 52 | torch_dtype=torch.float16, 53 | ) 54 | elif device == "mps": 55 | model = LlamaForCausalLM.from_pretrained( 56 | base_model, 57 | device_map={"": device}, 58 | torch_dtype=torch.float16, 59 | ) 60 | model = PeftModel.from_pretrained( 61 | model, 62 | lora_weights, 63 | device_map={"": device}, 64 | torch_dtype=torch.float16, 65 | ) 66 | else: 67 | model = LlamaForCausalLM.from_pretrained( 68 | base_model, device_map={"": device}, low_cpu_mem_usage=True 69 | ) 70 | model = PeftModel.from_pretrained( 71 | model, 72 | lora_weights, 73 | device_map={"": device}, 74 | ) 75 | 76 | # unwind broken decapoda-research config 77 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk 78 | model.config.bos_token_id = 1 79 | model.config.eos_token_id = 2 80 | 81 | if not load_8bit: 82 | model.half() # seems to fix bugs for some users. 83 | 84 | model.eval() 85 | if torch.__version__ >= "2" and sys.platform != "win32": 86 | model = torch.compile(model) 87 | 88 | def evaluate( 89 | instruction, 90 | input=None, 91 | output=None, 92 | suppress_tokens=None, 93 | temperature=0.1, 94 | top_p=0.75, 95 | top_k=40, 96 | num_beams=4, 97 | max_new_tokens=128, 98 | stream_output=False, 99 | **kwargs, 100 | ): 101 | prompt = prompter.generate_prompt(instruction, input, output) 102 | inputs = tokenizer(prompt, return_tensors="pt") 103 | input_ids = inputs["input_ids"].to(device) 104 | generation_config = GenerationConfig( 105 | num_beams=num_beams, 106 | suppress_tokens=suppress_tokens, 107 | **kwargs, 108 | ) 109 | 110 | # Without streaming 111 | with torch.no_grad(): 112 | generation_output = model.generate( 113 | input_ids=input_ids, 114 | generation_config=generation_config, 115 | return_dict_in_generate=True, 116 | output_scores=True, 117 | max_new_tokens=max_new_tokens, 118 | ) 119 | s = generation_output.sequences[0] 120 | output = tokenizer.decode(s) 121 | return prompter.get_response(output), s.size(-1) - input_ids.size(-1) 122 | 123 | def Waitk_policy( 124 | instruction, 125 | input=None, 126 | num_beams=1, 127 | waitk=1, 128 | max_new_tokens=256 129 | ): 130 | cur_target_str = "" 131 | tokenized_input = input 132 | i = 0 133 | src_len = len(input.split()) 134 | tmp_max_new_tokens = 1 135 | rw_seq = [] 136 | first_time = True 137 | suppress_tokens=[2] 138 | while (i+waitk <= src_len) or first_time: 139 | cut_input = ' '.join(input.split()[:min(i+waitk, src_len)]) 140 | tmp_max_new_tokens = 5 141 | if i+waitk >= src_len: 142 | tmp_max_new_tokens = max_new_tokens 143 | suppress_tokens=None 144 | cur_target_str, tmp_size = evaluate(instruction, cut_input, output=cur_target_str, suppress_tokens=suppress_tokens, num_beams=num_beams, max_new_tokens=tmp_max_new_tokens) 145 | if i+waitk < src_len: 146 | cur_target_str = ' '.join(cur_target_str.split()[:i+1]) 147 | rw_seq.append(i+waitk) 148 | if cur_target_str.find('') != -1: 149 | break 150 | else: 151 | tmp_size = len(cur_target_str.split()) - i 152 | rw_seq = rw_seq + [src_len] * tmp_size 153 | first_time=False 154 | i += 1 155 | rw_seq.append(src_len) 156 | 157 | return rw_seq, cur_target_str 158 | data = load_dataset("json", data_files=data_path) 159 | test_data = data["train"] 160 | output_text = [] 161 | j = 1 162 | for item_data in test_data: 163 | print('sample' + str(j)) 164 | j += 1 165 | tmp_result = Waitk_policy(item_data["instruction"], item_data["input"], num_beams=1, waitk=waitk, max_new_tokens=1024) 166 | index = tmp_result[1].find('\n') 167 | tmp_str = tmp_result[1] 168 | if index!=-1: 169 | tmp_str = tmp_result[1][:index] 170 | output_text.append({'rw': tmp_result[0], 'translation': tmp_str}) 171 | with open(output_translation_path, "w", encoding='utf-8') as fp: 172 | json.dump(output_text, fp, indent=4, ensure_ascii=False) 173 | 174 | if __name__ == "__main__": 175 | fire.Fire(main) 176 | 177 | -------------------------------------------------------------------------------- /Wait-k-SiLLM.sh: -------------------------------------------------------------------------------- 1 | k=11 2 | Base_Model=/path/base_model 3 | LoRA_Weithts=/path/LoRA_weights 4 | Output_Translation=/path/output 5 | Test_Data=./test.json 6 | 7 | python Wait-k-SiLLM.py \ 8 | --base_model ${Base_Model} \ 9 | --lora_weights ${LoRA_Weithts} \ 10 | --prompt_template 'Text_translation' \ 11 | --data_path ${Test_Data} \ 12 | --output_translation_path ${Output_Translation} \ 13 | --waitk ${k} 14 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import List 4 | 5 | import fire 6 | import torch 7 | import transformers 8 | from datasets import load_dataset 9 | 10 | """ 11 | Unused imports: 12 | import torch.nn as nn 13 | import bitsandbytes as bnb 14 | """ 15 | 16 | from peft import ( 17 | LoraConfig, 18 | get_peft_model, 19 | get_peft_model_state_dict, 20 | prepare_model_for_int8_training, 21 | set_peft_model_state_dict, 22 | ) 23 | from transformers import LlamaForCausalLM, LlamaTokenizer, AutoTokenizer 24 | 25 | from utils.prompter import Prompter 26 | 27 | 28 | def train( 29 | # model/data params 30 | base_model: str = "", # the only required argument 31 | data_path: str = "yahma/alpaca-cleaned", 32 | output_dir: str = "./lora-alpaca", 33 | # training hyperparams 34 | batch_size: int = 128, 35 | micro_batch_size: int = 4, 36 | num_epochs: int = 3, 37 | learning_rate: float = 3e-4, 38 | cutoff_len: int = 256, 39 | val_set_size: int = 2000, 40 | # lora hyperparams 41 | lora_r: int = 8, 42 | lora_alpha: int = 16, 43 | lora_dropout: float = 0.05, 44 | lora_target_modules: List[str] = [ 45 | "q_proj", 46 | "v_proj", 47 | ], 48 | # llm hyperparams 49 | train_on_inputs: bool = True, # if False, masks out inputs in loss 50 | add_eos_token: bool = False, 51 | group_by_length: bool = False, # faster, but produces an odd training loss curve 52 | # wandb params 53 | wandb_project: str = "", 54 | wandb_run_name: str = "", 55 | wandb_watch: str = "", # options: false | gradients | all 56 | wandb_log_model: str = "", # options: false | true 57 | resume_from_checkpoint: str = None, # either training checkpoint or final adapter 58 | prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca. 59 | ): 60 | if int(os.environ.get("LOCAL_RANK", 0)) == 0: 61 | print( 62 | f"Training Alpaca-LoRA model with params:\n" 63 | f"base_model: {base_model}\n" 64 | f"data_path: {data_path}\n" 65 | f"output_dir: {output_dir}\n" 66 | f"batch_size: {batch_size}\n" 67 | f"micro_batch_size: {micro_batch_size}\n" 68 | f"num_epochs: {num_epochs}\n" 69 | f"learning_rate: {learning_rate}\n" 70 | f"cutoff_len: {cutoff_len}\n" 71 | f"val_set_size: {val_set_size}\n" 72 | f"lora_r: {lora_r}\n" 73 | f"lora_alpha: {lora_alpha}\n" 74 | f"lora_dropout: {lora_dropout}\n" 75 | f"lora_target_modules: {lora_target_modules}\n" 76 | f"train_on_inputs: {train_on_inputs}\n" 77 | f"add_eos_token: {add_eos_token}\n" 78 | f"group_by_length: {group_by_length}\n" 79 | f"wandb_project: {wandb_project}\n" 80 | f"wandb_run_name: {wandb_run_name}\n" 81 | f"wandb_watch: {wandb_watch}\n" 82 | f"wandb_log_model: {wandb_log_model}\n" 83 | f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" 84 | f"prompt template: {prompt_template_name}\n" 85 | ) 86 | assert ( 87 | base_model 88 | ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" 89 | gradient_accumulation_steps = batch_size // micro_batch_size 90 | 91 | prompter = Prompter(prompt_template_name) 92 | 93 | device_map = "auto" 94 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 95 | ddp = world_size != 1 96 | if ddp: 97 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 98 | gradient_accumulation_steps = gradient_accumulation_steps // world_size 99 | 100 | # Check if parameter passed or if set within environ 101 | use_wandb = len(wandb_project) > 0 or ( 102 | "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 103 | ) 104 | # Only overwrite environ if wandb param passed 105 | if len(wandb_project) > 0: 106 | os.environ["WANDB_PROJECT"] = wandb_project 107 | if len(wandb_watch) > 0: 108 | os.environ["WANDB_WATCH"] = wandb_watch 109 | if len(wandb_log_model) > 0: 110 | os.environ["WANDB_LOG_MODEL"] = wandb_log_model 111 | 112 | model = LlamaForCausalLM.from_pretrained( 113 | base_model, 114 | load_in_8bit=True, 115 | torch_dtype=torch.float16, 116 | device_map=device_map, 117 | ) 118 | 119 | tokenizer = AutoTokenizer.from_pretrained(base_model) 120 | 121 | tokenizer.pad_token_id = ( 122 | 0 # unk. we want this to be different from the eos token 123 | ) 124 | tokenizer.padding_side = "left" # Allow batched inference 125 | 126 | def tokenize(prompt, add_eos_token=True): 127 | # there's probably a way to do this with the tokenizer settings 128 | # but again, gotta move fast 129 | result = tokenizer( 130 | prompt, 131 | truncation=True, 132 | max_length=cutoff_len, 133 | padding=False, 134 | return_tensors=None, 135 | ) 136 | if ( 137 | result["input_ids"][-1] != tokenizer.eos_token_id 138 | and len(result["input_ids"]) < cutoff_len 139 | and add_eos_token 140 | ): 141 | result["input_ids"].append(tokenizer.eos_token_id) 142 | result["attention_mask"].append(1) 143 | 144 | result["labels"] = result["input_ids"].copy() 145 | 146 | return result 147 | 148 | def generate_and_tokenize_prompt(data_point): 149 | full_prompt = prompter.generate_prompt( 150 | data_point["instruction"], 151 | data_point["input"], 152 | data_point["output"], 153 | ) 154 | tokenized_full_prompt = tokenize(full_prompt) 155 | if not train_on_inputs: 156 | user_prompt = prompter.generate_prompt( 157 | data_point["instruction"], data_point["input"] 158 | ) 159 | tokenized_user_prompt = tokenize( 160 | user_prompt, add_eos_token=add_eos_token 161 | ) 162 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 163 | 164 | if add_eos_token: 165 | user_prompt_len -= 1 166 | 167 | tokenized_full_prompt["labels"] = [ 168 | -100 169 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 170 | user_prompt_len: 171 | ] # could be sped up, probably 172 | return tokenized_full_prompt 173 | 174 | model = prepare_model_for_int8_training(model) 175 | 176 | config = LoraConfig( 177 | r=lora_r, 178 | lora_alpha=lora_alpha, 179 | target_modules=lora_target_modules, 180 | lora_dropout=lora_dropout, 181 | bias="none", 182 | task_type="CAUSAL_LM", 183 | ) 184 | model = get_peft_model(model, config) 185 | 186 | if data_path.endswith(".json") or data_path.endswith(".jsonl"): 187 | data = load_dataset("json", data_files=data_path) 188 | else: 189 | data = load_dataset(data_path) 190 | 191 | if resume_from_checkpoint: 192 | # Check the available weights and load them 193 | checkpoint_name = os.path.join( 194 | resume_from_checkpoint, "pytorch_model.bin" 195 | ) # Full checkpoint 196 | if not os.path.exists(checkpoint_name): 197 | checkpoint_name = os.path.join( 198 | resume_from_checkpoint, "adapter_model.bin" 199 | ) # only LoRA model - LoRA config above has to fit 200 | resume_from_checkpoint = ( 201 | False # So the trainer won't try loading its state 202 | ) 203 | # The two files above have a different name depending on how they were saved, but are actually the same. 204 | if os.path.exists(checkpoint_name): 205 | print(f"Restarting from {checkpoint_name}") 206 | adapters_weights = torch.load(checkpoint_name) 207 | set_peft_model_state_dict(model, adapters_weights) 208 | else: 209 | print(f"Checkpoint {checkpoint_name} not found") 210 | 211 | model.print_trainable_parameters() # Be more transparent about the % of trainable params. 212 | 213 | if val_set_size > 0: 214 | train_val = data["train"].train_test_split( 215 | test_size=val_set_size, shuffle=True, seed=42 216 | ) 217 | train_data = ( 218 | train_val["train"].shuffle().map(generate_and_tokenize_prompt) 219 | ) 220 | val_data = ( 221 | train_val["test"].shuffle().map(generate_and_tokenize_prompt) 222 | ) 223 | else: 224 | train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) 225 | val_data = None 226 | 227 | if not ddp and torch.cuda.device_count() > 1: 228 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available 229 | model.is_parallelizable = True 230 | model.model_parallel = True 231 | 232 | trainer = transformers.Trainer( 233 | model=model, 234 | train_dataset=train_data, 235 | eval_dataset=val_data, 236 | args=transformers.TrainingArguments( 237 | per_device_train_batch_size=micro_batch_size, 238 | gradient_accumulation_steps=gradient_accumulation_steps, 239 | warmup_steps=100, 240 | num_train_epochs=num_epochs, 241 | learning_rate=learning_rate, 242 | fp16=True, 243 | logging_steps=10, 244 | optim="adamw_torch", 245 | evaluation_strategy="steps" if val_set_size > 0 else "no", 246 | save_strategy="steps", 247 | eval_steps=200 if val_set_size > 0 else None, 248 | save_steps=200, 249 | output_dir=output_dir, 250 | save_total_limit=3, 251 | load_best_model_at_end=True if val_set_size > 0 else False, 252 | ddp_find_unused_parameters=False if ddp else None, 253 | group_by_length=group_by_length, 254 | report_to="wandb" if use_wandb else None, 255 | run_name=wandb_run_name if use_wandb else None, 256 | save_safetensors=False 257 | ), 258 | data_collator=transformers.DataCollatorForSeq2Seq( 259 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 260 | ), 261 | ) 262 | model.config.use_cache = False 263 | 264 | ''' 265 | old_state_dict = model.state_dict 266 | model.state_dict = ( 267 | lambda self, *_, **__: get_peft_model_state_dict( 268 | self, old_state_dict() 269 | ) 270 | ).__get__(model, type(model)) 271 | ''' 272 | 273 | if torch.__version__ >= "2" and sys.platform != "win32": 274 | model = torch.compile(model) 275 | 276 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 277 | 278 | model.save_pretrained(output_dir) 279 | 280 | print( 281 | "\n If there's a warning about missing keys above, please disregard :)" 282 | ) 283 | 284 | 285 | if __name__ == "__main__": 286 | fire.Fire(train) 287 | -------------------------------------------------------------------------------- /model.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/SiLLM/2c952ec5dc6e78bf6ba2481f4496b522c39c52c8/model.PNG -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | appdirs 3 | loralib 4 | bitsandbytes 5 | black 6 | black[jupyter] 7 | datasets 8 | fire 9 | git+https://github.com/huggingface/peft.git 10 | transformers>=4.28.0 11 | sentencepiece 12 | gradio -------------------------------------------------------------------------------- /templates/README.md: -------------------------------------------------------------------------------- 1 | # Prompt templates 2 | 3 | This directory contains template styles for the prompts used to finetune LoRA models. 4 | 5 | ## Format 6 | 7 | A template is described via a JSON file with the following keys: 8 | 9 | - `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders. 10 | - `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders. 11 | - `description`: A short description of the template, with possible use cases. 12 | - `response_split`: The text to use as separator when cutting real response from the model output. 13 | 14 | No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest. 15 | 16 | ## Example template 17 | 18 | The default template, used unless otherwise specified, is `alpaca.json` 19 | 20 | ```json 21 | { 22 | "description": "Template used by Alpaca-LoRA.", 23 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 24 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 25 | "response_split": "### Response:" 26 | } 27 | 28 | ``` 29 | 30 | ## Current templates 31 | 32 | ### alpaca 33 | 34 | Default template used for generic LoRA fine tunes so far. 35 | 36 | ### alpaca_legacy 37 | 38 | Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments. 39 | 40 | ### alpaca_short 41 | 42 | A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome. 43 | 44 | ### vigogne 45 | 46 | The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning. 47 | -------------------------------------------------------------------------------- /templates/Text_translation.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template for Text Machine Translation.", 3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### source sentence:\n{input}\n\n### target sentence:\n", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### target sentence:\n", 5 | "response_split": "### target sentence:" 6 | } -------------------------------------------------------------------------------- /templates/alpaca.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by Alpaca-LoRA.", 3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /templates/alpaca_legacy.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Legacy template, used by Original Alpaca repository.", 3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /templates/alpaca_short.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "A shorter template to experiment with.", 3 | "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 4 | "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /templates/vigogne.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "French template, used by Vigogne for finetuning.", 3 | "prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n", 4 | "prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n", 5 | "response_split": "### Réponse:" 6 | } 7 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # Directory for helpers modules 2 | 3 | ## prompter.py 4 | 5 | Prompter class, a template manager. 6 | 7 | `from utils.prompter import Prompter` 8 | 9 | ## callbacks.py 10 | 11 | Helpers to support streaming generate output. 12 | 13 | `from utils.callbacks import Iteratorize, Stream` 14 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/SiLLM/2c952ec5dc6e78bf6ba2481f4496b522c39c52c8/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/SiLLM/2c952ec5dc6e78bf6ba2481f4496b522c39c52c8/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/callbacks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/SiLLM/2c952ec5dc6e78bf6ba2481f4496b522c39c52c8/utils/__pycache__/callbacks.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/prompter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/SiLLM/2c952ec5dc6e78bf6ba2481f4496b522c39c52c8/utils/__pycache__/prompter.cpython-38.pyc -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to support streaming generate output. 3 | Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py 4 | """ 5 | 6 | import gc 7 | import traceback 8 | from queue import Queue 9 | from threading import Thread 10 | 11 | import torch 12 | import transformers 13 | 14 | 15 | class Stream(transformers.StoppingCriteria): 16 | def __init__(self, callback_func=None): 17 | self.callback_func = callback_func 18 | 19 | def __call__(self, input_ids, scores) -> bool: 20 | if self.callback_func is not None: 21 | self.callback_func(input_ids[0]) 22 | return False 23 | 24 | 25 | class Iteratorize: 26 | 27 | """ 28 | Transforms a function that takes a callback 29 | into a lazy iterator (generator). 30 | """ 31 | 32 | def __init__(self, func, kwargs={}, callback=None): 33 | self.mfunc = func 34 | self.c_callback = callback 35 | self.q = Queue() 36 | self.sentinel = object() 37 | self.kwargs = kwargs 38 | self.stop_now = False 39 | 40 | def _callback(val): 41 | if self.stop_now: 42 | raise ValueError 43 | self.q.put(val) 44 | 45 | def gentask(): 46 | try: 47 | ret = self.mfunc(callback=_callback, **self.kwargs) 48 | except ValueError: 49 | pass 50 | except: 51 | traceback.print_exc() 52 | pass 53 | 54 | self.q.put(self.sentinel) 55 | if self.c_callback: 56 | self.c_callback(ret) 57 | 58 | self.thread = Thread(target=gentask) 59 | self.thread.start() 60 | 61 | def __iter__(self): 62 | return self 63 | 64 | def __next__(self): 65 | obj = self.q.get(True, None) 66 | if obj is self.sentinel: 67 | raise StopIteration 68 | else: 69 | return obj 70 | 71 | def __enter__(self): 72 | return self 73 | 74 | def __exit__(self, exc_type, exc_val, exc_tb): 75 | self.stop_now = True 76 | -------------------------------------------------------------------------------- /utils/prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | A dedicated helper to manage templates and prompt building. 3 | """ 4 | 5 | import json 6 | import os.path as osp 7 | from typing import Union 8 | 9 | 10 | class Prompter(object): 11 | __slots__ = ("template", "_verbose") 12 | 13 | def __init__(self, template_name: str = "", verbose: bool = False): 14 | self._verbose = verbose 15 | if not template_name: 16 | # Enforce the default here, so the constructor can be called with '' and will not break. 17 | template_name = "alpaca" 18 | file_name = osp.join("templates", f"{template_name}.json") 19 | if not osp.exists(file_name): 20 | raise ValueError(f"Can't read {file_name}") 21 | with open(file_name) as fp: 22 | self.template = json.load(fp) 23 | if self._verbose: 24 | print( 25 | f"Using prompt template {template_name}: {self.template['description']}" 26 | ) 27 | 28 | def generate_prompt( 29 | self, 30 | instruction: str, 31 | input: Union[None, str] = None, 32 | label: Union[None, str] = None, 33 | ) -> str: 34 | # returns the full prompt from instruction and optional input 35 | # if a label (=response, =output) is provided, it's also appended. 36 | if input: 37 | res = self.template["prompt_input"].format( 38 | instruction=instruction, input=input 39 | ) 40 | else: 41 | res = self.template["prompt_no_input"].format( 42 | instruction=instruction 43 | ) 44 | if label: 45 | res = f"{res}{label}" 46 | if self._verbose: 47 | print(res) 48 | return res 49 | 50 | def get_response(self, output: str) -> str: 51 | return output.split(self.template["response_split"])[1].strip() 52 | --------------------------------------------------------------------------------