├── HMT-SiLLM.py
├── HMT-SiLLM.sh
├── HMT_Policy
├── L11_K8.json
├── L2_K4.json
├── L3_K6.json
├── L5_K6.json
├── L7_K6.json
└── L9_K8.json
├── Model_Framework(2).pdf
├── README.md
├── SFT.sh
├── SFT_data
└── DeEn_data.json
├── Wait-k-SiLLM.py
├── Wait-k-SiLLM.sh
├── finetune.py
├── model.PNG
├── requirements.txt
├── templates
├── README.md
├── Text_translation.json
├── alpaca.json
├── alpaca_legacy.json
├── alpaca_short.json
└── vigogne.json
├── test.json
└── utils
├── README.md
├── __init__.py
├── __pycache__
├── __init__.cpython-38.pyc
├── callbacks.cpython-38.pyc
└── prompter.cpython-38.pyc
├── callbacks.py
└── prompter.py
/HMT-SiLLM.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import pdb
4 | import fire
5 | import torch
6 | import transformers
7 | from peft import PeftModel
8 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, AutoTokenizer
9 | from datasets import load_dataset
10 | from utils.callbacks import Iteratorize, Stream
11 | from utils.prompter import Prompter
12 | import json
13 | import time
14 | if torch.cuda.is_available():
15 | device = "cuda"
16 | else:
17 | device = "cpu"
18 |
19 | try:
20 | if torch.backends.mps.is_available():
21 | device = "mps"
22 | except: # noqa: E722
23 | pass
24 |
25 |
26 | def main(
27 | load_8bit: bool = False,
28 | base_model: str = "",
29 | lora_weights: str = "tloen/alpaca-lora-7b",
30 | prompt_template: str = "", # The prompt template to use, will default to alpaca.
31 | data_path: str = "",
32 | output_translation_path: str="",
33 | Bottom: int=1,
34 | Top: int=3,
35 | ):
36 | base_model = base_model or os.environ.get("BASE_MODEL", "")
37 | assert (
38 | base_model
39 | ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
40 |
41 | prompter = Prompter(prompt_template)
42 | tokenizer = AutoTokenizer.from_pretrained(base_model)
43 | if device == "cuda":
44 | model = LlamaForCausalLM.from_pretrained(
45 | base_model,
46 | load_in_8bit=load_8bit,
47 | torch_dtype=torch.float16,
48 | device_map="auto",
49 | )
50 |
51 | model = PeftModel.from_pretrained(
52 | model,
53 | lora_weights,
54 | torch_dtype=torch.float16,
55 | )
56 |
57 | elif device == "mps":
58 | model = LlamaForCausalLM.from_pretrained(
59 | base_model,
60 | device_map={"": device},
61 | torch_dtype=torch.float16,
62 | )
63 | model = PeftModel.from_pretrained(
64 | model,
65 | lora_weights,
66 | device_map={"": device},
67 | torch_dtype=torch.float16,
68 | )
69 | else:
70 | model = LlamaForCausalLM.from_pretrained(
71 | base_model, device_map={"": device}, low_cpu_mem_usage=True
72 | )
73 | model = PeftModel.from_pretrained(
74 | model,
75 | lora_weights,
76 | device_map={"": device},
77 | )
78 |
79 | # unwind broken decapoda-research config
80 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
81 | model.config.bos_token_id = 1
82 | model.config.eos_token_id = 2
83 |
84 | if not load_8bit:
85 | model.half() # seems to fix bugs for some users.
86 |
87 | model.eval()
88 | if torch.__version__ >= "2" and sys.platform != "win32":
89 | model = torch.compile(model)
90 |
91 | def evaluate(
92 | instruction,
93 | input=None,
94 | output=None,
95 | suppress_tokens=None,
96 | temperature=0.1,
97 | top_p=0.75,
98 | top_k=40,
99 | num_beams=4,
100 | max_new_tokens=128,
101 | stream_output=False,
102 | **kwargs,
103 | ):
104 | prompt = prompter.generate_prompt(instruction, input, output)
105 | inputs = tokenizer(prompt, return_tensors="pt")
106 | input_ids = inputs["input_ids"].to(device)
107 | generation_config = GenerationConfig(
108 | num_beams=num_beams,
109 | suppress_tokens=suppress_tokens,
110 | **kwargs,
111 | )
112 |
113 | # Without streaming
114 | with torch.no_grad():
115 | generation_output = model.generate(
116 | input_ids=input_ids,
117 | generation_config=generation_config,
118 | return_dict_in_generate=True,
119 | output_scores=True,
120 | max_new_tokens=max_new_tokens,
121 | )
122 | s = generation_output.sequences[0]
123 | output = tokenizer.decode(s)
124 | return prompter.get_response(output), s.size(-1) - input_ids.size(-1)
125 |
126 | def HMT_policy(
127 | instruction,
128 | input=None,
129 | policy=[],
130 | Lower=1,
131 | Upper=3,
132 | num_beams=1,
133 | max_new_tokens=256
134 | ):
135 | cur_target_str = ""
136 | tokenized_input = input
137 | i = 0
138 | src_len = len(input.split())
139 | tmp_max_new_tokens = 1
140 | rw_seq = []
141 | first_time = True
142 |
143 | tran_tgt_seqLen = len(policy)
144 | supress_tokens = [2]
145 | total_tokens = 0
146 | for i in range(tran_tgt_seqLen):
147 | limited_policy = policy[i]
148 | if policy[i] < Lower+i:
149 | limited_policy = Lower+i
150 | elif policy[i] > Upper+i:
151 | limited_policy = Upper+i
152 | limited_policy = min(limited_policy, src_len)
153 | cut_input = ' '.join(input.split()[:limited_policy])
154 | tmp_max_new_tokens = 3
155 | if i >= (tran_tgt_seqLen - 1):
156 | tmp_max_new_tokens = max_new_tokens
157 | supress_tokens = None
158 | cur_target_str, tmp_size = evaluate(instruction, cut_input, output=cur_target_str, suppress_tokens=None, num_beams=num_beams, max_new_tokens=tmp_max_new_tokens)
159 | total_tokens += tmp_size
160 | if i < (tran_tgt_seqLen - 1):
161 | cur_target_str = ' '.join(cur_target_str.split()[:i+1])
162 | rw_seq.append(limited_policy)
163 | if cur_target_str.find('') != -1:
164 | break
165 | else:
166 | tmp_size = len(cur_target_str.split()) - i
167 | rw_seq = rw_seq + [src_len] * tmp_size
168 |
169 | rw_seq.append(src_len)
170 | return rw_seq, cur_target_str, total_tokens
171 |
172 | data = load_dataset("json", data_files=data_path)
173 | test_data = data["train"]
174 | output_text = []
175 | j = 1
176 | total_generate_tokens = 0
177 | total_generate_words = 0
178 | start_time = time.time()
179 | for item_data in test_data:
180 | print('sample' + str(j))
181 | j += 1
182 | tmp_result = HMT_policy(item_data["instruction"], item_data["input"], item_data['policy'], Bottom, Top, num_beams=1, max_new_tokens=1024)
183 | total_generate_tokens += tmp_result[2]
184 | total_generate_words += len(tmp_result[1].split(' '))
185 | index = tmp_result[1].find('\n')
186 | tmp_str = tmp_result[1]
187 | if index!=-1:
188 | tmp_str = tmp_result[1][:index]
189 | output_text.append({'rw': tmp_result[0], 'translation': tmp_str})
190 | end_time = time.time()
191 | with open(output_translation_path, "w", encoding='utf-8') as fp:
192 | json.dump(output_text, fp, indent=4, ensure_ascii=False)
193 |
194 | print('Total time: '+str(end_time-start_time) + 'Total_words: '+str(total_generate_words))
195 | if __name__ == "__main__":
196 | fire.Fire(main)
197 |
198 |
--------------------------------------------------------------------------------
/HMT-SiLLM.sh:
--------------------------------------------------------------------------------
1 | Base_Model=/path/base_model
2 | LoRA_Weithts=/path/LoRA_weights
3 | Output_Translation=/path/output
4 | Test_Data=./HMT_Policy/L2_K4.json
5 | Bottom=1
6 | Top=3
7 |
8 | python HMT-SiLLM.py \
9 | --base_model ${Base_Model} \
10 | --lora_weights ${LoRA_Weithts} \
11 | --prompt_template 'Text_translation' \
12 | --Bottom ${Bottom} \
13 | --Top ${Top} \
14 | --data_path ${Test_Data} \
15 | --output_translation_path ${Output_Translation}
16 |
--------------------------------------------------------------------------------
/Model_Framework(2).pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/SiLLM/2c952ec5dc6e78bf6ba2481f4496b522c39c52c8/Model_Framework(2).pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SiLLM
2 |
3 | Source code for our paper "SiLLM: Large Language Models for Simultaneous Machine Translation".
4 |
5 |
6 |

7 |
8 |
9 | The framework of SiLLM incorporates the LLM to achieve the Simultaneous Machine Translation. It generates the translations under the guidance of the policy decided by the conventional Simultaneous Machine Translation Model.
10 |
11 | Our method is implemented based on the open-source toolkit [Alpaca-LoRA](https://github.com/tloen/alpaca-lora).
12 |
13 | ## Requirements and Installation
14 |
15 | * Python version = 3.8
16 |
17 | * PyTorch version = 2.2
18 |
19 | * Install our library:
20 |
21 | ```
22 | git clone https://github.com/ictnlp/SiLLM.git
23 | cd SiLLM
24 | pip install -r requirements.txt
25 | ```
26 |
27 | ## Quick Start
28 |
29 | ### Fine-tune
30 |
31 | We sample 100k data for fine-tuning LLM from WMT15 German-English (download [here](https://www.statmt.org/wmt15)) and MuST-C English-German (download [here](https://mt.fbk.eu/must-c/)), respectively. In the given example, we sample only 50k of data to provide the data format.
32 |
33 |
34 | We perform SFT for WMT15 German-English dataset using the script:
35 | ```
36 | bash finetune.sh
37 | ```
38 |
39 | ### Wait-k-SiLLM
40 | We can execute the Wait-k policy with LLM by running the following script:
41 | ```
42 | bash Wait-k-SiLLM.sh
43 | ```
44 |
45 |
46 | ### HMT-SiLLM
47 | We can execute the HMT policy with LLM and get the outputs by running the following script:
48 | ```
49 | bash HMT-SiLLM.sh
50 | ```
51 |
52 |
53 | ## Citation
54 | ```
55 | @misc{guo2024sillm,
56 | title={SiLLM: Large Language Models for Simultaneous Machine Translation},
57 | author={Shoutao Guo and Shaolei Zhang and Zhengrui Ma and Min Zhang and Yang Feng},
58 | year={2024},
59 | eprint={2402.13036},
60 | archivePrefix={arXiv},
61 | primaryClass={cs.CL}
62 | }
63 | ```
64 |
--------------------------------------------------------------------------------
/SFT.sh:
--------------------------------------------------------------------------------
1 | Base_Model=/path/base_model
2 | LoRA_Weithts=/path/LoRA_weights
3 | Data_File=./SFT_Data
4 |
5 | python finetune.py \
6 | --base_model ${Base_Model} \
7 | --data_path ${Data_File} \
8 | --output_dir ${LoRA_Weithts} \
9 | --batch_size 128 \
10 | --micro_batch_size 4 \
11 | --num_epochs 10 \
12 | --learning_rate 1e-4 \
13 | --cutoff_len 1024 \
14 | --val_set_size 2000 \
15 | --lora_r 8 \
16 | --cutoff_len 1024 \
17 | --lora_alpha 16 \
18 | --lora_dropout 0.05 \
19 | --lora_target_modules '[q_proj,k_proj,v_proj,o_proj]' \
20 | --train_on_inputs \
21 | --group_by_length \
22 | --train_on_inputs False \
23 | --prompt_template_name 'Text_translation'
24 |
--------------------------------------------------------------------------------
/Wait-k-SiLLM.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import pdb
4 | import fire
5 | import torch
6 | import transformers
7 | from peft import PeftModel
8 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, AutoTokenizer
9 | from datasets import load_dataset
10 | from utils.callbacks import Iteratorize, Stream
11 | from utils.prompter import Prompter
12 | import json
13 |
14 | if torch.cuda.is_available():
15 | device = "cuda"
16 | else:
17 | device = "cpu"
18 |
19 | try:
20 | if torch.backends.mps.is_available():
21 | device = "mps"
22 | except: # noqa: E722
23 | pass
24 |
25 |
26 | def main(
27 | load_8bit: bool = False,
28 | base_model: str = "",
29 | lora_weights: str = "tloen/alpaca-lora-7b",
30 | prompt_template: str = "", # The prompt template to use, will default to alpaca.
31 | data_path: str = "",
32 | output_translation_path: str="",
33 | waitk: int=1,
34 | ):
35 | base_model = base_model or os.environ.get("BASE_MODEL", "")
36 | assert (
37 | base_model
38 | ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
39 |
40 | prompter = Prompter(prompt_template)
41 | tokenizer = AutoTokenizer.from_pretrained(base_model)
42 | if device == "cuda":
43 | model = LlamaForCausalLM.from_pretrained(
44 | base_model,
45 | load_in_8bit=load_8bit,
46 | torch_dtype=torch.float16,
47 | device_map="auto",
48 | )
49 | model = PeftModel.from_pretrained(
50 | model,
51 | lora_weights,
52 | torch_dtype=torch.float16,
53 | )
54 | elif device == "mps":
55 | model = LlamaForCausalLM.from_pretrained(
56 | base_model,
57 | device_map={"": device},
58 | torch_dtype=torch.float16,
59 | )
60 | model = PeftModel.from_pretrained(
61 | model,
62 | lora_weights,
63 | device_map={"": device},
64 | torch_dtype=torch.float16,
65 | )
66 | else:
67 | model = LlamaForCausalLM.from_pretrained(
68 | base_model, device_map={"": device}, low_cpu_mem_usage=True
69 | )
70 | model = PeftModel.from_pretrained(
71 | model,
72 | lora_weights,
73 | device_map={"": device},
74 | )
75 |
76 | # unwind broken decapoda-research config
77 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
78 | model.config.bos_token_id = 1
79 | model.config.eos_token_id = 2
80 |
81 | if not load_8bit:
82 | model.half() # seems to fix bugs for some users.
83 |
84 | model.eval()
85 | if torch.__version__ >= "2" and sys.platform != "win32":
86 | model = torch.compile(model)
87 |
88 | def evaluate(
89 | instruction,
90 | input=None,
91 | output=None,
92 | suppress_tokens=None,
93 | temperature=0.1,
94 | top_p=0.75,
95 | top_k=40,
96 | num_beams=4,
97 | max_new_tokens=128,
98 | stream_output=False,
99 | **kwargs,
100 | ):
101 | prompt = prompter.generate_prompt(instruction, input, output)
102 | inputs = tokenizer(prompt, return_tensors="pt")
103 | input_ids = inputs["input_ids"].to(device)
104 | generation_config = GenerationConfig(
105 | num_beams=num_beams,
106 | suppress_tokens=suppress_tokens,
107 | **kwargs,
108 | )
109 |
110 | # Without streaming
111 | with torch.no_grad():
112 | generation_output = model.generate(
113 | input_ids=input_ids,
114 | generation_config=generation_config,
115 | return_dict_in_generate=True,
116 | output_scores=True,
117 | max_new_tokens=max_new_tokens,
118 | )
119 | s = generation_output.sequences[0]
120 | output = tokenizer.decode(s)
121 | return prompter.get_response(output), s.size(-1) - input_ids.size(-1)
122 |
123 | def Waitk_policy(
124 | instruction,
125 | input=None,
126 | num_beams=1,
127 | waitk=1,
128 | max_new_tokens=256
129 | ):
130 | cur_target_str = ""
131 | tokenized_input = input
132 | i = 0
133 | src_len = len(input.split())
134 | tmp_max_new_tokens = 1
135 | rw_seq = []
136 | first_time = True
137 | suppress_tokens=[2]
138 | while (i+waitk <= src_len) or first_time:
139 | cut_input = ' '.join(input.split()[:min(i+waitk, src_len)])
140 | tmp_max_new_tokens = 5
141 | if i+waitk >= src_len:
142 | tmp_max_new_tokens = max_new_tokens
143 | suppress_tokens=None
144 | cur_target_str, tmp_size = evaluate(instruction, cut_input, output=cur_target_str, suppress_tokens=suppress_tokens, num_beams=num_beams, max_new_tokens=tmp_max_new_tokens)
145 | if i+waitk < src_len:
146 | cur_target_str = ' '.join(cur_target_str.split()[:i+1])
147 | rw_seq.append(i+waitk)
148 | if cur_target_str.find('') != -1:
149 | break
150 | else:
151 | tmp_size = len(cur_target_str.split()) - i
152 | rw_seq = rw_seq + [src_len] * tmp_size
153 | first_time=False
154 | i += 1
155 | rw_seq.append(src_len)
156 |
157 | return rw_seq, cur_target_str
158 | data = load_dataset("json", data_files=data_path)
159 | test_data = data["train"]
160 | output_text = []
161 | j = 1
162 | for item_data in test_data:
163 | print('sample' + str(j))
164 | j += 1
165 | tmp_result = Waitk_policy(item_data["instruction"], item_data["input"], num_beams=1, waitk=waitk, max_new_tokens=1024)
166 | index = tmp_result[1].find('\n')
167 | tmp_str = tmp_result[1]
168 | if index!=-1:
169 | tmp_str = tmp_result[1][:index]
170 | output_text.append({'rw': tmp_result[0], 'translation': tmp_str})
171 | with open(output_translation_path, "w", encoding='utf-8') as fp:
172 | json.dump(output_text, fp, indent=4, ensure_ascii=False)
173 |
174 | if __name__ == "__main__":
175 | fire.Fire(main)
176 |
177 |
--------------------------------------------------------------------------------
/Wait-k-SiLLM.sh:
--------------------------------------------------------------------------------
1 | k=11
2 | Base_Model=/path/base_model
3 | LoRA_Weithts=/path/LoRA_weights
4 | Output_Translation=/path/output
5 | Test_Data=./test.json
6 |
7 | python Wait-k-SiLLM.py \
8 | --base_model ${Base_Model} \
9 | --lora_weights ${LoRA_Weithts} \
10 | --prompt_template 'Text_translation' \
11 | --data_path ${Test_Data} \
12 | --output_translation_path ${Output_Translation} \
13 | --waitk ${k}
14 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from typing import List
4 |
5 | import fire
6 | import torch
7 | import transformers
8 | from datasets import load_dataset
9 |
10 | """
11 | Unused imports:
12 | import torch.nn as nn
13 | import bitsandbytes as bnb
14 | """
15 |
16 | from peft import (
17 | LoraConfig,
18 | get_peft_model,
19 | get_peft_model_state_dict,
20 | prepare_model_for_int8_training,
21 | set_peft_model_state_dict,
22 | )
23 | from transformers import LlamaForCausalLM, LlamaTokenizer, AutoTokenizer
24 |
25 | from utils.prompter import Prompter
26 |
27 |
28 | def train(
29 | # model/data params
30 | base_model: str = "", # the only required argument
31 | data_path: str = "yahma/alpaca-cleaned",
32 | output_dir: str = "./lora-alpaca",
33 | # training hyperparams
34 | batch_size: int = 128,
35 | micro_batch_size: int = 4,
36 | num_epochs: int = 3,
37 | learning_rate: float = 3e-4,
38 | cutoff_len: int = 256,
39 | val_set_size: int = 2000,
40 | # lora hyperparams
41 | lora_r: int = 8,
42 | lora_alpha: int = 16,
43 | lora_dropout: float = 0.05,
44 | lora_target_modules: List[str] = [
45 | "q_proj",
46 | "v_proj",
47 | ],
48 | # llm hyperparams
49 | train_on_inputs: bool = True, # if False, masks out inputs in loss
50 | add_eos_token: bool = False,
51 | group_by_length: bool = False, # faster, but produces an odd training loss curve
52 | # wandb params
53 | wandb_project: str = "",
54 | wandb_run_name: str = "",
55 | wandb_watch: str = "", # options: false | gradients | all
56 | wandb_log_model: str = "", # options: false | true
57 | resume_from_checkpoint: str = None, # either training checkpoint or final adapter
58 | prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
59 | ):
60 | if int(os.environ.get("LOCAL_RANK", 0)) == 0:
61 | print(
62 | f"Training Alpaca-LoRA model with params:\n"
63 | f"base_model: {base_model}\n"
64 | f"data_path: {data_path}\n"
65 | f"output_dir: {output_dir}\n"
66 | f"batch_size: {batch_size}\n"
67 | f"micro_batch_size: {micro_batch_size}\n"
68 | f"num_epochs: {num_epochs}\n"
69 | f"learning_rate: {learning_rate}\n"
70 | f"cutoff_len: {cutoff_len}\n"
71 | f"val_set_size: {val_set_size}\n"
72 | f"lora_r: {lora_r}\n"
73 | f"lora_alpha: {lora_alpha}\n"
74 | f"lora_dropout: {lora_dropout}\n"
75 | f"lora_target_modules: {lora_target_modules}\n"
76 | f"train_on_inputs: {train_on_inputs}\n"
77 | f"add_eos_token: {add_eos_token}\n"
78 | f"group_by_length: {group_by_length}\n"
79 | f"wandb_project: {wandb_project}\n"
80 | f"wandb_run_name: {wandb_run_name}\n"
81 | f"wandb_watch: {wandb_watch}\n"
82 | f"wandb_log_model: {wandb_log_model}\n"
83 | f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
84 | f"prompt template: {prompt_template_name}\n"
85 | )
86 | assert (
87 | base_model
88 | ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
89 | gradient_accumulation_steps = batch_size // micro_batch_size
90 |
91 | prompter = Prompter(prompt_template_name)
92 |
93 | device_map = "auto"
94 | world_size = int(os.environ.get("WORLD_SIZE", 1))
95 | ddp = world_size != 1
96 | if ddp:
97 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
98 | gradient_accumulation_steps = gradient_accumulation_steps // world_size
99 |
100 | # Check if parameter passed or if set within environ
101 | use_wandb = len(wandb_project) > 0 or (
102 | "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
103 | )
104 | # Only overwrite environ if wandb param passed
105 | if len(wandb_project) > 0:
106 | os.environ["WANDB_PROJECT"] = wandb_project
107 | if len(wandb_watch) > 0:
108 | os.environ["WANDB_WATCH"] = wandb_watch
109 | if len(wandb_log_model) > 0:
110 | os.environ["WANDB_LOG_MODEL"] = wandb_log_model
111 |
112 | model = LlamaForCausalLM.from_pretrained(
113 | base_model,
114 | load_in_8bit=True,
115 | torch_dtype=torch.float16,
116 | device_map=device_map,
117 | )
118 |
119 | tokenizer = AutoTokenizer.from_pretrained(base_model)
120 |
121 | tokenizer.pad_token_id = (
122 | 0 # unk. we want this to be different from the eos token
123 | )
124 | tokenizer.padding_side = "left" # Allow batched inference
125 |
126 | def tokenize(prompt, add_eos_token=True):
127 | # there's probably a way to do this with the tokenizer settings
128 | # but again, gotta move fast
129 | result = tokenizer(
130 | prompt,
131 | truncation=True,
132 | max_length=cutoff_len,
133 | padding=False,
134 | return_tensors=None,
135 | )
136 | if (
137 | result["input_ids"][-1] != tokenizer.eos_token_id
138 | and len(result["input_ids"]) < cutoff_len
139 | and add_eos_token
140 | ):
141 | result["input_ids"].append(tokenizer.eos_token_id)
142 | result["attention_mask"].append(1)
143 |
144 | result["labels"] = result["input_ids"].copy()
145 |
146 | return result
147 |
148 | def generate_and_tokenize_prompt(data_point):
149 | full_prompt = prompter.generate_prompt(
150 | data_point["instruction"],
151 | data_point["input"],
152 | data_point["output"],
153 | )
154 | tokenized_full_prompt = tokenize(full_prompt)
155 | if not train_on_inputs:
156 | user_prompt = prompter.generate_prompt(
157 | data_point["instruction"], data_point["input"]
158 | )
159 | tokenized_user_prompt = tokenize(
160 | user_prompt, add_eos_token=add_eos_token
161 | )
162 | user_prompt_len = len(tokenized_user_prompt["input_ids"])
163 |
164 | if add_eos_token:
165 | user_prompt_len -= 1
166 |
167 | tokenized_full_prompt["labels"] = [
168 | -100
169 | ] * user_prompt_len + tokenized_full_prompt["labels"][
170 | user_prompt_len:
171 | ] # could be sped up, probably
172 | return tokenized_full_prompt
173 |
174 | model = prepare_model_for_int8_training(model)
175 |
176 | config = LoraConfig(
177 | r=lora_r,
178 | lora_alpha=lora_alpha,
179 | target_modules=lora_target_modules,
180 | lora_dropout=lora_dropout,
181 | bias="none",
182 | task_type="CAUSAL_LM",
183 | )
184 | model = get_peft_model(model, config)
185 |
186 | if data_path.endswith(".json") or data_path.endswith(".jsonl"):
187 | data = load_dataset("json", data_files=data_path)
188 | else:
189 | data = load_dataset(data_path)
190 |
191 | if resume_from_checkpoint:
192 | # Check the available weights and load them
193 | checkpoint_name = os.path.join(
194 | resume_from_checkpoint, "pytorch_model.bin"
195 | ) # Full checkpoint
196 | if not os.path.exists(checkpoint_name):
197 | checkpoint_name = os.path.join(
198 | resume_from_checkpoint, "adapter_model.bin"
199 | ) # only LoRA model - LoRA config above has to fit
200 | resume_from_checkpoint = (
201 | False # So the trainer won't try loading its state
202 | )
203 | # The two files above have a different name depending on how they were saved, but are actually the same.
204 | if os.path.exists(checkpoint_name):
205 | print(f"Restarting from {checkpoint_name}")
206 | adapters_weights = torch.load(checkpoint_name)
207 | set_peft_model_state_dict(model, adapters_weights)
208 | else:
209 | print(f"Checkpoint {checkpoint_name} not found")
210 |
211 | model.print_trainable_parameters() # Be more transparent about the % of trainable params.
212 |
213 | if val_set_size > 0:
214 | train_val = data["train"].train_test_split(
215 | test_size=val_set_size, shuffle=True, seed=42
216 | )
217 | train_data = (
218 | train_val["train"].shuffle().map(generate_and_tokenize_prompt)
219 | )
220 | val_data = (
221 | train_val["test"].shuffle().map(generate_and_tokenize_prompt)
222 | )
223 | else:
224 | train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
225 | val_data = None
226 |
227 | if not ddp and torch.cuda.device_count() > 1:
228 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
229 | model.is_parallelizable = True
230 | model.model_parallel = True
231 |
232 | trainer = transformers.Trainer(
233 | model=model,
234 | train_dataset=train_data,
235 | eval_dataset=val_data,
236 | args=transformers.TrainingArguments(
237 | per_device_train_batch_size=micro_batch_size,
238 | gradient_accumulation_steps=gradient_accumulation_steps,
239 | warmup_steps=100,
240 | num_train_epochs=num_epochs,
241 | learning_rate=learning_rate,
242 | fp16=True,
243 | logging_steps=10,
244 | optim="adamw_torch",
245 | evaluation_strategy="steps" if val_set_size > 0 else "no",
246 | save_strategy="steps",
247 | eval_steps=200 if val_set_size > 0 else None,
248 | save_steps=200,
249 | output_dir=output_dir,
250 | save_total_limit=3,
251 | load_best_model_at_end=True if val_set_size > 0 else False,
252 | ddp_find_unused_parameters=False if ddp else None,
253 | group_by_length=group_by_length,
254 | report_to="wandb" if use_wandb else None,
255 | run_name=wandb_run_name if use_wandb else None,
256 | save_safetensors=False
257 | ),
258 | data_collator=transformers.DataCollatorForSeq2Seq(
259 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
260 | ),
261 | )
262 | model.config.use_cache = False
263 |
264 | '''
265 | old_state_dict = model.state_dict
266 | model.state_dict = (
267 | lambda self, *_, **__: get_peft_model_state_dict(
268 | self, old_state_dict()
269 | )
270 | ).__get__(model, type(model))
271 | '''
272 |
273 | if torch.__version__ >= "2" and sys.platform != "win32":
274 | model = torch.compile(model)
275 |
276 | trainer.train(resume_from_checkpoint=resume_from_checkpoint)
277 |
278 | model.save_pretrained(output_dir)
279 |
280 | print(
281 | "\n If there's a warning about missing keys above, please disregard :)"
282 | )
283 |
284 |
285 | if __name__ == "__main__":
286 | fire.Fire(train)
287 |
--------------------------------------------------------------------------------
/model.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/SiLLM/2c952ec5dc6e78bf6ba2481f4496b522c39c52c8/model.PNG
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | appdirs
3 | loralib
4 | bitsandbytes
5 | black
6 | black[jupyter]
7 | datasets
8 | fire
9 | git+https://github.com/huggingface/peft.git
10 | transformers>=4.28.0
11 | sentencepiece
12 | gradio
--------------------------------------------------------------------------------
/templates/README.md:
--------------------------------------------------------------------------------
1 | # Prompt templates
2 |
3 | This directory contains template styles for the prompts used to finetune LoRA models.
4 |
5 | ## Format
6 |
7 | A template is described via a JSON file with the following keys:
8 |
9 | - `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders.
10 | - `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders.
11 | - `description`: A short description of the template, with possible use cases.
12 | - `response_split`: The text to use as separator when cutting real response from the model output.
13 |
14 | No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest.
15 |
16 | ## Example template
17 |
18 | The default template, used unless otherwise specified, is `alpaca.json`
19 |
20 | ```json
21 | {
22 | "description": "Template used by Alpaca-LoRA.",
23 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
24 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
25 | "response_split": "### Response:"
26 | }
27 |
28 | ```
29 |
30 | ## Current templates
31 |
32 | ### alpaca
33 |
34 | Default template used for generic LoRA fine tunes so far.
35 |
36 | ### alpaca_legacy
37 |
38 | Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments.
39 |
40 | ### alpaca_short
41 |
42 | A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome.
43 |
44 | ### vigogne
45 |
46 | The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning.
47 |
--------------------------------------------------------------------------------
/templates/Text_translation.json:
--------------------------------------------------------------------------------
1 | {
2 | "description": "Template for Text Machine Translation.",
3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### source sentence:\n{input}\n\n### target sentence:\n",
4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### target sentence:\n",
5 | "response_split": "### target sentence:"
6 | }
--------------------------------------------------------------------------------
/templates/alpaca.json:
--------------------------------------------------------------------------------
1 | {
2 | "description": "Template used by Alpaca-LoRA.",
3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
5 | "response_split": "### Response:"
6 | }
7 |
--------------------------------------------------------------------------------
/templates/alpaca_legacy.json:
--------------------------------------------------------------------------------
1 | {
2 | "description": "Legacy template, used by Original Alpaca repository.",
3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:",
4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:",
5 | "response_split": "### Response:"
6 | }
7 |
--------------------------------------------------------------------------------
/templates/alpaca_short.json:
--------------------------------------------------------------------------------
1 | {
2 | "description": "A shorter template to experiment with.",
3 | "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
4 | "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n",
5 | "response_split": "### Response:"
6 | }
7 |
--------------------------------------------------------------------------------
/templates/vigogne.json:
--------------------------------------------------------------------------------
1 | {
2 | "description": "French template, used by Vigogne for finetuning.",
3 | "prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n",
4 | "prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n",
5 | "response_split": "### Réponse:"
6 | }
7 |
--------------------------------------------------------------------------------
/utils/README.md:
--------------------------------------------------------------------------------
1 | # Directory for helpers modules
2 |
3 | ## prompter.py
4 |
5 | Prompter class, a template manager.
6 |
7 | `from utils.prompter import Prompter`
8 |
9 | ## callbacks.py
10 |
11 | Helpers to support streaming generate output.
12 |
13 | `from utils.callbacks import Iteratorize, Stream`
14 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/SiLLM/2c952ec5dc6e78bf6ba2481f4496b522c39c52c8/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/SiLLM/2c952ec5dc6e78bf6ba2481f4496b522c39c52c8/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/callbacks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/SiLLM/2c952ec5dc6e78bf6ba2481f4496b522c39c52c8/utils/__pycache__/callbacks.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/prompter.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/SiLLM/2c952ec5dc6e78bf6ba2481f4496b522c39c52c8/utils/__pycache__/prompter.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to support streaming generate output.
3 | Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py
4 | """
5 |
6 | import gc
7 | import traceback
8 | from queue import Queue
9 | from threading import Thread
10 |
11 | import torch
12 | import transformers
13 |
14 |
15 | class Stream(transformers.StoppingCriteria):
16 | def __init__(self, callback_func=None):
17 | self.callback_func = callback_func
18 |
19 | def __call__(self, input_ids, scores) -> bool:
20 | if self.callback_func is not None:
21 | self.callback_func(input_ids[0])
22 | return False
23 |
24 |
25 | class Iteratorize:
26 |
27 | """
28 | Transforms a function that takes a callback
29 | into a lazy iterator (generator).
30 | """
31 |
32 | def __init__(self, func, kwargs={}, callback=None):
33 | self.mfunc = func
34 | self.c_callback = callback
35 | self.q = Queue()
36 | self.sentinel = object()
37 | self.kwargs = kwargs
38 | self.stop_now = False
39 |
40 | def _callback(val):
41 | if self.stop_now:
42 | raise ValueError
43 | self.q.put(val)
44 |
45 | def gentask():
46 | try:
47 | ret = self.mfunc(callback=_callback, **self.kwargs)
48 | except ValueError:
49 | pass
50 | except:
51 | traceback.print_exc()
52 | pass
53 |
54 | self.q.put(self.sentinel)
55 | if self.c_callback:
56 | self.c_callback(ret)
57 |
58 | self.thread = Thread(target=gentask)
59 | self.thread.start()
60 |
61 | def __iter__(self):
62 | return self
63 |
64 | def __next__(self):
65 | obj = self.q.get(True, None)
66 | if obj is self.sentinel:
67 | raise StopIteration
68 | else:
69 | return obj
70 |
71 | def __enter__(self):
72 | return self
73 |
74 | def __exit__(self, exc_type, exc_val, exc_tb):
75 | self.stop_now = True
76 |
--------------------------------------------------------------------------------
/utils/prompter.py:
--------------------------------------------------------------------------------
1 | """
2 | A dedicated helper to manage templates and prompt building.
3 | """
4 |
5 | import json
6 | import os.path as osp
7 | from typing import Union
8 |
9 |
10 | class Prompter(object):
11 | __slots__ = ("template", "_verbose")
12 |
13 | def __init__(self, template_name: str = "", verbose: bool = False):
14 | self._verbose = verbose
15 | if not template_name:
16 | # Enforce the default here, so the constructor can be called with '' and will not break.
17 | template_name = "alpaca"
18 | file_name = osp.join("templates", f"{template_name}.json")
19 | if not osp.exists(file_name):
20 | raise ValueError(f"Can't read {file_name}")
21 | with open(file_name) as fp:
22 | self.template = json.load(fp)
23 | if self._verbose:
24 | print(
25 | f"Using prompt template {template_name}: {self.template['description']}"
26 | )
27 |
28 | def generate_prompt(
29 | self,
30 | instruction: str,
31 | input: Union[None, str] = None,
32 | label: Union[None, str] = None,
33 | ) -> str:
34 | # returns the full prompt from instruction and optional input
35 | # if a label (=response, =output) is provided, it's also appended.
36 | if input:
37 | res = self.template["prompt_input"].format(
38 | instruction=instruction, input=input
39 | )
40 | else:
41 | res = self.template["prompt_no_input"].format(
42 | instruction=instruction
43 | )
44 | if label:
45 | res = f"{res}{label}"
46 | if self._verbose:
47 | print(res)
48 | return res
49 |
50 | def get_response(self, output: str) -> str:
51 | return output.split(self.template["response_split"])[1].strip()
52 |
--------------------------------------------------------------------------------