├── License.pdf
├── README.md
├── README_EN.md
├── demos
├── cli_demo.py
├── openai_api.py
├── requirements_web_demo.txt
└── web_demo.py
├── evaluation
├── README.md
├── README_EN.md
├── all_config.yaml
├── chat_humaneval.py
├── eval.py
└── run_eval.sh
├── finetune
├── README.md
├── README_EN.md
├── ds_config_zero3.json
├── finetune.py
└── run_finetune.sh
├── model
├── configuration_codeshell.py
├── modeling_codeshell.py
└── quantizer.py
├── requirements.txt
└── tokenizer
├── eval_tokenizer.py
├── special_tokens_map.json
├── tokenizer.json
└── tokenizer_config.json
/License.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WisdomShell/codeshell/09d1adc88ccada1a92924c69ece0cf0e73899b1b/License.pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | 🤗 Hugging Face • 🤖 ModelScope • ⭕️ WiseModel • 🌐 PKU-KCL
8 |
9 |
10 |
11 |
12 | [](https://github.com/WisdomShell/codeshell/blob/main/License.pdf)
13 |
14 |
中文|English
15 |
16 |
17 |
18 | ## Introduction
19 |
20 | CodeShell是[北京大学知识计算实验室](http://se.pku.edu.cn/kcl/)联合四川天府银行AI团队研发的多语言代码大模型基座。CodeShell具有70亿参数,在五千亿Tokens进行了训练,上下文窗口长度为8192。在权威的代码评估Benchmark(HumanEval与MBPP)上,CodeShell取得同等规模最好的性能。与此同时,我们提供了与CodeShell配套的部署方案与IDE插件,请参考代码库[CodeShell](https://github.com/WisdomShell/codeshell)。同时,为了方便中国用户下载,我们在[Modelscope](https://modelscope.cn/organization/WisdomShell)和[Wisemodel](https://www.wisemodel.cn/models/WisdomShell/CodeShell-7B/)中也上传了对应版本,国内用户可以访问。
21 |
22 |
23 | 本次开源的模型如下:
24 |
25 | - CodeShell Base:CodelShell底座模型,具有强大的代码基础能力。
26 | - CodeShell Chat:CodelShell对话模型,在代码问答、代码补全等下游任务重性能优异。
27 | - CodeShell Chat 4bit:CodelShell对话模型4bit量化版本,在保证模型性能的前提下内存消耗更小,速度更快。
28 | - CodeShell CPP:CodelShell对话模型CPP版本,支持开发者在没有GPU的个人电脑中使用。注意,CPP版本同样支持量化操作,用户可以在最小内存为8G的个人电脑中运行CodeShell。
29 |
30 |
31 | ## Main Characteristics of CodeShell
32 |
33 | - **强大的性能**:CodelShell在HumanEval和MBPP上达到了7B代码基座大模型的最优性能
34 | - **完整的体系**:除了代码大模型,同时开源IDE(VS Code与JetBrains)插件,形成开源的全栈技术体系
35 | - **轻量化部署**:支持本地C++部署,提供轻量快速的本地化软件开发助手解决方案
36 | - **全面的评测**:提供支持完整项目上下文、覆盖代码生成、代码缺陷检测与修复、测试用例生成等常见软件开发活动的多任务评测体系(即将开源)
37 | - **高效的训练**:基于高效的数据治理体系,CodeShell在完全冷启动情况下,只训练了五千亿Token即获得了优异的性能
38 |
39 | ## Performance
40 |
41 | 我们选取了目前最流行的两个代码评测数据集(HumanEval与MBPP)对模型进行评估,与目前最先进的两个7b代码大模型CodeLlama与Starcoder相比,Codeshell 取得了最优的成绩。具体评测结果如下。
42 |
43 | | 任务 | CodeShell-7b | CodeLlama-7b | Starcoder-7b |
44 | | ------- | --------- | --------- | --------- |
45 | | humaneval | **34.32** | 29.44 | 27.80 |
46 | | mbpp | **38.65** | 37.60 | 34.16 |
47 | | multiple-js | **33.17** | 31.30 | 27.02 |
48 | | multiple-java | **30.43** | 29.24 | 24.30 |
49 | | multiple-cpp | **28.21** | 27.33 | 23.04 |
50 | | multiple-swift | 24.30 | **25.32** | 15.70 |
51 | | multiple-php | **30.87** | 25.96 | 22.11 |
52 | | multiple-d | 8.85 | **11.60** | 8.08 |
53 | | multiple-jl | 22.08 | **25.28** | 22.96 |
54 | | multiple-lua | 22.39 | **30.50** | 22.92 |
55 | | multiple-r | **20.52** | 18.57 | 14.29 |
56 | | multiple-rkt | **17.20** | 12.55 | 10.43 |
57 | | multiple-rs | 24.55 | **25.90** | 22.82 |
58 |
59 | ## Requirements
60 |
61 | - python 3.8 and above
62 | - pytorch 2.0 and above are recommended
63 | - transformers 4.32 and above
64 | - CUDA 11.8 and above are recommended (this is for GPU users, flash-attention users, etc.)
65 |
66 | ## Quickstart
67 |
68 | CodeShell系列模型已经上传至 Hugging Face,开发者可以通过Transformers快速调用CodeShell和CodeShell-Chat。
69 |
70 | 在开始之前,请确保已经正确设置了环境,并安装了必要的代码包,以及满足上一小节的环境要求。你可以通过下列代码快速安装相关依赖。
71 |
72 | ```
73 | pip install -r requirements.txt
74 | ```
75 |
76 | 接下来你可以通过Transformers使用CodeShell。
77 |
78 | ### Code Generation
79 |
80 | 开发者可以使用CodeShell快速生成代码,加速开发效率。
81 |
82 | ```python
83 | import torch
84 | from transformers import AutoModelForCausalLM, AutoTokenizer
85 |
86 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
87 | tokenizer = AutoTokenizer.from_pretrained("WisdomShell/CodeShell-7B")
88 | model = AutoModelForCausalLM.from_pretrained("WisdomShell/CodeShell-7B", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
89 | inputs = tokenizer('def merge_sort():', return_tensors='pt').to(device)
90 | outputs = model.generate(**inputs)
91 | print(tokenizer.decode(outputs[0]))
92 | ```
93 |
94 | - Fill in the Moddle
95 |
96 | CodeShell 支持Fill-in-the-Middle模式,从而更好的支持软件开发过程。
97 |
98 | ```python
99 | input_text = "def print_hello_world():\n \n print('Hello world!')"
100 | inputs = tokenizer(input_text, return_tensors='pt').to(device)
101 | outputs = model.generate(**inputs)
102 | print(tokenizer.decode(outputs[0]))
103 | ```
104 |
105 | - 代码问答
106 |
107 | CodeShell同时开源了代码助手模型CodeShell-7B-Chat,开发者可以通过下列代码与模型进行交互。
108 |
109 | ```python
110 | model = AutoModelForCausalLM.from_pretrained('WisdomShell/CodeShell-7B-Chat', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
111 | tokenizer = AutoTokenizer.from_pretrained('WisdomShell/CodeShell-7B-Chat')
112 |
113 | history = []
114 | query = '你是谁?'
115 | response = model.chat(query, history, tokenizer)
116 | print(response)
117 | history.append((query, response))
118 |
119 | query = '用Python写一个HTTP server'
120 | response = model.chat(query, history, tokenizer)
121 | print(response)
122 | history.append((query, response))
123 | ```
124 |
125 | 开发者也可以通过VS Code与JetBrains插件与CodeShell-7B-Chat交互,详情请参[VSCode插件仓库](https://github.com/WisdomShell/codeshell-vscode)与[IntelliJ插件仓库](https://github.com/WisdomShell/codeshell-intellij)。
126 |
127 |
128 | - Model Quantization
129 |
130 | CodeShell 支持4 bit/8 bit量化,4 bit量化后,占用显存大小约6G,用户可以在显存较小的GPU上使用CodeShell。
131 |
132 | ```python
133 | model = AutoModelForCausalLM.from_pretrained('WisdomShell/CodeShell-7B-Chat-int4', trust_remote_code=True).to(device)
134 | tokenizer = AutoTokenizer.from_pretrained('WisdomShell/CodeShell-7B-Chat-int4')
135 | ```
136 |
137 | - CodeShell in c/c++
138 |
139 | 由于大部分个人电脑没有GPU,CodeShell提供了C/C++版本的推理支持,开发者可以根据本地环境进行编译与使用,详见[CodeShell C/C++本地化版](https://github.com/WisdomShell/llama_cpp_for_codeshell)。
140 |
141 | ## Demo
142 |
143 | 我们提供了Web-UI、命令行、API、IDE四种形式的Demo。
144 |
145 | ### Web UI
146 |
147 | 开发者通过下列命令启动Web服务,服务启动后,可以通过`https://127.0.0.1:8000`进行访问。
148 |
149 | ```
150 | python demos/web_demo.py
151 | ```
152 |
153 | ### CLI Demo
154 |
155 | 我们也提供了命令行交互的Demo版本,开发者可以通过下列命令运行。
156 |
157 | ```
158 | python demos/cli_demo.py
159 | ```
160 |
161 | ### API
162 |
163 | CodeShell也提供了基于OpenAI API的部署方法。
164 |
165 | ```
166 | python demos/openai_api.py
167 | ```
168 |
169 | 启动后即可通过HTTP请求与CodeShell交互。
170 |
171 | ```
172 | curl http://127.0.0.1:8000/v1/chat/completions \
173 | -H "Content-Type: application/json" \
174 | -d '{
175 | "model": "CodeShell-7B-Chat",
176 | "messages": [
177 | {
178 | "role": "user",
179 | "content": "你好"
180 | }
181 | ]
182 | }'
183 | ```
184 |
185 | ### IDE
186 |
187 | CodeShell最后提供了线上IDE,开发者可以通过IDE进行代码补全、代码问答等操作。同时,IDE插件也同时发布,开发者可以自行在本地进行安装使用。插件相关问题欢迎在[VSCode插件仓库](https://github.com/WisdomShell/codeshell-vscode)与[IntelliJ插件仓库](https://github.com/WisdomShell/codeshell-intellij)中讨论。
188 |
189 | ## Model Details
190 |
191 | Code Shell使用GPT-2作为基础架构,采用Grouped-Query Attention、RoPE相对位置编码等技术。
192 |
193 | ### Hyper-parameter
194 |
195 | | Hyper-parameter | Value |
196 | |---|---|
197 | | n_layer | 42 |
198 | | n_embd | 4096 |
199 | | n_inner | 16384 |
200 | | n_head | 32 |
201 | | num_query_groups | 8 |
202 | | seq-length | 8192 |
203 | | vocab_size | 70144 |
204 |
205 |
206 | ### Data
207 |
208 | CodeShell基于自己爬取的Github数据、Big Code开源的Stack和StarCoder数据集、以及少量高质量的中英文数据进行训练。在原始数据集的基础上,CodeShell采用基于Minihash对数据去重,基于KenLM以及高质量数据筛选模型对数据进行了过滤与筛选,最终得到高质量的预训练数据集。
209 |
210 | ### Tokenizer
211 |
212 | CodeShell基于Starcoder词表进行了优化,去除了使用频率较低的词语,并添加了部分中文词表,显著提升了中文的压缩率,为Chat版本的训练提供了基础。
213 |
214 |
215 | | Tokenizer | Size | Chinese | English | Code | Total|
216 | |---|---|---|---|---|---|
217 | | Starcoder | 49152 | 1.22 | 3.47 | 3.30 | 2.66 |
218 | | CodeShell | 70020 | 1.50 | 3.47 | 3.30 | 2.95 |
219 |
220 |
221 | ## License
222 |
223 | 社区使用CodeShell模型需要遵循[《CodeShell模型许可协议》](https://github.com/WisdomShell/codeshell/blob/main/License.pdf)及[Apache 2.0许可协议](https://www.apache.org/licenses/LICENSE-2.0)。CodeShell模型允许用于商业用途,但如果您计划将CodeShell模型或其派生产品用于商业用途,需要您确认主体符合以下条件:
224 |
225 | 1. 关联方的服务或产品的每日平均活跃用户数(DAU)不能超过100万。
226 | 2. 关联方不得是软件服务提供商或云服务提供商。
227 | 3. 关联方不存在将获得授予的商业许可,在未经许可的前提下将其再授权给其他第三方的可能性。
228 |
229 | 在满足上述条件的前提下,您需要通过向codeshell.opensource@gmail.com发送电子邮件,提交《CodeShell模型许可协议》要求的申请材料。经审核通过后,将授予您一个全球的、非排他的、不可转让的、不可再授权的商业版权许可。
230 |
231 |
232 | ## Star History
233 |
234 | [](https://star-history.com/#WisdomShell/codeshell&Date)
235 |
236 |
--------------------------------------------------------------------------------
/README_EN.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | 🤗 Hugging Face • 🤖 ModelScope • ⭕️ WiseModel • 🌐 PKU-KCL
8 |
9 |
10 |
11 |
12 | [](https://github.com/WisdomShell/codeshell/blob/main/License.pdf)
13 |
14 |
Chinese|English
15 |
16 |
17 |
18 | ## Introduction
19 |
20 | CodeShell is a code large language model (LLM) developed jointly by the [Knowledge Computing Lab at Peking University](http://se.pku.edu.cn/kcl/) and the AI team of Sichuan Tianfu Bank. CodeShell has 7 billion parameters, was trained on 500 billion tokens, and has a context window length of 8192. On authoritative code evaluation benchmarks (HumanEval and MBPP), CodeShell achieves the best performance for models of its scale. At the same time, we offer deployment solutions and IDE plugins that complement CodeShell. Please refer to the [CodeShell](https://github.com/WisdomShell/codeshell) repository for details.
21 |
22 | The open-source models are as follows:
23 |
24 | - CodeShell Base: The foundational model of CodeShell with strong coding capabilities.
25 | - CodeShell Chat: A dialogue model of CodeShell that excels in code Q&A, code completion, and other downstream tasks.
26 | - CodeShell Chat 4bit: A 4bit quantized version of the CodeShell dialogue model. While preserving model performance, it consumes less memory and operates faster.
27 | - CodeShell CPP: A C++ version of the CodeShell dialogue model. It allows developers to use it on personal computers without GPUs. Note that the CPP version also supports quantization, allowing users to run CodeShell on PCs with a minimum of 8GB RAM.
28 |
29 | ## Main Characteristics of CodeShell
30 |
31 | - **Powerful Performance**: CodeShell achieves optimal performance in 7B code base models on HumanEval and MBPP.
32 | - **Complete Ecosystem**: In addition to the code model, IDE plugins for open-source (VS Code and JetBrains) are provided, forming a complete open-source technology stack.
33 | - **Lightweight Deployment**: Supports local C++ deployment, providing a lightweight and fast local software development assistant solution.
34 | - **Comprehensive Evaluation**: A multi-task evaluation system that supports a complete project context and covers common software development activities such as code generation, code defect detection and repair, and test case generation will be open-sourced soon.
35 | - **Efficient Training**: Based on an efficient data governance system, CodeShell achieved excellent performance after training only 500 billion tokens from a complete cold start.
36 |
37 | ## Performance
38 |
39 | We selected the two most popular code evaluation datasets (HumanEval and MBPP) to evaluate the model. Compared with the two most advanced 7B code models, CodeLlama and Starcoder, Codeshell achieved optimal results. The specific evaluation results are as follows.
40 |
41 | | Task | CodeShell-7b | CodeLlama-7b | Starcoder-7b |
42 | | ------- | --------- | --------- | --------- |
43 | | humaneval | **34.32** | 29.44 | 27.80 |
44 | | mbpp | **38.65** | 37.60 | 34.16 |
45 | | multiple-js | **33.17** | 31.30 | 27.02 |
46 | | multiple-java | **30.43** | 29.24 | 24.30 |
47 | | multiple-cpp | **28.21** | 27.33 | 23.04 |
48 | | multiple-swift | 24.30 | **25.32** | 15.70 |
49 | | multiple-php | **30.87** | 25.96 | 22.11 |
50 | | multiple-d | 8.85 | **11.60** | 8.08 |
51 | | multiple-jl | 22.08 | **25.28** | 22.96 |
52 | | multiple-lua | 22.39 | **30.50** | 22.92 |
53 | | multiple-r | **20.52** | 18.57 | 14.29 |
54 | | multiple-rkt | **17.20** | 12.55 | 10.43 |
55 | | multiple-rs | 24.55 | **25.90** | 22.82 |
56 |
57 | ## Requirements
58 |
59 | - python 3.8 and above
60 | - pytorch 2.0 and above are recommended
61 | - transformers 4.32 and above
62 | - CUDA 11.8 and above are recommended (for GPU users, flash-attention users, etc.)
63 |
64 | ## Quickstart
65 |
66 | The CodeShell series models have been uploaded to Hugging Face. Developers can quickly call CodeShell and CodeShell-Chat through Transformers.
67 |
68 | Before starting, make sure you have set up the environment correctly, installed the necessary packages, and meet the environmental requirements from the previous section. The necessary dependencies can be installed quickly using the following code:
69 |
70 | pip install -r requirements.txt
71 |
72 |
73 | Next, you can use CodeShell through Transformers.
74 |
75 | ### Code Generation
76 |
77 | Developers can use CodeShell to quickly generate code, accelerating development efficiency.
78 |
79 | ```python
80 | import torch
81 | from transformers import AutoModelForCausalLM, AutoTokenizer
82 |
83 | tokenizer = AutoTokenizer.from_pretrained("WisdomShell/CodeShell-7B")
84 | model = AutoModelForCausalLM.from_pretrained("WisdomShell/CodeShell-7B", trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
85 | inputs = tokenizer('def merge_sort():', return_tensors='pt').cuda()
86 | outputs = model.generate(inputs)
87 | print(tokenizer.decode(outputs[0]))
88 | ```
89 |
90 | ### Fill in the Middle
91 | CodeShell supports the Fill-in-the-Middle mode to better assist the software development process.
92 |
93 | ```python
94 | input_text = "def print_hello_world():\n \n print('Hello world!')"
95 | inputs = tokenizer(input_text, return_tensors='pt').cuda()
96 | outputs = model.generate(inputs)
97 | print(tokenizer.decode(outputs[0]))
98 | ```
99 |
100 | ### Code Q&A
101 | CodeShell has also open-sourced the CodeShell-7B-Chat code assistant model. Developers can interact with the model using the following code.
102 |
103 | ```python
104 | model = AutoModelForCausalLM.from_pretrained('WisdomShell/CodeShell-7B-Chat', trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
105 | tokenizer = AutoTokenizer.from_pretrained('WisdomShell/CodeShell-7B-Chat')
106 |
107 | history = []
108 | query = 'Who are you?'
109 | response = model.chat(query, history, tokenizer)
110 | print(response)
111 | history.append((query, response))
112 |
113 | query = 'Write an HTTP server in Python'
114 | response = model.chat(query, history, tokenizer)
115 | print(response)
116 | history.append((query, response))
117 | ```
118 |
119 | Developers can also interact with CodeShell-7B-Chat through VS Code and JetBrains plugins. For details, please refer to the VSCode plugin repository and IntelliJ plugin repository.
120 |
121 | ### Model Quantization
122 | CodeShell supports 4 bit/8 bit quantization. After 4-bit quantization, the memory footprint is approximately 6GB, allowing users to use CodeShell on GPUs with smaller memory.
123 |
124 | ```python
125 | model = AutoModelForCausalLM.from_pretrained('WisdomShell/CodeShell-7B-Chat-int4', trust_remote_code=True).to(device)
126 | tokenizer = AutoTokenizer.from_pretrained('WisdomShell/CodeShell-7B-Chat-int4')
127 | ```
128 |
129 | ### CodeShell in c/c++
130 | As most personal computers lack a GPU, CodeShell offers C/C++ inference support. Developers can compile based on the local environment. See CodeShell C/C++ local version. After compilation, the Web API service can be started with the following command.
131 |
132 |
133 | ## Demo
134 | We offer demos in four forms: Web-UI, command line, API, and IDE.
135 |
136 | ### Web UI
137 | Developers can start the Web service using the following command. After the service starts, it can be accessed at https://127.0.0.1:8000.
138 |
139 | ```
140 | python demos/web_demo.py
141 | ```
142 |
143 | ### CLI Demo
144 |
145 | We also offer a command-line interactive demo version. Developers can run it using the following command.
146 |
147 | ```
148 | python cli_demo.py
149 | ```
150 |
151 | ### API
152 |
153 | CodeShell also offers a deployment method based on the OpenAI API.
154 |
155 | ```
156 | python openai_api.py
157 | ```
158 |
159 | Then you can interact with CodeShell via HTTP requests:
160 |
161 | ```
162 | curl http://127.0.0.1:8000/v1/chat/completions \
163 | -H "Content-Type: application/json" \
164 | -d '{
165 | "model": "CodeShell-7B-Chat",
166 | "messages": [
167 | {
168 | "role": "user",
169 | "content": "你好"
170 | }
171 | ]
172 | }'
173 | ```
174 |
175 | ### IDE
176 |
177 | Finally, CodeShell offers an online IDE. Developers can use the IDE for code completion, code Q&A, and other operations. IDE plugins are also released, and developers can install and use them locally. For plugin-related issues, please discuss in the VS Code plugin repository.
178 |
179 | ## Model Details
180 |
181 | Code Shell uses GPT-2 as its basic architecture and employs technologies like Grouped-Query Attention and RoPE relative position encoding.
182 |
183 | ### Hyper-parameter
184 |
185 | | Hyper-parameter | Value |
186 | |---|---|
187 | | n_layer | 42 |
188 | | n_embd | 4096 |
189 | | n_inner | 16384 |
190 | | n_head | 32 |
191 | | num_query_groups | 8 |
192 | | seq-length | 8192 |
193 | | vocab_size | 70144 |
194 |
195 |
196 | ### Data
197 |
198 | CodeShell was trained based on its own scraped Github data, the open-source Stack and StarCoder datasets from Big Code, as well as a small amount of high-quality Chinese and English data. On top of the original dataset, CodeShell used Minihash for data deduplication, KenLM, and a high-quality data selection model for data filtering and selection, resulting in a high-quality pre-training dataset.
199 |
200 | ### Tokenizer
201 |
202 | CodeShell optimized the Starcoder vocabulary by removing infrequently used words and adding some Chinese vocabulary, significantly improving the Chinese compression rate, laying the groundwork for the training of the Chat version.
203 |
204 |
205 | | Tokenizer | Size | Chinese | English | Code | Total|
206 | |---|---|---|---|---|---|
207 | | Starcoder | 49152 | 1.22 | 3.47 | 3.30 | 2.66 |
208 | | CodeShell | 70144 | 1.50 | 3.47 | 3.30 | 2.95 |
209 |
210 |
211 | ## License
212 | The community's use of the CodeShell model must adhere to the ["CodeShell Model License Agreement" ](https://github.com/WisdomShell/codeshell/blob/main/License.pdf) and the [ Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0). CodeShell is permitted for commercial use. However, if you plan to use the CodeShell model or its derivative products for commercial purposes, you must confirm that the entity meets the following conditions:
213 |
214 | - The daily average active user count (DAU) of the affiliated party's service or product cannot exceed 1 million.
215 | - The affiliated party must not be a software service provider or cloud service provider.
216 | - There is no possibility for the affiliated party to re-license the granted commercial license to another third party without proper authorization.
217 |
218 | Under the aforementioned conditions, you need to submit the application materials required by the "CodeShell Model License Agreement" by sending an email to codeshell.opensource@gmail.com. After approval, you will be granted a global, non-exclusive, non-transferable, non-sublicensable commercial copyright license.
219 |
220 | ## Star History
221 |
222 | [](https://star-history.com/#WisdomShell/codeshell&Date)
223 |
224 |
--------------------------------------------------------------------------------
/demos/cli_demo.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # This code is based on Qwen's Cli Demo. It has been modified from
17 | # its original forms to accommodate CodeShell.
18 |
19 | # Copyright (c) Alibaba Cloud.
20 | #
21 | # This source code is licensed under the license found in the
22 | # LICENSE file in the root directory of this source tree.
23 |
24 | """A simple command-line interactive chat demo."""
25 |
26 | import argparse
27 | import os
28 | import platform
29 | import shutil
30 | from copy import deepcopy
31 |
32 | import torch
33 | from transformers import AutoModelForCausalLM, AutoTokenizer
34 | from transformers.generation import GenerationConfig
35 | from transformers.trainer_utils import set_seed
36 | from transformers import StoppingCriteria, StoppingCriteriaList
37 |
38 | DEFAULT_CKPT_PATH = 'WisdomShell/CodeShell-7B-Chat'
39 |
40 | _WELCOME_MSG = '''\
41 | Welcome to use CodeShell-Chat model, type text to start chat, type :h to show command help.
42 | (欢迎使用 CodeShell-Chat 模型,输入内容即可进行对话,:h 显示命令帮助。)
43 |
44 | Note: This demo is governed by the original license of CodeShell.
45 | We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc.
46 | (注:本演示受CodeShell的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)
47 | '''
48 | _HELP_MSG = '''\
49 | Commands:
50 | :help / :h Show this help message 显示帮助信息
51 | :exit / :quit / :q Exit the demo 退出Demo
52 | :clear / :cl Clear screen 清屏
53 | :clear-his / :clh Clear history 清除对话历史
54 | :history / :his Show history 显示对话历史
55 | :seed Show current random seed 显示当前随机种子
56 | :seed Set random seed to 设置随机种子
57 | :conf Show current generation config 显示生成配置
58 | :conf = Change generation config 修改生成配置
59 | '''
60 |
61 |
62 | def _load_model_tokenizer(args):
63 | tokenizer = AutoTokenizer.from_pretrained(
64 | args.checkpoint_path, trust_remote_code=True, resume_download=True,
65 | )
66 |
67 | model = AutoModelForCausalLM.from_pretrained(
68 | args.checkpoint_path,
69 | device_map=args.device,
70 | trust_remote_code=True,
71 | resume_download=True,
72 | torch_dtype=torch.bfloat16
73 | ).eval()
74 |
75 | config = GenerationConfig.from_pretrained(
76 | args.checkpoint_path, trust_remote_code=True, resume_download=True,
77 | )
78 |
79 | return model, tokenizer, config
80 |
81 |
82 | def _gc():
83 | import gc
84 | gc.collect()
85 | if torch.cuda.is_available():
86 | torch.cuda.empty_cache()
87 |
88 |
89 | def _clear_screen():
90 | if platform.system() == "Windows":
91 | os.system("cls")
92 | else:
93 | os.system("clear")
94 |
95 |
96 | def _print_history(history):
97 | terminal_width = shutil.get_terminal_size()[0]
98 | print(f'History ({len(history)})'.center(terminal_width, '='))
99 | for index, (query, response) in enumerate(history):
100 | print(f'User[{index}]: {query}')
101 | print(f'CodeShell[{index}]: {response}')
102 | print('=' * terminal_width)
103 |
104 |
105 | def _get_input() -> str:
106 | while True:
107 | try:
108 | message = input('User> ').strip()
109 | except UnicodeDecodeError:
110 | print('[ERROR] Encoding error in input')
111 | continue
112 | except KeyboardInterrupt:
113 | exit(1)
114 | if message:
115 | return message
116 | print('[ERROR] Query is empty')
117 |
118 | def main():
119 | parser = argparse.ArgumentParser(
120 | description='CodeShell-Chat command-line interactive chat demo.')
121 | parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
122 | help="Checkpoint name or path, default to %(default)r")
123 | parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
124 | parser.add_argument("--device", type=str, default="cuda:0",help="Device name.")
125 |
126 | args = parser.parse_args()
127 |
128 | history, response = [], ''
129 |
130 | model, tokenizer, config = _load_model_tokenizer(args)
131 | orig_gen_config = deepcopy(model.generation_config)
132 |
133 |
134 | _clear_screen()
135 | print(_WELCOME_MSG)
136 |
137 | seed = args.seed
138 |
139 | while True:
140 | query = _get_input()
141 |
142 | # Process commands.
143 | if query.startswith(':'):
144 | command_words = query[1:].strip().split()
145 | if not command_words:
146 | command = ''
147 | else:
148 | command = command_words[0]
149 |
150 | if command in ['exit', 'quit', 'q']:
151 | break
152 | elif command in ['clear', 'cl']:
153 | _clear_screen()
154 | print(_WELCOME_MSG)
155 | _gc()
156 | continue
157 | elif command in ['clear-history', 'clh']:
158 | print(f'[INFO] All {len(history)} history cleared')
159 | history.clear()
160 | _gc()
161 | continue
162 | elif command in ['help', 'h']:
163 | print(_HELP_MSG)
164 | continue
165 | elif command in ['history', 'his']:
166 | _print_history(history)
167 | continue
168 | elif command in ['seed']:
169 | if len(command_words) == 1:
170 | print(f'[INFO] Current random seed: {seed}')
171 | continue
172 | else:
173 | new_seed_s = command_words[1]
174 | try:
175 | new_seed = int(new_seed_s)
176 | except ValueError:
177 | print(f'[WARNING] Fail to change random seed: {new_seed_s!r} is not a valid number')
178 | else:
179 | print(f'[INFO] Random seed changed to {new_seed}')
180 | seed = new_seed
181 | continue
182 | elif command in ['conf']:
183 | if len(command_words) == 1:
184 | print(model.generation_config)
185 | else:
186 | for key_value_pairs_str in command_words[1:]:
187 | eq_idx = key_value_pairs_str.find('=')
188 | if eq_idx == -1:
189 | print('[WARNING] format: =')
190 | continue
191 | conf_key, conf_value_str = key_value_pairs_str[:eq_idx], key_value_pairs_str[eq_idx + 1:]
192 | try:
193 | conf_value = eval(conf_value_str)
194 | except Exception as e:
195 | print(e)
196 | continue
197 | else:
198 | print(f'[INFO] Change config: model.generation_config.{conf_key} = {conf_value}')
199 | setattr(model.generation_config, conf_key, conf_value)
200 | continue
201 | elif command in ['reset-conf']:
202 | print('[INFO] Reset generation config')
203 | model.generation_config = deepcopy(orig_gen_config)
204 | print(model.generation_config)
205 | continue
206 | else:
207 | # As normal query.
208 | pass
209 |
210 | # Run chat.
211 | set_seed(seed)
212 | try:
213 | for response in model.chat(query, history, tokenizer, generation_config=config, stream=True):
214 | response = response.replace('|end|', '')
215 | response = response.replace('||', '')
216 | _clear_screen()
217 | print(f"\nUser: {query}")
218 | print(f"\nCodeShell-Chat: {response}")
219 | except KeyboardInterrupt:
220 | print('[WARNING] Generation interrupted')
221 | continue
222 | history.append((query, response))
223 |
224 | if __name__ == "__main__":
225 | main()
226 |
--------------------------------------------------------------------------------
/demos/openai_api.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # This code is based on Qwen's opeai_api Demo. It has been modified from
17 | # its original forms to accommodate CodeShell.
18 |
19 | # coding=utf-8
20 | # Implements API for CodeShell in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
21 | # Usage: python openai_api.py
22 | # Visit http://localhost:8000/docs for documents.
23 |
24 | import re
25 | import copy
26 | import json
27 | import time
28 | from copy import deepcopy
29 | from argparse import ArgumentParser
30 | from contextlib import asynccontextmanager
31 | from typing import Dict, List, Literal, Optional, Union
32 |
33 | import torch
34 | import uvicorn
35 | from fastapi import FastAPI, HTTPException
36 | from fastapi.middleware.cors import CORSMiddleware
37 | from pydantic import BaseModel, Field
38 | from sse_starlette.sse import EventSourceResponse
39 | from transformers import AutoTokenizer, AutoModelForCausalLM
40 | from transformers.generation import GenerationConfig
41 |
42 |
43 | def _gc(forced: bool = False):
44 | global args
45 | if args.disable_gc and not forced:
46 | return
47 |
48 | import gc
49 | gc.collect()
50 | if torch.cuda.is_available():
51 | torch.cuda.empty_cache()
52 |
53 |
54 | @asynccontextmanager
55 | async def lifespan(app: FastAPI): # collects GPU memory
56 | yield
57 | _gc(forced=True)
58 |
59 | app = FastAPI(lifespan=lifespan)
60 |
61 | app.add_middleware(
62 | CORSMiddleware,
63 | allow_origins=["*"],
64 | allow_credentials=True,
65 | allow_methods=["*"],
66 | allow_headers=["*"],
67 | )
68 |
69 |
70 | class ModelCard(BaseModel):
71 | id: str
72 | object: str = "model"
73 | created: int = Field(default_factory=lambda: int(time.time()))
74 | owned_by: str = "owner"
75 | root: Optional[str] = None
76 | parent: Optional[str] = None
77 | permission: Optional[list] = None
78 |
79 |
80 | class ModelList(BaseModel):
81 | object: str = "list"
82 | data: List[ModelCard] = []
83 |
84 | class ChatMessage(BaseModel):
85 | role: Literal["user", "assistant"]
86 | content: Optional[str]
87 |
88 | class DeltaMessage(BaseModel):
89 | role: Optional[Literal["user", "assistant", "system"]] = None
90 | content: Optional[str] = None
91 |
92 | class ChatCompletionRequest(BaseModel):
93 | model: str
94 | messages: List[ChatMessage]
95 | functions: Optional[List[Dict]] = None
96 | temperature: Optional[float] = None
97 | top_p: Optional[float] = None
98 | max_length: Optional[int] = None
99 | stream: Optional[bool] = False
100 | stop: Optional[List[str]] = None
101 |
102 |
103 | class ChatCompletionResponseChoice(BaseModel):
104 | index: int
105 | message: ChatMessage
106 | finish_reason: Literal["stop", "length"]
107 |
108 |
109 | class ChatCompletionResponseStreamChoice(BaseModel):
110 | index: int
111 | delta: DeltaMessage
112 | finish_reason: Optional[Literal["stop", "length"]]
113 |
114 |
115 | class ChatCompletionResponse(BaseModel):
116 | model: str
117 | object: Literal["chat.completion", "chat.completion.chunk"]
118 | choices: List[
119 | Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]
120 | ]
121 | created: Optional[int] = Field(default_factory=lambda: int(time.time()))
122 |
123 |
124 | @app.get("/v1/models", response_model=ModelList)
125 | async def list_models():
126 | global model_args
127 | model_card = ModelCard(id="CodeShell-7B-Chat")
128 | return ModelList(data=[model_card])
129 |
130 | def trim_stop_words(response):
131 | response = response.replace('|end|', '')
132 | response = response.replace('||', '')
133 | return response
134 |
135 | _TEXT_COMPLETION_CMD = object()
136 |
137 | #
138 | # Temporarily, the system role does not work as expected.
139 | # We advise that you write the setups for role-play in your query,
140 | # i.e., use the user role instead of the system role.
141 | #
142 | # TODO: Use real system role when the model is ready.
143 | #
144 | def parse_messages(messages):
145 | if all(m.role != "user" for m in messages):
146 | raise HTTPException(
147 | status_code=400,
148 | detail=f"Invalid request: Expecting at least one user message.",
149 | )
150 |
151 | messages = copy.deepcopy(messages)
152 |
153 | _messages = messages
154 | messages = []
155 | for m_idx, m in enumerate(_messages):
156 | role, content = m.role, m.content
157 | if content:
158 | content = content.lstrip("\n").rstrip()
159 | if role == "assistant":
160 | if len(messages) == 0:
161 | raise HTTPException(
162 | status_code=400,
163 | detail=f"Invalid request: Expecting role user before role assistant.",
164 | )
165 | last_msg = messages[-1].content
166 | last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0
167 | if messages[-1].role == "user":
168 | messages.append(
169 | ChatMessage(role="assistant", content=content.lstrip("\n").rstrip())
170 | )
171 | else:
172 | messages[-1].content += content
173 | elif role == "user":
174 | messages.append(
175 | ChatMessage(role="user", content=content.lstrip("\n").rstrip())
176 | )
177 | else:
178 | raise HTTPException(
179 | status_code=400, detail=f"Invalid request: Incorrect role {role}."
180 | )
181 |
182 | query = _TEXT_COMPLETION_CMD
183 | if messages[-1].role == "user":
184 | query = messages[-1].content
185 | messages = messages[:-1]
186 |
187 | if len(messages) % 2 != 0:
188 | raise HTTPException(status_code=400, detail="Invalid request")
189 |
190 | history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
191 | for i in range(0, len(messages), 2):
192 | if messages[i].role == "user" and messages[i + 1].role == "assistant":
193 | usr_msg = messages[i].content.lstrip("\n").rstrip()
194 | bot_msg = messages[i + 1].content.lstrip("\n").rstrip()
195 | history.append([usr_msg, bot_msg])
196 | else:
197 | raise HTTPException(
198 | status_code=400,
199 | detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.",
200 | )
201 | return query, history
202 |
203 |
204 | def parse_response(response):
205 | func_name, func_args = "", ""
206 | i = response.rfind("\nAction:")
207 | j = response.rfind("\nAction Input:")
208 | k = response.rfind("\nObservation:")
209 | if 0 <= i < j: # If the text has `Action` and `Action input`,
210 | if k < j: # but does not contain `Observation`,
211 | # then it is likely that `Observation` is omitted by the LLM,
212 | # because the output text may have discarded the stop word.
213 | response = response.rstrip() + "\nObservation:" # Add it back.
214 | k = response.rfind("\nObservation:")
215 | func_name = response[i + len("\nAction:") : j].strip()
216 | func_args = response[j + len("\nAction Input:") : k].strip()
217 | if func_name:
218 | choice_data = ChatCompletionResponseChoice(
219 | index=0,
220 | message=ChatMessage(
221 | role="assistant",
222 | content=response[:i],
223 | function_call={"name": func_name, "arguments": func_args},
224 | ),
225 | finish_reason="function_call",
226 | )
227 | return choice_data
228 | z = response.rfind("\nFinal Answer: ")
229 | if z >= 0:
230 | response = response[z + len("\nFinal Answer: ") :]
231 | choice_data = ChatCompletionResponseChoice(
232 | index=0,
233 | message=ChatMessage(role="assistant", content=response),
234 | finish_reason="stop",
235 | )
236 | return choice_data
237 |
238 |
239 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
240 | async def create_chat_completion(request: ChatCompletionRequest):
241 | global model, tokenizer
242 |
243 | generation_config = copy.deepcopy(model.generation_config)
244 | if request.temperature is not None:
245 | if request.temperature < 0.01:
246 | generation_config.top_k = 1 # greedy decoding
247 | else:
248 | # Not recommended. Please tune top_p instead.
249 | generation_config.temperature = request.temperature
250 | if request.top_p is not None:
251 | generation_config.top_p = request.top_p
252 |
253 | query, history = parse_messages(request.messages)
254 |
255 | if request.stream:
256 | if request.functions:
257 | raise HTTPException(
258 | status_code=400,
259 | detail="Invalid request: Function calling is not yet implemented for stream mode.",
260 | )
261 | generate = predict(query, history, request.model, generation_config)
262 | return EventSourceResponse(generate, media_type="text/event-stream")
263 |
264 | if query is _TEXT_COMPLETION_CMD:
265 | raise HTTPException(
266 | status_code=400,
267 | detail="Invalid request: COMPLETION model not supported now.",
268 | )
269 | else:
270 | response = model.chat(
271 | query,
272 | history,
273 | tokenizer,
274 | generation_config=generation_config
275 | )
276 | print(f"\n{history}\n{query}\n\n{response}\n")
277 | _gc()
278 |
279 | response = trim_stop_words(response)
280 | if request.functions:
281 | choice_data = parse_response(response)
282 | else:
283 | choice_data = ChatCompletionResponseChoice(
284 | index=0,
285 | message=ChatMessage(role="assistant", content=response),
286 | finish_reason="stop",
287 | )
288 | return ChatCompletionResponse(
289 | model=request.model, choices=[choice_data], object="chat.completion"
290 | )
291 |
292 |
293 | async def predict(
294 | query: str, history: List[List[str]], model_id: str, generation_config: GenerationConfig
295 | ):
296 | global model, tokenizer
297 | choice_data = ChatCompletionResponseStreamChoice(
298 | index=0, delta=DeltaMessage(role="assistant"), finish_reason=None
299 | )
300 | chunk = ChatCompletionResponse(
301 | model=model_id, choices=[choice_data], object="chat.completion.chunk"
302 | )
303 | yield "{}".format(chunk.model_dump_json(exclude_unset=True))
304 |
305 | current_length = 0
306 | response_generator = model.chat(query, history, tokenizer, stream=True, generation_config=generation_config)
307 |
308 | for new_response in response_generator:
309 | if len(new_response) == current_length:
310 | continue
311 |
312 | new_text = new_response[current_length:]
313 | current_length = len(new_response)
314 |
315 | choice_data = ChatCompletionResponseStreamChoice(
316 | index=0, delta=DeltaMessage(content=new_text), finish_reason=None
317 | )
318 | chunk = ChatCompletionResponse(
319 | model=model_id, choices=[choice_data], object="chat.completion.chunk"
320 | )
321 | yield "{}".format(chunk.model_dump_json(exclude_unset=True))
322 |
323 | choice_data = ChatCompletionResponseStreamChoice(
324 | index=0, delta=DeltaMessage(), finish_reason="stop"
325 | )
326 | chunk = ChatCompletionResponse(
327 | model=model_id, choices=[choice_data], object="chat.completion.chunk"
328 | )
329 | yield "{}".format(chunk.model_dump_json(exclude_unset=True))
330 | yield "[DONE]"
331 |
332 | _gc()
333 |
334 |
335 | def _get_args():
336 | parser = ArgumentParser()
337 | parser.add_argument(
338 | "-c",
339 | "--checkpoint-path",
340 | type=str,
341 | default="WisdomShell/CodeShell-7B-Chat",
342 | help="Checkpoint name or path, default to %(default)r",
343 | )
344 | parser.add_argument("--device", type=str, default="cuda:0",help="Device name.")
345 | parser.add_argument(
346 | "--server-port", type=int, default=8000, help="Demo server port."
347 | )
348 | parser.add_argument(
349 | "--server-name",
350 | type=str,
351 | default="127.0.0.1",
352 | help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer."
353 | " If you want other computers to access your server, use 0.0.0.0 instead.",
354 | )
355 | parser.add_argument("--disable-gc", action="store_true",
356 | help="Disable GC after each response generated.")
357 |
358 | args = parser.parse_args()
359 | return args
360 |
361 |
362 | if __name__ == "__main__":
363 | args = _get_args()
364 |
365 | tokenizer = AutoTokenizer.from_pretrained(
366 | args.checkpoint_path,
367 | trust_remote_code=True,
368 | resume_download=True,
369 | )
370 |
371 | model = AutoModelForCausalLM.from_pretrained(
372 | args.checkpoint_path,
373 | device_map=args.device,
374 | trust_remote_code=True,
375 | resume_download=True,
376 | torch_dtype=torch.bfloat16
377 | ).eval()
378 |
379 | model.generation_config = GenerationConfig.from_pretrained(
380 | args.checkpoint_path,
381 | trust_remote_code=True,
382 | resume_download=True,
383 | )
384 |
385 | uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1)
--------------------------------------------------------------------------------
/demos/requirements_web_demo.txt:
--------------------------------------------------------------------------------
1 | gradio<3.42
2 | mdtex2html
--------------------------------------------------------------------------------
/demos/web_demo.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # This code is based on Qwen's web Demo. It has been modified from
17 | # its original forms to accommodate CodeShell.
18 |
19 | # Copyright (c) Alibaba Cloud.
20 | #
21 | # This source code is licensed under the license found in the
22 | # LICENSE file in the root directory of this source tree.
23 |
24 | """A simple web interactive chat demo based on gradio."""
25 | import os
26 | from argparse import ArgumentParser
27 |
28 | import gradio as gr
29 | import mdtex2html
30 |
31 | import torch
32 | from transformers import AutoModelForCausalLM, AutoTokenizer
33 | from transformers.generation import GenerationConfig
34 |
35 |
36 | DEFAULT_CKPT_PATH = 'WisdomShell/CodeShell-7B-Chat'
37 |
38 |
39 | def _get_args():
40 | parser = ArgumentParser()
41 | parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
42 | help="Checkpoint name or path, default to %(default)r")
43 | parser.add_argument("--device", type=str, default="cuda:0", help="GPU device.")
44 |
45 | parser.add_argument("--share", action="store_true", default=False,
46 | help="Create a publicly shareable link for the interface.")
47 | parser.add_argument("--inbrowser", action="store_true", default=False,
48 | help="Automatically launch the interface in a new tab on the default browser.")
49 | parser.add_argument("--server-port", type=int, default=8000,
50 | help="Demo server port.")
51 | parser.add_argument("--server-name", type=str, default="127.0.0.1",
52 | help="Demo server name.")
53 |
54 | args = parser.parse_args()
55 | return args
56 |
57 |
58 | def _load_model_tokenizer(args):
59 | tokenizer = AutoTokenizer.from_pretrained(
60 | args.checkpoint_path, trust_remote_code=True, resume_download=True,
61 | )
62 |
63 | model = AutoModelForCausalLM.from_pretrained(
64 | args.checkpoint_path,
65 | device_map=args.device,
66 | trust_remote_code=True,
67 | resume_download=True,
68 | torch_dtype=torch.bfloat16
69 | ).eval()
70 |
71 | config = GenerationConfig.from_pretrained(
72 | args.checkpoint_path, trust_remote_code=True, resume_download=True,
73 | )
74 |
75 | return model, tokenizer, config
76 |
77 |
78 | def postprocess(self, y):
79 | if y is None:
80 | return []
81 | for i, (message, response) in enumerate(y):
82 | y[i] = (
83 | None if message is None else mdtex2html.convert(message),
84 | None if response is None else mdtex2html.convert(response),
85 | )
86 | return y
87 |
88 |
89 | gr.Chatbot.postprocess = postprocess
90 |
91 |
92 | def _parse_text(text):
93 | lines = text.split("\n")
94 | lines = [line for line in lines if line != ""]
95 | count = 0
96 | for i, line in enumerate(lines):
97 | if "```" in line:
98 | count += 1
99 | items = line.split("`")
100 | if count % 2 == 1:
101 | lines[i] = f''
102 | else:
103 | lines[i] = f"
"
104 | else:
105 | if i > 0:
106 | if count % 2 == 1:
107 | line = line.replace("`", r"\`")
108 | line = line.replace("<", "<")
109 | line = line.replace(">", ">")
110 | line = line.replace(" ", " ")
111 | line = line.replace("*", "*")
112 | line = line.replace("_", "_")
113 | line = line.replace("-", "-")
114 | line = line.replace(".", ".")
115 | line = line.replace("!", "!")
116 | line = line.replace("(", "(")
117 | line = line.replace(")", ")")
118 | line = line.replace("$", "$")
119 | lines[i] = "
" + line
120 | text = "".join(lines)
121 | return text
122 |
123 |
124 | def _gc():
125 | import gc
126 | gc.collect()
127 | if torch.cuda.is_available():
128 | torch.cuda.empty_cache()
129 |
130 |
131 | def _launch_demo(args, model, tokenizer, config):
132 |
133 | def predict(_query, _chatbot, _task_history):
134 | print(f"User: {_parse_text(_query)}")
135 | _chatbot.append((_parse_text(_query), ""))
136 | full_response = ""
137 |
138 | for response in model.chat(_query, _task_history, tokenizer, generation_config=config, stream=True):
139 | response = response.replace('|end|', '')
140 | response = response.replace('||', '')
141 | _chatbot[-1] = (_parse_text(_query), _parse_text(response))
142 |
143 | yield _chatbot
144 | full_response = _parse_text(response)
145 |
146 | print(f"History: {_task_history}")
147 | _task_history.append((_query, full_response))
148 | print(f"CodeShell-Chat: {_parse_text(full_response)}")
149 |
150 | def regenerate(_chatbot, _task_history):
151 | if not _task_history:
152 | yield _chatbot
153 | return
154 | item = _task_history.pop(-1)
155 | _chatbot.pop(-1)
156 | yield from predict(item[0], _chatbot, _task_history)
157 |
158 | def reset_user_input():
159 | return gr.update(value="")
160 |
161 | def reset_state(_chatbot, _task_history):
162 | _task_history.clear()
163 | _chatbot.clear()
164 | _gc()
165 | return _chatbot
166 |
167 | with gr.Blocks() as demo:
168 | gr.Markdown("""CodeShell-Chat Bot""")
169 |
170 | chatbot = gr.Chatbot(label='CodeShell-Chat', elem_classes="control-height")
171 | query = gr.Textbox(lines=2, label='Input')
172 | task_history = gr.State([])
173 |
174 | with gr.Row():
175 | empty_btn = gr.Button("🧹 Clear History (清除历史)")
176 | submit_btn = gr.Button("🚀 Submit (发送)")
177 | regen_btn = gr.Button("🤔️ Regenerate (重试)")
178 |
179 | submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True)
180 | submit_btn.click(reset_user_input, [], [query])
181 | empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True)
182 | regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
183 |
184 | gr.Markdown("""\
185 | Note: This demo is governed by the original license of CodeShell. \
186 | We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
187 | including hate speech, violence, pornography, deception, etc. \
188 | (注:本演示受CodeShell的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
189 | 包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
190 |
191 | demo.queue().launch(
192 | share=args.share,
193 | inbrowser=args.inbrowser,
194 | server_port=args.server_port,
195 | server_name=args.server_name,
196 | )
197 |
198 |
199 | def main():
200 | args = _get_args()
201 |
202 | model, tokenizer, config = _load_model_tokenizer(args)
203 |
204 | _launch_demo(args, model, tokenizer, config)
205 |
206 |
207 | if __name__ == '__main__':
208 | main()
--------------------------------------------------------------------------------
/evaluation/README.md:
--------------------------------------------------------------------------------
1 | # 代码评估
2 | 本文档将向您完整地介绍codeshell的代码评估过程,该评估脚本都是基于[bigcode-evaluation-harness](https://github.com/bigcode-project/bigcode-evaluation-harness)。
3 |
4 | ## 开始步骤
5 |
6 | 首先,复制bigcode-evaluation-harness仓库并导航至所在的文件夹内:
7 |
8 | ```bash
9 | git clone https://github.com/bigcode-project/bigcode-evaluation-harness.git
10 | cd bigcode-evaluation-harness
11 | ```
12 |
13 | 接下来,依照您设备的规格,安装PyTorch,然后运行以下命令安装剩余的依赖:
14 |
15 | ```bash
16 | pip install -e .
17 | ```
18 |
19 | 要使用评估脚本生成和评估任务,请按下述样例进行。确保您位于正确的目录中(`codeshell/evaluation`),然后依次执行两个 `run_eval.sh` 命令:
20 |
21 | ```bash
22 | cd codeshell/evaluation
23 | ./run_eval.sh local_gen humaneval $model_name_or_path $save_folder
24 | ./run_eval.sh eval humaneval $model_name_or_path $save_folder
25 | ```
26 |
--------------------------------------------------------------------------------
/evaluation/README_EN.md:
--------------------------------------------------------------------------------
1 | # Evaluation
2 | This guide introduces the evaluation process of codeshell. The evaluation script is base on [bigcode-evaluation-harness](https://github.com/bigcode-project/bigcode-evaluation-harness).
3 |
4 | ## Quick Start
5 |
6 | Begin by cloning the bigcode-evaluation-harness repository and entering the directory:
7 |
8 | ```bash
9 | git clone https://github.com/bigcode-project/bigcode-evaluation-harness.git
10 | cd bigcode-evaluation-harness
11 | ```
12 |
13 | Next, install PyTorch according to your device specifications, and then install the remaining packages using the command:
14 |
15 | ```bash
16 | pip install -e .
17 | ```
18 |
19 | To generate and evaluate tasks with the evaluation script, follow the example below. Ensure you are in the appropriate directory (`codeshell/evaluation`), then execute the two `run_eval.sh` commands:
20 |
21 | ```bash
22 | cd codeshell/evaluation
23 | ./run_eval.sh local_gen humaneval $model_name_or_path $save_folder
24 | ./run_eval.sh eval humaneval $model_name_or_path $save_folder
25 | ```
26 |
27 | By following this guide, you can now effectively utilize the bigcode-evaluation-harness to evaluate your model's performance on specific tasks.
--------------------------------------------------------------------------------
/evaluation/all_config.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | distributed_type: MULTI_GPU
3 | downcast_bf16: 'no'
4 | gpu_ids: all
5 | machine_rank: 0
6 | main_training_function: main
7 | mixed_precision: bf16
8 | num_machines: 1
9 | num_processes: 8
10 | rdzv_backend: static
11 | same_network: true
12 | tpu_env: []
13 | tpu_use_cluster: false
14 | tpu_use_sudo: false
15 | use_cpu: false
16 |
--------------------------------------------------------------------------------
/evaluation/chat_humaneval.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Evaluating Large Language Models Trained on Code
17 | https://arxiv.org/abs/2107.03374
18 |
19 | The HumanEval dataset released by OpenAI includes 164 programming problems with a function signature,
20 | docstring, body, and several unit tests.
21 | They were handwritten to ensure not to be included in the training set of code generation models.
22 |
23 | Homepage: https://github.com/openai/human-eval
24 | """
25 |
26 | from evaluate import load
27 |
28 | from lm_eval.base import Task
29 |
30 | _CITATION = """
31 | @misc{chen2021evaluating,
32 | title={Evaluating Large Language Models Trained on Code},
33 | author={Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba},
34 | year={2021},
35 | eprint={2107.03374},
36 | archivePrefix={arXiv},
37 | primaryClass={cs.LG}
38 | }
39 | """
40 |
41 |
42 | class ChatHumanEval(Task):
43 | """A task represents an entire benchmark including its dataset, problems,
44 | answers, generation settings and evaluation methods.
45 | """
46 |
47 | DATASET_PATH = "openai_humaneval"
48 |
49 | def __init__(self):
50 | super().__init__(
51 | stop_words=["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif", "\n\n\n\n", "\n\n\n"],
52 | requires_execution=True,
53 | )
54 |
55 | def get_dataset(self):
56 | """Returns dataset for the task or an iterable of any object, that get_prompt can handle"""
57 | return self.dataset["test"]
58 |
59 | def get_prompt(self, doc):
60 | """Builds the prompt for the LM to generate from."""
61 | # return doc["prompt"].strip()
62 | PROMPT = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n### Instruction:\nCreate a Python script for this problem:\n{Question}\n### Response:\n{Question}"""
63 |
64 | return PROMPT.format(Question=doc["prompt"].strip())
65 |
66 | def get_reference(self, doc):
67 | """Builds the reference solution for the doc (sample from the test dataset)."""
68 | test_func = doc["test"]
69 | entry_point = f"check({doc['entry_point']})"
70 | return "\n" + test_func + "\n" + entry_point
71 |
72 | @staticmethod
73 | def _stop_at_stop_token(decoded_string, stop_tokens):
74 | """
75 | Produces the prefix of decoded_string that ends at the first occurrence of
76 | a stop_token.
77 | WARNING: the decoded_string *must not* include the prompt, which may have stop tokens
78 | itself.
79 | """
80 | min_stop_index = len(decoded_string)
81 | for stop_token in stop_tokens:
82 | stop_index = decoded_string.find(stop_token)
83 | if stop_index != -1 and stop_index < min_stop_index:
84 | min_stop_index = stop_index
85 | return decoded_string[:min_stop_index]
86 |
87 | def postprocess_generation(self, generation, idx):
88 | """Defines the postprocessing for a LM generation.
89 | :param generation: str
90 | code generation from LM
91 | :param idx: int
92 | index of doc in the dataset to which the generation belongs
93 | (not used for Humaneval-Task)
94 | """
95 | prompt = self.get_prompt(self.dataset["test"][idx])
96 | # print("before postprocessing", generation)
97 | # print("prefix", prompt)
98 | generation = generation[len(prompt) :]
99 | # print("after postprocessing", generation)
100 | return self.dataset["test"][idx]["prompt"] + self._stop_at_stop_token(generation, self.stop_words)
101 |
102 | def process_results(self, generations, references):
103 | """Takes the list of LM generations and evaluates them against ground truth references,
104 | returning the metric for the generations.
105 | :param generations: list(list(str))
106 | list of lists containing generations
107 | :param references: list(str)
108 | list of str containing refrences
109 | """
110 | code_metric = load("code_eval")
111 | results, _ = code_metric.compute(
112 | references=references,
113 | predictions=generations,
114 | )
115 | return results
116 |
--------------------------------------------------------------------------------
/evaluation/eval.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import fnmatch
17 | import json
18 |
19 | import datasets
20 | import torch
21 | import transformers
22 | from accelerate import Accelerator
23 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
24 |
25 | from lm_eval.arguments import EvalArguments
26 | from lm_eval.evaluator import Evaluator
27 |
28 | from lm_eval.tasks import (humaneval, mbpp, multiple)
29 | from .chat_humaneval import ChatHumanEval
30 |
31 | TASK_REGISTRY = {
32 | **multiple.create_all_tasks(),
33 | "humaneval": humaneval.HumanEval,
34 | "mbpp": mbpp.MBPP,
35 | "chat-humaneval": chat_humaneval.ChatHumanEval,
36 | }
37 |
38 | ALL_TASKS = sorted(list(TASK_REGISTRY))
39 |
40 | class MultiChoice:
41 | def __init__(self, choices):
42 | self.choices = choices
43 |
44 | # Simple wildcard support (linux filename patterns)
45 | def __contains__(self, values):
46 | for value in values.split(","):
47 | if len(fnmatch.filter(self.choices, value)) == 0:
48 | return False
49 |
50 | return True
51 |
52 | def __iter__(self):
53 | for choice in self.choices:
54 | yield choice
55 |
56 |
57 | def parse_args():
58 | parser = HfArgumentParser(EvalArguments)
59 |
60 | parser.add_argument(
61 | "--model",
62 | default="codeparrot/codeparrot-small",
63 | help="Model to evaluate, provide a repo name in Hugging Face hub or a local path",
64 | )
65 | parser.add_argument(
66 | "--tasks",
67 | default=None,
68 | choices=MultiChoice(ALL_TASKS),
69 | help=f"Evaluation tasks from {ALL_TASKS}",
70 | )
71 | parser.add_argument(
72 | "--batch_size",
73 | type=int,
74 | default=1,
75 | help="Batch size for evaluation on each worker, can be larger for HumanEval",
76 | )
77 | parser.add_argument(
78 | "--max_length_generation",
79 | type=int,
80 | default=1024,
81 | help="Maximum length of generated sequence (prompt+generation)",
82 | )
83 | parser.add_argument(
84 | "--precision",
85 | type=str,
86 | default="bf16",
87 | help="Model precision, from: fp32, fp16 or bf16",
88 | )
89 | parser.add_argument(
90 | "--limit",
91 | type=int,
92 | default=None,
93 | help="Number of samples to solve and evaluate from the benchmark",
94 | )
95 | parser.add_argument(
96 | "--postprocess",
97 | action="store_false",
98 | help="Postprocess model outputs before execution, always on except during generation tests",
99 | )
100 | parser.add_argument(
101 | "--allow_code_execution",
102 | action="store_true",
103 | help="Allow code evaluation to execute external/untrusted Python code on your machine",
104 | )
105 | parser.add_argument(
106 | "--generation_only",
107 | action="store_true",
108 | help="Do code generation but no evaluation",
109 | )
110 | parser.add_argument(
111 | "--load_generations_path",
112 | type=str,
113 | default=None,
114 | help="Path of file with previously generated solutions, if provided generation is skipped and only evaluation is done",
115 | )
116 | parser.add_argument(
117 | "--metric_output_path",
118 | type=str,
119 | default="evaluation_results.json",
120 | help="Path to save the results",
121 | )
122 | parser.add_argument(
123 | "--save_generations",
124 | action="store_true",
125 | help="Whether to save code generations",
126 | )
127 | parser.add_argument(
128 | "--save_generations_path",
129 | type=str,
130 | default="generations.json",
131 | help="Path for saving the code generations",
132 | )
133 | parser.add_argument(
134 | "--save_references",
135 | action="store_true",
136 | help="Whether to save reference solutions/tests",
137 | )
138 | parser.add_argument(
139 | "--save_references_path",
140 | type=str,
141 | default="references.json",
142 | help="Path for saving the code references",
143 | )
144 | return parser.parse_args()
145 |
146 |
147 | def pattern_match(patterns, source_list):
148 | """Returns a list containing all values of the source_list that
149 | match at least one of the patterns"""
150 | task_names = set()
151 | for pattern in patterns:
152 | for matching in fnmatch.filter(source_list, pattern):
153 | task_names.add(matching)
154 | return list(task_names)
155 |
156 |
157 | def main():
158 | args = parse_args()
159 | transformers.logging.set_verbosity_error()
160 | datasets.logging.set_verbosity_error()
161 |
162 | task_names = pattern_match(args.tasks.split(","), ALL_TASKS)
163 |
164 | accelerator = Accelerator()
165 | if accelerator.is_main_process:
166 | print(f"Selected Tasks: {task_names}")
167 |
168 | results = {}
169 | if args.load_generations_path:
170 | # here we don't generate code but only evaluate previously computed generations
171 | if accelerator.is_main_process:
172 | print("evaluation only mode")
173 | evaluator = Evaluator(accelerator, None, None, args)
174 | for task in task_names:
175 | results[task] = evaluator.evaluate(task)
176 | else:
177 | # here we generate code and save it (evaluation is optional but True by default)
178 | dict_precisions = {
179 | "fp32": torch.float32,
180 | "fp16": torch.float16,
181 | "bf16": torch.bfloat16,
182 | }
183 | if args.precision not in dict_precisions:
184 | raise ValueError(
185 | f"Non valid precision {args.precision}, choose from: fp16, fp32, bf16"
186 | )
187 | print(f"Loading tokenizer and model (in {args.precision})")
188 | model = AutoModelForCausalLM.from_pretrained(
189 | args.model,
190 | revision=args.revision,
191 | torch_dtype=dict_precisions[args.precision],
192 | trust_remote_code=True,
193 | )
194 |
195 | tokenizer_path = args.model
196 | tokenizer = AutoTokenizer.from_pretrained(
197 | tokenizer_path,
198 | revision=args.revision,
199 | trust_remote_code=True,
200 | )
201 | if not tokenizer.eos_token:
202 | if tokenizer.bos_token:
203 | tokenizer.eos_token = tokenizer.bos_token
204 | print("bos_token used as eos_token")
205 | else:
206 | raise ValueError("No eos_token or bos_token found")
207 | if not tokenizer.pad_token:
208 | tokenizer.pad_token = tokenizer.eos_token
209 |
210 | evaluator = Evaluator(accelerator, model, tokenizer, args)
211 |
212 | for task in task_names:
213 | if args.generation_only:
214 | if accelerator.is_main_process:
215 | print("generation mode only")
216 | generations, references = evaluator.generate_text(task)
217 | if accelerator.is_main_process:
218 | with open(args.save_generations_path, "w", encoding="utf-8") as fp:
219 | json.dump(generations, fp, indent=4, ensure_ascii=False)
220 | print(f"generations were saved at {args.save_generations_path}")
221 | if args.save_references:
222 | with open(args.save_references_path, "w", encoding="utf-8") as fp:
223 | json.dump(references, fp, indent=4, ensure_ascii=False)
224 | print("references were saved")
225 | else:
226 | results[task] = evaluator.evaluate(task)
227 |
228 | results["config"] = {
229 | "model": args.model,
230 | "revision": args.revision,
231 | "temperature": args.temperature,
232 | "n_samples": args.n_samples,
233 | }
234 | if not args.generation_only:
235 | dumped = json.dumps(results, indent=2)
236 | if accelerator.is_main_process:
237 | print(dumped)
238 |
239 | with open(args.metric_output_path, "w") as f:
240 | f.write(dumped)
241 |
242 | if __name__ == "__main__":
243 | main()
244 |
--------------------------------------------------------------------------------
/evaluation/run_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export TOKENIZERS_PARALLELISM=false
3 | export TRANSFORMERS_OFFLINE=0
4 |
5 | function local_gen() {
6 | project_dir=$(cd "$(dirname $0)"; pwd)
7 | accelerate_args="--main_process_port=$((10000 + RANDOM % 20000)) --config_file=$project_dir/all_config.yaml"
8 |
9 | dataset=$1
10 | model_name_or_path=$2
11 | run_id=$3
12 | n_samples=40
13 |
14 | batch_size=10
15 |
16 | mkdir -p $project_dir/log/$run_id/$dataset
17 |
18 | accelerate launch $accelerate_args eval.py \
19 | --model $model_name_or_path \
20 | --tasks $dataset \
21 | --max_length_generation 2048 \
22 | --temperature 0.2 \
23 | --precision bf16 \
24 | --do_sample True \
25 | --n_samples $n_samples \
26 | --batch_size $batch_size \
27 | --save_generations \
28 | --save_references \
29 | --generation_only \
30 | --save_generations_path $project_dir/log/$run_id/$dataset/generations.json \
31 | --save_references_path $project_dir/log/$run_id/$dataset/references.json \
32 | --metric_output_path $project_dir/log/$run_id/$dataset/evaluation.json \
33 | | tee $project_dir/log/$run_id/$dataset/evaluation.log 2>&1
34 |
35 | }
36 |
37 | function eval() {
38 | dataset=$1
39 | model_name_or_path=$2
40 | run_id=$3
41 |
42 | project_dir=$(cd "$(dirname $0)"; pwd)
43 |
44 | python3 eval.py \
45 | --model $model_name_or_path \
46 | --tasks $dataset \
47 | --load_generations_path $project_dir/log/$run_id/$dataset/generations.json \
48 | --allow_code_execution \
49 | --n_samples 40 \
50 | --metric_output_path $project_dir/log/$run_id/$dataset/evaluation.json
51 |
52 | }
53 |
54 | task=$1
55 | dataset=$2
56 | model_name_or_path=$3
57 | run_id=$4
58 |
59 | if [ $task == "local_gen" ]; then
60 | local_gen $dataset $model_name_or_path $run_id
61 | elif [ $task == "eval" ]; then
62 | eval $dataset $model_name_or_path $run_id
63 | elif [ $task == "help" ]; then
64 | echo "./scripts/run_eval.sh [local_gen|eval] [humaneval|mbpp|chat-humaneval|multiple-*] model_name_or_path run_id"
65 | else
66 | echo "task should be local_gen or eval"
67 | fi
--------------------------------------------------------------------------------
/finetune/README.md:
--------------------------------------------------------------------------------
1 | 该文档为希望在特定领域任务中应用CodeShell模型的用户提供了官方微调示例。
2 |
3 | 开始前,您需要通过执行以下命令来配置必要的环境:
4 | ```bash
5 | pip install peft deepspeed
6 | ```
7 |
8 | 您需要按照 JSON 格式整理训练数据,其中每个样本是一个包含 ID 和对话列表的字典。该对话列表是消息对象的数组,代表了用户和助手之间的交谈。如下所示为一个样例:
9 |
10 | ```json
11 | [
12 | {
13 | "id": "identity_0",
14 | "conversations": [
15 | {
16 | "from": "human",
17 | "value": "你好"
18 | },
19 | {
20 | "from": "assistant",
21 | "value": "您好,我是CodeShell,请问有什么可以帮助您的吗?"
22 | }
23 | ]
24 | }
25 | ]
26 | ```
27 |
28 | 当数据准备完毕后,导航至微调的目录并执行 `run_finetune.sh` 脚本,命令如下:
29 |
30 | ```bash
31 | cd codeshell/finetune
32 | ./run_finetune.sh $model_name_or_path $dataset_path $save_path
33 | ```
34 |
35 | 按照这些步骤操作,您可以将预训练的模型微调,使其更加精确地适应您的特定任务。
36 |
37 | 该微调脚本基于qwen、fastchat 和 tatsu-lab/stanford_alpaca 的微调脚本。
--------------------------------------------------------------------------------
/finetune/README_EN.md:
--------------------------------------------------------------------------------
1 | In this guide, we present the official fine-tuning script for users who wish to adapt pretrained models for their domain-specific tasks.
2 |
3 | To begin, please set up the required environment by executing the following command:
4 | ```bash
5 | pip install peft deepspeed
6 | ```
7 |
8 | The training data should be organized in JSON format, with each sample being a dictionary containing an ID and a conversation list. The conversation list is an array of message objects, representing the conversation between the user and the assistant. See the example below:
9 |
10 | ```json
11 | [
12 | {
13 | "id": "identity_0",
14 | "conversations": [
15 | {
16 | "from": "human",
17 | "value": "你好"
18 | },
19 | {
20 | "from": "assistant",
21 | "value": "您好,我是CodeShell,请问有什么可以帮助您的吗?"
22 | }
23 | ]
24 | }
25 | ]
26 | ```
27 |
28 | Once the data is prepared, navigate to the fine-tuning directory and execute the `run_finetune.sh` script using the following command:
29 |
30 | ```bash
31 | cd codeshell/finetune
32 | ./run_finetune.sh $model_name_or_path $dataset_path $save_path
33 | ```
34 |
35 | By following these instructions, you'll be able to fine-tune the pretrained model to better suit your specific downstream tasks.
36 |
37 | This code is based on the revised code from qwen, fastchat and tatsu-lab/stanford_alpaca.
--------------------------------------------------------------------------------
/finetune/ds_config_zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 |
11 | "bf16": {
12 | "enabled": "auto"
13 | },
14 |
15 | "zero_optimization": {
16 | "stage": 3,
17 | "offload_optimizer": {
18 | "device": "none",
19 | "pin_memory": true
20 | },
21 | "offload_param": {
22 | "device": "none",
23 | "pin_memory": true
24 | },
25 | "overlap_comm": true,
26 | "contiguous_gradients": true,
27 | "sub_group_size": 1e9,
28 | "reduce_bucket_size": "auto",
29 | "stage3_prefetch_bucket_size": "auto",
30 | "stage3_param_persistence_threshold": "auto",
31 | "stage3_max_live_parameters": 1e9,
32 | "stage3_max_reuse_distance": 1e9,
33 | "stage3_gather_16bit_weights_on_model_save": true
34 | },
35 |
36 | "gradient_accumulation_steps": "auto",
37 | "gradient_clipping": "auto",
38 | "steps_per_print": 2000,
39 | "train_batch_size": "auto",
40 | "train_micro_batch_size_per_gpu": "auto",
41 | "wall_clock_breakdown": false
42 | }
43 |
--------------------------------------------------------------------------------
/finetune/finetune.py:
--------------------------------------------------------------------------------
1 | import json
2 | from dataclasses import dataclass, field
3 | from typing import Dict, Optional
4 |
5 | import torch
6 | import transformers
7 | from torch.utils.data import Dataset
8 | from tqdm import tqdm
9 | from transformers.training_args import TrainingArguments
10 |
11 |
12 | @dataclass
13 | class ModelArguments:
14 | model_name_or_path: Optional[str] = field(default="gpt2")
15 |
16 |
17 | @dataclass
18 | class DataArguments:
19 | data_path: str = field(
20 | default=None, metadata={"help": "Path to the training data."}
21 | )
22 |
23 |
24 | @dataclass
25 | class TrainingArguments(transformers.TrainingArguments):
26 | cache_dir: Optional[str] = field(default=None)
27 | optim: str = field(default="adamw_torch")
28 | model_max_length: int = field(
29 | default=512,
30 | metadata={
31 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
32 | },
33 | )
34 | use_lora: bool = field(default=False)
35 |
36 |
37 | class SupervisedDataset(Dataset):
38 | """Dataset for supervised fine-tuning."""
39 |
40 | def __init__(
41 | self,
42 | data_path,
43 | tokenizer,
44 | model_max_length,
45 | user_string="## human:",
46 | copilot_string="## copilot:",
47 | assistant_string="## assistant:",
48 | end_string=" || ",
49 | ):
50 | super(SupervisedDataset, self).__init__()
51 | self.data = json.load(open(data_path))
52 | # self.data = self.data[:1000]
53 | self.tokenizer = tokenizer
54 | self.model_max_length = model_max_length
55 | self.user_string = user_string
56 | self.assistant_string = assistant_string
57 | self.end_string = end_string
58 | self.user_tokens = self.tokenizer.encode(user_string)
59 | self.copilot_tokens = self.tokenizer.encode(copilot_string)
60 | self.assistant_tokens = self.tokenizer.encode(assistant_string)
61 | self.end_tokens = self.tokenizer.encode(end_string)
62 | self.ignore_index = -100
63 |
64 | self.preprocessed_data = self.preprocessing()
65 | item = self.preprocessed_data[0]
66 | print("input:", self.tokenizer.decode(item["input_ids"]))
67 | labels = []
68 | for id_ in item["labels"]:
69 | if id_ == -100:
70 | continue
71 |
72 | labels.append(id_)
73 | print("label:", self.tokenizer.decode(labels))
74 |
75 | def __len__(self):
76 | return len(self.preprocessed_data)
77 |
78 | def preprocessing(self):
79 | preprocessed_data = []
80 | for example in tqdm(self.data, desc="Preprocessing"):
81 | preprocess_example = self.preprocess_one(example)
82 | if len(preprocess_example["input_ids"]) <= 16:
83 | continue
84 | preprocessed_data.append(preprocess_example)
85 | return preprocessed_data
86 |
87 | def preprocess_one(self, example):
88 | input_ids = []
89 | labels = []
90 |
91 | chat_mode = "human"
92 | if "copilot" in [message["from"] for message in example["conversations"]]:
93 | chat_mode = "copilot"
94 |
95 | if chat_mode == "human":
96 | for idx, message in enumerate(example["conversations"]):
97 | if idx == 0:
98 | input_ids += [self.tokenizer.eos_token_id]
99 | labels += [self.ignore_index]
100 | from_ = message["from"]
101 | value = message["value"]
102 | value_ids = self.tokenizer.encode(value)
103 |
104 | if len(input_ids) + len(self.user_tokens + value_ids + self.end_tokens) > self.model_max_length:
105 | break
106 |
107 | if from_ == "human":
108 | input_ids += self.user_tokens + value_ids + self.end_tokens
109 | labels += [self.ignore_index] * len(
110 | self.user_tokens + value_ids + self.end_tokens
111 | )
112 | else:
113 | input_ids += self.assistant_tokens + value_ids + self.end_tokens
114 | labels += [self.ignore_index] * len(self.assistant_tokens) \
115 | + value_ids + self.end_tokens
116 | elif chat_mode == "copilot":
117 | for idx, message in enumerate(example["conversations"]):
118 | from_ = message["from"]
119 | value = message["value"]
120 | value_ids = self.tokenizer.encode(value)
121 |
122 | if len(input_ids) + len(value_ids) > self.model_max_length:
123 | break
124 |
125 | if from_ == "copilot":
126 | input_ids += value_ids
127 | labels += [self.ignore_index] * len(value_ids)
128 | else:
129 | input_ids += value_ids + [self.tokenizer.eos_token_id]
130 | labels += value_ids + [self.tokenizer.eos_token_id]
131 | else:
132 | raise ValueError("chat_mode should be human or copilot")
133 |
134 | input_ids = input_ids[-self.model_max_length:]
135 | labels = labels[-self.model_max_length:]
136 | input_ids = torch.LongTensor(input_ids)
137 | labels = torch.LongTensor(labels)
138 | attention_mask = input_ids.ne(self.tokenizer.eos_token_id)
139 | return {
140 | "input_ids": input_ids,
141 | "labels": labels,
142 | "attention_mask": attention_mask,
143 | }
144 |
145 | def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
146 | # print(self.preprocessed_data[idx]["input_ids"].shape)
147 | return self.preprocessed_data[idx]
148 |
149 | def print_dataset_example(self, num=3):
150 | for idx in range(num):
151 | example = self.preprocessed_data[idx]
152 | print("input_ids:\n{}".format(example["input_ids"]))
153 | print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
154 | print("label_ids:\n{}".format(example["labels"]))
155 | print("labels:\n{}".format(
156 | self.tokenizer.decode([d if d != self.ignore_index else self.tokenizer.eos_token_id for d in example["labels"]],
157 | skip_special_tokens=False)
158 | ))
159 |
160 |
161 | def train():
162 | parser = transformers.HfArgumentParser(
163 | (ModelArguments, DataArguments, TrainingArguments)
164 | )
165 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
166 |
167 | model = transformers.AutoModelForCausalLM.from_pretrained(
168 | model_args.model_name_or_path,
169 | trust_remote_code=True,
170 | cache_dir=training_args.cache_dir,
171 | )
172 | tokenizer = transformers.AutoTokenizer.from_pretrained(
173 | model_args.model_name_or_path,
174 | use_fast=True,
175 | trust_remote_code=True,
176 | model_max_length=training_args.model_max_length,
177 | )
178 |
179 | # tokenizer.eos_token_id = 70000
180 | # tokenizer.eos_token = "<|endoftext|>"
181 |
182 | if tokenizer.eos_token_id is None:
183 | tokenizer.eos_token_id = tokenizer.bos_token_id
184 | if tokenizer.eos_token is None:
185 | tokenizer.eos_token = tokenizer.bos_token
186 | if tokenizer.pad_token_id is None:
187 | tokenizer.pad_token_id = tokenizer.eos_token_id
188 | if tokenizer.pad_token is None:
189 | tokenizer.pad_token = tokenizer.eos_token
190 |
191 | if training_args.use_lora:
192 | from peft import LoraConfig, TaskType, get_peft_model
193 |
194 | peft_config = LoraConfig(
195 | task_type=TaskType.CAUSAL_LM,
196 | target_modules=["c_attn"],
197 | inference_mode=False,
198 | r=1,
199 | lora_alpha=32,
200 | lora_dropout=0.1,
201 | )
202 | model.enable_input_require_grads()
203 | model = get_peft_model(model, peft_config)
204 | model.print_trainable_parameters()
205 |
206 | dataset = SupervisedDataset(
207 | data_args.data_path, tokenizer, training_args.model_max_length
208 | )
209 | dataset.print_dataset_example()
210 |
211 | trainer = transformers.Trainer(
212 | model=model, args=training_args, train_dataset=dataset, tokenizer=tokenizer
213 | )
214 | trainer.train()
215 | trainer.save_state()
216 | trainer.save_model(output_dir=training_args.output_dir)
217 |
218 |
219 | if __name__ == "__main__":
220 | train()
--------------------------------------------------------------------------------
/finetune/run_finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export WANDB_DISABLED=true
4 |
5 | project_dir=$(cd "$(dirname $0)"; pwd)
6 |
7 | model=$1
8 | data_path=$2
9 | exp_id=$3
10 |
11 | output_dir=${project_dir}/output_models/${exp_id}
12 | log_dir=${project_dir}/log/${exp_id}
13 | mkdir -p ${output_dir} ${log_dir}
14 |
15 | # deepspeed_args="--master_port=23333 --hostfile=${project_dir}/configs/hostfile.txt --master_addr=10.0.0.16" # Default argument
16 | # deepspeed_args="--master_port=$((10000 + RANDOM % 20000)) --include=localhost:0,1,2,3" # Default argument
17 | deepspeed_args="--master_port=$((10000 + RANDOM % 20000))" # Default argument
18 |
19 | deepspeed ${deepspeed_args} ${project_dir}/finetune.py \
20 | --deepspeed ${project_dir}/ds_config_zero3.json \
21 | --model_name_or_path ${model} \
22 | --data_path ${data_path} \
23 | --model_max_length 4096 \
24 | --output_dir ${output_dir} \
25 | --per_device_train_batch_size 1 \
26 | --gradient_accumulation_steps 8 \
27 | --gradient_checkpointing True \
28 | --lr_scheduler_type cosine \
29 | --logging_steps 1 \
30 | --save_steps 100 \
31 | --learning_rate 2e-5 \
32 | --num_train_epochs 3 \
33 | --bf16 \
34 | | tee ${log_dir}/train.log \
35 | 2> ${log_dir}/train.err
36 |
--------------------------------------------------------------------------------
/model/configuration_codeshell.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # This code is based on Bigcode's GPTBigCode configuration. It has been modified from
17 | # its original forms to accommodate minor architectural differences compared to
18 | # GPTBigCode Configuration that trained the model.
19 |
20 | # coding=utf-8
21 | # Copyright 2023 The BigCode team and HuggingFace Inc. team.
22 | #
23 | # Licensed under the Apache License, Version 2.0 (the "License");
24 | # you may not use this file except in compliance with the License.
25 | # You may obtain a copy of the License at
26 | #
27 | # http://www.apache.org/licenses/LICENSE-2.0
28 | #
29 | # Unless required by applicable law or agreed to in writing, software
30 | # distributed under the License is distributed on an "AS IS" BASIS,
31 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32 | # See the License for the specific language governing permissions and
33 | # limitations under the License.
34 | """ CodeShell configuration"""
35 |
36 | from transformers.configuration_utils import PretrainedConfig
37 | from transformers.utils import logging
38 |
39 |
40 | logger = logging.get_logger(__name__)
41 |
42 |
43 | class CodeShellConfig(PretrainedConfig):
44 | """
45 | This is the configuration class to store the configuration of a [`CodeShellModel`]. It is used to instantiate a
46 | CodeShell model according to the specified arguments, defining the model architecture.
47 |
48 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
49 | documentation from [`PretrainedConfig`] for more information.
50 |
51 | Args:
52 | vocab_size (`int`, *optional*, defaults to 50257):
53 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
54 | `inputs_ids` passed when calling [`CodeShellModel`].
55 | n_positions (`int`, *optional*, defaults to 1024):
56 | The maximum sequence length that this model might ever be used with. Typically set this to something large
57 | just in case (e.g., 512 or 1024 or 2048).
58 | n_embd (`int`, *optional*, defaults to 768):
59 | Dimensionality of the embeddings and hidden states.
60 | n_layer (`int`, *optional*, defaults to 12):
61 | Number of hidden layers in the Transformer encoder.
62 | n_head (`int`, *optional*, defaults to 12):
63 | Number of attention heads for each attention layer in the Transformer encoder.
64 | n_inner (`int`, *optional*, defaults to None):
65 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
66 | activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`):
67 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new",
68 | "gelu_pytorch_tanh"]`.
69 | resid_pdrop (`float`, *optional*, defaults to 0.1):
70 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
71 | embd_pdrop (`float`, *optional*, defaults to 0.1):
72 | The dropout ratio for the embeddings.
73 | attn_pdrop (`float`, *optional*, defaults to 0.1):
74 | The dropout ratio for the attention.
75 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
76 | The epsilon to use in the layer normalization layers.
77 | initializer_range (`float`, *optional*, defaults to 0.02):
78 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
79 | scale_attn_weights (`bool`, *optional*, defaults to `True`):
80 | Scale attention weights by dividing by sqrt(hidden_size)..
81 | use_cache (`bool`, *optional*, defaults to `True`):
82 | Whether or not the model should return the last key/values attentions (not used by all models).
83 | attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
84 | Whether to call the fused softmax in float32.
85 | scale_attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
86 | Whether to scale the attention softmax in float32.
87 | attention_type (`bool`, *optional*, defaults to `True`):
88 | Whether to use Multi-Query Attion (`True`) or Multi-Head Attention (`False`).
89 | """
90 |
91 | model_type = "codeshell"
92 | keys_to_ignore_at_inference = ["past_key_values"]
93 | attribute_map = {
94 | "hidden_size": "n_embd",
95 | "max_position_embeddings": "n_positions",
96 | "num_attention_heads": "n_head",
97 | "num_hidden_layers": "n_layer",
98 | }
99 |
100 | def __init__(
101 | self,
102 | vocab_size=70144,
103 | n_positions=8192,
104 | n_embd=4096,
105 | n_layer=42,
106 | n_head=32,
107 | n_inner=None,
108 | activation_function="gelu_pytorch_tanh",
109 | resid_pdrop=0.1,
110 | embd_pdrop=0.1,
111 | attn_pdrop=0.1,
112 | layer_norm_epsilon=1e-5,
113 | initializer_range=0.02,
114 | scale_attn_weights=True,
115 | use_cache=True,
116 | bos_token_id=70000,
117 | eos_token_id=70000,
118 | attention_softmax_in_fp32=True,
119 | scale_attention_softmax_in_fp32=True,
120 | group_query_attention=True,
121 | num_query_groups=1,
122 | position_embedding_type="learned_absolute",
123 | rope_scaling=None,
124 | **kwargs,
125 | ):
126 | self.vocab_size = vocab_size
127 | self.n_positions = n_positions
128 | self.n_embd = n_embd
129 | self.n_layer = n_layer
130 | self.n_head = n_head
131 | self.n_inner = n_inner
132 | self.activation_function = activation_function
133 | self.resid_pdrop = resid_pdrop
134 | self.embd_pdrop = embd_pdrop
135 | self.attn_pdrop = attn_pdrop
136 | self.layer_norm_epsilon = layer_norm_epsilon
137 | self.initializer_range = initializer_range
138 | self.scale_attn_weights = scale_attn_weights
139 | self.use_cache = use_cache
140 | self.attention_softmax_in_fp32 = attention_softmax_in_fp32
141 | self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
142 | self.group_query_attention = group_query_attention
143 | self.num_query_groups = num_query_groups
144 | self.position_embedding_type = position_embedding_type
145 | self.rope_scaling = rope_scaling
146 | assert self.position_embedding_type in [
147 | "learned_absolute", "rope"
148 | ], "position_embedding_type must be one of ['learned_absolute', 'rope']"
149 |
150 | self.bos_token_id = bos_token_id
151 | self.eos_token_id = eos_token_id
152 |
153 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
154 |
--------------------------------------------------------------------------------
/model/modeling_codeshell.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # This code is based on Bigcode's GPTBigCode model. It has been modified from
17 | # its original forms to accommodate minor architectural differences compared to
18 | # GPTBigCode model that trained the model.
19 |
20 | # Copyright 2023 The Bigcode team and HuggingFace Inc. team.
21 | # Licensed under the Apache License, Version 2.0 (the "License");
22 | # you may not use this file except in compliance with the License.
23 | # You may obtain a copy of the License at
24 | #
25 | # http://www.apache.org/licenses/LICENSE-2.0
26 | #
27 | # Unless required by applicable law or agreed to in writing, software
28 | # distributed under the License is distributed on an "AS IS" BASIS,
29 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30 | # See the License for the specific language governing permissions and
31 | # limitations under the License.
32 | """PyTorch CodeShell model."""
33 | import os
34 | import math
35 | from typing import List, Optional, Tuple, Union, Callable
36 | from threading import Thread
37 | from queue import Queue
38 |
39 |
40 | import torch
41 | import torch.utils.checkpoint
42 | from torch import nn
43 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
44 |
45 | from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel, PretrainedConfig
46 | from transformers.generation.utils import GenerationConfig
47 |
48 | from transformers.activations import ACT2FN
49 | from transformers.modeling_outputs import (
50 | BaseModelOutputWithPastAndCrossAttentions,
51 | CausalLMOutputWithCrossAttentions,
52 | )
53 | from transformers.modeling_utils import PreTrainedModel
54 | from transformers.utils import (
55 | add_start_docstrings,
56 | add_start_docstrings_to_model_forward,
57 | )
58 | from .configuration_codeshell import CodeShellConfig
59 |
60 | # Fused kernels
61 | # Use separate functions for each case because conditionals prevent kernel fusion.
62 | # TODO: Could have better fused kernels depending on scaling, dropout and head mask.
63 | # Is it doable without writing 32 functions?
64 | @torch.jit.script
65 | def upcast_masked_softmax(
66 | x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
67 | ):
68 | input_dtype = x.dtype
69 | x = x.to(softmax_dtype) * scale
70 | x = torch.where(mask, x, mask_value)
71 | x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
72 | return x
73 |
74 |
75 | @torch.jit.script
76 | def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
77 | input_dtype = x.dtype
78 | x = x.to(softmax_dtype) * scale
79 | x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
80 | return x
81 |
82 |
83 | @torch.jit.script
84 | def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
85 | x = torch.where(mask, x, mask_value)
86 | x = torch.nn.functional.softmax(x, dim=-1)
87 | return x
88 |
89 |
90 | class CodeShellRotaryEmbedding(torch.nn.Module):
91 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
92 | super().__init__()
93 |
94 | self.dim = dim
95 | self.max_position_embeddings = max_position_embeddings
96 | self.base = base
97 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
98 | self.register_buffer("inv_freq", inv_freq)
99 |
100 | # Build here to make `torch.jit.trace` work.
101 | self._set_cos_sin_cache(
102 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
103 | )
104 |
105 | def _set_cos_sin_cache(self, seq_len, device, dtype):
106 | self.max_seq_len_cached = seq_len
107 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
108 |
109 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
110 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
111 | emb = torch.cat((freqs, freqs), dim=-1)
112 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
113 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
114 |
115 | def forward(self, x, seq_len=None):
116 | # x: [bs, num_attention_heads, seq_len, head_size]
117 | if seq_len > self.max_seq_len_cached:
118 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
119 |
120 | return (
121 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
122 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
123 | )
124 |
125 |
126 | class CodeShellLinearScalingRotaryEmbedding(CodeShellRotaryEmbedding):
127 | """CodeShellRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
128 |
129 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
130 | self.scaling_factor = scaling_factor
131 | super().__init__(dim, max_position_embeddings, base, device)
132 |
133 | def _set_cos_sin_cache(self, seq_len, device, dtype):
134 | self.max_seq_len_cached = seq_len
135 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
136 | t = t / self.scaling_factor
137 |
138 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
139 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
140 | emb = torch.cat((freqs, freqs), dim=-1)
141 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
142 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
143 |
144 |
145 | class CodeShellDynamicNTKScalingRotaryEmbedding(CodeShellRotaryEmbedding):
146 | """ShellRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
147 |
148 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
149 | self.scaling_factor = scaling_factor
150 | super().__init__(dim, max_position_embeddings, base, device)
151 |
152 | def _set_cos_sin_cache(self, seq_len, device, dtype):
153 | self.max_seq_len_cached = seq_len
154 |
155 | if seq_len > self.max_position_embeddings:
156 | base = self.base * (
157 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
158 | ) ** (self.dim / (self.dim - 2))
159 | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
160 | self.register_buffer("inv_freq", inv_freq)
161 |
162 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
163 |
164 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
165 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
166 | emb = torch.cat((freqs, freqs), dim=-1)
167 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
168 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
169 |
170 | def rotate_half(x):
171 | """Rotates half the hidden dims of the input."""
172 | x1 = x[..., : x.shape[-1] // 2]
173 | x2 = x[..., x.shape[-1] // 2 :]
174 | return torch.cat((-x2, x1), dim=-1)
175 |
176 |
177 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
178 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
179 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
180 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
181 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
182 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
183 | q_embed = (q * cos) + (rotate_half(q) * sin)
184 | k_embed = (k * cos) + (rotate_half(k) * sin)
185 | return q_embed, k_embed
186 |
187 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
188 | """
189 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
190 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
191 | """
192 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape
193 | if n_rep == 1:
194 | return hidden_states
195 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
196 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
197 |
198 | class CodeShellAttention(nn.Module):
199 | def __init__(self, config, layer_idx=None):
200 | super().__init__()
201 | self.mask_value = None
202 |
203 | self.position_embedding_type = config.position_embedding_type
204 | self.rope_scaling = config.rope_scaling
205 | self.max_position_embeddings = config.max_position_embeddings
206 |
207 | self.group_query_attention = config.group_query_attention
208 | self.num_query_groups = config.num_query_groups
209 | self.num_key_value_groups = config.num_attention_heads // config.num_query_groups
210 |
211 | self.embed_dim = config.hidden_size
212 | self.num_heads = config.num_attention_heads
213 | self.head_dim = self.embed_dim // self.num_heads
214 | self.kv_heads = config.num_query_groups if self.group_query_attention else self.num_heads
215 | self.kv_dim = self.kv_heads * self.head_dim
216 | self.split_size = self.embed_dim
217 | if self.head_dim * self.num_heads != self.embed_dim:
218 | raise ValueError(
219 | f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
220 | f" {self.num_heads})."
221 | )
222 |
223 | self.layer_idx = layer_idx
224 |
225 | self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim)
226 | self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
227 |
228 | self.attn_dropout = nn.Dropout(config.attn_pdrop)
229 | self.resid_dropout = nn.Dropout(config.resid_pdrop)
230 |
231 | if self.position_embedding_type == "rope":
232 | self._init_rope()
233 |
234 | def _init_rope(self):
235 | if self.rope_scaling is None:
236 | self.rotary_emb = CodeShellRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
237 | else:
238 | scaling_type = self.rope_scaling["type"]
239 | scaling_factor = self.rope_scaling["factor"]
240 | if scaling_type == "linear":
241 | self.rotary_emb = CodeShellLinearScalingRotaryEmbedding(
242 | self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
243 | )
244 | elif scaling_type == "dynamic":
245 | self.rotary_emb = CodeShellDynamicNTKScalingRotaryEmbedding(
246 | self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
247 | )
248 | else:
249 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
250 |
251 |
252 | def _get_mask_value(self, device, dtype):
253 | # torch.where expects a tensor. We use a cache to avoid recreating it every time.
254 | if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
255 | self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
256 | return self.mask_value
257 |
258 | def forward(
259 | self,
260 | hidden_states: torch.Tensor,
261 | layer_past: Optional[torch.Tensor] = None,
262 | attention_mask: Optional[torch.Tensor] = None,
263 | position_ids: Optional[torch.LongTensor] = None,
264 | head_mask: Optional[torch.Tensor] = None,
265 | use_cache: Optional[bool] = False,
266 | output_attentions: Optional[bool] = False,
267 | ) -> Union[
268 | Tuple[torch.Tensor, Optional[torch.Tensor]],
269 | Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
270 | ]:
271 | bsz, q_len, _ = hidden_states.size()
272 | query_states, key_states, value_states = self.c_attn(hidden_states).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=2)
273 |
274 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
275 | key_states = key_states.view(bsz, q_len, self.num_query_groups, self.head_dim).transpose(1, 2)
276 | value_states = value_states.view(bsz, q_len, self.num_query_groups, self.head_dim).transpose(1, 2)
277 |
278 | kv_seq_len = key_states.shape[-2]
279 | if layer_past is not None:
280 | kv_seq_len += layer_past[0].shape[-2]
281 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
282 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
283 |
284 | if layer_past is not None:
285 | # reuse k, v, self_attention
286 | key_states = torch.cat([layer_past[0], key_states], dim=2)
287 | value_states = torch.cat([layer_past[1], value_states], dim=2)
288 |
289 | layer_past = (key_states, value_states) if use_cache else None
290 |
291 | # repeat k/v heads if n_kv_heads < n_heads
292 | key_states = repeat_kv(key_states, self.num_heads // self.kv_heads)
293 | value_states = repeat_kv(value_states, self.num_heads // self.kv_heads)
294 |
295 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
296 |
297 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
298 | raise ValueError(
299 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
300 | f" {attn_weights.size()}"
301 | )
302 |
303 | if attention_mask is not None:
304 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
305 | raise ValueError(
306 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
307 | )
308 | mask_value = self._get_mask_value(attn_weights.device, attn_weights.dtype)
309 | # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
310 | attn_weights = torch.where(attention_mask, attn_weights, mask_value)
311 |
312 | # upcast attention to fp32
313 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
314 | attn_weights = self.attn_dropout(attn_weights)
315 | attn_output = torch.matmul(attn_weights, value_states)
316 |
317 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
318 | raise ValueError(
319 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
320 | f" {attn_output.size()}"
321 | )
322 |
323 | attn_output = attn_output.transpose(1, 2).contiguous()
324 | attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
325 |
326 | attn_output = self.c_proj(attn_output)
327 | attn_output = self.resid_dropout(attn_output)
328 |
329 | outputs = (attn_output, layer_past)
330 | if output_attentions:
331 | outputs += (attn_weights,)
332 |
333 | return outputs # a, present, (attentions)
334 |
335 |
336 | class CodeShellMLP(nn.Module):
337 | def __init__(self, intermediate_size, config):
338 | super().__init__()
339 | embed_dim = config.hidden_size
340 | self.c_fc = nn.Linear(embed_dim, intermediate_size)
341 | self.c_proj = nn.Linear(intermediate_size, embed_dim)
342 | self.act = ACT2FN[config.activation_function]
343 | self.dropout = nn.Dropout(config.resid_pdrop)
344 |
345 | # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward
346 | def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
347 | hidden_states = self.c_fc(hidden_states)
348 | hidden_states = self.act(hidden_states)
349 | hidden_states = self.c_proj(hidden_states)
350 | hidden_states = self.dropout(hidden_states)
351 | return hidden_states
352 |
353 |
354 | class CodeShellBlock(nn.Module):
355 | def __init__(self, config, layer_idx=None):
356 | super().__init__()
357 | hidden_size = config.hidden_size
358 | self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
359 |
360 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
361 | self.attn = CodeShellAttention(config, layer_idx=layer_idx)
362 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
363 |
364 | self.mlp = CodeShellMLP(self.inner_dim, config)
365 |
366 | def forward(
367 | self,
368 | hidden_states: Optional[Tuple[torch.Tensor]],
369 | layer_past: Optional[torch.Tensor] = None,
370 | attention_mask: Optional[torch.Tensor] = None,
371 | position_ids: Optional[torch.LongTensor] = None,
372 | head_mask: Optional[torch.Tensor] = None,
373 | encoder_hidden_states: Optional[torch.Tensor] = None,
374 | encoder_attention_mask: Optional[torch.Tensor] = None,
375 | use_cache: Optional[bool] = False,
376 | output_attentions: Optional[bool] = False,
377 | ) -> Union[
378 | Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
379 | ]:
380 | residual = hidden_states
381 | hidden_states = self.ln_1(hidden_states)
382 | attn_outputs = self.attn(
383 | hidden_states,
384 | layer_past=layer_past,
385 | attention_mask=attention_mask,
386 | position_ids=position_ids,
387 | head_mask=head_mask,
388 | use_cache=use_cache,
389 | output_attentions=output_attentions,
390 | )
391 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
392 |
393 | outputs = attn_outputs[1:]
394 | # residual connection
395 | hidden_states = attn_output + residual
396 |
397 | residual = hidden_states
398 | hidden_states = self.ln_2(hidden_states)
399 | feed_forward_hidden_states = self.mlp(hidden_states)
400 | # residual connection
401 | hidden_states = residual + feed_forward_hidden_states
402 |
403 | if use_cache:
404 | outputs = (hidden_states,) + outputs
405 | else:
406 | outputs = (hidden_states,) + outputs[1:]
407 |
408 | return outputs # hidden_states, present, (attentions, cross_attentions)
409 |
410 |
411 | class CodeShellPreTrainedModel(PreTrainedModel):
412 | """
413 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
414 | models.
415 | """
416 |
417 | config_class = CodeShellConfig
418 | base_model_prefix = "transformer"
419 | supports_gradient_checkpointing = True
420 | _no_split_modules = ["ShellBlock"]
421 | _skip_keys_device_placement = "past_key_values"
422 |
423 | def __init__(self, *inputs, **kwargs):
424 | super().__init__(*inputs, **kwargs)
425 |
426 | def _init_weights(self, module):
427 | """Initialize the weights."""
428 | if isinstance(module, (CodeShellMLP, CodeShellAttention)):
429 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
430 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
431 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
432 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/
433 | #
434 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
435 | module.c_proj.weight.data.normal_(
436 | mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
437 | )
438 | module.c_proj._is_hf_initialized = True
439 | elif isinstance(module, nn.Linear):
440 | # Slightly different from the TF version which uses truncated_normal for initialization
441 | # cf https://github.com/pytorch/pytorch/pull/5617
442 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
443 | if module.bias is not None:
444 | module.bias.data.zero_()
445 | elif isinstance(module, nn.Embedding):
446 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
447 | if module.padding_idx is not None:
448 | module.weight.data[module.padding_idx].zero_()
449 | elif isinstance(module, nn.LayerNorm):
450 | module.bias.data.zero_()
451 | module.weight.data.fill_(1.0)
452 |
453 | # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->Shell
454 | def _set_gradient_checkpointing(self, module, value=False):
455 | if isinstance(module, CodeShellModel):
456 | module.gradient_checkpointing = value
457 |
458 |
459 | GPT_BIGCODE_START_DOCSTRING = r"""
460 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
461 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
462 | etc.)
463 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
464 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
465 | and behavior.
466 | Parameters:
467 | config ([`CodeShellConfig`]): Model configuration class with all the parameters of the model.
468 | Initializing with a config file does not load the weights associated with the model, only the
469 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
470 | """
471 |
472 | GPT_BIGCODE_INPUTS_DOCSTRING = r"""
473 | Args:
474 | input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
475 | `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
476 | `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
477 | sequence tokens in the vocabulary.
478 | If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
479 | `input_ids`.
480 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
481 | [`PreTrainedTokenizer.__call__`] for details.
482 | [What are input IDs?](../glossary#input-ids)
483 | past_key_values (`Tuple[torch.Tensor]` of length `config.n_layers`):
484 | Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
485 | `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
486 | their past given to this model should not be passed as `input_ids` as they have already been computed.
487 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
488 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
489 | - 1 for tokens that are **not masked**,
490 | - 0 for tokens that are **masked**.
491 | If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
492 | `past_key_values`. In other words, the `attention_mask` always has to have the length:
493 | `len(past_key_values) + len(input_ids)`
494 | [What are attention masks?](../glossary#attention-mask)
495 | token_type_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`, *optional*):
496 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
497 | 1]`:
498 | - 0 corresponds to a *sentence A* token,
499 | - 1 corresponds to a *sentence B* token.
500 | [What are token type IDs?](../glossary#token-type-ids)
501 | position_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
502 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
503 | config.max_position_embeddings - 1]`.
504 | [What are position IDs?](../glossary#position-ids)
505 | head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
506 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
507 | - 1 indicates the head is **not masked**,
508 | - 0 indicates the head is **masked**.
509 | inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
510 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
511 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
512 | model's internal embedding lookup matrix.
513 | If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
514 | `past_key_values`).
515 | use_cache (`bool`, *optional*):
516 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
517 | `past_key_values`).
518 | output_attentions (`bool`, *optional*):
519 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
520 | tensors for more detail.
521 | output_hidden_states (`bool`, *optional*):
522 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
523 | more detail.
524 | return_dict (`bool`, *optional*):
525 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
526 | """
527 |
528 |
529 | @add_start_docstrings(
530 | "The bare GPT_BIGCODE Model transformer outputting raw hidden-states without any specific head on top.",
531 | GPT_BIGCODE_START_DOCSTRING,
532 | )
533 | class CodeShellModel(CodeShellPreTrainedModel):
534 | def __init__(self, config):
535 | super().__init__(config)
536 | self.group_query_attention = config.group_query_attention
537 | self.num_query_groups = config.num_query_groups
538 | self.position_embedding_type = config.position_embedding_type
539 | self.embed_dim = config.hidden_size
540 |
541 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
542 | if self.position_embedding_type == "learned_absolute":
543 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
544 | else:
545 | pass
546 |
547 | self.drop = nn.Dropout(config.embd_pdrop)
548 | self.h = nn.ModuleList([CodeShellBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
549 | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
550 |
551 | max_positions = config.max_position_embeddings
552 | self.register_buffer(
553 | "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False
554 | )
555 |
556 | self.gradient_checkpointing = False
557 |
558 | # Initialize weights and apply final processing
559 | self.post_init()
560 |
561 | def get_input_embeddings(self):
562 | return self.wte
563 |
564 | def set_input_embeddings(self, new_embeddings):
565 | self.wte = new_embeddings
566 |
567 | @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
568 | def forward(
569 | self,
570 | input_ids: Optional[torch.Tensor] = None,
571 | past_key_values: Optional[List[torch.Tensor]] = None,
572 | attention_mask: Optional[torch.Tensor] = None,
573 | token_type_ids: Optional[torch.Tensor] = None,
574 | position_ids: Optional[torch.Tensor] = None,
575 | head_mask: Optional[torch.Tensor] = None,
576 | inputs_embeds: Optional[torch.Tensor] = None,
577 | encoder_hidden_states: Optional[torch.Tensor] = None,
578 | encoder_attention_mask: Optional[torch.Tensor] = None,
579 | use_cache: Optional[bool] = None,
580 | output_attentions: Optional[bool] = None,
581 | output_hidden_states: Optional[bool] = None,
582 | return_dict: Optional[bool] = None,
583 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
584 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
585 | output_hidden_states = (
586 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
587 | )
588 | use_cache = use_cache if use_cache is not None else self.config.use_cache
589 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
590 |
591 | if input_ids is not None and inputs_embeds is not None:
592 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
593 | elif input_ids is not None:
594 | input_shape = input_ids.size()
595 | input_ids = input_ids.reshape(-1, input_shape[-1])
596 | batch_size = input_ids.shape[0]
597 | elif inputs_embeds is not None:
598 | input_shape = inputs_embeds.size()[:-1]
599 | batch_size = inputs_embeds.shape[0]
600 | else:
601 | raise ValueError("You have to specify either input_ids or inputs_embeds")
602 |
603 | if batch_size <= 0:
604 | raise ValueError("batch_size has to be defined and > 0")
605 |
606 | device = input_ids.device if input_ids is not None else inputs_embeds.device
607 |
608 | if token_type_ids is not None:
609 | token_type_ids = token_type_ids.reshape(-1, input_shape[-1])
610 | if position_ids is not None:
611 | position_ids = position_ids.reshape(-1, input_shape[-1])
612 |
613 | if past_key_values is None:
614 | past_length = 0
615 | past_key_values = tuple([None] * len(self.h))
616 | else:
617 | past_length = past_key_values[0][0].size(-2)
618 |
619 | if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
620 | # create position_ids on the fly for batch generation
621 | position_ids = attention_mask.long().cumsum(-1) - 1
622 | position_ids.masked_fill_(attention_mask == 0, 1)
623 | if past_length > 0:
624 | position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
625 | elif position_ids is None:
626 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
627 | position_ids = position_ids.unsqueeze(0).reshape(-1, input_shape[-1])
628 |
629 | # Self-attention mask.
630 | query_length = input_shape[-1]
631 | key_length = past_length + query_length
632 | self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
633 |
634 | if attention_mask is not None:
635 | self_attention_mask = self_attention_mask * attention_mask.reshape(batch_size, 1, -1).to(
636 | dtype=torch.bool, device=self_attention_mask.device
637 | )
638 |
639 | # MQA models: (batch_size, query_length, n_heads, key_length)
640 | # MHA models: (batch_size, n_heads, query_length, key_length)
641 | attention_mask = self_attention_mask.unsqueeze(1)
642 |
643 | encoder_attention_mask = None
644 |
645 | # Prepare head mask if needed
646 | # 1.0 in head_mask indicate we keep the head
647 | # attention_probs has shape bsz x n_heads x N x N
648 | # head_mask has shape n_layer x batch x n_heads x N x N
649 | head_mask = self.get_head_mask(head_mask, self.config.n_layer)
650 |
651 | if inputs_embeds is None:
652 | inputs_embeds = self.wte(input_ids)
653 |
654 | hidden_states = inputs_embeds
655 | if self.position_embedding_type == "learned_absolute":
656 | position_embeds = self.wpe(position_ids)
657 | hidden_states = hidden_states + position_embeds
658 |
659 | if token_type_ids is not None:
660 | token_type_embeds = self.wte(token_type_ids)
661 | hidden_states = hidden_states + token_type_embeds
662 |
663 | hidden_states = self.drop(hidden_states)
664 |
665 | output_shape = input_shape + (hidden_states.size(-1),)
666 |
667 | presents = [] if use_cache else None
668 | all_self_attentions = () if output_attentions else None
669 | all_hidden_states = () if output_hidden_states else None
670 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
671 | if output_hidden_states:
672 | all_hidden_states = all_hidden_states + (hidden_states,)
673 |
674 | if self.gradient_checkpointing and self.training:
675 |
676 | def create_custom_forward(module):
677 | def custom_forward(*inputs):
678 | # None for past_key_value
679 | return module(*inputs, use_cache, output_attentions)
680 |
681 | return custom_forward
682 |
683 | outputs = torch.utils.checkpoint.checkpoint(
684 | create_custom_forward(block),
685 | hidden_states,
686 | None,
687 | attention_mask,
688 | position_ids,
689 | head_mask[i],
690 | encoder_hidden_states,
691 | encoder_attention_mask,
692 | )
693 | else:
694 | outputs = block(
695 | hidden_states,
696 | layer_past=layer_past,
697 | attention_mask=attention_mask,
698 | position_ids=position_ids,
699 | head_mask=head_mask[i],
700 | encoder_hidden_states=encoder_hidden_states,
701 | encoder_attention_mask=encoder_attention_mask,
702 | use_cache=use_cache,
703 | output_attentions=output_attentions,
704 | )
705 |
706 | hidden_states = outputs[0]
707 | if use_cache:
708 | presents.append(outputs[1])
709 |
710 | if output_attentions:
711 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
712 |
713 | hidden_states = self.ln_f(hidden_states)
714 | hidden_states = hidden_states.reshape(output_shape)
715 | # Add last hidden state
716 | if output_hidden_states:
717 | all_hidden_states = all_hidden_states + (hidden_states,)
718 |
719 |
720 | if not return_dict:
721 | return tuple(
722 | v
723 | for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
724 | if v is not None
725 | )
726 |
727 | return BaseModelOutputWithPastAndCrossAttentions(
728 | last_hidden_state=hidden_states,
729 | past_key_values=presents,
730 | hidden_states=all_hidden_states,
731 | attentions=all_self_attentions,
732 | )
733 |
734 | class EndOfFunctionCriteria(StoppingCriteria):
735 | """Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
736 | def __init__(self, input_lengths, eof_strings, tokenizer):
737 | self.input_lengths = input_lengths
738 | self.eof_strings = eof_strings
739 | self.tokenizer = tokenizer
740 |
741 | def __call__(self, input_ids, scores, **kwargs):
742 | """Returns true if all generated sequences contain any of the end-of-function strings."""
743 | decoded_generations = []
744 | for _input_ids, input_length in zip(input_ids, self.input_lengths):
745 | decoded_generations.append(self.tokenizer.decode(_input_ids[input_length:]))
746 | done = []
747 | for decoded_generation in decoded_generations:
748 | done.append(
749 | any(
750 | [
751 | stop_string in decoded_generation
752 | for stop_string in self.eof_strings
753 | ]
754 | )
755 | )
756 | return all(done)
757 |
758 | class TextIterStreamer:
759 | def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
760 | self.tokenizer = tokenizer
761 | self.skip_prompt = skip_prompt
762 | self.skip_special_tokens = skip_special_tokens
763 | self.tokens = []
764 | self.text_queue = Queue()
765 | self.next_tokens_are_prompt = True
766 |
767 | def put(self, value):
768 | if self.skip_prompt and self.next_tokens_are_prompt:
769 | self.next_tokens_are_prompt = False
770 | else:
771 | if len(value.shape) > 1:
772 | value = value[0]
773 | self.tokens.extend(value.tolist())
774 | self.text_queue.put(
775 | self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
776 |
777 | def end(self):
778 | self.text_queue.put(None)
779 |
780 | def __iter__(self):
781 | return self
782 |
783 | def __next__(self):
784 | value = self.text_queue.get()
785 | if value is None:
786 | raise StopIteration()
787 | else:
788 | return value
789 |
790 |
791 | @add_start_docstrings(
792 | """
793 | The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input
794 | embeddings).
795 | """,
796 | GPT_BIGCODE_START_DOCSTRING,
797 | )
798 | class CodeShellForCausalLM(CodeShellPreTrainedModel):
799 | _tied_weights_keys = ["lm_head.weight"]
800 |
801 | def __init__(self, config):
802 | super().__init__(config)
803 | self.transformer = CodeShellModel(config)
804 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
805 |
806 | # Initialize weights and apply final processing
807 | self.post_init()
808 |
809 | def quantize(self, bits: int):
810 | try:
811 | import bitsandbytes
812 | from .quantizer import quantize
813 | except ImportError:
814 | raise ImportError(f"Needs bitsandbytes to run quantize.")
815 | return quantize(self, bits)
816 |
817 | def get_output_embeddings(self):
818 | return self.lm_head
819 |
820 | def set_output_embeddings(self, new_embeddings):
821 | self.lm_head = new_embeddings
822 |
823 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
824 | token_type_ids = kwargs.get("token_type_ids", None)
825 | # only last token for inputs_ids if past is defined in kwargs
826 | if past_key_values:
827 | input_ids = input_ids[:, -1].unsqueeze(-1)
828 | if token_type_ids is not None:
829 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
830 |
831 | attention_mask = kwargs.get("attention_mask", None)
832 | position_ids = kwargs.get("position_ids", None)
833 |
834 | if attention_mask is not None and position_ids is None:
835 | # create position_ids on the fly for batch generation
836 | position_ids = attention_mask.long().cumsum(-1) - 1
837 | position_ids.masked_fill_(attention_mask == 0, 1)
838 | if past_key_values:
839 | position_ids = position_ids[:, -1].unsqueeze(-1)
840 | else:
841 | position_ids = None
842 |
843 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
844 | if inputs_embeds is not None and past_key_values is None:
845 | model_inputs = {"inputs_embeds": inputs_embeds}
846 | else:
847 | model_inputs = {"input_ids": input_ids}
848 |
849 | model_inputs.update(
850 | {
851 | "past_key_values": past_key_values,
852 | "use_cache": kwargs.get("use_cache"),
853 | "position_ids": position_ids,
854 | "attention_mask": attention_mask,
855 | "token_type_ids": token_type_ids,
856 | }
857 | )
858 | return model_inputs
859 |
860 | @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
861 | def forward(
862 | self,
863 | input_ids: Optional[torch.Tensor] = None,
864 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
865 | attention_mask: Optional[torch.Tensor] = None,
866 | token_type_ids: Optional[torch.Tensor] = None,
867 | position_ids: Optional[torch.Tensor] = None,
868 | head_mask: Optional[torch.Tensor] = None,
869 | inputs_embeds: Optional[torch.Tensor] = None,
870 | encoder_hidden_states: Optional[torch.Tensor] = None,
871 | encoder_attention_mask: Optional[torch.Tensor] = None,
872 | labels: Optional[torch.Tensor] = None,
873 | use_cache: Optional[bool] = None,
874 | output_attentions: Optional[bool] = None,
875 | output_hidden_states: Optional[bool] = None,
876 | return_dict: Optional[bool] = None,
877 | ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
878 | r"""
879 | labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
880 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
881 | `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
882 | are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
883 | """
884 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
885 |
886 | transformer_outputs = self.transformer(
887 | input_ids,
888 | past_key_values=past_key_values,
889 | attention_mask=attention_mask,
890 | token_type_ids=token_type_ids,
891 | position_ids=position_ids,
892 | head_mask=head_mask,
893 | inputs_embeds=inputs_embeds,
894 | encoder_hidden_states=encoder_hidden_states,
895 | encoder_attention_mask=encoder_attention_mask,
896 | use_cache=use_cache,
897 | output_attentions=output_attentions,
898 | output_hidden_states=output_hidden_states,
899 | return_dict=return_dict,
900 | )
901 | hidden_states = transformer_outputs[0]
902 | lm_logits = self.lm_head(hidden_states)
903 | loss = None
904 | if labels is not None:
905 | # Shift so that tokens < n predict n
906 | shift_logits = lm_logits[..., :-1, :].contiguous()
907 | shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
908 | # Flatten the tokens
909 | loss_fct = CrossEntropyLoss()
910 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
911 |
912 | if not return_dict:
913 | output = (lm_logits,) + transformer_outputs[1:]
914 | return ((loss,) + output) if loss is not None else output
915 |
916 | return CausalLMOutputWithCrossAttentions(
917 | loss=loss,
918 | logits=lm_logits,
919 | past_key_values=transformer_outputs.past_key_values,
920 | hidden_states=transformer_outputs.hidden_states,
921 | attentions=transformer_outputs.attentions,
922 | )
923 |
924 | @staticmethod
925 | def _reorder_cache(past_key_values, beam_idx):
926 | reordered_past = ()
927 | for layer_past in past_key_values:
928 | reordered_past += (
929 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
930 | )
931 | return reordered_past
932 |
933 |
934 | def build_chat_input(self, query, history, tokenizer, max_new_tokens=None):
935 | user_name = "## human:"
936 | ai_name = "## assistant: "
937 | stop = '||'
938 |
939 | prompt = ''
940 | for q, r in history:
941 | prompt += f"{user_name}{q}{stop}"
942 | prompt += f"{ai_name}{r}{stop}"
943 | prompt += f"{user_name}{query}{stop}"
944 | prompt += ai_name.rstrip()
945 |
946 | max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
947 | max_new_tokens = max_new_tokens or 128
948 | max_input_tokens = self.config.n_positions - max_new_tokens
949 |
950 | input_tokens = tokenizer.encode(prompt)
951 | input_tokens = input_tokens[-max_input_tokens:] # truncate left
952 | return torch.LongTensor([input_tokens]).to(self.device)
953 |
954 | def chat(self, query, history, tokenizer, stream=False,
955 | generation_config: Optional[GenerationConfig]=None):
956 | generation_config = generation_config or self.generation_config
957 | input_ids = self.build_chat_input(query, history, tokenizer, generation_config.max_new_tokens)
958 | stopping_criteria = StoppingCriteriaList(
959 | [EndOfFunctionCriteria([len(input_ids[0])], ['||', '|end|', '<|endoftext|>', '## human'], tokenizer)]
960 | )
961 |
962 | if stream:
963 | streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
964 | Thread(target=self.generate, kwargs=dict(
965 | inputs=input_ids, streamer=streamer,
966 | stopping_criteria = stopping_criteria,
967 | generation_config=generation_config,
968 | )).start()
969 | return streamer
970 | else:
971 | outputs = self.generate(input_ids, generation_config=generation_config, stopping_criteria = stopping_criteria)
972 | response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
973 | return response
974 |
975 | def generate_stream(self, prompt, tokenizer, generation_config=None, **kwargs):
976 | generation_config = generation_config or self.generation_config
977 | max_input_tokens = self.config.n_positions - self.generation_config.max_new_tokens
978 |
979 | input_ids = tokenizer.encode(prompt)
980 | input_ids = input_ids[-max_input_tokens:] # truncate left
981 |
982 | stopping_criteria = StoppingCriteriaList(
983 | [EndOfFunctionCriteria([len(input_ids[0])], ['||', '|end|', '<|endoftext|>', '## human'], tokenizer)]
984 | )
985 |
986 | streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
987 | Thread(target=self.generate, kwargs=dict(
988 | inputs=input_ids, stopping_criteria=stopping_criteria, **kwargs
989 | )).start()
990 | return streamer
991 |
992 |
993 | class CodeShell4bitForCausalLM(CodeShellForCausalLM):
994 | def __init__(self, config):
995 | CodeShellPreTrainedModel.__init__(self, config)
996 | self.transformer = CodeShellModel(config)
997 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
998 |
999 | try:
1000 | import bitsandbytes
1001 | from .quantizer import quantize_offline
1002 | quantize_offline(self)
1003 | except ImportError:
1004 | raise ImportError(f"Needs bitsandbytes to run quantize.")
1005 |
1006 | self.post_init()
1007 |
1008 | @classmethod
1009 | def from_pretrained(
1010 | cls,
1011 | pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1012 | *model_args,
1013 | config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
1014 | cache_dir: Optional[Union[str, os.PathLike]] = None,
1015 | ignore_mismatched_sizes: bool = False,
1016 | force_download: bool = False,
1017 | local_files_only: bool = False,
1018 | token: Optional[Union[str, bool]] = None,
1019 | revision: str = "main",
1020 | use_safetensors: bool = None,
1021 | **kwargs,
1022 | ):
1023 | if not isinstance(config, PretrainedConfig):
1024 | config_path = config if config is not None else pretrained_model_name_or_path
1025 | config, _ = cls.config_class.from_pretrained(
1026 | config_path,
1027 | cache_dir=cache_dir,
1028 | return_unused_kwargs=True,
1029 | force_download=force_download,
1030 | resume_download=False,
1031 | proxies=None,
1032 | local_files_only=local_files_only,
1033 | token=token,
1034 | revision=revision,
1035 | subfolder="",
1036 | _from_auto=False,
1037 | _from_pipeline=None,
1038 | **kwargs,
1039 | )
1040 |
1041 | # Load config if we don't provide a configuration
1042 | from .quantizer import load_state_dict_for_qunantied_model
1043 | model = cls(config)
1044 | state_dict = torch.load(os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin'), map_location="cpu")
1045 | model = load_state_dict_for_qunantied_model(model, state_dict)
1046 | model.eval()
1047 |
1048 | # If it is a model with generation capabilities, attempt to load the generation config
1049 | if model.can_generate():
1050 | try:
1051 | model.generation_config = GenerationConfig.from_pretrained(
1052 | pretrained_model_name_or_path,
1053 | cache_dir=cache_dir,
1054 | force_download=force_download,
1055 | resume_download=False,
1056 | proxies=None,
1057 | local_files_only=local_files_only,
1058 | token=token,
1059 | revision=revision,
1060 | subfolder="",
1061 | _from_auto=False,
1062 | _from_pipeline=None,
1063 | **kwargs,
1064 | )
1065 | except (OSError, TypeError):
1066 | pass
1067 |
1068 | device_map = kwargs.pop("device_map", None)
1069 | if device_map is not None:
1070 | model = model.to(torch.device(device_map))
1071 |
1072 | return model
--------------------------------------------------------------------------------
/model/quantizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | try:
17 | import bitsandbytes as bnb
18 | from bitsandbytes.nn.modules import Params4bit, Int8Params
19 | except ImportError:
20 | pass
21 | import torch
22 |
23 | def Params4bitCuda(self, device):
24 | self.data = self.data.cuda(device)
25 | if self.quant_state is not None:
26 | self.quant_state[0] = self.quant_state[0].cuda(device)
27 | self.quant_state[6] = self.quant_state[6].cuda(device)
28 | return self
29 |
30 | def Params4bitTo(self, *args, **kwargs):
31 | device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
32 |
33 | if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"):
34 | return self.cuda(device)
35 | else:
36 | if self.quant_state is not None:
37 | # make sure the quantization state is on the right device
38 | self.quant_state[0] = self.quant_state[0].to(device)
39 | self.quant_state[6] = self.quant_state[6].to(device)
40 | new_param = Params4bit(self.to(device=device, dtype=dtype, non_blocking=non_blocking),
41 | requires_grad=self.requires_grad, quant_state=self.quant_state,
42 | blocksize=self.blocksize, compress_statistics=self.compress_statistics,
43 | quant_type=self.quant_type)
44 |
45 | return new_param
46 |
47 | class Linear4bitOnline(torch.nn.Module):
48 | def __init__(self, weight, bias, quant_type):
49 | super().__init__()
50 | self.weight = Params4bit(
51 | weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
52 | )
53 | self.compute_dtype = None
54 | #self.weight.cuda(weight.device)
55 | self.bias = bias
56 |
57 | def forward(self, x: torch.Tensor):
58 | # weights are cast automatically as Int8Params, but the bias has to be cast manually
59 | if self.bias is not None and self.bias.dtype != x.dtype:
60 | self.bias.data = self.bias.data.to(x.dtype)
61 |
62 | if getattr(self.weight, "quant_state", None) is None:
63 | print(
64 | "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
65 | )
66 | inp_dtype = x.dtype
67 | if self.compute_dtype is not None:
68 | x = x.to(self.compute_dtype)
69 |
70 | bias = None if self.bias is None else self.bias.to(self.compute_dtype)
71 | out = bnb.matmul_4bit(
72 | x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
73 | )
74 |
75 | out = out.to(inp_dtype)
76 |
77 | return out
78 |
79 | class Linear8bitLtOnline(torch.nn.Module):
80 | def __init__(
81 | self,
82 | weight,
83 | bias,
84 | has_fp16_weights=True,
85 | memory_efficient_backward=False,
86 | threshold=0.0,
87 | index=None,
88 | ):
89 | super().__init__()
90 | assert (
91 | not memory_efficient_backward
92 | ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
93 | self.state = bnb.MatmulLtState()
94 | self.index = index
95 |
96 | # Necessary for stacked layers
97 | self.state.threshold = threshold
98 | self.state.has_fp16_weights = has_fp16_weights
99 | self.state.memory_efficient_backward = memory_efficient_backward
100 | if threshold > 0.0 and not has_fp16_weights:
101 | self.state.use_pool = True
102 |
103 | self.weight = Int8Params(
104 | weight.data,
105 | has_fp16_weights=has_fp16_weights,
106 | requires_grad=has_fp16_weights,
107 | )
108 | self.bias = bias
109 |
110 | def init_8bit_state(self):
111 | self.state.CB = self.weight.CB
112 | self.state.SCB = self.weight.SCB
113 | self.weight.CB = None
114 | self.weight.SCB = None
115 |
116 | def forward(self, x: torch.Tensor):
117 | self.state.is_training = self.training
118 | if self.weight.CB is not None:
119 | self.init_8bit_state()
120 |
121 | # weights are cast automatically as Int8Params, but the bias has to be cast manually
122 | if self.bias is not None and self.bias.dtype != x.dtype:
123 | self.bias.data = self.bias.data.to(x.dtype)
124 |
125 | out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
126 |
127 | if not self.state.has_fp16_weights:
128 | if self.state.CB is not None and self.state.CxB is not None:
129 | # we converted 8-bit row major to turing/ampere format in the first inference pass
130 | # we no longer need the row-major weight
131 | del self.state.CB
132 | self.weight.data = self.state.CxB
133 | return out
134 |
135 | def quantize_online(model, bits: int):
136 | def quant(weight, bias=None):
137 | if bits == 8:
138 | linear = Linear8bitLtOnline(
139 | weight,
140 | bias,
141 | has_fp16_weights=False,
142 | threshold=6.0,
143 | )
144 | if bias is not None:
145 | linear.bias = torch.nn.Parameter(bias)
146 | elif bits == 4:
147 | linear = Linear4bitOnline(
148 | weight,
149 | bias,
150 | quant_type="nf4", #fp4/nf4
151 | )
152 | else:
153 | raise ValueError("quantize only support 4/8 bit")
154 | return linear
155 |
156 | def auto_quant(layer):
157 | if hasattr(layer,"bias"):
158 | linear = quant(layer.weight,bias=layer.bias)
159 | else:
160 | linear = quant(layer.weight)
161 | return linear
162 |
163 | for i,layer in enumerate(model.transformer.h):
164 | layer.mlp.c_fc = auto_quant(layer.mlp.c_fc)
165 | layer.mlp.c_proj = auto_quant(layer.mlp.c_proj)
166 |
167 | layer.attn.c_attn=auto_quant(layer.attn.c_attn)
168 | layer.attn.c_proj=auto_quant(layer.attn.c_proj)
169 |
170 | return model
171 |
172 |
173 | general_weight_dict = {
174 | "transformer.wte.weight": False,
175 | "transformer.ln_f.weight": False,
176 | "transformer.ln_f.bias": False,
177 | "lm_head.weight": False,
178 | }
179 |
180 | layer_weight_dict = {
181 | "transformer.h.{i}.ln_1.weight": False,
182 | "transformer.h.{i}.ln_1.bias": False,
183 | "transformer.h.{i}.attn.c_attn.weight": True,
184 | "transformer.h.{i}.attn.c_attn.bias": False,
185 | "transformer.h.{i}.attn.c_proj.weight": True,
186 | "transformer.h.{i}.attn.c_proj.bias": False,
187 | "transformer.h.{i}.attn.rotary_emb.inv_freq": False,
188 | "transformer.h.{i}.ln_2.weight": False,
189 | "transformer.h.{i}.ln_2.bias": False,
190 | "transformer.h.{i}.mlp.c_fc.weight": True,
191 | "transformer.h.{i}.mlp.c_fc.bias": False,
192 | "transformer.h.{i}.mlp.c_proj.weight": True,
193 | "transformer.h.{i}.mlp.c_proj.bias": False,
194 | }
195 | num_dict = {str(i):i for i in range(100)}
196 |
197 | def set_value(model, name, state_dict, is_4bit):
198 | keys = name.split('.')
199 | parent = model
200 | for key in keys[:-1]:
201 | if key in num_dict:
202 | parent = parent[num_dict[key]]
203 | else:
204 | parent = getattr(parent, key)
205 | if is_4bit:
206 | weight_data = state_dict[f'{name}.data']
207 | weight_quant_state = state_dict[f'{name}.quant_state']
208 | assert weight_data is not None, name
209 | assert weight_quant_state is not None, name
210 | setattr(parent, keys[-1], Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state))
211 | else:
212 | setattr(parent, keys[-1], state_dict[name])
213 |
214 | def quantize_offline(model):
215 | for i, layer in enumerate(model.transformer.h):
216 | layer.mlp.c_fc = bnb.nn.Linear4bit(
217 | layer.mlp.c_fc.weight.shape[1],
218 | layer.mlp.c_fc.weight.shape[0],
219 | False,
220 | torch.bfloat16,
221 | compress_statistics=True,
222 | quant_type="nf4",
223 | )
224 | layer.mlp.c_proj = bnb.nn.Linear4bit(
225 | layer.mlp.c_proj.weight.shape[1],
226 | layer.mlp.c_proj.weight.shape[0],
227 | False,
228 | torch.bfloat16,
229 | compress_statistics=True,
230 | quant_type="nf4",
231 | )
232 |
233 | layer.attn.c_attn = bnb.nn.Linear4bit(
234 | layer.attn.c_attn.weight.shape[1],
235 | layer.attn.c_attn.weight.shape[0],
236 | False,
237 | torch.bfloat16,
238 | compress_statistics=True,
239 | quant_type="nf4",
240 | )
241 | layer.attn.c_proj = bnb.nn.Linear4bit(
242 | layer.attn.c_proj.weight.shape[1],
243 | layer.attn.c_proj.weight.shape[0],
244 | False,
245 | torch.bfloat16,
246 | compress_statistics=True,
247 | quant_type="nf4",
248 | )
249 | return model
250 |
251 | def load_state_dict_for_qunantied_model(model, state_dict):
252 | #replace Params4bit.cuda with Params4bitCuda
253 | Params4bit.cuda = Params4bitCuda
254 | Params4bit.to = Params4bitTo
255 |
256 | for name, is_4bit in general_weight_dict.items():
257 | set_value(model, name, state_dict, is_4bit)
258 |
259 | for layer_i in range(len(model.transformer.h)):
260 | for name, is_4bit in layer_weight_dict.items():
261 | name = name.replace('{i}', str(layer_i))
262 | set_value(model, name, state_dict, is_4bit)
263 | return model
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers>=4.34.0
2 | torch>=2.0.1
3 |
--------------------------------------------------------------------------------
/tokenizer/eval_tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved.
3 |
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from transformers import AutoTokenizer
17 |
18 | import jsonlines
19 | import os
20 |
21 | from fire import Fire
22 |
23 |
24 | def evaluate(tokenizer_path: str, corpora_path: str):
25 | """
26 | Evaluate the compression ratio of a tokenizer on a corpora
27 | """
28 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
29 | corpora = [os.path.join(corpora_path, f) for f in os.listdir(corpora_path)
30 | if not os.path.isdir(os.path.join(corpora_path, f))]
31 |
32 | total_characters = 0
33 | total_tokens = 0
34 |
35 | for corpus in corpora:
36 | print(f"Processing {corpus}")
37 | texts = []
38 | partial_characters = 0
39 | partial_tokens = 0
40 | with open(corpus, "r", encoding="utf-8") as f_in:
41 | for item in jsonlines.Reader(f_in):
42 | texts.append(item["text"])
43 | partial_characters += len(item["text"])
44 |
45 | # noinspection PyUnboundLocalVariable
46 | tokens = tokenizer(texts)
47 |
48 | for seg in tokens["input_ids"]:
49 | partial_tokens += len(seg)
50 |
51 | total_characters += partial_characters
52 | total_tokens += partial_tokens
53 | print(f"Characters: {partial_characters}")
54 | print(f"Tokens: {partial_tokens}")
55 | print(f"Compression ratio: {partial_characters / partial_tokens}")
56 |
57 | print(f"Total characters: {total_characters}")
58 | print(f"Total tokens: {total_tokens}")
59 | print(f"Compression ratio: {total_characters / total_tokens}")
60 |
61 |
62 | if __name__ == "__main__":
63 | Fire(evaluate)
--------------------------------------------------------------------------------
/tokenizer/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {}
2 |
--------------------------------------------------------------------------------
/tokenizer/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "add_prefix_space": false,
3 | "additional_special_tokens": [
4 | "<|endoftext|>",
5 | "",
6 | "",
7 | "",
8 | "",
9 | "",
10 | "",
11 | "",
12 | "",
13 | "",
14 | "",
15 | "",
16 | "",
17 | "",
18 | "",
19 | "",
20 | "",
21 | "",
22 | "",
23 | "<|end|>"
24 | ],
25 | "bos_token": "<|endoftext|>",
26 | "clean_up_tokenization_spaces": true,
27 | "eos_token": "<|endoftext|>",
28 | "model_max_length": 8192,
29 | "tokenizer_class": "GPT2Tokenizer",
30 | "unk_token": "<|endoftext|>",
31 | "vocab_size": 70020,
32 | "pad_token": "<|endoftext|>"
33 | }
34 |
--------------------------------------------------------------------------------