├── assets ├── logo.png ├── qq_party.jpeg └── Cornucopia_LLM_training_inference_pipeline.png ├── utils ├── README.md └── prompter.py ├── templates ├── ori_template.json ├── fin_template.json ├── alpaca_short.json ├── alpaca.json ├── alpaca_legacy.json └── README.md ├── requirements.txt ├── base_models └── load.sh ├── scripts ├── finetune.sh ├── infer.sh └── comparison_test.sh ├── HOW_TO_CONTRIBUTE.md ├── instruction_data ├── infer.json └── fin_data.json ├── infer.py ├── README.md ├── LICENSE └── tuning_train.py /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerry1993-tech/Cornucopia-LLaMA-Fin-Chinese/HEAD/assets/logo.png -------------------------------------------------------------------------------- /assets/qq_party.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerry1993-tech/Cornucopia-LLaMA-Fin-Chinese/HEAD/assets/qq_party.jpeg -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # 提示构建模块 2 | 3 | ## prompter.py 4 | 5 | Prompter class, a template manager. 6 | 7 | ``` 8 | from utils.prompter import Prompter 9 | ``` 10 | -------------------------------------------------------------------------------- /assets/Cornucopia_LLM_training_inference_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerry1993-tech/Cornucopia-LLaMA-Fin-Chinese/HEAD/assets/Cornucopia_LLM_training_inference_pipeline.png -------------------------------------------------------------------------------- /templates/ori_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by Llama without sft", 3 | "prompt_input": "问题:{instruction} 回答:", 4 | "prompt_no_input": "问题:{instruction} 回答:", 5 | "response_split": "回答:" 6 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.17.1 2 | appdirs==1.4.4 3 | bitsandbytes==0.37.1 4 | black 5 | black[jupyter] 6 | datasets 7 | fire 8 | git+https://github.com/huggingface/peft.git 9 | git+https://github.com/huggingface/transformers.git 10 | gradio 11 | torch 12 | sentencepiece 13 | wandb>=0.15.0 -------------------------------------------------------------------------------- /templates/fin_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by Financial Instruction Tuning", 3 | "prompt_input": "下面是一个问题,运用金融财经知识来正确回答问题.\n### 问题:\n{instruction}\n### 回答:\n", 4 | "prompt_no_input": "下面是一个问题,运用金融财经知识来正确回答问题.\n### 问题:\n{instruction}\n### 回答:\n", 5 | "response_split": "### 回答:" 6 | } -------------------------------------------------------------------------------- /templates/alpaca_short.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "A shorter template to experiment with.", 3 | "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 4 | "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /base_models/load.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | base_model_pir="./base_models/llama-7b-hf" 5 | if [ ! -d $base_model_pir ];then 6 | cd ../base_models/ || exit 7 | git clone https://huggingface.co/decapoda-research/llama-7b-hf 8 | cd ../ || exit 9 | fi 10 | 11 | base_model_pir="./base_models/Linly-Chinese-LLaMA-7b-hf" 12 | if [ ! -d $base_model_pir ];then 13 | cd ../base_models/ || exit 14 | git clone https://huggingface.co/P01son/Linly-Chinese-LLaMA-7b-hf 15 | cd ../ || exit 16 | fi 17 | 18 | -------------------------------------------------------------------------------- /templates/alpaca.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by Alpaca-LoRA.", 3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /templates/alpaca_legacy.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Legacy template, used by Original Alpaca repository.", 3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /scripts/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | exp_tag="Meta" 5 | python3 tuning_train.py \ 6 | --base_model './base_models/llama-7b-hf' \ 7 | --data_path './instruction_data/fin_data.json' \ 8 | --output_dir './Fin-Alpaca-LoRA-7B-'$exp_tag \ 9 | --prompt_template_name 'fin_template' \ 10 | --micro_batch_size 64 \ 11 | --batch_size 64 \ 12 | --num_epochs 10 \ 13 | --wandb_run_name $exp_tag 14 | 15 | 16 | #exp_tag="Linly" 17 | #python3 tuning_train.py \ 18 | # --base_model './base_models/Linly-Chinese-LLaMA-7b-hf' \ 19 | # --data_path './instruction_data/fin_data.json' \ 20 | # --output_dir './Fin-Alpaca-LoRA-7B-'$exp_tag \ 21 | # --prompt_template_name 'fin_template' \ 22 | # --micro_batch_size 96 \ 23 | # --batch_size 96 \ 24 | # --num_epochs 10 \ 25 | # --wandb_run_name $exp_tag -------------------------------------------------------------------------------- /scripts/infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # If inferring with the llama model, set 'use_lora' to 'False' and 'prompt_template' to 'ori_template'. 4 | # If inferring with the default alpaca model, set 'use_lora' to 'True', 'lora_weights' to 'tloen/alpaca-lora-7b', and 'prompt_template' to 'alpaca'. 5 | # If inferring with the llama-fin model, download the LORA weights and set 'lora_weights' to './Fin-Alpaca-LoRA-7B-xxx' (or the exact directory of LORA weights) and 'prompt_template' to 'fin_template'. 6 | 7 | BASE_MODEL="./base_models/llama-7b-hf" # or「./base_models/Linly-Chinese-LLaMA-7b-hf」 8 | exp_tag="Meta" # or「"Linly"」 9 | python3 infer.py \ 10 | --base_model ${BASE_MODEL} \ 11 | --lora_weights './Fin-Alpaca-LoRA-7B-'$exp_tag \ 12 | --use_lora True \ 13 | --instruct_dir './instruction_data/infer.json' \ 14 | --prompt_template 'fin_template' 15 | -------------------------------------------------------------------------------- /HOW_TO_CONTRIBUTE.md: -------------------------------------------------------------------------------- 1 | 2 | 欢迎来到 聚宝盆(Cornucopia)项目! 感谢您有兴趣为我们的项目做出贡献。为了使投稿过程尽可能顺利,我们制定了一些指南来帮助您提交贡献投稿。 在您开始贡献之前,请花几分钟时间查看以下指南。 3 | 4 | ## 如何贡献代码 5 | 6 | 1. Fork 项目并将其 clone 到本地。 7 | 2. 使用描述性名称为您的贡献创建一个新分支。 8 | 3. 进行更改并确保它们经过适当的测试。 9 | 4. 向我们存储库的主分支提交 pull request。 10 | 11 | ## 如何贡献代码 12 | 13 | 如果您正在贡献prompts 或 prompt seeds,请使用以下标题格式打开一个新 issue:[New Prompt]: 或 [New Prompt Seed]:。 14 | 15 | 如果您正在贡献新的数据集: 16 | 1. 请核实我们官方数据集的格式。 17 | 2. 将您的数据集上传到某处,例如 HuggingFace。 18 | 3. 新建一个 issue 标题为:[Contributing Data]:。 描述数据集,例如 规模、内容等 19 | 4. 在 issue 中包含指向您的数据集的链接。 20 | 21 | 22 | ## 贡献指南 23 | 24 | 请确保您的贡献符合以下准则: 25 | 26 | 1. 遵循项目中使用的编码风格和约定。 27 | 2. 确保您的贡献有据可查且易于理解 28 | 3. 保持您的贡献简洁明了。 如果您正在进行多项更改,请考虑将它们分成单独的拉取请求。 29 | 4. 不要提交包含专有或机密信息的文稿。 30 | 31 | ## 提交 Issues 32 | 33 | 如果您在使用我们的项目时遇到任何问题,请通过我们的 issue 跟踪器报告。 请提供尽可能多的有关该问题的信息,包括重现该问题的步骤。 34 | 35 | 在提交问题之前,请先搜索现有问题 :) 36 | 37 | ## 结论 38 | 39 | 感谢您花时间阅读这些指南。 我们感谢您的贡献,并期待与您合作! 如果您有任何问题或疑虑,请联系项目维护人员。 40 | -------------------------------------------------------------------------------- /scripts/comparison_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # If inferring with the llama model, set 'use_lora' to 'False' and 'prompt_template' to 'ori_template'. 4 | # If inferring with the llama-fin model, download the LORA weights and set 'lora_weights' to './Fin-Alpaca-LoRA-7B-'$exp_tag (or the exact directory of LORA weights) and 'prompt_template' to 'fin_template'. 5 | 6 | BASE_MODEL="./base_models/llama-7b-hf" 7 | # only ori llama 8 | o_cmd="python3 infer.py \ 9 | --base_model ${BASE_MODEL} \ 10 | --use_lora False \ 11 | --prompt_template 'ori_template'" 12 | 13 | # Fin-Alpaca-LoRA-7B-Meta 14 | exp_tag="Meta" 15 | a_cmd="python3 infer.py \ 16 | --base_model ${BASE_MODEL} \ 17 | --use_lora True \ 18 | --lora_weights './Fin-Alpaca-LoRA-7B-'$exp_tag \ 19 | --prompt_template 'fin_template'" 20 | 21 | # Fin-Alpaca-LoRA-7B-Linly 22 | BASE_MODEL="./base_models/Linly-Chinese-LLaMA-7b-hf" 23 | exp_tag="Linly" 24 | m_cmd="python3 infer.py \ 25 | --base_model ${BASE_MODEL} \ 26 | --use_lora True \ 27 | --lora_weights './Fin-Alpaca-LoRA-7B-'$exp_tag \ 28 | --prompt_template 'fin_template'" 29 | 30 | echo "only_ori_llama" 31 | eval $o_cmd > infer_result/o_tmp.txt 32 | echo "Fin-Alpaca-LoRA-7B-Meta" 33 | eval $a_cmd > infer_result/a_tmp.txt 34 | echo "Fin-Alpaca-LoRA-7B-Linly" 35 | eval $m_cmd > infer_result/m_tmp.txt -------------------------------------------------------------------------------- /templates/README.md: -------------------------------------------------------------------------------- 1 | # 提示词模板 2 | 3 | 此目录包含用于 LoRA 微调 LLaMa 模型的提示的模板样式。 4 | 5 | ## Format 6 | 7 | 模板是通过一个JSON文件描述的,该文件包含以下键: 8 | 9 | - `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders. 10 | - `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders. 11 | - `description`: A short description of the template, with possible use cases. 12 | - `response_split`: The text to use as separator when cutting real response from the model output. 13 | 14 | No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest. 15 | 16 | ## 模板案例 17 | 18 | The default template, used unless otherwise specified, is `alpaca.json` 19 | 20 | ```json 21 | { 22 | "description": "Template used by Alpaca-LoRA.", 23 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 24 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 25 | "response_split": "### Response:" 26 | } 27 | 28 | ``` 29 | 30 | ## 现有模板 31 | 32 | ### alpaca 33 | 34 | 到目前为止,用于通用LoRA微调的默认模板。 35 | 36 | ### alpaca_legacy 37 | 38 | 原始羊驼使用的旧模板,响应字段后没有“\n”。保留以供参考和实验。 39 | 40 | ### alpaca_short 41 | 42 | 一个修剪过的羊驼模板,它似乎也表现得很好,并保留了一些 tokens。使用默认模板创建的模型似乎也可以通过短时间查询。 43 | -------------------------------------------------------------------------------- /instruction_data/infer.json: -------------------------------------------------------------------------------- 1 | {"instruction": "你好,我想了解一下创业板股票怎么买?", "input": "", "output": "您好,如果要购买创业板的股票,需要先开通创业板权限。\n开通创业板的三个条件:\n1.股东证券账户20个交易日日均资产10万以上;\n2.必须有24个月的交易经验;\n3.风险测试评级在积极型及以上;\n只要符合这三大条件的投资者都可以申请开通创业板交易权限。\n创业板是为中小型、创业性、科技型企业提供上市融资的场所。从这三类企业的角度出发,中小企业受限于规模,本身的抗风险能力就差;创业型企业的一般存在公司治理不完善的特点;对于科技型企业,高科技转化为产品或劳务具有不确定性,存在技术失败风险。所以对于普通投资者来说交易风险较大,把握不好还是要谨慎操作。"} 2 | {"instruction": "要是车抵押贷款的公司垮了咋办?", "input": "", "output": "办理车抵押贷款的公司垮了的话,建议如下处理:1.赶紧联系贷款公司的工作人员,让其安排后续还款事宜,有能力的赶紧还清,然后尽快拿回车辆登记证书。2.如果暂时联系不上人,看后续贷款公司是否会被接盘。而等新的金融机构接手后,届时还款给接手的金融机构就行。3.选择向当地公安部门报警处理,必要时还可以联系上其他同样在该贷款公司办有业务的客户一起向当地人民法院提起诉讼。至于车辆登记证书,可以根据公安出具的报案证明去车辆登记地的车管所申请补办。建议大家平时办理车辆抵押贷款最好选择银行或者正规大型持牌金融机构,一般不会面临倒闭情况。如果随便找一家不知名的贷款公司办理,很可能会上当受骗,甚至被索要高额利息和各种费用,造成个人钱财损失。"} 3 | {"instruction": "想了解一下理财单位净值和累计净值是什么意思?", "input": "", "output": "理财单位净值是指理财当前的价格情况,理财累计净值是指理财成立以来的价格情况,一般理财累计净值越高越好,累计净值越高代表理财的业绩越好,投资者获得的收益越多。净值就是价格的意思,理财净值由投资标的决定,投资标的上涨,那么理财的净值就会上涨,投资标的下跌,那么理财的净值就会下跌。理财投资标的有定期存款、债券、股票、外汇、期权、期货等。"} 4 | {"instruction": "我遇到个情况,就是信用卡到期了旧卡还能还款吗?", "input": "", "output": "信用卡到期了,用户没有激活新卡,那么就继续往旧卡里面还款,已经激活了新卡,则往新卡里面还款。有部分用户的旧卡到期了,但是没有收到新卡,为了避免出现信用卡逾期的情况,用户可以照常将还款金额存入旧卡中,银行仍然会按时扣款。当然,大部分用户选择激活了新卡,那么还款就可以在新卡里面进行,毕竟新卡激活以后,旧卡就没有任何用处了。"} 5 | {"instruction": "为什么医保350一年保费那么贵?", "input": "", "output": "医保350一年贵的原因在于:\n一方面医疗费用每年上涨,医保基金支出比较大,每年保持涨价50元的幅度,维持医保基金收支平衡;\n另一方面在于部分经济困难家庭可以支付350元一个人,但是一家如果是4口人或六口人,一年交1400-2100元,累计金额相对比较高,显得比较贵。\n其实城乡居民医保每年个人交费都享有政府补贴,比如个人交350,政府补贴了610元,相当于原价960元一年,已经是给予了价格补助。\n城乡居民医疗保险的特点在于:\n1、支持带病投保。无论是否已经罹患大病,都可以参保,可以报销,这是商业医疗险做不到的。\n2、保障范围涉及多个方面。比如普通门诊、大额门诊、住院待遇、特殊门诊、两病门诊、大病保险,如果家庭困难,医保报销以后还是解决不了问题,还可以申请政府医疗救助。\n3、报销比例方面:各省市有差异,在三级医院,如省人民医院,能够报65%-75%;在二级医院,比如县人员一样,能够报销75%-80%左右。\n家里有条件还是交居民医保,另外也可以配个商业医疗作为补充,比如看病住院花费5万,居民医保报销2万,商业医疗报销剩余的3万。"} -------------------------------------------------------------------------------- /utils/prompter.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | Choose template to build prompt. 5 | """ 6 | 7 | import json 8 | import os.path as osp 9 | from typing import Union 10 | 11 | 12 | class Prompter(object): 13 | """ 14 | prompt 构造器 -> chose template 15 | """ 16 | __slots__ = ("template", "_verbose") 17 | 18 | def __init__(self, template_name: str = "", verbose: bool = False): 19 | self._verbose = verbose 20 | 21 | if not template_name: 22 | template_name = "alpaca" # default 23 | file_name = osp.join("templates", "{}.json".format(template_name)) 24 | 25 | if not osp.exists(file_name): 26 | raise ValueError("Can't read {}".format(file_name)) 27 | 28 | with open(file_name) as fp: 29 | self.template = json.load(fp) 30 | if self._verbose: 31 | print("Chose prompt template {0}: {1}".format(template_name, self.template['description'])) 32 | 33 | def generate_prompt( 34 | self, 35 | instruction: str, 36 | input: Union[None, str] = None, 37 | label: Union[None, str] = None, 38 | ) -> str: 39 | """ 40 | 从指令和可选输入返回完整 prompt,如果提供了一个label (=response, =output),也会被添加 41 | """ 42 | if input: 43 | res = self.template["prompt_input"].format( 44 | instruction=instruction, input=input 45 | ) 46 | else: 47 | res = self.template["prompt_no_input"].format( 48 | instruction=instruction 49 | ) 50 | if label: 51 | res = "{0}{1}".format(res, label) 52 | if self._verbose: 53 | print(res) 54 | return res 55 | 56 | def get_response(self, output: str) -> str: 57 | return output.split(self.template["response_split"])[1].strip() 58 | -------------------------------------------------------------------------------- /instruction_data/fin_data.json: -------------------------------------------------------------------------------- 1 | {"instruction": "可转债风险大吗?和股票比风险大还是小?", "input": "", "output": "可转债和股票相比来说,风险是比较小一些的,打新债需要先开账户,开户完之后的2.个交易日就可以打新债了。\n开户流程如下\n1.打开优惠开户链接,点击开户;\n2.输入手机号和验证码;\n3.填写个人信息\n4.视频认证\n5.完成风险测评问题;\n6.3.方存管银行\n7.最后提交申请,等待审核。"} 2 | {"instruction": "股票没有成交不散单可以吗?", "input": "", "output": "可以,当有相同价格的买单或者卖单就能成交,股票按照价格优先、时间优先的原则进行成交,若委托单在交易时间未成交也未撤单的,那么股票清算后自动撤销该笔委托单,股票清算时间:交易日下午16:00到晚上22:00点。\n股票交易时间:周一至周五上午9:30-11:30,下午13:00-15:00,法定节假日不交易。"} 3 | {"instruction": "除权的股票能长期持股吗?", "input": "", "output": "除权是股票分红的一个阶段,对股票涨跌没有影响,股票是否长期持有不需要看是否除权,主要看股票是否具有投资价值,投资者可以根据基本面、技术面、消息面等情况分析股票是否具有投资价值。\n基本面主要分析财务指标、股东情况、公司概况等,财务指标看现金流量表(表示公司周转的资金,越多越好)、资产负债表(表示公司负债情况,负债越少越好,主要看流动比率和速动比率)和利润表(表示公司收入,越多越好);股东情况是公司股东的基本情况,包括:流通股东、非流通股东、股东户数、实际控制人、解禁时间等情况,查看股东情况可以了解股票是否有机构投资,公司概况是公司的基本信息,比如公司主要经营业务,投资者可以分析主营业务是否具有发展前景。\n技术面主要风险技术指标,KDJ指标:当K线向上穿过D线时,形成金叉是买入信号,MACD指标:当DIF线上穿DEA线时,形成金叉是买进信号,均线指标:当短期均线向上穿过长期均线时,形成金叉是买入信号,K线组合:早晨之星、曙光初现、红三兵等是买入信号。\n消息面主要分析有没有利好上市公司业务发展的政策和消息等,这些政策和消息会利于股票上涨。"} 4 | {"instruction": "股票挂单没成功什么时候返钱?", "input": "", "output": "股票挂单没成功撤单后资金实时到账,投资者马上可以继续委托,股票挂单没成功未撤单的,该笔委托单交易时间内有效,直到成交或者撤单,若收市时未成交的,股票清算时会自动撤销该笔交易,资金实时到账。\n股票按照市场实时价格进行成交,按照价格优先、时间优先的原则,一般股票委托价格和市场实时价格偏离幅度不大,能实时成交。"} 5 | {"instruction": "打新债要注意的坑有什么?要注意什么", "input": "", "output": "你好,很高兴为你解答,打新债要注意开户联系客户经理帮你申请低佣金账户,介绍几个可以提高打新债中签概率的小方法可以参考以下:\n1.,首先申购时间最好选择在上午10:31-11:30,以及下午14:31-15:00两个时间段申购的中签率会相对较高。这是证券公司所统计的中签率高的时间段,可以试一下。\n2.,建议多个账户进行申购,也可以提高投资者的中签概率,不过根据的可转债“打新”规则,同一投资者参与同一只可转债、可交换债网上申购只能使用一个证券账户,所以建议用家人身份证另外开通证券账户进行申购来提高中签。\n3.,最好选择顶格申购,新债申购没有市值要求,也没有资金要求,更不需要预先缴款,所以为了提高中签率,虽然是不同的券商交易的页面会不一样,但是一般都是默认顶格申购的。"} 6 | {"instruction": "办理商业汇票应遵守哪些原则和规定?", "input": "", "output": "办理商业汇票应遵守下列原则和规定:1.使用商业汇票的单位,必须是在银行开立帐户的法人;2.商业汇票在同城和异地均可使用;3.签发商业汇票必须以合法的商品交易为基础;④经承兑的商业汇票,可向银行贴现;⑤商业汇票一律记名,允许背书转让;⑥商业汇票的付款期限由交易双方商定,最长不得超过6个月;⑦商业汇票经承兑后,承兑人即付款人负有到期无条件交付票款的责任;⑧商业汇票由银行印制和发售。"} 7 | {"instruction": "我国有哪些政策性银行?", "input": "", "output": "为适应我国经济发展和深化改革的需要,我国政府于1994年先后组建了国家开发银行、中国农业发展银行、中国进出口银行等三家政策性银行。"} 8 | {"instruction": "CAPM与APT的联系与区别", "input": "", "output": "资本资产定价模型和套利定价理论都是资产定价理论,所讨论的都是期望收益率和风险的关系,但两者所用的假设和技术不同,两者既有联系、又有区别。\n联系:\n1.两者解决的问题相同,解决期望收益率和风险之间的关系,使期望收益与风险相匹配;\n2.两者对风险的看法相同,都是将风险分为系统性风险和非系统性风险,期望收益只与系统性风险相关,非系统性风险可以通过投资组合来分散掉。\n主要区别:\n1.在APT中,证券的风险有多个因素来解释,而在CAPM中,证券的风险只用证券相对于市场组合的β系数来解释;\n2.APT并没有假定投资者是风险厌恶的,没有对投资者的证券选择行为作出规定,因此APT的适用性增强了,而CAPM假定投资者按照期望收益率和标准差并利用无差异曲线选择投资组合;\n3.APT并不特别强调市场组合的作用,而CAPM强调市场组合是一个有效的组合;\n4.在APT中,资产均衡的得出是一个动态的过程,它是建立在一价定律的基础上的,而CAPM理论则建立在马科维茨的模型基础上的,强调的是一定风险下的收益最大化或一定收益下的风险最小化,均衡的导出是一个静态的过程。"} 9 | {"instruction": "借记账户和信用账户有什么区别?", "input": "", "output": "借记账户是你已经存入一定数量的钱的账户,以后你可以在其中取款和消费。信贷账户是指你实际上提取尚未存入但实际上是从银行借款的货币的账户。取款后,你将需要偿还该金额的利息。银行将为选定的商户提供分期付款的信用服务,分期付款最多12个月,最低消费(已用)的任何交易,以及借记账户,为你提供有竞争力的货币利息。"} 10 | {"instruction": "我这征信还能申请信用卡吗?", "input": "", "output": "(招商银行)可以申请,是否通过以系统审核结果为准。《征信业管理条例》规定:征信机构对个人不良信息的保存期限,自不良行为或者事件终止之日起为5年;超过5年的,应当予以删除。您在办理新的业务时,相关部门一般会优先考虑近期的消费、还款记录"} 11 | ...... 12 | ... -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | This file is to infer the tuned LLaMa model. 5 | """ 6 | 7 | import sys 8 | import json 9 | import argparse 10 | 11 | import torch 12 | from peft import PeftModel 13 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer 14 | 15 | from utils.prompter import Prompter 16 | 17 | if torch.cuda.is_available(): 18 | device = "cuda" 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--load_8bit", default=False, type=bool) 23 | parser.add_argument("--base_model", default='', type=str, required=True, help="original pretrained llama weights") 24 | parser.add_argument("--instruct_dir", default='', type=str, help="dataset of infer.") 25 | parser.add_argument("--use_lora", default=True, type=bool) 26 | parser.add_argument("--lora_weights", default='tloen/alpaca-lora-7b', type=str, help="The lora weights of llama.") 27 | parser.add_argument("--prompt_template", default='fin_template', type=str, help="The template of infer data.") 28 | args = parser.parse_args() 29 | 30 | 31 | def load_instruction(instruct_dir): 32 | input_data = [] 33 | with open(instruct_dir, "r") as f: 34 | lines = f.readlines() 35 | for line in lines: 36 | line = line.strip() 37 | d = json.loads(line) 38 | input_data.append(d) 39 | return input_data 40 | 41 | 42 | def main(): 43 | # ----------------------------------------- 44 | # 加载 prompt 模板与 llama模型 45 | prompter = Prompter(args.prompt_template) 46 | tokenizer = LlamaTokenizer.from_pretrained(args.base_model) 47 | model = LlamaForCausalLM.from_pretrained( 48 | args.base_model, 49 | load_in_8bit=args.load_8bit, 50 | torch_dtype=torch.float16, 51 | device_map="auto", 52 | ) 53 | 54 | if args.use_lora: 55 | print("using lora {}".format(args.lora_weights)) 56 | model = PeftModel.from_pretrained( 57 | model, 58 | args.lora_weights, 59 | torch_dtype=torch.float16 60 | ) 61 | # 重新配置 decapoda-research config 62 | model.config.pad_token_id, tokenizer.pad_token_id = 0, 0 # unk token 63 | model.config.bos_token_id = 1 64 | model.config.eos_token_id = 2 65 | if not args.load_8bit: 66 | model.half() # 开启半精度 67 | 68 | model.eval() # 保证BN层直接利用之前训练阶段得到的均值和方差 69 | 70 | if torch.__version__ >= "2" and sys.platform != "win32": 71 | model = torch.compile(model) 72 | 73 | # ----------------------------------------- 74 | def evaluate( 75 | instruction, 76 | input=None, 77 | temperature=0.2, 78 | top_p=0.85, 79 | top_k=40, 80 | num_beams=4, 81 | max_new_tokens=512, 82 | **kwargs 83 | ): 84 | prompt = prompter.generate_prompt(instruction, input) 85 | inputs = tokenizer(prompt, return_tensors="pt") 86 | input_ids = inputs["input_ids"].to(device) 87 | generation_config = GenerationConfig( 88 | temperature=temperature, 89 | top_p=top_p, 90 | top_k=top_k, 91 | num_beams=num_beams, 92 | **kwargs 93 | ) 94 | with torch.no_grad(): 95 | generation_output = model.generate( 96 | input_ids=input_ids, 97 | generation_config=generation_config, 98 | return_dict_in_generate=True, 99 | output_scores=True, 100 | max_new_tokens=max_new_tokens, 101 | early_stopping=True, 102 | remove_invalid_values=True, 103 | repetition_penalty=3.5, 104 | length_penalty=0.1, 105 | epsilon_cutoff=0.05, 106 | eos_token_id=2, 107 | forced_eos_token_id=2, 108 | pad_token_id=0 109 | ) 110 | s = generation_output.sequences[0] 111 | output = tokenizer.decode(s) 112 | return prompter.get_response(output) 113 | 114 | def infer_from_json(instruct_dir): 115 | input_data = load_instruction(instruct_dir) 116 | for d in input_data: 117 | instruction = d["instruction"] 118 | output = d["output"] 119 | print("###infering###") 120 | model_output = evaluate(instruction) 121 | print("###instruction###") 122 | print(instruction) 123 | print("###golden output###") 124 | print(output) 125 | print("###model output###") 126 | print(model_output) 127 | 128 | if args.instruct_dir != "": 129 | infer_from_json(args.instruct_dir) 130 | else: 131 | for instruction in [ 132 | "老年人理财好还是存定期好?", 133 | "手上有20万存款,可以作什么投资", 134 | "净值型和传统理财产品的区别,有什么不同?", 135 | "股票和基金能当天随买随卖吗?" 136 | ]: 137 | print("Instruction:", instruction) 138 | print("Response:", evaluate(instruction)) 139 | print() 140 | 141 | 142 | if __name__ == "__main__": 143 | main() 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | Cornucopia-LLaMa 4 | 5 | [![Author](https://img.shields.io/badge/Author-yangmu.yu-green.svg "Author")](https://www.zhihu.com/people/xuyingjie521/columns) [![OS](https://img.shields.io/badge/OS-Linux/Windows/Mac-red.svg "OS")](./) [![Based](https://img.shields.io/badge/Based-huggingface_transformers-blue.svg "OS")](./) 6 | 7 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/jerry1993-tech/Cornucopia-LLaMA-Fin-Chinese/LICENSE) [![Python](https://img.shields.io/badge/Python-3.9+-darkseagreen.svg "Python")](./) 8 | 9 | [![Typing SVG](https://readme-typing-svg.demolab.com?font=Fira+Code&pause=100&color=1AF783¢er=true&vCenter=true&width=600&lines=Transformers+load+Cornucopia+LLaMa>>>>>>>>>>+GO!)](https://git.io/typing-svg) 10 | 11 |
12 | 13 | 14 | # Cornucopia-LLaMA-Fin-Chinese 15 | ### 聚宝盆(Cornucopia): 基于中文金融知识的 LLaMA 系微调模型 16 | 17 | 18 | 本项目开源了基于 LLaMA 系基模型经过中文金融知识指令精调/指令微调(Instruct-tuning) 的微调模型。通过中文金融公开问答数据+爬取的金融问答数据构建指令数据集,并在此基础上对 LLaMA 系模型进行了指令微调,提高了 LLaMA 在金融领域的问答效果。 19 | 20 | 基于已有数据和继续爬取的中文金融数据,将继续利用 GPT3.5/4.0 API 构建高质量的数据集,另在[中文知识图谱-金融](http://www.openkg.cn/group/finance)、CFLEB 金融数据集等数据上进一步扩充高质量指令数据集。 21 | 22 | 陆续会发布新的**中文场景的金融模型**(next-pretrain、multi-task SFT、RLHF),欢迎大家届时使用体验,敬请期待。 23 | 24 | ## 🔄 News 25 | 26 | [2023/05/10] 发布了基于 Chinese-LLaMA 和中文金融数据进行指令微调的模型。 27 | 28 | [2023/05/07] 发布了基于 Meta-LLaMA 和中文金融数据进行指令微调的模型。 29 | 30 | 31 | ## 🚀 快速开始 32 | 33 | 首先安装依赖包,python环境建议3.9+ 34 | 35 | ``` 36 | 37 | pip install -r requirements.txt 38 | 39 | ``` 40 | 其次安装 lfs 方便本地下载 LLaMa 大模型 41 | 42 | ``` 43 | 44 | git lfs install 45 | 46 | # 下载7B模型到本地 47 | bash ./base_models/load.sh 48 | 49 | ``` 50 | 51 | ### 🤖 模型下载 52 | 53 | LoRA 权重可以通过 Huggingface 下载,结构如下: 54 | ``` 55 | Fin-Alpaca-LoRA-7B-Meta/ 56 | - adapter_config.json # LoRA权重配置文件 57 | - adapter_model.bin # LoRA权重文件 58 | ``` 59 | | LoRA模型下载 | 分类 | 重构模型 | 训练数据 | 训练序列长度 | 版本 | 60 | |----------------------------------------------------------------------------------|------------|-----|----------|--------|------| 61 | | [Fin-Alpaca-LoRA-7B-Meta](https://huggingface.co/yuyangmu125/Fin-Alpaca-LoRA-7B-Meta) | 中文金融问答微调模型 | decapoda-research/llama-7b-hf | 12M 指令数据 | 512 | V1.0 | 62 | | [Fin-Alpaca-LoRA-7B-Linly](https://huggingface.co/yuyangmu125/Fin-Alpaca-LoRA-7B-Linly) | 中文金融问答微调模型 | Linly-AI/Chinese-LLaMA-7B | 14M 指令数据 | 512 | V1.1 | 63 | 64 | 65 | ### Inference 66 | 67 | 目前在`./instruction_data/infer.json`中提供了一些测试用例,也可替换成其它的数据集但注意格式保持一致 68 | 69 | 70 | 运行infer脚本 71 | 72 | ``` 73 | 74 | # 单模型推理 75 | bash ./scripts/infer.sh 76 | 77 | # 多模型对比 78 | bash ./scripts/comparison_test.sh 79 | 80 | ``` 81 | 82 | ### 📊 数据集构建 83 | 84 | 此前版本采用了**公开和爬取的中文金融领域问答数据**,涉及到保险、理财、股票、基金、贷款、信用卡、社保等。 85 | 86 | 指令数据示例如下: 87 | 88 | ``` 89 | 90 | 问题:办理商业汇票应遵守哪些原则和规定? 91 | 92 | 回答: 办理商业汇票应遵守下列原则和规定:1.使用商业汇票的单位,必须是在银行开立帐户的法人;2.商业汇票在同城和异地均可使用;3.签发商业汇票必须以合法的商品交易为基础;4.经承兑的商业汇票,可向银行贴现;5.商业汇票一律记名,允许背书转让;6.商业汇票的付款期限由交易双方商定,最长不得超过6个月;7.商业汇票经承兑后,承兑人即付款人负有到期无条件交付票款的责任;8.商业汇票由银行印制和发售。 93 | 94 | ``` 95 | 96 | 针对此前数据仍存在不准确和类型单一等不完善的地方;目前我们利用GPT3.5/4.0 接口进一步优化数据、并扩充中文金融知识库,设置**多种Prompt形式、multi-task形式拓展丰富指令数据集**,实现金融领域多业务场景覆盖。 97 | 98 | 最新模型情况:(即将发布,敬请期待~) 99 | 100 | 101 |

102 | Cornucopia-LLaMA-Fin-Chinese 103 |

104 | 105 | 106 | 107 | ### ⚙️ Finetune 108 | 109 | 若想用自己的数据集微调 LLaMA,请按照`./instruction_data/fin_data.json`的格式构建自己的数据集 110 | 111 | 运行 finetune 脚本 112 | 113 | ``` 114 | 115 | bash ./scripts/finetune.sh 116 | 117 | ``` 118 | 119 | 120 | ## 🏃 训练细节 121 | 122 | ### 计算资源需求 123 | 124 | 目前训练设备为一张A100-SXM-80GB显卡,训练总轮次10轮。batch_size=64的情况下显存占用在40G左右、batch_size=96的情况下显存占用在65G左右。预计3090/4090显卡(24GB显存)以上显卡可以较好支持,根据显存大小来调整batch_size。 125 | 126 | ### 实验记录 127 | 128 | 实验指标记录与超参设置可以参考wandb的链接: 129 | 130 | https://api.wandb.ai/links/1812316597/hkyiriw5 131 | 132 | 133 | ## 📑 模型效果对比 134 | 135 | 136 | | 测试输入 | 原始Llama输出 | Cornucopia(Fin-Alpaca-LoRA-7B-Meta)输出 | Cornucopia(Fin-Alpaca-LoRA-7B-Linly)输出 | 文心一言输出 | 讯飞星火认知输出 | 137 | |:------------ |:----------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:--------------|:------| 138 | |老年人理财好还是存定期好?|老年人的理财办法应该适合于自身情况。如果他/她有足够的时间和投资能力,则建议他/她利用现在的时间与投资机会进行理财活动。因为老年人可以获得余额、保险、房地产等多种收入方式,所以他/她可以选择不同类型的理财项目,包括公开市场、非公开市场、基金、信托计划等。当然,老年人也应该注意其实需要的生活财务安全。| **若投资者风险承受能力比较高,那么可以选择理财产品进行投资,若投资者风险承受能力比较低,那么可以选择定期存款进行投资。目前理财产品和定期存款都是保本型产品,不会亏损本金,并且收益由投资标的决定,所以理财产品和定期存款相比较更加稳定。需要注意的是理财产品和定期存款没有同样的活期利率,而且理财产品每日计算收益,所以理财产品收益越高,投资者获得收益的概率也越高。9QDII基金是指中国市场上开放式外资基金,这类基金主要投资于:A股、港股、美股、英股、德股、日股、法股、台股、香港股、新加坡股等;QFII基金是指中国交易所开放式外资基金,这类基金主要投资于:A股、港股、美股、英股、德股、日股、法股、台股、香港股、新加坡股等。** | **对于风险承受能力比较低的老年人来说,选择定期存款比较好,因为其风险性比较低,收益也比较稳定,而对于风险承受能力比较高的老年人来说,选择理财产品比较好,因为其风险性比较大,收益也比较高。需要注意的是目前理财产品不是保本型产品,有亏损本金的可能,并且根据投资标的分为不同的风险等级,若投资者风险承受能力比较低,那么可以选择风险等级R2及以下的理财产品,若投资者风险承受能力比较高,那么可以选择风险等级R3及以上的理财产品。当前理财产品包括:债券型基金、股票型基金、混合型基金、指数基金和联接基金等,投资者可以根据自身实际情况选择合适的产品进行投资。** |老年人存定期比较好。老年人没工作收入,资产总额也是在逐渐降低,而且老人容易生病,而医疗费支出较大。如果老年人投资理财产品,伴随着其产品净值的变化,是有可能会出现亏损现象的。而定期存款,其利率一般是固定的,期满返本还息,相对而言,安全系数高,盈利更持久。|作为AI语言模型,我没有个人观点和建议。但是,老年人理财和存定期都有其优缺点。老年人理财的优点是可以更好地规划自己的财务状况,包括退休金、医疗保险、子女教育等方面的支出。此外,老年人可以通过投资股票、基金等金融产品来增加财富。存定期的优点是风险较小,收益相对稳定。老年人可以选择不同期限的定期存款,根据自己的需求和风险承受能力来选择。综合来看,老年人应该根据自己的实际情况和需求来选择适合自己的理财方式。如果老年人有足够的资金和风险承受能力,可以考虑投资股票、基金等金融产品来增加财富。如果老年人更注重稳定性和安全性,可以选择存定期来保障自己的财务状况。| 139 | |股票和基金能当天随买随卖吗?|Yes, stocks and funds can be bought and sold on the same day.| **不能,股票和基金实行T+1交易,交易日当天买入,按照买入当天收盘时的净值计算,第二个交易日确认份额,份额确认后计算收益,所以在交易日买入就要等到第二个交易日才能获得收益。比如投资者周一买入了某只股票,周二市场上公布的净值为10元,那么投资者可以根据这些数据计算收益:1000×10=10000;10000/10=1000。若投资者持有此股票超过30个工作日,那么可以直接转换成现金,也就是将该股票全部或者部分提前支取。需要注意的是,目前没有任何法律法规限制投资者在交易日买入基金、股票,因此投资者在交易日买入基金、股票都能进行当天开放型交易。同时,投资者在交易日买入股票之后,其股票会自动转换为现金,并且投资者在交易日买入股票之后,其股票会自动转换为涨跌幅限制(20%左右),从而影响投资者的手续费和收益。**| **不能,股票交易时间:周一至周五上午9:30-11:30,下午13:00-15:00,法定节假日不交易。基金交易时间:周一至周五上午9:30-11:30,下午13:00-15:00,法定节假日不交易。股票实行T+1交易,按照市场实时价格进行成交,遵循价格优先、时间优先的原则,沪深两市股票涨跌幅限制为10%,创业板和科创板股票涨跌幅限制为20%,北交所涨跌幅限制为30%,ST以及*ST涨跌幅限制为5%(创业板和科创板涨跌幅还没有这个限制),港股涨跌幅限制为20%,新三板涨跌幅限制为10%(创业板和科创板涨跌幅还没有这个限制),对于投资者来说,股票交易比较灵活,可以根据自身情况进行交易,并且股票交易手续费比较低,基金交易手续费也比较低,同时股票波动比较大,风险也比较大。** |不可以,股票基金实行T+1交易,交易日当天买入(交易时间),按照买入当天收盘时的净值计算,第二个交易日确认份额,份额确认后才能将基金卖出,一般交易日当天买入,第三个交易日才能卖出,第三个交易日随时可以卖出。| 股票和基金的交易方式因不同的证券市场而异。在一些证券市场上,股票和基金是可以当天随买随卖的,也就是所谓的“T+0”交易。但是,在另一些证券市场上,股票和基金是不能当天随买随卖的,需要等待一定的交易时间后才能进行买卖。在中国A股市场中,股票和基金是不能当天随买随卖的,需要进行T+1或T+2的交易。这意味着如果您今天买入了一只股票或基金,那么您需要等到明天才能卖出它。 在其他一些国家和地区,股票和基金的交易方式可能会有所不同。因此,如果您想进行股票或基金的投资,建议您了解当地的交易规则和市场情况。 | 140 | 141 | 142 | ## 项目参与者 143 | 本项目由[中科院成都计算机应用研究所 (中科信息)](http://www.casit.com.cn/)人工智能部Yangmu Yu、Wenhuan Hong完成。 144 | 145 | 146 | ## 👏 致谢 147 | 148 | 149 | 本项目参考了以下开源项目,在此对相关项目和研究开发人员表示感谢。 150 | 151 | - Facebook LLaMA: https://github.com/facebookresearch/llama 152 | - Stanford Alpaca: https://github.com/tatsu-lab/stanford_alpaca 153 | - alpaca-lora by @tloen: https://github.com/tloen/alpaca-lora 154 | - Huatuo-Llama-Med-Chinese: https://github.com/SCIR-HI/Hatuo-Llama-Med-Chinese 155 | - 文心一言 https://yiyan.baidu.com/welcome 本项目的logo由文心一言自动生成 156 | - 讯飞星火认知 https://xinghuo.xfyun.cn/desk 157 | 158 | 159 | ## 免责声明 160 | 161 | 本项目相关资源仅供学术研究之用,严禁用于商业用途。使用涉及第三方代码的部分时,请严格遵循相应的开源协议。模型生成的内容受模型计算、随机性和量化精度损失等因素影响,本项目无法对其准确性作出保证。对于模型输出的任何内容,不作为任何投资建议,本项目不承担任何法律责任,亦不对因使用相关资源和输出结果而可能产生的任何损失承担责任。 162 | 163 | 164 | ## TODO List 165 | 166 | - [ ] 支持中文金融领域 multi-task SFT 167 | - [ ] 支持量化模型 CUDA 部署 168 | - [ ] 强化学习 -> Chat化 169 | - [ ] 中文金融领域 next-pretrain 170 | - [ ] 支持 13B 模型 171 | 172 | 173 | ## 📌 Citation 174 | 175 | 如果你使用了本项目的数据或者代码,请声明引用 176 | 177 | 178 | ``` 179 | 180 | @misc{Cornucopia-LLaMA-Fin-Chinese, 181 | title={Cornucopia-LLaMA-Fin-Chinese}, 182 | author={YangMu Yu}, 183 | year={2023}, 184 | publisher = {GitHub}, 185 | journal = {GitHub repository}, 186 | howpublished = {\url{https://github.com/jerry1993-tech/Cornucopia-LLaMA-Fin-Chinese}}, 187 | } 188 | 189 | ``` 190 | 191 | 192 | ## ⛽️ 如何贡献 193 | 194 | 如果您想参与本项目,提交贡献数据/代码,请参考[如何贡献](https://github.com/jerry1993-tech/Cornucopia-LLaMA-Fin-Chinese/blob/main/HOW_TO_CONTRIBUTE.md)。 195 | 196 | 197 | ## ☎️ 交流和问题反馈 198 | 199 | 欢迎点赞🌟、关注、分享,一键三连;如有问题,请在GitHub Issue中提交,或进群进一步探讨: 200 | 201 | 202 | 203 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tuning_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | This file provides a method to tuning LLaMa model with financial data: 5 | ### Lora + int8_training ### 6 | """ 7 | 8 | import argparse 9 | import os 10 | import sys 11 | from typing import List 12 | 13 | import torch 14 | import transformers 15 | from transformers import LlamaForCausalLM, LlamaTokenizer 16 | from datasets import load_dataset 17 | from utils.prompter import Prompter 18 | 19 | from peft import ( 20 | LoraConfig, 21 | get_peft_model, 22 | get_peft_model_state_dict, 23 | prepare_model_for_int8_training, 24 | set_peft_model_state_dict, 25 | ) 26 | 27 | parser = argparse.ArgumentParser() 28 | # model/data params 29 | parser.add_argument("--base_model", default='', type=str, required=True, help="original pretrained llama weights") 30 | parser.add_argument("--data_path", default='yahma/alpaca-cleaned', type=str, help="dataset of SFT.") 31 | parser.add_argument("--output_dir", default='./lora-alpaca', type=str, help="The path of lora model.") 32 | # training hyperparams 33 | parser.add_argument("--batch_size", default=128, type=int, help="Batch size per GPU/CPU for training.") 34 | parser.add_argument("--micro_batch_size", default=8, type=int, help="Batch size per process for training.") 35 | parser.add_argument("--num_epochs", default=8, type=int, help="Total number of training epochs to perform.") 36 | parser.add_argument("--learning_rate", default=2e-4, type=float, help="The initial learning rate for Adam.") 37 | parser.add_argument("--cutoff_len", default=512, type=int, 38 | help="The maximum total input sequence length after tokenization. Sequences longer " 39 | "than this will be truncated, sequences shorter will be padded.") 40 | parser.add_argument("--val_set_size", default=1200, type=int, help="Batch size for evaluate.") 41 | # lora hyperparams 42 | parser.add_argument("--lora_r", default=8, type=int, help="The number of Lora ranks.") 43 | parser.add_argument("--lora_alpha", default=16, type=int, help="Set the alpha parameter of LORA.") 44 | parser.add_argument("--lora_dropout", default=0.05, type=float, help="Set the dropout parameter of LORA.") 45 | parser.add_argument("--lora_target_modules", default=["q_proj", "v_proj"], type=List[str], 46 | help="Set the target module for the PEFT model.") 47 | # llm hyperparams 48 | parser.add_argument("--train_on_inputs", default=False, type=bool, help="if False, masks out inputs in loss.") 49 | parser.add_argument("--group_by_length", default=False, type=bool, 50 | help="faster, but produces an odd training loss curve.") 51 | # wandb params 52 | parser.add_argument("--wandb_project", default='llama_fin', type=str, help="The name of wandb_project.") 53 | parser.add_argument("--wandb_run_name", default='', type=str) 54 | parser.add_argument("--wandb_watch", default='', type=str, choices=['false', 'gradients', 'all']) 55 | parser.add_argument("--wandb_log_model", default='', type=str, choices=['false', 'true']) 56 | parser.add_argument("--resume_from_checkpoint", default=None, type=str, 57 | help="Either training checkpoint or final adapter.") 58 | parser.add_argument("--prompt_template_name", default='alpaca', type=str, 59 | help="The prompt template to use, will default to alpaca.") 60 | args = parser.parse_args() 61 | 62 | 63 | def do_tuning(): 64 | if int(os.environ.get("LOCAL_RANK", 0)) == 0: 65 | print( 66 | "Training Alpaca-LoRA model with params:\n" 67 | "base_model: {}\n".format(args.base_model), 68 | "data_path: {}\n".format(args.data_path), 69 | "output_dir: {}\n".format(args.output_dir), 70 | "batch_size: {}\n".format(args.batch_size), 71 | "micro_batch_size: {}\n".format(args.micro_batch_size), 72 | "num_epochs: {}\n".format(args.num_epochs), 73 | "learning_rate: {}\n".format(args.learning_rate), 74 | "cutoff_len: {}\n".format(args.cutoff_len), 75 | "val_set_size: {}\n".format(args.val_set_size), 76 | "lora_r: {}\n".format(args.lora_r), 77 | "lora_alpha: {}\n".format(args.lora_alpha), 78 | "lora_dropout: {}\n".format(args.lora_dropout), 79 | "lora_target_modules: {}\n".format(args.lora_target_modules), 80 | "train_on_inputs: {}\n".format(args.train_on_inputs), 81 | "group_by_length: {}\n".format(args.group_by_length), 82 | "wandb_project: {}\n".format(args.wandb_project), 83 | "wandb_run_name: {}\n".format(args.wandb_run_name), 84 | "wandb_watch: {}\n".format(args.wandb_watch), 85 | "wandb_log_model: {}\n".format(args.wandb_log_model), 86 | "resume_from_checkpoint: {}\n".format(args.resume_from_checkpoint or None), 87 | "prompt template: {}\n".format(args.prompt_template_name) 88 | ) 89 | 90 | # --------------------------Check-------------------------- 91 | # Check if parameters passed or set in the environment 92 | use_wandb = len(args.wandb_project) > 0 or ( 93 | "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 94 | ) 95 | # Only overwrite environment if wandb param passed 96 | if len(args.wandb_project) > 0: 97 | os.environ["WANDB_PROJECT"] = args.wandb_project 98 | if len(args.wandb_watch) > 0: 99 | os.environ["WANDB_WATCH"] = args.wandb_watch 100 | if len(args.wandb_log_model) > 0: 101 | os.environ["WANDB_LOG_MODEL"] = args.wandb_log_model 102 | 103 | # Check if the base_model exists 104 | assert args.base_model, "Please specify a base_model, for example: 'decapoda-research/llama-7b-hf'" 105 | 106 | gradient_accumulation_steps = args.batch_size // args.micro_batch_size 107 | device_map = "auto" 108 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 109 | ddp = world_size != 1 110 | if ddp: # ddp adopts a multiprocessing approach 111 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 112 | gradient_accumulation_steps = gradient_accumulation_steps // world_size 113 | 114 | # --------------------------Model-------------------------- 115 | # 与全精度模型相比,以 8 位精度加载模型最多可节省 4 倍的内存 116 | model = LlamaForCausalLM.from_pretrained( 117 | args.base_model, 118 | load_in_8bit=True, 119 | torch_dtype=torch.float16, 120 | device_map=device_map, 121 | ) 122 | model = prepare_model_for_int8_training(model) 123 | 124 | # using lora to tuning 125 | config = LoraConfig( 126 | r=args.lora_r, 127 | lora_alpha=args.lora_alpha, 128 | target_modules=args.lora_target_modules, 129 | lora_dropout=args.lora_dropout, 130 | bias="none", 131 | task_type="CAUSAL_LM", 132 | ) 133 | model = get_peft_model(model, config) 134 | 135 | # Check the available weights and load 136 | if args.resume_from_checkpoint: 137 | # Full checkpoint 138 | checkpoint_name = os.path.join( 139 | args.resume_from_checkpoint, "pytorch_model.bin" 140 | ) 141 | 142 | # only LoRA model 143 | if not os.path.exists(checkpoint_name): 144 | checkpoint_name = os.path.join(args.resume_from_checkpoint, "adapter_model.bin") 145 | # So the trainer won't try loading its state 146 | args.resume_from_checkpoint = None 147 | # The two files above have a different name depending on how they were saved, but are actually the same. 148 | if os.path.exists(checkpoint_name): 149 | print("Restarting from {}".format(checkpoint_name)) 150 | adapters_weights = torch.load(checkpoint_name) 151 | model = set_peft_model_state_dict(model, adapters_weights) 152 | else: 153 | print("Checkpoint {} not found".format(checkpoint_name)) 154 | 155 | # More transparent to x% of trainable parameters 156 | model.print_trainable_parameters() 157 | 158 | # --------------------------Tokenizer-------------------------- 159 | tokenizer = LlamaTokenizer.from_pretrained(args.base_model) 160 | tokenizer.pad_token_id = 0 # 「unk」 different from the eos token 161 | tokenizer.padding_side = "left" # allow batched inference 162 | 163 | def tokenize(prompt, add_eos_token=True): 164 | result = tokenizer( 165 | prompt, 166 | truncation=True, 167 | max_length=args.cutoff_len, 168 | padding=False, 169 | return_tensors=None, 170 | ) 171 | 172 | if ( 173 | result["input_ids"][-1] != tokenizer.eos_token_id 174 | and len(result["input_ids"]) < args.cutoff_len 175 | and add_eos_token 176 | ): 177 | result["input_ids"].append(tokenizer.eos_token_id) 178 | result["attention_mask"].append(1) 179 | 180 | result["labels"] = result["input_ids"].copy() 181 | 182 | return result 183 | 184 | # choose prompt template 185 | prompter = Prompter(args.prompt_template_name) 186 | 187 | def generate_and_tokenize_prompt(data_point): 188 | full_prompt = prompter.generate_prompt( 189 | data_point["instruction"], 190 | data_point["input"], 191 | data_point["output"], 192 | ) 193 | tokenized_full_prompt = tokenize(full_prompt) 194 | 195 | if not args.train_on_inputs: 196 | user_prompt = prompter.generate_prompt(data_point["instruction"], data_point["input"]) 197 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) 198 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 199 | 200 | tokenized_full_prompt["labels"] = [-100] * user_prompt_len \ 201 | + tokenized_full_prompt["labels"][user_prompt_len:] # Maybe faster 202 | return tokenized_full_prompt 203 | 204 | # load dataset from file(here is xx.json) 205 | if args.data_path.endswith(".json") or args.data_path.endswith(".jsonl"): 206 | data = load_dataset("json", data_files=args.data_path) 207 | else: 208 | data = load_dataset(args.data_path) 209 | 210 | if args.val_set_size > 0: 211 | train_val = data["train"].train_test_split( 212 | test_size=args.val_set_size, shuffle=True, seed=2023 213 | ) 214 | train_data = ( 215 | train_val["train"].shuffle().map(generate_and_tokenize_prompt) 216 | ) 217 | val_data = ( 218 | train_val["test"].shuffle().map(generate_and_tokenize_prompt) 219 | ) 220 | else: 221 | train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) 222 | val_data = None 223 | 224 | # --------------------------Trainer-------------------------- 225 | # 当多张显卡时,阻止 Trainer 使用自己的 DataParallelism 226 | if not ddp and torch.cuda.device_count() > 1: 227 | model.is_parallelizable = True 228 | model.model_parallel = True 229 | 230 | trainer = transformers.Trainer( 231 | model=model, 232 | train_dataset=train_data, 233 | eval_dataset=val_data, 234 | args=transformers.TrainingArguments( 235 | per_device_train_batch_size=args.micro_batch_size, 236 | gradient_accumulation_steps=gradient_accumulation_steps, 237 | warmup_ratio=0.1, 238 | num_train_epochs=args.num_epochs, 239 | learning_rate=args.learning_rate, 240 | fp16=True, 241 | logging_steps=8, 242 | optim="adamw_torch", 243 | evaluation_strategy="steps" if args.val_set_size > 0 else "no", 244 | save_strategy="steps", 245 | eval_steps=32 if args.val_set_size > 0 else None, 246 | save_steps=32, 247 | output_dir=args.output_dir, 248 | save_total_limit=5, 249 | load_best_model_at_end=True if args.val_set_size > 0 else False, 250 | ddp_find_unused_parameters=False if ddp else None, 251 | group_by_length=args.group_by_length, 252 | report_to="wandb" if use_wandb else None, 253 | run_name=args.wandb_run_name if use_wandb else None, 254 | ), 255 | data_collator=transformers.DataCollatorForSeq2Seq( 256 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 257 | ), 258 | ) 259 | 260 | model.config.use_cache = False 261 | 262 | old_state_dict = model.state_dict 263 | model.state_dict = ( 264 | lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) 265 | ).__get__(model, type(model)) 266 | 267 | if torch.__version__ >= "2" and sys.platform != "win32": 268 | model = torch.compile(model) 269 | 270 | trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) 271 | 272 | model.save_pretrained(args.output_dir) 273 | 274 | print("\n 若上面出现有关于keys丢失的警告,请忽略! o(^_^)o ~") 275 | 276 | 277 | if __name__ == "__main__": 278 | do_tuning() 279 | --------------------------------------------------------------------------------