├── .env.example ├── .gitignore ├── README.md ├── arguments.py ├── constants.py ├── data ├── qa_medical_pairs.json ├── qa_pairs_edoctor.json └── qa_vinmec.json ├── data_collators.py ├── dataset.py ├── inference.py ├── prompt ├── __init__.py ├── base.py ├── qa_prompt.py └── role.py ├── requirements.txt ├── tokenizer.py ├── train.sh ├── train_debug.sh ├── training.py └── utils.py /.env.example: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungnlp/viBioGPT/ddded74366cfa5e22c866c31b0b859de93a7a300/.env.example -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.pyc 4 | *.pyo 5 | *.pyd 6 | 7 | # Virtual environment 8 | venv/ 9 | env/ 10 | .venv/ 11 | env.bak/ 12 | venv.bak/ 13 | .vscode/ 14 | 15 | # IDE files 16 | .idea/ 17 | *.suo 18 | *.ntvs* 19 | *.njsproj 20 | *.sln 21 | *.pyproj 22 | *.user 23 | *.vscode/ 24 | 25 | # Compiled source 26 | *.com 27 | *.class 28 | *.dll 29 | *.exe 30 | *.o 31 | *.pyc 32 | *.pyo 33 | 34 | # Logs and databases 35 | *.log 36 | *.sqlite 37 | *.db 38 | 39 | # Package files 40 | *.egg 41 | *.egg-info/ 42 | dist/ 43 | build/ 44 | .DS_Store 45 | # Distribution / packaging 46 | .Python 47 | env/ 48 | build/ 49 | develop-eggs/ 50 | dist/ 51 | downloads/ 52 | eggs/ 53 | .eggs/ 54 | lib/ 55 | lib64/ 56 | parts/ 57 | sdist/ 58 | var/ 59 | wheels/ 60 | *.egg-info/ 61 | .installed.cfg 62 | *.egg 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # Jupyter Notebook 68 | .ipynb_checkpoints 69 | 70 | # IPython 71 | profile_default/ 72 | ipython_config.py 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # pipenv 78 | Pipfile 79 | Pipfile.lock 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | # Pyre type checker 110 | .pyre/ 111 | 112 | # pipenv 113 | Pipfile.lock 114 | wandb 115 | data/news 116 | data_collector/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## viBioGPT: A Vietnamese Large Language Model for Biomedical Question Answering 2 | 3 | ---- 4 | 5 | **viBioGPT-7B-instruct** is a Vietnamese Large Language Model (LLM) fine-tuned for the task of Question Answering within 6 | the medical and healthcare domain. This model uses 7 | pre-trained [Vistral-Chat-7B](https://huggingface.co/Viet-Mistral/Vistral-7B-Chat), then QLora technique 8 | to fine-tune. 9 | 10 | ### Table of Contents 11 | 12 | --- 13 | 14 | * [Using viBioGPT with Transformer](#using-vibiogpt-with-transformer) 15 | * [Run on your device](#run-on-your-device) 16 | * [Run on Google colab](#run-on-google-colab) 17 | * [Training Data](#training-data) 18 | * [Training](#training) 19 | 20 | ## Using viBioGPT with Transformer 21 | 22 | ### Model Download 23 | 24 | Our model has been fine-tuned based on the pre-trained model 25 | model [Vistral-Chat-7B](https://huggingface.co/Viet-Mistral/Vistral-7B-Chat) 26 | 27 | | Size | Hugging Face Model | Base Model | 28 | |------|---------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------| 29 | | 7B | [hungnm/viBioGPT-7B-instruct-qlora-adapter](https://huggingface.co/hungnm/viBioGPT-7B-instruct-qlora-adapter) | [Vistral-Chat-7B](https://huggingface.co/Viet-Mistral/Vistral-7B-Chat) | 30 | 31 | ### Run on your device 32 | 33 | Create environment with conda 34 | 35 | ```shell 36 | conda create -n biogpt python=3.10 37 | ``` 38 | 39 | Activate environment 40 | 41 | ```shell 42 | conda activate biogpt 43 | ``` 44 | 45 | Install dependencies 46 | 47 | ```shell 48 | pip install peft==0.7.1 bitsandbytes==0.41.3.post2 transformers==4.36.2 torch==2.1.2 typer==0.9.0 49 | ``` 50 | 51 | Install Flash Attention 2 52 | 53 | ```shell 54 | pip install flash-attn==2.3.3 --no-build-isolation 55 | ```` 56 | 57 | Example usage 58 | 59 | _Note: replace your huggingface token_ 60 | 61 | ```python 62 | import torch 63 | from peft import PeftModel 64 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig 65 | 66 | HF_TOKEN = "" 67 | model_name = "Viet-Mistral/Vistral-7B-Chat" 68 | adapter = "hungnm/viBioGPT-7B-instruct-qlora-adapter" 69 | 70 | compute_dtype = getattr(torch, "bfloat16") 71 | bnb_config = BitsAndBytesConfig( 72 | load_in_4bit=True, 73 | bnb_4bit_quant_type="nf4", 74 | bnb_4bit_compute_dtype=compute_dtype, 75 | bnb_4bit_use_double_quant=True, 76 | ) 77 | model = AutoModelForCausalLM.from_pretrained(model_name, 78 | quantization_config=bnb_config, 79 | device_map={"": 0}, 80 | token=HF_TOKEN 81 | ) 82 | model = PeftModel.from_pretrained(model, adapter) 83 | 84 | # load and config tokenizer 85 | tokenizer = AutoTokenizer.from_pretrained(model_name, 86 | token=HF_TOKEN) 87 | tokenizer.padding_side = "left" 88 | tokenizer.pad_token_id = tokenizer.eos_token_id 89 | 90 | system_prompt = ("Bạn là một trợ lý ảo AI trong lĩnh vực Y học, Sức Khỏe. Tên của bạn là AI-Doctor. " 91 | "Nhiệm vụ của bạn là trả lời các thắc mắc hoặc các câu hỏi về Y học, Sức khỏe.") 92 | 93 | question = "tôi có một ít nhân sâm nhưng đang bị viêm dạ dày. Vậy tôi có nên ăn nhân sâm ko?" 94 | conversation = [ 95 | { 96 | "role": "system", 97 | "content": system_prompt}, 98 | { 99 | "role": "user", 100 | "content": question 101 | }] 102 | instruction_str = tokenizer.apply_chat_template(conversation=conversation, 103 | tokenize=False) 104 | token_ids = tokenizer([instruction_str], return_tensors="pt")["input_ids"] 105 | token_ids = token_ids.to(model.device) 106 | outputs = model.generate(input_ids=token_ids, 107 | max_new_tokens=768, 108 | do_sample=True, 109 | temperature=0.001, 110 | top_p=0.95, 111 | top_k=40, 112 | repetition_penalty=1.2) 113 | all_token_ids = outputs[0].tolist() 114 | output_token_ids = all_token_ids[token_ids.shape[-1]:] 115 | output = tokenizer.decode(output_token_ids) 116 | 117 | print(output) 118 | 119 | ``` 120 | 121 | Output: 122 | 123 | ```text 124 | Chào anh! 125 | Nhân sâm được biết đến như loại thảo dược quý hiếm và rất tốt cho sức khoẻ con người tuy nhiên không phải ai cũng dùng được nó đặc biệt với những bệnh nhân đau dạ dày thì càng cần thận trọng khi sử dụng vì nếu lạm dụng sẽ gây ra nhiều tác hại nghiêm trọng tới hệ tiêu hoá nói chung và tình trạng đau dạ dày nói riêng . 126 | Vì vậy trước tiên anh hãy điều trị dứt điểm căn bênh này rồi mới nghĩ tới việc bổ sung thêm dinh dưỡng từ nhân sâm nhé ! 127 | Chúc anh mau khỏi bệnh ạ! 128 | ``` 129 | 130 | Or Inference with script 131 | 132 | ```shell 133 | python inference text-generate "tôi có một ít nhân sâm nhưng đang bị viêm dạ dày. Vậy tôi có nên ăn nhân sâm ko?" 134 | ``` 135 | 136 | ### Run on Google colab 137 | 138 | [Notebook](https://colab.research.google.com/drive/1yo53qWNo6bsfBNjp0IgLORQG0Howx30o?usp=drive_link) 139 | 140 | ## Training Data 141 | 142 | Dataset collected from [edoctor](https://edoctor.io/hoi-dap) 143 | and [vinmec](https://www.vinmec.com/vi/tin-tuc/hoi-dap-bac-si/). 144 | 145 | * Size: After merging data from these two sources, obtained 9335 QA pairs. 146 | * Language: Vietnamese 147 | 148 | Data example: 149 | 150 | ```json 151 | 152 | { 153 | "question": "Chào bác sĩ,\nRăng cháu hiện tại có mủ ở dưới lợi nhưng khi đau cháu sẽ không ngủ được (quá đau). Tuy nhiên chỉ vài ngày là hết mà thỉnh thoảng nó lại bị đau. Chị cháu bảo là trước chị cháu cũng bị như vậy chỉ là đau răng tuổi dậy thì thôi. Bác sĩ cho cháu hỏi đau răng kèm có mủ dưới lợi là bệnh gì? Cháu có cần đi chữa trị không? Cháu cảm ơn.", 154 | "answer": "Chào bạn,\nĐể trả lời câu hỏi trên, bác sĩ xin giải đáp như sau:\nRăng bạn hiện tại có mủ dưới lợi gây đau nhức nhiều. Bạn có thể đến phòng khám răng hàm mặt bệnh viện để được thăm khám, chụp phim và tư vấn cho bạn được chính xác\nTrân trọng!" 155 | } 156 | 157 | ``` 158 | 159 | ## Training 160 | 161 | Because we use pretrained [Viet-Mistral/Vistral-7B-Chat](https://huggingface.co/Viet-Mistral/Vistral-7B-Chat), ensure 162 | that you granted access to that model. 163 | 164 | Then, you have to create .env file and set your **HF_TOKEN**, **WANDB_KEY** 165 | 166 | ```shell 167 | HF_TOKEN= 168 | WANDB_KEY= 169 | ``` 170 | 171 | To training model you can run command: 172 | 173 | ```shell 174 | sh train.sh 175 | ``` 176 | 177 | or run: 178 | 179 | ```shell 180 | python -m training \ 181 | --model_name_or_path Viet-Mistral/Vistral-7B-Chat \ 182 | --train_path ./data/qa_pairs_edoctor.json,./data/qa_vinmec.json \ 183 | --lora True \ 184 | --qlora True \ 185 | --bf16 True \ 186 | --output_dir models/bioGPT-instruct \ 187 | --num_train_epochs 2 \ 188 | --per_device_train_batch_size 8 \ 189 | --per_device_eval_batch_size 8 \ 190 | --gradient_accumulation_steps 3 \ 191 | --eval_accumulation_steps 1 \ 192 | --evaluation_strategy "epoch" \ 193 | --eval_steps 40 \ 194 | --save_strategy "epoch" \ 195 | --save_steps 100 \ 196 | --save_total_limit 3 \ 197 | --learning_rate 1.2e-5 \ 198 | --lr_scheduler_type "cosine" \ 199 | --logging_steps 1 \ 200 | --tf32 True \ 201 | --model_max_length 1024 \ 202 | --gradient_checkpointing True \ 203 | --packing False \ 204 | --report_to "wandb" 205 | ``` 206 | 207 | ## Citation 208 | 209 | If you find our project helpful, please star our repo and cite our work. Thanks! 210 | 211 | ```bibtex 212 | @misc{viBioGPT, 213 | title={Vietnamese Medical QA: Question Answering dataset for medical in Vietnamese}, 214 | author={Hung Nguyen}, 215 | howpublished={\url{https://github.com/hungnm-ai/viBioGPT}}, 216 | year={2024}, 217 | } 218 | ``` -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from dataclasses import dataclass, field 3 | from transformers import TrainingArguments 4 | 5 | 6 | @dataclass 7 | class TokenizerArgs: 8 | _model_name_or_path: Optional[str] = field(default="Viet-Mistral/Vistral-7B-Chat", 9 | metadata={"help": "Tokenizer model name on HuggingFace"}) 10 | padding_side: Optional[str] = field(default="left", 11 | metadata={"help": "Setting padding side is left or right"}) 12 | 13 | 14 | @dataclass 15 | class ModelArgs: 16 | model_name_or_path: Optional[str] = field( 17 | # default="bkai-foundation-models/vietnamese-llama2-7b-120GB", 18 | default="Viet-Mistral/Vistral-7B-Chat", 19 | metadata={"help": "Model name or path to pretrained model"}) 20 | lora: Optional[bool] = field(default=True, 21 | metadata={"help": "Use lora to train"}) 22 | qlora: Optional[bool] = field(default=True, 23 | metadata={"help": "Use qlora to train"}) 24 | flash_attention: Optional[bool] = field(default=True, 25 | metadata={"help": "Use flash_attention to train"}) 26 | model_max_length: Optional[int] = field(default=1024, 27 | metadata={ 28 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}) 29 | 30 | 31 | @dataclass 32 | class DataArgs: 33 | train_path: str = field(default="", 34 | metadata={"help": "Path to the training data. Use comma for multi input"}) 35 | valid_path: str = field(default=None, 36 | metadata={"help": "Path to the evaluation data. Use comma for multi input"}) 37 | 38 | 39 | @dataclass 40 | class TrainingArguments(TrainingArguments): 41 | per_device_train_batch_size: int = field(default=8, 42 | metadata={ 43 | "help": "The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for training"}) 44 | per_device_eval_batch_size: int = field(default=8, 45 | metadata={ 46 | "help": "The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation"}) 47 | cache_dir: Optional[str] = field(default=None) 48 | optim: str = field(default="adamw_torch") 49 | packing: bool = field(default=False, 50 | metadata={"help": "Whether use packing or not"}) 51 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dotenv import load_dotenv 4 | 5 | load_dotenv() 6 | 7 | HF_TOKEN = os.getenv('HF_TOKEN') 8 | WANDB_KEY = os.getenv('WANDB_KEY') 9 | 10 | SYS_PROMPT = ("Bạn là một trợ lý ảo AI trong lĩnh vực Y học, Sức Khỏe. Tên của bạn là AI-Doctor. " 11 | "Nhiệm vụ của bạn là trả lời các thắc mắc hoặc các câu hỏi về Y học, Sức khỏe.") 12 | -------------------------------------------------------------------------------- /data_collators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import * 3 | from tokenizer import SpecialToken 4 | from transformers import DataCollatorForLanguageModeling 5 | 6 | 7 | class DataCollatorForCompletionLM(DataCollatorForLanguageModeling): 8 | def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: 9 | 10 | batch = super().torch_call(examples) 11 | labels = batch["labels"].clone() 12 | attention_mask = batch["labels"].clone() 13 | attention_mask[attention_mask != -100] = 1 14 | attention_mask[attention_mask == -100] = 0 15 | 16 | batch['attention_mask'] = attention_mask 17 | # The code then encodes a special token, RESPONSE_KEY_NL, 18 | # representing the end of the prompt followed by a newline. 19 | # It searches for this token in the sequence of tokens (labels) 20 | # and finds its index. 21 | # print(batch['attention_mask']) 22 | # print(batch['labels']) 23 | # print(batch['input_ids']) 24 | response_token_ids = self.tokenizer.encode(SpecialToken.end_instruct, 25 | add_special_tokens=False) 26 | for i in range(len(examples)): 27 | label = batch["labels"][i] 28 | response_token_ids_start_idx = None 29 | for idx in np.where(label == response_token_ids[0])[0]: 30 | response_token_ids_start_idx = idx 31 | break 32 | 33 | if response_token_ids_start_idx is None: 34 | response_token_ids_end_idx = -1 35 | # If the response token is not found in the sequence, it raises a RuntimeError. 36 | # Otherwise, it determines the end index of the response token. 37 | 38 | print(f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}') 39 | # raise RuntimeError( 40 | # f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}' 41 | # ) 42 | else: 43 | response_token_ids_end_idx = response_token_ids_start_idx + 1 44 | 45 | # To train the model to predict only the response and ignore the prompt tokens, 46 | # it sets the label values before the response token to -100. 47 | # This ensures that those tokens are ignored by the PyTorch loss function during training. 48 | labels[i, :response_token_ids_end_idx] = -100 49 | 50 | batch["labels"] = labels 51 | return batch 52 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | from typing import List, Dict, Union 4 | from torch.utils.data import Dataset 5 | from prompt.qa_prompt import QAPrompt 6 | from transformers import PreTrainedTokenizer 7 | 8 | 9 | class DataReader: 10 | def __init__(self, path: Union[str, List[str]]): 11 | 12 | if isinstance(path, str): 13 | path = [p.strip() for p in path.split(",")] 14 | self.path = path 15 | 16 | @staticmethod 17 | def clean_text(text: str) -> str: 18 | text = re.sub(r"Với câu hỏi “.*?”", "Để trả lời câu hỏi trên", text, flags=re.MULTILINE | re.IGNORECASE) 19 | text = re.sub(r"bệnh viện Vinmec", "bệnh viện", text, flags=re.MULTILINE | re.IGNORECASE) 20 | text = re.sub(r"hệ thống Y Khoa Vinmec", "bệnh viện", text, flags=re.MULTILINE | re.IGNORECASE) 21 | text = re.sub(r"đặt hẹn qua tổng đài Vinmec", "thăm khám tại bệnh viện", text, 22 | flags=re.MULTILINE | re.IGNORECASE) 23 | text = re.sub(r"edoctor", "AI-Doctor", text, flags=re.MULTILINE | re.IGNORECASE) 24 | return text 25 | 26 | def load_data(self) -> List[Dict[str, str]]: 27 | 28 | qa_pairs = [] 29 | for path in self.path: 30 | with open(path, 'r') as f: 31 | data = json.load(f) 32 | for item in data: 33 | question = item['question'] 34 | answer = item['answer'] 35 | answer = self.clean_text(answer) 36 | if "vinmec" not in question.lower(): 37 | answer = re.sub("vinmec", "AI-Doctor", answer, 38 | flags=re.MULTILINE | re.IGNORECASE) 39 | answer = answer.strip() 40 | question = question.strip() 41 | if len(answer) == 0 or len(question) == 0: 42 | continue 43 | qa_pairs.append({ 44 | 'question': question, 45 | 'answer': answer 46 | }) 47 | print("Number of pairs: ", len(qa_pairs)) 48 | # with open(os.path.join(app_path, "data", 'qa_medical_pairs.json'), 'w') as f: 49 | # json.dump(qa_pairs, f, ensure_ascii=False, indent=4) 50 | return qa_pairs 51 | 52 | 53 | class BioDataset(Dataset): 54 | def __init__(self, 55 | examples: List, 56 | tokenizer: PreTrainedTokenizer, 57 | max_length: int = 1024, 58 | truncation: bool = True, 59 | ignore_sample: bool = True, 60 | sort_by_length: bool = False) -> None: 61 | qa_prompt = QAPrompt() 62 | dataset = [] 63 | 64 | min_seq_len = max_length 65 | 66 | for i, example in enumerate(examples): 67 | prompt_ids = qa_prompt.build_prompt_template(example=example, 68 | tokenizer=tokenizer, 69 | tokenize=True, 70 | max_length=max_length, 71 | truncation=truncation) 72 | 73 | if i == 5: 74 | prompt_str = qa_prompt.build_prompt_template(example=example, 75 | tokenizer=tokenizer, 76 | tokenize=False, 77 | max_length=max_length, 78 | truncation=truncation) 79 | 80 | print("prompt_str: ", prompt_str) 81 | print("prompt_ids: ", prompt_ids) 82 | 83 | if (len(prompt_ids) > max_length or len(prompt_ids) < 128) and ignore_sample: 84 | continue 85 | 86 | dataset.append({ 87 | "prompt_ids": prompt_ids, 88 | "length": len(prompt_ids) 89 | }) 90 | 91 | if len(prompt_ids) < min_seq_len: 92 | min_seq_len = len(prompt_ids) 93 | if sort_by_length: 94 | dataset = sorted(dataset, key=lambda item: item['length'], reverse=True) 95 | self.dataset = [example["prompt_ids"] for example in dataset] 96 | 97 | print("Total dataset: ", len(self.dataset)) 98 | 99 | def __len__(self) -> int: 100 | return len(self.dataset) 101 | 102 | def __getitem__(self, index: int): 103 | return self.dataset[index] 104 | 105 | 106 | if __name__ == '__main__': 107 | import os 108 | from utils import app_path 109 | 110 | paths = [os.path.join(app_path, 'data', 'qa_vinmec.json'), 111 | os.path.join(app_path, 'data', 'qa_pairs_edoctor.json') 112 | ] 113 | train_reader = DataReader(paths) 114 | data = train_reader.load_data() 115 | from tokenizer import load_tokenizer 116 | from arguments import TokenizerArgs 117 | 118 | _tokenizer = load_tokenizer(TokenizerArgs()) 119 | bio_dataset = BioDataset(examples=data, 120 | tokenizer=_tokenizer) 121 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import typer 2 | import torch 3 | from peft import PeftModel 4 | from constants import HF_TOKEN 5 | from tokenizer import SpecialToken 6 | from prompt.qa_prompt import QAPrompt 7 | from transformers import (AutoTokenizer, 8 | BitsAndBytesConfig, 9 | AutoModelForCausalLM, 10 | PreTrainedTokenizer, 11 | MistralForCausalLM, 12 | StoppingCriteria, 13 | StoppingCriteriaList) 14 | 15 | app = typer.Typer() 16 | 17 | model_name = "Viet-Mistral/Vistral-7B-Chat" 18 | adapter = "hungnm/viBioGPT-7B-instruct-qlora-adapter" 19 | 20 | 21 | class StoppingCriteriaSub(StoppingCriteria): 22 | 23 | def __init__(self, stops=[], encounters=1): 24 | super().__init__() 25 | self.stops = [stop for stop in stops] 26 | 27 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 28 | # print(input_ids.device) 29 | for stop in self.stops: 30 | stop = stop.to(input_ids.device) 31 | if torch.all((stop == input_ids[0][-len(stop):])).item(): 32 | return True 33 | 34 | return False 35 | 36 | 37 | def load_tokenizer(): 38 | """load and config tokenizer""" 39 | 40 | tokenizer = AutoTokenizer.from_pretrained(model_name, 41 | token=HF_TOKEN) 42 | tokenizer.padding_side = "left" 43 | tokenizer.pad_token_id = tokenizer.eos_token_id 44 | 45 | return tokenizer 46 | 47 | 48 | @app.command() 49 | def merge_adapter(save_dir: str = "./models/merge_adapter"): 50 | """Merge adapter into pretrained model and save model""" 51 | # Load the pretrained model 52 | model = AutoModelForCausalLM.from_pretrained(model_name) 53 | 54 | # Load and activate the adapter on top of the base model 55 | model = PeftModel.from_pretrained(model, adapter) 56 | 57 | # Merge the adapter with the base model 58 | model = model.merge_and_unload() 59 | 60 | tokenizer = load_tokenizer() 61 | 62 | # Save the merged model in a director in the safetensors format 63 | model.save_pretrained(save_dir, 64 | safe_serialization=True) 65 | tokenizer.save_pretrained(save_dir) 66 | 67 | 68 | def load_adapter_merged(save_dir: str = typer.Option(default="./models/merge_adapter")): 69 | """Load adapter merged with pretrained model""" 70 | compute_dtype = getattr(torch, "bfloat16") 71 | bnb_config = BitsAndBytesConfig( 72 | load_in_4bit=True, 73 | bnb_4bit_quant_type="nf4", 74 | bnb_4bit_compute_dtype=compute_dtype, 75 | bnb_4bit_use_double_quant=True, 76 | ) 77 | model = AutoModelForCausalLM.from_pretrained(save_dir, 78 | quantization_config=bnb_config, 79 | device_map={"": 0}, 80 | use_flash_attention_2=True 81 | ) 82 | 83 | return model 84 | 85 | 86 | def load_adapter(): 87 | """ 88 | Load adapter on top of pretrained model (Vistral-7B-Chat) 89 | Note: Using the same loading hyperparameters used for fine-tuning 90 | """ 91 | compute_dtype = getattr(torch, "bfloat16") 92 | bnb_config = BitsAndBytesConfig( 93 | load_in_4bit=True, 94 | bnb_4bit_quant_type="nf4", 95 | bnb_4bit_compute_dtype=compute_dtype, 96 | bnb_4bit_use_double_quant=True, 97 | ) 98 | model = AutoModelForCausalLM.from_pretrained(model_name, 99 | quantization_config=bnb_config, 100 | device_map={"": 0}, 101 | use_flash_attention_2=True 102 | ) 103 | model = PeftModel.from_pretrained(model, adapter) 104 | 105 | return model 106 | 107 | 108 | @app.command() 109 | def text_generate(question: str, 110 | model=None, 111 | tokenizer=None): 112 | question = question.strip() 113 | if model is None: 114 | model = load_adapter() 115 | if tokenizer is None: 116 | tokenizer = load_tokenizer() 117 | 118 | stop_words = [SpecialToken.eos, SpecialToken.end_instruct] 119 | stop_words_ids = [tokenizer(stop_word, return_tensors='pt', 120 | add_special_tokens=False)['input_ids'].squeeze(1) 121 | for stop_word in stop_words] 122 | stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) 123 | 124 | prompt = QAPrompt() 125 | instruction_str = prompt.build_prompt_instruction(question=question, tokenizer=tokenizer) 126 | token_ids = tokenizer([instruction_str], return_tensors="pt")["input_ids"] 127 | token_ids = token_ids.to(model.device) 128 | outputs = model.generate(input_ids=token_ids, 129 | max_new_tokens=768, 130 | do_sample=True, 131 | temperature=0.001, 132 | top_p=0.95, 133 | top_k=40, 134 | repetition_penalty=1.2, 135 | stopping_criteria=stopping_criteria 136 | ) 137 | all_token_ids = outputs[0].tolist() 138 | output_token_ids = all_token_ids[token_ids.shape[-1]:] 139 | output = tokenizer.decode(output_token_ids) 140 | 141 | print(f"User: {question}\n") 142 | print(f"AI-Doctor: {output}") 143 | 144 | return output 145 | 146 | 147 | if __name__ == '__main__': 148 | app() 149 | -------------------------------------------------------------------------------- /prompt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungnlp/viBioGPT/ddded74366cfa5e22c866c31b0b859de93a7a300/prompt/__init__.py -------------------------------------------------------------------------------- /prompt/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Union, List 3 | from transformers import PreTrainedTokenizer 4 | 5 | 6 | class PromptBase(ABC): 7 | 8 | @abstractmethod 9 | def build_prompt_template(self, 10 | example: Dict[str, str], 11 | tokenizer: PreTrainedTokenizer, 12 | tokenize: bool = True) -> Union[str, List[int]]: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /prompt/qa_prompt.py: -------------------------------------------------------------------------------- 1 | from prompt.role import Role 2 | from constants import SYS_PROMPT 3 | from prompt.base import PromptBase 4 | from typing import Dict, Union, List 5 | from transformers import PreTrainedTokenizer 6 | 7 | 8 | class QAPrompt(PromptBase): 9 | def __init__(self, system_prompt: str = SYS_PROMPT): 10 | self.system_prompt = system_prompt 11 | 12 | def build_prompt_template(self, 13 | example: Dict[str, str], 14 | tokenizer: PreTrainedTokenizer, 15 | tokenize: bool = True, 16 | truncation: bool = True, 17 | max_length: int = 1024) -> Union[str, List[int]]: 18 | conversation = [ 19 | { 20 | "role": Role.system, 21 | "content": self.system_prompt 22 | }, 23 | { 24 | "role": Role.user, 25 | "content": example['question'] 26 | }, 27 | { 28 | "role": Role.assistant, 29 | "content": example['answer'] 30 | } 31 | ] 32 | prompt_str = tokenizer.apply_chat_template(conversation=conversation, 33 | tokenize=tokenize, 34 | truncation=truncation, 35 | max_length=max_length) 36 | 37 | return prompt_str 38 | 39 | def build_prompt_instruction(self, 40 | question: str, 41 | tokenizer: PreTrainedTokenizer, 42 | tokenize: bool = False, 43 | truncation: bool = False, 44 | max_length: int = 1024) -> Union[str, List[int]]: 45 | conversation = [ 46 | { 47 | "role": Role.system, 48 | "content": self.system_prompt 49 | }, 50 | { 51 | "role": Role.user, 52 | "content": question 53 | } 54 | ] 55 | instruction_str = tokenizer.apply_chat_template(conversation=conversation, 56 | tokenize=tokenize, 57 | truncation=truncation, 58 | max_length=max_length) 59 | 60 | return instruction_str 61 | -------------------------------------------------------------------------------- /prompt/role.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class Role(str): 5 | user = "user" 6 | system = "system" 7 | assistant = "assistant" 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.11.3 2 | typer==0.9.0 3 | sentence_transformers==2.2.2 4 | accelerate==0.21.0 5 | sentencepiece==0.1.99 6 | pydantic==1.10.13 7 | transformers==4.36.2 8 | chromadb==0.3.23 9 | colorama==0.4.6 10 | requests==2.31.0 11 | peft==0.7.1 12 | bitsandbytes==0.41.3.post2 13 | tiktoken==0.5.2 14 | openai==1.6.1 15 | scikit-learn==1.3.2 16 | wandb==0.16.1 17 | numpy==1.26.2 18 | datasets==2.15.0 19 | python-dotenv==1.0.1 20 | selenium==4.17.2 21 | scrapy==2.11.1 22 | html2text==2020.1.16 23 | tqdm==4.66.1 24 | -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | from constants import HF_TOKEN 2 | from arguments import TokenizerArgs 3 | from transformers import AutoTokenizer 4 | from transformers import PreTrainedTokenizer 5 | 6 | 7 | class SpecialToken: 8 | bos = "" 9 | eos = "" 10 | start_instruct = "[/INST]" 11 | end_instruct = "[INST]" 12 | start_sys = "<>" 13 | end_sys = "<>" 14 | 15 | 16 | def load_tokenizer(tokenizer_args: TokenizerArgs) -> PreTrainedTokenizer: 17 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=tokenizer_args._model_name_or_path, 18 | token=HF_TOKEN) 19 | tokenizer.padding_side = tokenizer_args.padding_side 20 | tokenizer.pad_token_id = tokenizer.eos_token_id 21 | 22 | return tokenizer 23 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python -m training \ 2 | --model_name_or_path Viet-Mistral/Vistral-7B-Chat \ 3 | --train_path ./data/qa_pairs_edoctor.json,./data/qa_vinmec.json \ 4 | --lora True \ 5 | --qlora True \ 6 | --bf16 True \ 7 | --output_dir models/bioGPT-instruct \ 8 | --num_train_epochs 2 \ 9 | --per_device_train_batch_size 8 \ 10 | --per_device_eval_batch_size 8 \ 11 | --gradient_accumulation_steps 3 \ 12 | --eval_accumulation_steps 1 \ 13 | --evaluation_strategy "epoch" \ 14 | --eval_steps 40 \ 15 | --save_strategy "epoch" \ 16 | --save_steps 100 \ 17 | --save_total_limit 3 \ 18 | --learning_rate 1.2e-5 \ 19 | --lr_scheduler_type "cosine" \ 20 | --logging_steps 1 \ 21 | --tf32 True \ 22 | --model_max_length 1024 \ 23 | --gradient_checkpointing True \ 24 | --packing False \ 25 | --report_to "wandb" -------------------------------------------------------------------------------- /train_debug.sh: -------------------------------------------------------------------------------- 1 | python -m training \ 2 | --model_name_or_path gpt2 \ 3 | --train_path ./data/qa_pairs_edoctor.json \ 4 | --lora False \ 5 | --qlora False \ 6 | --flash_attention False \ 7 | --bf16 False \ 8 | --output_dir models/vietnamese-llama2-7b-instruct \ 9 | --num_train_epochs 2 \ 10 | --per_device_train_batch_size 2 \ 11 | --per_device_eval_batch_size 2 \ 12 | --gradient_accumulation_steps 16 \ 13 | --eval_accumulation_steps 1 \ 14 | --evaluation_strategy "epoch" \ 15 | --eval_steps 40 \ 16 | --save_strategy "epoch" \ 17 | --save_steps 80 \ 18 | --save_total_limit 3 \ 19 | --learning_rate 2e-4 \ 20 | --lr_scheduler_type "cosine" \ 21 | --logging_steps 1 \ 22 | --tf32 False \ 23 | --model_max_length 1024 \ 24 | --gradient_checkpointing True \ 25 | --packing False \ 26 | --remove_unused_columns False \ 27 | --report_to "wandb" -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | import random 4 | import bitsandbytes as bnb 5 | from constants import WANDB_KEY 6 | from tokenizer import load_tokenizer 7 | from dataset import BioDataset, DataReader 8 | from sklearn.model_selection import train_test_split 9 | from data_collators import DataCollatorForCompletionLM 10 | from arguments import ModelArgs, TokenizerArgs, TrainingArguments, DataArgs 11 | from peft import get_peft_model, prepare_model_for_kbit_training, LoraConfig 12 | from transformers import (Trainer, HfArgumentParser, BitsAndBytesConfig, 13 | PreTrainedTokenizer, AutoModelForCausalLM) 14 | 15 | SEED = 100 16 | 17 | wandb.login(key=WANDB_KEY) 18 | run = wandb.init( 19 | project='bioGPT-instruct', 20 | job_type="training", 21 | anonymous="allow" 22 | ) 23 | 24 | 25 | def set_seed(seed: int = SEED): 26 | """Set random to ensure result reproducible""" 27 | random.seed(seed) 28 | torch.manual_seed(seed) 29 | 30 | 31 | def find_all_linear_names(model): 32 | lora_module_names = set() 33 | for name, module in model.named_modules(): 34 | if isinstance(module, bnb.nn.Linear4bit) or isinstance(module, torch.nn.Linear): 35 | names = name.split(".") 36 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 37 | 38 | if "lm_head" in lora_module_names: # needed for 16-bit 39 | lora_module_names.remove("lm_head") 40 | return list(lora_module_names) 41 | 42 | 43 | def print_trainable_parameters(model): 44 | """ 45 | Prints the number of trainable parameters in the model. 46 | """ 47 | trainable_params = 0 48 | all_param = 0 49 | for _, param in model.named_parameters(): 50 | all_param += param.numel() 51 | if param.requires_grad: 52 | trainable_params += param.numel() 53 | print( 54 | f"Trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}") 55 | 56 | 57 | def load_model(model_args: ModelArgs, tokenizer: PreTrainedTokenizer): 58 | device = {"": 0} if torch.cuda.is_available() else "cpu" 59 | quantization_config = None 60 | if model_args.qlora: 61 | compute_dtype = getattr(torch, "bfloat16") 62 | quantization_config = BitsAndBytesConfig(load_in_4bit=True, 63 | bnb_4bit_quant_type="nf4", 64 | bnb_4bit_use_double_quant=True, 65 | bnb_4bit_compute_dtype=compute_dtype) 66 | model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, 67 | device_map=device, 68 | trust_remote_code=True, 69 | quantization_config=quantization_config, 70 | use_flash_attention_2=model_args.flash_attention 71 | ) 72 | 73 | model.resize_token_embeddings(len(tokenizer)) 74 | 75 | model.config.pad_token_id = tokenizer.pad_token_id 76 | model.gradient_checkpointing_enable() 77 | 78 | modules = find_all_linear_names(model) 79 | 80 | if model_args.lora: 81 | if model_args.qlora: 82 | print("Using QLora to train...") 83 | model = prepare_model_for_kbit_training(model) 84 | else: 85 | print("Using Lora to train...") 86 | 87 | print("Modules is adopted: {}".format(modules)) 88 | # Adding the adopter to the layer 89 | peft_config = LoraConfig( 90 | r=16, # dimension of the updated matrices 91 | lora_alpha=64, # parameter for scaling 92 | # ['k_proj', 'q_proj', 'v_proj', 'o_proj', 93 | # "gate_proj", "down_proj", "up_proj"] 94 | target_modules=modules, 95 | lora_dropout=0.1, # dropout probability for layers 96 | bias="none", 97 | task_type="CAUSAL_LM", 98 | modules_to_save=["lm_head", "embed_tokens"], 99 | ) 100 | model = get_peft_model(model, peft_config) 101 | 102 | """Because KV cache is useless during training(Finetune), It only works for inference. 103 | For a Generative Language model. 104 | For a training iteration, all result are computed parallel with casual mask and teacher-forcing, 105 | which means all the key and value for different input token are computed in one time. 106 | https://stackoverflow.com/questions/76633335/why-does-hugging-face-falcon-model-use-mode-config-use-cache-false-why-wouldn 107 | """ 108 | model.config.use_cache = False 109 | 110 | print_trainable_parameters(model) 111 | 112 | return model 113 | 114 | 115 | def train(): 116 | arg_parser = HfArgumentParser((ModelArgs, DataArgs, TrainingArguments, TokenizerArgs)) 117 | model_args, data_args, training_args, tokenizer_args = arg_parser.parse_args_into_dataclasses() 118 | 119 | train_reader = DataReader(data_args.train_path) 120 | train_data = train_reader.load_data() 121 | if data_args.valid_path: 122 | train_reader = DataReader(data_args.valid_path) 123 | valid_data = train_reader.load_data() 124 | 125 | else: 126 | train_data, valid_data = train_test_split(train_data, 127 | test_size=0.1, 128 | random_state=SEED, 129 | shuffle=True) 130 | 131 | tokenizer_args._model_name_or_path = model_args.model_name_or_path 132 | tokenizer = load_tokenizer(tokenizer_args) 133 | 134 | print("Number of training examples: {}".format(len(train_data))) 135 | print("Number of valid examples: {}".format(len(valid_data))) 136 | train_dataset = BioDataset(examples=train_data, 137 | tokenizer=tokenizer) 138 | 139 | valid_dataset = BioDataset(examples=valid_data, 140 | tokenizer=tokenizer) 141 | 142 | model = load_model(model_args, tokenizer) 143 | data_collator = DataCollatorForCompletionLM(mlm=False, tokenizer=tokenizer) 144 | trainer = Trainer(model=model, 145 | args=training_args, 146 | train_dataset=train_dataset, 147 | eval_dataset=valid_dataset, 148 | data_collator=data_collator 149 | ) 150 | 151 | trainer.train() 152 | 153 | 154 | if __name__ == '__main__': 155 | train() 156 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | app_path = pathlib.Path(__file__).parent.resolve() 4 | --------------------------------------------------------------------------------