├── .gitignore ├── README.md ├── config.py ├── data_convert.py ├── data_generate.py ├── seed_tasks.jsonl ├── template ├── disease_classification.py └── prompt_template.txt ├── tools ├── cfg_wapper.py └── qianfan_requestor.py └── unsloth_finetune.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | *.vscode/* 3 | *.idea/* 4 | __pycache__ 5 | *.pyc 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SFT_data_generation 2 | ## Introduction 3 | The main purpose of this project is to generate instruction fine-tuning data for fine-tuning open source LLMs. Drawing on the work of Stanford Alpaca[1] in generating instruction-following data, the data generation process of this work refers to the SELF-INSTRUCT[2] work of the University of Washington. By constructing seed data and using LLMs to generate instruction fine-tuning data for scene adaptation fine-tuning, we achieve efficient and low-cost automated generation of fine-tuning data. We have successfully conducted fine-tuning experiments on Llama3 using the generated fine-tuning data. 4 | 5 | The fine-tuning scenario for this project is an online medical service scenario, so the coding is centered around this scenario. We can modify and adjust the code according to our own scenarios. 6 | 7 | ## Processing 8 | - Construct seed data 9 | - Set up seed data for different tasks based on the fine-tuning scenarios you have set. In this example, we have set three types of instruction tasks. Please refer to seed_tasks.jsonl 10 | - Build generation templates; 11 | - Construct prompt templates and place the seed data into the prompts as examples. 12 | Please refer to template/prompt_template.txt 13 | - The "disease_list" field in the prompt template is designed to consider diverse entities. Since the scenario of this project is medical, this field represents a specific dictionary of disease classifications. During data generation, several diseases will be randomly selected from this list as references for generating data. 14 | - Call large model API to generate data; 15 | - The main process of calling the large model API to generate instruction data is to use the large model's in-context learning. By providing only a few examples, the large model can generate high-quality instruction data that follows the requirements of the prompt. The API used in this project is Baidu ERNIE-4.0-8K, but you can also switch to other large model APIs. 16 | - Condat data 17 | - The Fine-tuning data for models can often be supplemented with some collected open source fine-tuning data. Additionally, it is necessary to incorporate a portion of general corpus (to enhance the generalization of model fine-tuning). Data from different sources should be converted into a unified format to construct the final instruction fine-tuning data. However, determining the optimal ratio of data from different sources and the appropriate data scale to achieve better fine-tuning results requires further experimentation and validation on our own. 18 | 19 | ### How to use 20 | - modify the seed_tasks.jsonl according to your own fine-tuning scenario requirements. The more seed data you provide, the better the diversity of the generated content will be. 21 | - modify the template/prompt_template.txt according to your own fine-tuning scenario requirements.This prompt template is designed for Chinese. You can modify it for other languages to make the data generated by the model closer to real-world question and answer scenarios. 22 | - Run the script "data_generate.py" to generate instruction tuning data 23 | - add your api key and api secret of Baidu ernie api in the config.py 24 | - Please note that the actual number of large model API calls is equal to the product of generate_times and random_seed_num. Additionally, you can specify how many data entries to generate per API call in the 9th item of the prompt template. 25 | 26 | ### installation 27 | - Make sure you have installed requests package,recomand version is 2.26.0. 28 | - If you want to try fine-tuning Lamma3 using unsloth[3], you need to install unsloth. 29 | - install unsloth 30 | - pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" 31 | - pip install --no-deps "xformers<0.0.26" trl peft accelerate bitsandbytes 32 | - try finetune Llama3 33 | - Run the script "unsloth_finetune.py" 34 | 35 | 36 | 37 | ### refrence 38 | - [1] https://github.com/tatsu-lab/stanford_alpaca/tree/main 39 | - [2] https://arxiv.org/pdf/2212.10560 40 | - [3] https://github.com/unslothai/unsloth 41 | 42 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | config_dict=dict( 2 | access_config = dict( 3 | api_key="", # add your api_key here 4 | api_secret="", # add your api_secret here 5 | client_credentials_address="https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials", 6 | ), 7 | model_config = dict( 8 | model_dict={ 9 | "ERNIE-Lite-8K":'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', 10 | "ERNIE-4.0-8K":'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', 11 | }) 12 | ) -------------------------------------------------------------------------------- /data_convert.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | def random_choice_from_file(file_path,choose_num=1000,new_path_suffix='random',data_type='huatuo'): 5 | with open(file_path, 'r',encoding="utf-8") as f: 6 | if data_type == 'huatuo': 7 | lines = f.readlines() 8 | random.shuffle(lines) 9 | choose_samples = lines[:choose_num] 10 | else: 11 | # for json object 12 | data = json.load(f) 13 | random.shuffle(data) 14 | choose_samples = data[:choose_num] 15 | # convert dict element to string 16 | choose_samples = [json.dumps(sample,ensure_ascii=False)+'\n' for sample in choose_samples] 17 | 18 | # add suffix to new file and keep json format 19 | file_path,file_type = file_path.split('.') 20 | new_path = file_path+'_'+new_path_suffix+ f'_{choose_num}' + '.'+file_type 21 | 22 | with open(new_path, 'w+',encoding="utf-8") as f_out: 23 | for line in choose_samples: 24 | f_out.write(line) 25 | 26 | # split data to train and val 27 | def split_data(input_file,out_dir,split_ratio=0.9): 28 | with open(input_file, 'r',encoding="utf-8") as f: 29 | lines = f.readlines() 30 | random.shuffle(lines) 31 | train_lines = lines[:int(len(lines)*split_ratio)] 32 | val_lines = lines[int(len(lines)*split_ratio):] 33 | with open(out_dir+'train.json', 'w+',encoding="utf-8") as f_train: 34 | for line in train_lines: 35 | f_train.write(line) 36 | with open(out_dir+'val.json', 'w+',encoding="utf-8") as f_val: 37 | for line in val_lines: 38 | f_val.write(line) 39 | 40 | 41 | # convert another dict type to uniform dict type 42 | def convert_dict_type(input_file,output_file,data_type='huatuo'): 43 | 44 | # standard format data 45 | #{"id": "1_gout_appointment", "name": "gout_appointment_instruction", "instruction": "我想预约看痛风,该挂什么科?", "instances": [{"input": "", "output": "您好,痛风一般属于风湿免疫科,建议您挂风湿免疫科的号进行就诊。"}]} 46 | 47 | format_data = [] 48 | with open(input_file, 'r',encoding="utf-8") as f: 49 | origin_data = f.readlines() 50 | if data_type == 'huatuo': 51 | # "query" and "answer" convert to "instruction" and "output",'input' as null 52 | for line in origin_data: 53 | data = json.loads(line) 54 | standard_dict = {"instruction": "","instances": [{"input": "", "output": ""}]} 55 | standard_dict['instruction'] = data['query'] 56 | standard_dict['instances'][0]['output'] = data['answer'] 57 | format_data.append(standard_dict) 58 | elif data_type == 'general': 59 | for line in origin_data: 60 | data = json.loads(line) 61 | standard_dict = {"instruction": "","instances": [{"input": "", "output": ""}]} 62 | standard_dict['instruction'] = data['instruction'] 63 | standard_dict['instances'][0]['input'] = data['input'] 64 | standard_dict['instances'][0]['output'] = data['output'] 65 | format_data.append(standard_dict) 66 | else: 67 | for line in origin_data: 68 | data = json.loads(line) 69 | standard_dict = {"instruction": "","instances": [{"input": "", "output": ""}]} 70 | standard_dict['instruction'] = data['instruction'] 71 | standard_dict['instances'] = data['instances'] 72 | format_data.append(standard_dict) 73 | 74 | return format_data 75 | 76 | 77 | 78 | # concat three data files to final sft data file 79 | def concat_data(med_data_file,general_data_file,generate_sft_raw_data_file,final_sft_data_file): 80 | 81 | total_lines = [] 82 | 83 | with open(med_data_file, 'r',encoding="utf-8") as f: 84 | med_lines = f.readlines() 85 | # convert dict type to uniform dict type 86 | med_lines = convert_dict_type(med_data_file,med_lines,data_type='huatuo') 87 | total_lines.extend(med_lines) 88 | with open(general_data_file, 'r',encoding="utf-8") as f: 89 | general_lines = f.readlines() 90 | # convert dict type to uniform dict type 91 | general_lines = convert_dict_type(general_data_file,general_lines,data_type='general') 92 | total_lines.extend(general_lines) 93 | with open(generate_sft_raw_data_file, 'r',encoding="utf-8") as f: 94 | generate_lines = f.readlines() 95 | # convert dict type to uniform dict type 96 | generate_lines = convert_dict_type(generate_sft_raw_data_file,generate_lines,data_type='sft') 97 | total_lines.extend(generate_lines) 98 | 99 | # shuffle data 100 | random.shuffle(total_lines) 101 | with open(final_sft_data_file, 'w+',encoding="utf-8") as f: 102 | for line in total_lines: 103 | f.write(json.dumps(line,ensure_ascii=False)+'\n') 104 | 105 | 106 | 107 | 108 | if __name__ == '__main__': 109 | 110 | med_data_file = 'data/public/huatuo-GPT-226k.jsonl' 111 | # random choice from hua tuo data 112 | # random_choice_from_file(med_data_file,choose_num=350,new_path_suffix='random') 113 | 114 | general_data_file = 'data/public/alpaca_data_zh_51k.json' 115 | # random choice from general data 116 | # random_choice_from_file(general_data_file,choose_num=500,new_path_suffix='random',data_type='general') 117 | 118 | 119 | # concat data 120 | generate_sft_raw_data_file = 'data/output/sft_data_raw.jsonl' 121 | med_data_file = 'data/public/huatuo-GPT-226k_random_350.jsonl' 122 | general_data_file = 'data/public/alpaca_data_zh_51k_random_500.json' 123 | final_sft_data_file = 'data/output/sft_data_1500.jsonl' 124 | concat_data(med_data_file,general_data_file,generate_sft_raw_data_file,final_sft_data_file) 125 | 126 | # # split data to train and val 127 | # out_dir = 'data/' 128 | # split_data(output_file,out_dir,split_ratio=0.9) -------------------------------------------------------------------------------- /data_generate.py: -------------------------------------------------------------------------------- 1 | import re 2 | import random 3 | from tools.qianfan_requestor import init_qianfan_requestor 4 | from config import config_dict 5 | import json 6 | from tqdm import tqdm 7 | from template.disease_classification import disease_dictionary 8 | 9 | cut_space = re.compile(r'[\s\'\n]+|```|json') 10 | 11 | def format_prompt(prompt, instruction, disease_pro_list,instruct_tag='',personal_tag=''): 12 | 13 | prompt = cut_space.sub('', prompt) 14 | # replace instruction tag with instruction 15 | prompt = prompt.replace(instruct_tag, instruction) 16 | # replace specific disease tag with disease proposals 17 | prompt = prompt.replace(personal_tag, disease_pro_list) 18 | return prompt 19 | 20 | def choose_disease_proposal(disease_dict,choose_num=3): 21 | disease_list = [] 22 | for disease in disease_dict.keys(): 23 | disease_list.extend(disease_dict[disease]) 24 | 25 | return random.sample(disease_list,choose_num) 26 | 27 | 28 | def generate_sft_data(instruction_file, prompt_file,output_file,chat_requestor,random_seed_num=3,generate_times=20): 29 | 30 | instruction_list = [] 31 | personal_info_list = [] 32 | 33 | with open(instruction_file, 'r',encoding="utf-8") as f: 34 | 35 | for line in f: 36 | # cut space for each line 37 | instruction_line = cut_space.sub('', line.strip()) 38 | instruction_list.append(instruction_line) 39 | 40 | # load prompt from file 41 | with open(prompt_file, 'r',encoding="utf-8") as f: 42 | prompt = f.read() 43 | 44 | with open(output_file, 'w+',encoding="utf-8") as f: 45 | 46 | for i in tqdm(range(generate_times)): 47 | 48 | # choose random instruction and disease proposal 49 | instruction = random.sample(instruction_list,random_seed_num) 50 | disease_pro_list = choose_disease_proposal(disease_dictionary,choose_num=random_seed_num) 51 | # concat instruction and personal info to prompt 52 | # there is only one shot for each instruction due the limitation nums of seed, if you have more seeds, you can add more samples in the context 53 | for j in range(random_seed_num): 54 | f_prompt = format_prompt(prompt, instruction[j], disease_pro_list[j]) 55 | res = chat_requestor.send_message(f_prompt) 56 | print('api response:',res) 57 | try: gene_sample_list = json.loads(cut_space.sub('', res['result'])) 58 | except Exception as e: 59 | print('json loads error,continue') 60 | continue 61 | 62 | for gene_sample in gene_sample_list: 63 | gene_sample_str = json.dumps(gene_sample,ensure_ascii=False) 64 | f.write(gene_sample_str+'\n') 65 | f.flush() 66 | 67 | 68 | 69 | if __name__ == '__main__': 70 | 71 | instruction_file = 'seed_tasks.jsonl' 72 | prompt_file = './template/prompt_template.txt' 73 | output_file = './data/output/sft_data_raw_1.jsonl' 74 | 75 | # initial baidu_ai_chat 76 | baidu_chat = init_qianfan_requestor(config_dict) 77 | 78 | generate_sft_data(instruction_file, prompt_file,output_file,baidu_chat,generate_times=70) 79 | 80 | -------------------------------------------------------------------------------- /seed_tasks.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "seed_task_1", "name": "appointment_platform", "instruction": "有什么挂号预约平台推荐吗?","instances": [{"input":"明天想去医院看颈椎" ,"output": "您好,如果您明天想去医院看颈椎,我建议您提前进行预约。您可以通过一些在线医疗预约平台进行预约,比如微医、健康160等,这些平台通常可以提供多家医院的号源。另外,您也可以直接拨打您想去的医院的预约电话进行预约。\n\n在选择医院时,您可以考虑选择有专业骨科或者神经外科的医院,比如市人民医院或者骨科医院等。如果您不确定哪家医院更合适,可以在预约平台上查看医院的评价和医生的资质等信息,以便做出更好的选择。\n\n希望这些建议能帮到您,祝您早日康复!"}]} 2 | {"id": "seed_task_1", "name": "appointment_phone", "instruction": "如何通过电话预约挂号?我听说可以电话预约挂号,是真的吗?该怎么操作呢", "instances": [{"input": "", "output": "您好,是的,很多医院都提供电话预约挂号服务。您可以查找您想去的医院的预约电话,通常在医院的官方网站或公告板上会有相关信息。拨打电话后,按照语音提示或人工服务的指引,提供您的个人信息和想要预约的科室及时间,即可完成预约。"}]} 3 | {"id": "seed_task_1", "name": "appointment_processing", "instruction": "你好,他们的挂号预约流程是怎样的?", "instances": [{"input": "我想去市第一医院就诊", "output": "您好,如果您想去市第一医院就诊,可以先访问他们的官方网站或使用相关的手机应用进行预约。在预约时,您需要选择想要看的科室和医生,以及合适的就诊时间。请提前准备好个人身份证件和相关医疗信息,以便顺利完成预约。"}]} 4 | {"id": "seed_task_1", "name": "appointment_disease", "instruction": "我感觉心脏很不舒服,想要紧急预约一个心内科的号,应该怎么办?", "instances": [{"input": "", "output": "您好,如果您感觉心脏不适,建议您立即拨打急救电话或前往就近的医院进行紧急处理。如果您想预约心内科的号,可以尝试联系您所在地区的三甲医院或专业心血管医院,询问是否有紧急预约服务。同时,您也可以尝试在线医疗预约平台,查找可以提供紧急预约服务的医生。"}]} 5 | {"id": "seed_task_2", "name": "guiding_consultation", "instruction": "你好,我感觉胃不太舒服改挂哪个科?", "instances": [{"input": "", "output": "您好,感觉胃不太舒服通常需要去消化内科就诊。请问您是否有胃痛、胃胀、反酸等症状?另外,您之前是否有过胃病史或者最近饮食习惯有所改变?这些信息将有助于医生更好地了解您的病情。请前往消化内科挂号,并详细描述您的症状,以便医生为您提供专业的诊断和治疗建议。"}]} 6 | {"id": "seed_task_2", "name": "guiding_consultation", "instruction": "我最近总是感觉疲劳无力,该去哪个科室检查?", "instances": [{"input": "", "output": "您好,长期感到疲劳无力可能与多种原因有关,建议您先咨询内科。请问您除了疲劳无力外,还有其他如食欲不振、失眠等症状吗?这些信息将帮助医生更全面地评估您的健康状况。请通过内科在线预约挂号,并向医生详细描述您的症状和持续时间,以便得到适当的检查和治疗建议。"}]} 7 | {"id": "seed_task_2", "name": "guiding_consultation", "instruction": "我应该看哪个科?", "instances": [{"input": "最近皮肤很痒,还有红疹", "output": "您好,皮肤痒和红疹可能是皮肤问题的症状,建议您咨询皮肤科医生。请问您出现这些症状有多长时间了?是否有过敏史或接触过可能引起过敏的物质?这些信息对医生诊断非常重要。请通过在线平台预约皮肤科医生,就诊时请详细描述症状的发展过程和您的个人病史,以便医生为您提供精准的治疗方案。"}]} 8 | {"id": "seed_task_2", "name": "guiding_consultation", "instruction": "我应该去哪个科室看看?", "instances": [{"input": "最近总是心慌,有时心跳很快,", "output": "您好,心慌和心跳加快可能与心脏健康有关,建议您咨询心血管内科。请问您除了心慌外,是否有胸闷、气短等其他症状?这些症状的出现频率和持续时间如何?这些信息将有助于医生评估您的心脏功能。请尽快通过在线导诊系统预约心血管内科医生,并在就诊时详细叙述您的症状和病史,以便医生为您做出准确的诊断和治疗计划。"}]} 9 | {"id": "seed_task_3", "name": "health_management_diabetes", "instruction": "针对糖尿病患者的健康建议是什么?", "instances": [{"input": "我被诊断为2型糖尿病", "output": "作为糖尿病患者,保持稳定的血糖水平至关重要。建议您定期监测血糖,遵医嘱用药,保持低糖饮食,适量运动,并避免高糖高脂食物。此外,保持充足的睡眠和减少压力也有助于控制血糖。"}]} 10 | {"id": "seed_task_3", "name": "health_management_arthritis", "instruction": "关节炎患者应该如何缓解疼痛?", "instances": [{"input": "", "output": "关节炎患者应保持适度的运动以增强关节周围肌肉的力量,同时避免过度使用关节。保持温暖和舒适的环境,使用热敷或冷敷来缓解疼痛。如果疼痛严重,可以咨询医生使用止痛药或物理治疗。"}]} 11 | {"id": "seed_task_3", "name": "health_management_hypertension", "instruction": "高血压患者应该如何控制血压?", "instances": [{"input": "", "output": "高血压患者应遵医嘱用药,保持低盐低脂饮食,适量运动,戒烟限酒,保持心情愉快,避免过度劳累。定期监测血压,及时调整治疗方案。"}]} 12 | {"id": "seed_task_3", "name": "health_management_asthma", "instruction": "哮喘患者应该如何预防哮喘发作?", "instances": [{"input": "", "output": "哮喘患者应避免接触过敏原,保持室内空气清新,避免烟尘刺激。定期使用吸入器,遵医嘱用药,保持良好的生活习惯,避免过度劳累。"}]} -------------------------------------------------------------------------------- /template/disease_classification.py: -------------------------------------------------------------------------------- 1 | disease_dictionary = { 2 | "呼吸道疾病": [ 3 | "感冒", 4 | "急性上呼吸道感染", 5 | "急慢性呼吸道感染", 6 | "肺炎", 7 | "肺结核", 8 | "急性扁桃体炎", 9 | "慢性支气管炎", 10 | "支气管哮喘", 11 | "急性鼻窦炎", 12 | "咽喉炎" 13 | ], 14 | "胃肠道疾病": [ 15 | "胃溃疡", 16 | "急慢性胃炎", 17 | "胆囊炎", 18 | "胆结石", 19 | "阑尾炎", 20 | "结肠炎", 21 | "十二指肠溃疡", 22 | "便秘", 23 | "感染性腹泻" 24 | ], 25 | "心脑血管疾病": [ 26 | "高血压", 27 | "冠心病", 28 | "心力衰竭", 29 | "心肌病", 30 | "心包炎", 31 | "先天性心脏病", 32 | "心绞痛", 33 | "心肌梗塞", 34 | "脑出血", 35 | "脑梗塞", 36 | "脑血栓" 37 | ], 38 | "内分泌疾病": [ 39 | "糖尿病", 40 | "骨质疏松", 41 | "痛风", 42 | "甲亢", 43 | "甲状腺肿瘤" 44 | ], 45 | "其他疾病": [ 46 | "贫血", 47 | "肾炎", 48 | "尿路感染", 49 | "白内障", 50 | "青光眼" 51 | ] 52 | } -------------------------------------------------------------------------------- /template/prompt_template.txt: -------------------------------------------------------------------------------- 1 | 你被要求提供一条包含两个任务的指令微调数据生成 2 | 以下是要求: 3 | 1.尽量不要在每个指令中重复动词,以最大限度地提高多样性; 4 | 2.指令中使用的语言也应该多样化。体现丰富性; 5 | 4.输出应该是对指令和输入的适当响应。确保输出少于100个字。 6 | 5.生成的数据需要增加不同的病种,并且适配不同的任务; 7 | 6.注意生成的数据的指令类型只有这三类,指令范围务必限制在这三个范围内:(1)挂号预约;(2)用户导诊;(3)健康管理。健康管理与疾病咨询不同,偏向于生活习惯,运动,饮食注意事项等; 8 | 8.生成的数据格式与示例数据格式保持一致; 9 | 9.输出3条不一样的数据,并组成json格式的数据返回,注意不要增加多余字符,直接返回json格式数据,列表中包含3条生成的字典数据:[{...}]; 10 | 10.注意instruction为必需的指令字段,input为对instruction的指令补充,有可能为空,如果instruction指令描述清楚时input为空,如果需要进一步补充信息才不为空。 11 | 11.本次主要参考的疾病列表:; 12 | 以下是相关示例: 13 | 14 | 输出: 15 | -------------------------------------------------------------------------------- /tools/cfg_wapper.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | class cfg_dict(object): 4 | 5 | # constructor 6 | def __init__(self, dict1): 7 | self.__dict__.update(dict1) 8 | 9 | def getitem(self, key): 10 | return self.__dict__.get(key) 11 | 12 | 13 | 14 | def load_config(dict1): 15 | 16 | return json.loads(json.dumps(dict1), object_hook=cfg_dict) -------------------------------------------------------------------------------- /tools/qianfan_requestor.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | from config import config_dict 4 | from tools.cfg_wapper import load_config 5 | 6 | class BaiduAIChat: 7 | def __init__(self, api_key, secret_key,api_address,client_credentials_address): 8 | self.api_key = api_key 9 | self.secret_key = secret_key 10 | self.api_address = api_address 11 | self.client_credentials_address = client_credentials_address 12 | self.access_token = self.get_access_token() 13 | self.req_url = f"{self.api_address}?access_token={self.access_token}" 14 | 15 | def get_access_token(self): 16 | url = ( 17 | f"{self.client_credentials_address}" 18 | f"&client_id={self.api_key}&client_secret={self.secret_key}" 19 | ) 20 | payload = json.dumps("") 21 | headers = { 22 | 'Content-Type': 'application/json', 23 | 'Accept': 'application/json' 24 | } 25 | 26 | response = requests.request("POST", url, headers=headers, data=payload) 27 | return response.json().get("access_token") 28 | 29 | 30 | def send_message(self, message): 31 | 32 | payload = json.dumps({ 33 | "messages": [ 34 | { 35 | "role": "user", 36 | "content": f"{message}" 37 | } 38 | ] 39 | }) 40 | headers = { 41 | 'Content-Type': 'application/json' 42 | } 43 | 44 | response = requests.request("POST", self.req_url, headers=headers, data=payload) 45 | 46 | 47 | return response.json() 48 | 49 | # initial qianfan_requestor use the following code 50 | def init_qianfan_requestor(config_dict): 51 | config = load_config(config_dict) 52 | 53 | access_config = config.access_config 54 | 55 | api_key = access_config.api_key 56 | api_secret = access_config.api_secret 57 | client_credentials_address = access_config.client_credentials_address 58 | 59 | model_config = config.model_config 60 | api_model_address = model_config.model_dict.getitem('ERNIE-4.0-8K') 61 | baidu_ai_chat = BaiduAIChat(api_key, api_secret,api_model_address,client_credentials_address) 62 | return baidu_ai_chat 63 | 64 | 65 | # 使用示例 66 | if __name__ == '__main__': 67 | 68 | config = load_config(config_dict) 69 | 70 | access_config = config.access_config 71 | 72 | api_key = access_config.api_key 73 | api_secret = access_config.api_secret 74 | client_credentials_address = access_config.client_credentials_address 75 | 76 | 77 | model_config = config.model_config 78 | api_model_address = model_config.model_dict.getitem('ERNIE-4.0-8K') 79 | baidu_ai_chat = BaiduAIChat(api_key, api_secret,api_model_address,client_credentials_address) 80 | 81 | res = baidu_ai_chat.send_message("你好") 82 | 83 | print(res) -------------------------------------------------------------------------------- /unsloth_finetune.py: -------------------------------------------------------------------------------- 1 | from unsloth import FastLanguageModel 2 | import torch 3 | 4 | from datasets import load_dataset 5 | from trl import SFTTrainer 6 | from transformers import TrainingArguments 7 | 8 | max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! 9 | dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ 10 | load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False. 11 | 12 | 13 | alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 14 | 15 | ### Instruction: 16 | {} 17 | 18 | ### Input: 19 | {} 20 | 21 | ### Response: 22 | {}""" 23 | 24 | 25 | 26 | # 4bit pre quantized models we support for 4x faster downloading + no OOMs. 27 | fourbit_models = [ 28 | "unsloth/mistral-7b-bnb-4bit", 29 | "unsloth/mistral-7b-instruct-v0.2-bnb-4bit", 30 | "unsloth/llama-2-7b-bnb-4bit", 31 | "unsloth/gemma-7b-bnb-4bit", 32 | "unsloth/gemma-7b-it-bnb-4bit", # Instruct version of Gemma 7b 33 | "unsloth/gemma-2b-bnb-4bit", 34 | "unsloth/gemma-2b-it-bnb-4bit", # Instruct version of Gemma 2b 35 | "unsloth/llama-3-8b-bnb-4bit", # [NEW] 15 Trillion token Llama-3 36 | ] # More models at https://huggingface.co/unsloth 37 | 38 | # load model 39 | def load_model(model_name = "unsloth/llama-3-8b-bnb-4bit"): 40 | 41 | model, tokenizer = FastLanguageModel.from_pretrained( 42 | model_name = model_name, 43 | max_seq_length = max_seq_length, 44 | dtype = dtype, 45 | load_in_4bit = load_in_4bit, 46 | # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf 47 | ) 48 | 49 | return model, tokenizer 50 | 51 | # add LoRA adapters so we only need to update 1 to 10% of all parameter 52 | def lora_adapt(base_model): 53 | 54 | model = FastLanguageModel.get_peft_model( 55 | base_model, 56 | r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 57 | target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", 58 | "gate_proj", "up_proj", "down_proj",], 59 | lora_alpha = 16, 60 | lora_dropout = 0, # Supports any, but = 0 is optimized 61 | bias = "none", # Supports any, but = "none" is optimized 62 | # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! 63 | use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context 64 | random_state = 3407, 65 | use_rslora = False, # We support rank stabilized LoRA 66 | loftq_config = None, # And LoftQ 67 | ) 68 | 69 | return model 70 | 71 | # format prompt for customized dataset 72 | def formatting_sft_prompts_func(examples): 73 | 74 | instructions = examples["instruction"] 75 | 76 | instances = examples['instances'] 77 | inputs = [] 78 | outputs = [] 79 | 80 | for instance in instances: 81 | inputs.append(instance[0]['input']) 82 | outputs.append(instance[0]['output']) 83 | texts = [] 84 | 85 | for instruction, input, output in zip(instructions, inputs, outputs): 86 | # print(f'instruction = {instruction}\ninput = {input}\noutput = {output}') 87 | # Must add EOS_TOKEN, otherwise your generation will go on forever! 88 | text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN 89 | texts.append(text) 90 | return { "text" : texts, } 91 | 92 | 93 | # load customize dataset 94 | def load_data(dataset_path = "/content/train_sft", split = 'train'): 95 | dataset = load_dataset(path=dataset_path,split =split) 96 | dataset = dataset.map(formatting_sft_prompts_func, batched = True,) 97 | 98 | return dataset 99 | 100 | 101 | # initialize trainer 102 | def init_trainer(model, tokenizer, dataset, max_seq_length = 2048): 103 | trainer = SFTTrainer( 104 | model = model, 105 | tokenizer = tokenizer, 106 | train_dataset = dataset, 107 | dataset_text_field = "text", 108 | max_seq_length = max_seq_length, 109 | dataset_num_proc = 2, 110 | packing = False, # Can make training 5x faster for short sequences. 111 | args = TrainingArguments( 112 | per_device_train_batch_size = 2, 113 | gradient_accumulation_steps = 4, 114 | warmup_steps = 5, 115 | max_steps = 60, 116 | learning_rate = 2e-4, 117 | fp16 = not torch.cuda.is_bf16_supported(), 118 | bf16 = torch.cuda.is_bf16_supported(), 119 | logging_steps = 1, 120 | optim = "adamw_8bit", 121 | weight_decay = 0.01, 122 | lr_scheduler_type = "linear", 123 | seed = 3407, 124 | output_dir = "outputs", 125 | ), 126 | ) 127 | 128 | return trainer 129 | 130 | # save model 131 | def save_model(model, tokenizer, path = "outputs"): 132 | model.save_pretrained(path) 133 | tokenizer.save_pretrained(path) 134 | 135 | 136 | def load_model(model_path = "outputs"): 137 | model, tokenizer = FastLanguageModel.from_pretrained( 138 | model_name = model_path, 139 | max_seq_length = max_seq_length, 140 | dtype = dtype, 141 | load_in_4bit = load_in_4bit, 142 | # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf 143 | ) 144 | FastLanguageModel.for_inference(model) # Enable native 2x faster inference 145 | 146 | return model, tokenizer 147 | 148 | # inference 149 | def inference(model, tokenizer,intruction='Continue the fibonnaci sequence', input='1, 1, 2, 3, 5, 8'): 150 | FastLanguageModel.for_inference(model) # Enable native 2x faster inference 151 | inputs = tokenizer( 152 | [ 153 | alpaca_prompt.format( 154 | intruction, # instruction 155 | input, # input 156 | "", # output - leave this blank for generation! 157 | ) 158 | ], return_tensors = "pt").to("cuda") 159 | 160 | outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True) 161 | res = tokenizer.batch_decode(outputs) 162 | return res 163 | 164 | 165 | def train_pipeline(ori_model_path,data_path,output_path = "./data/output/lora_ckpt"): 166 | # load model 167 | model, tokenizer = load_model(ori_model_path) 168 | 169 | global EOS_TOKEN 170 | EOS_TOKEN = tokenizer.eos_token # assign EOS_TOKEN 171 | 172 | # lora adapt 173 | model = lora_adapt(model) 174 | # load dataset 175 | dataset = load_data(data_path) 176 | 177 | # initialize trainer 178 | trainer = init_trainer(model, tokenizer, dataset) 179 | # train model 180 | trainer.train() 181 | 182 | print("Training finished!") 183 | 184 | # save model 185 | save_model(model, tokenizer, path = output_path) 186 | 187 | print("loRa model saved!") 188 | 189 | 190 | if __name__ == '__main__': 191 | 192 | ori_model_path = "/root/data/llama-3-8b-bnb-4bit" 193 | data_path = "/root/code/LLM/unsloth/sft_data" 194 | save_path = "/root/code/LLM/unsloth/output" 195 | # sft train 196 | train_pipeline(ori_model_path,data_path,output_path = save_path) 197 | 198 | # inference 199 | # model, tokenizer = load_model(model_path = save_path) 200 | # task_case = {'instruction':'Continue the fibonnaci sequence','input':'1, 1, 2, 3, 5, 8'} 201 | # res = inference(model, tokenizer,intruction=task_case['instruction'], input=task_case['input']) 202 | # print(res) 203 | 204 | # infrence use cmd input 205 | while 1: 206 | task_case = input("Please input your task case:") 207 | task_case = {"instruction":task_case,'input':''} 208 | res = inference(model, tokenizer,intruction=task_case['instruction'], input=task_case['input']) 209 | print(res) 210 | 211 | 212 | 213 | 214 | 215 | --------------------------------------------------------------------------------