├── LICENSE ├── LLM_Pretrain_Datasets.md ├── README.md ├── embedding_init.ipynb ├── generate_parameter.ipynb ├── memory_precision.ipynb ├── precision.ipynb ├── quality_hash.ipynb ├── sentencepiece.ipynb ├── tokenization.ipynb └── transformer_torch.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LLM_Pretrain_Datasets.md: -------------------------------------------------------------------------------- 1 | 2 | 开源的可用于LLM Pretrain数据集 3 | 4 | | 数据集 | 语言 | 大小 | 备注 | 地址 | 5 | | ---------------------------------- | -------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | 6 | | WuDaoCorpora | 中文 | 200G | 北京智源研究院从100TB原始网页数据中清洗得出最终数据集,包含教育、科技等50+个行业数据标签,总共5TB,开源200G | [https://data.baai.ac.cn/details/WuDaoCorporaText](https://data.baai.ac.cn/details/WuDaoCorporaText "https://data.baai.ac.cn/details/WuDaoCorporaText") | 7 | | WanJuan1.0 | 中英文 | 约1T,中文约500G | 上海AI实验室从来自网页、百科、书籍、专利、教材、考题等不同来源的清洗后预训练语料组成,数据总量超过5亿个文档,数据大小超过1TB。 | [https://opendatalab.org.cn/OpenDataLab/WanJuan1\_dot\_0](https://opendatalab.org.cn/OpenDataLab/WanJuan1_dot_0 "https://opendatalab.org.cn/OpenDataLab/WanJuan1_dot_0") | 8 | | 蜜巢·花粉1.0 | 中文 | 约240G | 蜜度公司从公开可访问的中文互联网数据,领域包括新闻、政务等。通过关键词过滤、图片抽取、规则过滤、格式转换等一系列数据处理流程,最终清洗后的数据达7000余万条,同时包括100余万个图片链接。 | [https://opendatalab.org.cn/OpenDataLab/MiChao](https://opendatalab.org.cn/OpenDataLab/MiChao "https://opendatalab.org.cn/OpenDataLab/MiChao") | 9 | | MNBVC | 中文 | 目前20T | 中文数据开源之光! MNBVC数据集包括新闻、作文、小说、书籍、杂志、论文、台词、帖子、wiki、古诗、歌词、商品介绍、笑话、糗事、聊天记录等一切形式的纯文本中文数据,数据均来源于互联网收集。 | [https://github.com/esbatmop/MNBVC](https://github.com/esbatmop/MNBVC "https://github.com/esbatmop/MNBVC") | 10 | | TigerBot | 中英文 | 中文约50G,英文约50G | Tiger基于 GPT3 的 pretrain 的数据分布,采集中文书籍,互联网,和百科类数据,并通过数据源质量分过滤和 tf-idf soft deduping,从 20TB 数据过滤到 2TB,保持语言和类目的比例,并在此基础上随机抽样 100G 数据开源。 | [https://github.com/TigerResearch/TigerBot#开源数据集](https://github.com/TigerResearch/TigerBot#开源数据集 "https://github.com/TigerResearch/TigerBot#开源数据集") | 11 | | CLUECorpus2020 | 中文 | 约100G | 通过对Common Crawl的中文部分进行语料清洗,最终得到100GB的高质量中文预训练语料 | [https://github.com/CLUEbenchmark/CLUECorpus2020/](https://github.com/CLUEbenchmark/CLUECorpus2020/ "https://github.com/CLUEbenchmark/CLUECorpus2020/") | 12 | | FinCorpus | 中文 | 约60G | 度小满开源的中文金融资讯数据集 | [https://huggingface.co/datasets/Duxiaoman-DI/FinCorpus](https://huggingface.co/datasets/Duxiaoman-DI/FinCorpus "https://huggingface.co/datasets/Duxiaoman-DI/FinCorpus") | 13 | | Chinese\_book\_dataset | 中文 | 13.3万本 | 一个广泛搜集爬取的中文图书分类数据集。数据采集自各大电子书网站。 | [https://github.com/JiangYanting/Chinese\_book\_dataset](https://github.com/JiangYanting/Chinese_book_dataset "https://github.com/JiangYanting/Chinese_book_dataset") | 14 | | CulturaX | 多语言,主要英文 | 共27T,约1T中文 | 用于167种语言的大型语言模型的多语言数据集,数据集经过比较彻底的清理阶段 | [https://huggingface.co/datasets/uonlp/CulturaX](https://huggingface.co/datasets/uonlp/CulturaX "https://huggingface.co/datasets/uonlp/CulturaX") | 15 | | Bloom | 多语言,主要英文 | 共1.6T,约10G中文 | BLOOM是在ROOTS的语料上训练的,其是一个由498个Hugging Face数据集组成的语料。共计1.61TB的文本,包含46种自然语言和13种编程语言。 | [https://huggingface.co/bigscience-data](https://huggingface.co/bigscience-data "https://huggingface.co/bigscience-data") | 16 | | Common Crawl | 多语言,主要英文 | 每月更新 | Common Crawl 每个月都会发布一个快照,包含了随机搜索和采样的 URL 所获得的原始网页。 | [https://commoncrawl.org/](https://commoncrawl.org/ "https://commoncrawl.org/") | 17 | | Colossal Clean Crawled Corpus (C4) | 多语言,主要英文 | 最新版约17T | 基于Common Crawl数据清洗得到的,最初被Google用来训练 T5 模型,最新版是2023年4月的3.1.0版本。 | [https://www.tensorflow.org/datasets/catalog/c4](https://www.tensorflow.org/datasets/catalog/c4 "https://www.tensorflow.org/datasets/catalog/c4") | 18 | | The Pile | 主要英文 | 825G | 由22个高质量数据集集合并进一步处理的预训练数据集 | [https://pile.eleuther.ai/](https://pile.eleuther.ai/ "https://pile.eleuther.ai/") | 19 | | RedPajama | 主要英文 | 约5T | 复刻llama的预训练数据集 | [https://github.com/togethercomputer/RedPajama-Data](https://github.com/togethercomputer/RedPajama-Data "https://github.com/togethercomputer/RedPajama-Data") | 20 | | Wikipedia | 主要英文 | 更新中 | 维基百科的数据 | [https://huggingface.co/datasets/wikipedia](https://huggingface.co/datasets/wikipedia "https://huggingface.co/datasets/wikipedia") | 21 | | WebText2 | 主要英文 | 约65G | 从Reddit提交的URL中抓取的文档构成 | [https://github.com/EleutherAI/openwebtext2](https://github.com/EleutherAI/openwebtext2 "https://github.com/EleutherAI/openwebtext2") | 22 | | BookCorpus | 英文 | 约3G | 英文图书 | [https://github.com/soskek/bookcorpus](https://github.com/soskek/bookcorpus "https://github.com/soskek/bookcorpus") | 23 | | ArXiv | 英文 | 约170万篇 | 英文学术论文 | [https://huggingface.co/datasets/arxiv\_dataset](https://huggingface.co/datasets/arxiv_dataset "https://huggingface.co/datasets/arxiv_dataset") | 24 | | 几个数据集平台 | | | | | 25 | | CLUEDatasetSearch | 中文 | 多NLP任务数据集合 | 中英文NLP相关任务数据集的集合 | [https://github.com/CLUEbenchmark/CLUEDatasetSearch](https://github.com/CLUEbenchmark/CLUEDatasetSearch "https://github.com/CLUEbenchmark/CLUEDatasetSearch") | 26 | | OpenDataLab | 中英文 | 多NLP任务数据集合 | 数据集平台 | [https://opendatalab.org.cn/home](https://opendatalab.org.cn/home "https://opendatalab.org.cn/home") | 27 | | Huggingface datasets | 多语言 | 多数据集 | 数据集平台 | [https://huggingface.co/datasets](https://huggingface.co/datasets "https://huggingface.co/datasets") | 28 | | 千言数据集 | 中文 | 多NLP任务数据集合 | 数据集平台 | [https://www.luge.ai/#/](https://www.luge.ai/#/ "https://www.luge.ai/#/") | 29 | | 天池数据集 | 中文 | 多NLP任务数据集合 | 数据集平台 | [https://tianchi.aliyun.com/dataset/](https://tianchi.aliyun.com/dataset/ "https://tianchi.aliyun.com/dataset/") | 30 | | kaggle | 主要英文 | 多NLP任务数据集合 | 数据集平台 | [https://www.kaggle.com/](https://www.kaggle.com/ "https://www.kaggle.com/") | 31 | | hyper | 中文 | 各种数据集 | 数据集平台 | [https://hyper.ai/datasets](https://hyper.ai/datasets "https://hyper.ai/datasets") | 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM-Travel 2 | ![Authour](https://img.shields.io/badge/Author-Glan-red.svg) [![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE) ![python_version](https://img.shields.io/badge/Python-3.x%2B-green.svg) 3 | 4 | 5 | ## Introduction 6 | 7 | 欢迎来到 "LLM-travel" 仓库!探索大语言模型(LLM)的奥秘 🚀。致力于深入理解、探讨以及实现与大模型相关的各种技术、原理和应用。 8 | 文章在知乎:https://www.zhihu.com/people/allenvery/posts 9 | 10 | ### 这里会有什么 🌟 11 | 12 | - **技术讲解**: 通过清晰且深入的文章,尽力揭示大语言模型的相关技术,探讨其背后的数学、算法和架构,帮助您理解它们的运作机制。 13 | 14 | - **实用代码实现**: 每篇实践性技术文章会配置相应的实践代码,帮助更好的理解和实现。 15 | 16 | - **解答疑问与讨论**: 欢迎提出问题、分享想法,以及想看到哪些内容,一起探讨大语言模型! 17 | 18 | ### 加入旅程 🌏 19 | 20 | 搭乘 "LLM-travel" 列车,一起探索大语言模型的奇妙世界! 21 | 22 | ## 已更新内容 23 | 24 | Date| Title(知乎链接)| Code| Note 25 | :---|---|---|--- 26 | 2024-06-23|[LLM大模型之Hallucination幻觉](https://zhuanlan.zhihu.com/p/703034375)|[无]()|LLM大模型之Hallucination幻觉 27 | 2024-06-03|[LLM大模型之分布式训练小结](https://zhuanlan.zhihu.com/p/699938704)|[无]()|LLM大模型之分布式训练小结 28 | 2024-05-10|[LLM大模型之训练优化方法](https://zhuanlan.zhihu.com/p/698787661)|[无]()|LLM大模型之训练优化方法 29 | 2024-04-09|[Transformer实践](https://www.zhihu.com/question/445556653/answer/3460351120)|[Transformer_torch](https://github.com/Glanvery/LLM-Travel/blob/main/transformer_torch.ipynb)|Transformer实践 30 | 2023-12-16|[LLM之Deepspeed实践](https://www.zhihu.com/question/371094177/answer/3330130413)|[无]()|Deepspeed实践 31 | 2023-11-11|[LLM之数据质量](https://zhuanlan.zhihu.com/p/670365989)|[quality_hash.ipynb](https://github.com/Glanvery/LLM-Travel/blob/main/quality_hash.ipynb)|LLM大模型之大规模数据文本质量(Text Quality)实践一 32 | 2023-11-04|[LLM之Trainer](https://zhuanlan.zhihu.com/p/662619853)|[无]()|LLM大模型之Trainer以及训练参数 33 | 2023-10-14|[LLM之数据处理二](https://zhuanlan.zhihu.com/p/661421095)|[无]()|LLM大模型之大规模数据处理工具篇Hadoop-Spark集群安装 34 | 2023-10-10|[LLM之开源数据整理](https://www.zhihu.com/question/609604943/answer/3248054165)|[LLM_Pretrain_Datasets](https://github.com/Glanvery/LLM-Travel/blob/main/LLM_Pretrain_Datasets.md)|开源的可用于LLM Pretrain数据集 35 | 2023-10-10|[LLM之数据处理一](https://zhuanlan.zhihu.com/p/660806587)|[无]()|LLM大模型之大规模数据处理工具篇Hadoop-Spark集群介绍 36 | 2023-09-30|[LLM之显存占用](https://zhuanlan.zhihu.com/p/658343628)|[memory_precision.ipynb](https://github.com/Glanvery/LLM-Travel/blob/main/memory_precision.ipynb)|不同精度下显存占用与相互转换实践 37 | 2023-09-29|[LLM之精度问题详解](https://zhuanlan.zhihu.com/p/657886517)|[precision.ipynb](https://github.com/Glanvery/LLM-Travel/blob/main/precision.ipynb)|精度问题(FP16,FP32,BF16)详解与实践 38 | 2023-09-24|[LLM之Embedding初始化](https://zhuanlan.zhihu.com/p/656335338)|[embedding_init.ipynb](https://github.com/Glanvery/LLM-Travel/blob/main/embedding_init.ipynb)|扩充词表后Embedding和LM_head层的初始化 39 | 2023-09-23|[LLM之扩充词表](https://zhuanlan.zhihu.com/p/655281268)|[sentencepiece.ipynb](https://github.com/Glanvery/LLM-Travel/blob/main/sentencepiece.ipynb)|基于SentencePiece扩充LLaMa中文词表实践 40 | 2023-09-16|[LLM之Generate参数详解](https://zhuanlan.zhihu.com/p/653926703)|[generate_parameter.ipynb](https://github.com/Glanvery/LLM-Travel/blob/main/generate_parameter.ipynb)|Generate/Inference(生成/推理)中参数与解码策略原理及其代码实现 41 | 2023-09-09|[LLM之Tokenization分词方法](https://zhuanlan.zhihu.com/p/652520262)|[tokenization.ipynb](https://github.com/Glanvery/LLM-Travel/blob/main/tokenization.ipynb)|WordPiece,Byte-Pair Encoding (BPE),Byte-level BPE(BBPE)原理及其代码实现 42 | -------------------------------------------------------------------------------- /embedding_init.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "1789c00b-6928-4167-b677-93a4a0dbd08c", 6 | "metadata": {}, 7 | "source": [ 8 | "### 加载模型" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "15417e78-d2f4-4d0b-b2c0-07890c080485", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from transformers import AutoModelForCausalLM, AutoTokenizer\n", 19 | "import torch\n", 20 | "model_name = \"/path/llama-2-7b-hf\" # 你模型的位置\n", 21 | "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\", torch_dtype=torch.float16)\n", 22 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 23 | "# 新的分词器\n", 24 | "new_tokenizer = AutoTokenizer.from_pretrained(\"/path/to/merged_tokenizer_hf\") # 你保存分词器的位置\n", 25 | "model" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "id": "0c76fb61-efc3-49d5-a7cb-32f9e610f0b7", 31 | "metadata": {}, 32 | "source": [ 33 | "### 随机扩充" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "b52abe0c-0e22-4b38-9496-8c4860047d59", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "# 获取原先的embedding\n", 44 | "embeddings = model.get_input_embeddings()\n", 45 | "print(embeddings)\n", 46 | "print(embeddings(torch.LongTensor([31999])))\n", 47 | "\n", 48 | "# 扩充\n", 49 | "model.resize_token_embeddings(40114)\n", 50 | "new_embeddings = model.get_input_embeddings()\n", 51 | "print(new_embeddings)\n", 52 | "print(new_embeddings(torch.LongTensor([31999])))" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "8013cb8b-3634-441f-840b-65bcd91fa5c1", 58 | "metadata": {}, 59 | "source": [ 60 | "### 均值扩充" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "a1930339-2cc5-4ef6-80f7-0b0d5912ae8c", 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "# 新增的token和在原来token相对应的字典\n", 71 | "token_mapping = {}\n", 72 | "for i in range(32000, len(new_tokenizer)):\n", 73 | " # 使用 tokenizer 的 convert_ids_to_tokens 方法将索引转换为对应的 token\n", 74 | " token = new_tokenizer.convert_ids_to_tokens(i)\n", 75 | " # 原来的token\n", 76 | " input_ids = tokenizer(token, return_tensors=\"pt\").input_ids[0]\n", 77 | " # 判断是否为_\n", 78 | " if input_ids[1] == 29871:\n", 79 | " new_input_ids = input_ids[2:]\n", 80 | " else:\n", 81 | " new_input_ids = input_ids[1:] \n", 82 | " token_mapping[i] = new_input_ids\n", 83 | "\n", 84 | "# 原始输入embedding\n", 85 | "embeddings = model.get_input_embeddings()\n", 86 | "# 新完全初始化的embedding\n", 87 | "new_vocab_size = len(new_tokenizer)\n", 88 | "embedding_dim = 4096\n", 89 | "new_embedding = torch.nn.Embedding(new_vocab_size, embedding_dim)\n", 90 | "\n", 91 | "# 将现有Embedding层的权重赋值给新的Embedding层的前32000行\n", 92 | "num_to_copy = min(new_vocab_size, len(embeddings.weight))\n", 93 | "new_embedding.weight.data[:num_to_copy, :] = embeddings.weight.data[:num_to_copy, :]\n", 94 | "\n", 95 | "# 开始新增\n", 96 | "for new_token, original_tokens in token_mapping.items():\n", 97 | " original_embeddings = embeddings(original_tokens)\n", 98 | " mean_embedding = torch.mean(original_embeddings, dim=0)\n", 99 | " new_embedding.weight.data[new_token] = mean_embedding\n", 100 | "\n", 101 | "# 更换嵌入层\n", 102 | "model.set_input_embeddings(new_embedding)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "562725e4-abaa-4d89-a7e0-9473e9fa208c", 108 | "metadata": {}, 109 | "source": [ 110 | "#### 扩充lm_head" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "id": "d65e0e2e-91ff-4a73-bf8e-8775266f015c", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "output_size = 32000\n", 121 | "new_output_size = 40114\n", 122 | "lm_head = model.lm_head\n", 123 | "# 新的lm_head\n", 124 | "new_lm_head = torch.nn.Linear(in_features=4096, out_features=new_output_size, bias=False)\n", 125 | "# 前32000个向量不变\n", 126 | "new_lm_head.weight.data[:output_size, :] = lm_head.weight.data[:output_size, :]\n", 127 | "\n", 128 | "# 新增\n", 129 | "for new_token, original_tokens in token_mapping.items():\n", 130 | " original = 0\n", 131 | " for i in original_tokens:\n", 132 | " original += lm_head.weight.data[i]\n", 133 | " mean_para = original / len(original_tokens)\n", 134 | " new_lm_head.weight.data[new_token] = mean_para\n", 135 | "\n", 136 | "# 替换模型原来的lm_head\n", 137 | "model.lm_head = new_lm_head\n", 138 | "\n", 139 | "# 最后完成了embedding和lm_head替换后,保存模型\n", 140 | "model.save_pretrained(\"llama-2-7b-extent\", max_shard_size=\"8GB\")" 141 | ] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "Python 3 (ipykernel)", 147 | "language": "python", 148 | "name": "python3" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.8.13" 161 | } 162 | }, 163 | "nbformat": 4, 164 | "nbformat_minor": 5 165 | } 166 | -------------------------------------------------------------------------------- /generate_parameter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "35813d22-9f27-4d87-8623-53397599409a", 6 | "metadata": {}, 7 | "source": [ 8 | "### 原始生成" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "8fc31193-4f3d-49f9-adea-3c74f52de9c0", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from transformers import AutoModelForCausalLM, AutoTokenizer\n", 19 | "import torch\n", 20 | "\n", 21 | "model_name = \"llama-2-7b-hf\" # 用你下载的模型的文件夹位置\n", 22 | "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\")\n", 23 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 24 | "\n", 25 | "text = \"say\"\n", 26 | "inputs = tokenizer(text, return_tensors=\"pt\")\n", 27 | "print(f\"inputs:{inputs}\")\n", 28 | "\n", 29 | "logits = model.forward(input_ids)\n", 30 | "print(\"Logits Shape:\", logits.logits.shape)\n", 31 | "print(f\"logits:{logits.logits}\")\n", 32 | "\n", 33 | "next_token = torch.argmax(logits.logits, dim=-1).reshape(-1)[1]\n", 34 | "print(f\"next_token:{next_token}\")\n", 35 | "\n", 36 | "next_word = tokenizer.decode(next_token)\n", 37 | "print(f\"next_word:{next_word}\")" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "d3f4e0e0-3057-49bc-aedd-4704b021408e", 43 | "metadata": {}, 44 | "source": [ 45 | "### temperature" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "id": "1a96f09f-0799-46af-9627-ef03b9844039", 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "name": "stdout", 56 | "output_type": "stream", 57 | "text": [ 58 | "probs:tensor([[0.2559, 0.5154, 0.0571, 0.1716]])\n", 59 | "probs_low:tensor([[0.1800, 0.7301, 0.0090, 0.0809]])\n", 60 | "probs_high:tensor([[0.2695, 0.3825, 0.1273, 0.2207]])\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "import torch\n", 66 | "logits = torch.tensor([[0.5, 1.2, -1.0, 0.1]])\n", 67 | "# 无temperature\n", 68 | "probs = torch.softmax(logits, dim=-1)\n", 69 | "# temperature low 0.5\n", 70 | "probs_low = torch.softmax(logits / 0.5, dim=-1)\n", 71 | "# temperature high 2\n", 72 | "probs_high = torch.softmax(logits / 2, dim=-1)\n", 73 | "\n", 74 | "print(f\"probs:{probs}\")\n", 75 | "print(f\"probs_low:{probs_low}\")\n", 76 | "print(f\"probs_high:{probs_high}\")" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "id": "2a2d69b9-0209-405d-94a6-800fa4622459", 82 | "metadata": {}, 83 | "source": [ 84 | "### top_p" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "id": "c5e77008-380e-4a23-a609-c2a58b639873", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "import torch\n", 95 | "# 样例:probs: tensor([[0.2559, 0.5154, 0.0571, 0.1716]])\n", 96 | "probs = torch.tensor([[0.2559, 0.5154, 0.0571, 0.1716]])\n", 97 | "# 第一步进行排序\n", 98 | "probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)\n", 99 | "# 结果\n", 100 | "probs_sort: tensor([[0.5154, 0.2559, 0.1716, 0.0571]])\n", 101 | "probs_idx: tensor([[1, 0, 3, 2]])\n", 102 | "\n", 103 | "# 第二步概率的累积和\n", 104 | "probs_sum = torch.cumsum(probs_sort, dim=-1)\n", 105 | "# 结果\n", 106 | "probs_sum: tensor([[0.5154, 0.7713, 0.9429, 1.0000]])\n", 107 | "\n", 108 | "# 第三步找到第一个大于阈值p的位置,假设p=0.9,并将后面的概率值置为0:\n", 109 | "mask = probs_sum - probs_sort > p\n", 110 | "probs_sort[mask] = 0.0\n", 111 | "# 结果\n", 112 | "probs_sort: tensor([[0.5154, 0.2559, 0.1716, 0.0000]])\n", 113 | "\n", 114 | "# 第四步复原原序列\n", 115 | "new_probs = probs_sort.scatter(1, probs_idx, probs_sort)\n", 116 | "# 结果\n", 117 | "new_probs: tensor([[0.2559, 0.5154, 0.0000, 0.1716]])\n", 118 | "\n", 119 | "# 注:在真实实现中一般会把舍弃的概率置为-inf,即\n", 120 | "zero_indices = (new_probs == 0)\n", 121 | "new_probs[zero_indices] = float('-inf')\n", 122 | "# 结果\n", 123 | "new_probs: tensor([[0.2559, 0.5154, -inf, 0.1716]])\n", 124 | "# 完整代码\n", 125 | "def sample_top_p(probs, p):\n", 126 | " probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)\n", 127 | " probs_sum = torch.cumsum(probs_sort, dim=-1)\n", 128 | " mask = probs_sum - probs_sort > p\n", 129 | " probs_sort[mask] = 0.0\n", 130 | " new_probs = probs_sort.scatter(1, probs_idx, probs_sort)\n", 131 | " zero_indices = (new_probs == 0)\n", 132 | " new_probs[zero_indices] = float('-inf')\n", 133 | " return new_probs" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "id": "9bb3d974-75e3-43df-915a-eb22771f86b3", 139 | "metadata": {}, 140 | "source": [ 141 | "### top_k" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 5, 147 | "id": "c575d66f-059d-49ea-81e8-5ba948521254", 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "new_probs: tensor([[0.2559, 0.5154, -inf, -inf]])\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "import torch\n", 160 | "filter_value = -float(\"Inf\")\n", 161 | "top_k = 2\n", 162 | "probs = torch.tensor([[0.2559, 0.5154, 0.0571, 0.1716]])\n", 163 | "indices_to_remove = probs < torch.topk(probs, top_k)[0][..., -1, None]\n", 164 | "new_probs = probs.masked_fill(indices_to_remove, filter_value)\n", 165 | "print(\"new_probs:\", new_probs)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "id": "88402e6a-2909-4628-9d4c-4111053f38b9", 171 | "metadata": {}, 172 | "source": [ 173 | "### repetition_penalty" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 6, 179 | "id": "d0b4101f-a916-4068-b1cf-f2c328429b07", 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "原始概率分布: [0.3 0.1 0.3 0.1 0.2]\n", 187 | "调整后的概率分布: [0.33333333 0.11111111 0.26666667 0.11111111 0.17777778]\n" 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | "import numpy as np\n", 193 | "def apply_repetition_penalty(probs, repetition_penalty, prev_tokens):\n", 194 | " adjusted_probs = np.copy(probs)\n", 195 | " for token in set(prev_tokens):\n", 196 | " adjusted_probs[token] *= repetition_penalty\n", 197 | " adjusted_probs /= np.sum(adjusted_probs) \n", 198 | " return adjusted_probs\n", 199 | "# 示例概率分布,索引对应不同词语\n", 200 | "original_probs = np.array([0.3, 0.1, 0.3, 0.1, 0.2])\n", 201 | "# 示例先前生成的词语\n", 202 | "previous_tokens = [2, 4, 2]\n", 203 | "# 重复惩罚系数\n", 204 | "repetition_penalty = 0.8\n", 205 | "# 应用重复惩罚,得到调整后的概率分布\n", 206 | "adjusted_probs = apply_repetition_penalty(original_probs, repetition_penalty, previous_tokens)\n", 207 | "\n", 208 | "print(\"原始概率分布:\", original_probs)\n", 209 | "print(\"调整后的概率分布:\", adjusted_probs)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "id": "152f8cdb-97f9-4a66-9206-c82f9306045b", 215 | "metadata": {}, 216 | "source": [ 217 | "### do_sample" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 9, 223 | "id": "d1b6b1a8-ccf8-49d6-90b2-6a8b79ce0ad5", 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "next_token: tensor([[1]])\n" 231 | ] 232 | } 233 | ], 234 | "source": [ 235 | "import torch\n", 236 | "probs = torch.tensor([[0.2559, 0.5154, 0.0571, 0.1716]])\n", 237 | "next_token = torch.multinomial(probs, num_samples=1)\n", 238 | "print(\"next_token:\", next_token)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "id": "714e70d5-bbf2-46c8-97c4-b6a0428e065a", 244 | "metadata": {}, 245 | "source": [ 246 | "### num_beams" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 12, 252 | "id": "cc0c0b32-d922-4ee5-98a8-dedcf0ab49fc", 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "name": "stdout", 257 | "output_type": "stream", 258 | "text": [ 259 | "Sentence 1: I like\n", 260 | "Sentence 2: I apple\n", 261 | "Sentence 3: I peach\n" 262 | ] 263 | } 264 | ], 265 | "source": [ 266 | "class BeamSearchNode:\n", 267 | " def __init__(self, sequence, score):\n", 268 | " self.sequence = sequence # 生成的序列\n", 269 | " self.score = score # 分数(概率)\n", 270 | " \n", 271 | "# 示例:下一个token的概率函数,简单使用固定概率\n", 272 | "def simple_next_word_probs(sequence):\n", 273 | " if sequence[-1] == \"\":\n", 274 | " return {}\n", 275 | " return {\"apple\": 0.3, \"like\": 0.35, \"peach\": 0.2, \"banana\": 0.15}\n", 276 | "\n", 277 | "\n", 278 | "def beam_search(initial_sequence, next_word_probs_func, num_beams, max_sequence_length):\n", 279 | " # 初始化初始节点,且分数为1\n", 280 | " initial_node = BeamSearchNode(sequence=initial_sequence, score=1.0)\n", 281 | " candidates = [initial_node]\n", 282 | "\n", 283 | " final_candidates = [] # 最终的候选序列\n", 284 | " # 只要候选节点列表不为空,且 final_candidates 中的候选节点数量还没有达到指定的束宽度,就继续进行搜索\n", 285 | " while candidates and len(final_candidates) < num_beams:\n", 286 | " # 候选节点排序\n", 287 | " candidates.sort(key=lambda x: -x.score)\n", 288 | " current_node = candidates.pop(0)\n", 289 | " # 当节点序列末尾生成结束符号(如\"\"),或者当生成的序列长度达到最大限制时终止节点的扩展\n", 290 | " if current_node.sequence[-1] == \"\" or len(current_node.sequence) >= max_sequence_length:\n", 291 | " final_candidates.append(current_node)\n", 292 | " else:\n", 293 | " # 获取下一个token的概率,我们的例子返回的是固定的概率\n", 294 | " next_words_probs = next_word_probs_func(current_node.sequence) \n", 295 | " # 生成新的候选序列,并计算分数 \n", 296 | " for next_word, next_word_prob in next_words_probs.items():\n", 297 | " new_sequence = current_node.sequence + [next_word]\n", 298 | " new_score = current_node.score * next_word_prob\n", 299 | " new_node = BeamSearchNode(sequence=new_sequence, score=new_score)\n", 300 | " candidates.append(new_node)\n", 301 | "\n", 302 | " return [candidate.sequence for candidate in final_candidates]\n", 303 | "\n", 304 | "# 开始使用:\n", 305 | "\n", 306 | "initial_sequence = [\"\", \"I\"]\n", 307 | "num_beams = 3\n", 308 | "max_sequence_length = 3\n", 309 | "result = beam_search(initial_sequence, simple_next_word_probs, num_beams, max_sequence_length)\n", 310 | "\n", 311 | "for idx, sequence in enumerate(result):\n", 312 | " print(f\"Sentence {idx + 1}: {' '.join(sequence)}\")" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "id": "6fa70db8-d21e-4e7c-b0d0-7e4f4ad17e08", 318 | "metadata": {}, 319 | "source": [ 320 | "### constrained beam-search decoding" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "id": "746241eb-5f34-47bd-af4a-ff1a29811db0", 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig\n", 331 | "import torch\n", 332 | "model_name = \"llama-2-7b-hf\" # 你模型的位置\n", 333 | "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\")\n", 334 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 335 | "\n", 336 | "text = \"say hello to\"\n", 337 | "inputs = tokenizer(text, return_tensors=\"pt\")\n", 338 | "print(f\"inputs:{inputs}\")\n", 339 | "input_ids = inputs[\"input_ids\"].to(\"cuda\")\n", 340 | "\n", 341 | "# generate实现\n", 342 | "generation_output = model.generate(\n", 343 | " input_ids=input_ids,\n", 344 | " num_beams = 3,\n", 345 | " num_return_sequences=3,\n", 346 | " return_dict_in_generate=True,\n", 347 | " max_new_tokens=3,\n", 348 | ")\n", 349 | "\n", 350 | "print(\"query:\", text)\n", 351 | "for i, output_sequence in enumerate(generation_output.sequences):\n", 352 | " output_text = tokenizer.decode(output_sequence, skip_special_tokens=True)\n", 353 | " print(f\"Generated sequence {i+1}: {output_text}\")" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "id": "deb0a45c-2255-4254-b8e5-a0748b5f71ad", 360 | "metadata": {}, 361 | "outputs": [], 362 | "source": [ 363 | "from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig\n", 364 | "import torch\n", 365 | "model_name = \"llama-2-7b-hf\" # 你模型的位置\n", 366 | "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\")\n", 367 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 368 | "\n", 369 | "text = \"say hello to\"\n", 370 | "inputs = tokenizer(text, return_tensors=\"pt\")\n", 371 | "print(f\"inputs:{inputs}\")\n", 372 | "input_ids = inputs[\"input_ids\"].to(\"cuda\")\n", 373 | "\n", 374 | "force_words = [\"my\"]\n", 375 | "force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids\n", 376 | "\n", 377 | "generation_output = model.generate(\n", 378 | " input_ids=input_ids,\n", 379 | " force_words_ids = force_words_ids,\n", 380 | " num_beams = 3,\n", 381 | " num_return_sequences=3,\n", 382 | " return_dict_in_generate=True,\n", 383 | " max_new_tokens=3,\n", 384 | ")\n", 385 | "\n", 386 | "print(\"query:\", text)\n", 387 | "for i, output_sequence in enumerate(generation_output.sequences):\n", 388 | " output_text = tokenizer.decode(output_sequence, skip_special_tokens=True)\n", 389 | " print(f\"Generated sequence {i+1}: {output_text}\")" 390 | ] 391 | }, 392 | { 393 | "cell_type": "markdown", 394 | "id": "f0589bee-997d-449b-9db3-a6b21b2c530c", 395 | "metadata": {}, 396 | "source": [ 397 | "### contrastive search" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": null, 403 | "id": "831debf4-b0d8-4c6f-b7b0-710d92286d00", 404 | "metadata": {}, 405 | "outputs": [], 406 | "source": [ 407 | "from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig\n", 408 | "import torch\n", 409 | "model_name = \"llama-2-7b-hf\" # 你模型的位置\n", 410 | "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\")\n", 411 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 412 | "\n", 413 | "\n", 414 | "text = \"say hello to\"\n", 415 | "inputs = tokenizer(text, return_tensors=\"pt\")\n", 416 | "print(f\"inputs:{inputs}\")\n", 417 | "input_ids = inputs[\"input_ids\"].to(\"cuda\")\n", 418 | "\n", 419 | "\n", 420 | "generation_output = model.generate(\n", 421 | " input_ids=input_ids,\n", 422 | " penalty_alpha = 0.5,\n", 423 | " top_k = 30,\n", 424 | " return_dict_in_generate=True,\n", 425 | " max_new_tokens=3,\n", 426 | ")\n", 427 | "\n", 428 | "# 直接使用其函数\n", 429 | "# generation_output = model.contrastive_search(\n", 430 | "# input_ids=input_ids,\n", 431 | "# penalty_alpha = 0.5,\n", 432 | "# top_k = 30,\n", 433 | "# return_dict_in_generate=True,\n", 434 | "# max_new_tokens=3,\n", 435 | "# )\n", 436 | "\n", 437 | "print(\"query:\", text)\n", 438 | "for i, output_sequence in enumerate(generation_output.sequences):\n", 439 | " output_text = tokenizer.decode(output_sequence, skip_special_tokens=True)\n", 440 | " print(f\"Generated sequence {i+1}: {output_text}\")" 441 | ] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "id": "749fa43c-4712-4e35-92d5-c29e8cc5f940", 446 | "metadata": {}, 447 | "source": [ 448 | "### greedy decoding" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": null, 454 | "id": "fb414f73-7d16-4262-99b9-ce104c2add78", 455 | "metadata": {}, 456 | "outputs": [], 457 | "source": [ 458 | "from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig\n", 459 | "import torch\n", 460 | "model_name = \"llama-2-7b-hf\" # 你模型的位置\n", 461 | "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\")\n", 462 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 463 | "\n", 464 | "text = \"say hello to\"\n", 465 | "inputs = tokenizer(text, return_tensors=\"pt\")\n", 466 | "print(f\"inputs:{inputs}\")\n", 467 | "input_ids = inputs[\"input_ids\"].to(\"cuda\")\n", 468 | "\n", 469 | "\n", 470 | "generation_output = model.generate(\n", 471 | " input_ids=input_ids,\n", 472 | " num_beams = 1,\n", 473 | " do_sample = False,\n", 474 | " return_dict_in_generate=True,\n", 475 | " max_new_tokens=3,\n", 476 | ")\n", 477 | "# 直接指定使用其函数\n", 478 | "# generation_output = model.greedy_search(\n", 479 | "# input_ids=input_ids,\n", 480 | "# num_beams = 1,\n", 481 | "# do_sample = False,\n", 482 | "# return_dict_in_generate=True,\n", 483 | "# max_length = 7\n", 484 | "# )\n", 485 | "\n", 486 | "\n", 487 | "print(\"query:\", text)\n", 488 | "for i, output_sequence in enumerate(generation_output.sequences):\n", 489 | " output_text = tokenizer.decode(output_sequence, skip_special_tokens=True)\n", 490 | " print(f\"Generated sequence {i+1}: {output_text}\")" 491 | ] 492 | }, 493 | { 494 | "cell_type": "markdown", 495 | "id": "0bd85cb9-0cb2-41d8-ad8e-a1c24fa3e2bb", 496 | "metadata": {}, 497 | "source": [ 498 | "### multinomial sampling" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "id": "2870cf87-304a-471a-a35a-c63fe714edf2", 505 | "metadata": {}, 506 | "outputs": [], 507 | "source": [ 508 | "from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig\n", 509 | "from transformers import (\n", 510 | " LogitsProcessorList,\n", 511 | " TopKLogitsWarper,\n", 512 | " TopPLogitsWarper,\n", 513 | " TemperatureLogitsWarper,\n", 514 | " )\n", 515 | "\n", 516 | "import torch\n", 517 | "model_name = \"llama-2-7b-hf\" # 你模型的位置\n", 518 | "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\")\n", 519 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 520 | "\n", 521 | "text = \"say hello to\"\n", 522 | "inputs = tokenizer(text, return_tensors=\"pt\")\n", 523 | "print(f\"inputs:{inputs}\")\n", 524 | "input_ids = inputs[\"input_ids\"].to(\"cuda\")\n", 525 | "\n", 526 | "\n", 527 | "generation_output = model.generate(\n", 528 | " input_ids=input_ids,\n", 529 | " num_beams = 1,\n", 530 | " do_sample = True,\n", 531 | " temperature = 1.2,\n", 532 | " top_k = 100,\n", 533 | " top_p = 0.6,\n", 534 | " return_dict_in_generate=True,\n", 535 | " max_length=7,\n", 536 | ")\n", 537 | "\n", 538 | "\n", 539 | "# sample实现\n", 540 | "# logits_warper = LogitsProcessorList(\n", 541 | "# [\n", 542 | "# TopKLogitsWarper(100),\n", 543 | "# TemperatureLogitsWarper(1.2),\n", 544 | "# TopPLogitsWarper(0.6)\n", 545 | "# ]\n", 546 | "# )\n", 547 | "# generation_output = model.sample(\n", 548 | "# input_ids=input_ids,\n", 549 | "# logits_warper=logits_warper,\n", 550 | "# return_dict_in_generate=True,\n", 551 | "# max_length=7,\n", 552 | "# )\n", 553 | "\n", 554 | "\n", 555 | "print(\"query:\", text)\n", 556 | "for i, output_sequence in enumerate(generation_output.sequences):\n", 557 | " output_text = tokenizer.decode(output_sequence, skip_special_tokens=True)\n", 558 | " print(f\"Generated sequence {i+1}: {output_text}\")" 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "id": "fcf86750-0f17-49a0-a9fd-0e94b355e4d0", 564 | "metadata": {}, 565 | "source": [ 566 | "### assisted decoding" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": null, 572 | "id": "fa2dad77-d546-4460-9536-1e1c1232d695", 573 | "metadata": {}, 574 | "outputs": [], 575 | "source": [ 576 | "from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig\n", 577 | "import torch\n", 578 | "model_name = \"llama-2-13b-hf\" # 你自己模型的位置\n", 579 | "assistant_model_name = \"llama-2-7b-hf\" # 你自己模型的位置\n", 580 | "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\")\n", 581 | "assistant_model = AutoModelForCausalLM.from_pretrained(assistant_model_name, device_map=\"auto\")\n", 582 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 583 | "\n", 584 | "text = \"say hello to\"\n", 585 | "inputs = tokenizer(text, return_tensors=\"pt\")\n", 586 | "print(f\"inputs:{inputs}\")\n", 587 | "input_ids = inputs[\"input_ids\"].to(\"cuda\")\n", 588 | "\n", 589 | "\n", 590 | "generation_output = model.generate(\n", 591 | " assistant_model=assistant_model,\n", 592 | " input_ids=input_ids,\n", 593 | " num_beams = 1,\n", 594 | " do_sample = False,\n", 595 | " return_dict_in_generate=True,\n", 596 | " max_length=7,\n", 597 | ")\n", 598 | "\n", 599 | "\n", 600 | "print(\"query:\", text)\n", 601 | "for i, output_sequence in enumerate(generation_output.sequences):\n", 602 | " output_text = tokenizer.decode(output_sequence, skip_special_tokens=True)\n", 603 | " print(f\"Generated sequence {i+1}: {output_text}\")" 604 | ] 605 | } 606 | ], 607 | "metadata": { 608 | "kernelspec": { 609 | "display_name": "Python 3 (ipykernel)", 610 | "language": "python", 611 | "name": "python3" 612 | }, 613 | "language_info": { 614 | "codemirror_mode": { 615 | "name": "ipython", 616 | "version": 3 617 | }, 618 | "file_extension": ".py", 619 | "mimetype": "text/x-python", 620 | "name": "python", 621 | "nbconvert_exporter": "python", 622 | "pygments_lexer": "ipython3", 623 | "version": "3.8.13" 624 | } 625 | }, 626 | "nbformat": 4, 627 | "nbformat_minor": 5 628 | } 629 | -------------------------------------------------------------------------------- /memory_precision.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b2d6451b-02c9-4ac4-8ecd-eeab4c0c62ac", 6 | "metadata": {}, 7 | "source": [ 8 | "### 版本信息" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "59607366-f50d-4a12-a6f5-8f4d99a1b0e6", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import transformers\n", 19 | "from transformers import AutoModelForCausalLM, AutoTokenizer\n", 20 | "import torch\n", 21 | "\n", 22 | "# 打印版本号\n", 23 | "print(\"transformers version:\", transformers.__version__)\n", 24 | "print(\"torch version:\", torch.__version__)\n", 25 | "\n", 26 | "# 检查系统中是否有可用的 GPU\n", 27 | "if torch.cuda.is_available():\n", 28 | " # 获取可用的 GPU 设备数量\n", 29 | " num_devices = torch.cuda.device_count()\n", 30 | " print(\"可用 GPU 数量:\", num_devices)\n", 31 | "\n", 32 | " # 遍历所有可用的 GPU 设备并打印详细信息\n", 33 | " for i in range(num_devices):\n", 34 | " device = torch.cuda.get_device_properties(i)\n", 35 | " print(f\"\\nGPU {i} 的详细信息:\")\n", 36 | " print(\"名称:\", device.name)\n", 37 | " print(\"计算能力:\", f\"{device.major}.{device.minor}\")\n", 38 | " print(\"内存总量 (GB):\", round(device.total_memory / (1024**3), 1))\n", 39 | "else:\n", 40 | " print(\"没有可用的 GPU\")" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "id": "ac275fe8-0db3-40b4-a1da-727ede64bca0", 46 | "metadata": {}, 47 | "source": [ 48 | "### FP16显存占用" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "88b66c08-79a8-40fe-ba87-25a759999e6a", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "# 加载模型\n", 59 | "model_name = \"/path/to/llama-2-7b-hf\" # 你模型存放的位置\n", 60 | "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"cuda:0\", torch_dtype=torch.float16)\n", 61 | "# 模型总参数\n", 62 | "total_parameters = model.num_parameters()\n", 63 | "print(\"Total parameters in the model:\", total_parameters)\n", 64 | "\n", 65 | "# 计算结果\n", 66 | "size_per_parameter_bytes = 2\n", 67 | "# 计算模型在显存中的总空间(以字节为单位)\n", 68 | "total_memory_bytes = total_parameters * size_per_parameter_bytes\n", 69 | "# 将字节转换为更常见的单位(GB)\n", 70 | "total_memory_gb = total_memory_bytes / (1024**3)\n", 71 | "print(\"Total memory occupied by the model in MB:\", total_memory_gb)\n", 72 | "\n", 73 | "# torch显示结果\n", 74 | "memory_allocated = torch.cuda.memory_allocated(device='cuda:0')\n", 75 | "# 将字节转换为更常见的单位(GB)\n", 76 | "memory_allocated_gb = memory_allocated / (1024**3)\n", 77 | "print(\"Memory allocated by the model in GB:\", memory_allocated_gb)\n", 78 | "\n", 79 | "# 显卡显示结果\n", 80 | "nvidia-smi" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "6dde2aff-53d4-487f-a68a-8d8692979a1b", 86 | "metadata": {}, 87 | "source": [ 88 | "### FP32显存占用" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "4f9522a8-7ff5-4201-8bf3-d599dc8a94f8", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "# 和上述是一模一样的代码,就是两个地方不一样\n", 99 | "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"cuda:0\", torch_dtype=torch.float32)\n", 100 | "...\n", 101 | "size_per_parameter_bytes = 4" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "id": "8015c4de-29c3-4f26-abde-e42c33d49dc0", 107 | "metadata": {}, 108 | "source": [ 109 | "### BF16显存占用" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "7414a663-a4b4-485b-95fc-48a4015a279b", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "# 和上述是一模一样的代码,就是两个地方不一样\n", 120 | "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"cuda:0\", torch_dtype=torch.bfloat16)\n", 121 | "...\n", 122 | "size_per_parameter_bytes = 2" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "id": "4ca433d2-b8b2-4511-b341-347f57e1d4e6", 128 | "metadata": {}, 129 | "source": [ 130 | "### 模型精度转换" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "78f675c7-2516-4e19-95fd-790037e08d66", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "# 以float32加载\n", 141 | "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"cuda:0\", torch_dtype=torch.float32)\n", 142 | "# 计算模型的显存占用\n", 143 | "memory_allocated = torch.cuda.memory_allocated(device='cuda:0')\n", 144 | "# 将字节转换为更常见的单位(GB)\n", 145 | "memory_allocated_gb = memory_allocated / (1024**3)\n", 146 | "print(\"Memory allocated by the model in GB:\", memory_allocated_gb)\n", 147 | "\n", 148 | "# 转为float16\n", 149 | "model.half()\n", 150 | "# 计算模型的显存占用\n", 151 | "memory_allocated = torch.cuda.memory_allocated(device='cuda:0')\n", 152 | "# 将字节转换为更常见的单位(GB)\n", 153 | "memory_allocated_gb = memory_allocated / (1024**3)\n", 154 | "print(\"Memory allocated by the model in GB:\", memory_allocated_gb)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "id": "711f44be-4908-4975-9fd2-640339e9afd1", 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "# 以float16加载\n", 165 | "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"cuda:0\", torch_dtype=torch.float16)\n", 166 | "# 计算模型的显存占用\n", 167 | "memory_allocated = torch.cuda.memory_allocated(device='cuda:0')\n", 168 | "# 将字节转换为更常见的单位(GB)\n", 169 | "memory_allocated_gb = memory_allocated / (1024**3)\n", 170 | "print(\"Memory allocated by the model in GB:\", memory_allocated_gb)\n", 171 | "\n", 172 | "# 转为float16\n", 173 | "model.float()\n", 174 | "# 计算模型的显存占用\n", 175 | "memory_allocated = torch.cuda.memory_allocated(device='cuda:0')\n", 176 | "# 将字节转换为更常见的单位(GB)\n", 177 | "memory_allocated_gb = memory_allocated / (1024**3)\n", 178 | "print(\"Memory allocated by the model in GB:\", memory_allocated_gb)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "id": "3d51a524-9dfa-4f5e-b9c2-48486cd3713d", 184 | "metadata": {}, 185 | "source": [ 186 | "### 如何转换" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 3, 192 | "id": "b87b4f9e-e46d-458c-b601-aaa62e286662", 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "Original Tensor:\n", 200 | " tensor([3.1400])\n", 201 | "Half-Precision Tensor:\n", 202 | " tensor([3.1406], dtype=torch.float16)\n" 203 | ] 204 | } 205 | ], 206 | "source": [ 207 | "import torch\n", 208 | "# 创建一个单精度浮点数的张量\n", 209 | "float_tensor = torch.tensor([3.14], dtype=torch.float32)\n", 210 | "# 将张量转换为半精度浮点数\n", 211 | "half_tensor = float_tensor.half()\n", 212 | "# 打印转换后的张量及其数据类型\n", 213 | "print(\"Original Tensor:\\n\", float_tensor)\n", 214 | "print(\"Half-Precision Tensor:\\n\", half_tensor)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "id": "040f3c1d-20f5-42f1-b3d6-e5dd69aefafc", 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": "Python 3 (ipykernel)", 229 | "language": "python", 230 | "name": "python3" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.8.13" 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 5 247 | } 248 | -------------------------------------------------------------------------------- /precision.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5beaa751-1ea2-467f-96b6-d29a436f001c", 6 | "metadata": {}, 7 | "source": [ 8 | "### FP16" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 2, 14 | "id": "778f1b3a-9a43-448f-8490-c1e5f1f6d44a", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "data": { 19 | "text/plain": [ 20 | "finfo(resolution=0.001, min=-65504, max=65504, eps=0.000976562, smallest_normal=6.10352e-05, tiny=6.10352e-05, dtype=float16)" 21 | ] 22 | }, 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "output_type": "execute_result" 26 | } 27 | ], 28 | "source": [ 29 | "import torch\n", 30 | "torch.finfo(torch.float16)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 6, 36 | "id": "e55197aa-b5af-4958-896d-a9cd5e4605a9", 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "3.141: 3.140625\n", 44 | "3.1415: 3.140625\n", 45 | "3.142: 3.142578125\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "# 把10进制数转化为 torch.float16\n", 51 | "num = 3.141\n", 52 | "num_fp16 = torch.tensor(num).half()\n", 53 | "print(f\"3.141: {num_fp16}\")\n", 54 | "\n", 55 | "num = 3.1415\n", 56 | "num_fp16 = torch.tensor(num).half()\n", 57 | "print(f\"3.1415: {num_fp16}\")\n", 58 | "\n", 59 | "num = 3.142\n", 60 | "num_fp16 = torch.tensor(num).half()\n", 61 | "print(f\"3.142: {num_fp16}\")" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 7, 67 | "id": "e155cdb8-7a1d-426c-b834-65e52813abce", 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "3.141: 3.140625\n", 75 | "二进制: 0100001001001000\n", 76 | "3.1415: 3.140625\n", 77 | "二进制: 0100001001001000\n", 78 | "3.142: 3.142578125\n", 79 | "二进制: 0100001001001001\n" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "# float16变成2进制\n", 85 | "import struct\n", 86 | "def float16_to_bin(num):\n", 87 | " # 将float16数打包为2字节16位,使用struct.pack\n", 88 | " packed_num = struct.pack('e', num)\n", 89 | "\n", 90 | " # 解包打包后的字节以获取整数表示\n", 91 | " int_value = struct.unpack('H', packed_num)[0]\n", 92 | "\n", 93 | " # 将整数表示转换为二进制\n", 94 | " binary_representation = bin(int_value)[2:].zfill(16)\n", 95 | " return binary_representation\n", 96 | "\n", 97 | "num = 3.141\n", 98 | "num_fp16 = torch.tensor(num).half()\n", 99 | "print(f\"3.141: {num_fp16}\")\n", 100 | "binary_representation = float16_to_bin(num_fp16)\n", 101 | "print(f\"二进制: {binary_representation}\") # 打印二进制表示\n", 102 | "\n", 103 | "num = 3.1415\n", 104 | "num_fp16 = torch.tensor(num).half()\n", 105 | "print(f\"3.1415: {num_fp16}\")\n", 106 | "binary_representation = float16_to_bin(num_fp16)\n", 107 | "print(f\"二进制: {binary_representation}\") # 打印二进制表示\n", 108 | "\n", 109 | "num = 3.142\n", 110 | "num_fp16 = torch.tensor(num).half()\n", 111 | "print(f\"3.142: {num_fp16}\")\n", 112 | "binary_representation = float16_to_bin(num_fp16)\n", 113 | "print(f\"二进制: {binary_representation}\") # 打印二进制表示" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 8, 119 | "id": "6aeadf20-ae44-4fa7-b73d-9e38c6e0c36a", 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "通过2进制转化后Float16值: 3.140625\n", 127 | "通过2进制转化后Float16值: 3.140625\n", 128 | "通过2进制转化后Float16值: 3.142578125\n" 129 | ] 130 | } 131 | ], 132 | "source": [ 133 | "# 2进制变成float16\n", 134 | "def binary_to_float16(binary_string):\n", 135 | " # 检查输入是否是有效的16位二进制字符串\n", 136 | " if len(binary_string) != 16:\n", 137 | " raise ValueError(\"输入的二进制字符串必须是16位长\")\n", 138 | "\n", 139 | " # 提取组成部分:符号、指数、尾数\n", 140 | " sign = int(binary_string[0]) # 符号位\n", 141 | " exponent = int(binary_string[1:6], 2) # 指数位\n", 142 | " mantissa = int(binary_string[6:], 2) / 1024.0 # 尾数位,除以2的10次方(即1024)以获得10位精度\n", 143 | "\n", 144 | " # 根据符号、指数和尾数计算float16值\n", 145 | " value = (-1) ** sign * (1 + mantissa) * 2 ** (exponent - 15)\n", 146 | " return value\n", 147 | "\n", 148 | "# 10进制3.141对应float16:3.1406\n", 149 | "binary_representation = \"0100001001001000\"\n", 150 | "# 将二进制表示转换为float16\n", 151 | "float16_value = binary_to_float16(binary_representation)\n", 152 | "print(\"通过2进制转化后Float16值:\", float16_value)\n", 153 | "# 结果:\n", 154 | "\n", 155 | "# 10进制3.1415对应float16:3.1406\n", 156 | "binary_representation = \"0100001001001000\"\n", 157 | "# 将二进制表示转换为float16\n", 158 | "float16_value = binary_to_float16(binary_representation)\n", 159 | "print(\"通过2进制转化后Float16值:\", float16_value)\n", 160 | "\n", 161 | "# 10进制3.142对应float16:3.1426\n", 162 | "binary_representation = \"0100001001001001\"\n", 163 | "# 将二进制表示转换为float16\n", 164 | "float16_value = binary_to_float16(binary_representation)\n", 165 | "print(\"通过2进制转化后Float16值:\", float16_value)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "id": "4f5c7e7d-c01c-4ed8-a1f7-e9d582f8c799", 171 | "metadata": {}, 172 | "source": [ 173 | "### BF16" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 10, 179 | "id": "8c4bb38f-c408-4f85-a5ec-d6fba45e52c5", 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "text/plain": [ 185 | "finfo(resolution=0.01, min=-3.38953e+38, max=3.38953e+38, eps=0.0078125, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=bfloat16)" 186 | ] 187 | }, 188 | "execution_count": 10, 189 | "metadata": {}, 190 | "output_type": "execute_result" 191 | } 192 | ], 193 | "source": [ 194 | "torch.finfo(torch.bfloat16)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "id": "ed76269d-2b88-45fd-a862-b572382aada9", 200 | "metadata": {}, 201 | "source": [ 202 | "### FP32" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 11, 208 | "id": "aa294bf7-d748-4cb4-819b-c8a8ba8d4257", 209 | "metadata": {}, 210 | "outputs": [ 211 | { 212 | "data": { 213 | "text/plain": [ 214 | "finfo(resolution=1e-06, min=-3.40282e+38, max=3.40282e+38, eps=1.19209e-07, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=float32)" 215 | ] 216 | }, 217 | "execution_count": 11, 218 | "metadata": {}, 219 | "output_type": "execute_result" 220 | } 221 | ], 222 | "source": [ 223 | "torch.finfo(torch.float32)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "id": "f6e76055-8e76-4cff-8172-5f350a9bf218", 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [] 233 | } 234 | ], 235 | "metadata": { 236 | "kernelspec": { 237 | "display_name": "Python 3 (ipykernel)", 238 | "language": "python", 239 | "name": "python3" 240 | }, 241 | "language_info": { 242 | "codemirror_mode": { 243 | "name": "ipython", 244 | "version": 3 245 | }, 246 | "file_extension": ".py", 247 | "mimetype": "text/x-python", 248 | "name": "python", 249 | "nbconvert_exporter": "python", 250 | "pygments_lexer": "ipython3", 251 | "version": "3.8.13" 252 | } 253 | }, 254 | "nbformat": 4, 255 | "nbformat_minor": 5 256 | } 257 | -------------------------------------------------------------------------------- /quality_hash.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pyspark.sql import SparkSession\n", 10 | "from pyspark.ml.feature import HashingTF\n", 11 | "from pyspark.ml.classification import LogisticRegression\n", 12 | "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n", 13 | "from pyspark.sql.functions import udf\n", 14 | "from pyspark.sql.types import ArrayType, StringType\n", 15 | "import jieba\n", 16 | "import time\n", 17 | "\n", 18 | "def initialize_spark():\n", 19 | " return SparkSession.builder \\\n", 20 | " .appName(\"TextClassification\") \\\n", 21 | " .master(\"local[12]\") \\\n", 22 | " .config(\"spark.executor.memory\", \"4g\") \\\n", 23 | " .config(\"spark.executor.cores\", \"5\") \\\n", 24 | " .config(\"spark.task.cpus\", \"1\") \\\n", 25 | " .config(\"spark.executor.instances\", \"2\") \\\n", 26 | " .config(\"spark.sql.auto.repartition\", \"true\") \\\n", 27 | " .config(\"spark.driver.memory\", \"8g\") \\\n", 28 | " .getOrCreate()\n", 29 | "\n", 30 | "def load_data(spark, data_dir):\n", 31 | " return spark.read.json(data_dir)\n", 32 | "\n", 33 | "def jieba_tokenizer(text):\n", 34 | " words = jieba.cut(text)\n", 35 | " return [word for word in words]\n", 36 | "\n", 37 | "def register_jieba_udf(data):\n", 38 | " jieba_udf = udf(jieba_tokenizer, ArrayType(StringType()))\n", 39 | " return data.withColumn(\"words\", jieba_udf(data[\"content\"]))\n", 40 | "\n", 41 | "\n", 42 | "\n", 43 | "# Train and evaluate a logistic regression model\n", 44 | "def train_and_evaluate_model(train_data, test_data):\n", 45 | " # Feature extraction using HashingTF\n", 46 | " hashingTF = HashingTF(numFeatures=2**18, inputCol=\"words\", outputCol=\"features\")\n", 47 | " train_data = hashingTF.transform(train_data)\n", 48 | " test_data = hashingTF.transform(test_data)\n", 49 | "\n", 50 | " # Train a logistic regression model\n", 51 | " lr = LogisticRegression(maxIter=10, regParam=0.02, featuresCol=\"features\")\n", 52 | " model = lr.fit(train_data)\n", 53 | "\n", 54 | " # Evaluate the model\n", 55 | " predictions = model.transform(test_data)\n", 56 | " evaluator = BinaryClassificationEvaluator(rawPredictionCol=\"rawPrediction\", labelCol=\"label\")\n", 57 | " auc = evaluator.evaluate(predictions)\n", 58 | " accuracy = predictions.filter(predictions[\"prediction\"] == predictions[\"label\"]).count() / predictions.count()\n", 59 | "\n", 60 | " return model, auc, accuracy\n", 61 | "\n", 62 | "# Save the trained model to disk\n", 63 | "def save_model(model, model_path):\n", 64 | " model.save(model_path)\n", 65 | "\n", 66 | "if __name__ == \"__main__\":\n", 67 | " start_time = time.time()\n", 68 | " \n", 69 | " # Initialize\n", 70 | " jieba.initialize() \n", 71 | " spark = initialize_spark()\n", 72 | "\n", 73 | " # Load data\n", 74 | " data_dir = \"/path/to/data\"\n", 75 | " data = load_data(spark, data_dir)\n", 76 | "\n", 77 | " # Split data into train and test sets\n", 78 | " train_data, test_data = data.randomSplit([0.8, 0.2], seed=123)\n", 79 | "\n", 80 | " # Register Jieba tokenizer as a UDF\n", 81 | " train_data = register_jieba_udf(train_data)\n", 82 | " test_data = register_jieba_udf(test_data)\n", 83 | "\n", 84 | " # Train and evaluate the model\n", 85 | " model, auc, accuracy = train_and_evaluate_model(train_data, test_data)\n", 86 | " print(\"Area Under the ROC Curve (AUC):\", auc)\n", 87 | " print(\"Accuracy:\", accuracy)\n", 88 | " \n", 89 | " # Get size of wights\n", 90 | " num_weights = len(model.coefficients)\n", 91 | " print(\"Number of Weights (Coefficients) in the Model:\", num_weights)\n", 92 | " \n", 93 | " # Save the model to disk\n", 94 | " model_path = \"/path/to/model\"\n", 95 | " save_model(model, model_path) \n", 96 | "\n", 97 | " end_time = time.time()\n", 98 | " execution_time = end_time - start_time\n", 99 | " print(\"Code execution time:\", execution_time, \"seconds\")" 100 | ] 101 | } 102 | ], 103 | "metadata": { 104 | "language_info": { 105 | "name": "python" 106 | } 107 | }, 108 | "nbformat": 4, 109 | "nbformat_minor": 2 110 | } 111 | -------------------------------------------------------------------------------- /sentencepiece.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "662360be-0285-4885-a477-2a563b6cd6ed", 6 | "metadata": {}, 7 | "source": [ 8 | "### 准备数据" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "68790045-f219-4831-8aeb-04c316ee7f7f", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import pandas as pd\n", 19 | "# 读取.parquet文件\n", 20 | "parquet_file = '/path/file_name.parquet'\n", 21 | "df = pd.read_parquet(parquet_file)\n", 22 | "\n", 23 | "# 获取text列的前1万条数据,只用10000条来做测试\n", 24 | "text_col = df['text'][:10000]\n", 25 | "\n", 26 | "# 指定要写入的txt文件\n", 27 | "txt_file = '/path/file_name.txt'\n", 28 | "\n", 29 | "# 将数据追加写入txt文件\n", 30 | "with open(txt_file, 'a') as file:\n", 31 | " content_col.to_csv(file, sep='\\t', index=False, header=False)\n", 32 | "print(f'前1万条content数据已写入到 {txt_file}')" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "id": "f68c7254-dc01-4835-83e3-fec476587a28", 38 | "metadata": {}, 39 | "source": [ 40 | "### 开始训练" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "fbdb3dc6-c0bb-4820-8b02-661e95a05f88", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "pip install sentencepiece\n", 51 | "\n", 52 | "nohup spm_train --input '/path/file_name.txt' \\\n", 53 | "--input_format text \\\n", 54 | "--model_prefix bpe_test \\\n", 55 | "--model_type bpe \\\n", 56 | "--vocab_size 10000 \\\n", 57 | "--character_coverage 0.9995 \\\n", 58 | "--num_threads 32 \\\n", 59 | "--split_digits True \\\n", 60 | "--byte_fallback True \\\n", 61 | "--max_sentence_length 24000 > bpe_test.log &" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "id": "efda410b-6836-4638-a70d-d77f188be2f3", 67 | "metadata": {}, 68 | "source": [ 69 | "### 开始使用" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "9f64b771-f659-4c10-8472-d991bb6d6423", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "import sentencepiece as spm\n", 80 | "sp_bpe = spm.SentencePieceProcessor() \n", 81 | "sp_bpe.load('bpe_test.model')\n", 82 | "print('*** BPE ***')\n", 83 | "print(sp_bpe.encode_as_pieces('The excellence of a translation can only be judged by noting'))\n", 84 | "print(len(sp_bpe.encode_as_pieces('The excellence of a translation can only be judged by noting')))\n", 85 | "print(sp_bpe.encode_as_pieces('麒麟,是中国古代神话中的一种瑞兽'))\n", 86 | "print(len(sp_bpe.encode_as_pieces('麒麟,是中国古代神话中的一种瑞兽')))" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "id": "6e47c652-b5eb-4973-b19a-3a5f35708110", 92 | "metadata": {}, 93 | "source": [ 94 | "### 合并LLaMa词表" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "c222b17b-8e8f-4823-86d6-b0e8a13f3e58", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "import os\n", 105 | "os.environ[\"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION\"]=\"python\"\n", 106 | "from transformers import LlamaTokenizer\n", 107 | "from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model\n", 108 | "import sentencepiece as spm\n", 109 | "\n", 110 | "# 位置\n", 111 | "llama_tokenizer_dir = \"/path/llama-2-7b-hf\" # 换成你自己模型的位置\n", 112 | "chinese_sp_model_file =\"/path/bpe_test.model\" # 刚才训练的模型\n", 113 | "\n", 114 | "# 加载\n", 115 | "llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir)\n", 116 | "chinese_sp_model = spm.SentencePieceProcessor()\n", 117 | "chinese_sp_model.Load(chinese_sp_model_file)\n", 118 | "llama_spm = sp_pb2_model.ModelProto()\n", 119 | "llama_spm.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto())\n", 120 | "chinese_spm = sp_pb2_model.ModelProto()\n", 121 | "chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto())\n", 122 | "\n", 123 | "\n", 124 | "# 打印两个词表的大小和原llama的特殊token\n", 125 | "print(len(llama_tokenizer),len(chinese_sp_model))\n", 126 | "print(llama_tokenizer.all_special_tokens)\n", 127 | "print(llama_tokenizer.all_special_ids)\n", 128 | "print(llama_tokenizer.special_tokens_map)\n", 129 | "\n", 130 | "# 结果\n", 131 | "32000 10000\n", 132 | "['', '', '']\n", 133 | "[1, 2, 0]\n", 134 | "{'bos_token': '', 'eos_token': '', 'unk_token': ''}\n", 135 | "\n", 136 | "# 开始往llama词表里添加,这里你也可以直接加入你想要加入词表的词,或者是领域内的特殊词\n", 137 | "llama_spm_tokens_set=set(p.piece for p in llama_spm.pieces)\n", 138 | "print(len(llama_spm_tokens_set))\n", 139 | "print(f\"Before:{len(llama_spm_tokens_set)}\")\n", 140 | "for p in chinese_spm.pieces:\n", 141 | " piece = p.piece\n", 142 | " if piece not in llama_spm_tokens_set:\n", 143 | " new_p = sp_pb2_model.ModelProto().SentencePiece()\n", 144 | " new_p.piece = piece\n", 145 | " new_p.score = 0\n", 146 | " llama_spm.pieces.append(new_p)\n", 147 | "print(f\"New model pieces: {len(llama_spm.pieces)}\")\n", 148 | "\n", 149 | "# 结果\n", 150 | "32000\n", 151 | "Before:32000\n", 152 | "New model pieces: 40114\n", 153 | "# 我们中文词表原来有1万,去重添加后,添加了8114个词。\n", 154 | "\n", 155 | "# 保存合并后的模型\n", 156 | "output_sp_dir = 'merged_tokenizer_sp_test'\n", 157 | "output_hf_dir = 'merged_tokenizer_hf_test'\n", 158 | "os.makedirs(output_sp_dir,exist_ok=True)\n", 159 | "with open(output_sp_dir+'/chinese_llama.model', 'wb') as f:\n", 160 | " f.write(llama_spm.SerializeToString())\n", 161 | "tokenizer = LlamaTokenizer(vocab_file=output_sp_dir+'/chinese_llama.model')\n", 162 | "\n", 163 | "tokenizer.save_pretrained(output_hf_dir)\n", 164 | "print(f\"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}\")\n", 165 | "\n", 166 | "# 看一下效果\n", 167 | "llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir)\n", 168 | "chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir)\n", 169 | "\n", 170 | "\n", 171 | "text = \"The excellence of a translation can only be judged by noting\"\n", 172 | "print(\"Test text:\\n\",text)\n", 173 | "print(f\"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}\")\n", 174 | "print(f\"Tokenized length by LLaMA tokenizer:{len(llama_tokenizer.tokenize(text))}\")\n", 175 | "print(f\"Tokenized by chinese_llama tokenizer:{chinese_llama_tokenizer.tokenize(text)}\")\n", 176 | "print(f\"Tokenized length by LLaMA-extent-1 tokenizer:{len(chinese_llama_tokenizer.tokenize(text))}\")\n", 177 | "\n", 178 | "\n", 179 | "text = \"麒麟,是中国古代神话中的一种瑞兽\"\n", 180 | "print(\"Test text:\\n\",text)\n", 181 | "print(f\"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}\")\n", 182 | "print(f\"Tokenized length by LLaMA tokenizer:{len(llama_tokenizer.tokenize(text))}\")\n", 183 | "print(f\"Tokenized by chinese_llama tokenizer:{chinese_llama_tokenizer.tokenize(text)}\")\n", 184 | "print(f\"Tokenized length by chinese_llama tokenizer:{len(chinese_llama_tokenizer.tokenize(text))}\")\n" 185 | ] 186 | } 187 | ], 188 | "metadata": { 189 | "kernelspec": { 190 | "display_name": "Python 3 (ipykernel)", 191 | "language": "python", 192 | "name": "python3" 193 | }, 194 | "language_info": { 195 | "codemirror_mode": { 196 | "name": "ipython", 197 | "version": 3 198 | }, 199 | "file_extension": ".py", 200 | "mimetype": "text/x-python", 201 | "name": "python", 202 | "nbconvert_exporter": "python", 203 | "pygments_lexer": "ipython3", 204 | "version": "3.8.13" 205 | } 206 | }, 207 | "nbformat": 4, 208 | "nbformat_minor": 5 209 | } 210 | -------------------------------------------------------------------------------- /tokenization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "49100733-1c5e-4ff2-9051-208065a6a86f", 6 | "metadata": {}, 7 | "source": [ 8 | "### Wordpiece" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "5f0f8aa8-e3db-42a5-be41-cde4d6373463", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "stats: defaultdict(, {'我': 1, '喜欢': 2, '吃': 2, '苹果': 1, '他': 1, '不': 1, '苹果派': 1, 'I': 1, 'like': 1, 'to': 1, 'eat': 1, 'apples': 1, 'She': 1, 'has': 1, 'a': 2, 'cute': 2, 'cat': 1, 'you': 2, 'are': 1, 'very': 1, 'give': 1, 'hug': 1})\n", 22 | "alphabet: ['##a', '##e', '##g', '##h', '##i', '##k', '##l', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##y', '##果', '##欢', '##派', 'I', 'S', 'a', 'c', 'e', 'g', 'h', 'l', 't', 'v', 'y', '不', '他', '吃', '喜', '我', '苹']\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "sentences = [\n", 28 | " \"我\",\n", 29 | " \"喜欢\",\n", 30 | " \"吃\",\n", 31 | " \"苹果\",\n", 32 | " \"他\",\n", 33 | " \"不\",\n", 34 | " \"喜欢\",\n", 35 | " \"吃\",\n", 36 | " \"苹果派\",\n", 37 | " \"I like to eat apples\",\n", 38 | " \"She has a cute cat\",\n", 39 | " \"you are very cute\",\n", 40 | " \"give you a hug\",\n", 41 | "]\n", 42 | "\n", 43 | "from collections import defaultdict\n", 44 | "# 构建频率统计\n", 45 | "def build_stats(sentences):\n", 46 | " stats = defaultdict(int)\n", 47 | " for sentence in sentences:\n", 48 | " symbols = sentence.split()\n", 49 | " for symbol in symbols:\n", 50 | " stats[symbol] += 1\n", 51 | " return stats\n", 52 | "\n", 53 | "stats = build_stats(sentences)\n", 54 | "print(\"stats:\", stats)\n", 55 | "\n", 56 | "alphabet = []\n", 57 | "for word in stats.keys():\n", 58 | " if word[0] not in alphabet:\n", 59 | " alphabet.append(word[0])\n", 60 | " for letter in word[1:]:\n", 61 | " if f\"##{letter}\" not in alphabet:\n", 62 | " alphabet.append(f\"##{letter}\")\n", 63 | "\n", 64 | "alphabet.sort()\n", 65 | "# 初始词表\n", 66 | "vocab = alphabet.copy()\n", 67 | "print(\"alphabet:\", alphabet)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 2, 73 | "id": "d6282ac2-5549-4f55-9aac-bad034c7a0f7", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "splits: {'我': ['我'], '喜欢': ['喜', '##欢'], '吃': ['吃'], '苹果': ['苹', '##果'], '他': ['他'], '不': ['不'], '苹果派': ['苹', '##果', '##派'], 'I': ['I'], 'like': ['l', '##i', '##k', '##e'], 'to': ['t', '##o'], 'eat': ['e', '##a', '##t'], 'apples': ['a', '##p', '##p', '##l', '##e', '##s'], 'She': ['S', '##h', '##e'], 'has': ['h', '##a', '##s'], 'a': ['a'], 'cute': ['c', '##u', '##t', '##e'], 'cat': ['c', '##a', '##t'], 'you': ['y', '##o', '##u'], 'are': ['a', '##r', '##e'], 'very': ['v', '##e', '##r', '##y'], 'give': ['g', '##i', '##v', '##e'], 'hug': ['h', '##u', '##g']}\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "splits = {\n", 86 | " word: [c if i == 0 else f\"##{c}\" for i, c in enumerate(word)]\n", 87 | " for word in stats.keys()\n", 88 | "}\n", 89 | "print(\"splits:\", splits)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 3, 95 | "id": "1e1bc015-9b09-4172-ac2e-5a744c3e43d4", 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "('喜', '##欢'): 0.5\n", 103 | "('苹', '##果'): 0.5\n", 104 | "('##果', '##派'): 0.5\n", 105 | "('l', '##i'): 0.5\n", 106 | "('##i', '##k'): 0.5\n", 107 | "('##k', '##e'): 0.125\n" 108 | ] 109 | } 110 | ], 111 | "source": [ 112 | "def compute_pair_scores(splits):\n", 113 | " letter_freqs = defaultdict(int)\n", 114 | " pair_freqs = defaultdict(int)\n", 115 | " for word, freq in stats.items():\n", 116 | " split = splits[word]\n", 117 | " if len(split) == 1:\n", 118 | " letter_freqs[split[0]] += freq\n", 119 | " continue\n", 120 | " for i in range(len(split) - 1):\n", 121 | " pair = (split[i], split[i + 1])\n", 122 | " letter_freqs[split[i]] += freq\n", 123 | " pair_freqs[pair] += freq\n", 124 | " letter_freqs[split[-1]] += freq\n", 125 | "\n", 126 | " scores = {\n", 127 | " pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])\n", 128 | " for pair, freq in pair_freqs.items()\n", 129 | " }\n", 130 | " return scores\n", 131 | "\n", 132 | "pair_scores = compute_pair_scores(splits)\n", 133 | "for i, key in enumerate(pair_scores.keys()):\n", 134 | " print(f\"{key}: {pair_scores[key]}\")\n", 135 | " if i >= 5:\n", 136 | " break" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 4, 142 | "id": "5f0e1163-21ee-4105-bd48-66482bdd59e0", 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "('S', '##h') 1.0\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "best_pair = \"\"\n", 155 | "max_score = None\n", 156 | "for pair, score in pair_scores.items():\n", 157 | " if max_score is None or max_score < score:\n", 158 | " best_pair = pair\n", 159 | " max_score = score\n", 160 | "\n", 161 | "print(best_pair, max_score)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 5, 167 | "id": "4084f9cb-0246-4bc0-b17a-c01bee69c3ba", 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "name": "stdout", 172 | "output_type": "stream", 173 | "text": [ 174 | "vocab: ['##a', '##e', '##g', '##h', '##i', '##k', '##l', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##y', '##果', '##欢', '##派', 'I', 'S', 'a', 'c', 'e', 'g', 'h', 'l', 't', 'v', 'y', '不', '他', '吃', '喜', '我', '苹', 'Sh', '喜欢', '苹果', '苹果派', 'li', 'lik', 'gi', 'giv', '##pl', '##ppl', '##ry', 'to', 'yo', 'ea', 'eat']\n" 175 | ] 176 | } 177 | ], 178 | "source": [ 179 | "def merge_pair(a, b, splits):\n", 180 | " for word in stats:\n", 181 | " split = splits[word]\n", 182 | " if len(split) == 1:\n", 183 | " continue\n", 184 | " i = 0\n", 185 | " while i < len(split) - 1:\n", 186 | " if split[i] == a and split[i + 1] == b:\n", 187 | " merge = a + b[2:] if b.startswith(\"##\") else a + b\n", 188 | " split = split[:i] + [merge] + split[i + 2 :]\n", 189 | " else:\n", 190 | " i += 1\n", 191 | " splits[word] = split\n", 192 | " return splits\n", 193 | "\n", 194 | "vocab_size = 50\n", 195 | "while len(vocab) < vocab_size:\n", 196 | " scores = compute_pair_scores(splits)\n", 197 | " best_pair, max_score = \"\", None\n", 198 | " for pair, score in scores.items():\n", 199 | " if max_score is None or max_score < score:\n", 200 | " best_pair = pair\n", 201 | " max_score = score\n", 202 | " splits = merge_pair(*best_pair, splits)\n", 203 | " new_token = (\n", 204 | " best_pair[0] + best_pair[1][2:]\n", 205 | " if best_pair[1].startswith(\"##\")\n", 206 | " else best_pair[0] + best_pair[1]\n", 207 | " )\n", 208 | " vocab.append(new_token)\n", 209 | "\n", 210 | "print(\"vocab:\", vocab)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "id": "2cad0484-c620-4074-88ef-a86a64297c83", 216 | "metadata": {}, 217 | "source": [ 218 | "### Byte-Pair Encoding (BPE)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 6, 224 | "id": "9048c7d1-5f7c-49b6-baff-6abd65f65639", 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "sentences = [\n", 229 | " \"我\",\n", 230 | " \"喜欢\",\n", 231 | " \"吃\",\n", 232 | " \"苹果\",\n", 233 | " \"他\",\n", 234 | " \"不\",\n", 235 | " \"喜欢\",\n", 236 | " \"吃\",\n", 237 | " \"苹果派\",\n", 238 | " \"I like to eat apples\",\n", 239 | " \"She has a cute cat\",\n", 240 | " \"you are very cute\",\n", 241 | " \"give you a hug\",\n", 242 | "]" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 7, 248 | "id": "3ecdc5f7-e981-4f4d-9ef9-0aa7a2e807c2", 249 | "metadata": {}, 250 | "outputs": [ 251 | { 252 | "name": "stdout", 253 | "output_type": "stream", 254 | "text": [ 255 | "stats: defaultdict(, {'我': 1, '喜欢': 2, '吃': 2, '苹果': 1, '他': 1, '不': 1, '苹果派': 1, 'I': 1, 'like': 1, 'to': 1, 'eat': 1, 'apples': 1, 'She': 1, 'has': 1, 'a': 2, 'cute': 2, 'cat': 1, 'you': 2, 'are': 1, 'very': 1, 'give': 1, 'hug': 1})\n", 256 | "alphabet: ['I', 'S', 'a', 'c', 'e', 'g', 'h', 'i', 'k', 'l', 'o', 'p', 'r', 's', 't', 'u', 'v', 'y', '不', '他', '吃', '喜', '我', '果', '欢', '派', '苹']\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "# 构建频率统计\n", 262 | "def build_stats(sentences):\n", 263 | " stats = defaultdict(int)\n", 264 | " for sentence in sentences:\n", 265 | " symbols = sentence.split()\n", 266 | " for symbol in symbols:\n", 267 | " stats[symbol] += 1\n", 268 | " return stats\n", 269 | "\n", 270 | "stats = build_stats(sentences)\n", 271 | "print(\"stats:\", stats)\n", 272 | "\n", 273 | "alphabet = []\n", 274 | "for word in stats.keys():\n", 275 | " for letter in word:\n", 276 | " if letter not in alphabet:\n", 277 | " alphabet.append(letter)\n", 278 | "alphabet.sort()\n", 279 | "\n", 280 | "# 初始词表\n", 281 | "vocab = alphabet.copy()\n", 282 | "print(\"alphabet:\", alphabet)" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 8, 288 | "id": "3e438e15-dce4-479b-9447-083ca242c716", 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "splits: {'我': ['我'], '喜欢': ['喜', '欢'], '吃': ['吃'], '苹果': ['苹', '果'], '他': ['他'], '不': ['不'], '苹果派': ['苹', '果', '派'], 'I': ['I'], 'like': ['l', 'i', 'k', 'e'], 'to': ['t', 'o'], 'eat': ['e', 'a', 't'], 'apples': ['a', 'p', 'p', 'l', 'e', 's'], 'She': ['S', 'h', 'e'], 'has': ['h', 'a', 's'], 'a': ['a'], 'cute': ['c', 'u', 't', 'e'], 'cat': ['c', 'a', 't'], 'you': ['y', 'o', 'u'], 'are': ['a', 'r', 'e'], 'very': ['v', 'e', 'r', 'y'], 'give': ['g', 'i', 'v', 'e'], 'hug': ['h', 'u', 'g']}\n", 296 | "('喜', '欢'): 2\n", 297 | "('苹', '果'): 2\n", 298 | "('果', '派'): 1\n", 299 | "('l', 'i'): 1\n", 300 | "('i', 'k'): 1\n", 301 | "('k', 'e'): 1\n" 302 | ] 303 | } 304 | ], 305 | "source": [ 306 | "splits = {word: [c for c in word] for word in stats.keys()}\n", 307 | "print(\"splits:\", splits)\n", 308 | "\n", 309 | "def compute_pair_freqs(splits):\n", 310 | " pair_freqs = defaultdict(int)\n", 311 | " for word, freq in stats.items():\n", 312 | " split = splits[word]\n", 313 | " if len(split) == 1:\n", 314 | " continue\n", 315 | " for i in range(len(split) - 1):\n", 316 | " pair = (split[i], split[i + 1])\n", 317 | " pair_freqs[pair] += freq\n", 318 | " return pair_freqs\n", 319 | "pair_freqs = compute_pair_freqs(splits)\n", 320 | "\n", 321 | "for i, key in enumerate(pair_freqs.keys()):\n", 322 | " print(f\"{key}: {pair_freqs[key]}\")\n", 323 | " if i >= 5:\n", 324 | " break" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 9, 330 | "id": "14b12e7e-8910-40dc-8217-1f46d603d33f", 331 | "metadata": {}, 332 | "outputs": [ 333 | { 334 | "name": "stdout", 335 | "output_type": "stream", 336 | "text": [ 337 | "('喜', '欢') 2\n" 338 | ] 339 | } 340 | ], 341 | "source": [ 342 | "best_pair = \"\"\n", 343 | "max_freq = None\n", 344 | "for pair, freq in pair_freqs.items():\n", 345 | " if max_freq is None or max_freq < freq:\n", 346 | " best_pair = pair\n", 347 | " max_freq = freq\n", 348 | "\n", 349 | "print(best_pair, max_freq)" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 10, 355 | "id": "68c3d67a-fb1b-4573-a629-29cb7c09e5a8", 356 | "metadata": {}, 357 | "outputs": [ 358 | { 359 | "name": "stdout", 360 | "output_type": "stream", 361 | "text": [ 362 | "merges: {('喜', '欢'): '喜欢', ('苹', '果'): '苹果', ('a', 't'): 'at', ('c', 'u'): 'cu', ('cu', 't'): 'cut', ('cut', 'e'): 'cute', ('y', 'o'): 'yo', ('yo', 'u'): 'you', ('v', 'e'): 've', ('苹果', '派'): '苹果派', ('l', 'i'): 'li', ('li', 'k'): 'lik', ('lik', 'e'): 'like', ('t', 'o'): 'to', ('e', 'at'): 'eat', ('a', 'p'): 'ap', ('ap', 'p'): 'app', ('app', 'l'): 'appl', ('appl', 'e'): 'apple', ('apple', 's'): 'apples', ('S', 'h'): 'Sh', ('Sh', 'e'): 'She', ('h', 'a'): 'ha'}\n", 363 | "vocab: ['I', 'S', 'a', 'c', 'e', 'g', 'h', 'i', 'k', 'l', 'o', 'p', 'r', 's', 't', 'u', 'v', 'y', '不', '他', '吃', '喜', '我', '果', '欢', '派', '苹', '喜欢', '苹果', 'at', 'cu', 'cut', 'cute', 'yo', 'you', 've', '苹果派', 'li', 'lik', 'like', 'to', 'eat', 'ap', 'app', 'appl', 'apple', 'apples', 'Sh', 'She', 'ha']\n" 364 | ] 365 | } 366 | ], 367 | "source": [ 368 | "def merge_pair(a, b, splits):\n", 369 | " for word in stats:\n", 370 | " split = splits[word]\n", 371 | " if len(split) == 1:\n", 372 | " continue\n", 373 | "\n", 374 | " i = 0\n", 375 | " while i < len(split) - 1:\n", 376 | " if split[i] == a and split[i + 1] == b:\n", 377 | " split = split[:i] + [a + b] + split[i + 2 :]\n", 378 | " else:\n", 379 | " i += 1\n", 380 | " splits[word] = split\n", 381 | " return splits\n", 382 | "\n", 383 | "# 假设我们想要的词典为50\n", 384 | "merges = {}\n", 385 | "vocab_size = 50\n", 386 | "\n", 387 | "while len(vocab) < vocab_size:\n", 388 | " pair_freqs = compute_pair_freqs(splits)\n", 389 | " best_pair = \"\"\n", 390 | " max_freq = None\n", 391 | " for pair, freq in pair_freqs.items():\n", 392 | " if max_freq is None or max_freq < freq:\n", 393 | " best_pair = pair\n", 394 | " max_freq = freq\n", 395 | " splits = merge_pair(*best_pair, splits)\n", 396 | " merges[best_pair] = best_pair[0] + best_pair[1]\n", 397 | " vocab.append(best_pair[0] + best_pair[1])\n", 398 | "\n", 399 | "print(\"merges:\", merges)\n", 400 | "print(\"vocab:\", vocab)" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "id": "4589d599-0c4f-467c-8c03-bec879f93982", 406 | "metadata": {}, 407 | "source": [ 408 | "### Byte-level BPE(BBPE)" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 11, 414 | "id": "14cdfdb4-a3c6-414e-9a02-34e3e04398b0", 415 | "metadata": {}, 416 | "outputs": [ 417 | { 418 | "name": "stdout", 419 | "output_type": "stream", 420 | "text": [ 421 | "initial_vocab: [b'\\x00', b'\\x01', b'\\x02', b'\\x03', b'\\x04', b'\\x05', b'\\x06', b'\\x07', b'\\x08', b'\\t', b'\\n', b'\\x0b', b'\\x0c', b'\\r', b'\\x0e', b'\\x0f', b'\\x10', b'\\x11', b'\\x12', b'\\x13', b'\\x14', b'\\x15', b'\\x16', b'\\x17', b'\\x18', b'\\x19', b'\\x1a', b'\\x1b', b'\\x1c', b'\\x1d', b'\\x1e', b'\\x1f', b' ', b'!', b'\"', b'#', b'$', b'%', b'&', b\"'\", b'(', b')', b'*', b'+', b',', b'-', b'.', b'/', b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b':', b';', b'<', b'=', b'>', b'?', b'@', b'A', b'B', b'C', b'D', b'E', b'F', b'G', b'H', b'I', b'J', b'K', b'L', b'M', b'N', b'O', b'P', b'Q', b'R', b'S', b'T', b'U', b'V', b'W', b'X', b'Y', b'Z', b'[', b'\\\\', b']', b'^', b'_', b'`', b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', b'z', b'{', b'|', b'}', b'~', b'\\x7f', b'\\x80', b'\\x81', b'\\x82', b'\\x83', b'\\x84', b'\\x85', b'\\x86', b'\\x87', b'\\x88', b'\\x89', b'\\x8a', b'\\x8b', b'\\x8c', b'\\x8d', b'\\x8e', b'\\x8f', b'\\x90', b'\\x91', b'\\x92', b'\\x93', b'\\x94', b'\\x95', b'\\x96', b'\\x97', b'\\x98', b'\\x99', b'\\x9a', b'\\x9b', b'\\x9c', b'\\x9d', b'\\x9e', b'\\x9f', b'\\xa0', b'\\xa1', b'\\xa2', b'\\xa3', b'\\xa4', b'\\xa5', b'\\xa6', b'\\xa7', b'\\xa8', b'\\xa9', b'\\xaa', b'\\xab', b'\\xac', b'\\xad', b'\\xae', b'\\xaf', b'\\xb0', b'\\xb1', b'\\xb2', b'\\xb3', b'\\xb4', b'\\xb5', b'\\xb6', b'\\xb7', b'\\xb8', b'\\xb9', b'\\xba', b'\\xbb', b'\\xbc', b'\\xbd', b'\\xbe', b'\\xbf', b'\\xc0', b'\\xc1', b'\\xc2', b'\\xc3', b'\\xc4', b'\\xc5', b'\\xc6', b'\\xc7', b'\\xc8', b'\\xc9', b'\\xca', b'\\xcb', b'\\xcc', b'\\xcd', b'\\xce', b'\\xcf', b'\\xd0', b'\\xd1', b'\\xd2', b'\\xd3', b'\\xd4', b'\\xd5', b'\\xd6', b'\\xd7', b'\\xd8', b'\\xd9', b'\\xda', b'\\xdb', b'\\xdc', b'\\xdd', b'\\xde', b'\\xdf', b'\\xe0', b'\\xe1', b'\\xe2', b'\\xe3', b'\\xe4', b'\\xe5', b'\\xe6', b'\\xe7', b'\\xe8', b'\\xe9', b'\\xea', b'\\xeb', b'\\xec', b'\\xed', b'\\xee', b'\\xef', b'\\xf0', b'\\xf1', b'\\xf2', b'\\xf3', b'\\xf4', b'\\xf5', b'\\xf6', b'\\xf7', b'\\xf8', b'\\xf9', b'\\xfa', b'\\xfb', b'\\xfc', b'\\xfd', b'\\xfe', b'\\xff']\n", 422 | "vocab: [b'\\x00', b'\\x01', b'\\x02', b'\\x03', b'\\x04', b'\\x05', b'\\x06', b'\\x07', b'\\x08', b'\\t', b'\\n', b'\\x0b', b'\\x0c', b'\\r', b'\\x0e', b'\\x0f', b'\\x10', b'\\x11', b'\\x12', b'\\x13', b'\\x14', b'\\x15', b'\\x16', b'\\x17', b'\\x18', b'\\x19', b'\\x1a', b'\\x1b', b'\\x1c', b'\\x1d', b'\\x1e', b'\\x1f', b' ', b'!', b'\"', b'#', b'$', b'%', b'&', b\"'\", b'(', b')', b'*', b'+', b',', b'-', b'.', b'/', b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b':', b';', b'<', b'=', b'>', b'?', b'@', b'A', b'B', b'C', b'D', b'E', b'F', b'G', b'H', b'I', b'J', b'K', b'L', b'M', b'N', b'O', b'P', b'Q', b'R', b'S', b'T', b'U', b'V', b'W', b'X', b'Y', b'Z', b'[', b'\\\\', b']', b'^', b'_', b'`', b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', b'z', b'{', b'|', b'}', b'~', b'\\x7f', b'\\x80', b'\\x81', b'\\x82', b'\\x83', b'\\x84', b'\\x85', b'\\x86', b'\\x87', b'\\x88', b'\\x89', b'\\x8a', b'\\x8b', b'\\x8c', b'\\x8d', b'\\x8e', b'\\x8f', b'\\x90', b'\\x91', b'\\x92', b'\\x93', b'\\x94', b'\\x95', b'\\x96', b'\\x97', b'\\x98', b'\\x99', b'\\x9a', b'\\x9b', b'\\x9c', b'\\x9d', b'\\x9e', b'\\x9f', b'\\xa0', b'\\xa1', b'\\xa2', b'\\xa3', b'\\xa4', b'\\xa5', b'\\xa6', b'\\xa7', b'\\xa8', b'\\xa9', b'\\xaa', b'\\xab', b'\\xac', b'\\xad', b'\\xae', b'\\xaf', b'\\xb0', b'\\xb1', b'\\xb2', b'\\xb3', b'\\xb4', b'\\xb5', b'\\xb6', b'\\xb7', b'\\xb8', b'\\xb9', b'\\xba', b'\\xbb', b'\\xbc', b'\\xbd', b'\\xbe', b'\\xbf', b'\\xc0', b'\\xc1', b'\\xc2', b'\\xc3', b'\\xc4', b'\\xc5', b'\\xc6', b'\\xc7', b'\\xc8', b'\\xc9', b'\\xca', b'\\xcb', b'\\xcc', b'\\xcd', b'\\xce', b'\\xcf', b'\\xd0', b'\\xd1', b'\\xd2', b'\\xd3', b'\\xd4', b'\\xd5', b'\\xd6', b'\\xd7', b'\\xd8', b'\\xd9', b'\\xda', b'\\xdb', b'\\xdc', b'\\xdd', b'\\xde', b'\\xdf', b'\\xe0', b'\\xe1', b'\\xe2', b'\\xe3', b'\\xe4', b'\\xe5', b'\\xe6', b'\\xe7', b'\\xe8', b'\\xe9', b'\\xea', b'\\xeb', b'\\xec', b'\\xed', b'\\xee', b'\\xef', b'\\xf0', b'\\xf1', b'\\xf2', b'\\xf3', b'\\xf4', b'\\xf5', b'\\xf6', b'\\xf7', b'\\xf8', b'\\xf9', b'\\xfa', b'\\xfb', b'\\xfc', b'\\xfd', b'\\xfe', b'\\xff']\n" 423 | ] 424 | } 425 | ], 426 | "source": [ 427 | "from collections import defaultdict\n", 428 | "sentences = [\n", 429 | " \"我\",\n", 430 | " \"喜欢\",\n", 431 | " \"吃\",\n", 432 | " \"苹果\",\n", 433 | " \"他\",\n", 434 | " \"不\",\n", 435 | " \"喜欢\",\n", 436 | " \"吃\",\n", 437 | " \"苹果派\",\n", 438 | " \"I like to eat apples\",\n", 439 | " \"She has a cute cat\",\n", 440 | " \"you are very cute\",\n", 441 | " \"give you a hug\",\n", 442 | "]\n", 443 | "# 构建初始词汇表,包含一个字节的256个表示\n", 444 | "initial_vocab = [bytes([byte]) for byte in range(256)]\n", 445 | "vocab = initial_vocab.copy()\n", 446 | "print(\"initial_vocab:\", initial_vocab)\n", 447 | "\n", 448 | "# 构建频率统计\n", 449 | "def build_stats(sentences):\n", 450 | " stats = defaultdict(int)\n", 451 | " for sentence in sentences:\n", 452 | " symbols = sentence.split()\n", 453 | " for symbol in symbols:\n", 454 | " stats[symbol.encode(\"utf-8\")] += 1\n", 455 | " return stats\n", 456 | "stats = build_stats(sentences)\n", 457 | "\n", 458 | "splits = {word: [byte for byte in word] for word in stats.keys()}\n", 459 | "def compute_pair_freqs(splits):\n", 460 | " pair_freqs = defaultdict(int)\n", 461 | " for word, freq in stats.items():\n", 462 | " split = splits[word]\n", 463 | " if len(split) == 1:\n", 464 | " continue\n", 465 | " for i in range(len(split) - 1):\n", 466 | " pair = (split[i], split[i + 1])\n", 467 | " pair_freqs[pair] += freq\n", 468 | " return pair_freqs\n", 469 | "\n", 470 | "pair_freqs = compute_pair_freqs(splits)\n", 471 | "\n", 472 | "def merge_pair(pair, splits):\n", 473 | " merged_byte = bytes(pair)\n", 474 | " for word in stats:\n", 475 | " split = splits[word]\n", 476 | " if len(split) == 1:\n", 477 | " continue\n", 478 | " i = 0\n", 479 | " while i < len(split) - 1:\n", 480 | " if split[i:i+2] == pair: # 检查分割中是否有这对字节\n", 481 | " split = split[:i] + [merged_byte] + split[i + 2 :]\n", 482 | " else:\n", 483 | " i += 1\n", 484 | " splits[word] = split\n", 485 | " return splits\n", 486 | "\n", 487 | "vocab_size = 50\n", 488 | "while len(vocab) < vocab_size:\n", 489 | " pair_freqs = compute_pair_freqs(splits)\n", 490 | " best_pair = ()\n", 491 | " max_freq = None\n", 492 | " for pair, freq in pair_freqs.items():\n", 493 | " if max_freq is None or max_freq < freq:\n", 494 | " best_pair = pair\n", 495 | " max_freq = freq\n", 496 | " splits = merge_pair(best_pair, splits)\n", 497 | " merged_byte = bytes(best_pair)\n", 498 | "\n", 499 | "print(\"vocab:\", vocab)" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "id": "71a225fd-f5cf-4e11-a178-30d5f96c7cc8", 506 | "metadata": {}, 507 | "outputs": [], 508 | "source": [] 509 | } 510 | ], 511 | "metadata": { 512 | "kernelspec": { 513 | "display_name": "Python 3 (ipykernel)", 514 | "language": "python", 515 | "name": "python3" 516 | }, 517 | "language_info": { 518 | "codemirror_mode": { 519 | "name": "ipython", 520 | "version": 3 521 | }, 522 | "file_extension": ".py", 523 | "mimetype": "text/x-python", 524 | "name": "python", 525 | "nbconvert_exporter": "python", 526 | "pygments_lexer": "ipython3", 527 | "version": "3.8.13" 528 | } 529 | }, 530 | "nbformat": 4, 531 | "nbformat_minor": 5 532 | } 533 | -------------------------------------------------------------------------------- /transformer_torch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c73fdb87-4574-4dd4-a8ae-67b4040fa3e1", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# 后续代码所需的依赖\n", 11 | "import numpy as np\n", 12 | "import torch\n", 13 | "import torch.nn as nn\n", 14 | "import torch.nn.functional as F\n", 15 | "import math, copy, time\n", 16 | "\n", 17 | "class TokenEmbedding(nn.Embedding):\n", 18 | " \"\"\"\n", 19 | " 使用torch.nn的Embedding模块\n", 20 | " \"\"\"\n", 21 | "\n", 22 | " def __init__(self, vocab_size, d_model):\n", 23 | " \"\"\"\n", 24 | " TokenEmbedding类\n", 25 | "\n", 26 | " :param vocab_size: 词汇表的大小\n", 27 | " :param d_model: 模型的维度\n", 28 | " :padding的索引为1,即token索引为1时,Embedding补0\n", 29 | " \"\"\"\n", 30 | " super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "id": "956b9cb9-3aab-481b-88dd-34b92f9563d2", 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "模块中的参数数量为: 512000\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "tok_emb = TokenEmbedding(1000, 512)\n", 49 | "num_params = sum(p.numel() for p in tok_emb.parameters())\n", 50 | "print(\"模块中的参数数量为:\", num_params)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 3, 56 | "id": "bf5a9738-b4d0-4ab1-be42-108c6fffb969", 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "res: tensor([[[ 0.1994, 0.9800, 1.3162, ..., 0.5627, 0.4720, 0.9675],\n", 64 | " [ 0.5742, -0.6087, 0.5124, ..., 1.2031, 0.7531, -0.6712],\n", 65 | " [ 1.5465, -0.5008, 0.9264, ..., 0.0815, -0.2846, 0.0975]],\n", 66 | "\n", 67 | " [[ 0.7377, -0.0586, -1.3138, ..., 0.4524, -0.2046, 1.6616],\n", 68 | " [ 1.1249, -1.7420, -1.6807, ..., -0.0495, 0.3476, 1.1462],\n", 69 | " [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]],\n", 70 | " grad_fn=)\n", 71 | "res.shape: torch.Size([2, 3, 512])\n" 72 | ] 73 | } 74 | ], 75 | "source": [ 76 | "# x是batch_size为2, seq_len为3,索引为1的会被padding为0\n", 77 | "x = torch.LongTensor([[6, 5, 4], [3, 2, 1]])\n", 78 | "res = tok_emb(x)\n", 79 | "print(\"res:\", res)\n", 80 | "print(\"res.shape:\", res.shape)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 4, 86 | "id": "5a45fb17-e01b-434e-9003-f8193d8beb6c", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "class PositionalEncoding(nn.Module):\n", 91 | " \"\"\"\n", 92 | " 计算正余弦位置编码。\n", 93 | " \"\"\"\n", 94 | " def __init__(self, d_model, max_len):\n", 95 | " \"\"\"\n", 96 | " 正余弦位置编码类\n", 97 | "\n", 98 | " :param d_model: 模型的维度\n", 99 | " :param max_len: 最大序列长度\n", 100 | " \"\"\"\n", 101 | " super(PositionalEncoding, self).__init__()\n", 102 | "\n", 103 | " # 初始化位置编码矩阵\n", 104 | " self.encoding = torch.zeros(max_len, d_model)\n", 105 | " self.encoding.requires_grad = False # 不需要计算梯度\n", 106 | "\n", 107 | " pos = torch.arange(0, max_len)\n", 108 | " pos = pos.float().unsqueeze(dim=1)\n", 109 | "\n", 110 | " # 'i'表示d_model的索引(例如,嵌入大小=50,'i' = [0,50])\n", 111 | " # “step=2”表示将'i'乘以二(与2 * i相同)\n", 112 | " _2i = torch.arange(0, d_model, step=2).float()\n", 113 | " self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))\n", 114 | " self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))\n", 115 | "\n", 116 | " def forward(self, x):\n", 117 | " # self.encoding\n", 118 | " # [max_len = 512, d_model = 512]\n", 119 | "\n", 120 | " batch_size, seq_len = x.size()\n", 121 | " # [batch_size = 8, seq_len = 30]\n", 122 | "\n", 123 | " return self.encoding[:seq_len, :]\n", 124 | " # [seq_len = 30, d_model = 512]\n", 125 | " # 将与 tok_emb 相加:[8, 30, 512]" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 5, 131 | "id": "2f2e1eb5-18eb-4b50-92c3-0ed6c5d29f01", 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "模块中的参数数量为: 0\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "pe = PositionalEncoding(512,512)\n", 144 | "num_params = sum(p.numel() for p in pe.parameters())\n", 145 | "print(\"模块中的参数数量为:\", num_params)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 6, 151 | "id": "6d60b122-114f-48cf-b241-fd57766e0afe", 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "name": "stdout", 156 | "output_type": "stream", 157 | "text": [ 158 | "res: tensor([[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,\n", 159 | " 0.0000e+00, 1.0000e+00],\n", 160 | " [ 8.4147e-01, 5.4030e-01, 8.2186e-01, ..., 1.0000e+00,\n", 161 | " 1.0366e-04, 1.0000e+00],\n", 162 | " [ 9.0930e-01, -4.1615e-01, 9.3641e-01, ..., 1.0000e+00,\n", 163 | " 2.0733e-04, 1.0000e+00]])\n", 164 | "res.shape: torch.Size([3, 512])\n" 165 | ] 166 | } 167 | ], 168 | "source": [ 169 | "# x是batch_size为2, seq_len为3\n", 170 | "x = torch.LongTensor([[6, 5, 4], [3, 2, 1]])\n", 171 | "res = pe.forward(x)\n", 172 | "print(\"res:\", res)\n", 173 | "# 返回的形状是[seq_len = 3, d_model = 512]\n", 174 | "print(\"res.shape:\", res.shape)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 7, 180 | "id": "5cf833ab-9ab5-47ec-923e-3e812c7adfeb", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "class TransformerEmbedding(nn.Module):\n", 185 | " \"\"\"\n", 186 | " token embedding + positional encoding\n", 187 | " \"\"\"\n", 188 | "\n", 189 | " def __init__(self, vocab_size, d_model, max_len, drop_prob):\n", 190 | " \"\"\"\n", 191 | " 包含Embedding和位置编码的类\n", 192 | "\n", 193 | " :param vocab_size: 词汇表大小\n", 194 | " :param d_model: 模型的维度\n", 195 | " :param max_len: 最大序列长度\n", 196 | " :param drop_prob: dropout 正则化概率,防止过拟合\n", 197 | " \"\"\"\n", 198 | " super(TransformerEmbedding, self).__init__()\n", 199 | " self.tok_emb = TokenEmbedding(vocab_size, d_model)\n", 200 | " self.pos_emb = PositionalEncoding(d_model, max_len)\n", 201 | " self.drop_out = nn.Dropout(p=drop_prob)\n", 202 | "\n", 203 | " def forward(self, x):\n", 204 | " tok_emb = self.tok_emb(x)\n", 205 | " pos_emb = self.pos_emb(x) \n", 206 | " return self.drop_out(tok_emb + pos_emb)" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 8, 212 | "id": "2e05f345-5473-4d66-b8c8-f9ed5bf08185", 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "res: tensor([[[ 1.6016e+00, 0.0000e+00, 7.3131e-01, ..., 3.4665e-01,\n", 220 | " 1.2576e-01, 1.7100e+00],\n", 221 | " [ 5.4879e-01, 7.7141e-01, 1.6291e+00, ..., 0.0000e+00,\n", 222 | " -0.0000e+00, 2.7418e+00],\n", 223 | " [ 0.0000e+00, 1.3316e+00, 0.0000e+00, ..., -1.0642e+00,\n", 224 | " -6.8040e-02, -4.0913e-01]],\n", 225 | "\n", 226 | " [[-1.1485e+00, 3.1954e-01, 0.0000e+00, ..., 0.0000e+00,\n", 227 | " 1.0274e+00, 1.3045e+00],\n", 228 | " [-5.3425e-02, 2.9367e-01, 5.0048e-01, ..., 1.2340e+00,\n", 229 | " 1.0360e+00, -7.7892e-01],\n", 230 | " [ 1.0103e+00, -4.6239e-01, 1.0405e+00, ..., 1.1111e+00,\n", 231 | " 2.3036e-04, 1.1111e+00]]], grad_fn=)\n", 232 | "res.shape: torch.Size([2, 3, 512])\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "te = TransformerEmbedding(1000,512,512,0.1)\n", 238 | "# x是batch_size为2, seq_len为3\n", 239 | "x = torch.LongTensor([[6, 5, 4], [3, 2, 1]])\n", 240 | "res = te.forward(x)\n", 241 | "print(\"res:\", res)\n", 242 | "# 返回的形状是[batch_size = 2, seq_len = 3, d_model = 512]\n", 243 | "print(\"res.shape:\", res.shape)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 9, 249 | "id": "5ae6f88d-bbd1-4789-9f8f-7b17d9df5d6d", 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "class ScaleDotProductAttention(nn.Module):\n", 254 | " \"\"\"\n", 255 | " 计算单个点积注意力\n", 256 | " \"\"\"\n", 257 | "\n", 258 | " def __init__(self):\n", 259 | " super(ScaleDotProductAttention, self).__init__()\n", 260 | " self.softmax = nn.Softmax(dim=-1)\n", 261 | "\n", 262 | " def forward(self, q, k, v, mask=None, e=1e-12):\n", 263 | " # 输入是一个4维的张量\n", 264 | " # [batch_size, head, length, d_tensor]\n", 265 | " batch_size, head, length, d_tensor = k.size()\n", 266 | "\n", 267 | " # 1.用Key的转置与Query计算点积\n", 268 | " k_t = k.transpose(2, 3) # transpose\n", 269 | " score = (q @ k_t) / math.sqrt(d_tensor) # scaled dot product\n", 270 | "\n", 271 | " # 2.进行掩码,encoder不需要进行掩码,decoder需要进行掩码\n", 272 | " if mask is not None:\n", 273 | " score = score.masked_fill(mask == 0, -10000) \n", 274 | "\n", 275 | " # 3.通过softmax使分数范围在[0, 1]之间\n", 276 | " score = self.softmax(score)\n", 277 | "\n", 278 | " # 4.再与Value相乘 \n", 279 | " v = score @ v\n", 280 | "\n", 281 | " return v, score" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 10, 287 | "id": "c3d594e3-af75-45a0-8558-8174ab9bf1a4", 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "class MultiHeadAttention(nn.Module):\n", 292 | "\n", 293 | " def __init__(self, d_model, n_head):\n", 294 | " super(MultiHeadAttention, self).__init__()\n", 295 | " self.n_head = n_head\n", 296 | " self.attention = ScaleDotProductAttention()\n", 297 | " self.w_q = nn.Linear(d_model, d_model)\n", 298 | " self.w_k = nn.Linear(d_model, d_model)\n", 299 | " self.w_v = nn.Linear(d_model, d_model)\n", 300 | " self.w_concat = nn.Linear(d_model, d_model)\n", 301 | "\n", 302 | " def forward(self, q, k, v, mask=None):\n", 303 | " # 1.点积相应的矩阵\n", 304 | " q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)\n", 305 | "\n", 306 | " # 2.根据头数进行维度拆分\n", 307 | " q, k, v = self.split(q), self.split(k), self.split(v)\n", 308 | "\n", 309 | " # 3.进行计算\n", 310 | " out, attention = self.attention(q, k, v, mask=mask)\n", 311 | "\n", 312 | " # 4.把拆分的多头再拼起来\n", 313 | " out = self.concat(out)\n", 314 | " out = self.w_concat(out)\n", 315 | "\n", 316 | " return out\n", 317 | "\n", 318 | " def split(self, tensor):\n", 319 | " \"\"\"\n", 320 | " 根据头数进行维度拆分\n", 321 | "\n", 322 | " :param tensor: [batch_size, length, d_model]\n", 323 | " :return: [batch_size, head, length, d_tensor]\n", 324 | " \"\"\"\n", 325 | " batch_size, length, d_model = tensor.size()\n", 326 | "\n", 327 | " d_tensor = d_model // self.n_head\n", 328 | " tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2)\n", 329 | "\n", 330 | " return tensor\n", 331 | "\n", 332 | " def concat(self, tensor):\n", 333 | " \"\"\"\n", 334 | " 把拆分的多头再拼起来\n", 335 | "\n", 336 | " :param tensor: [batch_size, head, length, d_tensor]\n", 337 | " :return: [batch_size, length, d_model]\n", 338 | " \"\"\"\n", 339 | " batch_size, head, length, d_tensor = tensor.size()\n", 340 | " d_model = head * d_tensor\n", 341 | "\n", 342 | " tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)\n", 343 | " return tensor" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 11, 349 | "id": "2eeb6cc0-69b8-4ae3-b398-d9fc09bfef4a", 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "class PositionwiseFeedForward(nn.Module):\n", 354 | "\n", 355 | " def __init__(self, d_model, hidden, drop_prob=0.1):\n", 356 | " super(PositionwiseFeedForward, self).__init__()\n", 357 | " self.linear1 = nn.Linear(d_model, hidden)\n", 358 | " self.linear2 = nn.Linear(hidden, d_model)\n", 359 | " self.relu = nn.ReLU()\n", 360 | " self.dropout = nn.Dropout(p=drop_prob)\n", 361 | "\n", 362 | " def forward(self, x):\n", 363 | " x = self.linear1(x)\n", 364 | " x = self.relu(x)\n", 365 | " x = self.dropout(x)\n", 366 | " x = self.linear2(x)\n", 367 | " return x" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 12, 373 | "id": "d91d170e-5bca-4b0b-9a0e-ff289055620c", 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "class LayerNorm(nn.Module):\n", 378 | " def __init__(self, d_model, eps=1e-12):\n", 379 | " super(LayerNorm, self).__init__()\n", 380 | " self.alpha = nn.Parameter(torch.ones(d_model))\n", 381 | " self.bias = nn.Parameter(torch.zeros(d_model))\n", 382 | " self.eps = eps\n", 383 | "\n", 384 | " def forward(self, x):\n", 385 | " mean = x.mean(-1, keepdim=True)\n", 386 | " var = x.var(-1, unbiased=False, keepdim=True)\n", 387 | "\n", 388 | " out = (x - mean) / torch.sqrt(var + self.eps)\n", 389 | " out = self.alpha * out + self.bias\n", 390 | " return out" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 13, 396 | "id": "0f0dc0bf-972d-482e-91ad-538e8d6c4b76", 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "class EncoderLayer(nn.Module):\n", 401 | "\n", 402 | " def __init__(self, d_model, ffn_hidden, n_head, drop_prob):\n", 403 | " super(EncoderLayer, self).__init__()\n", 404 | " self.attention = MultiHeadAttention(d_model=d_model, n_head=n_head)\n", 405 | " self.norm1 = LayerNorm(d_model=d_model)\n", 406 | " self.dropout1 = nn.Dropout(p=drop_prob)\n", 407 | "\n", 408 | " self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)\n", 409 | " self.norm2 = LayerNorm(d_model=d_model)\n", 410 | " self.dropout2 = nn.Dropout(p=drop_prob)\n", 411 | "\n", 412 | " def forward(self, x, src_mask):\n", 413 | " # 1.计算注意力\n", 414 | " _x = x\n", 415 | " x = self.attention(q=x, k=x, v=x, mask=src_mask)\n", 416 | "\n", 417 | " # 2.残差连接和层归一化\n", 418 | " x = self.dropout1(x)\n", 419 | " x = self.norm1(x + _x)\n", 420 | "\n", 421 | " # 3.前馈层\n", 422 | " _x = x\n", 423 | " x = self.ffn(x)\n", 424 | "\n", 425 | " # 4.最后一次残差连接和层归一化\n", 426 | " x = self.dropout2(x)\n", 427 | " x = self.norm2(x + _x)\n", 428 | " return x" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": 14, 434 | "id": "e739d3da-ad98-4e77-bcc3-6a1e5e2af421", 435 | "metadata": {}, 436 | "outputs": [], 437 | "source": [ 438 | "class Encoder(nn.Module):\n", 439 | "\n", 440 | " def __init__(self, enc_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob):\n", 441 | " super().__init__()\n", 442 | " self.emb = TransformerEmbedding(d_model=d_model,\n", 443 | " max_len=max_len,\n", 444 | " vocab_size=enc_voc_size,\n", 445 | " drop_prob=drop_prob)\n", 446 | "\n", 447 | " self.layers = nn.ModuleList([EncoderLayer(d_model=d_model,\n", 448 | " ffn_hidden=ffn_hidden,\n", 449 | " n_head=n_head,\n", 450 | " drop_prob=drop_prob)\n", 451 | " for _ in range(n_layers)])\n", 452 | "\n", 453 | " def forward(self, x, src_mask):\n", 454 | " x = self.emb(x)\n", 455 | "\n", 456 | " for layer in self.layers:\n", 457 | " x = layer(x, src_mask)\n", 458 | "\n", 459 | " return x" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": 15, 465 | "id": "1ae30ffe-e0a6-487e-a035-96bbf2abe09a", 466 | "metadata": {}, 467 | "outputs": [], 468 | "source": [ 469 | "class DecoderLayer(nn.Module):\n", 470 | "\n", 471 | " def __init__(self, d_model, ffn_hidden, n_head, drop_prob):\n", 472 | " super(DecoderLayer, self).__init__()\n", 473 | " self.self_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)\n", 474 | " self.norm1 = LayerNorm(d_model=d_model)\n", 475 | " self.dropout1 = nn.Dropout(p=drop_prob)\n", 476 | "\n", 477 | " self.enc_dec_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)\n", 478 | " self.norm2 = LayerNorm(d_model=d_model)\n", 479 | " self.dropout2 = nn.Dropout(p=drop_prob)\n", 480 | "\n", 481 | " self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)\n", 482 | " self.norm3 = LayerNorm(d_model=d_model)\n", 483 | " self.dropout3 = nn.Dropout(p=drop_prob)\n", 484 | "\n", 485 | " def forward(self, dec, enc, trg_mask, src_mask): \n", 486 | " # 1.对应上面说的第一点\n", 487 | " _x = dec\n", 488 | " x = self.self_attention(q=dec, k=dec, v=dec, mask=trg_mask)\n", 489 | "\n", 490 | " # 2.残差连接和层归一化\n", 491 | " x = self.dropout1(x)\n", 492 | " x = self.norm1(x + _x)\n", 493 | "\n", 494 | " if enc is not None:\n", 495 | " # 3.对应上面说的第二点\n", 496 | " _x = x\n", 497 | " x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=src_mask)\n", 498 | "\n", 499 | " # 4.残差连接和层归一化\n", 500 | " x = self.dropout2(x)\n", 501 | " x = self.norm2(x + _x)\n", 502 | "\n", 503 | " # 5.前馈层\n", 504 | " _x = x\n", 505 | " x = self.ffn(x)\n", 506 | "\n", 507 | " # 6.残差连接和层归一化\n", 508 | " x = self.dropout3(x)\n", 509 | " x = self.norm3(x + _x)\n", 510 | " return x" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": 16, 516 | "id": "a462ceab-1c6c-47af-96ee-09c8b7bcd835", 517 | "metadata": {}, 518 | "outputs": [], 519 | "source": [ 520 | "class Decoder(nn.Module):\n", 521 | " def __init__(self, dec_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob):\n", 522 | " super().__init__()\n", 523 | " self.emb = TransformerEmbedding(d_model=d_model,\n", 524 | " drop_prob=drop_prob,\n", 525 | " max_len=max_len,\n", 526 | " vocab_size=dec_voc_size)\n", 527 | "\n", 528 | " self.layers = nn.ModuleList([DecoderLayer(d_model=d_model,\n", 529 | " ffn_hidden=ffn_hidden,\n", 530 | " n_head=n_head,\n", 531 | " drop_prob=drop_prob)\n", 532 | " for _ in range(n_layers)])\n", 533 | "\n", 534 | " self.linear = nn.Linear(d_model, dec_voc_size)\n", 535 | "\n", 536 | " def forward(self, trg, src, trg_mask, src_mask):\n", 537 | " trg = self.emb(trg)\n", 538 | "\n", 539 | " for layer in self.layers:\n", 540 | " trg = layer(trg, src, trg_mask, src_mask)\n", 541 | "\n", 542 | " # 最后经过一个全连接层\n", 543 | " output = self.linear(trg)\n", 544 | " return output" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": 17, 550 | "id": "9b6c8832-6e42-457b-a472-913223f9d415", 551 | "metadata": {}, 552 | "outputs": [], 553 | "source": [ 554 | "class Transformer(nn.Module):\n", 555 | "\n", 556 | " def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len,\n", 557 | " ffn_hidden, n_layers, drop_prob):\n", 558 | " super().__init__()\n", 559 | " self.src_pad_idx = src_pad_idx\n", 560 | " self.trg_pad_idx = trg_pad_idx\n", 561 | " self.trg_sos_idx = trg_sos_idx\n", 562 | " self.encoder = Encoder(d_model=d_model,\n", 563 | " n_head=n_head,\n", 564 | " max_len=max_len,\n", 565 | " ffn_hidden=ffn_hidden,\n", 566 | " enc_voc_size=enc_voc_size,\n", 567 | " drop_prob=drop_prob,\n", 568 | " n_layers=n_layers)\n", 569 | "\n", 570 | " self.decoder = Decoder(d_model=d_model,\n", 571 | " n_head=n_head,\n", 572 | " max_len=max_len,\n", 573 | " ffn_hidden=ffn_hidden,\n", 574 | " dec_voc_size=dec_voc_size,\n", 575 | " drop_prob=drop_prob,\n", 576 | " n_layers=n_layers)\n", 577 | "\n", 578 | " def forward(self, src, trg):\n", 579 | " src_mask = self.make_src_mask(src)\n", 580 | " trg_mask = self.make_trg_mask(trg)\n", 581 | " enc_src = self.encoder(src, src_mask)\n", 582 | " output = self.decoder(trg, enc_src, trg_mask, src_mask)\n", 583 | " return output\n", 584 | "\n", 585 | " def make_src_mask(self, src):\n", 586 | " \"\"\"\n", 587 | " 创建源序列(src)的掩码, 将pad补零的位置设为False\n", 588 | " \"\"\"\n", 589 | " src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)\n", 590 | " return src_mask\n", 591 | "\n", 592 | " def make_trg_mask(self, trg):\n", 593 | " \"\"\"\n", 594 | " 创建目标序列(trg)的掩码, 1.pad补零的位置设为False;\n", 595 | " 2.创建一个下三角矩阵,这个矩阵的对角线及以下的为为True,其余位置为False\n", 596 | " 表示在训练时模型只能依赖于当前和过去的信息,不能依赖未来的信息\n", 597 | " \"\"\"\n", 598 | " trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3)\n", 599 | " trg_len = trg.shape[1]\n", 600 | " trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).type(torch.ByteTensor).to(self.device)\n", 601 | " trg_mask = trg_pad_mask & trg_sub_mask\n", 602 | " return trg_mask" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": 19, 608 | "id": "dd10f763-e89a-4668-bf0d-290a9de3a018", 609 | "metadata": {}, 610 | "outputs": [ 611 | { 612 | "ename": "NameError", 613 | "evalue": "name 'device' is not defined", 614 | "output_type": "error", 615 | "traceback": [ 616 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 617 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 618 | "Cell \u001b[0;32mIn[19], line 18\u001b[0m\n\u001b[1;32m 15\u001b[0m enc_voc_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m32000\u001b[39m\n\u001b[1;32m 16\u001b[0m dec_voc_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m25000\u001b[39m\n\u001b[0;32m---> 18\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mTransformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43msrc_pad_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msrc_pad_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrg_pad_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrg_pad_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrg_sos_idx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrg_sos_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43md_model\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43md_model\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43menc_voc_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43menc_voc_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mdec_voc_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdec_voc_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_len\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_len\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 25\u001b[0m \u001b[43m \u001b[49m\u001b[43mffn_hidden\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mffn_hidden\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_head\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_heads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 27\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_layers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 28\u001b[0m \u001b[43m \u001b[49m\u001b[43mdrop_prob\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdrop_prob\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcount_parameters\u001b[39m(model):\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msum\u001b[39m(p\u001b[38;5;241m.\u001b[39mnumel() \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m model\u001b[38;5;241m.\u001b[39mparameters() \u001b[38;5;28;01mif\u001b[39;00m p\u001b[38;5;241m.\u001b[39mrequires_grad)\n", 619 | "Cell \u001b[0;32mIn[17], line 9\u001b[0m, in \u001b[0;36mTransformer.__init__\u001b[0;34m(self, src_pad_idx, trg_pad_idx, trg_sos_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len, ffn_hidden, n_layers, drop_prob)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrg_pad_idx \u001b[38;5;241m=\u001b[39m trg_pad_idx\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrg_sos_idx \u001b[38;5;241m=\u001b[39m trg_sos_idx\n\u001b[0;32m----> 9\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice \u001b[38;5;241m=\u001b[39m \u001b[43mdevice\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencoder \u001b[38;5;241m=\u001b[39m Encoder(d_model\u001b[38;5;241m=\u001b[39md_model,\n\u001b[1;32m 11\u001b[0m n_head\u001b[38;5;241m=\u001b[39mn_head,\n\u001b[1;32m 12\u001b[0m max_len\u001b[38;5;241m=\u001b[39mmax_len,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 15\u001b[0m drop_prob\u001b[38;5;241m=\u001b[39mdrop_prob,\n\u001b[1;32m 16\u001b[0m n_layers\u001b[38;5;241m=\u001b[39mn_layers)\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoder \u001b[38;5;241m=\u001b[39m Decoder(d_model\u001b[38;5;241m=\u001b[39md_model,\n\u001b[1;32m 19\u001b[0m n_head\u001b[38;5;241m=\u001b[39mn_head,\n\u001b[1;32m 20\u001b[0m max_len\u001b[38;5;241m=\u001b[39mmax_len,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 23\u001b[0m drop_prob\u001b[38;5;241m=\u001b[39mdrop_prob,\n\u001b[1;32m 24\u001b[0m n_layers\u001b[38;5;241m=\u001b[39mn_layers)\n", 620 | "\u001b[0;31mNameError\u001b[0m: name 'device' is not defined" 621 | ] 622 | } 623 | ], 624 | "source": [ 625 | "# 原文base模型参数\n", 626 | "max_len = 256\n", 627 | "d_model = 512\n", 628 | "n_layers = 6\n", 629 | "n_heads = 8\n", 630 | "ffn_hidden = 2048\n", 631 | "drop_prob = 0.1\n", 632 | "\n", 633 | "# 分词,词表的一些参数,和通过数据集训练的tokenizer是相关的,这里就简单给一下\n", 634 | "# 原文使用的数据集是WMT14 EN-DE,enc_voc_size为32000,dec_voc_size为25000,\n", 635 | "# 这个训练出的词表数值不一样,最终模型的参数也不一样,因为这两个参数会影响Embedding层的参数\n", 636 | "src_pad_idx = 1\n", 637 | "trg_pad_idx = 1\n", 638 | "trg_sos_idx = 2\n", 639 | "enc_voc_size = 32000\n", 640 | "dec_voc_size = 25000\n", 641 | "\n", 642 | "model = Transformer(src_pad_idx=src_pad_idx,\n", 643 | " trg_pad_idx=trg_pad_idx,\n", 644 | " trg_sos_idx=trg_sos_idx,\n", 645 | " d_model=d_model,\n", 646 | " enc_voc_size=enc_voc_size,\n", 647 | " dec_voc_size=dec_voc_size,\n", 648 | " max_len=max_len,\n", 649 | " ffn_hidden=ffn_hidden,\n", 650 | " n_head=n_heads,\n", 651 | " n_layers=n_layers,\n", 652 | " drop_prob=drop_prob)\n", 653 | "\n", 654 | "def count_parameters(model):\n", 655 | " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 656 | "\n", 657 | "# 模型的参数\n", 658 | "print(f'The model has {count_parameters(model):,} trainable parameters')\n", 659 | "\n", 660 | "# 模型的结构\n", 661 | "print(model)\n", 662 | "\n" 663 | ] 664 | }, 665 | { 666 | "cell_type": "code", 667 | "execution_count": null, 668 | "id": "b8303a9a-edd4-4683-b31d-58c701937370", 669 | "metadata": {}, 670 | "outputs": [], 671 | "source": [] 672 | } 673 | ], 674 | "metadata": { 675 | "kernelspec": { 676 | "display_name": "Python 3 (ipykernel)", 677 | "language": "python", 678 | "name": "python3" 679 | }, 680 | "language_info": { 681 | "codemirror_mode": { 682 | "name": "ipython", 683 | "version": 3 684 | }, 685 | "file_extension": ".py", 686 | "mimetype": "text/x-python", 687 | "name": "python", 688 | "nbconvert_exporter": "python", 689 | "pygments_lexer": "ipython3", 690 | "version": "3.10.6" 691 | } 692 | }, 693 | "nbformat": 4, 694 | "nbformat_minor": 5 695 | } 696 | --------------------------------------------------------------------------------