├── License.pdf ├── README.md ├── README_EN.md ├── demos ├── cli_demo.py ├── openai_api.py ├── requirements_web_demo.txt └── web_demo.py ├── evaluation ├── README.md ├── README_EN.md ├── all_config.yaml ├── chat_humaneval.py ├── eval.py └── run_eval.sh ├── finetune ├── README.md ├── README_EN.md ├── ds_config_zero3.json ├── finetune.py └── run_finetune.sh ├── model ├── configuration_codeshell.py ├── modeling_codeshell.py └── quantizer.py ├── requirements.txt └── tokenizer ├── eval_tokenizer.py ├── special_tokens_map.json ├── tokenizer.json └── tokenizer_config.json /License.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisdomShell/codeshell/09d1adc88ccada1a92924c69ece0cf0e73899b1b/License.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | 4 |

5 | 6 |

7 | 🤗 Hugging Face • 🤖 ModelScope • ⭕️ WiseModel • 🌐 PKU-KCL 8 |

9 | 10 |
11 | 12 | [![license](https://img.shields.io/github/license/modelscope/modelscope.svg)](https://github.com/WisdomShell/codeshell/blob/main/License.pdf) 13 |

14 |

中文|English

15 |

16 |
17 | 18 | ## Introduction 19 | 20 | CodeShell是[北京大学知识计算实验室](http://se.pku.edu.cn/kcl/)联合四川天府银行AI团队研发的多语言代码大模型基座。CodeShell具有70亿参数,在五千亿Tokens进行了训练,上下文窗口长度为8192。在权威的代码评估Benchmark(HumanEval与MBPP)上,CodeShell取得同等规模最好的性能。与此同时,我们提供了与CodeShell配套的部署方案与IDE插件,请参考代码库[CodeShell](https://github.com/WisdomShell/codeshell)。同时,为了方便中国用户下载,我们在[Modelscope](https://modelscope.cn/organization/WisdomShell)和[Wisemodel](https://www.wisemodel.cn/models/WisdomShell/CodeShell-7B/)中也上传了对应版本,国内用户可以访问。 21 | 22 | 23 | 本次开源的模型如下: 24 | 25 | - CodeShell Base:CodelShell底座模型,具有强大的代码基础能力。 26 | - CodeShell Chat:CodelShell对话模型,在代码问答、代码补全等下游任务重性能优异。 27 | - CodeShell Chat 4bit:CodelShell对话模型4bit量化版本,在保证模型性能的前提下内存消耗更小,速度更快。 28 | - CodeShell CPP:CodelShell对话模型CPP版本,支持开发者在没有GPU的个人电脑中使用。注意,CPP版本同样支持量化操作,用户可以在最小内存为8G的个人电脑中运行CodeShell。 29 | 30 | 31 | ## Main Characteristics of CodeShell 32 | 33 | - **强大的性能**:CodelShell在HumanEval和MBPP上达到了7B代码基座大模型的最优性能 34 | - **完整的体系**:除了代码大模型,同时开源IDE(VS Code与JetBrains)插件,形成开源的全栈技术体系 35 | - **轻量化部署**:支持本地C++部署,提供轻量快速的本地化软件开发助手解决方案 36 | - **全面的评测**:提供支持完整项目上下文、覆盖代码生成、代码缺陷检测与修复、测试用例生成等常见软件开发活动的多任务评测体系(即将开源) 37 | - **高效的训练**:基于高效的数据治理体系,CodeShell在完全冷启动情况下,只训练了五千亿Token即获得了优异的性能 38 | 39 | ## Performance 40 | 41 | 我们选取了目前最流行的两个代码评测数据集(HumanEval与MBPP)对模型进行评估,与目前最先进的两个7b代码大模型CodeLlama与Starcoder相比,Codeshell 取得了最优的成绩。具体评测结果如下。 42 | 43 | | 任务 | CodeShell-7b | CodeLlama-7b | Starcoder-7b | 44 | | ------- | --------- | --------- | --------- | 45 | | humaneval | **34.32** | 29.44 | 27.80 | 46 | | mbpp | **38.65** | 37.60 | 34.16 | 47 | | multiple-js | **33.17** | 31.30 | 27.02 | 48 | | multiple-java | **30.43** | 29.24 | 24.30 | 49 | | multiple-cpp | **28.21** | 27.33 | 23.04 | 50 | | multiple-swift | 24.30 | **25.32** | 15.70 | 51 | | multiple-php | **30.87** | 25.96 | 22.11 | 52 | | multiple-d | 8.85 | **11.60** | 8.08 | 53 | | multiple-jl | 22.08 | **25.28** | 22.96 | 54 | | multiple-lua | 22.39 | **30.50** | 22.92 | 55 | | multiple-r | **20.52** | 18.57 | 14.29 | 56 | | multiple-rkt | **17.20** | 12.55 | 10.43 | 57 | | multiple-rs | 24.55 | **25.90** | 22.82 | 58 | 59 | ## Requirements 60 | 61 | - python 3.8 and above 62 | - pytorch 2.0 and above are recommended 63 | - transformers 4.32 and above 64 | - CUDA 11.8 and above are recommended (this is for GPU users, flash-attention users, etc.) 65 | 66 | ## Quickstart 67 | 68 | CodeShell系列模型已经上传至 Hugging Face,开发者可以通过Transformers快速调用CodeShell和CodeShell-Chat。 69 | 70 | 在开始之前,请确保已经正确设置了环境,并安装了必要的代码包,以及满足上一小节的环境要求。你可以通过下列代码快速安装相关依赖。 71 | 72 | ``` 73 | pip install -r requirements.txt 74 | ``` 75 | 76 | 接下来你可以通过Transformers使用CodeShell。 77 | 78 | ### Code Generation 79 | 80 | 开发者可以使用CodeShell快速生成代码,加速开发效率。 81 | 82 | ```python 83 | import torch 84 | from transformers import AutoModelForCausalLM, AutoTokenizer 85 | 86 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 87 | tokenizer = AutoTokenizer.from_pretrained("WisdomShell/CodeShell-7B") 88 | model = AutoModelForCausalLM.from_pretrained("WisdomShell/CodeShell-7B", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) 89 | inputs = tokenizer('def merge_sort():', return_tensors='pt').to(device) 90 | outputs = model.generate(**inputs) 91 | print(tokenizer.decode(outputs[0])) 92 | ``` 93 | 94 | - Fill in the Moddle 95 | 96 | CodeShell 支持Fill-in-the-Middle模式,从而更好的支持软件开发过程。 97 | 98 | ```python 99 | input_text = "def print_hello_world():\n \n print('Hello world!')" 100 | inputs = tokenizer(input_text, return_tensors='pt').to(device) 101 | outputs = model.generate(**inputs) 102 | print(tokenizer.decode(outputs[0])) 103 | ``` 104 | 105 | - 代码问答 106 | 107 | CodeShell同时开源了代码助手模型CodeShell-7B-Chat,开发者可以通过下列代码与模型进行交互。 108 | 109 | ```python 110 | model = AutoModelForCausalLM.from_pretrained('WisdomShell/CodeShell-7B-Chat', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) 111 | tokenizer = AutoTokenizer.from_pretrained('WisdomShell/CodeShell-7B-Chat') 112 | 113 | history = [] 114 | query = '你是谁?' 115 | response = model.chat(query, history, tokenizer) 116 | print(response) 117 | history.append((query, response)) 118 | 119 | query = '用Python写一个HTTP server' 120 | response = model.chat(query, history, tokenizer) 121 | print(response) 122 | history.append((query, response)) 123 | ``` 124 | 125 | 开发者也可以通过VS Code与JetBrains插件与CodeShell-7B-Chat交互,详情请参[VSCode插件仓库](https://github.com/WisdomShell/codeshell-vscode)与[IntelliJ插件仓库](https://github.com/WisdomShell/codeshell-intellij)。 126 | 127 | 128 | - Model Quantization 129 | 130 | CodeShell 支持4 bit/8 bit量化,4 bit量化后,占用显存大小约6G,用户可以在显存较小的GPU上使用CodeShell。 131 | 132 | ```python 133 | model = AutoModelForCausalLM.from_pretrained('WisdomShell/CodeShell-7B-Chat-int4', trust_remote_code=True).to(device) 134 | tokenizer = AutoTokenizer.from_pretrained('WisdomShell/CodeShell-7B-Chat-int4') 135 | ``` 136 | 137 | - CodeShell in c/c++ 138 | 139 | 由于大部分个人电脑没有GPU,CodeShell提供了C/C++版本的推理支持,开发者可以根据本地环境进行编译与使用,详见[CodeShell C/C++本地化版](https://github.com/WisdomShell/llama_cpp_for_codeshell)。 140 | 141 | ## Demo 142 | 143 | 我们提供了Web-UI、命令行、API、IDE四种形式的Demo。 144 | 145 | ### Web UI 146 | 147 | 开发者通过下列命令启动Web服务,服务启动后,可以通过`https://127.0.0.1:8000`进行访问。 148 | 149 | ``` 150 | python demos/web_demo.py 151 | ``` 152 | 153 | ### CLI Demo 154 | 155 | 我们也提供了命令行交互的Demo版本,开发者可以通过下列命令运行。 156 | 157 | ``` 158 | python demos/cli_demo.py 159 | ``` 160 | 161 | ### API 162 | 163 | CodeShell也提供了基于OpenAI API的部署方法。 164 | 165 | ``` 166 | python demos/openai_api.py 167 | ``` 168 | 169 | 启动后即可通过HTTP请求与CodeShell交互。 170 | 171 | ``` 172 | curl http://127.0.0.1:8000/v1/chat/completions \ 173 | -H "Content-Type: application/json" \ 174 | -d '{ 175 | "model": "CodeShell-7B-Chat", 176 | "messages": [ 177 | { 178 | "role": "user", 179 | "content": "你好" 180 | } 181 | ] 182 | }' 183 | ``` 184 | 185 | ### IDE 186 | 187 | CodeShell最后提供了线上IDE,开发者可以通过IDE进行代码补全、代码问答等操作。同时,IDE插件也同时发布,开发者可以自行在本地进行安装使用。插件相关问题欢迎在[VSCode插件仓库](https://github.com/WisdomShell/codeshell-vscode)与[IntelliJ插件仓库](https://github.com/WisdomShell/codeshell-intellij)中讨论。 188 | 189 | ## Model Details 190 | 191 | Code Shell使用GPT-2作为基础架构,采用Grouped-Query Attention、RoPE相对位置编码等技术。 192 | 193 | ### Hyper-parameter 194 | 195 | | Hyper-parameter | Value | 196 | |---|---| 197 | | n_layer | 42 | 198 | | n_embd | 4096 | 199 | | n_inner | 16384 | 200 | | n_head | 32 | 201 | | num_query_groups | 8 | 202 | | seq-length | 8192 | 203 | | vocab_size | 70144 | 204 | 205 | 206 | ### Data 207 | 208 | CodeShell基于自己爬取的Github数据、Big Code开源的Stack和StarCoder数据集、以及少量高质量的中英文数据进行训练。在原始数据集的基础上,CodeShell采用基于Minihash对数据去重,基于KenLM以及高质量数据筛选模型对数据进行了过滤与筛选,最终得到高质量的预训练数据集。 209 | 210 | ### Tokenizer 211 | 212 | CodeShell基于Starcoder词表进行了优化,去除了使用频率较低的词语,并添加了部分中文词表,显著提升了中文的压缩率,为Chat版本的训练提供了基础。 213 | 214 | 215 | | Tokenizer | Size | Chinese | English | Code | Total| 216 | |---|---|---|---|---|---| 217 | | Starcoder | 49152 | 1.22 | 3.47 | 3.30 | 2.66 | 218 | | CodeShell | 70020 | 1.50 | 3.47 | 3.30 | 2.95 | 219 | 220 | 221 | ## License 222 | 223 | 社区使用CodeShell模型需要遵循[《CodeShell模型许可协议》](https://github.com/WisdomShell/codeshell/blob/main/License.pdf)及[Apache 2.0许可协议](https://www.apache.org/licenses/LICENSE-2.0)。CodeShell模型允许用于商业用途,但如果您计划将CodeShell模型或其派生产品用于商业用途,需要您确认主体符合以下条件: 224 | 225 | 1. 关联方的服务或产品的每日平均活跃用户数(DAU)不能超过100万。 226 | 2. 关联方不得是软件服务提供商或云服务提供商。 227 | 3. 关联方不存在将获得授予的商业许可,在未经许可的前提下将其再授权给其他第三方的可能性。 228 | 229 | 在满足上述条件的前提下,您需要通过向codeshell.opensource@gmail.com发送电子邮件,提交《CodeShell模型许可协议》要求的申请材料。经审核通过后,将授予您一个全球的、非排他的、不可转让的、不可再授权的商业版权许可。 230 | 231 | 232 | ## Star History 233 | 234 | [![Star History Chart](https://api.star-history.com/svg?repos=WisdomShell/codeshell&type=Date)](https://star-history.com/#WisdomShell/codeshell&Date) 235 | 236 | -------------------------------------------------------------------------------- /README_EN.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | 6 |

7 | 🤗 Hugging Face • 🤖 ModelScope • ⭕️ WiseModel • 🌐 PKU-KCL 8 |

9 | 10 |
11 | 12 | [![license](https://img.shields.io/github/license/modelscope/modelscope.svg)](https://github.com/WisdomShell/codeshell/blob/main/License.pdf) 13 |

14 |

Chinese|English

15 |

16 |
17 | 18 | ## Introduction 19 | 20 | CodeShell is a code large language model (LLM) developed jointly by the [Knowledge Computing Lab at Peking University](http://se.pku.edu.cn/kcl/) and the AI team of Sichuan Tianfu Bank. CodeShell has 7 billion parameters, was trained on 500 billion tokens, and has a context window length of 8192. On authoritative code evaluation benchmarks (HumanEval and MBPP), CodeShell achieves the best performance for models of its scale. At the same time, we offer deployment solutions and IDE plugins that complement CodeShell. Please refer to the [CodeShell](https://github.com/WisdomShell/codeshell) repository for details. 21 | 22 | The open-source models are as follows: 23 | 24 | - CodeShell Base: The foundational model of CodeShell with strong coding capabilities. 25 | - CodeShell Chat: A dialogue model of CodeShell that excels in code Q&A, code completion, and other downstream tasks. 26 | - CodeShell Chat 4bit: A 4bit quantized version of the CodeShell dialogue model. While preserving model performance, it consumes less memory and operates faster. 27 | - CodeShell CPP: A C++ version of the CodeShell dialogue model. It allows developers to use it on personal computers without GPUs. Note that the CPP version also supports quantization, allowing users to run CodeShell on PCs with a minimum of 8GB RAM. 28 | 29 | ## Main Characteristics of CodeShell 30 | 31 | - **Powerful Performance**: CodeShell achieves optimal performance in 7B code base models on HumanEval and MBPP. 32 | - **Complete Ecosystem**: In addition to the code model, IDE plugins for open-source (VS Code and JetBrains) are provided, forming a complete open-source technology stack. 33 | - **Lightweight Deployment**: Supports local C++ deployment, providing a lightweight and fast local software development assistant solution. 34 | - **Comprehensive Evaluation**: A multi-task evaluation system that supports a complete project context and covers common software development activities such as code generation, code defect detection and repair, and test case generation will be open-sourced soon. 35 | - **Efficient Training**: Based on an efficient data governance system, CodeShell achieved excellent performance after training only 500 billion tokens from a complete cold start. 36 | 37 | ## Performance 38 | 39 | We selected the two most popular code evaluation datasets (HumanEval and MBPP) to evaluate the model. Compared with the two most advanced 7B code models, CodeLlama and Starcoder, Codeshell achieved optimal results. The specific evaluation results are as follows. 40 | 41 | | Task | CodeShell-7b | CodeLlama-7b | Starcoder-7b | 42 | | ------- | --------- | --------- | --------- | 43 | | humaneval | **34.32** | 29.44 | 27.80 | 44 | | mbpp | **38.65** | 37.60 | 34.16 | 45 | | multiple-js | **33.17** | 31.30 | 27.02 | 46 | | multiple-java | **30.43** | 29.24 | 24.30 | 47 | | multiple-cpp | **28.21** | 27.33 | 23.04 | 48 | | multiple-swift | 24.30 | **25.32** | 15.70 | 49 | | multiple-php | **30.87** | 25.96 | 22.11 | 50 | | multiple-d | 8.85 | **11.60** | 8.08 | 51 | | multiple-jl | 22.08 | **25.28** | 22.96 | 52 | | multiple-lua | 22.39 | **30.50** | 22.92 | 53 | | multiple-r | **20.52** | 18.57 | 14.29 | 54 | | multiple-rkt | **17.20** | 12.55 | 10.43 | 55 | | multiple-rs | 24.55 | **25.90** | 22.82 | 56 | 57 | ## Requirements 58 | 59 | - python 3.8 and above 60 | - pytorch 2.0 and above are recommended 61 | - transformers 4.32 and above 62 | - CUDA 11.8 and above are recommended (for GPU users, flash-attention users, etc.) 63 | 64 | ## Quickstart 65 | 66 | The CodeShell series models have been uploaded to Hugging Face. Developers can quickly call CodeShell and CodeShell-Chat through Transformers. 67 | 68 | Before starting, make sure you have set up the environment correctly, installed the necessary packages, and meet the environmental requirements from the previous section. The necessary dependencies can be installed quickly using the following code: 69 | 70 | pip install -r requirements.txt 71 | 72 | 73 | Next, you can use CodeShell through Transformers. 74 | 75 | ### Code Generation 76 | 77 | Developers can use CodeShell to quickly generate code, accelerating development efficiency. 78 | 79 | ```python 80 | import torch 81 | from transformers import AutoModelForCausalLM, AutoTokenizer 82 | 83 | tokenizer = AutoTokenizer.from_pretrained("WisdomShell/CodeShell-7B") 84 | model = AutoModelForCausalLM.from_pretrained("WisdomShell/CodeShell-7B", trust_remote_code=True, torch_dtype=torch.bfloat16).cuda() 85 | inputs = tokenizer('def merge_sort():', return_tensors='pt').cuda() 86 | outputs = model.generate(inputs) 87 | print(tokenizer.decode(outputs[0])) 88 | ``` 89 | 90 | ### Fill in the Middle 91 | CodeShell supports the Fill-in-the-Middle mode to better assist the software development process. 92 | 93 | ```python 94 | input_text = "def print_hello_world():\n \n print('Hello world!')" 95 | inputs = tokenizer(input_text, return_tensors='pt').cuda() 96 | outputs = model.generate(inputs) 97 | print(tokenizer.decode(outputs[0])) 98 | ``` 99 | 100 | ### Code Q&A 101 | CodeShell has also open-sourced the CodeShell-7B-Chat code assistant model. Developers can interact with the model using the following code. 102 | 103 | ```python 104 | model = AutoModelForCausalLM.from_pretrained('WisdomShell/CodeShell-7B-Chat', trust_remote_code=True, torch_dtype=torch.bfloat16).cuda() 105 | tokenizer = AutoTokenizer.from_pretrained('WisdomShell/CodeShell-7B-Chat') 106 | 107 | history = [] 108 | query = 'Who are you?' 109 | response = model.chat(query, history, tokenizer) 110 | print(response) 111 | history.append((query, response)) 112 | 113 | query = 'Write an HTTP server in Python' 114 | response = model.chat(query, history, tokenizer) 115 | print(response) 116 | history.append((query, response)) 117 | ``` 118 | 119 | Developers can also interact with CodeShell-7B-Chat through VS Code and JetBrains plugins. For details, please refer to the VSCode plugin repository and IntelliJ plugin repository. 120 | 121 | ### Model Quantization 122 | CodeShell supports 4 bit/8 bit quantization. After 4-bit quantization, the memory footprint is approximately 6GB, allowing users to use CodeShell on GPUs with smaller memory. 123 | 124 | ```python 125 | model = AutoModelForCausalLM.from_pretrained('WisdomShell/CodeShell-7B-Chat-int4', trust_remote_code=True).to(device) 126 | tokenizer = AutoTokenizer.from_pretrained('WisdomShell/CodeShell-7B-Chat-int4') 127 | ``` 128 | 129 | ### CodeShell in c/c++ 130 | As most personal computers lack a GPU, CodeShell offers C/C++ inference support. Developers can compile based on the local environment. See CodeShell C/C++ local version. After compilation, the Web API service can be started with the following command. 131 | 132 | 133 | ## Demo 134 | We offer demos in four forms: Web-UI, command line, API, and IDE. 135 | 136 | ### Web UI 137 | Developers can start the Web service using the following command. After the service starts, it can be accessed at https://127.0.0.1:8000. 138 | 139 | ``` 140 | python demos/web_demo.py 141 | ``` 142 | 143 | ### CLI Demo 144 | 145 | We also offer a command-line interactive demo version. Developers can run it using the following command. 146 | 147 | ``` 148 | python cli_demo.py 149 | ``` 150 | 151 | ### API 152 | 153 | CodeShell also offers a deployment method based on the OpenAI API. 154 | 155 | ``` 156 | python openai_api.py 157 | ``` 158 | 159 | Then you can interact with CodeShell via HTTP requests: 160 | 161 | ``` 162 | curl http://127.0.0.1:8000/v1/chat/completions \ 163 | -H "Content-Type: application/json" \ 164 | -d '{ 165 | "model": "CodeShell-7B-Chat", 166 | "messages": [ 167 | { 168 | "role": "user", 169 | "content": "你好" 170 | } 171 | ] 172 | }' 173 | ``` 174 | 175 | ### IDE 176 | 177 | Finally, CodeShell offers an online IDE. Developers can use the IDE for code completion, code Q&A, and other operations. IDE plugins are also released, and developers can install and use them locally. For plugin-related issues, please discuss in the VS Code plugin repository. 178 | 179 | ## Model Details 180 | 181 | Code Shell uses GPT-2 as its basic architecture and employs technologies like Grouped-Query Attention and RoPE relative position encoding. 182 | 183 | ### Hyper-parameter 184 | 185 | | Hyper-parameter | Value | 186 | |---|---| 187 | | n_layer | 42 | 188 | | n_embd | 4096 | 189 | | n_inner | 16384 | 190 | | n_head | 32 | 191 | | num_query_groups | 8 | 192 | | seq-length | 8192 | 193 | | vocab_size | 70144 | 194 | 195 | 196 | ### Data 197 | 198 | CodeShell was trained based on its own scraped Github data, the open-source Stack and StarCoder datasets from Big Code, as well as a small amount of high-quality Chinese and English data. On top of the original dataset, CodeShell used Minihash for data deduplication, KenLM, and a high-quality data selection model for data filtering and selection, resulting in a high-quality pre-training dataset. 199 | 200 | ### Tokenizer 201 | 202 | CodeShell optimized the Starcoder vocabulary by removing infrequently used words and adding some Chinese vocabulary, significantly improving the Chinese compression rate, laying the groundwork for the training of the Chat version. 203 | 204 | 205 | | Tokenizer | Size | Chinese | English | Code | Total| 206 | |---|---|---|---|---|---| 207 | | Starcoder | 49152 | 1.22 | 3.47 | 3.30 | 2.66 | 208 | | CodeShell | 70144 | 1.50 | 3.47 | 3.30 | 2.95 | 209 | 210 | 211 | ## License 212 | The community's use of the CodeShell model must adhere to the ["CodeShell Model License Agreement" ](https://github.com/WisdomShell/codeshell/blob/main/License.pdf) and the [ Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0). CodeShell is permitted for commercial use. However, if you plan to use the CodeShell model or its derivative products for commercial purposes, you must confirm that the entity meets the following conditions: 213 | 214 | - The daily average active user count (DAU) of the affiliated party's service or product cannot exceed 1 million. 215 | - The affiliated party must not be a software service provider or cloud service provider. 216 | - There is no possibility for the affiliated party to re-license the granted commercial license to another third party without proper authorization. 217 | 218 | Under the aforementioned conditions, you need to submit the application materials required by the "CodeShell Model License Agreement" by sending an email to codeshell.opensource@gmail.com. After approval, you will be granted a global, non-exclusive, non-transferable, non-sublicensable commercial copyright license. 219 | 220 | ## Star History 221 | 222 | [![Star History Chart](https://api.star-history.com/svg?repos=WisdomShell/codeshell&type=Date)](https://star-history.com/#WisdomShell/codeshell&Date) 223 | 224 | -------------------------------------------------------------------------------- /demos/cli_demo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This code is based on Qwen's Cli Demo. It has been modified from 17 | # its original forms to accommodate CodeShell. 18 | 19 | # Copyright (c) Alibaba Cloud. 20 | # 21 | # This source code is licensed under the license found in the 22 | # LICENSE file in the root directory of this source tree. 23 | 24 | """A simple command-line interactive chat demo.""" 25 | 26 | import argparse 27 | import os 28 | import platform 29 | import shutil 30 | from copy import deepcopy 31 | 32 | import torch 33 | from transformers import AutoModelForCausalLM, AutoTokenizer 34 | from transformers.generation import GenerationConfig 35 | from transformers.trainer_utils import set_seed 36 | from transformers import StoppingCriteria, StoppingCriteriaList 37 | 38 | DEFAULT_CKPT_PATH = 'WisdomShell/CodeShell-7B-Chat' 39 | 40 | _WELCOME_MSG = '''\ 41 | Welcome to use CodeShell-Chat model, type text to start chat, type :h to show command help. 42 | (欢迎使用 CodeShell-Chat 模型,输入内容即可进行对话,:h 显示命令帮助。) 43 | 44 | Note: This demo is governed by the original license of CodeShell. 45 | We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc. 46 | (注:本演示受CodeShell的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。) 47 | ''' 48 | _HELP_MSG = '''\ 49 | Commands: 50 | :help / :h Show this help message 显示帮助信息 51 | :exit / :quit / :q Exit the demo 退出Demo 52 | :clear / :cl Clear screen 清屏 53 | :clear-his / :clh Clear history 清除对话历史 54 | :history / :his Show history 显示对话历史 55 | :seed Show current random seed 显示当前随机种子 56 | :seed Set random seed to 设置随机种子 57 | :conf Show current generation config 显示生成配置 58 | :conf = Change generation config 修改生成配置 59 | ''' 60 | 61 | 62 | def _load_model_tokenizer(args): 63 | tokenizer = AutoTokenizer.from_pretrained( 64 | args.checkpoint_path, trust_remote_code=True, resume_download=True, 65 | ) 66 | 67 | model = AutoModelForCausalLM.from_pretrained( 68 | args.checkpoint_path, 69 | device_map=args.device, 70 | trust_remote_code=True, 71 | resume_download=True, 72 | torch_dtype=torch.bfloat16 73 | ).eval() 74 | 75 | config = GenerationConfig.from_pretrained( 76 | args.checkpoint_path, trust_remote_code=True, resume_download=True, 77 | ) 78 | 79 | return model, tokenizer, config 80 | 81 | 82 | def _gc(): 83 | import gc 84 | gc.collect() 85 | if torch.cuda.is_available(): 86 | torch.cuda.empty_cache() 87 | 88 | 89 | def _clear_screen(): 90 | if platform.system() == "Windows": 91 | os.system("cls") 92 | else: 93 | os.system("clear") 94 | 95 | 96 | def _print_history(history): 97 | terminal_width = shutil.get_terminal_size()[0] 98 | print(f'History ({len(history)})'.center(terminal_width, '=')) 99 | for index, (query, response) in enumerate(history): 100 | print(f'User[{index}]: {query}') 101 | print(f'CodeShell[{index}]: {response}') 102 | print('=' * terminal_width) 103 | 104 | 105 | def _get_input() -> str: 106 | while True: 107 | try: 108 | message = input('User> ').strip() 109 | except UnicodeDecodeError: 110 | print('[ERROR] Encoding error in input') 111 | continue 112 | except KeyboardInterrupt: 113 | exit(1) 114 | if message: 115 | return message 116 | print('[ERROR] Query is empty') 117 | 118 | def main(): 119 | parser = argparse.ArgumentParser( 120 | description='CodeShell-Chat command-line interactive chat demo.') 121 | parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, 122 | help="Checkpoint name or path, default to %(default)r") 123 | parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed") 124 | parser.add_argument("--device", type=str, default="cuda:0",help="Device name.") 125 | 126 | args = parser.parse_args() 127 | 128 | history, response = [], '' 129 | 130 | model, tokenizer, config = _load_model_tokenizer(args) 131 | orig_gen_config = deepcopy(model.generation_config) 132 | 133 | 134 | _clear_screen() 135 | print(_WELCOME_MSG) 136 | 137 | seed = args.seed 138 | 139 | while True: 140 | query = _get_input() 141 | 142 | # Process commands. 143 | if query.startswith(':'): 144 | command_words = query[1:].strip().split() 145 | if not command_words: 146 | command = '' 147 | else: 148 | command = command_words[0] 149 | 150 | if command in ['exit', 'quit', 'q']: 151 | break 152 | elif command in ['clear', 'cl']: 153 | _clear_screen() 154 | print(_WELCOME_MSG) 155 | _gc() 156 | continue 157 | elif command in ['clear-history', 'clh']: 158 | print(f'[INFO] All {len(history)} history cleared') 159 | history.clear() 160 | _gc() 161 | continue 162 | elif command in ['help', 'h']: 163 | print(_HELP_MSG) 164 | continue 165 | elif command in ['history', 'his']: 166 | _print_history(history) 167 | continue 168 | elif command in ['seed']: 169 | if len(command_words) == 1: 170 | print(f'[INFO] Current random seed: {seed}') 171 | continue 172 | else: 173 | new_seed_s = command_words[1] 174 | try: 175 | new_seed = int(new_seed_s) 176 | except ValueError: 177 | print(f'[WARNING] Fail to change random seed: {new_seed_s!r} is not a valid number') 178 | else: 179 | print(f'[INFO] Random seed changed to {new_seed}') 180 | seed = new_seed 181 | continue 182 | elif command in ['conf']: 183 | if len(command_words) == 1: 184 | print(model.generation_config) 185 | else: 186 | for key_value_pairs_str in command_words[1:]: 187 | eq_idx = key_value_pairs_str.find('=') 188 | if eq_idx == -1: 189 | print('[WARNING] format: =') 190 | continue 191 | conf_key, conf_value_str = key_value_pairs_str[:eq_idx], key_value_pairs_str[eq_idx + 1:] 192 | try: 193 | conf_value = eval(conf_value_str) 194 | except Exception as e: 195 | print(e) 196 | continue 197 | else: 198 | print(f'[INFO] Change config: model.generation_config.{conf_key} = {conf_value}') 199 | setattr(model.generation_config, conf_key, conf_value) 200 | continue 201 | elif command in ['reset-conf']: 202 | print('[INFO] Reset generation config') 203 | model.generation_config = deepcopy(orig_gen_config) 204 | print(model.generation_config) 205 | continue 206 | else: 207 | # As normal query. 208 | pass 209 | 210 | # Run chat. 211 | set_seed(seed) 212 | try: 213 | for response in model.chat(query, history, tokenizer, generation_config=config, stream=True): 214 | response = response.replace('|end|', '') 215 | response = response.replace('||', '') 216 | _clear_screen() 217 | print(f"\nUser: {query}") 218 | print(f"\nCodeShell-Chat: {response}") 219 | except KeyboardInterrupt: 220 | print('[WARNING] Generation interrupted') 221 | continue 222 | history.append((query, response)) 223 | 224 | if __name__ == "__main__": 225 | main() 226 | -------------------------------------------------------------------------------- /demos/openai_api.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This code is based on Qwen's opeai_api Demo. It has been modified from 17 | # its original forms to accommodate CodeShell. 18 | 19 | # coding=utf-8 20 | # Implements API for CodeShell in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) 21 | # Usage: python openai_api.py 22 | # Visit http://localhost:8000/docs for documents. 23 | 24 | import re 25 | import copy 26 | import json 27 | import time 28 | from copy import deepcopy 29 | from argparse import ArgumentParser 30 | from contextlib import asynccontextmanager 31 | from typing import Dict, List, Literal, Optional, Union 32 | 33 | import torch 34 | import uvicorn 35 | from fastapi import FastAPI, HTTPException 36 | from fastapi.middleware.cors import CORSMiddleware 37 | from pydantic import BaseModel, Field 38 | from sse_starlette.sse import EventSourceResponse 39 | from transformers import AutoTokenizer, AutoModelForCausalLM 40 | from transformers.generation import GenerationConfig 41 | 42 | 43 | def _gc(forced: bool = False): 44 | global args 45 | if args.disable_gc and not forced: 46 | return 47 | 48 | import gc 49 | gc.collect() 50 | if torch.cuda.is_available(): 51 | torch.cuda.empty_cache() 52 | 53 | 54 | @asynccontextmanager 55 | async def lifespan(app: FastAPI): # collects GPU memory 56 | yield 57 | _gc(forced=True) 58 | 59 | app = FastAPI(lifespan=lifespan) 60 | 61 | app.add_middleware( 62 | CORSMiddleware, 63 | allow_origins=["*"], 64 | allow_credentials=True, 65 | allow_methods=["*"], 66 | allow_headers=["*"], 67 | ) 68 | 69 | 70 | class ModelCard(BaseModel): 71 | id: str 72 | object: str = "model" 73 | created: int = Field(default_factory=lambda: int(time.time())) 74 | owned_by: str = "owner" 75 | root: Optional[str] = None 76 | parent: Optional[str] = None 77 | permission: Optional[list] = None 78 | 79 | 80 | class ModelList(BaseModel): 81 | object: str = "list" 82 | data: List[ModelCard] = [] 83 | 84 | class ChatMessage(BaseModel): 85 | role: Literal["user", "assistant"] 86 | content: Optional[str] 87 | 88 | class DeltaMessage(BaseModel): 89 | role: Optional[Literal["user", "assistant", "system"]] = None 90 | content: Optional[str] = None 91 | 92 | class ChatCompletionRequest(BaseModel): 93 | model: str 94 | messages: List[ChatMessage] 95 | functions: Optional[List[Dict]] = None 96 | temperature: Optional[float] = None 97 | top_p: Optional[float] = None 98 | max_length: Optional[int] = None 99 | stream: Optional[bool] = False 100 | stop: Optional[List[str]] = None 101 | 102 | 103 | class ChatCompletionResponseChoice(BaseModel): 104 | index: int 105 | message: ChatMessage 106 | finish_reason: Literal["stop", "length"] 107 | 108 | 109 | class ChatCompletionResponseStreamChoice(BaseModel): 110 | index: int 111 | delta: DeltaMessage 112 | finish_reason: Optional[Literal["stop", "length"]] 113 | 114 | 115 | class ChatCompletionResponse(BaseModel): 116 | model: str 117 | object: Literal["chat.completion", "chat.completion.chunk"] 118 | choices: List[ 119 | Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice] 120 | ] 121 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 122 | 123 | 124 | @app.get("/v1/models", response_model=ModelList) 125 | async def list_models(): 126 | global model_args 127 | model_card = ModelCard(id="CodeShell-7B-Chat") 128 | return ModelList(data=[model_card]) 129 | 130 | def trim_stop_words(response): 131 | response = response.replace('|end|', '') 132 | response = response.replace('||', '') 133 | return response 134 | 135 | _TEXT_COMPLETION_CMD = object() 136 | 137 | # 138 | # Temporarily, the system role does not work as expected. 139 | # We advise that you write the setups for role-play in your query, 140 | # i.e., use the user role instead of the system role. 141 | # 142 | # TODO: Use real system role when the model is ready. 143 | # 144 | def parse_messages(messages): 145 | if all(m.role != "user" for m in messages): 146 | raise HTTPException( 147 | status_code=400, 148 | detail=f"Invalid request: Expecting at least one user message.", 149 | ) 150 | 151 | messages = copy.deepcopy(messages) 152 | 153 | _messages = messages 154 | messages = [] 155 | for m_idx, m in enumerate(_messages): 156 | role, content = m.role, m.content 157 | if content: 158 | content = content.lstrip("\n").rstrip() 159 | if role == "assistant": 160 | if len(messages) == 0: 161 | raise HTTPException( 162 | status_code=400, 163 | detail=f"Invalid request: Expecting role user before role assistant.", 164 | ) 165 | last_msg = messages[-1].content 166 | last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0 167 | if messages[-1].role == "user": 168 | messages.append( 169 | ChatMessage(role="assistant", content=content.lstrip("\n").rstrip()) 170 | ) 171 | else: 172 | messages[-1].content += content 173 | elif role == "user": 174 | messages.append( 175 | ChatMessage(role="user", content=content.lstrip("\n").rstrip()) 176 | ) 177 | else: 178 | raise HTTPException( 179 | status_code=400, detail=f"Invalid request: Incorrect role {role}." 180 | ) 181 | 182 | query = _TEXT_COMPLETION_CMD 183 | if messages[-1].role == "user": 184 | query = messages[-1].content 185 | messages = messages[:-1] 186 | 187 | if len(messages) % 2 != 0: 188 | raise HTTPException(status_code=400, detail="Invalid request") 189 | 190 | history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)] 191 | for i in range(0, len(messages), 2): 192 | if messages[i].role == "user" and messages[i + 1].role == "assistant": 193 | usr_msg = messages[i].content.lstrip("\n").rstrip() 194 | bot_msg = messages[i + 1].content.lstrip("\n").rstrip() 195 | history.append([usr_msg, bot_msg]) 196 | else: 197 | raise HTTPException( 198 | status_code=400, 199 | detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.", 200 | ) 201 | return query, history 202 | 203 | 204 | def parse_response(response): 205 | func_name, func_args = "", "" 206 | i = response.rfind("\nAction:") 207 | j = response.rfind("\nAction Input:") 208 | k = response.rfind("\nObservation:") 209 | if 0 <= i < j: # If the text has `Action` and `Action input`, 210 | if k < j: # but does not contain `Observation`, 211 | # then it is likely that `Observation` is omitted by the LLM, 212 | # because the output text may have discarded the stop word. 213 | response = response.rstrip() + "\nObservation:" # Add it back. 214 | k = response.rfind("\nObservation:") 215 | func_name = response[i + len("\nAction:") : j].strip() 216 | func_args = response[j + len("\nAction Input:") : k].strip() 217 | if func_name: 218 | choice_data = ChatCompletionResponseChoice( 219 | index=0, 220 | message=ChatMessage( 221 | role="assistant", 222 | content=response[:i], 223 | function_call={"name": func_name, "arguments": func_args}, 224 | ), 225 | finish_reason="function_call", 226 | ) 227 | return choice_data 228 | z = response.rfind("\nFinal Answer: ") 229 | if z >= 0: 230 | response = response[z + len("\nFinal Answer: ") :] 231 | choice_data = ChatCompletionResponseChoice( 232 | index=0, 233 | message=ChatMessage(role="assistant", content=response), 234 | finish_reason="stop", 235 | ) 236 | return choice_data 237 | 238 | 239 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 240 | async def create_chat_completion(request: ChatCompletionRequest): 241 | global model, tokenizer 242 | 243 | generation_config = copy.deepcopy(model.generation_config) 244 | if request.temperature is not None: 245 | if request.temperature < 0.01: 246 | generation_config.top_k = 1 # greedy decoding 247 | else: 248 | # Not recommended. Please tune top_p instead. 249 | generation_config.temperature = request.temperature 250 | if request.top_p is not None: 251 | generation_config.top_p = request.top_p 252 | 253 | query, history = parse_messages(request.messages) 254 | 255 | if request.stream: 256 | if request.functions: 257 | raise HTTPException( 258 | status_code=400, 259 | detail="Invalid request: Function calling is not yet implemented for stream mode.", 260 | ) 261 | generate = predict(query, history, request.model, generation_config) 262 | return EventSourceResponse(generate, media_type="text/event-stream") 263 | 264 | if query is _TEXT_COMPLETION_CMD: 265 | raise HTTPException( 266 | status_code=400, 267 | detail="Invalid request: COMPLETION model not supported now.", 268 | ) 269 | else: 270 | response = model.chat( 271 | query, 272 | history, 273 | tokenizer, 274 | generation_config=generation_config 275 | ) 276 | print(f"\n{history}\n{query}\n\n{response}\n") 277 | _gc() 278 | 279 | response = trim_stop_words(response) 280 | if request.functions: 281 | choice_data = parse_response(response) 282 | else: 283 | choice_data = ChatCompletionResponseChoice( 284 | index=0, 285 | message=ChatMessage(role="assistant", content=response), 286 | finish_reason="stop", 287 | ) 288 | return ChatCompletionResponse( 289 | model=request.model, choices=[choice_data], object="chat.completion" 290 | ) 291 | 292 | 293 | async def predict( 294 | query: str, history: List[List[str]], model_id: str, generation_config: GenerationConfig 295 | ): 296 | global model, tokenizer 297 | choice_data = ChatCompletionResponseStreamChoice( 298 | index=0, delta=DeltaMessage(role="assistant"), finish_reason=None 299 | ) 300 | chunk = ChatCompletionResponse( 301 | model=model_id, choices=[choice_data], object="chat.completion.chunk" 302 | ) 303 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 304 | 305 | current_length = 0 306 | response_generator = model.chat(query, history, tokenizer, stream=True, generation_config=generation_config) 307 | 308 | for new_response in response_generator: 309 | if len(new_response) == current_length: 310 | continue 311 | 312 | new_text = new_response[current_length:] 313 | current_length = len(new_response) 314 | 315 | choice_data = ChatCompletionResponseStreamChoice( 316 | index=0, delta=DeltaMessage(content=new_text), finish_reason=None 317 | ) 318 | chunk = ChatCompletionResponse( 319 | model=model_id, choices=[choice_data], object="chat.completion.chunk" 320 | ) 321 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 322 | 323 | choice_data = ChatCompletionResponseStreamChoice( 324 | index=0, delta=DeltaMessage(), finish_reason="stop" 325 | ) 326 | chunk = ChatCompletionResponse( 327 | model=model_id, choices=[choice_data], object="chat.completion.chunk" 328 | ) 329 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 330 | yield "[DONE]" 331 | 332 | _gc() 333 | 334 | 335 | def _get_args(): 336 | parser = ArgumentParser() 337 | parser.add_argument( 338 | "-c", 339 | "--checkpoint-path", 340 | type=str, 341 | default="WisdomShell/CodeShell-7B-Chat", 342 | help="Checkpoint name or path, default to %(default)r", 343 | ) 344 | parser.add_argument("--device", type=str, default="cuda:0",help="Device name.") 345 | parser.add_argument( 346 | "--server-port", type=int, default=8000, help="Demo server port." 347 | ) 348 | parser.add_argument( 349 | "--server-name", 350 | type=str, 351 | default="127.0.0.1", 352 | help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer." 353 | " If you want other computers to access your server, use 0.0.0.0 instead.", 354 | ) 355 | parser.add_argument("--disable-gc", action="store_true", 356 | help="Disable GC after each response generated.") 357 | 358 | args = parser.parse_args() 359 | return args 360 | 361 | 362 | if __name__ == "__main__": 363 | args = _get_args() 364 | 365 | tokenizer = AutoTokenizer.from_pretrained( 366 | args.checkpoint_path, 367 | trust_remote_code=True, 368 | resume_download=True, 369 | ) 370 | 371 | model = AutoModelForCausalLM.from_pretrained( 372 | args.checkpoint_path, 373 | device_map=args.device, 374 | trust_remote_code=True, 375 | resume_download=True, 376 | torch_dtype=torch.bfloat16 377 | ).eval() 378 | 379 | model.generation_config = GenerationConfig.from_pretrained( 380 | args.checkpoint_path, 381 | trust_remote_code=True, 382 | resume_download=True, 383 | ) 384 | 385 | uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1) -------------------------------------------------------------------------------- /demos/requirements_web_demo.txt: -------------------------------------------------------------------------------- 1 | gradio<3.42 2 | mdtex2html -------------------------------------------------------------------------------- /demos/web_demo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This code is based on Qwen's web Demo. It has been modified from 17 | # its original forms to accommodate CodeShell. 18 | 19 | # Copyright (c) Alibaba Cloud. 20 | # 21 | # This source code is licensed under the license found in the 22 | # LICENSE file in the root directory of this source tree. 23 | 24 | """A simple web interactive chat demo based on gradio.""" 25 | import os 26 | from argparse import ArgumentParser 27 | 28 | import gradio as gr 29 | import mdtex2html 30 | 31 | import torch 32 | from transformers import AutoModelForCausalLM, AutoTokenizer 33 | from transformers.generation import GenerationConfig 34 | 35 | 36 | DEFAULT_CKPT_PATH = 'WisdomShell/CodeShell-7B-Chat' 37 | 38 | 39 | def _get_args(): 40 | parser = ArgumentParser() 41 | parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH, 42 | help="Checkpoint name or path, default to %(default)r") 43 | parser.add_argument("--device", type=str, default="cuda:0", help="GPU device.") 44 | 45 | parser.add_argument("--share", action="store_true", default=False, 46 | help="Create a publicly shareable link for the interface.") 47 | parser.add_argument("--inbrowser", action="store_true", default=False, 48 | help="Automatically launch the interface in a new tab on the default browser.") 49 | parser.add_argument("--server-port", type=int, default=8000, 50 | help="Demo server port.") 51 | parser.add_argument("--server-name", type=str, default="127.0.0.1", 52 | help="Demo server name.") 53 | 54 | args = parser.parse_args() 55 | return args 56 | 57 | 58 | def _load_model_tokenizer(args): 59 | tokenizer = AutoTokenizer.from_pretrained( 60 | args.checkpoint_path, trust_remote_code=True, resume_download=True, 61 | ) 62 | 63 | model = AutoModelForCausalLM.from_pretrained( 64 | args.checkpoint_path, 65 | device_map=args.device, 66 | trust_remote_code=True, 67 | resume_download=True, 68 | torch_dtype=torch.bfloat16 69 | ).eval() 70 | 71 | config = GenerationConfig.from_pretrained( 72 | args.checkpoint_path, trust_remote_code=True, resume_download=True, 73 | ) 74 | 75 | return model, tokenizer, config 76 | 77 | 78 | def postprocess(self, y): 79 | if y is None: 80 | return [] 81 | for i, (message, response) in enumerate(y): 82 | y[i] = ( 83 | None if message is None else mdtex2html.convert(message), 84 | None if response is None else mdtex2html.convert(response), 85 | ) 86 | return y 87 | 88 | 89 | gr.Chatbot.postprocess = postprocess 90 | 91 | 92 | def _parse_text(text): 93 | lines = text.split("\n") 94 | lines = [line for line in lines if line != ""] 95 | count = 0 96 | for i, line in enumerate(lines): 97 | if "```" in line: 98 | count += 1 99 | items = line.split("`") 100 | if count % 2 == 1: 101 | lines[i] = f'
'
102 |             else:
103 |                 lines[i] = f"
" 104 | else: 105 | if i > 0: 106 | if count % 2 == 1: 107 | line = line.replace("`", r"\`") 108 | line = line.replace("<", "<") 109 | line = line.replace(">", ">") 110 | line = line.replace(" ", " ") 111 | line = line.replace("*", "*") 112 | line = line.replace("_", "_") 113 | line = line.replace("-", "-") 114 | line = line.replace(".", ".") 115 | line = line.replace("!", "!") 116 | line = line.replace("(", "(") 117 | line = line.replace(")", ")") 118 | line = line.replace("$", "$") 119 | lines[i] = "
" + line 120 | text = "".join(lines) 121 | return text 122 | 123 | 124 | def _gc(): 125 | import gc 126 | gc.collect() 127 | if torch.cuda.is_available(): 128 | torch.cuda.empty_cache() 129 | 130 | 131 | def _launch_demo(args, model, tokenizer, config): 132 | 133 | def predict(_query, _chatbot, _task_history): 134 | print(f"User: {_parse_text(_query)}") 135 | _chatbot.append((_parse_text(_query), "")) 136 | full_response = "" 137 | 138 | for response in model.chat(_query, _task_history, tokenizer, generation_config=config, stream=True): 139 | response = response.replace('|end|', '') 140 | response = response.replace('||', '') 141 | _chatbot[-1] = (_parse_text(_query), _parse_text(response)) 142 | 143 | yield _chatbot 144 | full_response = _parse_text(response) 145 | 146 | print(f"History: {_task_history}") 147 | _task_history.append((_query, full_response)) 148 | print(f"CodeShell-Chat: {_parse_text(full_response)}") 149 | 150 | def regenerate(_chatbot, _task_history): 151 | if not _task_history: 152 | yield _chatbot 153 | return 154 | item = _task_history.pop(-1) 155 | _chatbot.pop(-1) 156 | yield from predict(item[0], _chatbot, _task_history) 157 | 158 | def reset_user_input(): 159 | return gr.update(value="") 160 | 161 | def reset_state(_chatbot, _task_history): 162 | _task_history.clear() 163 | _chatbot.clear() 164 | _gc() 165 | return _chatbot 166 | 167 | with gr.Blocks() as demo: 168 | gr.Markdown("""
CodeShell-Chat Bot
""") 169 | 170 | chatbot = gr.Chatbot(label='CodeShell-Chat', elem_classes="control-height") 171 | query = gr.Textbox(lines=2, label='Input') 172 | task_history = gr.State([]) 173 | 174 | with gr.Row(): 175 | empty_btn = gr.Button("🧹 Clear History (清除历史)") 176 | submit_btn = gr.Button("🚀 Submit (发送)") 177 | regen_btn = gr.Button("🤔️ Regenerate (重试)") 178 | 179 | submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True) 180 | submit_btn.click(reset_user_input, [], [query]) 181 | empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True) 182 | regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True) 183 | 184 | gr.Markdown("""\ 185 | Note: This demo is governed by the original license of CodeShell. \ 186 | We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \ 187 | including hate speech, violence, pornography, deception, etc. \ 188 | (注:本演示受CodeShell的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\ 189 | 包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""") 190 | 191 | demo.queue().launch( 192 | share=args.share, 193 | inbrowser=args.inbrowser, 194 | server_port=args.server_port, 195 | server_name=args.server_name, 196 | ) 197 | 198 | 199 | def main(): 200 | args = _get_args() 201 | 202 | model, tokenizer, config = _load_model_tokenizer(args) 203 | 204 | _launch_demo(args, model, tokenizer, config) 205 | 206 | 207 | if __name__ == '__main__': 208 | main() -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # 代码评估 2 | 本文档将向您完整地介绍codeshell的代码评估过程,该评估脚本都是基于[bigcode-evaluation-harness](https://github.com/bigcode-project/bigcode-evaluation-harness)。 3 | 4 | ## 开始步骤 5 | 6 | 首先,复制bigcode-evaluation-harness仓库并导航至所在的文件夹内: 7 | 8 | ```bash 9 | git clone https://github.com/bigcode-project/bigcode-evaluation-harness.git 10 | cd bigcode-evaluation-harness 11 | ``` 12 | 13 | 接下来,依照您设备的规格,安装PyTorch,然后运行以下命令安装剩余的依赖: 14 | 15 | ```bash 16 | pip install -e . 17 | ``` 18 | 19 | 要使用评估脚本生成和评估任务,请按下述样例进行。确保您位于正确的目录中(`codeshell/evaluation`),然后依次执行两个 `run_eval.sh` 命令: 20 | 21 | ```bash 22 | cd codeshell/evaluation 23 | ./run_eval.sh local_gen humaneval $model_name_or_path $save_folder 24 | ./run_eval.sh eval humaneval $model_name_or_path $save_folder 25 | ``` 26 | -------------------------------------------------------------------------------- /evaluation/README_EN.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | This guide introduces the evaluation process of codeshell. The evaluation script is base on [bigcode-evaluation-harness](https://github.com/bigcode-project/bigcode-evaluation-harness). 3 | 4 | ## Quick Start 5 | 6 | Begin by cloning the bigcode-evaluation-harness repository and entering the directory: 7 | 8 | ```bash 9 | git clone https://github.com/bigcode-project/bigcode-evaluation-harness.git 10 | cd bigcode-evaluation-harness 11 | ``` 12 | 13 | Next, install PyTorch according to your device specifications, and then install the remaining packages using the command: 14 | 15 | ```bash 16 | pip install -e . 17 | ``` 18 | 19 | To generate and evaluate tasks with the evaluation script, follow the example below. Ensure you are in the appropriate directory (`codeshell/evaluation`), then execute the two `run_eval.sh` commands: 20 | 21 | ```bash 22 | cd codeshell/evaluation 23 | ./run_eval.sh local_gen humaneval $model_name_or_path $save_folder 24 | ./run_eval.sh eval humaneval $model_name_or_path $save_folder 25 | ``` 26 | 27 | By following this guide, you can now effectively utilize the bigcode-evaluation-harness to evaluate your model's performance on specific tasks. -------------------------------------------------------------------------------- /evaluation/all_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: MULTI_GPU 3 | downcast_bf16: 'no' 4 | gpu_ids: all 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: bf16 8 | num_machines: 1 9 | num_processes: 8 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /evaluation/chat_humaneval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Evaluating Large Language Models Trained on Code 17 | https://arxiv.org/abs/2107.03374 18 | 19 | The HumanEval dataset released by OpenAI includes 164 programming problems with a function signature, 20 | docstring, body, and several unit tests. 21 | They were handwritten to ensure not to be included in the training set of code generation models. 22 | 23 | Homepage: https://github.com/openai/human-eval 24 | """ 25 | 26 | from evaluate import load 27 | 28 | from lm_eval.base import Task 29 | 30 | _CITATION = """ 31 | @misc{chen2021evaluating, 32 | title={Evaluating Large Language Models Trained on Code}, 33 | author={Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba}, 34 | year={2021}, 35 | eprint={2107.03374}, 36 | archivePrefix={arXiv}, 37 | primaryClass={cs.LG} 38 | } 39 | """ 40 | 41 | 42 | class ChatHumanEval(Task): 43 | """A task represents an entire benchmark including its dataset, problems, 44 | answers, generation settings and evaluation methods. 45 | """ 46 | 47 | DATASET_PATH = "openai_humaneval" 48 | 49 | def __init__(self): 50 | super().__init__( 51 | stop_words=["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif", "\n\n\n\n", "\n\n\n"], 52 | requires_execution=True, 53 | ) 54 | 55 | def get_dataset(self): 56 | """Returns dataset for the task or an iterable of any object, that get_prompt can handle""" 57 | return self.dataset["test"] 58 | 59 | def get_prompt(self, doc): 60 | """Builds the prompt for the LM to generate from.""" 61 | # return doc["prompt"].strip() 62 | PROMPT = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n### Instruction:\nCreate a Python script for this problem:\n{Question}\n### Response:\n{Question}""" 63 | 64 | return PROMPT.format(Question=doc["prompt"].strip()) 65 | 66 | def get_reference(self, doc): 67 | """Builds the reference solution for the doc (sample from the test dataset).""" 68 | test_func = doc["test"] 69 | entry_point = f"check({doc['entry_point']})" 70 | return "\n" + test_func + "\n" + entry_point 71 | 72 | @staticmethod 73 | def _stop_at_stop_token(decoded_string, stop_tokens): 74 | """ 75 | Produces the prefix of decoded_string that ends at the first occurrence of 76 | a stop_token. 77 | WARNING: the decoded_string *must not* include the prompt, which may have stop tokens 78 | itself. 79 | """ 80 | min_stop_index = len(decoded_string) 81 | for stop_token in stop_tokens: 82 | stop_index = decoded_string.find(stop_token) 83 | if stop_index != -1 and stop_index < min_stop_index: 84 | min_stop_index = stop_index 85 | return decoded_string[:min_stop_index] 86 | 87 | def postprocess_generation(self, generation, idx): 88 | """Defines the postprocessing for a LM generation. 89 | :param generation: str 90 | code generation from LM 91 | :param idx: int 92 | index of doc in the dataset to which the generation belongs 93 | (not used for Humaneval-Task) 94 | """ 95 | prompt = self.get_prompt(self.dataset["test"][idx]) 96 | # print("before postprocessing", generation) 97 | # print("prefix", prompt) 98 | generation = generation[len(prompt) :] 99 | # print("after postprocessing", generation) 100 | return self.dataset["test"][idx]["prompt"] + self._stop_at_stop_token(generation, self.stop_words) 101 | 102 | def process_results(self, generations, references): 103 | """Takes the list of LM generations and evaluates them against ground truth references, 104 | returning the metric for the generations. 105 | :param generations: list(list(str)) 106 | list of lists containing generations 107 | :param references: list(str) 108 | list of str containing refrences 109 | """ 110 | code_metric = load("code_eval") 111 | results, _ = code_metric.compute( 112 | references=references, 113 | predictions=generations, 114 | ) 115 | return results 116 | -------------------------------------------------------------------------------- /evaluation/eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import fnmatch 17 | import json 18 | 19 | import datasets 20 | import torch 21 | import transformers 22 | from accelerate import Accelerator 23 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser 24 | 25 | from lm_eval.arguments import EvalArguments 26 | from lm_eval.evaluator import Evaluator 27 | 28 | from lm_eval.tasks import (humaneval, mbpp, multiple) 29 | from .chat_humaneval import ChatHumanEval 30 | 31 | TASK_REGISTRY = { 32 | **multiple.create_all_tasks(), 33 | "humaneval": humaneval.HumanEval, 34 | "mbpp": mbpp.MBPP, 35 | "chat-humaneval": chat_humaneval.ChatHumanEval, 36 | } 37 | 38 | ALL_TASKS = sorted(list(TASK_REGISTRY)) 39 | 40 | class MultiChoice: 41 | def __init__(self, choices): 42 | self.choices = choices 43 | 44 | # Simple wildcard support (linux filename patterns) 45 | def __contains__(self, values): 46 | for value in values.split(","): 47 | if len(fnmatch.filter(self.choices, value)) == 0: 48 | return False 49 | 50 | return True 51 | 52 | def __iter__(self): 53 | for choice in self.choices: 54 | yield choice 55 | 56 | 57 | def parse_args(): 58 | parser = HfArgumentParser(EvalArguments) 59 | 60 | parser.add_argument( 61 | "--model", 62 | default="codeparrot/codeparrot-small", 63 | help="Model to evaluate, provide a repo name in Hugging Face hub or a local path", 64 | ) 65 | parser.add_argument( 66 | "--tasks", 67 | default=None, 68 | choices=MultiChoice(ALL_TASKS), 69 | help=f"Evaluation tasks from {ALL_TASKS}", 70 | ) 71 | parser.add_argument( 72 | "--batch_size", 73 | type=int, 74 | default=1, 75 | help="Batch size for evaluation on each worker, can be larger for HumanEval", 76 | ) 77 | parser.add_argument( 78 | "--max_length_generation", 79 | type=int, 80 | default=1024, 81 | help="Maximum length of generated sequence (prompt+generation)", 82 | ) 83 | parser.add_argument( 84 | "--precision", 85 | type=str, 86 | default="bf16", 87 | help="Model precision, from: fp32, fp16 or bf16", 88 | ) 89 | parser.add_argument( 90 | "--limit", 91 | type=int, 92 | default=None, 93 | help="Number of samples to solve and evaluate from the benchmark", 94 | ) 95 | parser.add_argument( 96 | "--postprocess", 97 | action="store_false", 98 | help="Postprocess model outputs before execution, always on except during generation tests", 99 | ) 100 | parser.add_argument( 101 | "--allow_code_execution", 102 | action="store_true", 103 | help="Allow code evaluation to execute external/untrusted Python code on your machine", 104 | ) 105 | parser.add_argument( 106 | "--generation_only", 107 | action="store_true", 108 | help="Do code generation but no evaluation", 109 | ) 110 | parser.add_argument( 111 | "--load_generations_path", 112 | type=str, 113 | default=None, 114 | help="Path of file with previously generated solutions, if provided generation is skipped and only evaluation is done", 115 | ) 116 | parser.add_argument( 117 | "--metric_output_path", 118 | type=str, 119 | default="evaluation_results.json", 120 | help="Path to save the results", 121 | ) 122 | parser.add_argument( 123 | "--save_generations", 124 | action="store_true", 125 | help="Whether to save code generations", 126 | ) 127 | parser.add_argument( 128 | "--save_generations_path", 129 | type=str, 130 | default="generations.json", 131 | help="Path for saving the code generations", 132 | ) 133 | parser.add_argument( 134 | "--save_references", 135 | action="store_true", 136 | help="Whether to save reference solutions/tests", 137 | ) 138 | parser.add_argument( 139 | "--save_references_path", 140 | type=str, 141 | default="references.json", 142 | help="Path for saving the code references", 143 | ) 144 | return parser.parse_args() 145 | 146 | 147 | def pattern_match(patterns, source_list): 148 | """Returns a list containing all values of the source_list that 149 | match at least one of the patterns""" 150 | task_names = set() 151 | for pattern in patterns: 152 | for matching in fnmatch.filter(source_list, pattern): 153 | task_names.add(matching) 154 | return list(task_names) 155 | 156 | 157 | def main(): 158 | args = parse_args() 159 | transformers.logging.set_verbosity_error() 160 | datasets.logging.set_verbosity_error() 161 | 162 | task_names = pattern_match(args.tasks.split(","), ALL_TASKS) 163 | 164 | accelerator = Accelerator() 165 | if accelerator.is_main_process: 166 | print(f"Selected Tasks: {task_names}") 167 | 168 | results = {} 169 | if args.load_generations_path: 170 | # here we don't generate code but only evaluate previously computed generations 171 | if accelerator.is_main_process: 172 | print("evaluation only mode") 173 | evaluator = Evaluator(accelerator, None, None, args) 174 | for task in task_names: 175 | results[task] = evaluator.evaluate(task) 176 | else: 177 | # here we generate code and save it (evaluation is optional but True by default) 178 | dict_precisions = { 179 | "fp32": torch.float32, 180 | "fp16": torch.float16, 181 | "bf16": torch.bfloat16, 182 | } 183 | if args.precision not in dict_precisions: 184 | raise ValueError( 185 | f"Non valid precision {args.precision}, choose from: fp16, fp32, bf16" 186 | ) 187 | print(f"Loading tokenizer and model (in {args.precision})") 188 | model = AutoModelForCausalLM.from_pretrained( 189 | args.model, 190 | revision=args.revision, 191 | torch_dtype=dict_precisions[args.precision], 192 | trust_remote_code=True, 193 | ) 194 | 195 | tokenizer_path = args.model 196 | tokenizer = AutoTokenizer.from_pretrained( 197 | tokenizer_path, 198 | revision=args.revision, 199 | trust_remote_code=True, 200 | ) 201 | if not tokenizer.eos_token: 202 | if tokenizer.bos_token: 203 | tokenizer.eos_token = tokenizer.bos_token 204 | print("bos_token used as eos_token") 205 | else: 206 | raise ValueError("No eos_token or bos_token found") 207 | if not tokenizer.pad_token: 208 | tokenizer.pad_token = tokenizer.eos_token 209 | 210 | evaluator = Evaluator(accelerator, model, tokenizer, args) 211 | 212 | for task in task_names: 213 | if args.generation_only: 214 | if accelerator.is_main_process: 215 | print("generation mode only") 216 | generations, references = evaluator.generate_text(task) 217 | if accelerator.is_main_process: 218 | with open(args.save_generations_path, "w", encoding="utf-8") as fp: 219 | json.dump(generations, fp, indent=4, ensure_ascii=False) 220 | print(f"generations were saved at {args.save_generations_path}") 221 | if args.save_references: 222 | with open(args.save_references_path, "w", encoding="utf-8") as fp: 223 | json.dump(references, fp, indent=4, ensure_ascii=False) 224 | print("references were saved") 225 | else: 226 | results[task] = evaluator.evaluate(task) 227 | 228 | results["config"] = { 229 | "model": args.model, 230 | "revision": args.revision, 231 | "temperature": args.temperature, 232 | "n_samples": args.n_samples, 233 | } 234 | if not args.generation_only: 235 | dumped = json.dumps(results, indent=2) 236 | if accelerator.is_main_process: 237 | print(dumped) 238 | 239 | with open(args.metric_output_path, "w") as f: 240 | f.write(dumped) 241 | 242 | if __name__ == "__main__": 243 | main() 244 | -------------------------------------------------------------------------------- /evaluation/run_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export TOKENIZERS_PARALLELISM=false 3 | export TRANSFORMERS_OFFLINE=0 4 | 5 | function local_gen() { 6 | project_dir=$(cd "$(dirname $0)"; pwd) 7 | accelerate_args="--main_process_port=$((10000 + RANDOM % 20000)) --config_file=$project_dir/all_config.yaml" 8 | 9 | dataset=$1 10 | model_name_or_path=$2 11 | run_id=$3 12 | n_samples=40 13 | 14 | batch_size=10 15 | 16 | mkdir -p $project_dir/log/$run_id/$dataset 17 | 18 | accelerate launch $accelerate_args eval.py \ 19 | --model $model_name_or_path \ 20 | --tasks $dataset \ 21 | --max_length_generation 2048 \ 22 | --temperature 0.2 \ 23 | --precision bf16 \ 24 | --do_sample True \ 25 | --n_samples $n_samples \ 26 | --batch_size $batch_size \ 27 | --save_generations \ 28 | --save_references \ 29 | --generation_only \ 30 | --save_generations_path $project_dir/log/$run_id/$dataset/generations.json \ 31 | --save_references_path $project_dir/log/$run_id/$dataset/references.json \ 32 | --metric_output_path $project_dir/log/$run_id/$dataset/evaluation.json \ 33 | | tee $project_dir/log/$run_id/$dataset/evaluation.log 2>&1 34 | 35 | } 36 | 37 | function eval() { 38 | dataset=$1 39 | model_name_or_path=$2 40 | run_id=$3 41 | 42 | project_dir=$(cd "$(dirname $0)"; pwd) 43 | 44 | python3 eval.py \ 45 | --model $model_name_or_path \ 46 | --tasks $dataset \ 47 | --load_generations_path $project_dir/log/$run_id/$dataset/generations.json \ 48 | --allow_code_execution \ 49 | --n_samples 40 \ 50 | --metric_output_path $project_dir/log/$run_id/$dataset/evaluation.json 51 | 52 | } 53 | 54 | task=$1 55 | dataset=$2 56 | model_name_or_path=$3 57 | run_id=$4 58 | 59 | if [ $task == "local_gen" ]; then 60 | local_gen $dataset $model_name_or_path $run_id 61 | elif [ $task == "eval" ]; then 62 | eval $dataset $model_name_or_path $run_id 63 | elif [ $task == "help" ]; then 64 | echo "./scripts/run_eval.sh [local_gen|eval] [humaneval|mbpp|chat-humaneval|multiple-*] model_name_or_path run_id" 65 | else 66 | echo "task should be local_gen or eval" 67 | fi -------------------------------------------------------------------------------- /finetune/README.md: -------------------------------------------------------------------------------- 1 | 该文档为希望在特定领域任务中应用CodeShell模型的用户提供了官方微调示例。 2 | 3 | 开始前,您需要通过执行以下命令来配置必要的环境: 4 | ```bash 5 | pip install peft deepspeed 6 | ``` 7 | 8 | 您需要按照 JSON 格式整理训练数据,其中每个样本是一个包含 ID 和对话列表的字典。该对话列表是消息对象的数组,代表了用户和助手之间的交谈。如下所示为一个样例: 9 | 10 | ```json 11 | [ 12 | { 13 | "id": "identity_0", 14 | "conversations": [ 15 | { 16 | "from": "human", 17 | "value": "你好" 18 | }, 19 | { 20 | "from": "assistant", 21 | "value": "您好,我是CodeShell,请问有什么可以帮助您的吗?" 22 | } 23 | ] 24 | } 25 | ] 26 | ``` 27 | 28 | 当数据准备完毕后,导航至微调的目录并执行 `run_finetune.sh` 脚本,命令如下: 29 | 30 | ```bash 31 | cd codeshell/finetune 32 | ./run_finetune.sh $model_name_or_path $dataset_path $save_path 33 | ``` 34 | 35 | 按照这些步骤操作,您可以将预训练的模型微调,使其更加精确地适应您的特定任务。 36 | 37 | 该微调脚本基于qwen、fastchat 和 tatsu-lab/stanford_alpaca 的微调脚本。 -------------------------------------------------------------------------------- /finetune/README_EN.md: -------------------------------------------------------------------------------- 1 | In this guide, we present the official fine-tuning script for users who wish to adapt pretrained models for their domain-specific tasks. 2 | 3 | To begin, please set up the required environment by executing the following command: 4 | ```bash 5 | pip install peft deepspeed 6 | ``` 7 | 8 | The training data should be organized in JSON format, with each sample being a dictionary containing an ID and a conversation list. The conversation list is an array of message objects, representing the conversation between the user and the assistant. See the example below: 9 | 10 | ```json 11 | [ 12 | { 13 | "id": "identity_0", 14 | "conversations": [ 15 | { 16 | "from": "human", 17 | "value": "你好" 18 | }, 19 | { 20 | "from": "assistant", 21 | "value": "您好,我是CodeShell,请问有什么可以帮助您的吗?" 22 | } 23 | ] 24 | } 25 | ] 26 | ``` 27 | 28 | Once the data is prepared, navigate to the fine-tuning directory and execute the `run_finetune.sh` script using the following command: 29 | 30 | ```bash 31 | cd codeshell/finetune 32 | ./run_finetune.sh $model_name_or_path $dataset_path $save_path 33 | ``` 34 | 35 | By following these instructions, you'll be able to fine-tune the pretrained model to better suit your specific downstream tasks. 36 | 37 | This code is based on the revised code from qwen, fastchat and tatsu-lab/stanford_alpaca. -------------------------------------------------------------------------------- /finetune/ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "bf16": { 12 | "enabled": "auto" 13 | }, 14 | 15 | "zero_optimization": { 16 | "stage": 3, 17 | "offload_optimizer": { 18 | "device": "none", 19 | "pin_memory": true 20 | }, 21 | "offload_param": { 22 | "device": "none", 23 | "pin_memory": true 24 | }, 25 | "overlap_comm": true, 26 | "contiguous_gradients": true, 27 | "sub_group_size": 1e9, 28 | "reduce_bucket_size": "auto", 29 | "stage3_prefetch_bucket_size": "auto", 30 | "stage3_param_persistence_threshold": "auto", 31 | "stage3_max_live_parameters": 1e9, 32 | "stage3_max_reuse_distance": 1e9, 33 | "stage3_gather_16bit_weights_on_model_save": true 34 | }, 35 | 36 | "gradient_accumulation_steps": "auto", 37 | "gradient_clipping": "auto", 38 | "steps_per_print": 2000, 39 | "train_batch_size": "auto", 40 | "train_micro_batch_size_per_gpu": "auto", 41 | "wall_clock_breakdown": false 42 | } 43 | -------------------------------------------------------------------------------- /finetune/finetune.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass, field 3 | from typing import Dict, Optional 4 | 5 | import torch 6 | import transformers 7 | from torch.utils.data import Dataset 8 | from tqdm import tqdm 9 | from transformers.training_args import TrainingArguments 10 | 11 | 12 | @dataclass 13 | class ModelArguments: 14 | model_name_or_path: Optional[str] = field(default="gpt2") 15 | 16 | 17 | @dataclass 18 | class DataArguments: 19 | data_path: str = field( 20 | default=None, metadata={"help": "Path to the training data."} 21 | ) 22 | 23 | 24 | @dataclass 25 | class TrainingArguments(transformers.TrainingArguments): 26 | cache_dir: Optional[str] = field(default=None) 27 | optim: str = field(default="adamw_torch") 28 | model_max_length: int = field( 29 | default=512, 30 | metadata={ 31 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 32 | }, 33 | ) 34 | use_lora: bool = field(default=False) 35 | 36 | 37 | class SupervisedDataset(Dataset): 38 | """Dataset for supervised fine-tuning.""" 39 | 40 | def __init__( 41 | self, 42 | data_path, 43 | tokenizer, 44 | model_max_length, 45 | user_string="## human:", 46 | copilot_string="## copilot:", 47 | assistant_string="## assistant:", 48 | end_string=" || ", 49 | ): 50 | super(SupervisedDataset, self).__init__() 51 | self.data = json.load(open(data_path)) 52 | # self.data = self.data[:1000] 53 | self.tokenizer = tokenizer 54 | self.model_max_length = model_max_length 55 | self.user_string = user_string 56 | self.assistant_string = assistant_string 57 | self.end_string = end_string 58 | self.user_tokens = self.tokenizer.encode(user_string) 59 | self.copilot_tokens = self.tokenizer.encode(copilot_string) 60 | self.assistant_tokens = self.tokenizer.encode(assistant_string) 61 | self.end_tokens = self.tokenizer.encode(end_string) 62 | self.ignore_index = -100 63 | 64 | self.preprocessed_data = self.preprocessing() 65 | item = self.preprocessed_data[0] 66 | print("input:", self.tokenizer.decode(item["input_ids"])) 67 | labels = [] 68 | for id_ in item["labels"]: 69 | if id_ == -100: 70 | continue 71 | 72 | labels.append(id_) 73 | print("label:", self.tokenizer.decode(labels)) 74 | 75 | def __len__(self): 76 | return len(self.preprocessed_data) 77 | 78 | def preprocessing(self): 79 | preprocessed_data = [] 80 | for example in tqdm(self.data, desc="Preprocessing"): 81 | preprocess_example = self.preprocess_one(example) 82 | if len(preprocess_example["input_ids"]) <= 16: 83 | continue 84 | preprocessed_data.append(preprocess_example) 85 | return preprocessed_data 86 | 87 | def preprocess_one(self, example): 88 | input_ids = [] 89 | labels = [] 90 | 91 | chat_mode = "human" 92 | if "copilot" in [message["from"] for message in example["conversations"]]: 93 | chat_mode = "copilot" 94 | 95 | if chat_mode == "human": 96 | for idx, message in enumerate(example["conversations"]): 97 | if idx == 0: 98 | input_ids += [self.tokenizer.eos_token_id] 99 | labels += [self.ignore_index] 100 | from_ = message["from"] 101 | value = message["value"] 102 | value_ids = self.tokenizer.encode(value) 103 | 104 | if len(input_ids) + len(self.user_tokens + value_ids + self.end_tokens) > self.model_max_length: 105 | break 106 | 107 | if from_ == "human": 108 | input_ids += self.user_tokens + value_ids + self.end_tokens 109 | labels += [self.ignore_index] * len( 110 | self.user_tokens + value_ids + self.end_tokens 111 | ) 112 | else: 113 | input_ids += self.assistant_tokens + value_ids + self.end_tokens 114 | labels += [self.ignore_index] * len(self.assistant_tokens) \ 115 | + value_ids + self.end_tokens 116 | elif chat_mode == "copilot": 117 | for idx, message in enumerate(example["conversations"]): 118 | from_ = message["from"] 119 | value = message["value"] 120 | value_ids = self.tokenizer.encode(value) 121 | 122 | if len(input_ids) + len(value_ids) > self.model_max_length: 123 | break 124 | 125 | if from_ == "copilot": 126 | input_ids += value_ids 127 | labels += [self.ignore_index] * len(value_ids) 128 | else: 129 | input_ids += value_ids + [self.tokenizer.eos_token_id] 130 | labels += value_ids + [self.tokenizer.eos_token_id] 131 | else: 132 | raise ValueError("chat_mode should be human or copilot") 133 | 134 | input_ids = input_ids[-self.model_max_length:] 135 | labels = labels[-self.model_max_length:] 136 | input_ids = torch.LongTensor(input_ids) 137 | labels = torch.LongTensor(labels) 138 | attention_mask = input_ids.ne(self.tokenizer.eos_token_id) 139 | return { 140 | "input_ids": input_ids, 141 | "labels": labels, 142 | "attention_mask": attention_mask, 143 | } 144 | 145 | def __getitem__(self, idx) -> Dict[str, torch.Tensor]: 146 | # print(self.preprocessed_data[idx]["input_ids"].shape) 147 | return self.preprocessed_data[idx] 148 | 149 | def print_dataset_example(self, num=3): 150 | for idx in range(num): 151 | example = self.preprocessed_data[idx] 152 | print("input_ids:\n{}".format(example["input_ids"])) 153 | print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False))) 154 | print("label_ids:\n{}".format(example["labels"])) 155 | print("labels:\n{}".format( 156 | self.tokenizer.decode([d if d != self.ignore_index else self.tokenizer.eos_token_id for d in example["labels"]], 157 | skip_special_tokens=False) 158 | )) 159 | 160 | 161 | def train(): 162 | parser = transformers.HfArgumentParser( 163 | (ModelArguments, DataArguments, TrainingArguments) 164 | ) 165 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 166 | 167 | model = transformers.AutoModelForCausalLM.from_pretrained( 168 | model_args.model_name_or_path, 169 | trust_remote_code=True, 170 | cache_dir=training_args.cache_dir, 171 | ) 172 | tokenizer = transformers.AutoTokenizer.from_pretrained( 173 | model_args.model_name_or_path, 174 | use_fast=True, 175 | trust_remote_code=True, 176 | model_max_length=training_args.model_max_length, 177 | ) 178 | 179 | # tokenizer.eos_token_id = 70000 180 | # tokenizer.eos_token = "<|endoftext|>" 181 | 182 | if tokenizer.eos_token_id is None: 183 | tokenizer.eos_token_id = tokenizer.bos_token_id 184 | if tokenizer.eos_token is None: 185 | tokenizer.eos_token = tokenizer.bos_token 186 | if tokenizer.pad_token_id is None: 187 | tokenizer.pad_token_id = tokenizer.eos_token_id 188 | if tokenizer.pad_token is None: 189 | tokenizer.pad_token = tokenizer.eos_token 190 | 191 | if training_args.use_lora: 192 | from peft import LoraConfig, TaskType, get_peft_model 193 | 194 | peft_config = LoraConfig( 195 | task_type=TaskType.CAUSAL_LM, 196 | target_modules=["c_attn"], 197 | inference_mode=False, 198 | r=1, 199 | lora_alpha=32, 200 | lora_dropout=0.1, 201 | ) 202 | model.enable_input_require_grads() 203 | model = get_peft_model(model, peft_config) 204 | model.print_trainable_parameters() 205 | 206 | dataset = SupervisedDataset( 207 | data_args.data_path, tokenizer, training_args.model_max_length 208 | ) 209 | dataset.print_dataset_example() 210 | 211 | trainer = transformers.Trainer( 212 | model=model, args=training_args, train_dataset=dataset, tokenizer=tokenizer 213 | ) 214 | trainer.train() 215 | trainer.save_state() 216 | trainer.save_model(output_dir=training_args.output_dir) 217 | 218 | 219 | if __name__ == "__main__": 220 | train() -------------------------------------------------------------------------------- /finetune/run_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export WANDB_DISABLED=true 4 | 5 | project_dir=$(cd "$(dirname $0)"; pwd) 6 | 7 | model=$1 8 | data_path=$2 9 | exp_id=$3 10 | 11 | output_dir=${project_dir}/output_models/${exp_id} 12 | log_dir=${project_dir}/log/${exp_id} 13 | mkdir -p ${output_dir} ${log_dir} 14 | 15 | # deepspeed_args="--master_port=23333 --hostfile=${project_dir}/configs/hostfile.txt --master_addr=10.0.0.16" # Default argument 16 | # deepspeed_args="--master_port=$((10000 + RANDOM % 20000)) --include=localhost:0,1,2,3" # Default argument 17 | deepspeed_args="--master_port=$((10000 + RANDOM % 20000))" # Default argument 18 | 19 | deepspeed ${deepspeed_args} ${project_dir}/finetune.py \ 20 | --deepspeed ${project_dir}/ds_config_zero3.json \ 21 | --model_name_or_path ${model} \ 22 | --data_path ${data_path} \ 23 | --model_max_length 4096 \ 24 | --output_dir ${output_dir} \ 25 | --per_device_train_batch_size 1 \ 26 | --gradient_accumulation_steps 8 \ 27 | --gradient_checkpointing True \ 28 | --lr_scheduler_type cosine \ 29 | --logging_steps 1 \ 30 | --save_steps 100 \ 31 | --learning_rate 2e-5 \ 32 | --num_train_epochs 3 \ 33 | --bf16 \ 34 | | tee ${log_dir}/train.log \ 35 | 2> ${log_dir}/train.err 36 | -------------------------------------------------------------------------------- /model/configuration_codeshell.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This code is based on Bigcode's GPTBigCode configuration. It has been modified from 17 | # its original forms to accommodate minor architectural differences compared to 18 | # GPTBigCode Configuration that trained the model. 19 | 20 | # coding=utf-8 21 | # Copyright 2023 The BigCode team and HuggingFace Inc. team. 22 | # 23 | # Licensed under the Apache License, Version 2.0 (the "License"); 24 | # you may not use this file except in compliance with the License. 25 | # You may obtain a copy of the License at 26 | # 27 | # http://www.apache.org/licenses/LICENSE-2.0 28 | # 29 | # Unless required by applicable law or agreed to in writing, software 30 | # distributed under the License is distributed on an "AS IS" BASIS, 31 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32 | # See the License for the specific language governing permissions and 33 | # limitations under the License. 34 | """ CodeShell configuration""" 35 | 36 | from transformers.configuration_utils import PretrainedConfig 37 | from transformers.utils import logging 38 | 39 | 40 | logger = logging.get_logger(__name__) 41 | 42 | 43 | class CodeShellConfig(PretrainedConfig): 44 | """ 45 | This is the configuration class to store the configuration of a [`CodeShellModel`]. It is used to instantiate a 46 | CodeShell model according to the specified arguments, defining the model architecture. 47 | 48 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 49 | documentation from [`PretrainedConfig`] for more information. 50 | 51 | Args: 52 | vocab_size (`int`, *optional*, defaults to 50257): 53 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the 54 | `inputs_ids` passed when calling [`CodeShellModel`]. 55 | n_positions (`int`, *optional*, defaults to 1024): 56 | The maximum sequence length that this model might ever be used with. Typically set this to something large 57 | just in case (e.g., 512 or 1024 or 2048). 58 | n_embd (`int`, *optional*, defaults to 768): 59 | Dimensionality of the embeddings and hidden states. 60 | n_layer (`int`, *optional*, defaults to 12): 61 | Number of hidden layers in the Transformer encoder. 62 | n_head (`int`, *optional*, defaults to 12): 63 | Number of attention heads for each attention layer in the Transformer encoder. 64 | n_inner (`int`, *optional*, defaults to None): 65 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd 66 | activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`): 67 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new", 68 | "gelu_pytorch_tanh"]`. 69 | resid_pdrop (`float`, *optional*, defaults to 0.1): 70 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 71 | embd_pdrop (`float`, *optional*, defaults to 0.1): 72 | The dropout ratio for the embeddings. 73 | attn_pdrop (`float`, *optional*, defaults to 0.1): 74 | The dropout ratio for the attention. 75 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): 76 | The epsilon to use in the layer normalization layers. 77 | initializer_range (`float`, *optional*, defaults to 0.02): 78 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 79 | scale_attn_weights (`bool`, *optional*, defaults to `True`): 80 | Scale attention weights by dividing by sqrt(hidden_size).. 81 | use_cache (`bool`, *optional*, defaults to `True`): 82 | Whether or not the model should return the last key/values attentions (not used by all models). 83 | attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): 84 | Whether to call the fused softmax in float32. 85 | scale_attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): 86 | Whether to scale the attention softmax in float32. 87 | attention_type (`bool`, *optional*, defaults to `True`): 88 | Whether to use Multi-Query Attion (`True`) or Multi-Head Attention (`False`). 89 | """ 90 | 91 | model_type = "codeshell" 92 | keys_to_ignore_at_inference = ["past_key_values"] 93 | attribute_map = { 94 | "hidden_size": "n_embd", 95 | "max_position_embeddings": "n_positions", 96 | "num_attention_heads": "n_head", 97 | "num_hidden_layers": "n_layer", 98 | } 99 | 100 | def __init__( 101 | self, 102 | vocab_size=70144, 103 | n_positions=8192, 104 | n_embd=4096, 105 | n_layer=42, 106 | n_head=32, 107 | n_inner=None, 108 | activation_function="gelu_pytorch_tanh", 109 | resid_pdrop=0.1, 110 | embd_pdrop=0.1, 111 | attn_pdrop=0.1, 112 | layer_norm_epsilon=1e-5, 113 | initializer_range=0.02, 114 | scale_attn_weights=True, 115 | use_cache=True, 116 | bos_token_id=70000, 117 | eos_token_id=70000, 118 | attention_softmax_in_fp32=True, 119 | scale_attention_softmax_in_fp32=True, 120 | group_query_attention=True, 121 | num_query_groups=1, 122 | position_embedding_type="learned_absolute", 123 | rope_scaling=None, 124 | **kwargs, 125 | ): 126 | self.vocab_size = vocab_size 127 | self.n_positions = n_positions 128 | self.n_embd = n_embd 129 | self.n_layer = n_layer 130 | self.n_head = n_head 131 | self.n_inner = n_inner 132 | self.activation_function = activation_function 133 | self.resid_pdrop = resid_pdrop 134 | self.embd_pdrop = embd_pdrop 135 | self.attn_pdrop = attn_pdrop 136 | self.layer_norm_epsilon = layer_norm_epsilon 137 | self.initializer_range = initializer_range 138 | self.scale_attn_weights = scale_attn_weights 139 | self.use_cache = use_cache 140 | self.attention_softmax_in_fp32 = attention_softmax_in_fp32 141 | self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 142 | self.group_query_attention = group_query_attention 143 | self.num_query_groups = num_query_groups 144 | self.position_embedding_type = position_embedding_type 145 | self.rope_scaling = rope_scaling 146 | assert self.position_embedding_type in [ 147 | "learned_absolute", "rope" 148 | ], "position_embedding_type must be one of ['learned_absolute', 'rope']" 149 | 150 | self.bos_token_id = bos_token_id 151 | self.eos_token_id = eos_token_id 152 | 153 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 154 | -------------------------------------------------------------------------------- /model/modeling_codeshell.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This code is based on Bigcode's GPTBigCode model. It has been modified from 17 | # its original forms to accommodate minor architectural differences compared to 18 | # GPTBigCode model that trained the model. 19 | 20 | # Copyright 2023 The Bigcode team and HuggingFace Inc. team. 21 | # Licensed under the Apache License, Version 2.0 (the "License"); 22 | # you may not use this file except in compliance with the License. 23 | # You may obtain a copy of the License at 24 | # 25 | # http://www.apache.org/licenses/LICENSE-2.0 26 | # 27 | # Unless required by applicable law or agreed to in writing, software 28 | # distributed under the License is distributed on an "AS IS" BASIS, 29 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 30 | # See the License for the specific language governing permissions and 31 | # limitations under the License. 32 | """PyTorch CodeShell model.""" 33 | import os 34 | import math 35 | from typing import List, Optional, Tuple, Union, Callable 36 | from threading import Thread 37 | from queue import Queue 38 | 39 | 40 | import torch 41 | import torch.utils.checkpoint 42 | from torch import nn 43 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 44 | 45 | from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel, PretrainedConfig 46 | from transformers.generation.utils import GenerationConfig 47 | 48 | from transformers.activations import ACT2FN 49 | from transformers.modeling_outputs import ( 50 | BaseModelOutputWithPastAndCrossAttentions, 51 | CausalLMOutputWithCrossAttentions, 52 | ) 53 | from transformers.modeling_utils import PreTrainedModel 54 | from transformers.utils import ( 55 | add_start_docstrings, 56 | add_start_docstrings_to_model_forward, 57 | ) 58 | from .configuration_codeshell import CodeShellConfig 59 | 60 | # Fused kernels 61 | # Use separate functions for each case because conditionals prevent kernel fusion. 62 | # TODO: Could have better fused kernels depending on scaling, dropout and head mask. 63 | # Is it doable without writing 32 functions? 64 | @torch.jit.script 65 | def upcast_masked_softmax( 66 | x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype 67 | ): 68 | input_dtype = x.dtype 69 | x = x.to(softmax_dtype) * scale 70 | x = torch.where(mask, x, mask_value) 71 | x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) 72 | return x 73 | 74 | 75 | @torch.jit.script 76 | def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): 77 | input_dtype = x.dtype 78 | x = x.to(softmax_dtype) * scale 79 | x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) 80 | return x 81 | 82 | 83 | @torch.jit.script 84 | def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): 85 | x = torch.where(mask, x, mask_value) 86 | x = torch.nn.functional.softmax(x, dim=-1) 87 | return x 88 | 89 | 90 | class CodeShellRotaryEmbedding(torch.nn.Module): 91 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 92 | super().__init__() 93 | 94 | self.dim = dim 95 | self.max_position_embeddings = max_position_embeddings 96 | self.base = base 97 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 98 | self.register_buffer("inv_freq", inv_freq) 99 | 100 | # Build here to make `torch.jit.trace` work. 101 | self._set_cos_sin_cache( 102 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 103 | ) 104 | 105 | def _set_cos_sin_cache(self, seq_len, device, dtype): 106 | self.max_seq_len_cached = seq_len 107 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 108 | 109 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 110 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 111 | emb = torch.cat((freqs, freqs), dim=-1) 112 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 113 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 114 | 115 | def forward(self, x, seq_len=None): 116 | # x: [bs, num_attention_heads, seq_len, head_size] 117 | if seq_len > self.max_seq_len_cached: 118 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 119 | 120 | return ( 121 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 122 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 123 | ) 124 | 125 | 126 | class CodeShellLinearScalingRotaryEmbedding(CodeShellRotaryEmbedding): 127 | """CodeShellRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 128 | 129 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 130 | self.scaling_factor = scaling_factor 131 | super().__init__(dim, max_position_embeddings, base, device) 132 | 133 | def _set_cos_sin_cache(self, seq_len, device, dtype): 134 | self.max_seq_len_cached = seq_len 135 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 136 | t = t / self.scaling_factor 137 | 138 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 139 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 140 | emb = torch.cat((freqs, freqs), dim=-1) 141 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 142 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 143 | 144 | 145 | class CodeShellDynamicNTKScalingRotaryEmbedding(CodeShellRotaryEmbedding): 146 | """ShellRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 147 | 148 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 149 | self.scaling_factor = scaling_factor 150 | super().__init__(dim, max_position_embeddings, base, device) 151 | 152 | def _set_cos_sin_cache(self, seq_len, device, dtype): 153 | self.max_seq_len_cached = seq_len 154 | 155 | if seq_len > self.max_position_embeddings: 156 | base = self.base * ( 157 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) 158 | ) ** (self.dim / (self.dim - 2)) 159 | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 160 | self.register_buffer("inv_freq", inv_freq) 161 | 162 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 163 | 164 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 165 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 166 | emb = torch.cat((freqs, freqs), dim=-1) 167 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) 168 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) 169 | 170 | def rotate_half(x): 171 | """Rotates half the hidden dims of the input.""" 172 | x1 = x[..., : x.shape[-1] // 2] 173 | x2 = x[..., x.shape[-1] // 2 :] 174 | return torch.cat((-x2, x1), dim=-1) 175 | 176 | 177 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 178 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 179 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 180 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 181 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 182 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 183 | q_embed = (q * cos) + (rotate_half(q) * sin) 184 | k_embed = (k * cos) + (rotate_half(k) * sin) 185 | return q_embed, k_embed 186 | 187 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 188 | """ 189 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 190 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 191 | """ 192 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 193 | if n_rep == 1: 194 | return hidden_states 195 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 196 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 197 | 198 | class CodeShellAttention(nn.Module): 199 | def __init__(self, config, layer_idx=None): 200 | super().__init__() 201 | self.mask_value = None 202 | 203 | self.position_embedding_type = config.position_embedding_type 204 | self.rope_scaling = config.rope_scaling 205 | self.max_position_embeddings = config.max_position_embeddings 206 | 207 | self.group_query_attention = config.group_query_attention 208 | self.num_query_groups = config.num_query_groups 209 | self.num_key_value_groups = config.num_attention_heads // config.num_query_groups 210 | 211 | self.embed_dim = config.hidden_size 212 | self.num_heads = config.num_attention_heads 213 | self.head_dim = self.embed_dim // self.num_heads 214 | self.kv_heads = config.num_query_groups if self.group_query_attention else self.num_heads 215 | self.kv_dim = self.kv_heads * self.head_dim 216 | self.split_size = self.embed_dim 217 | if self.head_dim * self.num_heads != self.embed_dim: 218 | raise ValueError( 219 | f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" 220 | f" {self.num_heads})." 221 | ) 222 | 223 | self.layer_idx = layer_idx 224 | 225 | self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) 226 | self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) 227 | 228 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 229 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 230 | 231 | if self.position_embedding_type == "rope": 232 | self._init_rope() 233 | 234 | def _init_rope(self): 235 | if self.rope_scaling is None: 236 | self.rotary_emb = CodeShellRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) 237 | else: 238 | scaling_type = self.rope_scaling["type"] 239 | scaling_factor = self.rope_scaling["factor"] 240 | if scaling_type == "linear": 241 | self.rotary_emb = CodeShellLinearScalingRotaryEmbedding( 242 | self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor 243 | ) 244 | elif scaling_type == "dynamic": 245 | self.rotary_emb = CodeShellDynamicNTKScalingRotaryEmbedding( 246 | self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor 247 | ) 248 | else: 249 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}") 250 | 251 | 252 | def _get_mask_value(self, device, dtype): 253 | # torch.where expects a tensor. We use a cache to avoid recreating it every time. 254 | if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: 255 | self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) 256 | return self.mask_value 257 | 258 | def forward( 259 | self, 260 | hidden_states: torch.Tensor, 261 | layer_past: Optional[torch.Tensor] = None, 262 | attention_mask: Optional[torch.Tensor] = None, 263 | position_ids: Optional[torch.LongTensor] = None, 264 | head_mask: Optional[torch.Tensor] = None, 265 | use_cache: Optional[bool] = False, 266 | output_attentions: Optional[bool] = False, 267 | ) -> Union[ 268 | Tuple[torch.Tensor, Optional[torch.Tensor]], 269 | Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], 270 | ]: 271 | bsz, q_len, _ = hidden_states.size() 272 | query_states, key_states, value_states = self.c_attn(hidden_states).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=2) 273 | 274 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 275 | key_states = key_states.view(bsz, q_len, self.num_query_groups, self.head_dim).transpose(1, 2) 276 | value_states = value_states.view(bsz, q_len, self.num_query_groups, self.head_dim).transpose(1, 2) 277 | 278 | kv_seq_len = key_states.shape[-2] 279 | if layer_past is not None: 280 | kv_seq_len += layer_past[0].shape[-2] 281 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 282 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 283 | 284 | if layer_past is not None: 285 | # reuse k, v, self_attention 286 | key_states = torch.cat([layer_past[0], key_states], dim=2) 287 | value_states = torch.cat([layer_past[1], value_states], dim=2) 288 | 289 | layer_past = (key_states, value_states) if use_cache else None 290 | 291 | # repeat k/v heads if n_kv_heads < n_heads 292 | key_states = repeat_kv(key_states, self.num_heads // self.kv_heads) 293 | value_states = repeat_kv(value_states, self.num_heads // self.kv_heads) 294 | 295 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 296 | 297 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 298 | raise ValueError( 299 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 300 | f" {attn_weights.size()}" 301 | ) 302 | 303 | if attention_mask is not None: 304 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 305 | raise ValueError( 306 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 307 | ) 308 | mask_value = self._get_mask_value(attn_weights.device, attn_weights.dtype) 309 | # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. 310 | attn_weights = torch.where(attention_mask, attn_weights, mask_value) 311 | 312 | # upcast attention to fp32 313 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 314 | attn_weights = self.attn_dropout(attn_weights) 315 | attn_output = torch.matmul(attn_weights, value_states) 316 | 317 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 318 | raise ValueError( 319 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 320 | f" {attn_output.size()}" 321 | ) 322 | 323 | attn_output = attn_output.transpose(1, 2).contiguous() 324 | attn_output = attn_output.reshape(bsz, q_len, self.embed_dim) 325 | 326 | attn_output = self.c_proj(attn_output) 327 | attn_output = self.resid_dropout(attn_output) 328 | 329 | outputs = (attn_output, layer_past) 330 | if output_attentions: 331 | outputs += (attn_weights,) 332 | 333 | return outputs # a, present, (attentions) 334 | 335 | 336 | class CodeShellMLP(nn.Module): 337 | def __init__(self, intermediate_size, config): 338 | super().__init__() 339 | embed_dim = config.hidden_size 340 | self.c_fc = nn.Linear(embed_dim, intermediate_size) 341 | self.c_proj = nn.Linear(intermediate_size, embed_dim) 342 | self.act = ACT2FN[config.activation_function] 343 | self.dropout = nn.Dropout(config.resid_pdrop) 344 | 345 | # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward 346 | def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor: 347 | hidden_states = self.c_fc(hidden_states) 348 | hidden_states = self.act(hidden_states) 349 | hidden_states = self.c_proj(hidden_states) 350 | hidden_states = self.dropout(hidden_states) 351 | return hidden_states 352 | 353 | 354 | class CodeShellBlock(nn.Module): 355 | def __init__(self, config, layer_idx=None): 356 | super().__init__() 357 | hidden_size = config.hidden_size 358 | self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size 359 | 360 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 361 | self.attn = CodeShellAttention(config, layer_idx=layer_idx) 362 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 363 | 364 | self.mlp = CodeShellMLP(self.inner_dim, config) 365 | 366 | def forward( 367 | self, 368 | hidden_states: Optional[Tuple[torch.Tensor]], 369 | layer_past: Optional[torch.Tensor] = None, 370 | attention_mask: Optional[torch.Tensor] = None, 371 | position_ids: Optional[torch.LongTensor] = None, 372 | head_mask: Optional[torch.Tensor] = None, 373 | encoder_hidden_states: Optional[torch.Tensor] = None, 374 | encoder_attention_mask: Optional[torch.Tensor] = None, 375 | use_cache: Optional[bool] = False, 376 | output_attentions: Optional[bool] = False, 377 | ) -> Union[ 378 | Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 379 | ]: 380 | residual = hidden_states 381 | hidden_states = self.ln_1(hidden_states) 382 | attn_outputs = self.attn( 383 | hidden_states, 384 | layer_past=layer_past, 385 | attention_mask=attention_mask, 386 | position_ids=position_ids, 387 | head_mask=head_mask, 388 | use_cache=use_cache, 389 | output_attentions=output_attentions, 390 | ) 391 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 392 | 393 | outputs = attn_outputs[1:] 394 | # residual connection 395 | hidden_states = attn_output + residual 396 | 397 | residual = hidden_states 398 | hidden_states = self.ln_2(hidden_states) 399 | feed_forward_hidden_states = self.mlp(hidden_states) 400 | # residual connection 401 | hidden_states = residual + feed_forward_hidden_states 402 | 403 | if use_cache: 404 | outputs = (hidden_states,) + outputs 405 | else: 406 | outputs = (hidden_states,) + outputs[1:] 407 | 408 | return outputs # hidden_states, present, (attentions, cross_attentions) 409 | 410 | 411 | class CodeShellPreTrainedModel(PreTrainedModel): 412 | """ 413 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 414 | models. 415 | """ 416 | 417 | config_class = CodeShellConfig 418 | base_model_prefix = "transformer" 419 | supports_gradient_checkpointing = True 420 | _no_split_modules = ["ShellBlock"] 421 | _skip_keys_device_placement = "past_key_values" 422 | 423 | def __init__(self, *inputs, **kwargs): 424 | super().__init__(*inputs, **kwargs) 425 | 426 | def _init_weights(self, module): 427 | """Initialize the weights.""" 428 | if isinstance(module, (CodeShellMLP, CodeShellAttention)): 429 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 430 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 431 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 432 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 433 | # 434 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 435 | module.c_proj.weight.data.normal_( 436 | mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) 437 | ) 438 | module.c_proj._is_hf_initialized = True 439 | elif isinstance(module, nn.Linear): 440 | # Slightly different from the TF version which uses truncated_normal for initialization 441 | # cf https://github.com/pytorch/pytorch/pull/5617 442 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 443 | if module.bias is not None: 444 | module.bias.data.zero_() 445 | elif isinstance(module, nn.Embedding): 446 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 447 | if module.padding_idx is not None: 448 | module.weight.data[module.padding_idx].zero_() 449 | elif isinstance(module, nn.LayerNorm): 450 | module.bias.data.zero_() 451 | module.weight.data.fill_(1.0) 452 | 453 | # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->Shell 454 | def _set_gradient_checkpointing(self, module, value=False): 455 | if isinstance(module, CodeShellModel): 456 | module.gradient_checkpointing = value 457 | 458 | 459 | GPT_BIGCODE_START_DOCSTRING = r""" 460 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 461 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 462 | etc.) 463 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 464 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 465 | and behavior. 466 | Parameters: 467 | config ([`CodeShellConfig`]): Model configuration class with all the parameters of the model. 468 | Initializing with a config file does not load the weights associated with the model, only the 469 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 470 | """ 471 | 472 | GPT_BIGCODE_INPUTS_DOCSTRING = r""" 473 | Args: 474 | input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): 475 | `input_ids_length` = `sequence_length` if `past_key_values` is `None` else 476 | `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input 477 | sequence tokens in the vocabulary. 478 | If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as 479 | `input_ids`. 480 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 481 | [`PreTrainedTokenizer.__call__`] for details. 482 | [What are input IDs?](../glossary#input-ids) 483 | past_key_values (`Tuple[torch.Tensor]` of length `config.n_layers`): 484 | Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see 485 | `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have 486 | their past given to this model should not be passed as `input_ids` as they have already been computed. 487 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 488 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 489 | - 1 for tokens that are **not masked**, 490 | - 0 for tokens that are **masked**. 491 | If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for 492 | `past_key_values`. In other words, the `attention_mask` always has to have the length: 493 | `len(past_key_values) + len(input_ids)` 494 | [What are attention masks?](../glossary#attention-mask) 495 | token_type_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`, *optional*): 496 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 497 | 1]`: 498 | - 0 corresponds to a *sentence A* token, 499 | - 1 corresponds to a *sentence B* token. 500 | [What are token type IDs?](../glossary#token-type-ids) 501 | position_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 502 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 503 | config.max_position_embeddings - 1]`. 504 | [What are position IDs?](../glossary#position-ids) 505 | head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 506 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 507 | - 1 indicates the head is **not masked**, 508 | - 0 indicates the head is **masked**. 509 | inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 510 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 511 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 512 | model's internal embedding lookup matrix. 513 | If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see 514 | `past_key_values`). 515 | use_cache (`bool`, *optional*): 516 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 517 | `past_key_values`). 518 | output_attentions (`bool`, *optional*): 519 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 520 | tensors for more detail. 521 | output_hidden_states (`bool`, *optional*): 522 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 523 | more detail. 524 | return_dict (`bool`, *optional*): 525 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 526 | """ 527 | 528 | 529 | @add_start_docstrings( 530 | "The bare GPT_BIGCODE Model transformer outputting raw hidden-states without any specific head on top.", 531 | GPT_BIGCODE_START_DOCSTRING, 532 | ) 533 | class CodeShellModel(CodeShellPreTrainedModel): 534 | def __init__(self, config): 535 | super().__init__(config) 536 | self.group_query_attention = config.group_query_attention 537 | self.num_query_groups = config.num_query_groups 538 | self.position_embedding_type = config.position_embedding_type 539 | self.embed_dim = config.hidden_size 540 | 541 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim) 542 | if self.position_embedding_type == "learned_absolute": 543 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) 544 | else: 545 | pass 546 | 547 | self.drop = nn.Dropout(config.embd_pdrop) 548 | self.h = nn.ModuleList([CodeShellBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) 549 | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) 550 | 551 | max_positions = config.max_position_embeddings 552 | self.register_buffer( 553 | "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False 554 | ) 555 | 556 | self.gradient_checkpointing = False 557 | 558 | # Initialize weights and apply final processing 559 | self.post_init() 560 | 561 | def get_input_embeddings(self): 562 | return self.wte 563 | 564 | def set_input_embeddings(self, new_embeddings): 565 | self.wte = new_embeddings 566 | 567 | @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) 568 | def forward( 569 | self, 570 | input_ids: Optional[torch.Tensor] = None, 571 | past_key_values: Optional[List[torch.Tensor]] = None, 572 | attention_mask: Optional[torch.Tensor] = None, 573 | token_type_ids: Optional[torch.Tensor] = None, 574 | position_ids: Optional[torch.Tensor] = None, 575 | head_mask: Optional[torch.Tensor] = None, 576 | inputs_embeds: Optional[torch.Tensor] = None, 577 | encoder_hidden_states: Optional[torch.Tensor] = None, 578 | encoder_attention_mask: Optional[torch.Tensor] = None, 579 | use_cache: Optional[bool] = None, 580 | output_attentions: Optional[bool] = None, 581 | output_hidden_states: Optional[bool] = None, 582 | return_dict: Optional[bool] = None, 583 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 584 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 585 | output_hidden_states = ( 586 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 587 | ) 588 | use_cache = use_cache if use_cache is not None else self.config.use_cache 589 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 590 | 591 | if input_ids is not None and inputs_embeds is not None: 592 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 593 | elif input_ids is not None: 594 | input_shape = input_ids.size() 595 | input_ids = input_ids.reshape(-1, input_shape[-1]) 596 | batch_size = input_ids.shape[0] 597 | elif inputs_embeds is not None: 598 | input_shape = inputs_embeds.size()[:-1] 599 | batch_size = inputs_embeds.shape[0] 600 | else: 601 | raise ValueError("You have to specify either input_ids or inputs_embeds") 602 | 603 | if batch_size <= 0: 604 | raise ValueError("batch_size has to be defined and > 0") 605 | 606 | device = input_ids.device if input_ids is not None else inputs_embeds.device 607 | 608 | if token_type_ids is not None: 609 | token_type_ids = token_type_ids.reshape(-1, input_shape[-1]) 610 | if position_ids is not None: 611 | position_ids = position_ids.reshape(-1, input_shape[-1]) 612 | 613 | if past_key_values is None: 614 | past_length = 0 615 | past_key_values = tuple([None] * len(self.h)) 616 | else: 617 | past_length = past_key_values[0][0].size(-2) 618 | 619 | if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: 620 | # create position_ids on the fly for batch generation 621 | position_ids = attention_mask.long().cumsum(-1) - 1 622 | position_ids.masked_fill_(attention_mask == 0, 1) 623 | if past_length > 0: 624 | position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] 625 | elif position_ids is None: 626 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 627 | position_ids = position_ids.unsqueeze(0).reshape(-1, input_shape[-1]) 628 | 629 | # Self-attention mask. 630 | query_length = input_shape[-1] 631 | key_length = past_length + query_length 632 | self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] 633 | 634 | if attention_mask is not None: 635 | self_attention_mask = self_attention_mask * attention_mask.reshape(batch_size, 1, -1).to( 636 | dtype=torch.bool, device=self_attention_mask.device 637 | ) 638 | 639 | # MQA models: (batch_size, query_length, n_heads, key_length) 640 | # MHA models: (batch_size, n_heads, query_length, key_length) 641 | attention_mask = self_attention_mask.unsqueeze(1) 642 | 643 | encoder_attention_mask = None 644 | 645 | # Prepare head mask if needed 646 | # 1.0 in head_mask indicate we keep the head 647 | # attention_probs has shape bsz x n_heads x N x N 648 | # head_mask has shape n_layer x batch x n_heads x N x N 649 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 650 | 651 | if inputs_embeds is None: 652 | inputs_embeds = self.wte(input_ids) 653 | 654 | hidden_states = inputs_embeds 655 | if self.position_embedding_type == "learned_absolute": 656 | position_embeds = self.wpe(position_ids) 657 | hidden_states = hidden_states + position_embeds 658 | 659 | if token_type_ids is not None: 660 | token_type_embeds = self.wte(token_type_ids) 661 | hidden_states = hidden_states + token_type_embeds 662 | 663 | hidden_states = self.drop(hidden_states) 664 | 665 | output_shape = input_shape + (hidden_states.size(-1),) 666 | 667 | presents = [] if use_cache else None 668 | all_self_attentions = () if output_attentions else None 669 | all_hidden_states = () if output_hidden_states else None 670 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 671 | if output_hidden_states: 672 | all_hidden_states = all_hidden_states + (hidden_states,) 673 | 674 | if self.gradient_checkpointing and self.training: 675 | 676 | def create_custom_forward(module): 677 | def custom_forward(*inputs): 678 | # None for past_key_value 679 | return module(*inputs, use_cache, output_attentions) 680 | 681 | return custom_forward 682 | 683 | outputs = torch.utils.checkpoint.checkpoint( 684 | create_custom_forward(block), 685 | hidden_states, 686 | None, 687 | attention_mask, 688 | position_ids, 689 | head_mask[i], 690 | encoder_hidden_states, 691 | encoder_attention_mask, 692 | ) 693 | else: 694 | outputs = block( 695 | hidden_states, 696 | layer_past=layer_past, 697 | attention_mask=attention_mask, 698 | position_ids=position_ids, 699 | head_mask=head_mask[i], 700 | encoder_hidden_states=encoder_hidden_states, 701 | encoder_attention_mask=encoder_attention_mask, 702 | use_cache=use_cache, 703 | output_attentions=output_attentions, 704 | ) 705 | 706 | hidden_states = outputs[0] 707 | if use_cache: 708 | presents.append(outputs[1]) 709 | 710 | if output_attentions: 711 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 712 | 713 | hidden_states = self.ln_f(hidden_states) 714 | hidden_states = hidden_states.reshape(output_shape) 715 | # Add last hidden state 716 | if output_hidden_states: 717 | all_hidden_states = all_hidden_states + (hidden_states,) 718 | 719 | 720 | if not return_dict: 721 | return tuple( 722 | v 723 | for v in [hidden_states, presents, all_hidden_states, all_self_attentions] 724 | if v is not None 725 | ) 726 | 727 | return BaseModelOutputWithPastAndCrossAttentions( 728 | last_hidden_state=hidden_states, 729 | past_key_values=presents, 730 | hidden_states=all_hidden_states, 731 | attentions=all_self_attentions, 732 | ) 733 | 734 | class EndOfFunctionCriteria(StoppingCriteria): 735 | """Custom `StoppingCriteria` which checks if all generated functions in the batch are completed.""" 736 | def __init__(self, input_lengths, eof_strings, tokenizer): 737 | self.input_lengths = input_lengths 738 | self.eof_strings = eof_strings 739 | self.tokenizer = tokenizer 740 | 741 | def __call__(self, input_ids, scores, **kwargs): 742 | """Returns true if all generated sequences contain any of the end-of-function strings.""" 743 | decoded_generations = [] 744 | for _input_ids, input_length in zip(input_ids, self.input_lengths): 745 | decoded_generations.append(self.tokenizer.decode(_input_ids[input_length:])) 746 | done = [] 747 | for decoded_generation in decoded_generations: 748 | done.append( 749 | any( 750 | [ 751 | stop_string in decoded_generation 752 | for stop_string in self.eof_strings 753 | ] 754 | ) 755 | ) 756 | return all(done) 757 | 758 | class TextIterStreamer: 759 | def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False): 760 | self.tokenizer = tokenizer 761 | self.skip_prompt = skip_prompt 762 | self.skip_special_tokens = skip_special_tokens 763 | self.tokens = [] 764 | self.text_queue = Queue() 765 | self.next_tokens_are_prompt = True 766 | 767 | def put(self, value): 768 | if self.skip_prompt and self.next_tokens_are_prompt: 769 | self.next_tokens_are_prompt = False 770 | else: 771 | if len(value.shape) > 1: 772 | value = value[0] 773 | self.tokens.extend(value.tolist()) 774 | self.text_queue.put( 775 | self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens)) 776 | 777 | def end(self): 778 | self.text_queue.put(None) 779 | 780 | def __iter__(self): 781 | return self 782 | 783 | def __next__(self): 784 | value = self.text_queue.get() 785 | if value is None: 786 | raise StopIteration() 787 | else: 788 | return value 789 | 790 | 791 | @add_start_docstrings( 792 | """ 793 | The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input 794 | embeddings). 795 | """, 796 | GPT_BIGCODE_START_DOCSTRING, 797 | ) 798 | class CodeShellForCausalLM(CodeShellPreTrainedModel): 799 | _tied_weights_keys = ["lm_head.weight"] 800 | 801 | def __init__(self, config): 802 | super().__init__(config) 803 | self.transformer = CodeShellModel(config) 804 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 805 | 806 | # Initialize weights and apply final processing 807 | self.post_init() 808 | 809 | def quantize(self, bits: int): 810 | try: 811 | import bitsandbytes 812 | from .quantizer import quantize 813 | except ImportError: 814 | raise ImportError(f"Needs bitsandbytes to run quantize.") 815 | return quantize(self, bits) 816 | 817 | def get_output_embeddings(self): 818 | return self.lm_head 819 | 820 | def set_output_embeddings(self, new_embeddings): 821 | self.lm_head = new_embeddings 822 | 823 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 824 | token_type_ids = kwargs.get("token_type_ids", None) 825 | # only last token for inputs_ids if past is defined in kwargs 826 | if past_key_values: 827 | input_ids = input_ids[:, -1].unsqueeze(-1) 828 | if token_type_ids is not None: 829 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 830 | 831 | attention_mask = kwargs.get("attention_mask", None) 832 | position_ids = kwargs.get("position_ids", None) 833 | 834 | if attention_mask is not None and position_ids is None: 835 | # create position_ids on the fly for batch generation 836 | position_ids = attention_mask.long().cumsum(-1) - 1 837 | position_ids.masked_fill_(attention_mask == 0, 1) 838 | if past_key_values: 839 | position_ids = position_ids[:, -1].unsqueeze(-1) 840 | else: 841 | position_ids = None 842 | 843 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 844 | if inputs_embeds is not None and past_key_values is None: 845 | model_inputs = {"inputs_embeds": inputs_embeds} 846 | else: 847 | model_inputs = {"input_ids": input_ids} 848 | 849 | model_inputs.update( 850 | { 851 | "past_key_values": past_key_values, 852 | "use_cache": kwargs.get("use_cache"), 853 | "position_ids": position_ids, 854 | "attention_mask": attention_mask, 855 | "token_type_ids": token_type_ids, 856 | } 857 | ) 858 | return model_inputs 859 | 860 | @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) 861 | def forward( 862 | self, 863 | input_ids: Optional[torch.Tensor] = None, 864 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 865 | attention_mask: Optional[torch.Tensor] = None, 866 | token_type_ids: Optional[torch.Tensor] = None, 867 | position_ids: Optional[torch.Tensor] = None, 868 | head_mask: Optional[torch.Tensor] = None, 869 | inputs_embeds: Optional[torch.Tensor] = None, 870 | encoder_hidden_states: Optional[torch.Tensor] = None, 871 | encoder_attention_mask: Optional[torch.Tensor] = None, 872 | labels: Optional[torch.Tensor] = None, 873 | use_cache: Optional[bool] = None, 874 | output_attentions: Optional[bool] = None, 875 | output_hidden_states: Optional[bool] = None, 876 | return_dict: Optional[bool] = None, 877 | ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: 878 | r""" 879 | labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 880 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 881 | `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` 882 | are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` 883 | """ 884 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 885 | 886 | transformer_outputs = self.transformer( 887 | input_ids, 888 | past_key_values=past_key_values, 889 | attention_mask=attention_mask, 890 | token_type_ids=token_type_ids, 891 | position_ids=position_ids, 892 | head_mask=head_mask, 893 | inputs_embeds=inputs_embeds, 894 | encoder_hidden_states=encoder_hidden_states, 895 | encoder_attention_mask=encoder_attention_mask, 896 | use_cache=use_cache, 897 | output_attentions=output_attentions, 898 | output_hidden_states=output_hidden_states, 899 | return_dict=return_dict, 900 | ) 901 | hidden_states = transformer_outputs[0] 902 | lm_logits = self.lm_head(hidden_states) 903 | loss = None 904 | if labels is not None: 905 | # Shift so that tokens < n predict n 906 | shift_logits = lm_logits[..., :-1, :].contiguous() 907 | shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) 908 | # Flatten the tokens 909 | loss_fct = CrossEntropyLoss() 910 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) 911 | 912 | if not return_dict: 913 | output = (lm_logits,) + transformer_outputs[1:] 914 | return ((loss,) + output) if loss is not None else output 915 | 916 | return CausalLMOutputWithCrossAttentions( 917 | loss=loss, 918 | logits=lm_logits, 919 | past_key_values=transformer_outputs.past_key_values, 920 | hidden_states=transformer_outputs.hidden_states, 921 | attentions=transformer_outputs.attentions, 922 | ) 923 | 924 | @staticmethod 925 | def _reorder_cache(past_key_values, beam_idx): 926 | reordered_past = () 927 | for layer_past in past_key_values: 928 | reordered_past += ( 929 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 930 | ) 931 | return reordered_past 932 | 933 | 934 | def build_chat_input(self, query, history, tokenizer, max_new_tokens=None): 935 | user_name = "## human:" 936 | ai_name = "## assistant: " 937 | stop = '||' 938 | 939 | prompt = '' 940 | for q, r in history: 941 | prompt += f"{user_name}{q}{stop}" 942 | prompt += f"{ai_name}{r}{stop}" 943 | prompt += f"{user_name}{query}{stop}" 944 | prompt += ai_name.rstrip() 945 | 946 | max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens 947 | max_new_tokens = max_new_tokens or 128 948 | max_input_tokens = self.config.n_positions - max_new_tokens 949 | 950 | input_tokens = tokenizer.encode(prompt) 951 | input_tokens = input_tokens[-max_input_tokens:] # truncate left 952 | return torch.LongTensor([input_tokens]).to(self.device) 953 | 954 | def chat(self, query, history, tokenizer, stream=False, 955 | generation_config: Optional[GenerationConfig]=None): 956 | generation_config = generation_config or self.generation_config 957 | input_ids = self.build_chat_input(query, history, tokenizer, generation_config.max_new_tokens) 958 | stopping_criteria = StoppingCriteriaList( 959 | [EndOfFunctionCriteria([len(input_ids[0])], ['||', '|end|', '<|endoftext|>', '## human'], tokenizer)] 960 | ) 961 | 962 | if stream: 963 | streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 964 | Thread(target=self.generate, kwargs=dict( 965 | inputs=input_ids, streamer=streamer, 966 | stopping_criteria = stopping_criteria, 967 | generation_config=generation_config, 968 | )).start() 969 | return streamer 970 | else: 971 | outputs = self.generate(input_ids, generation_config=generation_config, stopping_criteria = stopping_criteria) 972 | response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True) 973 | return response 974 | 975 | def generate_stream(self, prompt, tokenizer, generation_config=None, **kwargs): 976 | generation_config = generation_config or self.generation_config 977 | max_input_tokens = self.config.n_positions - self.generation_config.max_new_tokens 978 | 979 | input_ids = tokenizer.encode(prompt) 980 | input_ids = input_ids[-max_input_tokens:] # truncate left 981 | 982 | stopping_criteria = StoppingCriteriaList( 983 | [EndOfFunctionCriteria([len(input_ids[0])], ['||', '|end|', '<|endoftext|>', '## human'], tokenizer)] 984 | ) 985 | 986 | streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 987 | Thread(target=self.generate, kwargs=dict( 988 | inputs=input_ids, stopping_criteria=stopping_criteria, **kwargs 989 | )).start() 990 | return streamer 991 | 992 | 993 | class CodeShell4bitForCausalLM(CodeShellForCausalLM): 994 | def __init__(self, config): 995 | CodeShellPreTrainedModel.__init__(self, config) 996 | self.transformer = CodeShellModel(config) 997 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 998 | 999 | try: 1000 | import bitsandbytes 1001 | from .quantizer import quantize_offline 1002 | quantize_offline(self) 1003 | except ImportError: 1004 | raise ImportError(f"Needs bitsandbytes to run quantize.") 1005 | 1006 | self.post_init() 1007 | 1008 | @classmethod 1009 | def from_pretrained( 1010 | cls, 1011 | pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], 1012 | *model_args, 1013 | config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, 1014 | cache_dir: Optional[Union[str, os.PathLike]] = None, 1015 | ignore_mismatched_sizes: bool = False, 1016 | force_download: bool = False, 1017 | local_files_only: bool = False, 1018 | token: Optional[Union[str, bool]] = None, 1019 | revision: str = "main", 1020 | use_safetensors: bool = None, 1021 | **kwargs, 1022 | ): 1023 | if not isinstance(config, PretrainedConfig): 1024 | config_path = config if config is not None else pretrained_model_name_or_path 1025 | config, _ = cls.config_class.from_pretrained( 1026 | config_path, 1027 | cache_dir=cache_dir, 1028 | return_unused_kwargs=True, 1029 | force_download=force_download, 1030 | resume_download=False, 1031 | proxies=None, 1032 | local_files_only=local_files_only, 1033 | token=token, 1034 | revision=revision, 1035 | subfolder="", 1036 | _from_auto=False, 1037 | _from_pipeline=None, 1038 | **kwargs, 1039 | ) 1040 | 1041 | # Load config if we don't provide a configuration 1042 | from .quantizer import load_state_dict_for_qunantied_model 1043 | model = cls(config) 1044 | state_dict = torch.load(os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin'), map_location="cpu") 1045 | model = load_state_dict_for_qunantied_model(model, state_dict) 1046 | model.eval() 1047 | 1048 | # If it is a model with generation capabilities, attempt to load the generation config 1049 | if model.can_generate(): 1050 | try: 1051 | model.generation_config = GenerationConfig.from_pretrained( 1052 | pretrained_model_name_or_path, 1053 | cache_dir=cache_dir, 1054 | force_download=force_download, 1055 | resume_download=False, 1056 | proxies=None, 1057 | local_files_only=local_files_only, 1058 | token=token, 1059 | revision=revision, 1060 | subfolder="", 1061 | _from_auto=False, 1062 | _from_pipeline=None, 1063 | **kwargs, 1064 | ) 1065 | except (OSError, TypeError): 1066 | pass 1067 | 1068 | device_map = kwargs.pop("device_map", None) 1069 | if device_map is not None: 1070 | model = model.to(torch.device(device_map)) 1071 | 1072 | return model -------------------------------------------------------------------------------- /model/quantizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | try: 17 | import bitsandbytes as bnb 18 | from bitsandbytes.nn.modules import Params4bit, Int8Params 19 | except ImportError: 20 | pass 21 | import torch 22 | 23 | def Params4bitCuda(self, device): 24 | self.data = self.data.cuda(device) 25 | if self.quant_state is not None: 26 | self.quant_state[0] = self.quant_state[0].cuda(device) 27 | self.quant_state[6] = self.quant_state[6].cuda(device) 28 | return self 29 | 30 | def Params4bitTo(self, *args, **kwargs): 31 | device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) 32 | 33 | if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): 34 | return self.cuda(device) 35 | else: 36 | if self.quant_state is not None: 37 | # make sure the quantization state is on the right device 38 | self.quant_state[0] = self.quant_state[0].to(device) 39 | self.quant_state[6] = self.quant_state[6].to(device) 40 | new_param = Params4bit(self.to(device=device, dtype=dtype, non_blocking=non_blocking), 41 | requires_grad=self.requires_grad, quant_state=self.quant_state, 42 | blocksize=self.blocksize, compress_statistics=self.compress_statistics, 43 | quant_type=self.quant_type) 44 | 45 | return new_param 46 | 47 | class Linear4bitOnline(torch.nn.Module): 48 | def __init__(self, weight, bias, quant_type): 49 | super().__init__() 50 | self.weight = Params4bit( 51 | weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type 52 | ) 53 | self.compute_dtype = None 54 | #self.weight.cuda(weight.device) 55 | self.bias = bias 56 | 57 | def forward(self, x: torch.Tensor): 58 | # weights are cast automatically as Int8Params, but the bias has to be cast manually 59 | if self.bias is not None and self.bias.dtype != x.dtype: 60 | self.bias.data = self.bias.data.to(x.dtype) 61 | 62 | if getattr(self.weight, "quant_state", None) is None: 63 | print( 64 | "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first." 65 | ) 66 | inp_dtype = x.dtype 67 | if self.compute_dtype is not None: 68 | x = x.to(self.compute_dtype) 69 | 70 | bias = None if self.bias is None else self.bias.to(self.compute_dtype) 71 | out = bnb.matmul_4bit( 72 | x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state 73 | ) 74 | 75 | out = out.to(inp_dtype) 76 | 77 | return out 78 | 79 | class Linear8bitLtOnline(torch.nn.Module): 80 | def __init__( 81 | self, 82 | weight, 83 | bias, 84 | has_fp16_weights=True, 85 | memory_efficient_backward=False, 86 | threshold=0.0, 87 | index=None, 88 | ): 89 | super().__init__() 90 | assert ( 91 | not memory_efficient_backward 92 | ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" 93 | self.state = bnb.MatmulLtState() 94 | self.index = index 95 | 96 | # Necessary for stacked layers 97 | self.state.threshold = threshold 98 | self.state.has_fp16_weights = has_fp16_weights 99 | self.state.memory_efficient_backward = memory_efficient_backward 100 | if threshold > 0.0 and not has_fp16_weights: 101 | self.state.use_pool = True 102 | 103 | self.weight = Int8Params( 104 | weight.data, 105 | has_fp16_weights=has_fp16_weights, 106 | requires_grad=has_fp16_weights, 107 | ) 108 | self.bias = bias 109 | 110 | def init_8bit_state(self): 111 | self.state.CB = self.weight.CB 112 | self.state.SCB = self.weight.SCB 113 | self.weight.CB = None 114 | self.weight.SCB = None 115 | 116 | def forward(self, x: torch.Tensor): 117 | self.state.is_training = self.training 118 | if self.weight.CB is not None: 119 | self.init_8bit_state() 120 | 121 | # weights are cast automatically as Int8Params, but the bias has to be cast manually 122 | if self.bias is not None and self.bias.dtype != x.dtype: 123 | self.bias.data = self.bias.data.to(x.dtype) 124 | 125 | out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) 126 | 127 | if not self.state.has_fp16_weights: 128 | if self.state.CB is not None and self.state.CxB is not None: 129 | # we converted 8-bit row major to turing/ampere format in the first inference pass 130 | # we no longer need the row-major weight 131 | del self.state.CB 132 | self.weight.data = self.state.CxB 133 | return out 134 | 135 | def quantize_online(model, bits: int): 136 | def quant(weight, bias=None): 137 | if bits == 8: 138 | linear = Linear8bitLtOnline( 139 | weight, 140 | bias, 141 | has_fp16_weights=False, 142 | threshold=6.0, 143 | ) 144 | if bias is not None: 145 | linear.bias = torch.nn.Parameter(bias) 146 | elif bits == 4: 147 | linear = Linear4bitOnline( 148 | weight, 149 | bias, 150 | quant_type="nf4", #fp4/nf4 151 | ) 152 | else: 153 | raise ValueError("quantize only support 4/8 bit") 154 | return linear 155 | 156 | def auto_quant(layer): 157 | if hasattr(layer,"bias"): 158 | linear = quant(layer.weight,bias=layer.bias) 159 | else: 160 | linear = quant(layer.weight) 161 | return linear 162 | 163 | for i,layer in enumerate(model.transformer.h): 164 | layer.mlp.c_fc = auto_quant(layer.mlp.c_fc) 165 | layer.mlp.c_proj = auto_quant(layer.mlp.c_proj) 166 | 167 | layer.attn.c_attn=auto_quant(layer.attn.c_attn) 168 | layer.attn.c_proj=auto_quant(layer.attn.c_proj) 169 | 170 | return model 171 | 172 | 173 | general_weight_dict = { 174 | "transformer.wte.weight": False, 175 | "transformer.ln_f.weight": False, 176 | "transformer.ln_f.bias": False, 177 | "lm_head.weight": False, 178 | } 179 | 180 | layer_weight_dict = { 181 | "transformer.h.{i}.ln_1.weight": False, 182 | "transformer.h.{i}.ln_1.bias": False, 183 | "transformer.h.{i}.attn.c_attn.weight": True, 184 | "transformer.h.{i}.attn.c_attn.bias": False, 185 | "transformer.h.{i}.attn.c_proj.weight": True, 186 | "transformer.h.{i}.attn.c_proj.bias": False, 187 | "transformer.h.{i}.attn.rotary_emb.inv_freq": False, 188 | "transformer.h.{i}.ln_2.weight": False, 189 | "transformer.h.{i}.ln_2.bias": False, 190 | "transformer.h.{i}.mlp.c_fc.weight": True, 191 | "transformer.h.{i}.mlp.c_fc.bias": False, 192 | "transformer.h.{i}.mlp.c_proj.weight": True, 193 | "transformer.h.{i}.mlp.c_proj.bias": False, 194 | } 195 | num_dict = {str(i):i for i in range(100)} 196 | 197 | def set_value(model, name, state_dict, is_4bit): 198 | keys = name.split('.') 199 | parent = model 200 | for key in keys[:-1]: 201 | if key in num_dict: 202 | parent = parent[num_dict[key]] 203 | else: 204 | parent = getattr(parent, key) 205 | if is_4bit: 206 | weight_data = state_dict[f'{name}.data'] 207 | weight_quant_state = state_dict[f'{name}.quant_state'] 208 | assert weight_data is not None, name 209 | assert weight_quant_state is not None, name 210 | setattr(parent, keys[-1], Params4bit(weight_data, requires_grad=False, quant_state=weight_quant_state)) 211 | else: 212 | setattr(parent, keys[-1], state_dict[name]) 213 | 214 | def quantize_offline(model): 215 | for i, layer in enumerate(model.transformer.h): 216 | layer.mlp.c_fc = bnb.nn.Linear4bit( 217 | layer.mlp.c_fc.weight.shape[1], 218 | layer.mlp.c_fc.weight.shape[0], 219 | False, 220 | torch.bfloat16, 221 | compress_statistics=True, 222 | quant_type="nf4", 223 | ) 224 | layer.mlp.c_proj = bnb.nn.Linear4bit( 225 | layer.mlp.c_proj.weight.shape[1], 226 | layer.mlp.c_proj.weight.shape[0], 227 | False, 228 | torch.bfloat16, 229 | compress_statistics=True, 230 | quant_type="nf4", 231 | ) 232 | 233 | layer.attn.c_attn = bnb.nn.Linear4bit( 234 | layer.attn.c_attn.weight.shape[1], 235 | layer.attn.c_attn.weight.shape[0], 236 | False, 237 | torch.bfloat16, 238 | compress_statistics=True, 239 | quant_type="nf4", 240 | ) 241 | layer.attn.c_proj = bnb.nn.Linear4bit( 242 | layer.attn.c_proj.weight.shape[1], 243 | layer.attn.c_proj.weight.shape[0], 244 | False, 245 | torch.bfloat16, 246 | compress_statistics=True, 247 | quant_type="nf4", 248 | ) 249 | return model 250 | 251 | def load_state_dict_for_qunantied_model(model, state_dict): 252 | #replace Params4bit.cuda with Params4bitCuda 253 | Params4bit.cuda = Params4bitCuda 254 | Params4bit.to = Params4bitTo 255 | 256 | for name, is_4bit in general_weight_dict.items(): 257 | set_value(model, name, state_dict, is_4bit) 258 | 259 | for layer_i in range(len(model.transformer.h)): 260 | for name, is_4bit in layer_weight_dict.items(): 261 | name = name.replace('{i}', str(layer_i)) 262 | set_value(model, name, state_dict, is_4bit) 263 | return model -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.34.0 2 | torch>=2.0.1 3 | -------------------------------------------------------------------------------- /tokenizer/eval_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 WisdomShell Inc. All Rights Reserved. 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from transformers import AutoTokenizer 17 | 18 | import jsonlines 19 | import os 20 | 21 | from fire import Fire 22 | 23 | 24 | def evaluate(tokenizer_path: str, corpora_path: str): 25 | """ 26 | Evaluate the compression ratio of a tokenizer on a corpora 27 | """ 28 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) 29 | corpora = [os.path.join(corpora_path, f) for f in os.listdir(corpora_path) 30 | if not os.path.isdir(os.path.join(corpora_path, f))] 31 | 32 | total_characters = 0 33 | total_tokens = 0 34 | 35 | for corpus in corpora: 36 | print(f"Processing {corpus}") 37 | texts = [] 38 | partial_characters = 0 39 | partial_tokens = 0 40 | with open(corpus, "r", encoding="utf-8") as f_in: 41 | for item in jsonlines.Reader(f_in): 42 | texts.append(item["text"]) 43 | partial_characters += len(item["text"]) 44 | 45 | # noinspection PyUnboundLocalVariable 46 | tokens = tokenizer(texts) 47 | 48 | for seg in tokens["input_ids"]: 49 | partial_tokens += len(seg) 50 | 51 | total_characters += partial_characters 52 | total_tokens += partial_tokens 53 | print(f"Characters: {partial_characters}") 54 | print(f"Tokens: {partial_tokens}") 55 | print(f"Compression ratio: {partial_characters / partial_tokens}") 56 | 57 | print(f"Total characters: {total_characters}") 58 | print(f"Total tokens: {total_tokens}") 59 | print(f"Compression ratio: {total_characters / total_tokens}") 60 | 61 | 62 | if __name__ == "__main__": 63 | Fire(evaluate) -------------------------------------------------------------------------------- /tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "additional_special_tokens": [ 4 | "<|endoftext|>", 5 | "", 6 | "", 7 | "", 8 | "", 9 | "", 10 | "", 11 | "", 12 | "", 13 | "", 14 | "", 15 | "", 16 | "", 17 | "", 18 | "", 19 | "", 20 | "", 21 | "", 22 | "", 23 | "<|end|>" 24 | ], 25 | "bos_token": "<|endoftext|>", 26 | "clean_up_tokenization_spaces": true, 27 | "eos_token": "<|endoftext|>", 28 | "model_max_length": 8192, 29 | "tokenizer_class": "GPT2Tokenizer", 30 | "unk_token": "<|endoftext|>", 31 | "vocab_size": 70020, 32 | "pad_token": "<|endoftext|>" 33 | } 34 | --------------------------------------------------------------------------------