├── assets ├── logo.png ├── qq_party.jpeg └── Cornucopia_LLM_training_inference_pipeline.png ├── utils ├── README.md └── prompter.py ├── templates ├── ori_template.json ├── fin_template.json ├── alpaca_short.json ├── alpaca.json ├── alpaca_legacy.json └── README.md ├── requirements.txt ├── base_models └── load.sh ├── scripts ├── finetune.sh ├── infer.sh └── comparison_test.sh ├── HOW_TO_CONTRIBUTE.md ├── instruction_data ├── infer.json └── fin_data.json ├── infer.py ├── README.md ├── LICENSE └── tuning_train.py /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerry1993-tech/Cornucopia-LLaMA-Fin-Chinese/HEAD/assets/logo.png -------------------------------------------------------------------------------- /assets/qq_party.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerry1993-tech/Cornucopia-LLaMA-Fin-Chinese/HEAD/assets/qq_party.jpeg -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # 提示构建模块 2 | 3 | ## prompter.py 4 | 5 | Prompter class, a template manager. 6 | 7 | ``` 8 | from utils.prompter import Prompter 9 | ``` 10 | -------------------------------------------------------------------------------- /assets/Cornucopia_LLM_training_inference_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerry1993-tech/Cornucopia-LLaMA-Fin-Chinese/HEAD/assets/Cornucopia_LLM_training_inference_pipeline.png -------------------------------------------------------------------------------- /templates/ori_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by Llama without sft", 3 | "prompt_input": "问题:{instruction} 回答:", 4 | "prompt_no_input": "问题:{instruction} 回答:", 5 | "response_split": "回答:" 6 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.17.1 2 | appdirs==1.4.4 3 | bitsandbytes==0.37.1 4 | black 5 | black[jupyter] 6 | datasets 7 | fire 8 | git+https://github.com/huggingface/peft.git 9 | git+https://github.com/huggingface/transformers.git 10 | gradio 11 | torch 12 | sentencepiece 13 | wandb>=0.15.0 -------------------------------------------------------------------------------- /templates/fin_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by Financial Instruction Tuning", 3 | "prompt_input": "下面是一个问题,运用金融财经知识来正确回答问题.\n### 问题:\n{instruction}\n### 回答:\n", 4 | "prompt_no_input": "下面是一个问题,运用金融财经知识来正确回答问题.\n### 问题:\n{instruction}\n### 回答:\n", 5 | "response_split": "### 回答:" 6 | } -------------------------------------------------------------------------------- /templates/alpaca_short.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "A shorter template to experiment with.", 3 | "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 4 | "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /base_models/load.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | base_model_pir="./base_models/llama-7b-hf" 5 | if [ ! -d $base_model_pir ];then 6 | cd ../base_models/ || exit 7 | git clone https://huggingface.co/decapoda-research/llama-7b-hf 8 | cd ../ || exit 9 | fi 10 | 11 | base_model_pir="./base_models/Linly-Chinese-LLaMA-7b-hf" 12 | if [ ! -d $base_model_pir ];then 13 | cd ../base_models/ || exit 14 | git clone https://huggingface.co/P01son/Linly-Chinese-LLaMA-7b-hf 15 | cd ../ || exit 16 | fi 17 | 18 | -------------------------------------------------------------------------------- /templates/alpaca.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by Alpaca-LoRA.", 3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /templates/alpaca_legacy.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Legacy template, used by Original Alpaca repository.", 3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /scripts/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | exp_tag="Meta" 5 | python3 tuning_train.py \ 6 | --base_model './base_models/llama-7b-hf' \ 7 | --data_path './instruction_data/fin_data.json' \ 8 | --output_dir './Fin-Alpaca-LoRA-7B-'$exp_tag \ 9 | --prompt_template_name 'fin_template' \ 10 | --micro_batch_size 64 \ 11 | --batch_size 64 \ 12 | --num_epochs 10 \ 13 | --wandb_run_name $exp_tag 14 | 15 | 16 | #exp_tag="Linly" 17 | #python3 tuning_train.py \ 18 | # --base_model './base_models/Linly-Chinese-LLaMA-7b-hf' \ 19 | # --data_path './instruction_data/fin_data.json' \ 20 | # --output_dir './Fin-Alpaca-LoRA-7B-'$exp_tag \ 21 | # --prompt_template_name 'fin_template' \ 22 | # --micro_batch_size 96 \ 23 | # --batch_size 96 \ 24 | # --num_epochs 10 \ 25 | # --wandb_run_name $exp_tag -------------------------------------------------------------------------------- /scripts/infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # If inferring with the llama model, set 'use_lora' to 'False' and 'prompt_template' to 'ori_template'. 4 | # If inferring with the default alpaca model, set 'use_lora' to 'True', 'lora_weights' to 'tloen/alpaca-lora-7b', and 'prompt_template' to 'alpaca'. 5 | # If inferring with the llama-fin model, download the LORA weights and set 'lora_weights' to './Fin-Alpaca-LoRA-7B-xxx' (or the exact directory of LORA weights) and 'prompt_template' to 'fin_template'. 6 | 7 | BASE_MODEL="./base_models/llama-7b-hf" # or「./base_models/Linly-Chinese-LLaMA-7b-hf」 8 | exp_tag="Meta" # or「"Linly"」 9 | python3 infer.py \ 10 | --base_model ${BASE_MODEL} \ 11 | --lora_weights './Fin-Alpaca-LoRA-7B-'$exp_tag \ 12 | --use_lora True \ 13 | --instruct_dir './instruction_data/infer.json' \ 14 | --prompt_template 'fin_template' 15 | -------------------------------------------------------------------------------- /HOW_TO_CONTRIBUTE.md: -------------------------------------------------------------------------------- 1 | 2 | 欢迎来到 聚宝盆(Cornucopia)项目! 感谢您有兴趣为我们的项目做出贡献。为了使投稿过程尽可能顺利,我们制定了一些指南来帮助您提交贡献投稿。 在您开始贡献之前,请花几分钟时间查看以下指南。 3 | 4 | ## 如何贡献代码 5 | 6 | 1. Fork 项目并将其 clone 到本地。 7 | 2. 使用描述性名称为您的贡献创建一个新分支。 8 | 3. 进行更改并确保它们经过适当的测试。 9 | 4. 向我们存储库的主分支提交 pull request。 10 | 11 | ## 如何贡献代码 12 | 13 | 如果您正在贡献prompts 或 prompt seeds,请使用以下标题格式打开一个新 issue:[New Prompt]: 或 [New Prompt Seed]:。 14 | 15 | 如果您正在贡献新的数据集: 16 | 1. 请核实我们官方数据集的格式。 17 | 2. 将您的数据集上传到某处,例如 HuggingFace。 18 | 3. 新建一个 issue 标题为:[Contributing Data]:。 描述数据集,例如 规模、内容等 19 | 4. 在 issue 中包含指向您的数据集的链接。 20 | 21 | 22 | ## 贡献指南 23 | 24 | 请确保您的贡献符合以下准则: 25 | 26 | 1. 遵循项目中使用的编码风格和约定。 27 | 2. 确保您的贡献有据可查且易于理解 28 | 3. 保持您的贡献简洁明了。 如果您正在进行多项更改,请考虑将它们分成单独的拉取请求。 29 | 4. 不要提交包含专有或机密信息的文稿。 30 | 31 | ## 提交 Issues 32 | 33 | 如果您在使用我们的项目时遇到任何问题,请通过我们的 issue 跟踪器报告。 请提供尽可能多的有关该问题的信息,包括重现该问题的步骤。 34 | 35 | 在提交问题之前,请先搜索现有问题 :) 36 | 37 | ## 结论 38 | 39 | 感谢您花时间阅读这些指南。 我们感谢您的贡献,并期待与您合作! 如果您有任何问题或疑虑,请联系项目维护人员。 40 | -------------------------------------------------------------------------------- /scripts/comparison_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # If inferring with the llama model, set 'use_lora' to 'False' and 'prompt_template' to 'ori_template'. 4 | # If inferring with the llama-fin model, download the LORA weights and set 'lora_weights' to './Fin-Alpaca-LoRA-7B-'$exp_tag (or the exact directory of LORA weights) and 'prompt_template' to 'fin_template'. 5 | 6 | BASE_MODEL="./base_models/llama-7b-hf" 7 | # only ori llama 8 | o_cmd="python3 infer.py \ 9 | --base_model ${BASE_MODEL} \ 10 | --use_lora False \ 11 | --prompt_template 'ori_template'" 12 | 13 | # Fin-Alpaca-LoRA-7B-Meta 14 | exp_tag="Meta" 15 | a_cmd="python3 infer.py \ 16 | --base_model ${BASE_MODEL} \ 17 | --use_lora True \ 18 | --lora_weights './Fin-Alpaca-LoRA-7B-'$exp_tag \ 19 | --prompt_template 'fin_template'" 20 | 21 | # Fin-Alpaca-LoRA-7B-Linly 22 | BASE_MODEL="./base_models/Linly-Chinese-LLaMA-7b-hf" 23 | exp_tag="Linly" 24 | m_cmd="python3 infer.py \ 25 | --base_model ${BASE_MODEL} \ 26 | --use_lora True \ 27 | --lora_weights './Fin-Alpaca-LoRA-7B-'$exp_tag \ 28 | --prompt_template 'fin_template'" 29 | 30 | echo "only_ori_llama" 31 | eval $o_cmd > infer_result/o_tmp.txt 32 | echo "Fin-Alpaca-LoRA-7B-Meta" 33 | eval $a_cmd > infer_result/a_tmp.txt 34 | echo "Fin-Alpaca-LoRA-7B-Linly" 35 | eval $m_cmd > infer_result/m_tmp.txt -------------------------------------------------------------------------------- /templates/README.md: -------------------------------------------------------------------------------- 1 | # 提示词模板 2 | 3 | 此目录包含用于 LoRA 微调 LLaMa 模型的提示的模板样式。 4 | 5 | ## Format 6 | 7 | 模板是通过一个JSON文件描述的,该文件包含以下键: 8 | 9 | - `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders. 10 | - `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders. 11 | - `description`: A short description of the template, with possible use cases. 12 | - `response_split`: The text to use as separator when cutting real response from the model output. 13 | 14 | No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest. 15 | 16 | ## 模板案例 17 | 18 | The default template, used unless otherwise specified, is `alpaca.json` 19 | 20 | ```json 21 | { 22 | "description": "Template used by Alpaca-LoRA.", 23 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 24 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 25 | "response_split": "### Response:" 26 | } 27 | 28 | ``` 29 | 30 | ## 现有模板 31 | 32 | ### alpaca 33 | 34 | 到目前为止,用于通用LoRA微调的默认模板。 35 | 36 | ### alpaca_legacy 37 | 38 | 原始羊驼使用的旧模板,响应字段后没有“\n”。保留以供参考和实验。 39 | 40 | ### alpaca_short 41 | 42 | 一个修剪过的羊驼模板,它似乎也表现得很好,并保留了一些 tokens。使用默认模板创建的模型似乎也可以通过短时间查询。 43 | -------------------------------------------------------------------------------- /instruction_data/infer.json: -------------------------------------------------------------------------------- 1 | {"instruction": "你好,我想了解一下创业板股票怎么买?", "input": "", "output": "您好,如果要购买创业板的股票,需要先开通创业板权限。\n开通创业板的三个条件:\n1.股东证券账户20个交易日日均资产10万以上;\n2.必须有24个月的交易经验;\n3.风险测试评级在积极型及以上;\n只要符合这三大条件的投资者都可以申请开通创业板交易权限。\n创业板是为中小型、创业性、科技型企业提供上市融资的场所。从这三类企业的角度出发,中小企业受限于规模,本身的抗风险能力就差;创业型企业的一般存在公司治理不完善的特点;对于科技型企业,高科技转化为产品或劳务具有不确定性,存在技术失败风险。所以对于普通投资者来说交易风险较大,把握不好还是要谨慎操作。"} 2 | {"instruction": "要是车抵押贷款的公司垮了咋办?", "input": "", "output": "办理车抵押贷款的公司垮了的话,建议如下处理:1.赶紧联系贷款公司的工作人员,让其安排后续还款事宜,有能力的赶紧还清,然后尽快拿回车辆登记证书。2.如果暂时联系不上人,看后续贷款公司是否会被接盘。而等新的金融机构接手后,届时还款给接手的金融机构就行。3.选择向当地公安部门报警处理,必要时还可以联系上其他同样在该贷款公司办有业务的客户一起向当地人民法院提起诉讼。至于车辆登记证书,可以根据公安出具的报案证明去车辆登记地的车管所申请补办。建议大家平时办理车辆抵押贷款最好选择银行或者正规大型持牌金融机构,一般不会面临倒闭情况。如果随便找一家不知名的贷款公司办理,很可能会上当受骗,甚至被索要高额利息和各种费用,造成个人钱财损失。"} 3 | {"instruction": "想了解一下理财单位净值和累计净值是什么意思?", "input": "", "output": "理财单位净值是指理财当前的价格情况,理财累计净值是指理财成立以来的价格情况,一般理财累计净值越高越好,累计净值越高代表理财的业绩越好,投资者获得的收益越多。净值就是价格的意思,理财净值由投资标的决定,投资标的上涨,那么理财的净值就会上涨,投资标的下跌,那么理财的净值就会下跌。理财投资标的有定期存款、债券、股票、外汇、期权、期货等。"} 4 | {"instruction": "我遇到个情况,就是信用卡到期了旧卡还能还款吗?", "input": "", "output": "信用卡到期了,用户没有激活新卡,那么就继续往旧卡里面还款,已经激活了新卡,则往新卡里面还款。有部分用户的旧卡到期了,但是没有收到新卡,为了避免出现信用卡逾期的情况,用户可以照常将还款金额存入旧卡中,银行仍然会按时扣款。当然,大部分用户选择激活了新卡,那么还款就可以在新卡里面进行,毕竟新卡激活以后,旧卡就没有任何用处了。"} 5 | {"instruction": "为什么医保350一年保费那么贵?", "input": "", "output": "医保350一年贵的原因在于:\n一方面医疗费用每年上涨,医保基金支出比较大,每年保持涨价50元的幅度,维持医保基金收支平衡;\n另一方面在于部分经济困难家庭可以支付350元一个人,但是一家如果是4口人或六口人,一年交1400-2100元,累计金额相对比较高,显得比较贵。\n其实城乡居民医保每年个人交费都享有政府补贴,比如个人交350,政府补贴了610元,相当于原价960元一年,已经是给予了价格补助。\n城乡居民医疗保险的特点在于:\n1、支持带病投保。无论是否已经罹患大病,都可以参保,可以报销,这是商业医疗险做不到的。\n2、保障范围涉及多个方面。比如普通门诊、大额门诊、住院待遇、特殊门诊、两病门诊、大病保险,如果家庭困难,医保报销以后还是解决不了问题,还可以申请政府医疗救助。\n3、报销比例方面:各省市有差异,在三级医院,如省人民医院,能够报65%-75%;在二级医院,比如县人员一样,能够报销75%-80%左右。\n家里有条件还是交居民医保,另外也可以配个商业医疗作为补充,比如看病住院花费5万,居民医保报销2万,商业医疗报销剩余的3万。"} -------------------------------------------------------------------------------- /utils/prompter.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | Choose template to build prompt. 5 | """ 6 | 7 | import json 8 | import os.path as osp 9 | from typing import Union 10 | 11 | 12 | class Prompter(object): 13 | """ 14 | prompt 构造器 -> chose template 15 | """ 16 | __slots__ = ("template", "_verbose") 17 | 18 | def __init__(self, template_name: str = "", verbose: bool = False): 19 | self._verbose = verbose 20 | 21 | if not template_name: 22 | template_name = "alpaca" # default 23 | file_name = osp.join("templates", "{}.json".format(template_name)) 24 | 25 | if not osp.exists(file_name): 26 | raise ValueError("Can't read {}".format(file_name)) 27 | 28 | with open(file_name) as fp: 29 | self.template = json.load(fp) 30 | if self._verbose: 31 | print("Chose prompt template {0}: {1}".format(template_name, self.template['description'])) 32 | 33 | def generate_prompt( 34 | self, 35 | instruction: str, 36 | input: Union[None, str] = None, 37 | label: Union[None, str] = None, 38 | ) -> str: 39 | """ 40 | 从指令和可选输入返回完整 prompt,如果提供了一个label (=response, =output),也会被添加 41 | """ 42 | if input: 43 | res = self.template["prompt_input"].format( 44 | instruction=instruction, input=input 45 | ) 46 | else: 47 | res = self.template["prompt_no_input"].format( 48 | instruction=instruction 49 | ) 50 | if label: 51 | res = "{0}{1}".format(res, label) 52 | if self._verbose: 53 | print(res) 54 | return res 55 | 56 | def get_response(self, output: str) -> str: 57 | return output.split(self.template["response_split"])[1].strip() 58 | -------------------------------------------------------------------------------- /instruction_data/fin_data.json: -------------------------------------------------------------------------------- 1 | {"instruction": "可转债风险大吗?和股票比风险大还是小?", "input": "", "output": "可转债和股票相比来说,风险是比较小一些的,打新债需要先开账户,开户完之后的2.个交易日就可以打新债了。\n开户流程如下\n1.打开优惠开户链接,点击开户;\n2.输入手机号和验证码;\n3.填写个人信息\n4.视频认证\n5.完成风险测评问题;\n6.3.方存管银行\n7.最后提交申请,等待审核。"} 2 | {"instruction": "股票没有成交不散单可以吗?", "input": "", "output": "可以,当有相同价格的买单或者卖单就能成交,股票按照价格优先、时间优先的原则进行成交,若委托单在交易时间未成交也未撤单的,那么股票清算后自动撤销该笔委托单,股票清算时间:交易日下午16:00到晚上22:00点。\n股票交易时间:周一至周五上午9:30-11:30,下午13:00-15:00,法定节假日不交易。"} 3 | {"instruction": "除权的股票能长期持股吗?", "input": "", "output": "除权是股票分红的一个阶段,对股票涨跌没有影响,股票是否长期持有不需要看是否除权,主要看股票是否具有投资价值,投资者可以根据基本面、技术面、消息面等情况分析股票是否具有投资价值。\n基本面主要分析财务指标、股东情况、公司概况等,财务指标看现金流量表(表示公司周转的资金,越多越好)、资产负债表(表示公司负债情况,负债越少越好,主要看流动比率和速动比率)和利润表(表示公司收入,越多越好);股东情况是公司股东的基本情况,包括:流通股东、非流通股东、股东户数、实际控制人、解禁时间等情况,查看股东情况可以了解股票是否有机构投资,公司概况是公司的基本信息,比如公司主要经营业务,投资者可以分析主营业务是否具有发展前景。\n技术面主要风险技术指标,KDJ指标:当K线向上穿过D线时,形成金叉是买入信号,MACD指标:当DIF线上穿DEA线时,形成金叉是买进信号,均线指标:当短期均线向上穿过长期均线时,形成金叉是买入信号,K线组合:早晨之星、曙光初现、红三兵等是买入信号。\n消息面主要分析有没有利好上市公司业务发展的政策和消息等,这些政策和消息会利于股票上涨。"} 4 | {"instruction": "股票挂单没成功什么时候返钱?", "input": "", "output": "股票挂单没成功撤单后资金实时到账,投资者马上可以继续委托,股票挂单没成功未撤单的,该笔委托单交易时间内有效,直到成交或者撤单,若收市时未成交的,股票清算时会自动撤销该笔交易,资金实时到账。\n股票按照市场实时价格进行成交,按照价格优先、时间优先的原则,一般股票委托价格和市场实时价格偏离幅度不大,能实时成交。"} 5 | {"instruction": "打新债要注意的坑有什么?要注意什么", "input": "", "output": "你好,很高兴为你解答,打新债要注意开户联系客户经理帮你申请低佣金账户,介绍几个可以提高打新债中签概率的小方法可以参考以下:\n1.,首先申购时间最好选择在上午10:31-11:30,以及下午14:31-15:00两个时间段申购的中签率会相对较高。这是证券公司所统计的中签率高的时间段,可以试一下。\n2.,建议多个账户进行申购,也可以提高投资者的中签概率,不过根据的可转债“打新”规则,同一投资者参与同一只可转债、可交换债网上申购只能使用一个证券账户,所以建议用家人身份证另外开通证券账户进行申购来提高中签。\n3.,最好选择顶格申购,新债申购没有市值要求,也没有资金要求,更不需要预先缴款,所以为了提高中签率,虽然是不同的券商交易的页面会不一样,但是一般都是默认顶格申购的。"} 6 | {"instruction": "办理商业汇票应遵守哪些原则和规定?", "input": "", "output": "办理商业汇票应遵守下列原则和规定:1.使用商业汇票的单位,必须是在银行开立帐户的法人;2.商业汇票在同城和异地均可使用;3.签发商业汇票必须以合法的商品交易为基础;④经承兑的商业汇票,可向银行贴现;⑤商业汇票一律记名,允许背书转让;⑥商业汇票的付款期限由交易双方商定,最长不得超过6个月;⑦商业汇票经承兑后,承兑人即付款人负有到期无条件交付票款的责任;⑧商业汇票由银行印制和发售。"} 7 | {"instruction": "我国有哪些政策性银行?", "input": "", "output": "为适应我国经济发展和深化改革的需要,我国政府于1994年先后组建了国家开发银行、中国农业发展银行、中国进出口银行等三家政策性银行。"} 8 | {"instruction": "CAPM与APT的联系与区别", "input": "", "output": "资本资产定价模型和套利定价理论都是资产定价理论,所讨论的都是期望收益率和风险的关系,但两者所用的假设和技术不同,两者既有联系、又有区别。\n联系:\n1.两者解决的问题相同,解决期望收益率和风险之间的关系,使期望收益与风险相匹配;\n2.两者对风险的看法相同,都是将风险分为系统性风险和非系统性风险,期望收益只与系统性风险相关,非系统性风险可以通过投资组合来分散掉。\n主要区别:\n1.在APT中,证券的风险有多个因素来解释,而在CAPM中,证券的风险只用证券相对于市场组合的β系数来解释;\n2.APT并没有假定投资者是风险厌恶的,没有对投资者的证券选择行为作出规定,因此APT的适用性增强了,而CAPM假定投资者按照期望收益率和标准差并利用无差异曲线选择投资组合;\n3.APT并不特别强调市场组合的作用,而CAPM强调市场组合是一个有效的组合;\n4.在APT中,资产均衡的得出是一个动态的过程,它是建立在一价定律的基础上的,而CAPM理论则建立在马科维茨的模型基础上的,强调的是一定风险下的收益最大化或一定收益下的风险最小化,均衡的导出是一个静态的过程。"} 9 | {"instruction": "借记账户和信用账户有什么区别?", "input": "", "output": "借记账户是你已经存入一定数量的钱的账户,以后你可以在其中取款和消费。信贷账户是指你实际上提取尚未存入但实际上是从银行借款的货币的账户。取款后,你将需要偿还该金额的利息。银行将为选定的商户提供分期付款的信用服务,分期付款最多12个月,最低消费(已用)的任何交易,以及借记账户,为你提供有竞争力的货币利息。"} 10 | {"instruction": "我这征信还能申请信用卡吗?", "input": "", "output": "(招商银行)可以申请,是否通过以系统审核结果为准。《征信业管理条例》规定:征信机构对个人不良信息的保存期限,自不良行为或者事件终止之日起为5年;超过5年的,应当予以删除。您在办理新的业务时,相关部门一般会优先考虑近期的消费、还款记录"} 11 | ...... 12 | ... -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | This file is to infer the tuned LLaMa model. 5 | """ 6 | 7 | import sys 8 | import json 9 | import argparse 10 | 11 | import torch 12 | from peft import PeftModel 13 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer 14 | 15 | from utils.prompter import Prompter 16 | 17 | if torch.cuda.is_available(): 18 | device = "cuda" 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--load_8bit", default=False, type=bool) 23 | parser.add_argument("--base_model", default='', type=str, required=True, help="original pretrained llama weights") 24 | parser.add_argument("--instruct_dir", default='', type=str, help="dataset of infer.") 25 | parser.add_argument("--use_lora", default=True, type=bool) 26 | parser.add_argument("--lora_weights", default='tloen/alpaca-lora-7b', type=str, help="The lora weights of llama.") 27 | parser.add_argument("--prompt_template", default='fin_template', type=str, help="The template of infer data.") 28 | args = parser.parse_args() 29 | 30 | 31 | def load_instruction(instruct_dir): 32 | input_data = [] 33 | with open(instruct_dir, "r") as f: 34 | lines = f.readlines() 35 | for line in lines: 36 | line = line.strip() 37 | d = json.loads(line) 38 | input_data.append(d) 39 | return input_data 40 | 41 | 42 | def main(): 43 | # ----------------------------------------- 44 | # 加载 prompt 模板与 llama模型 45 | prompter = Prompter(args.prompt_template) 46 | tokenizer = LlamaTokenizer.from_pretrained(args.base_model) 47 | model = LlamaForCausalLM.from_pretrained( 48 | args.base_model, 49 | load_in_8bit=args.load_8bit, 50 | torch_dtype=torch.float16, 51 | device_map="auto", 52 | ) 53 | 54 | if args.use_lora: 55 | print("using lora {}".format(args.lora_weights)) 56 | model = PeftModel.from_pretrained( 57 | model, 58 | args.lora_weights, 59 | torch_dtype=torch.float16 60 | ) 61 | # 重新配置 decapoda-research config 62 | model.config.pad_token_id, tokenizer.pad_token_id = 0, 0 # unk token 63 | model.config.bos_token_id = 1 64 | model.config.eos_token_id = 2 65 | if not args.load_8bit: 66 | model.half() # 开启半精度 67 | 68 | model.eval() # 保证BN层直接利用之前训练阶段得到的均值和方差 69 | 70 | if torch.__version__ >= "2" and sys.platform != "win32": 71 | model = torch.compile(model) 72 | 73 | # ----------------------------------------- 74 | def evaluate( 75 | instruction, 76 | input=None, 77 | temperature=0.2, 78 | top_p=0.85, 79 | top_k=40, 80 | num_beams=4, 81 | max_new_tokens=512, 82 | **kwargs 83 | ): 84 | prompt = prompter.generate_prompt(instruction, input) 85 | inputs = tokenizer(prompt, return_tensors="pt") 86 | input_ids = inputs["input_ids"].to(device) 87 | generation_config = GenerationConfig( 88 | temperature=temperature, 89 | top_p=top_p, 90 | top_k=top_k, 91 | num_beams=num_beams, 92 | **kwargs 93 | ) 94 | with torch.no_grad(): 95 | generation_output = model.generate( 96 | input_ids=input_ids, 97 | generation_config=generation_config, 98 | return_dict_in_generate=True, 99 | output_scores=True, 100 | max_new_tokens=max_new_tokens, 101 | early_stopping=True, 102 | remove_invalid_values=True, 103 | repetition_penalty=3.5, 104 | length_penalty=0.1, 105 | epsilon_cutoff=0.05, 106 | eos_token_id=2, 107 | forced_eos_token_id=2, 108 | pad_token_id=0 109 | ) 110 | s = generation_output.sequences[0] 111 | output = tokenizer.decode(s) 112 | return prompter.get_response(output) 113 | 114 | def infer_from_json(instruct_dir): 115 | input_data = load_instruction(instruct_dir) 116 | for d in input_data: 117 | instruction = d["instruction"] 118 | output = d["output"] 119 | print("###infering###") 120 | model_output = evaluate(instruction) 121 | print("###instruction###") 122 | print(instruction) 123 | print("###golden output###") 124 | print(output) 125 | print("###model output###") 126 | print(model_output) 127 | 128 | if args.instruct_dir != "": 129 | infer_from_json(args.instruct_dir) 130 | else: 131 | for instruction in [ 132 | "老年人理财好还是存定期好?", 133 | "手上有20万存款,可以作什么投资", 134 | "净值型和传统理财产品的区别,有什么不同?", 135 | "股票和基金能当天随买随卖吗?" 136 | ]: 137 | print("Instruction:", instruction) 138 | print("Response:", evaluate(instruction)) 139 | print() 140 | 141 | 142 | if __name__ == "__main__": 143 | main() 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
4 |
5 | [](https://www.zhihu.com/people/xuyingjie521/columns) [](./) [](./)
6 |
7 | [](https://github.com/jerry1993-tech/Cornucopia-LLaMA-Fin-Chinese/LICENSE) [](./)
8 |
9 | [](https://git.io/typing-svg)
10 |
11 |
102 |
103 |
202 |
203 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/tuning_train.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | """
4 | This file provides a method to tuning LLaMa model with financial data:
5 | ### Lora + int8_training ###
6 | """
7 |
8 | import argparse
9 | import os
10 | import sys
11 | from typing import List
12 |
13 | import torch
14 | import transformers
15 | from transformers import LlamaForCausalLM, LlamaTokenizer
16 | from datasets import load_dataset
17 | from utils.prompter import Prompter
18 |
19 | from peft import (
20 | LoraConfig,
21 | get_peft_model,
22 | get_peft_model_state_dict,
23 | prepare_model_for_int8_training,
24 | set_peft_model_state_dict,
25 | )
26 |
27 | parser = argparse.ArgumentParser()
28 | # model/data params
29 | parser.add_argument("--base_model", default='', type=str, required=True, help="original pretrained llama weights")
30 | parser.add_argument("--data_path", default='yahma/alpaca-cleaned', type=str, help="dataset of SFT.")
31 | parser.add_argument("--output_dir", default='./lora-alpaca', type=str, help="The path of lora model.")
32 | # training hyperparams
33 | parser.add_argument("--batch_size", default=128, type=int, help="Batch size per GPU/CPU for training.")
34 | parser.add_argument("--micro_batch_size", default=8, type=int, help="Batch size per process for training.")
35 | parser.add_argument("--num_epochs", default=8, type=int, help="Total number of training epochs to perform.")
36 | parser.add_argument("--learning_rate", default=2e-4, type=float, help="The initial learning rate for Adam.")
37 | parser.add_argument("--cutoff_len", default=512, type=int,
38 | help="The maximum total input sequence length after tokenization. Sequences longer "
39 | "than this will be truncated, sequences shorter will be padded.")
40 | parser.add_argument("--val_set_size", default=1200, type=int, help="Batch size for evaluate.")
41 | # lora hyperparams
42 | parser.add_argument("--lora_r", default=8, type=int, help="The number of Lora ranks.")
43 | parser.add_argument("--lora_alpha", default=16, type=int, help="Set the alpha parameter of LORA.")
44 | parser.add_argument("--lora_dropout", default=0.05, type=float, help="Set the dropout parameter of LORA.")
45 | parser.add_argument("--lora_target_modules", default=["q_proj", "v_proj"], type=List[str],
46 | help="Set the target module for the PEFT model.")
47 | # llm hyperparams
48 | parser.add_argument("--train_on_inputs", default=False, type=bool, help="if False, masks out inputs in loss.")
49 | parser.add_argument("--group_by_length", default=False, type=bool,
50 | help="faster, but produces an odd training loss curve.")
51 | # wandb params
52 | parser.add_argument("--wandb_project", default='llama_fin', type=str, help="The name of wandb_project.")
53 | parser.add_argument("--wandb_run_name", default='', type=str)
54 | parser.add_argument("--wandb_watch", default='', type=str, choices=['false', 'gradients', 'all'])
55 | parser.add_argument("--wandb_log_model", default='', type=str, choices=['false', 'true'])
56 | parser.add_argument("--resume_from_checkpoint", default=None, type=str,
57 | help="Either training checkpoint or final adapter.")
58 | parser.add_argument("--prompt_template_name", default='alpaca', type=str,
59 | help="The prompt template to use, will default to alpaca.")
60 | args = parser.parse_args()
61 |
62 |
63 | def do_tuning():
64 | if int(os.environ.get("LOCAL_RANK", 0)) == 0:
65 | print(
66 | "Training Alpaca-LoRA model with params:\n"
67 | "base_model: {}\n".format(args.base_model),
68 | "data_path: {}\n".format(args.data_path),
69 | "output_dir: {}\n".format(args.output_dir),
70 | "batch_size: {}\n".format(args.batch_size),
71 | "micro_batch_size: {}\n".format(args.micro_batch_size),
72 | "num_epochs: {}\n".format(args.num_epochs),
73 | "learning_rate: {}\n".format(args.learning_rate),
74 | "cutoff_len: {}\n".format(args.cutoff_len),
75 | "val_set_size: {}\n".format(args.val_set_size),
76 | "lora_r: {}\n".format(args.lora_r),
77 | "lora_alpha: {}\n".format(args.lora_alpha),
78 | "lora_dropout: {}\n".format(args.lora_dropout),
79 | "lora_target_modules: {}\n".format(args.lora_target_modules),
80 | "train_on_inputs: {}\n".format(args.train_on_inputs),
81 | "group_by_length: {}\n".format(args.group_by_length),
82 | "wandb_project: {}\n".format(args.wandb_project),
83 | "wandb_run_name: {}\n".format(args.wandb_run_name),
84 | "wandb_watch: {}\n".format(args.wandb_watch),
85 | "wandb_log_model: {}\n".format(args.wandb_log_model),
86 | "resume_from_checkpoint: {}\n".format(args.resume_from_checkpoint or None),
87 | "prompt template: {}\n".format(args.prompt_template_name)
88 | )
89 |
90 | # --------------------------Check--------------------------
91 | # Check if parameters passed or set in the environment
92 | use_wandb = len(args.wandb_project) > 0 or (
93 | "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
94 | )
95 | # Only overwrite environment if wandb param passed
96 | if len(args.wandb_project) > 0:
97 | os.environ["WANDB_PROJECT"] = args.wandb_project
98 | if len(args.wandb_watch) > 0:
99 | os.environ["WANDB_WATCH"] = args.wandb_watch
100 | if len(args.wandb_log_model) > 0:
101 | os.environ["WANDB_LOG_MODEL"] = args.wandb_log_model
102 |
103 | # Check if the base_model exists
104 | assert args.base_model, "Please specify a base_model, for example: 'decapoda-research/llama-7b-hf'"
105 |
106 | gradient_accumulation_steps = args.batch_size // args.micro_batch_size
107 | device_map = "auto"
108 | world_size = int(os.environ.get("WORLD_SIZE", 1))
109 | ddp = world_size != 1
110 | if ddp: # ddp adopts a multiprocessing approach
111 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
112 | gradient_accumulation_steps = gradient_accumulation_steps // world_size
113 |
114 | # --------------------------Model--------------------------
115 | # 与全精度模型相比,以 8 位精度加载模型最多可节省 4 倍的内存
116 | model = LlamaForCausalLM.from_pretrained(
117 | args.base_model,
118 | load_in_8bit=True,
119 | torch_dtype=torch.float16,
120 | device_map=device_map,
121 | )
122 | model = prepare_model_for_int8_training(model)
123 |
124 | # using lora to tuning
125 | config = LoraConfig(
126 | r=args.lora_r,
127 | lora_alpha=args.lora_alpha,
128 | target_modules=args.lora_target_modules,
129 | lora_dropout=args.lora_dropout,
130 | bias="none",
131 | task_type="CAUSAL_LM",
132 | )
133 | model = get_peft_model(model, config)
134 |
135 | # Check the available weights and load
136 | if args.resume_from_checkpoint:
137 | # Full checkpoint
138 | checkpoint_name = os.path.join(
139 | args.resume_from_checkpoint, "pytorch_model.bin"
140 | )
141 |
142 | # only LoRA model
143 | if not os.path.exists(checkpoint_name):
144 | checkpoint_name = os.path.join(args.resume_from_checkpoint, "adapter_model.bin")
145 | # So the trainer won't try loading its state
146 | args.resume_from_checkpoint = None
147 | # The two files above have a different name depending on how they were saved, but are actually the same.
148 | if os.path.exists(checkpoint_name):
149 | print("Restarting from {}".format(checkpoint_name))
150 | adapters_weights = torch.load(checkpoint_name)
151 | model = set_peft_model_state_dict(model, adapters_weights)
152 | else:
153 | print("Checkpoint {} not found".format(checkpoint_name))
154 |
155 | # More transparent to x% of trainable parameters
156 | model.print_trainable_parameters()
157 |
158 | # --------------------------Tokenizer--------------------------
159 | tokenizer = LlamaTokenizer.from_pretrained(args.base_model)
160 | tokenizer.pad_token_id = 0 # 「unk」 different from the eos token
161 | tokenizer.padding_side = "left" # allow batched inference
162 |
163 | def tokenize(prompt, add_eos_token=True):
164 | result = tokenizer(
165 | prompt,
166 | truncation=True,
167 | max_length=args.cutoff_len,
168 | padding=False,
169 | return_tensors=None,
170 | )
171 |
172 | if (
173 | result["input_ids"][-1] != tokenizer.eos_token_id
174 | and len(result["input_ids"]) < args.cutoff_len
175 | and add_eos_token
176 | ):
177 | result["input_ids"].append(tokenizer.eos_token_id)
178 | result["attention_mask"].append(1)
179 |
180 | result["labels"] = result["input_ids"].copy()
181 |
182 | return result
183 |
184 | # choose prompt template
185 | prompter = Prompter(args.prompt_template_name)
186 |
187 | def generate_and_tokenize_prompt(data_point):
188 | full_prompt = prompter.generate_prompt(
189 | data_point["instruction"],
190 | data_point["input"],
191 | data_point["output"],
192 | )
193 | tokenized_full_prompt = tokenize(full_prompt)
194 |
195 | if not args.train_on_inputs:
196 | user_prompt = prompter.generate_prompt(data_point["instruction"], data_point["input"])
197 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
198 | user_prompt_len = len(tokenized_user_prompt["input_ids"])
199 |
200 | tokenized_full_prompt["labels"] = [-100] * user_prompt_len \
201 | + tokenized_full_prompt["labels"][user_prompt_len:] # Maybe faster
202 | return tokenized_full_prompt
203 |
204 | # load dataset from file(here is xx.json)
205 | if args.data_path.endswith(".json") or args.data_path.endswith(".jsonl"):
206 | data = load_dataset("json", data_files=args.data_path)
207 | else:
208 | data = load_dataset(args.data_path)
209 |
210 | if args.val_set_size > 0:
211 | train_val = data["train"].train_test_split(
212 | test_size=args.val_set_size, shuffle=True, seed=2023
213 | )
214 | train_data = (
215 | train_val["train"].shuffle().map(generate_and_tokenize_prompt)
216 | )
217 | val_data = (
218 | train_val["test"].shuffle().map(generate_and_tokenize_prompt)
219 | )
220 | else:
221 | train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
222 | val_data = None
223 |
224 | # --------------------------Trainer--------------------------
225 | # 当多张显卡时,阻止 Trainer 使用自己的 DataParallelism
226 | if not ddp and torch.cuda.device_count() > 1:
227 | model.is_parallelizable = True
228 | model.model_parallel = True
229 |
230 | trainer = transformers.Trainer(
231 | model=model,
232 | train_dataset=train_data,
233 | eval_dataset=val_data,
234 | args=transformers.TrainingArguments(
235 | per_device_train_batch_size=args.micro_batch_size,
236 | gradient_accumulation_steps=gradient_accumulation_steps,
237 | warmup_ratio=0.1,
238 | num_train_epochs=args.num_epochs,
239 | learning_rate=args.learning_rate,
240 | fp16=True,
241 | logging_steps=8,
242 | optim="adamw_torch",
243 | evaluation_strategy="steps" if args.val_set_size > 0 else "no",
244 | save_strategy="steps",
245 | eval_steps=32 if args.val_set_size > 0 else None,
246 | save_steps=32,
247 | output_dir=args.output_dir,
248 | save_total_limit=5,
249 | load_best_model_at_end=True if args.val_set_size > 0 else False,
250 | ddp_find_unused_parameters=False if ddp else None,
251 | group_by_length=args.group_by_length,
252 | report_to="wandb" if use_wandb else None,
253 | run_name=args.wandb_run_name if use_wandb else None,
254 | ),
255 | data_collator=transformers.DataCollatorForSeq2Seq(
256 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
257 | ),
258 | )
259 |
260 | model.config.use_cache = False
261 |
262 | old_state_dict = model.state_dict
263 | model.state_dict = (
264 | lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
265 | ).__get__(model, type(model))
266 |
267 | if torch.__version__ >= "2" and sys.platform != "win32":
268 | model = torch.compile(model)
269 |
270 | trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
271 |
272 | model.save_pretrained(args.output_dir)
273 |
274 | print("\n 若上面出现有关于keys丢失的警告,请忽略! o(^_^)o ~")
275 |
276 |
277 | if __name__ == "__main__":
278 | do_tuning()
279 |
--------------------------------------------------------------------------------