├── .gitignore ├── gguf ├── convert.sh └── download_model.py ├── requirements.txt ├── threads ├── phi3.txt ├── template_default.txt ├── template_chatml.txt ├── template_llama2.txt ├── finetune_prompts.py └── finetune_prompts_dpo.py ├── translators ├── base.py ├── m2m.py ├── madlad.py ├── nllb.py ├── towerinstruct.py ├── mbart.py ├── seamless_m4t_v2.py ├── gemini_pro.py └── opus.py ├── .github └── ISSUE_TEMPLATE │ └── default-issue-template.md ├── changelog.md ├── merge_adapter.py ├── combine_checkpoints.py ├── run_inference.py ├── benchmark.py ├── finetune_orpo.py ├── finetune_dpo.py ├── finetune.py ├── LICENSE ├── translate.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *pycache* 2 | wandb/ 3 | .idea -------------------------------------------------------------------------------- /gguf/convert.sh: -------------------------------------------------------------------------------- 1 | echo "Running this script assumes you have llama.cpp installed with all its requirements." 2 | echo "Usage: convert.sh model_name local_location llamacpp_location llamacpp_outfile llamacpp_outtype" 3 | python download_model.py $1 $2 4 | python $3/convert.py $2 --outfile $4 --outtype $5 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.25.0 2 | peft==0.10.0 3 | bitsandbytes==0.41.2.post2 4 | transformers==4.36.2 5 | trl==0.8.2 6 | sentencepiece==0.1.99 7 | sacremoses==0.1.1 8 | datasets==2.15.0 9 | huggingface_hub==0.19.4 10 | scipy==1.11.4 11 | tensorboardx==2.6.2.2 12 | pandas==1.5.3 13 | stanza==1.7.0 14 | tqdm 15 | sacrebleu 16 | google-generativeai -------------------------------------------------------------------------------- /threads/phi3.txt: -------------------------------------------------------------------------------- 1 | {{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|user|>' + '\n'}}{% elif (message['role'] == 'user') %}{{message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %} -------------------------------------------------------------------------------- /threads/template_default.txt: -------------------------------------------------------------------------------- 1 | {% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %} -------------------------------------------------------------------------------- /threads/template_chatml.txt: -------------------------------------------------------------------------------- 1 | {% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %} -------------------------------------------------------------------------------- /translators/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | class BaseTranslator(abc.ABC): 4 | def __init__(self, device, quant4, quant4_config, quant8, max_length): 5 | self.device = device 6 | self.quant4 = quant4 7 | self.quant4_config = quant4_config 8 | self.quant8 = quant8 9 | self.max_length = max_length 10 | 11 | @abc.abstractmethod 12 | def translate(self, texts, source_lang, target_lang): 13 | pass -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/default-issue-template.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Default issue template 3 | about: Use this template for bugs or questions 4 | title: Question or bug 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Branch** 11 | 12 | **Environment** 13 | **RAM/vRAM** 14 | 15 | **Script with parameters** 16 | 17 | **Data layout or HF dataset** 18 | 19 | **Problem description/Question** 20 | -------------------------------------------------------------------------------- /gguf/download_model.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import snapshot_download 2 | import argparse 3 | 4 | def main(): 5 | parser = argparse.ArgumentParser(description="Downloads a model from HF hub to disk.") 6 | parser.add_argument('model_name', type=str, 7 | help='The name of the model to download.') 8 | parser.add_argument('folder', type=str, 9 | help='The output folder to store the model.') 10 | 11 | 12 | args = parser.parse_args() 13 | model_name = args.model_name 14 | folder = args.folder 15 | 16 | snapshot_download(repo_id=model_name, local_dir=folder, local_dir_use_symlinks=False, revision="main") 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /threads/template_llama2.txt: -------------------------------------------------------------------------------- 1 | {% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %} {% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\n' + content.strip() + '\n<>\n\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %} -------------------------------------------------------------------------------- /changelog.md: -------------------------------------------------------------------------------- 1 | # Change info 2 | _v0.6_ 3 | * **[2024-04-19]** 👉 Added LLaMa3 support and first models. 👈 4 | * **[2024-04-12]** Added ORPO training as an RLHF substitute. 5 | 6 | _v0.5_ 7 | * **[2024-02-08]** Added DPO training as an RLHF substitute. 8 | * **[2024-02-08]** Added more translation methods like Tower Instruct (LLM) and Google Gemini via API (no GPU required) 9 | 10 | _v0.4_ 11 | * **[2024-01-18]** The create threads script has been removed. We now directly use the chat template provided by the base model's tokenizer, thus supporting mutliple chat/instruct/prompt templates. 12 | 13 | _v0.3_ 14 | * **[2024-01-12]** You can now benchmark different translation models using `benchmark.py`. 15 | * **[2024-01-09]** We have significantly refactored the translation process. Please follow the readme carefully if you come from v0.2. 16 | * **[2024-01-09]** We now support translation through M2M. 17 | * **[2024-01-04]** We now support translation through MADLAD. Especially for models where Helsinki has a low BLEU score (less than 40), MADLAD (or the faster M2M) is preferred. Using MADLAD drastically slows down training time, especially if you quantize (4 bit is even slower than 8 bit). 18 | * **[2024-01-04]** We now use argparser to parse command line arguments. Make sure you update your calls to our scripts accordingly. Use `-h` on all scripts to get help. 19 | 20 | _v0.2_ 21 | * **[2023-12-29]** We now batch translations in `translate.py` for a 30-60% speed increase. If you have checkpoints from before this date, you can **not** continue using the main branch but instead must use the [v0.1 branch](https://github.com/UnderstandLingBV/LLaMa2lang/tree/v0.1). 22 | -------------------------------------------------------------------------------- /merge_adapter.py: -------------------------------------------------------------------------------- 1 | from peft import AutoPeftModelForCausalLM 2 | import torch 3 | import os 4 | import argparse 5 | from transformers import AutoTokenizer 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description="Script to run merge a (Q)LoRA adapter into the base model.") 9 | parser.add_argument('model_name', type=str, 10 | help='The name of the tuned adapter model that you pushed to Huggingface after finetuning or DPO.') 11 | parser.add_argument('output_name', type=str, 12 | help='The name of the output (merged) model. Can either be on Huggingface or on disk') 13 | parser.add_argument('--cpu', action='store_true', 14 | help="Forces usage of CPU. By default GPU is taken if available.") 15 | 16 | args = parser.parse_args() 17 | model_name = args.model_name 18 | force_cpu = args.cpu 19 | device = torch.device("cuda:0" if torch.cuda.is_available() and not (force_cpu) else "cpu") 20 | output_model = args.output_name 21 | 22 | # Load the model and merge with base 23 | model = AutoPeftModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16, trust_remote_code=True) 24 | model = model.merge_and_unload() 25 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 26 | 27 | if os.path.isdir(output_model): 28 | model.save_to_disk(output_model) 29 | tokenizer.save_to_disk(output_model) 30 | else: 31 | # Try to push to hub, requires HF_TOKEN environment variable to be set, see https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hftoken 32 | model.push_to_hub(output_model) 33 | tokenizer.push_to_hub(output_model) 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /combine_checkpoints.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datasets import Dataset, DatasetDict 4 | import pandas as pd 5 | import sys 6 | import argparse 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser(description="Combine checkpoint files from translation.") 10 | parser.add_argument('input_folder', type=str, 11 | help='The checkpoint folder used in translation, with the target language appended. Example: "./checkpoints_nl".') 12 | parser.add_argument('output_location', type=str, 13 | help='Where to write the Huggingface Dataset. Can be a disk location or a Huggingface Dataset repository.') 14 | args = parser.parse_args() 15 | input_folder = args.input_folder 16 | output_location = args.output_location 17 | 18 | dataset = {} 19 | # Get the subdirectories which will become the keys of the Dataset 20 | folds = [name for name in os.listdir(input_folder) if os.path.isdir(os.path.join(input_folder, name))] 21 | 22 | for fold in folds: 23 | all_data = [] 24 | 25 | for lang_folder in os.listdir(os.path.join(input_folder, fold)): 26 | for filename in os.listdir(os.path.join(input_folder, fold, lang_folder)): 27 | if filename.endswith('.json'): 28 | file_path = os.path.join(input_folder, fold, lang_folder, filename) 29 | with open(file_path, 'r', encoding='utf-8') as file: 30 | data = json.load(file) 31 | all_data.extend(data) 32 | 33 | dataset[fold] = Dataset.from_pandas(pd.DataFrame(data=all_data)) 34 | 35 | dataset = DatasetDict(dataset) 36 | # Check if output location is a valid directory 37 | if os.path.isdir(output_location): 38 | dataset.save_to_disk(output_location) 39 | else: 40 | # Try to push to hub, requires HF_TOKEN environment variable to be set, see https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hftoken 41 | dataset.push_to_hub(output_location) 42 | 43 | if __name__ == "__main__": 44 | main() -------------------------------------------------------------------------------- /translators/m2m.py: -------------------------------------------------------------------------------- 1 | from translators.base import BaseTranslator 2 | from transformers import ( 3 | M2M100ForConditionalGeneration, 4 | M2M100Tokenizer 5 | ) 6 | import torch 7 | 8 | class M2MTranslator(BaseTranslator): 9 | def __init__(self, device, quant4, quant4_config, quant8, max_length, model_size): 10 | super().__init__(device, quant4, quant4_config, quant8, max_length) 11 | self.model_size = model_size 12 | 13 | model_name = f'facebook/m2m100_{self.model_size}' 14 | # Load model and tokenizer 15 | if self.quant4: 16 | model = M2M100ForConditionalGeneration.from_pretrained(model_name, device_map=device, quantization_config=self.quant4_config, load_in_4bit=True) 17 | elif self.quant8: 18 | model = M2M100ForConditionalGeneration.from_pretrained(model_name, device_map=self.device, load_in_8bit=True) 19 | else: 20 | model = M2M100ForConditionalGeneration.from_pretrained(model_name).to(self.device) 21 | tokenizer = M2M100Tokenizer.from_pretrained(model_name) 22 | 23 | self.model = model 24 | self.tokenizer = tokenizer 25 | 26 | def translate(self, texts, source_lang, target_lang): 27 | # Small fix for odd language codes 28 | if source_lang == 'pt-BR': 29 | source_lang = 'pt' 30 | if source_lang == 'uk-UA': 31 | source_lang = 'uk' 32 | with torch.no_grad(): 33 | if source_lang == 'eu': 34 | # Not supported by M2M 35 | return None 36 | # Set the source language for the tokenizer 37 | self.tokenizer.src_lang = source_lang 38 | if self.max_length is None: 39 | encoded_batch = self.tokenizer(texts, return_tensors="pt", padding=True).to(self.device) 40 | generated_tokens = self.model.generate(**encoded_batch, forced_bos_token_id=self.tokenizer.get_lang_id(target_lang)) 41 | else: 42 | encoded_batch = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length).to(self.device) 43 | generated_tokens = self.model.generate(**encoded_batch, max_length=self.max_length, forced_bos_token_id=self.tokenizer.get_lang_id(target_lang)) 44 | translated_texts = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) 45 | 46 | return translated_texts 47 | -------------------------------------------------------------------------------- /translators/madlad.py: -------------------------------------------------------------------------------- 1 | from translators.base import BaseTranslator 2 | from transformers import ( 3 | T5ForConditionalGeneration, 4 | T5Tokenizer 5 | ) 6 | import torch 7 | 8 | class MADLADTranslator(BaseTranslator): 9 | def __init__(self, device, quant4, quant4_config, quant8, max_length, model_size): 10 | super().__init__(device, quant4, quant4_config, quant8, max_length) 11 | self.model_size = model_size 12 | 13 | model_name = f'google/madlad400-{self.model_size}-mt' 14 | # Quick rewrite the model name for bt 15 | if self.model_size == '7b-bt': 16 | model_name = 'google/madlad400-7b-mt-bt' 17 | # Load model and tokenizer 18 | if self.quant4: 19 | model = T5ForConditionalGeneration.from_pretrained(model_name, device_map=device, quantization_config=self.quant4_config, load_in_4bit=True) 20 | elif self.quant8: 21 | model = T5ForConditionalGeneration.from_pretrained(model_name, device_map=self.device, load_in_8bit=True) 22 | else: 23 | model = T5ForConditionalGeneration.from_pretrained(model_name).to(self.device) 24 | tokenizer = T5Tokenizer.from_pretrained(model_name) 25 | 26 | self.model = model 27 | self.tokenizer = tokenizer 28 | 29 | def translate(self, texts, source_lang, target_lang): 30 | # Small fix for odd language codes 31 | if source_lang == 'pt-BR': 32 | source_lang = 'pt' 33 | if source_lang == 'uk-UA': 34 | source_lang = 'uk' 35 | with torch.no_grad(): 36 | # Preprocess texts and add target language prefix 37 | madlad_texts = [f'<2{target_lang}> ' + text.replace("\n", " ") for text in texts] 38 | if self.max_length is None: 39 | encoded_batch = self.tokenizer(madlad_texts, return_tensors="pt", padding=True).to(self.device) 40 | outputs = self.model.generate(input_ids=encoded_batch['input_ids'], max_new_tokens=2048) # max_new_tokens is required otherwise we get 20 41 | else: 42 | encoded_batch = self.tokenizer(madlad_texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length).to(self.device) 43 | outputs = self.model.generate(input_ids=encoded_batch['input_ids'], max_new_tokens=self.max_length) 44 | translated_texts = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs] 45 | 46 | return translated_texts 47 | -------------------------------------------------------------------------------- /translators/nllb.py: -------------------------------------------------------------------------------- 1 | from translators.base import BaseTranslator 2 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 3 | import torch 4 | 5 | class NLLBTranslator(BaseTranslator): 6 | language_mapping = { 7 | 'en': 'eng_Latn', 8 | 'es': 'spa_Latn', 9 | 'de': 'deu_Latn', 10 | 'ru': 'rus_Cyrl', 11 | 'ja': 'jpn_Jpan', 12 | 'pt-BR': 'por_Latn', 13 | 'ca': 'cat_Latn', 14 | 'fr': 'fra_Latn', 15 | 'pl': 'pol_Latn', 16 | 'vi': 'vie_Latn', 17 | 'zh': 'zho_Hant', 18 | 'hu': 'hun_Latn', 19 | 'ko': 'kor_Hang', 20 | 'eu': 'eus_Latn', 21 | 'it': 'ita_Latn', 22 | 'uk-UA': 'ukr_Cyrl', 23 | 'uk': 'ukr_Cyrl', 24 | 'id': 'ind_Latn', 25 | 'ar': 'arb_Arab', 26 | 'fi': 'fin_Latn', 27 | 'tr': 'tur_Latn', 28 | 'da': 'dan_Latn', 29 | 'th': 'tha_Thai', 30 | 'sv': 'swe_Latn', 31 | 'cs': 'ces_Latn', 32 | 'nl': 'nld_Latn' 33 | } 34 | 35 | def __init__(self, device, quant4, quant4_config, quant8, max_length, model_size): 36 | super().__init__(device, quant4, quant4_config, quant8, max_length) 37 | 38 | model_name = f'facebook/nllb-200-{model_size}' 39 | # Load model and tokenizer 40 | if self.quant4: 41 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map=device, quantization_config=self.quant4_config, load_in_4bit=True) 42 | elif self.quant8: 43 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map=self.device, load_in_8bit=True) 44 | else: 45 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device) 46 | tokenizer = AutoTokenizer.from_pretrained(model_name) 47 | 48 | self.model = model 49 | self.tokenizer = tokenizer 50 | 51 | def translate(self, texts, source_lang, target_lang): 52 | self.tokenizer.src_lang = self.language_mapping[source_lang] 53 | with torch.no_grad(): 54 | if self.max_length is None: 55 | encoded_batch = self.tokenizer(texts, return_tensors="pt", padding=True).to(self.device) 56 | else: 57 | encoded_batch = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length).to(self.device) 58 | outputs = self.model.generate(**encoded_batch, forced_bos_token_id=self.tokenizer.lang_code_to_id[target_lang]) 59 | translated_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) 60 | 61 | return translated_texts 62 | -------------------------------------------------------------------------------- /threads/finetune_prompts.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from tqdm import tqdm 3 | 4 | # We only continue the thread with the highest ranked answer to each input 5 | def find_highest_ranked_child(df, parent_id, base_dataset_parent_field, base_dataset_rank_field): 6 | children = df[df[base_dataset_parent_field] == parent_id] 7 | if not children.empty: 8 | return children.loc[children[base_dataset_rank_field].idxmin()] 9 | return None 10 | 11 | # Creates the prompts 12 | def create_prompts(dataset, tokenizer, base_dataset_rank_field, base_dataset_parent_field, base_dataset_id_field, base_dataset_text_field, base_dataset_role_field, instruction_prompt, chat_template): 13 | # Construct threads 14 | threads = [] 15 | df = dataset.to_pandas() 16 | 17 | # Replace NULLs in rank with a value higher than the highest rank 18 | max_rank = df[base_dataset_rank_field].max() 19 | df[base_dataset_rank_field].fillna(max_rank + 1, inplace=True) 20 | 21 | # Identify root messages (those without a parent_id) 22 | root_messages = df[df[base_dataset_parent_field].isna()] 23 | 24 | with tqdm(total=len(root_messages)) as pbar: 25 | for _, root_message in root_messages.iterrows(): 26 | if root_message[base_dataset_text_field] is None: 27 | continue 28 | # Create the thread 29 | thread = [ 30 | { 31 | 'content': instruction_prompt, 32 | 'role': 'system' 33 | }, 34 | { 35 | 'content': root_message[base_dataset_text_field], 36 | 'role': 'user' 37 | } 38 | ] 39 | next_message = find_highest_ranked_child(df, root_message[base_dataset_id_field], base_dataset_parent_field, base_dataset_rank_field) 40 | 41 | while next_message is not None: 42 | role = next_message[base_dataset_role_field] 43 | if role == 'prompter': 44 | role = 'user' 45 | thread.append({ 46 | 'content': next_message[base_dataset_text_field], 47 | 'role': role 48 | }) 49 | next_message = find_highest_ranked_child(df, next_message[base_dataset_id_field], base_dataset_parent_field, base_dataset_rank_field) 50 | 51 | # Turn this into LLaMa3 format 52 | try: 53 | threads.append({'text': tokenizer.apply_chat_template(thread, tokenize=False, chat_template=chat_template)}) 54 | except Exception as e: 55 | print(f"ERROR: {e}") 56 | print(thread) 57 | import sys 58 | sys.exit(0) 59 | # Update progress 60 | pbar.update(1) 61 | 62 | return threads -------------------------------------------------------------------------------- /translators/towerinstruct.py: -------------------------------------------------------------------------------- 1 | from translators.base import BaseTranslator 2 | from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline 3 | import torch 4 | 5 | class TowerInstructTranslator(BaseTranslator): 6 | language_mapping = { 7 | 'en': 'English', 8 | 'pt': 'Portuguese', 9 | 'pt-BR': 'Portuguese', 10 | 'es': 'Spanish', 11 | 'fr': 'French', 12 | 'de': 'German', 13 | 'nl': 'Dutch', 14 | 'it': 'Italian', 15 | 'ko': 'Korean', 16 | 'zh': 'Chinese', 17 | 'ru': 'Russian', 18 | 'uk': 'Ukrainian' 19 | } 20 | def __init__(self, device, quant4, quant4_config, quant8, max_length): 21 | super().__init__(device, quant4, quant4_config, quant8, max_length) 22 | 23 | model_name = f'Unbabel/TowerInstruct-7B-v0.1' 24 | # Load model and tokenizer 25 | if self.quant4: 26 | model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, quantization_config=self.quant4_config, load_in_4bit=True) 27 | elif self.quant8: 28 | model = AutoModelForCausalLM.from_pretrained(model_name, device_map=self.device, load_in_8bit=True) 29 | else: 30 | model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device) 31 | tokenizer = AutoTokenizer.from_pretrained(model_name) 32 | 33 | self.nlp_pipeline = pipeline("text-generation", model=model, device_map=self.device, tokenizer=tokenizer) 34 | self.printed_error_langs = {} 35 | 36 | def translate(self, texts, source_lang, target_lang): 37 | if source_lang in self.language_mapping and target_lang in self.language_mapping: 38 | src_lang = self.language_mapping[source_lang] 39 | trgt_lang = self.language_mapping[target_lang] 40 | 41 | with torch.no_grad(): 42 | texts = [{'role':'user','content': f'Translate the following text from {src_lang} into {trgt_lang}.\n{src_lang}: {t}\n{trgt_lang}:'} for t in texts] 43 | prompts = [self.nlp_pipeline.tokenizer.apply_chat_template([text], tokenize=False, add_generation_prompt=True) for text in texts] 44 | if self.max_length is None: 45 | outputs = [self.nlp_pipeline(prompt, do_sample=False) for prompt in prompts] 46 | else: 47 | outputs = [self.nlp_pipeline(prompt, max_new_tokens=self.max_length, do_sample=False) for prompt in prompts] 48 | 49 | # Remove the prompts from the outputs 50 | result = [] 51 | for output, prompt in zip(outputs, prompts): 52 | result.append(output[0]['generated_text'][len(prompt):]) 53 | 54 | return result 55 | else: 56 | if not(source_lang in self.printed_error_langs): 57 | print(f"[---- LLaMa2Lang ----] Tower Instruct cannot translate from source language {source_lang} or to your target language {target_lang}, returning originals") 58 | self.printed_error_langs[source_lang] = True 59 | return None -------------------------------------------------------------------------------- /threads/finetune_prompts_dpo.py: -------------------------------------------------------------------------------- 1 | from pandas import DataFrame 2 | from tqdm import tqdm 3 | 4 | def format_dpo( 5 | thread: list[str], 6 | system_instruction: str, 7 | bad_child: str, 8 | tokenizer, 9 | chat_template) \ 10 | -> dict[str, str]: 11 | chat = [ 12 | {"role": "system", "content": system_instruction} 13 | ] 14 | 15 | formatted_thread: dict[str, str] = {} 16 | 17 | for i in range(0, len(thread) - 1): 18 | if i % 2 == 0: 19 | chat.append({"role": "user", "content": thread[i]}) 20 | else: 21 | chat.append({"role": "assistant", "content": thread[i]}) 22 | 23 | # Run it untokenized so we can write it out 24 | formatted_thread['prompt'] = tokenizer.apply_chat_template(chat, tokenize=False, chat_template=chat_template)[len(tokenizer.bos_token):] 25 | formatted_thread['chosen'] = thread[-1] 26 | formatted_thread['rejected'] = bad_child 27 | 28 | return formatted_thread 29 | 30 | 31 | # We only continue the thread with the highest ranked answer to each input 32 | def find_children_and_highest_ranked_child( 33 | df: DataFrame, 34 | parent_id: int, 35 | base_dataset_parent_field: str, 36 | base_dataset_rank_field: str) -> tuple[DataFrame, DataFrame]: 37 | children = df[df[base_dataset_parent_field] == parent_id] 38 | min_rank = children[base_dataset_rank_field].min() 39 | 40 | if not children.empty: 41 | return children[children[base_dataset_rank_field] == min_rank], children[children[base_dataset_rank_field] != min_rank] 42 | 43 | df_empty = children.iloc[:0, :].copy() 44 | 45 | return df_empty, df_empty 46 | 47 | 48 | def create_prompts(dataset, tokenizer, base_dataset_rank_field, base_dataset_parent_field, base_dataset_id_field, base_dataset_text_field, instruction_prompt, chat_template): 49 | # Construct threads 50 | threads = [] 51 | df = dataset.to_pandas() 52 | 53 | # Replace NULLs in rank with a value highest than the highest rank 54 | max_rank = df[base_dataset_rank_field].max() 55 | df[base_dataset_rank_field].fillna(max_rank + 1, inplace=True) 56 | 57 | # Identify root messages (those without a parent_id) 58 | root_messages = df[df[base_dataset_parent_field].isna()] 59 | 60 | with tqdm(total=len(root_messages)) as pbar: 61 | for _, root_message in root_messages.iterrows(): 62 | # Create the thread 63 | thread: list[str] = [root_message[base_dataset_text_field]] 64 | 65 | good_child, bad_children = find_children_and_highest_ranked_child(df, 66 | root_message[base_dataset_id_field], base_dataset_parent_field, base_dataset_rank_field) 67 | 68 | while not good_child.empty: 69 | thread.append(good_child.iloc[0][base_dataset_text_field]) 70 | 71 | for bad_child in bad_children.iterrows(): 72 | formatted_dpo = format_dpo(thread, instruction_prompt, bad_child[1][base_dataset_text_field], tokenizer, chat_template) 73 | threads.append(formatted_dpo) 74 | 75 | good_child, bad_children = find_children_and_highest_ranked_child(df, good_child[ 76 | base_dataset_id_field].iloc[0], base_dataset_parent_field, base_dataset_rank_field) 77 | 78 | # Update progress 79 | pbar.update(1) 80 | 81 | return threads 82 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | import bitsandbytes as bnb 2 | from functools import partial 3 | from peft import AutoPeftModelForCausalLM 4 | import torch 5 | from transformers import AutoTokenizer 6 | import sys 7 | import argparse 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(description="Script to run inference on a tuned model.") 11 | parser.add_argument('model_name', type=str, 12 | help='The name of the tuned model that you pushed to Huggingface after finetuning or DPO.') 13 | parser.add_argument('instruction_prompt', type=str, 14 | help='An instruction message added to every prompt given to the chatbot to force it to answer in the target language.') 15 | parser.add_argument('--cpu', action='store_true', 16 | help="Forces usage of CPU. By default GPU is taken if available.") 17 | parser.add_argument('--thread_template', type=str, default="threads/template_default.txt", 18 | help='A file containing the thread template to use. Default is threads/template_fefault.txt') 19 | parser.add_argument('--padding', type=str, default="left", 20 | help='What padding to use, can be either left or right.') 21 | 22 | 23 | args = parser.parse_args() 24 | model_name = args.model_name 25 | instruction_prompt = args.instruction_prompt 26 | thread_template_file = args.thread_template 27 | force_cpu = args.cpu 28 | device = torch.device("cuda:0" if torch.cuda.is_available() and not (force_cpu) else "cpu") 29 | padding = args.padding 30 | 31 | # Get the template 32 | with open(thread_template_file, 'r', encoding="utf8") as f: 33 | chat_template = f.read() 34 | 35 | # Load the model and merge with base 36 | model = AutoPeftModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16) 37 | model = model.merge_and_unload() 38 | model.eval() 39 | tokenizer = AutoTokenizer.from_pretrained(model_name) 40 | if padding == 'left': 41 | tokenizer.pad_token_id = 0 42 | else: 43 | tokenizer.pad_token_id = tokenizer.eos_token_id 44 | tokenizer.padding_side = padding 45 | 46 | thread = [ 47 | {'role': 'system', 'content': instruction_prompt} 48 | ] 49 | while True: 50 | user_input = input("Enter your input, use ':n' for a new thread or ':q' to quit: ") 51 | if user_input.lower() == ':q': 52 | break 53 | elif user_input.lower() == ':n': 54 | thread = [{'role': 'system', 'content': instruction_prompt}] 55 | continue 56 | 57 | # Prepare input in LLaMa3 chat format 58 | thread.append({ 59 | 'role': 'user', 'content': user_input 60 | }) 61 | input_chat = tokenizer.apply_chat_template(thread, tokenize=False, chat_template=chat_template) 62 | inputs = tokenizer(input_chat, return_tensors="pt").to(device) 63 | 64 | # Generate response and decode 65 | output_sequences = model.generate( 66 | input_ids=inputs['input_ids'], 67 | max_length=200, 68 | repetition_penalty=1.2 # LLaMa3 is sensitive to repetition 69 | ) 70 | generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True) 71 | print(generated_text) 72 | # Get the answer only 73 | answer = generated_text[(len(input_chat)-len(tokenizer.bos_token)+1):] 74 | thread.append({ 75 | 'role': 'assistant', 'content': answer 76 | }) 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /translators/mbart.py: -------------------------------------------------------------------------------- 1 | from translators.base import BaseTranslator 2 | from transformers import MBartForConditionalGeneration, MBart50TokenizerFast 3 | import torch 4 | 5 | class mBARTTranslator(BaseTranslator): 6 | language_mapping = { 7 | 'ar': 'ar_AR', 8 | 'cs': 'cs_CZ', 9 | 'de': 'de_DE', 10 | 'en': 'en_XX', 11 | 'es': 'es_XX', 12 | 'et': 'et_EE', 13 | 'fi': 'fi_FI', 14 | 'fr': 'fr_XX', 15 | 'gu': 'gu_IN', 16 | 'hi': 'hi_IN', 17 | 'it': 'it_IT', 18 | 'ja': 'ja_XX', 19 | 'kk': 'kk_KZ', 20 | 'ko': 'ko_KR', 21 | 'lt': 'lt_LT', 22 | 'lv': 'lv_LV', 23 | 'my': 'my_MM', 24 | 'ne': 'ne_NP', 25 | 'nl': 'nl_XX', 26 | 'ro': 'ro_RO', 27 | 'ru': 'ru_RU', 28 | 'si': 'si_LK', 29 | 'tr': 'tr_TR', 30 | 'vi': 'vi_VN', 31 | 'zh': 'zh_CN', 32 | 'af': 'af_ZA', 33 | 'az': 'az_AZ', 34 | 'bn': 'bn_IN', 35 | 'fa': 'fa_IR', 36 | 'he': 'he_IL', 37 | 'hr': 'hr_HR', 38 | 'id': 'id_ID', 39 | 'ka': 'ka_GE', 40 | 'km': 'km_KH', 41 | 'mk': 'mk_MK', 42 | 'ml': 'ml_IN', 43 | 'mn': 'mn_MN', 44 | 'mr': 'mr_IN', 45 | 'pl': 'pl_PL', 46 | 'ps': 'ps_AF', 47 | 'pt': 'pt_XX', 48 | 'pt-BR': 'pt_XX', 49 | 'sv': 'sv_SE', 50 | 'sw': 'sw_KE', 51 | 'ta': 'ta_IN', 52 | 'te': 'te_IN', 53 | 'th': 'th_TH', 54 | 'tl': 'tl_XX', 55 | 'uk_UA': 'uk_UA', 56 | 'uk': 'uk_UA', 57 | 'ur': 'ur_PK', 58 | 'xh': 'xh_ZA', 59 | 'gl': 'gl_ES', 60 | 'sl': 'sl_SI' 61 | } 62 | 63 | def __init__(self, device, quant4, quant4_config, quant8, max_length): 64 | super().__init__(device, quant4, quant4_config, quant8, max_length) 65 | 66 | model_name = 'facebook/mbart-large-50-many-to-many-mmt' 67 | # Load model and tokenizer 68 | if self.quant4: 69 | model = MBartForConditionalGeneration.from_pretrained(model_name, device_map=device, quantization_config=self.quant4_config, load_in_4bit=True) 70 | elif self.quant8: 71 | model = MBartForConditionalGeneration.from_pretrained(model_name, device_map=self.device, load_in_8bit=True) 72 | else: 73 | model = MBartForConditionalGeneration.from_pretrained(model_name).to(self.device) 74 | tokenizer = MBart50TokenizerFast.from_pretrained(model_name) 75 | 76 | self.model = model 77 | self.tokenizer = tokenizer 78 | self.printed_error_langs = {} 79 | 80 | def translate(self, texts, source_lang, target_lang): 81 | if source_lang in self.language_mapping: 82 | self.tokenizer.src_lang = self.language_mapping[source_lang] 83 | with torch.no_grad(): 84 | if self.max_length is None: 85 | encoded_batch = self.tokenizer(texts, return_tensors="pt", padding=True).to(self.device) 86 | else: 87 | encoded_batch = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length).to(self.device) 88 | outputs = self.model.generate(**encoded_batch, forced_bos_token_id=self.tokenizer.lang_code_to_id[target_lang]) 89 | translated_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) 90 | 91 | return translated_texts 92 | else: 93 | if not(source_lang in self.printed_error_langs): 94 | print(f"[---- LLaMa2Lang ----] mBART cannot translate from source language {source_lang}, returning originals") 95 | self.printed_error_langs[source_lang] = True 96 | return None 97 | -------------------------------------------------------------------------------- /translators/seamless_m4t_v2.py: -------------------------------------------------------------------------------- 1 | from translators.base import BaseTranslator 2 | from transformers import SeamlessM4Tv2ForTextToText, AutoProcessor 3 | from stanza.pipeline.core import DownloadMethod 4 | import stanza 5 | import torch 6 | 7 | 8 | class Seamless_M4T_V2(BaseTranslator): 9 | language_mapping = { 10 | 'en': 'eng', 11 | 'es': 'spa', 12 | 'de': 'deu', 13 | 'ru': 'rus', 14 | 'ja': 'jpn', 15 | 'pt-BR': 'por', 16 | 'ca': 'cat', 17 | 'fr': 'fra', 18 | 'pl': 'pol', 19 | 'vi': 'vie', 20 | 'zh': 'zho', 21 | 'hu': 'hun', 22 | 'ko': 'kor', 23 | 'eu': 'eus', 24 | 'it': 'ita', 25 | 'uk-UA': 'ukr', 26 | 'uk': 'ukr', 27 | 'id': 'ind', 28 | 'ar': 'arb', 29 | 'fi': 'fin', 30 | 'tr': 'tur', 31 | 'da': 'dan', 32 | 'th': 'tha', 33 | 'sv': 'swe', 34 | 'cs': 'ces' 35 | } 36 | 37 | def __init__(self, device, quant4, quant4_config, quant8, max_length): 38 | super().__init__(device, quant4, quant4_config, quant8, max_length) 39 | 40 | model_name = f'facebook/seamless-m4t-v2-large' 41 | # Load model and tokenizer 42 | if self.quant4: 43 | model = SeamlessM4Tv2ForTextToText.from_pretrained(model_name, device_map=device, 44 | quantization_config=self.quant4_config, 45 | load_in_4bit=True) 46 | elif self.quant8: 47 | model = SeamlessM4Tv2ForTextToText.from_pretrained(model_name, device_map=self.device, load_in_8bit=True) 48 | else: 49 | model = SeamlessM4Tv2ForTextToText.from_pretrained(model_name).to(self.device) 50 | processor = AutoProcessor.from_pretrained(model_name) 51 | self.model = model 52 | self.processor = processor 53 | 54 | def translate(self, texts, source_lang, target_lang): 55 | self.processor.src_lang = self.language_mapping[source_lang] 56 | with torch.no_grad(): 57 | # Seamless is good for short messages/sentences, 58 | # so there is need to conduct sentence segmentation to have a 59 | # good quality translation of texts 60 | nlp_processors = {'tokenize': 'spacy'} if source_lang == 'en' else 'tokenize' 61 | nlp = stanza.Pipeline( 62 | lang=source_lang, 63 | download_method=DownloadMethod.REUSE_RESOURCES, 64 | processors=nlp_processors, 65 | use_gpu=True, 66 | verbose=False 67 | ) 68 | sentence_segmented_texts = nlp.bulk_process(texts) 69 | translated_texts = [] 70 | for document in sentence_segmented_texts: 71 | translated_text = "" 72 | for sentence in document.sentences: 73 | decoder_input_ids = self.translate_text(target_lang, sentence.text) 74 | translated_text += self.processor.decode(decoder_input_ids, skip_special_tokens=True) + " " 75 | translated_texts.append(translated_text.strip()) 76 | return translated_texts 77 | 78 | def translate_text(self, target_lang, text): 79 | if self.max_length is None: 80 | encoded_batch = self.processor(text, return_tensors="pt", padding=True).to(self.device) 81 | else: 82 | encoded_batch = self.processor(text, return_tensors="pt", padding=True, truncation=True, 83 | max_length=self.max_length).to(self.device) 84 | decoder_input_ids = self.model.generate(**encoded_batch, 85 | tgt_lang=self.language_mapping[target_lang])[0].tolist() 86 | return decoder_input_ids 87 | -------------------------------------------------------------------------------- /translators/gemini_pro.py: -------------------------------------------------------------------------------- 1 | from google.api_core.exceptions import InternalServerError 2 | from translators.base import BaseTranslator 3 | 4 | import google.generativeai as genai 5 | import asyncio 6 | import codecs 7 | 8 | 9 | class GeminiProTranslator(BaseTranslator): 10 | # based on https://ai.google.dev/available_regions#available_languages 11 | # make sure that you have access to Gemini Region 12 | language_mapping = { 13 | "en": "English", 14 | "pt": "Portuguese", 15 | "pt-BR": "Portuguese", 16 | "es": "Spanish", 17 | "fr": "French", 18 | "de": "German", 19 | "nl": "Dutch", 20 | "it": "Italian", 21 | "ko": "Korean", 22 | "zh": "Chinese", 23 | "uk": "Ukrainian", 24 | "uk-UA": "Ukrainian", 25 | "ja": "Japan", 26 | "pl": "Polish", 27 | "ar": "Arabic", 28 | "bn": "Bengali", 29 | "bg": "Bulgarian", 30 | "hr": "Croatian", 31 | "cs": "Czech", 32 | "da": "Danish", 33 | "et": "Estonian", 34 | "fi": "Finnish", 35 | "el": "Greek", 36 | "iw": "Hebrew", 37 | "hi": "Hindi", 38 | "hu": "Hungarian", 39 | "id": "Indonesian", 40 | "lv": "Latvian", 41 | "lt": "Lithuanian", 42 | "no": "Norwegian", 43 | "ro": "Romanian", 44 | "ru": "Russian", 45 | "sr": "Serbian", 46 | "sk": "Slovak", 47 | "sl": "Slovenian", 48 | "sw": "Swahili", 49 | "sv": "Swedish", 50 | "th": "Thai", 51 | "tr": "Turkish", 52 | "vi": "Vietnamese" 53 | } 54 | 55 | def __init__(self, access_token, max_length): 56 | if access_token is None: 57 | raise Exception("Access token is required!") 58 | super().__init__(None, None, None, None, max_length) 59 | genai.configure(api_key=access_token) 60 | self.printed_error_langs = {} 61 | self.model = genai.GenerativeModel('gemini-pro') 62 | 63 | async def translate_text(self, text, prompt): 64 | try: 65 | ## Need to ignore safety to correctly translate input from different languages 66 | result = self.model.generate_content_async(f"{prompt}\n{text}", safety_settings={'HARASSMENT': 'block_none', 67 | 'HARM_CATEGORY_SEXUALLY_EXPLICIT': 'block_none', 68 | 'harm_category_dangerous_content': 'block_none', 69 | 'harm_category_hate_speech': 'block_none', 70 | 'harm_category_harassment': 'block_none' 71 | }) 72 | return await result 73 | except InternalServerError: 74 | return await self.translate_text(text, prompt) 75 | def decode_result(self, response): 76 | try: 77 | return response.text 78 | except: 79 | try: 80 | result = "".join(map(lambda part: part.text, response.parts)) 81 | decoded_result = codecs.escape_decode(result)[0].decode("utf8") 82 | return decoded_result 83 | except: 84 | if len(response.candidates) == 0: 85 | return 86 | result = "".join(map(lambda part: part.text, response.candidates[0].content.parts)) 87 | decoded_result = codecs.escape_decode(result)[0].decode("utf8") 88 | return decoded_result 89 | 90 | async def translate_texts(self, texts, prompt): 91 | tasks = [] 92 | for text in texts: 93 | tasks.append(self.translate_text(text, prompt)) 94 | await asyncio.sleep(1) 95 | results = await asyncio.gather(*tasks) 96 | decoded_results = [] 97 | for i in range(0,len(results)): 98 | try: 99 | decoded_results.append(self.decode_result(results[i])) 100 | except: 101 | print("Error during translation, returning source language") 102 | decoded_results.append(texts[i]) 103 | 104 | return decoded_results 105 | 106 | def translate(self, texts, source_lang, target_lang): 107 | if len(texts) > 60: 108 | raise Exception("Batch size cannot be more than 60 for this translator due ratelimit in 60 RPM!") 109 | if source_lang in self.language_mapping and target_lang in self.language_mapping: 110 | trgt_lang = self.language_mapping[target_lang] 111 | prompt = (f"Translate text below to {trgt_lang} language and preserve formatting and special characters. " 112 | f"Respond with translated text ONLY. Here is text to translate:\n") 113 | loop = asyncio.get_event_loop() 114 | result = loop.run_until_complete(self.translate_texts(texts, prompt)) 115 | return result 116 | else: 117 | if not (source_lang in self.printed_error_langs): 118 | print( 119 | f"[---- LLaMa2Lang ----] Gemini Pro cannot translate from source language {source_lang} or to your target language {target_lang}, returning originals") 120 | self.printed_error_langs[source_lang] = True 121 | return None 122 | -------------------------------------------------------------------------------- /translators/opus.py: -------------------------------------------------------------------------------- 1 | from translators.base import BaseTranslator 2 | from transformers import ( 3 | AutoTokenizer, 4 | AutoModelForSeq2SeqLM 5 | ) 6 | import torch 7 | 8 | class OPUSTranslator(BaseTranslator): 9 | def __init__(self, device, quant4, quant4_config, quant8, max_length): 10 | super().__init__(device, quant4, quant4_config, quant8, max_length) 11 | # Cache for loaded translation models, seemingly faster than letting Huggingface handle it 12 | self.model_cache = {} 13 | # Alternative models that are not created by Helsink-NLP 14 | self.alternative_models = { 15 | "en-pl": 'gsarti/opus-mt-tc-en-pl', 16 | "en-ja": 'gsarti/opus-mt-tc-base-en-ja' 17 | } 18 | 19 | def translate(self, texts, source_lang, target_lang): 20 | with torch.no_grad(): 21 | model, tokenizer = self.get_helsinki_nlp_model(source_lang, target_lang) 22 | if model is None or tokenizer is None: 23 | # Try via intermediate language 24 | model_i, tokenizer_i = self.get_helsinki_nlp_model(source_lang, 'en') 25 | model_t, tokenizer_t = self.get_helsinki_nlp_model('en', target_lang) 26 | if model_i is None or tokenizer_i is None or model_t is None or tokenizer_t is None: 27 | print(f"[---- LLaMa2Lang ----] No translation possible from {source_lang} to {target_lang}") 28 | return None 29 | 30 | # To intermediate language first 31 | if self.max_length is None: 32 | # OPUS crashes if we pass it more than 512 tokens 33 | inputs = tokenizer_i(texts, padding=True, truncation=True, max_length=512, return_tensors="pt").to(self.device) 34 | translated_outputs = model_i.generate(inputs.input_ids) 35 | else: 36 | inputs = tokenizer_i(texts, padding=True, truncation=True, return_tensors="pt", max_length=self.max_length).to(self.device) 37 | translated_outputs = model_i.generate(inputs.input_ids, max_length=self.max_length) 38 | intermediate_texts = [tokenizer_i.decode(output, skip_special_tokens=True) for output in translated_outputs] 39 | 40 | # Now to target 41 | if self.max_length is None: 42 | inputs = tokenizer_t(intermediate_texts, padding=True, truncation=True, max_length=512, return_tensors="pt").to(self.device) 43 | translated_outputs = model_t.generate(inputs.input_ids) 44 | else: 45 | inputs = tokenizer_t(intermediate_texts, padding=True, truncation=True, return_tensors="pt", max_length=self.max_length).to(self.device) 46 | translated_outputs = model_t.generate(inputs.input_ids, max_length=self.max_length) 47 | translated_texts = [tokenizer_t.decode(output, skip_special_tokens=True) for output in translated_outputs] 48 | return translated_texts 49 | else: 50 | if self.max_length is None: 51 | inputs = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt").to(self.device) 52 | translated_outputs = model.generate(inputs.input_ids) 53 | else: 54 | inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=self.max_length).to(self.device) 55 | translated_outputs = model.generate(inputs.input_ids, max_length=self.max_length) 56 | translated_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in translated_outputs] 57 | return translated_texts 58 | 59 | def load_model(self, model_name, model_key): 60 | tokenizer = AutoTokenizer.from_pretrained(model_name) 61 | # Apply quantization if needed 62 | if self.quant4: 63 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map=self.device, quantization_config=self.quant4_config, load_in_4bit=True) 64 | elif self.quant8: 65 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map=self.device, load_in_8bit=True) 66 | else: 67 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device) 68 | self.model_cache[model_key] = (model, tokenizer) 69 | return model, tokenizer 70 | 71 | # Tries to obtain a translation model from the Helsinki-NLP groups OPUS models. Returns None, None if no model is found for this language pair 72 | def get_helsinki_nlp_model(self, source_lang, target_lang): 73 | # Small fix for odd language codes 74 | if source_lang == 'pt-BR': 75 | source_lang = 'bzs' 76 | if source_lang == 'uk-UA': 77 | source_lang = 'uk' 78 | model_key = f'{source_lang}-{target_lang}' 79 | 80 | if model_key in self.model_cache: 81 | return self.model_cache[model_key] 82 | 83 | model_name = f'Helsinki-NLP/opus-mt-{source_lang}-{target_lang}' 84 | try: 85 | return self.load_model(model_name, model_key) 86 | except OSError as e: 87 | # Try to load the tc-big naming convention files 88 | try: 89 | model_name = f'Helsinki-NLP/opus-mt-tc-big-{source_lang}-{target_lang}' 90 | return self.load_model(model_name, model_key) 91 | except OSError as e: 92 | try: 93 | model_name = self.alternative_models[model_key] 94 | return self.load_model(model_name, model_key) 95 | except Exception as e: 96 | return None, None -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from transformers import BitsAndBytesConfig 4 | from datasets import load_dataset 5 | import gc 6 | from sacrebleu.metrics import BLEU, CHRF 7 | from translators.m2m import M2MTranslator 8 | from translators.madlad import MADLADTranslator 9 | from translators.mbart import mBARTTranslator 10 | from translators.nllb import NLLBTranslator 11 | from translators.opus import OPUSTranslator 12 | from translators.seamless_m4t_v2 import Seamless_M4T_V2 13 | 14 | def main(): 15 | allowed_models = ['opus', 'm2m_418m', 'm2m_1.2b', 'madlad_3b', 'madlad_7b', 'madlad_10b', 'madlad_7bbt', 'mbart', 'nllb_distilled-600m', 'nllb_1.3b', 'nllb_distilled-1.3b', 'nllb_3.3b', 'seamless'] 16 | parser = argparse.ArgumentParser(description="Benchmark all the different translation models for a specific source and target language to find out which performs best. This uses 4bit quantization to limit GPU usage. Note: the outcomes are indicative - you cannot assume corretness of the BLEU and CHRF scores but you can compare models against each other relatively.") 17 | parser.add_argument('source_language', type=str, 18 | help='The source language you want to test for. Check your dataset to see which occur most prevalent or use English as a good start.') 19 | parser.add_argument('target_language', type=str, 20 | help='The source language you want to test for. This should be the language you want to apply the translate script on. Note: in benchmark, we use 2-character language codes, in constrast to translate.py where you need to specify whatever your model expects.') 21 | mdls = ', '.join(allowed_models) 22 | parser.add_argument('included_models', type=str, 23 | help=f'Comma-separated list of models to include. Allowed values are: {mdls}') 24 | parser.add_argument('--cpu', action='store_true', 25 | help="Forces usage of CPU. By default GPU is taken if available.") 26 | parser.add_argument('--start', type=int, default=0, 27 | help="The starting offset to include sentences from the OPUS books dataset from. Defaults to 0.") 28 | parser.add_argument('--n', type=int, default=100, 29 | help="The number of sentences to benchmark on. Defaults to 100.") 30 | parser.add_argument('--max_length', type=int, default=512, 31 | help="How much tokens to generate at most. More tokens might be more accurate for lengthy input but creates a risk of running out of memory. Default is 512.") 32 | args = parser.parse_args() 33 | source_language = args.source_language 34 | target_language = args.target_language 35 | included_models = args.included_models 36 | force_cpu = args.cpu 37 | start = args.start 38 | n = args.n 39 | max_length = args.max_length 40 | 41 | # Initialize common parameters 42 | device = torch.device("cuda:0" if torch.cuda.is_available() and not(force_cpu) else "cpu") 43 | # Set up quantization configs if required 44 | quant4_config = BitsAndBytesConfig( 45 | load_in_4bit=True, 46 | bnb_4bit_use_double_quant=True, 47 | bnb_4bit_quant_type="nf4", 48 | bnb_4bit_compute_dtype=torch.bfloat16 49 | ) 50 | 51 | # Initialize scorers 52 | bleu = BLEU() 53 | chrf = CHRF() 54 | 55 | # Handle the models 56 | models = [m.strip() for m in included_models.lower().split(",") if m.strip() in allowed_models] 57 | print(f"[---- LLaMa2Lang ----] Starting benchmarking from {source_language} to {target_language} for models {models} on {n} records on device {device}") 58 | 59 | # Load the OPUS dataset 60 | dataset = load_dataset("opus100", f'{source_language}-{target_language}', split=f'train[{start}:{start+n}]').shuffle().select(range(n)) 61 | 62 | # Process each model one at a time 63 | translator = None 64 | for model in models: 65 | # Clear CUDA 66 | del translator 67 | if str(device).startswith('cuda'): 68 | torch.cuda.empty_cache() 69 | gc.collect() 70 | 71 | # Handle the model naming 72 | model_target_language = target_language 73 | if model.startswith('madlad'): 74 | model_size = model.split('_')[1] 75 | if model_size == '7bbt': 76 | model_size = '7b-bt' 77 | translator = MADLADTranslator(device, True, quant4_config, False, max_length, model_size) 78 | elif model.startswith('m2m'): 79 | model_size = model.split('_')[1] 80 | translator = M2MTranslator(device, True, quant4_config, False, max_length, model_size) 81 | elif model.startswith('mbart'): 82 | translator = mBARTTranslator(device, True, quant4_config, False, max_length) 83 | model_target_language = translator.language_mapping[target_language] 84 | elif model.startswith('nllb'): 85 | model_size = model.split('_')[1][:-1] + model.split('_')[1][-1].upper() 86 | translator = NLLBTranslator(device, True, quant4_config, False, max_length, model_size) 87 | # TODO: Extend this later, there are far more languages 88 | model_target_language = translator.language_mapping[target_language] 89 | elif model.startswith('seamless'): 90 | model_size = 'large' # Currently only one on HF 91 | translator = Seamless_M4T_V2(device, True, quant4_config, False, max_length) 92 | else: 93 | translator = OPUSTranslator(device, False, quant4_config, False, max_length) 94 | 95 | # Run the translations 96 | translated = [] 97 | for s in dataset['translation']: 98 | translated += translator.translate([s[source_language]], source_language, model_target_language) 99 | 100 | # Compute scores, using max_length is not at all correct but it's better than not doing it at all 101 | b_score = bleu.corpus_score([s[:max_length] for s in translated], [[s[target_language][:max_length] for s in dataset['translation']]]) 102 | c_score = chrf.corpus_score([s[:max_length] for s in translated], [[s[target_language][:max_length] for s in dataset['translation']]]) 103 | # Report 104 | print(f"[---- LLaMa2Lang ----] [{model}] BLEU: {b_score.score}") 105 | print(f"[---- LLaMa2Lang ----] [{model}] CHRF: {c_score.score}") 106 | print("") 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /finetune_orpo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from datasets import load_dataset, load_from_disk, Dataset, DatasetDict 4 | from transformers import ( 5 | AutoModelForCausalLM, 6 | AutoTokenizer, 7 | BitsAndBytesConfig, 8 | TrainingArguments, 9 | ) 10 | from trl import ORPOConfig, ORPOTrainer 11 | import argparse 12 | from threads import finetune_prompts_dpo 13 | import pandas as pd 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser(description="Finetune a base instruct/chat model using (Q)LoRA and PEFT using ORPO (RLHF)") 17 | parser.add_argument('tuned_model', type=str, 18 | help='The name of the resulting tuned model.') 19 | parser.add_argument('dataset_name', type=str, 20 | help='The name of the dataset to use for fine-tuning. This should be the output of the combine_checkpoints script.') 21 | parser.add_argument('instruction_prompt', type=str, 22 | help='An instruction message added to every prompt given to the chatbot to force it to answer in the target language. Example: "You are a generic chatbot that always answers in English."') 23 | parser.add_argument('--base_model', type=str, default="NousResearch/Meta-Llama-3-8B-Instruct", 24 | help='The base foundation model. Default is "NousResearch/Meta-Llama-3-8B-Instruct".') 25 | parser.add_argument('--base_dataset_text_field', type=str, default="text", 26 | help="The dataset's column name containing the actual text to translate. Defaults to text") 27 | parser.add_argument('--base_dataset_rank_field', type=str, default="rank", 28 | help="The dataset's column name containing the rank of an answer given to a prompt. Defaults to rank") 29 | parser.add_argument('--base_dataset_id_field', type=str, default="message_id", 30 | help="The dataset's column name containing the id of a text. Defaults to message_id") 31 | parser.add_argument('--base_dataset_parent_field', type=str, default="parent_id", 32 | help="The dataset's column name containing the parent id of a text. Defaults to parent_id") 33 | parser.add_argument('--quant8', action='store_true', 34 | help='Finetunes the model in 8 bits. Requires more memory than the default 4 bit.') 35 | parser.add_argument('--noquant', action='store_true', 36 | help='Do not quantize the finetuning. Requires more memory than the default 4 bit and optional 8 bit.') 37 | parser.add_argument('--max_seq_length', type=int, default=512, 38 | help='The maximum sequence length to use in finetuning. Should most likely line up with your base model\'s default max_seq_length. Default is 512.') 39 | parser.add_argument('--max_prompt_length', type=int, default=512, 40 | help='The maximum length of the prompts to use. Default is 512.') 41 | parser.add_argument('--num_train_epochs', type=int, default=2, 42 | help='Number of epochs to use. 2 is default and has been shown to work well.') 43 | parser.add_argument('--batch_size', type=int, default=4, 44 | help='The batch size to use in finetuning. Adjust to fit in your GPU vRAM. Default is 4') 45 | parser.add_argument('--threads_output_name', type=str, default=None, 46 | help='If specified, the threads created in this script for finetuning will also be saved to disk or HuggingFace Hub.') 47 | parser.add_argument('--thread_template', type=str, default="threads/template_default.txt", 48 | help='A file containing the thread template to use. Default is threads/template_fefault.txt') 49 | parser.add_argument('--max_steps', type=int, default=-1, 50 | help='The maximum number of steps to run ORPO for. Default is -1 which will run the data through fully for the number of epochs but this will be very time-consuming.') 51 | parser.add_argument('--padding', type=str, default="left", 52 | help='What padding to use, can be either left or right.') 53 | 54 | args = parser.parse_args() 55 | base_model = args.base_model 56 | tuned_model = args.tuned_model 57 | dataset_name = args.dataset_name 58 | instruction_prompt = args.instruction_prompt 59 | base_dataset_text_field = args.base_dataset_text_field 60 | base_dataset_rank_field = args.base_dataset_rank_field 61 | base_dataset_id_field = args.base_dataset_id_field 62 | base_dataset_parent_field = args.base_dataset_parent_field 63 | quant8 = args.quant8 64 | noquant = args.noquant 65 | max_seq_length = args.max_seq_length 66 | num_train_epochs = args.num_train_epochs 67 | per_device_train_batch_size = args.batch_size 68 | threads_output_name = args.threads_output_name 69 | thread_template_file = args.thread_template 70 | max_prompt_length = args.max_prompt_length 71 | max_steps = args.max_steps 72 | padding = args.padding 73 | 74 | # Check for HF_TOKEN 75 | if 'HF_TOKEN' not in os.environ: 76 | print("[WARNING] Environment variable 'HF_TOKEN' is not set!") 77 | user_input = input("Do you want to continue? (yes/no): ").strip().lower() 78 | 79 | if user_input != "yes": 80 | print("Terminating the program.") 81 | exit() 82 | 83 | # Load the base translated dataset 84 | if os.path.isdir(dataset_name): 85 | dataset = load_from_disk(dataset_name) 86 | else: 87 | dataset = load_dataset(dataset_name) 88 | 89 | # Load base tokenizer 90 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) 91 | # Get the template 92 | with open(thread_template_file, 'r', encoding="utf8") as f: 93 | chat_template = f.read() 94 | 95 | # Compute the threads 96 | prompts = {k: [] for k in dataset.keys()} 97 | for fold in prompts: 98 | print(f"[---- LLaMa2Lang ----] Generating prompts using chat template {thread_template_file} for fold {fold}") 99 | templated_prompts = finetune_prompts_dpo.create_prompts(dataset[fold], tokenizer, base_dataset_rank_field, base_dataset_parent_field, base_dataset_id_field, base_dataset_text_field, instruction_prompt, chat_template) 100 | prompts[fold] = Dataset.from_pandas(pd.DataFrame(data=templated_prompts)) 101 | 102 | prompts = DatasetDict(prompts) 103 | # Check if we need to write out 104 | if threads_output_name is not None: 105 | # Also do the other folds 106 | print(f"[---- LLaMa2Lang ----] Writing out ORPO thread prompts dataset to {threads_output_name}") 107 | if os.path.isdir(threads_output_name): 108 | prompts.save_to_disk(threads_output_name) 109 | else: 110 | prompts.push_to_hub(threads_output_name) 111 | 112 | if noquant: 113 | # Load base model 114 | model = AutoModelForCausalLM.from_pretrained(base_model, device_map={"": 0}, trust_remote_code=True) 115 | elif quant8: 116 | quant_config = BitsAndBytesConfig( 117 | load_in_8bit=True, 118 | bnb_8bit_quant_type="qat8", 119 | bnb_8bit_compute_dtype=getattr(torch, "float32"), 120 | bnb_8bit_use_double_quant=False 121 | ) 122 | # Load base model 123 | model = AutoModelForCausalLM.from_pretrained(base_model, quantization_config=quant_config, device_map={"": 0}, trust_remote_code=True) 124 | else: 125 | # Set up quantization config 126 | quant_config = BitsAndBytesConfig( 127 | load_in_4bit=True, 128 | bnb_4bit_quant_type="nf4", 129 | bnb_4bit_compute_dtype=getattr(torch, "float16"), 130 | bnb_4bit_use_double_quant=True, 131 | ) 132 | # Load base model 133 | model = AutoModelForCausalLM.from_pretrained(base_model, quantization_config=quant_config, device_map={"": 0}, trust_remote_code=True) 134 | 135 | model.config.use_cache = False 136 | model.config.pretraining_tp = 1 137 | 138 | # Load base tokenizer 139 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) 140 | # Just like Alpaca, because we allow to add history in the prompts, it makes more sense to do left-padding to have the most informative text at the end. 141 | # In this case, we need a different pad token than EOS because we actually do _not_ pad end of sentence. 142 | if padding == 'left': 143 | tokenizer.pad_token_id = 0 144 | else: 145 | tokenizer.pad_token_id = tokenizer.eos_token_id 146 | tokenizer.padding_side = padding 147 | 148 | orpo_config = ORPOConfig( 149 | num_train_epochs=num_train_epochs, 150 | per_device_train_batch_size=per_device_train_batch_size, 151 | gradient_accumulation_steps=1, 152 | gradient_checkpointing=True, 153 | learning_rate=5e-5, 154 | lr_scheduler_type="cosine", 155 | max_steps=max_steps, 156 | save_strategy="no", 157 | logging_steps=1, 158 | output_dir="./results", 159 | optim="paged_adamw_32bit", 160 | warmup_steps=100, 161 | bf16=True, 162 | report_to=None, 163 | remove_unused_columns=False, 164 | beta=0.1, # the lambda/alpha hyperparameter in the paper/code 165 | target_modules='all-linear', 166 | ) 167 | 168 | trainer = ORPOTrainer( 169 | model, 170 | args=orpo_config, 171 | train_dataset=prompts['train'], 172 | tokenizer=tokenizer, 173 | ) 174 | 175 | # Before starting training, free up memory 176 | torch.cuda.empty_cache() 177 | # Train the ORP model 178 | trainer.train() 179 | 180 | # Check if output location is a valid directory 181 | print(f"[---- LLaMa2Lang ----] Writing model and tokenizer out to {tuned_model}") 182 | if os.path.isdir(tuned_model): 183 | trainer.model.save_to_disk(tuned_model) 184 | trainer.tokenizer.save_to_disk(tuned_model) 185 | else: 186 | # Try to push to hub, requires HF_TOKEN environment variable to be set, see https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hftoken 187 | trainer.model.push_to_hub(tuned_model) 188 | trainer.tokenizer.push_to_hub(tuned_model) 189 | 190 | 191 | if __name__ == "__main__": 192 | main() -------------------------------------------------------------------------------- /finetune_dpo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from datasets import load_dataset, load_from_disk, Dataset, DatasetDict 4 | from transformers import ( 5 | AutoModelForCausalLM, 6 | AutoTokenizer, 7 | BitsAndBytesConfig, 8 | TrainingArguments, 9 | ) 10 | from peft import LoraConfig 11 | from trl import DPOTrainer 12 | import sys 13 | import argparse 14 | from threads import finetune_prompts_dpo 15 | import pandas as pd 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser(description="Finetune a base instruct/chat model using (Q)LoRA and PEFT using DPO (RLHF)") 19 | parser.add_argument('tuned_model', type=str, 20 | help='The name of the resulting tuned model.') 21 | parser.add_argument('dataset_name', type=str, 22 | help='The name of the dataset to use for fine-tuning. This should be the output of the combine_checkpoints script.') 23 | parser.add_argument('instruction_prompt', type=str, 24 | help='An instruction message added to every prompt given to the chatbot to force it to answer in the target language. Example: "You are a generic chatbot that always answers in English."') 25 | parser.add_argument('--base_model', type=str, default="NousResearch/Meta-Llama-3-8B-Instruct", 26 | help='The base foundation model. Default is "NousResearch/Meta-Llama-3-8B-Instruct".') 27 | parser.add_argument('--base_dataset_text_field', type=str, default="text", 28 | help="The dataset's column name containing the actual text to translate. Defaults to text") 29 | parser.add_argument('--base_dataset_rank_field', type=str, default="rank", 30 | help="The dataset's column name containing the rank of an answer given to a prompt. Defaults to rank") 31 | parser.add_argument('--base_dataset_id_field', type=str, default="message_id", 32 | help="The dataset's column name containing the id of a text. Defaults to message_id") 33 | parser.add_argument('--base_dataset_parent_field', type=str, default="parent_id", 34 | help="The dataset's column name containing the parent id of a text. Defaults to parent_id") 35 | parser.add_argument('--quant8', action='store_true', 36 | help='Finetunes the model in 8 bits. Requires more memory than the default 4 bit.') 37 | parser.add_argument('--noquant', action='store_true', 38 | help='Do not quantize the finetuning. Requires more memory than the default 4 bit and optional 8 bit.') 39 | parser.add_argument('--max_seq_length', type=int, default=512, 40 | help='The maximum sequence length to use in finetuning. Should most likely line up with your base model\'s default max_seq_length. Default is 512.') 41 | parser.add_argument('--max_prompt_length', type=int, default=512, 42 | help='The maximum length of the prompts to use. Default is 512.') 43 | parser.add_argument('--num_train_epochs', type=int, default=2, 44 | help='Number of epochs to use. 2 is default and has been shown to work well.') 45 | parser.add_argument('--batch_size', type=int, default=4, 46 | help='The batch size to use in finetuning. Adjust to fit in your GPU vRAM. Default is 4') 47 | parser.add_argument('--threads_output_name', type=str, default=None, 48 | help='If specified, the threads created in this script for finetuning will also be saved to disk or HuggingFace Hub.') 49 | parser.add_argument('--thread_template', type=str, default="threads/template_default.txt", 50 | help='A file containing the thread template to use. Default is threads/template_fefault.txt') 51 | parser.add_argument('--max_steps', type=int, default=-1, 52 | help='The maximum number of steps to run DPO for. Default is -1 which will run the data through fully for the number of epochs but this will be very time-consuming.') 53 | parser.add_argument('--padding', type=str, default="left", 54 | help='What padding to use, can be either left or right.') 55 | 56 | args = parser.parse_args() 57 | base_model = args.base_model 58 | tuned_model = args.tuned_model 59 | dataset_name = args.dataset_name 60 | instruction_prompt = args.instruction_prompt 61 | base_dataset_text_field = args.base_dataset_text_field 62 | base_dataset_rank_field = args.base_dataset_rank_field 63 | base_dataset_id_field = args.base_dataset_id_field 64 | base_dataset_parent_field = args.base_dataset_parent_field 65 | quant8 = args.quant8 66 | noquant = args.noquant 67 | max_seq_length = args.max_seq_length 68 | num_train_epochs = args.num_train_epochs 69 | per_device_train_batch_size = args.batch_size 70 | threads_output_name = args.threads_output_name 71 | thread_template_file = args.thread_template 72 | max_prompt_length = args.max_prompt_length 73 | max_steps = args.max_steps 74 | padding = args.padding 75 | 76 | # Check for HF_TOKEN 77 | if 'HF_TOKEN' not in os.environ: 78 | print("[WARNING] Environment variable 'HF_TOKEN' is not set!") 79 | user_input = input("Do you want to continue? (yes/no): ").strip().lower() 80 | 81 | if user_input != "yes": 82 | print("Terminating the program.") 83 | exit() 84 | 85 | # Load the base translated dataset 86 | if os.path.isdir(dataset_name): 87 | dataset = load_from_disk(dataset_name) 88 | else: 89 | dataset = load_dataset(dataset_name) 90 | 91 | # Load base tokenizer 92 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) 93 | # Get the template 94 | with open(thread_template_file, 'r', encoding="utf8") as f: 95 | chat_template = f.read() 96 | 97 | # Compute the threads 98 | prompts = {k: [] for k in dataset.keys()} 99 | for fold in prompts: 100 | print(f"[---- LLaMa2Lang ----] Generating prompts using chat template {thread_template_file} for fold {fold}") 101 | templated_prompts = finetune_prompts_dpo.create_prompts(dataset[fold], tokenizer, base_dataset_rank_field, base_dataset_parent_field, base_dataset_id_field, base_dataset_text_field, instruction_prompt, chat_template) 102 | prompts[fold] = Dataset.from_pandas(pd.DataFrame(data=templated_prompts)) 103 | 104 | prompts = DatasetDict(prompts) 105 | # Check if we need to write out 106 | if threads_output_name is not None: 107 | # Also do the other folds 108 | print(f"[---- LLaMa2Lang ----] Writing out DPO thread prompts dataset to {threads_output_name}") 109 | if os.path.isdir(threads_output_name): 110 | prompts.save_to_disk(threads_output_name) 111 | else: 112 | prompts.push_to_hub(threads_output_name) 113 | 114 | if noquant: 115 | # Load base model 116 | model = AutoModelForCausalLM.from_pretrained(base_model, device_map={"": 0}, trust_remote_code=True) 117 | elif quant8: 118 | quant_config = BitsAndBytesConfig( 119 | load_in_8bit=True, 120 | bnb_8bit_quant_type="qat8", 121 | bnb_8bit_compute_dtype=getattr(torch, "float32"), 122 | bnb_8bit_use_double_quant=False 123 | ) 124 | # Load base model 125 | model = AutoModelForCausalLM.from_pretrained(base_model, quantization_config=quant_config, device_map={"": 0}, trust_remote_code=True) 126 | else: 127 | # Set up quantization config 128 | quant_config = BitsAndBytesConfig( 129 | load_in_4bit=True, 130 | bnb_4bit_quant_type="nf4", 131 | bnb_4bit_compute_dtype=getattr(torch, "float16"), 132 | bnb_4bit_use_double_quant=True, 133 | ) 134 | # Load base model 135 | model = AutoModelForCausalLM.from_pretrained(base_model, quantization_config=quant_config, device_map={"": 0}, trust_remote_code=True) 136 | 137 | model.config.use_cache = False 138 | model.config.pretraining_tp = 1 139 | 140 | # Load base tokenizer 141 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) 142 | # Just like Alpaca, because we allow to add history in the prompts, it makes more sense to do left-padding to have the most informative text at the end. 143 | # In this case, we need a different pad token than EOS because we actually do _not_ pad end of sentence. 144 | if padding == 'left': 145 | tokenizer.pad_token_id = 0 146 | else: 147 | tokenizer.pad_token_id = tokenizer.eos_token_id 148 | tokenizer.padding_side = padding 149 | 150 | # Set up LoRA configuration 151 | peft_params = LoraConfig( 152 | lora_alpha=16, 153 | lora_dropout=0.1, 154 | r=64, 155 | bias="none", 156 | task_type="CAUSAL_LM", 157 | target_modules='all-linear', 158 | ) 159 | 160 | # Pass quant and lora to trainer 161 | training_params = TrainingArguments( 162 | num_train_epochs=num_train_epochs, 163 | per_device_train_batch_size=per_device_train_batch_size, 164 | gradient_accumulation_steps=1, 165 | gradient_checkpointing=True, 166 | learning_rate=5e-5, 167 | lr_scheduler_type="cosine", 168 | max_steps=max_steps, 169 | save_strategy="no", 170 | logging_steps=1, 171 | output_dir="./results", 172 | optim="paged_adamw_32bit", 173 | warmup_steps=100, 174 | bf16=True, 175 | report_to=None, 176 | remove_unused_columns=False 177 | ) 178 | 179 | trainer = DPOTrainer( 180 | model=model, 181 | ref_model=None, 182 | args=training_params, 183 | train_dataset=prompts['train'], 184 | tokenizer=tokenizer, 185 | peft_config=peft_params, 186 | beta=0.1, 187 | max_prompt_length=max_prompt_length, 188 | max_length=max_seq_length, 189 | ) 190 | 191 | # Before starting training, free up memory 192 | torch.cuda.empty_cache() 193 | # Train the DPO model 194 | trainer.train() 195 | 196 | # Check if output location is a valid directory 197 | print(f"[---- LLaMa2Lang ----] Writing model and tokenizer out to {tuned_model}") 198 | if os.path.isdir(tuned_model): 199 | trainer.model.save_to_disk(tuned_model) 200 | trainer.tokenizer.save_to_disk(tuned_model) 201 | else: 202 | # Try to push to hub, requires HF_TOKEN environment variable to be set, see https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hftoken 203 | trainer.model.push_to_hub(tuned_model) 204 | trainer.tokenizer.push_to_hub(tuned_model) 205 | 206 | 207 | if __name__ == "__main__": 208 | main() -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from datasets import load_dataset, load_from_disk 4 | from transformers import ( 5 | AutoModelForCausalLM, 6 | AutoTokenizer, 7 | BitsAndBytesConfig, 8 | TrainingArguments, 9 | ) 10 | from peft import LoraConfig 11 | from trl import SFTTrainer 12 | import sys 13 | import argparse 14 | from threads import finetune_prompts 15 | from datasets import Dataset, DatasetDict 16 | import pandas as pd 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser(description="Finetune a base instruct/chat model using (Q)LoRA and PEFT") 20 | parser.add_argument('tuned_model', type=str, 21 | help='The name of the resulting tuned model.') 22 | parser.add_argument('dataset_name', type=str, 23 | help='The name of the dataset to use for fine-tuning. This should be the output of the combine_checkpoints script.') 24 | parser.add_argument('instruction_prompt', type=str, 25 | help='An instruction message added to every prompt given to the chatbot to force it to answer in the target language. Example: "You are a generic chatbot that always answers in English."') 26 | parser.add_argument('--base_model', type=str, default="NousResearch/Meta-Llama-3-8B-Instruct", 27 | help='The base foundation model. Default is "NousResearch/Meta-Llama-3-8B-Instruct".') 28 | parser.add_argument('--base_dataset_text_field', type=str, default="text", 29 | help="The dataset's column name containing the actual text to translate. Defaults to text") 30 | parser.add_argument('--base_dataset_rank_field', type=str, default="rank", 31 | help="The dataset's column name containing the rank of an answer given to a prompt. Defaults to rank") 32 | parser.add_argument('--base_dataset_id_field', type=str, default="message_id", 33 | help="The dataset's column name containing the id of a text. Defaults to message_id") 34 | parser.add_argument('--base_dataset_parent_field', type=str, default="parent_id", 35 | help="The dataset's column name containing the parent id of a text. Defaults to parent_id") 36 | parser.add_argument('--base_dataset_role_field', type=str, default="role", 37 | help="The dataset's column name containing the role of the author of the text (eg. prompter, assistant). Defaults to role") 38 | parser.add_argument('--quant8', action='store_true', 39 | help='Finetunes the model in 8 bits. Requires more memory than the default 4 bit.') 40 | parser.add_argument('--noquant', action='store_true', 41 | help='Do not quantize the finetuning. Requires more memory than the default 4 bit and optional 8 bit.') 42 | parser.add_argument('--max_seq_length', type=int, default=512, 43 | help='The maximum sequence length to use in finetuning. Should most likely line up with your base model\'s default max_seq_length. Default is 512.') 44 | parser.add_argument('--num_train_epochs', type=int, default=2, 45 | help='Number of epochs to use. 2 is default and has been shown to work well.') 46 | parser.add_argument('--batch_size', type=int, default=4, 47 | help='The batch size to use in finetuning. Adjust to fit in your GPU vRAM. Default is 4') 48 | parser.add_argument('--threads_output_name', type=str, default=None, 49 | help='If specified, the threads created in this script for finetuning will also be saved to disk or HuggingFace Hub.') 50 | parser.add_argument('--thread_template', type=str, default="threads/template_default.txt", 51 | help='A file containing the thread template to use. Default is threads/template_fefault.txt') 52 | parser.add_argument('--padding', type=str, default="left", 53 | help='What padding to use, can be either left or right.') 54 | parser.add_argument('--force_cpu', action='store_true', 55 | help="Forces usage of CPU. By default GPU is taken if available.") 56 | 57 | args = parser.parse_args() 58 | base_model = args.base_model 59 | tuned_model = args.tuned_model 60 | dataset_name = args.dataset_name 61 | instruction_prompt = args.instruction_prompt 62 | base_dataset_text_field = args.base_dataset_text_field 63 | base_dataset_rank_field = args.base_dataset_rank_field 64 | base_dataset_id_field = args.base_dataset_id_field 65 | base_dataset_parent_field = args.base_dataset_parent_field 66 | base_dataset_role_field = args.base_dataset_role_field 67 | quant8 = args.quant8 68 | noquant = args.noquant 69 | max_seq_length = args.max_seq_length 70 | num_train_epochs = args.num_train_epochs 71 | per_device_train_batch_size = args.batch_size 72 | threads_output_name = args.threads_output_name 73 | thread_template_file = args.thread_template 74 | padding = args.padding 75 | force_cpu = args.force_cpu 76 | device = torch.device("cuda:0" if torch.cuda.is_available() and not (force_cpu) else "cpu") 77 | 78 | # Check for HF_TOKEN 79 | if 'HF_TOKEN' not in os.environ: 80 | print("[WARNING] Environment variable 'HF_TOKEN' is not set!") 81 | user_input = input("Do you want to continue? (yes/no): ").strip().lower() 82 | 83 | if user_input != "yes": 84 | print("Terminating the program.") 85 | exit() 86 | 87 | # Load the base translated dataset 88 | if os.path.isdir(dataset_name): 89 | dataset = load_from_disk(dataset_name) 90 | else: 91 | dataset = load_dataset(dataset_name) 92 | 93 | # Load base tokenizer 94 | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) 95 | # Get the template 96 | with open(thread_template_file, 'r', encoding="utf8") as f: 97 | chat_template = f.read() 98 | 99 | # Compute the threads 100 | prompts = {k: [] for k in dataset.keys()} 101 | for fold in prompts: 102 | print(f"[---- LLaMa2Lang ----] Generating prompts using chat template {thread_template_file} for fold {fold}") 103 | templated_prompts = finetune_prompts.create_prompts(dataset[fold], tokenizer, base_dataset_rank_field, base_dataset_parent_field, base_dataset_id_field, base_dataset_text_field, base_dataset_role_field, instruction_prompt, chat_template) 104 | prompts[fold] = Dataset.from_pandas(pd.DataFrame(data=templated_prompts)) 105 | 106 | prompts = DatasetDict(prompts) 107 | # Check if we need to write out 108 | if threads_output_name is not None: 109 | # Also do the other folds 110 | print(f"[---- LLaMa2Lang ----] Writing out thread prompts dataset to {threads_output_name}") 111 | if os.path.isdir(threads_output_name): 112 | prompts.save_to_disk(threads_output_name) 113 | else: 114 | prompts.push_to_hub(threads_output_name) 115 | 116 | if noquant: 117 | # Load base model 118 | model = AutoModelForCausalLM.from_pretrained(base_model, device_map=device, trust_remote_code=True) 119 | elif quant8: 120 | quant_config = BitsAndBytesConfig( 121 | load_in_8bit=True, 122 | bnb_8bit_quant_type="qat8", 123 | bnb_8bit_compute_dtype=getattr(torch, "float32"), 124 | bnb_8bit_use_double_quant=False 125 | ) 126 | # Load base model 127 | model = AutoModelForCausalLM.from_pretrained(base_model, quantization_config=quant_config, device_map=device, trust_remote_code=True) 128 | else: 129 | # Set up quantization config 130 | quant_config = BitsAndBytesConfig( 131 | load_in_4bit=True, 132 | bnb_4bit_quant_type="nf4", 133 | bnb_4bit_compute_dtype=getattr(torch, "float16"), 134 | bnb_4bit_use_double_quant=True, 135 | ) 136 | # Load base model 137 | model = AutoModelForCausalLM.from_pretrained(base_model, quantization_config=quant_config, device_map=device, trust_remote_code=True) 138 | 139 | model.config.use_cache = False 140 | model.config.pretraining_tp = 1 141 | 142 | # Just like Alpaca, because we allow to add history in the prompts, it makes more sense to do left-padding to have the most informative text at the end. 143 | # In this case, we need a different pad token than EOS because we actually do _not_ pad end of sentence. 144 | if padding == 'left': 145 | tokenizer.pad_token_id = 0 146 | else: 147 | tokenizer.pad_token_id = tokenizer.eos_token_id 148 | tokenizer.padding_side = padding 149 | 150 | # Set up LoRA configuration 151 | peft_params = LoraConfig( 152 | lora_alpha=16, 153 | lora_dropout=0.1, 154 | r=64, 155 | bias="none", 156 | task_type="CAUSAL_LM", 157 | target_modules = 'all-linear', 158 | ) 159 | 160 | # Pass quant and lora to trainer 161 | use_fp16 = not(noquant or quant8) 162 | training_params = TrainingArguments( 163 | output_dir="./results", 164 | num_train_epochs=num_train_epochs, 165 | per_device_train_batch_size=per_device_train_batch_size, 166 | gradient_accumulation_steps=1, 167 | optim="paged_adamw_32bit", 168 | save_steps=1000, 169 | logging_steps=500, 170 | learning_rate=2e-4, 171 | weight_decay=0.001, 172 | fp16=use_fp16, 173 | bf16=False, 174 | max_grad_norm=0.3, 175 | max_steps=-1, 176 | warmup_ratio=0.03, 177 | group_by_length=True, 178 | lr_scheduler_type="constant", 179 | report_to="tensorboard" 180 | ) 181 | trainer = SFTTrainer( 182 | model=model, 183 | train_dataset=prompts['train'], 184 | peft_config=peft_params, 185 | dataset_text_field=base_dataset_text_field, 186 | max_seq_length=max_seq_length, 187 | tokenizer=tokenizer, 188 | args=training_params, 189 | packing=False, 190 | ) 191 | 192 | # Before starting training, free up memory 193 | torch.cuda.empty_cache() 194 | print(f"[---- LLaMa2Lang ----] Starting training") 195 | # Train the model 196 | trainer.train() 197 | 198 | # Check if output location is a valid directory 199 | print(f"[---- LLaMa2Lang ----] Writing model and tokenizer out to {tuned_model}") 200 | if os.path.isdir(tuned_model): 201 | trainer.model.save_to_disk(tuned_model) 202 | trainer.tokenizer.save_to_disk(tuned_model) 203 | else: 204 | # Try to push to hub, requires HF_TOKEN environment variable to be set, see https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hftoken 205 | trainer.model.push_to_hub(tuned_model) 206 | trainer.tokenizer.push_to_hub(tuned_model) 207 | 208 | if __name__ == "__main__": 209 | main() 210 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from datasets import load_dataset 4 | from transformers import BitsAndBytesConfig 5 | import json 6 | import re 7 | import gc 8 | from tqdm import tqdm 9 | import argparse 10 | from translators.m2m import M2MTranslator 11 | from translators.madlad import MADLADTranslator 12 | from translators.mbart import mBARTTranslator 13 | from translators.nllb import NLLBTranslator 14 | from translators.opus import OPUSTranslator 15 | from translators.seamless_m4t_v2 import Seamless_M4T_V2 16 | from translators.towerinstruct import TowerInstructTranslator 17 | from translators.gemini_pro import GeminiProTranslator 18 | 19 | 20 | # Find the max checkpoint number to continue from 21 | def find_largest_checkpoint(checkpoint_location): 22 | pattern = r'upto_(\d+).json' 23 | files = os.listdir(checkpoint_location) 24 | numbers = [int(re.search(pattern, file).group(1)) for file in files if re.match(pattern, file)] 25 | if numbers: 26 | return max(numbers) 27 | else: 28 | return 0 29 | 30 | 31 | # Group all records in a dataset by language so we can use a single model in a batched fashion 32 | def group_records_by_language(dataset, lang_field): 33 | grouped_records = {} 34 | for record in dataset: 35 | lang = record[lang_field] 36 | if lang not in grouped_records: 37 | grouped_records[lang] = [] 38 | grouped_records[lang].append(record) 39 | return grouped_records 40 | 41 | 42 | def main(): 43 | parser = argparse.ArgumentParser( 44 | description="Translate an instruct/RLHF dataset to a given target language using a variety of translation models") 45 | subparsers = parser.add_subparsers(dest='model', help='The model/architecture used for translation.') 46 | 47 | parser.add_argument('target_lang', type=str, 48 | help="The target language. Make sure you use language codes defined by the translation model you are using.") 49 | parser.add_argument('checkpoint_location', type=str, 50 | help="The folder the script will write (JSONized) checkpoint files to. Folder will be created if it doesn't exist.") 51 | 52 | parser.add_argument('--quant8', action='store_true', 53 | help='Optional flag to load the translation model in 8 bits. Decreases memory usage, increases running time') 54 | parser.add_argument('--quant4', action='store_true', 55 | help='Optional flag to load the translation model in 4 bits. Decreases memory usage, increases running time') 56 | parser.add_argument('--base_dataset', type=str, default="OpenAssistant/oasst1", 57 | help="The base dataset to translate, defaults to OpenAssistant/oasst1") 58 | parser.add_argument('--base_dataset_text_field', type=str, default="text", 59 | help="The base dataset's column name containing the actual text to translate. Defaults to text") 60 | parser.add_argument('--base_dataset_lang_field', type=str, default="lang", 61 | help="The base dataset's column name containing the language the source text was written in. Defaults to lang") 62 | parser.add_argument('--checkpoint_n', type=int, default=400, 63 | help="An integer representing how often a checkpoint file will be written out. To start off, 400 is a reasonable number.") 64 | parser.add_argument('--batch_size', type=int, default=10, 65 | help="The batch size for a single translation model. Adjust based on your GPU capacity. Default is 10.") 66 | parser.add_argument('--max_length', type=int, default=None, 67 | help='How much tokens to generate at most. More tokens might be more accurate for lengthy input but creates a risk of running out of memory. Default is unlimited.') 68 | parser.add_argument('--cpu', action='store_true', 69 | help="Forces usage of CPU. By default GPU is taken if available.") 70 | parser.add_argument('--source_lang', type=str, default=None, 71 | help="Source language to select from OASST based on lang property of dataset") 72 | parser.add_argument('--start_index', type=int, default=None, 73 | help="Set start index for processing in dataset by range") 74 | parser.add_argument('--end_index', type=int, default=None, 75 | help="Set end index for processing in dataset by range") 76 | 77 | parser_opus = subparsers.add_parser('opus', help='Translate the dataset using HelsinkiNLP OPUS models.') 78 | 79 | parser_mbart = subparsers.add_parser('mbart', help='Translate the dataset using mBART.') 80 | 81 | parser_madlad = subparsers.add_parser('madlad', help='Translate the dataset using Google\'s MADLAD models.') 82 | parser_madlad.add_argument('--model_size', type=str, default="3b", choices=['3b', '7b', '7b-bt', '10b'], 83 | help='The size of the MADLAD model to use. 7b-bt is the backtrained version (best to avoid unless you know what you are doing).') 84 | 85 | parser_m2m = subparsers.add_parser('m2m', help='Translate the dataset using Facebook\'s M2M models.') 86 | parser_m2m.add_argument('--model_size', type=str, default="418M", choices=['418M', '1.2B'], 87 | help='The size of the M2M model to use. Default is 418M') 88 | 89 | parser_nllb = subparsers.add_parser('nllb', help='Translate the dataset using Facebook\'s NLLB models.') 90 | parser_nllb.add_argument('--model_size', type=str, default="distilled-600M", 91 | choices=['distilled-600M', '1.3B', 'distilled-1.3B', '3.3B'], 92 | help='The size of the NLLB model to use. Default is distilled-600M') 93 | 94 | parser_seamlessv2 = subparsers.add_parser('seamless_m4t_v2', 95 | help='Translate the dataset using Facebook\'s SeamlessM4T-v2 multimodal models.') 96 | 97 | parser_towerinstruct = subparsers.add_parser('towerinstruct', help='Translate the dataset using Unbabel\'s Tower Instruct. Make sure your target language is in the 10 languages supported by the model.') 98 | 99 | parser_gemini_pro = subparsers.add_parser('gemini_pro', help='Gemini Pro translation model') 100 | 101 | parser_gemini_pro.add_argument('--auth_token', type=str, default=None, 102 | help='Gemini Pro retrieved here https://makersuite.google.com/app/apikey') 103 | # Default arguments shared across models 104 | args = parser.parse_args() 105 | model = args.model 106 | target_lang = args.target_lang 107 | checkpoint_location = args.checkpoint_location 108 | quant4 = args.quant4 109 | quant8 = args.quant8 110 | base_dataset = args.base_dataset 111 | base_dataset_text_field = args.base_dataset_text_field 112 | base_dataset_lang_field = args.base_dataset_lang_field 113 | checkpoint_n = args.checkpoint_n 114 | batch_size = args.batch_size 115 | force_cpu = args.cpu 116 | selected_source_language = args.source_lang 117 | start_index = args.start_index 118 | end_index = args.end_index 119 | 120 | device = torch.device("cuda:0" if torch.cuda.is_available() and not (force_cpu) else "cpu") 121 | 122 | if checkpoint_n % batch_size != 0: 123 | raise Exception("Checkpoint N must be a multiple of batch size!") 124 | 125 | # Load the base dataset that we want to translate 126 | dataset = load_dataset(base_dataset) 127 | 128 | # Set up quantization configs if required 129 | quant4_config = BitsAndBytesConfig( 130 | load_in_4bit=True, 131 | bnb_4bit_use_double_quant=True, 132 | bnb_4bit_quant_type="nf4", 133 | bnb_4bit_compute_dtype=torch.bfloat16 134 | ) 135 | 136 | print(f"[---- LLaMa2Lang ----] Starting translation of {base_dataset} using {model} on device {device}") 137 | 138 | # Load the correct model 139 | if model == 'madlad': 140 | translator = MADLADTranslator(device, quant4, quant4_config, quant8, args.max_length, args.model_size) 141 | elif model == 'm2m': 142 | translator = M2MTranslator(device, quant4, quant4_config, quant8, args.max_length, args.model_size) 143 | elif model == 'mbart': 144 | translator = mBARTTranslator(device, quant4, quant4_config, quant8, args.max_length) 145 | elif model == 'nllb': 146 | translator = NLLBTranslator(device, quant4, quant4_config, quant8, args.max_length, args.model_size) 147 | elif model == 'seamless_m4t_v2': 148 | translator = Seamless_M4T_V2(device, quant4, quant4_config, quant8, args.max_length) 149 | elif model == 'towerinstruct': 150 | translator = TowerInstructTranslator(device, quant4, quant4_config, quant8, args.max_length) 151 | elif model == 'gemini_pro': 152 | translator = GeminiProTranslator(args.auth_token, args.max_length) 153 | else: 154 | translator = OPUSTranslator(device, quant4, quant4_config, quant8, args.max_length) 155 | 156 | # Loop through the actual data and translate 157 | with tqdm(total=sum(len(split) for split in dataset.values())) as pbar: 158 | for fold in dataset: 159 | records_by_lang = group_records_by_language(dataset[fold], base_dataset_lang_field) 160 | if selected_source_language is not None: 161 | records = records_by_lang[selected_source_language] 162 | translate_records(base_dataset_lang_field, base_dataset_text_field, batch_size, checkpoint_location, 163 | checkpoint_n, device, fold, pbar, records, selected_source_language, target_lang, translator, 164 | last_checkpoint=start_index, end_of_range=end_index) 165 | else: 166 | for source_lang, records in records_by_lang.items(): 167 | translate_records(base_dataset_lang_field, base_dataset_text_field, batch_size, checkpoint_location, 168 | checkpoint_n, device, fold, pbar, records, source_lang, target_lang, translator, 169 | last_checkpoint=start_index, end_of_range=end_index) 170 | # One source language down, release the memory 171 | gc.collect() 172 | if str(device).startswith('cuda'): 173 | torch.cuda.empty_cache() 174 | 175 | 176 | def translate_records(base_dataset_lang_field, base_dataset_text_field, batch_size, checkpoint_location, checkpoint_n, 177 | device, fold, pbar, records, source_lang, target_lang, translator, last_checkpoint = None, 178 | end_of_range = None): 179 | lang_checkpoint_location = os.path.join(checkpoint_location, fold, f'from_{source_lang}') 180 | os.makedirs(lang_checkpoint_location, exist_ok=True) 181 | last_checkpoint_n = last_checkpoint if last_checkpoint is not None else find_largest_checkpoint(lang_checkpoint_location) 182 | translated_texts = [] 183 | records_length = len(records) if end_of_range is None else end_of_range 184 | print( 185 | f'[---- LLaMa2Lang ----] Got {len(records)} records for source language {source_lang}, skipping {last_checkpoint_n}, will process till {records_length}') 186 | pbar.total = records_length 187 | pbar.update(last_checkpoint_n) 188 | last_cnt = last_checkpoint_n 189 | for cnt in range(last_checkpoint_n, records_length, batch_size): 190 | # Translate a full batch 191 | batch = records[cnt:cnt + batch_size] 192 | texts_to_translate = [record[base_dataset_text_field] for record in batch] 193 | # Offload translation to class implementation 194 | translated_batch = translator.translate(texts_to_translate, source_lang, target_lang) 195 | if translated_batch is not None: 196 | # Combine original record with translated text 197 | for record, translation in zip(batch, translated_batch): 198 | record[base_dataset_text_field] = translation 199 | record[base_dataset_lang_field] = target_lang 200 | translated_texts.append(record) 201 | 202 | pbar.update(batch_size) 203 | 204 | # Write out checkpoint file 205 | if (cnt + batch_size) % checkpoint_n == 0 and cnt != 0: 206 | print( 207 | f"[---- LLaMa2Lang ----] Writing out checkpoint #{str(cnt + batch_size)} for source language {source_lang}") 208 | with open(os.path.join(lang_checkpoint_location, f'upto_{str(cnt + batch_size)}.json'), 'w', 209 | encoding='utf-8') as f: 210 | json.dump(translated_texts, f) 211 | translated_texts = [] 212 | # Free some memory 213 | gc.collect() 214 | if str(device).startswith('cuda'): 215 | torch.cuda.empty_cache() 216 | last_cnt = cnt 217 | # Write checkpoint 218 | batch = records[last_cnt:] 219 | checkpoint_file = os.path.join(lang_checkpoint_location, f'upto_{last_cnt}.json') 220 | with open(checkpoint_file, 'w', encoding='utf-8') as f: 221 | json.dump(batch, f) 222 | 223 | 224 | if __name__ == "__main__": 225 | main() 226 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🚀 Now with LLaMa3 support 🚀 2 | 3 | 4 | # LLaMa2lang v0.6 5 | This repository contains convenience scripts to finetune LLaMa3-8B (or any other foundation model) for chat towards any language (that isn't English). The rationale behind this is that LLaMa3 is trained on primarily English data and while it works to some extent for other languages, its performance is poor compared to English. 6 | 7 | Combine the power of fine-tuning with the power of RAG - check out our [RAG Me Up repository](https://github.com/UnderstandLingBV/RAGMeUp) on RAG which can be used on top of your models tuned with LLaMa2Lang. 8 | 9 | # TL;DR 10 | 11 | ``` 12 | pip install -r requirements.txt 13 | 14 | # Translate OASST1 to target language 15 | python translate.py m2m target_lang checkpoint_location 16 | 17 | # Combine the checkpoint files into a dataset 18 | python combine_checkpoints.py input_folder output_location 19 | 20 | # Finetune 21 | python finetune.py tuned_model dataset_name instruction_prompt 22 | 23 | # Optionally finetune with DPO (RLHF) 24 | python finetune_dpo.py tuned_model dataset_name instruction_prompt 25 | 26 | # Run inference 27 | python run_inference.py model_name instruction_prompt input 28 | ``` 29 | 30 | # What it does 31 | The process we follow to tune a foundation model such as LLaMa3 for a specific language is as follows: 32 | 33 | 1. Load a dataset that contains Q&A/instruction pairs. 34 | 2. Translate the entire dataset to a given target language. 35 | 3. Load the translated dataset and extract threads by recursively selecting prompts with their respective answers with the highest rank only, through to subsequent prompts, etc. 36 | 4. Turn the threads into prompts following a given template (customizable). 37 | 5. Use QLoRA and PEFT to finetune a base foundation model's instruct finetune on this dataset. 38 | 6. * Use QLoRA and PEFT to finetune with [DPO](https://huggingface.co/docs/trl/main/en/dpo_trainer) to extend the model's capacities even further and teach it preferred answers over rejected ones. Note that your base dataset must have this information. 39 | * Alternatively to DPO, you can achieve the same with [ORPO](https://huggingface.co/docs/trl/main/en/orpo_trainer) 40 | 7. Run inference using the newly trained model. 41 | 42 | # Supported paradigms 43 | ## Translation 44 | * OPUS 45 | * M2M 46 | * MADLAD 47 | * mBART 48 | * NLLB 49 | * Seamless (Large only) 50 | * Tower Instruct (Can correct spelling mistakes) 51 | ## Base datasets 52 | The following have been tested but potentially more will work 53 | * OASST1 54 | * OASST2 55 | ## Supported foundation models 56 | * **LLaMa3** 57 | * LLaMa2 58 | * Mistral 59 | * (Unofficial) Mixtral 8x7B 60 | 61 | # Roadmap 62 | * [L2L-6] Investigate interoperability with other libraries (Axolotl, llamacpp, unsloth) 63 | * [L2L-7] Allow for different quantizations next to QLoRA (GGUF, GPTQ, AWQ) 64 | * [L2L-10] Support extending the tokenizer and vocabulary 65 | 66 | ## Cost and runtime 67 | 68 | The above process can be fully run on a free Google Colab T4 GPU. The last step however, can only be successfully run with short enough context windows and a batch of at most 2. In addition, the translation in step 2 takes about 36 hours in total for any given language so should be run in multiple steps if you want to stick with a free Google Colab GPU. 69 | 70 | Our fine-tuned models for step 5 were performed using an A40 on [vast.ai](https://vast.ai/) and cost us less than a dollar for each model, completing in about 1.5 hours. 71 | 72 | # Usage 73 | 1. Make sure pytorch is installed and working for your environment (use of CUDA preferable): https://pytorch.org/get-started/locally/ 74 | 75 | 2. Clone the repo and install the requirements. 76 | 77 | `pip install -r requirements.txt` 78 | 79 | 2. Translate your base dataset to your designated target language. 80 | 81 | ``` 82 | usage: translate.py [-h] [--quant8] [--quant4] [--base_dataset BASE_DATASET] [--base_dataset_text_field BASE_DATASET_TEXT_FIELD] [--base_dataset_lang_field BASE_DATASET_LANG_FIELD] 83 | [--checkpoint_n CHECKPOINT_N] [--batch_size BATCH_SIZE] [--max_length MAX_LENGTH] [--cpu] [--source_lang SOURCE_LANG] 84 | {opus,mbart,madlad,m2m,nllb,seamless_m4t_v2,towerinstruct} ... target_lang checkpoint_location 85 | 86 | Translate an instruct/RLHF dataset to a given target language using a variety of translation models 87 | 88 | positional arguments: 89 | {opus,mbart,madlad,m2m,nllb,seamless_m4t_v2,towerinstruct} 90 | The model/architecture used for translation. 91 | opus Translate the dataset using HelsinkiNLP OPUS models. 92 | mbart Translate the dataset using mBART. 93 | madlad Translate the dataset using Google's MADLAD models. 94 | m2m Translate the dataset using Facebook's M2M models. 95 | nllb Translate the dataset using Facebook's NLLB models. 96 | seamless_m4t_v2 Translate the dataset using Facebook's SeamlessM4T-v2 multimodal models. 97 | towerinstruct Translate the dataset using Unbabel's Tower Instruct. Make sure your target language is in the 10 languages supported by the model. 98 | target_lang The target language. Make sure you use language codes defined by the translation model you are using. 99 | checkpoint_location The folder the script will write (JSONized) checkpoint files to. Folder will be created if it doesn't exist. 100 | 101 | options: 102 | -h, --help show this help message and exit 103 | --quant8 Optional flag to load the translation model in 8 bits. Decreases memory usage, increases running time 104 | --quant4 Optional flag to load the translation model in 4 bits. Decreases memory usage, increases running time 105 | --base_dataset BASE_DATASET 106 | The base dataset to translate, defaults to OpenAssistant/oasst1 107 | --base_dataset_text_field BASE_DATASET_TEXT_FIELD 108 | The base dataset's column name containing the actual text to translate. Defaults to text 109 | --base_dataset_lang_field BASE_DATASET_LANG_FIELD 110 | The base dataset's column name containing the language the source text was written in. Defaults to lang 111 | --checkpoint_n CHECKPOINT_N 112 | An integer representing how often a checkpoint file will be written out. To start off, 400 is a reasonable number. 113 | --batch_size BATCH_SIZE 114 | The batch size for a single translation model. Adjust based on your GPU capacity. Default is 10. 115 | --max_length MAX_LENGTH 116 | How much tokens to generate at most. More tokens might be more accurate for lengthy input but creates a risk of running out of memory. Default is unlimited. 117 | --cpu Forces usage of CPU. By default GPU is taken if available. 118 | --source_lang SOURCE_LANG 119 | Source language to select from OASST based on lang property of dataset 120 | ``` 121 | 122 | If you want more parameters for the different translation models, run: 123 | ``` 124 | python translate.py [MODEL] -h 125 | ``` 126 | 127 | Be sure to specify model-specific parameters first before you specify common parameters from the list above. Example calls: 128 | ``` 129 | # Using M2M with 4bit quantization and differen batch sizes to translate Dutch 130 | python translate.py m2m nl ./output_nl --quant4 --batch_size 20 131 | 132 | # Using madlad 7B with 8bit quantization for German with different max_length 133 | python translate.py madlad --model_size 7b de ./output_de --quant8 --batch_size 5 --max_length 512 134 | 135 | # Be sure to use target language codes that the model you use understands 136 | python translate.py mbart xh_ZA ./output_xhosa 137 | python translate.py nllb nld_Latn ./output_nl 138 | ``` 139 | 140 | 3. Combine the JSON arrays from the checkpoints' files into a Huggingface Dataset and then either write it to disk or publish it to Huggingface. The script will try to write to disk by default and fall back to publishing to Huggingface if the folder doesn't exist on disk. For publishing to Huggingface, make sure you have your `HF_TOKEN` environment variable set up as per [the documentation](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hftoken). 141 | 142 | ``` 143 | usage: combine_checkpoints.py [-h] input_folder output_location 144 | 145 | Combine checkpoint files from translation. 146 | 147 | positional arguments: 148 | input_folder The checkpoint folder used in translation, with the target language appended. 149 | Example: "./output_nl". 150 | output_location Where to write the Huggingface Dataset. Can be a disk location or a Huggingface 151 | Dataset repository. 152 | 153 | options: 154 | -h, --help show this help message and exit 155 | ``` 156 | 157 | 5. Turn the translated messages into chat/instruct/prompt threads and finetune a foundate model's instruct using LoRA and PEFT. 158 | 159 | ``` 160 | usage: finetune.py [-h] [--base_model BASE_MODEL] [--base_dataset_text_field BASE_DATASET_TEXT_FIELD] [--base_dataset_rank_field BASE_DATASET_RANK_FIELD] [--base_dataset_id_field BASE_DATASET_ID_FIELD] [--base_dataset_parent_field BASE_DATASET_PARENT_FIELD] 161 | [--base_dataset_role_field BASE_DATASET_ROLE_FIELD] [--quant8] [--noquant] [--max_seq_length MAX_SEQ_LENGTH] [--num_train_epochs NUM_TRAIN_EPOCHS] [--batch_size BATCH_SIZE] [--threads_output_name THREADS_OUTPUT_NAME] [--thread_template THREAD_TEMPLATE] 162 | [--padding PADDING] 163 | tuned_model dataset_name instruction_prompt 164 | 165 | Finetune a base instruct/chat model using (Q)LoRA and PEFT 166 | 167 | positional arguments: 168 | tuned_model The name of the resulting tuned model. 169 | dataset_name The name of the dataset to use for fine-tuning. This should be the output of the combine_checkpoints script. 170 | instruction_prompt An instruction message added to every prompt given to the chatbot to force it to answer in the target language. Example: "You are a generic chatbot that always answers in English." 171 | 172 | options: 173 | -h, --help show this help message and exit 174 | --base_model BASE_MODEL 175 | The base foundation model. Default is "NousResearch/Meta-Llama-3-8B-Instruct". 176 | --base_dataset_text_field BASE_DATASET_TEXT_FIELD 177 | The dataset's column name containing the actual text to translate. Defaults to text 178 | --base_dataset_rank_field BASE_DATASET_RANK_FIELD 179 | The dataset's column name containing the rank of an answer given to a prompt. Defaults to rank 180 | --base_dataset_id_field BASE_DATASET_ID_FIELD 181 | The dataset's column name containing the id of a text. Defaults to message_id 182 | --base_dataset_parent_field BASE_DATASET_PARENT_FIELD 183 | The dataset's column name containing the parent id of a text. Defaults to parent_id 184 | --base_dataset_role_field BASE_DATASET_ROLE_FIELD 185 | The dataset's column name containing the role of the author of the text (eg. prompter, assistant). Defaults to role 186 | --quant8 Finetunes the model in 8 bits. Requires more memory than the default 4 bit. 187 | --noquant Do not quantize the finetuning. Requires more memory than the default 4 bit and optional 8 bit. 188 | --max_seq_length MAX_SEQ_LENGTH 189 | The maximum sequence length to use in finetuning. Should most likely line up with your base model's default max_seq_length. Default is 512. 190 | --num_train_epochs NUM_TRAIN_EPOCHS 191 | Number of epochs to use. 2 is default and has been shown to work well. 192 | --batch_size BATCH_SIZE 193 | The batch size to use in finetuning. Adjust to fit in your GPU vRAM. Default is 4 194 | --threads_output_name THREADS_OUTPUT_NAME 195 | If specified, the threads created in this script for finetuning will also be saved to disk or HuggingFace Hub. 196 | --thread_template THREAD_TEMPLATE 197 | A file containing the thread template to use. Default is threads/template_fefault.txt 198 | --padding PADDING What padding to use, can be either left or right. 199 | ``` 200 | 201 | 6.1 [OPTIONAL] Finetune using DPO (similar to RLHF) 202 | ``` 203 | usage: finetune_dpo.py [-h] [--base_model BASE_MODEL] [--base_dataset_text_field BASE_DATASET_TEXT_FIELD] [--base_dataset_rank_field BASE_DATASET_RANK_FIELD] [--base_dataset_id_field BASE_DATASET_ID_FIELD] [--base_dataset_parent_field BASE_DATASET_PARENT_FIELD] [--quant8] 204 | [--noquant] [--max_seq_length MAX_SEQ_LENGTH] [--max_prompt_length MAX_PROMPT_LENGTH] [--num_train_epochs NUM_TRAIN_EPOCHS] [--batch_size BATCH_SIZE] [--threads_output_name THREADS_OUTPUT_NAME] [--thread_template THREAD_TEMPLATE] [--max_steps MAX_STEPS] 205 | [--padding PADDING] 206 | tuned_model dataset_name instruction_prompt 207 | 208 | Finetune a base instruct/chat model using (Q)LoRA and PEFT using DPO (RLHF) 209 | 210 | positional arguments: 211 | tuned_model The name of the resulting tuned model. 212 | dataset_name The name of the dataset to use for fine-tuning. This should be the output of the combine_checkpoints script. 213 | instruction_prompt An instruction message added to every prompt given to the chatbot to force it to answer in the target language. Example: "You are a generic chatbot that always answers in English." 214 | 215 | options: 216 | -h, --help show this help message and exit 217 | --base_model BASE_MODEL 218 | The base foundation model. Default is "NousResearch/Meta-Llama-3-8B-Instruct". 219 | --base_dataset_text_field BASE_DATASET_TEXT_FIELD 220 | The dataset's column name containing the actual text to translate. Defaults to text 221 | --base_dataset_rank_field BASE_DATASET_RANK_FIELD 222 | The dataset's column name containing the rank of an answer given to a prompt. Defaults to rank 223 | --base_dataset_id_field BASE_DATASET_ID_FIELD 224 | The dataset's column name containing the id of a text. Defaults to message_id 225 | --base_dataset_parent_field BASE_DATASET_PARENT_FIELD 226 | The dataset's column name containing the parent id of a text. Defaults to parent_id 227 | --quant8 Finetunes the model in 8 bits. Requires more memory than the default 4 bit. 228 | --noquant Do not quantize the finetuning. Requires more memory than the default 4 bit and optional 8 bit. 229 | --max_seq_length MAX_SEQ_LENGTH 230 | The maximum sequence length to use in finetuning. Should most likely line up with your base model's default max_seq_length. Default is 512. 231 | --max_prompt_length MAX_PROMPT_LENGTH 232 | The maximum length of the prompts to use. Default is 512. 233 | --num_train_epochs NUM_TRAIN_EPOCHS 234 | Number of epochs to use. 2 is default and has been shown to work well. 235 | --batch_size BATCH_SIZE 236 | The batch size to use in finetuning. Adjust to fit in your GPU vRAM. Default is 4 237 | --threads_output_name THREADS_OUTPUT_NAME 238 | If specified, the threads created in this script for finetuning will also be saved to disk or HuggingFace Hub. 239 | --thread_template THREAD_TEMPLATE 240 | A file containing the thread template to use. Default is threads/template_fefault.txt 241 | --max_steps MAX_STEPS 242 | The maximum number of steps to run DPO for. Default is -1 which will run the data through fully for the number of epochs but this will be very time-consuming. 243 | --padding PADDING What padding to use, can be either left or right. 244 | ``` 245 | 246 | 6.1 [OPTIONAL] Finetune using ORPO (similar to RLHF) 247 | ``` 248 | usage: finetune_orpo.py [-h] [--base_model BASE_MODEL] [--base_dataset_text_field BASE_DATASET_TEXT_FIELD] [--base_dataset_rank_field BASE_DATASET_RANK_FIELD] [--base_dataset_id_field BASE_DATASET_ID_FIELD] [--base_dataset_parent_field BASE_DATASET_PARENT_FIELD] [--quant8] 249 | [--noquant] [--max_seq_length MAX_SEQ_LENGTH] [--max_prompt_length MAX_PROMPT_LENGTH] [--num_train_epochs NUM_TRAIN_EPOCHS] [--batch_size BATCH_SIZE] [--threads_output_name THREADS_OUTPUT_NAME] [--thread_template THREAD_TEMPLATE] [--max_steps MAX_STEPS] 250 | [--padding PADDING] 251 | tuned_model dataset_name instruction_prompt 252 | 253 | Finetune a base instruct/chat model using (Q)LoRA and PEFT using ORPO (RLHF) 254 | 255 | positional arguments: 256 | tuned_model The name of the resulting tuned model. 257 | dataset_name The name of the dataset to use for fine-tuning. This should be the output of the combine_checkpoints script. 258 | instruction_prompt An instruction message added to every prompt given to the chatbot to force it to answer in the target language. Example: "You are a generic chatbot that always answers in English." 259 | 260 | options: 261 | -h, --help show this help message and exit 262 | --base_model BASE_MODEL 263 | The base foundation model. Default is "NousResearch/Meta-Llama-3-8B-Instruct". 264 | --base_dataset_text_field BASE_DATASET_TEXT_FIELD 265 | The dataset's column name containing the actual text to translate. Defaults to text 266 | --base_dataset_rank_field BASE_DATASET_RANK_FIELD 267 | The dataset's column name containing the rank of an answer given to a prompt. Defaults to rank 268 | --base_dataset_id_field BASE_DATASET_ID_FIELD 269 | The dataset's column name containing the id of a text. Defaults to message_id 270 | --base_dataset_parent_field BASE_DATASET_PARENT_FIELD 271 | The dataset's column name containing the parent id of a text. Defaults to parent_id 272 | --quant8 Finetunes the model in 8 bits. Requires more memory than the default 4 bit. 273 | --noquant Do not quantize the finetuning. Requires more memory than the default 4 bit and optional 8 bit. 274 | --max_seq_length MAX_SEQ_LENGTH 275 | The maximum sequence length to use in finetuning. Should most likely line up with your base model's default max_seq_length. Default is 512. 276 | --max_prompt_length MAX_PROMPT_LENGTH 277 | The maximum length of the prompts to use. Default is 512. 278 | --num_train_epochs NUM_TRAIN_EPOCHS 279 | Number of epochs to use. 2 is default and has been shown to work well. 280 | --batch_size BATCH_SIZE 281 | The batch size to use in finetuning. Adjust to fit in your GPU vRAM. Default is 4 282 | --threads_output_name THREADS_OUTPUT_NAME 283 | If specified, the threads created in this script for finetuning will also be saved to disk or HuggingFace Hub. 284 | --thread_template THREAD_TEMPLATE 285 | A file containing the thread template to use. Default is threads/template_fefault.txt 286 | --max_steps MAX_STEPS 287 | The maximum number of steps to run ORPO for. Default is -1 which will run the data through fully for the number of epochs but this will be very time-consuming. 288 | --padding PADDING What padding to use, can be either left or right. 289 | ``` 290 | 291 | 7. Run inference using the newly created QLoRA model. 292 | 293 | ``` 294 | usage: run_inference.py [-h] model_name instruction_prompt input 295 | 296 | Script to run inference on a tuned model. 297 | 298 | positional arguments: 299 | model_name The name of the tuned model that you pushed to Huggingface in the previous 300 | step. 301 | instruction_prompt An instruction message added to every prompt given to the chatbot to force 302 | it to answer in the target language. 303 | input The actual chat input prompt. The script is only meant for testing purposes 304 | and exits after answering. 305 | 306 | options: 307 | -h, --help show this help message and exit 308 | 309 | ``` 310 | 311 | # Choosing the right translation model 312 | > How do I know which translation model to choose for my target language? 313 | 314 | **We got you covered** with out `benchmark.py` script that helps make somewhat of a good guess (the dataset we use is the same as the OPUS models are trained on so the outcomes are always favorable towards OPUS). For usage, see the help of this script below. Models are loaded in 4-bit quantization and run on a small sample of the OPUS books subset. 315 | 316 | Be sure to use the most commonly occurring languages in your base dataset as source_language and your target translation language as target_language. For OASST1 for example, be sure to at least run `en` and `es` as source languages. 317 | 318 | ``` 319 | usage: benchmark.py [-h] [--cpu] [--start START] [--n N] [--max_length MAX_LENGTH] source_language target_language included_models 320 | 321 | Benchmark all the different translation models for a specific source and target language to find out which performs best. This uses 4bit quantization to limit GPU usage. Note: 322 | the outcomes are indicative - you cannot assume corretness of the BLEU and CHRF scores but you can compare models against each other relatively. 323 | 324 | positional arguments: 325 | source_language The source language you want to test for. Check your dataset to see which occur most prevalent or use English as a good start. 326 | target_language The source language you want to test for. This should be the language you want to apply the translate script on. Note: in benchmark, we use 2-character 327 | language codes, in constrast to translate.py where you need to specify whatever your model expects. 328 | included_models Comma-separated list of models to include. Allowed values are: opus, m2m_418m, m2m_1.2b, madlad_3b, madlad_7b, madlad_10b, madlad_7bbt, mbart, 329 | nllb_distilled600m, nllb_1.3b, nllb_distilled1.3b, nllb_3.3b, seamless 330 | 331 | options: 332 | -h, --help show this help message and exit 333 | --cpu Forces usage of CPU. By default GPU is taken if available. 334 | --start START The starting offset to include sentences from the OPUS books dataset from. Defaults to 0. 335 | --n N The number of sentences to benchmark on. Defaults to 100. 336 | --max_length MAX_LENGTH 337 | How much tokens to generate at most. More tokens might be more accurate for lengthy input but creates a risk of running out of memory. Default is 512. 338 | ``` 339 | 340 | # Datasets and models 341 | 342 | We have created and will continue to create numerous datasets and models already. **Want to help democratize LLMs?** Clone the repo and create datasets and models for other languages, then create a PR. 343 | 344 | ## Translated oasst1 datasets 345 | 346 | | | | | | 347 | |---------|---------|---------|---------| 348 | | Dutch [UnderstandLing/oasst1_nl](https://huggingface.co/datasets/UnderstandLing/oasst1_nl) | Spanish [UnderstandLing/oasst1_es](https://huggingface.co/datasets/UnderstandLing/oasst1_es) | French [UnderstandLing/oasst1_fr](https://huggingface.co/datasets/UnderstandLing/oasst1_fr) | German [UnderstandLing/oasst1_de](https://huggingface.co/datasets/UnderstandLing/oasst1_de) | 349 | | Catalan [xaviviro/oasst1_ca](https://huggingface.co/datasets/xaviviro/oasst1_ca) | Portuguese [UnderstandLing/oasst1_pt](https://huggingface.co/datasets/UnderstandLing/oasst1_pt) | Arabic [HeshamHaroon/oasst-arabic](https://huggingface.co/datasets/HeshamHaroon/oasst-arabic) | Italian [UnderstandLing/oasst1_it](https://huggingface.co/datasets/UnderstandLing/oasst1_it) | 350 | | Russian [UnderstandLing/oasst1_ru](https://huggingface.co/datasets/UnderstandLing/oasst1_ru) | Hindi [UnderstandLing/oasst1_hi](https://huggingface.co/datasets/UnderstandLing/oasst1_hi) | Chinese [UnderstandLing/oasst1_zh](https://huggingface.co/datasets/UnderstandLing/oasst1_zh) | Polish [chrystians/oasst1_pl](https://huggingface.co/datasets/chrystians/oasst1_pl) | 351 | | Japanese [UnderstandLing/oasst1_jap](https://huggingface.co/datasets/UnderstandLing/oasst1_jap) | Basque [xezpeleta/oasst1_eu](https://huggingface.co/datasets/xezpeleta/oasst1_eu) | Bengali [UnderstandLing/oasst1_bn](https://huggingface.co/datasets/UnderstandLing/oasst1_bn) | Turkish [UnderstandLing/oasst1_tr](https://huggingface.co/datasets/UnderstandLing/oasst1_tr) | 352 | 353 | ## Language-specific ❗LLaMa3-8B❗ chat model adapters 354 | 355 | Make sure you have access to Meta's [LLaMa3-8B model](https://huggingface.co/meta-llama/Meta-Llama-3-8B) and set your HF_TOKEN before using these models. 356 | 357 | | | | | | 358 | |---------|---------|---------|---------| 359 | | [UnderstandLing/Llama-3-8B-Instruct-nl](https://huggingface.co/UnderstandLing/Llama-3-8B-Instruct-nl) Dutch | [UnderstandLing/Llama-3-8B-Instruct-es](https://huggingface.co/UnderstandLing/Llama-3-8B-Instruct-es) Spanish | [UnderstandLing/Llama-3-8B-Instruct-fr](https://huggingface.co/UnderstandLing/Llama-3-8B-Instruct-fr) French | [UnderstandLing/Llama-3-8B-Instruct-de](https://huggingface.co/UnderstandLing/Llama-3-8B-Instruct-de) German | 360 | | [UnderstandLing/Llama-3-8B-Instruct-pt](https://huggingface.co/UnderstandLing/Llama-3-8B-Instruct-pt) Portuguese | [UnderstandLing/Llama-3-8B-Instruct-it](https://huggingface.co/UnderstandLing/Llama-3-8B-Instruct-it) Italian | [UnderstandLing/Llama-3-8B-Instruct-hi](https://huggingface.co/UnderstandLing/Llama-3-8B-Instruct-hi) Hindi | [UnderstandLing/Llama-3-8B-Instruct-ru](https://huggingface.co/UnderstandLing/Llama-3-8B-Instruct-ru) Russian | 361 | 362 | 363 | ## Translated LLaMa2 thread chat prompt datasets 364 | 365 | | | | | | 366 | |---------|---------|---------|---------| 367 | | Dutch [UnderstandLing/oasst1_nl_threads](https://huggingface.co/datasets/UnderstandLing/oasst1_nl_threads) | Spanish [UnderstandLing/oasst1_es_threads](https://huggingface.co/datasets/UnderstandLing/oasst1_es_threads) | French [UnderstandLing/oasst1_fr_threads](https://huggingface.co/datasets/UnderstandLing/oasst1_fr_threads) | German [UnderstandLing/oasst1_de_threads](https://huggingface.co/datasets/UnderstandLing/oasst1_de_threads) | 368 | | Catalan [xaviviro/oasst1_ca_threads](https://huggingface.co/datasets/xaviviro/oasst1_ca_threads) | Portuguese [UnderstandLing/oasst1_pt_threads](https://huggingface.co/datasets/UnderstandLing/oasst1_pt_threads) | Arabic [HeshamHaroon/oasst-arabic_threads](https://huggingface.co/datasets/HeshamHaroon/oasst-arabic_threads) | Italian [UnderstandLing/oasst1_it_threads](https://huggingface.co/datasets/UnderstandLing/oasst1_it_threads) | 369 | | Russian [UnderstandLing/oasst1_ru_threads](https://huggingface.co/datasets/UnderstandLing/oasst1_ru_threads) | Hindi [UnderstandLing/oasst1_hi_threads](https://huggingface.co/datasets/UnderstandLing/oasst1_hi_threads) | Chinese [UnderstandLing/oasst1_zh_threads](https://huggingface.co/datasets/UnderstandLing/oasst1_zh_threads) | Polish [chrystians/oasst1_pl_threads](https://huggingface.co/datasets/chrystians/oasst1_pl_threads) | 370 | | Japanese [UnderstandLing/oasst1_jap_threads](https://huggingface.co/datasets/UnderstandLing/oasst1_jap_threads) | Basque [xezpeleta/oasst1_eu_threads](https://huggingface.co/datasets/xezpeleta/oasst1_eu_threads) | Bengali [UnderstandLing/oasst1_bn_threads](https://huggingface.co/datasets/UnderstandLing/oasst1_bn_threads) | Turkish [UnderstandLing/oasst1_tr_threads](https://huggingface.co/datasets/UnderstandLing/oasst1_tr_threads) | 371 | 372 | ## Language-specific LLaMa2-7B chat model adapters 373 | | | | | | 374 | |---------|---------|---------|---------| 375 | | [UnderstandLing/llama-2-7b-chat-nl](https://huggingface.co/UnderstandLing/llama-2-7b-chat-nl) Dutch | [UnderstandLing/llama-2-7b-chat-es](https://huggingface.co/UnderstandLing/llama-2-7b-chat-es) Spanish | [UnderstandLing/llama-2-7b-chat-fr](https://huggingface.co/UnderstandLing/llama-2-7b-chat-fr) French |[UnderstandLing/llama-2-7b-chat-de](https://huggingface.co/UnderstandLing/llama-2-7b-chat-de) German | 376 | [xaviviro/llama-2-7b-chat-ca](https://huggingface.co/xaviviro/llama-2-7b-chat-ca) Catalan | [UnderstandLing/llama-2-7b-chat-pt](https://huggingface.co/UnderstandLing/llama-2-7b-chat-pt) Portuguese | [HeshamHaroon/llama-2-7b-chat-ar](https://huggingface.co/HeshamHaroon/llama-2-7b-chat-ar) Arabic | [UnderstandLing/llama-2-7b-chat-it](https://huggingface.co/UnderstandLing/llama-2-7b-chat-it) Italian | 377 | [UnderstandLing/llama-2-7b-chat-ru](https://huggingface.co/UnderstandLing/llama-2-7b-chat-ru) Russian | [UnderstandLing/llama-2-7b-chat-hi](https://huggingface.co/UnderstandLing/llama-2-7b-chat-hi) Hindi | [UnderstandLing/llama-2-7b-chat-zh](https://huggingface.co/UnderstandLing/llama-2-7b-chat-zh) Chinese | [chrystians/llama-2-7b-chat-pl-polish-polski](https://huggingface.co/chrystians/llama-2-7b-chat-pl-polish-polski) Polish | 378 | | [xezpeleta/llama-2-7b-chat-eu](https://huggingface.co/xezpeleta/llama-2-7b-chat-eu) Basque | [UnderstandLing/llama-2-7b-chat-bn](https://huggingface.co/UnderstandLing/llama-2-7b-chat-bn) Bengali | [UnderstandLing/llama-2-7b-chat-tr](https://huggingface.co/UnderstandLing/llama-2-7b-chat-tr) Turkish | | 379 | 380 | ## Language-specific Mistral chat model adapters 381 | | | | | | 382 | |---------|---------|---------|---------| 383 | | [UnderstandLing/Mistral-7B-Instruct-v0.2-nl](https://huggingface.co/UnderstandLing/Mistral-7B-Instruct-v0.2-nl) Dutch | [UnderstandLing/Mistral-7B-Instruct-v0.2-es](https://huggingface.co/UnderstandLing/Mistral-7B-Instruct-v0.2-es) Spanish | [UnderstandLing/Mistral-7B-Instruct-v0.2-de](https://huggingface.co/UnderstandLing/Mistral-7B-Instruct-v0.2-de) German | | 384 | 385 | ## Language-specific LLaMa2-13B chat model adapters 386 | | | | | | 387 | |---------|---------|---------|---------| 388 | | [UnderstandLing/llama-2-13b-chat-nl](https://huggingface.co/UnderstandLing/llama-2-13b-chat-nl) Dutch | [UnderstandLing/llama-2-13b-chat-es](https://huggingface.co/UnderstandLing/llama-2-13b-chat-es) Spanish | [UnderstandLing/llama-2-13b-chat-fr](https://huggingface.co/UnderstandLing/llama-2-13b-chat-fr) French | | 389 | 390 | ## Language-specific Mixtral-8x7B chat model adapters 391 | | | | | | 392 | |---------|---------|---------|---------| 393 | | [UnderstandLing/Mixtral-8x7B-Instruct-nl](https://huggingface.co/UnderstandLing/Mixtral-8x7B-Instruct-nl) Dutch | | | | 394 | 395 | # Empirical performance 396 | 397 | ## Dutch 398 | 399 | `[INST] <> Je bent een generieke chatbot die altijd in het Nederlands antwoord geeft. <> Wat is de hoofdstad van Nederland? [/INST] Amsterdam` 400 | 401 | `[INST] <> Je bent een generieke chatbot die altijd in het Nederlands antwoord geeft. <> Wat is de hoofdstad van Nederland? [/INST] Amsterdam[INST] Hoeveel inwoners heeft die stad? [/INST] 850 duizend inwoners (2023)` 402 | 403 | `[INST] <> Je bent een generieke chatbot die altijd in het Nederlands antwoord geeft. <> Wat is de hoofdstad van Nederland? [/INST] Amsterdam[INST] Hoeveel inwoners heeft die stad? [/INST] 850 duizend inwoners (2023)[INST] In welke provincie ligt die stad? [/INST] In de provincie Noord-Holland` 404 | 405 | `[INST] <> Je bent een generieke chatbot die altijd in het Nederlands antwoord geeft. <> Wie is de minister-president van Nederland? [/INST] Mark Rutte is sinds 2010 minister-president van Nederland. Hij is meerdere keren herkozen.` 406 | 407 | # FAQ 408 | 409 | * Q: Why do you translate the full OASST1/2 dataset first? Wouldn't it be faster to only translate highest ranked threads? 410 | * A: While you can gain quite a lot in terms of throughput time by first creating the threads and then translating them, we provide full OASST1/2 translations to the community as we believe they can be useful on their own. 411 | 412 | * Q: How well do the fine-tunes perform compared to vanilla LLaMa3? 413 | * A: While we do not have formal benchmarks, getting LLaMa3 to consistently speak another language than English to begin with is challenging if not impossible. The non-English language it does produce is often grammatically broken. Our fine-tunes do not show this behavior. 414 | 415 | * Q: Can I use other frameworks for fine-tuning? 416 | * A: Yes you can, we use [Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) for training on multi-GPU setups. 417 | 418 | * Q: Can I mix different translation models? 419 | * A: Absolutely, we think it might even increase performance to have translation done by multiple models. You can achieve this by early-stopping a translation and continuing from the checkpoints by reruning the translate script with a different translation model. 420 | 421 | # Funding 422 | We are actively looking for funding to democratize AI and advance its applications. Contact us at info@commandos.ai if you want to invest. 423 | --------------------------------------------------------------------------------