├── data └── alpaca │ ├── data-00000-of-00001.arrow │ ├── dataset_info.json │ └── state.json ├── requirements.txt ├── config └── ds_config.json ├── LICENSE ├── single_layer.py ├── arguments.py ├── tokenize_dataset_rows.py ├── configuration_chatglm.py ├── README.md ├── infer.py ├── finetune.py ├── tokenization_chatglm.py ├── quantization.py └── modeling_chatglm.py /data/alpaca/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uglyghost/ChatGLM-Peft-Tuning/HEAD/data/alpaca/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # int8 2 | bitsandbytes 3 | accelerate==0.17.1 4 | 5 | # chatglm 6 | protobuf>=3.19.5,<3.20.1 7 | transformers>=4.26.1 8 | icetk 9 | cpm_kernels 10 | torch>=1.10 11 | 12 | # 13 | datasets 14 | git+https://github.com/mymusise/peft.git@54b6ce2c0e19eed7fbba1f101c278adde64952ab -------------------------------------------------------------------------------- /data/alpaca/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "", 3 | "description": "", 4 | "features": { 5 | "input_ids": { 6 | "feature": { 7 | "dtype": "int32", 8 | "_type": "Value" 9 | }, 10 | "_type": "Sequence" 11 | } 12 | }, 13 | "homepage": "", 14 | "license": "" 15 | } -------------------------------------------------------------------------------- /data/alpaca/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "79ec697ac499d766", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": null 13 | } -------------------------------------------------------------------------------- /config/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 32, 3 | "train_micro_batch_size_per_gpu": 8, 4 | "gradient_accumulation_steps": 2, 5 | "fp16": { 6 | "enabled": true, 7 | "loss_scale": 0, 8 | "initial_scale_power": 16, 9 | "loss_scale_window": 1000, 10 | "hysteresis": 2, 11 | "min_loss_scale": 1 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "allgather_partitions": true, 16 | "allgather_bucket_size": 5e8, 17 | "overlap_comm": true, 18 | "reduce_scatter": true, 19 | "reduce_bucket_size": 5e8, 20 | "contiguous_gradients": true 21 | } 22 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Chengxi Guo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /single_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class single_layer(torch.nn.Module): 5 | def __init__(self, in_features, out_features): 6 | super(single_layer, self).__init__() 7 | self.linear_q = torch.nn.Linear(in_features, out_features // 3) 8 | self.linear_k = torch.nn.Linear(in_features, out_features // 3) 9 | self.linear_v = torch.nn.Linear(in_features, out_features // 3) 10 | 11 | def update(self, target_layer): 12 | self.linear_q.weight.data = target_layer.weight[:target_layer.out_features // 3, :].data 13 | self.linear_q.bias.data = target_layer.bias[:target_layer.out_features // 3].data 14 | 15 | self.linear_k.weight.data = target_layer.weight[ 16 | target_layer.out_features // 3:target_layer.out_features // 3 * 2, :].data 17 | self.linear_k.bias.data = target_layer.bias[ 18 | target_layer.out_features // 3:target_layer.out_features // 3 * 2].data 19 | 20 | self.linear_v.weight.data = target_layer.weight[target_layer.out_features // 3 * 2:, :].data 21 | self.linear_v.bias.data = target_layer.bias[target_layer.out_features // 3 * 2:].data 22 | 23 | def forward(self, x): 24 | q = self.linear_q(x) 25 | k = self.linear_k(x) 26 | v = self.linear_v(x) 27 | return torch.concat([q, k, v], dim=-1) 28 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | # 导入argparse模块 2 | import argparse 3 | 4 | # 创建ArgumentParser对象 5 | parser = argparse.ArgumentParser(description="ChatGLM model fine-tuning") 6 | 7 | '''tokenize_dataset_rows.py 参数配置''' 8 | # 对微调数据集做预处理,tokenized -> convert to json -> save to binary file. 9 | parser.add_argument("--jsonl_path", type=str, default="data/change_name.jsonl.txt") # json格式的数据集 10 | parser.add_argument("--save_path", type=str, default="data/alpaca") # 用于训练数据集的存储路径 11 | # parser.add_argument("--dataset_path", type=str, default="data/alpaca/dataset.pkl") 12 | parser.add_argument("--max_seq_length", type=int, default=512) # 样本文本的最大长度 13 | 14 | '''finetune.py 参数配置''' 15 | # training 16 | parser.add_argument("--continue_training", type=bool, default=True) # 是否在微调模型上继续训练 17 | parser.add_argument("--checkpoint_enable", type=bool, default=True) # 是否开启checkpoint功能 18 | parser.add_argument("--grads_enable", type=bool, default=True) # 启用输入梯度计算功能,支持高阶导数 19 | 20 | # LoRA是一种低秩适应大型语言模型的方法 21 | parser.add_argument("--lora_rank", type=int, default=8) # 低秩矩阵的秩 22 | parser.add_argument("--lora_alpha", type=int, default=32) # 控制低秩矩阵和原始矩阵之间权重平衡的系数 23 | parser.add_argument("--lora_dropout", type=float, default=0.1) # 防止过拟合的概率 24 | 25 | parser.add_argument("--per_device_train_batch_size", type=int, default=1) # 每个设备上的数据批次,显存足够可增加 26 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) # 多次计算得到的梯度值进行累加,一次性进行参数更新 27 | parser.add_argument("--max_steps", type=int, default=10000) # 最大训练迭代次数 28 | parser.add_argument("--save_steps", type=int, default=10) # checkpoint保存步长 29 | parser.add_argument("--save_total_limit", type=int, default=2) # 保存条目数量上限 30 | parser.add_argument("--learning_rate", type=float, default=2e-4) # 模型学习率 31 | parser.add_argument("--logging_steps", type=int, default=50) # 日志输出间隔 32 | parser.add_argument("--output_dir", type=str, default="output") # finetune模型 & checkpoint 存储目录 33 | 34 | '''infer.py 参数配置''' 35 | # parser.add_argument("--peft_path", type=str, default="output/chatglm-lora.pt") # finetune模型存储地址 36 | parser.add_argument("--max_length", type=int, default=512) # 最大输出长度 37 | parser.add_argument("--temperature", type=int, default=0) # 情感 38 | 39 | 40 | def get_args(): 41 | # 解析ArgumentParser对象,获得argparse.Namespace对象 42 | arguments = parser.parse_args() 43 | return arguments -------------------------------------------------------------------------------- /tokenize_dataset_rows.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torch.utils.data import Dataset 5 | from transformers import AutoTokenizer 6 | from arguments import get_args 7 | 8 | 9 | class AlpacaDataset(Dataset): 10 | def __init__(self, pairs, tokenizer, device) -> None: 11 | super().__init__() 12 | self.pairs = pairs 13 | self.tokenizer = tokenizer 14 | self.device = device 15 | self.EOS_ID = 150005 16 | 17 | def __getitem__(self, index): 18 | prompt = self.tokenizer.encode(self.pairs[index]['prompt']) 19 | completion = self.tokenizer.encode(self.pairs[index]['completion'], add_special_tokens=False) + [self.EOS_ID] 20 | 21 | seq = prompt + completion 22 | context_length = seq.index(150004) + 1 23 | 24 | attention_mask = torch.ones((len(seq), len(seq)), device=self.device ) 25 | attention_mask.tril_() 26 | attention_mask[..., :context_length - 1] = 1 27 | attention_mask.unsqueeze_(0) 28 | attention_mask = (attention_mask < 0.5).bool() 29 | 30 | position_ids = torch.stack([torch.arange(0, len(seq), device=self.device ), torch.concat( 31 | [torch.zeros(context_length - 2, device=self.device ), 32 | torch.arange(0, len(seq) - context_length + 2, device=self.device )])]).long() 33 | labels = torch.tensor([-100] * len(prompt) + completion, device=self.device ).long() 34 | 35 | return {'input_ids': seq, 'attention_mask': attention_mask, "labels": labels, 'position_ids': position_ids} 36 | 37 | def __len__(self): 38 | return len(self.pairs) 39 | 40 | 41 | def collate_fn(batch): 42 | input_ids = [] 43 | attention_mask = [] 44 | labels = [] 45 | position_ids = [] 46 | # TODO: padding for batch training 47 | for obj in batch: 48 | input_ids.append(obj['input_ids']) 49 | attention_mask.append(obj['attention_mask']) 50 | labels.append(obj['labels']) 51 | position_ids.append(obj['position_ids']) 52 | return {'input_ids': torch.tensor(input_ids).long(), 53 | 'attention_mask': torch.stack(attention_mask), 54 | 'labels': torch.stack(labels), 55 | 'position_ids':torch.stack(position_ids)} 56 | 57 | 58 | # 定义一个函数main,不接收任何参数 59 | def load_dataset(): 60 | # 解析命令行参数并赋值给args变量 61 | args = get_args() 62 | 63 | device = 'cuda' 64 | 65 | PROMPT_DICT = { 66 | "prompt_input": ( 67 | #"下面的指令介绍了一个任务问题,并且提供了上下文的输入。" 68 | #"请写一个合适的回复,回复指令中描述的问题。\n\n" 69 | "### 指令:\n{instruction}\n\n### 输入:\n{input}\n\n### 回复:" 70 | ), 71 | "prompt_no_input": ( 72 | #"下面的指令介绍了一个任务问题。" 73 | #"请写一个合适的回复,回复指令中描述的问题。\n\n" 74 | "### 指令:\n{instruction}\n\n### 回复:" 75 | ) 76 | } 77 | 78 | with open(args.jsonl_path, 'r') as f: 79 | content = json.load(f) 80 | 81 | pairs = [] 82 | 83 | for line in content: 84 | if line['input'] == '': 85 | prompt = PROMPT_DICT['prompt_no_input'].format_map(line) 86 | else: 87 | prompt = PROMPT_DICT['prompt_input'].format_map(line) 88 | completion = line['output'] 89 | pairs.append({'prompt': prompt, 'completion': completion}) 90 | 91 | # 从预训练模型"THUDM/chatglm-6b"加载分词器tokenizer,并信任远程代码 92 | tokenizer = AutoTokenizer.from_pretrained( 93 | "THUDM/chatglm-6b", trust_remote_code=True 94 | ) 95 | 96 | train_dataset = AlpacaDataset(pairs, tokenizer=tokenizer, device=device) 97 | # print(pairs) 98 | train_dataloader = DataLoader(dataset=train_dataset, collate_fn=collate_fn, shuffle=True, batch_size=1) 99 | 100 | # 打印生成了多少个样本 101 | print(f"Generated {len(train_dataloader.dataset)} samples.") 102 | 103 | return train_dataloader -------------------------------------------------------------------------------- /configuration_chatglm.py: -------------------------------------------------------------------------------- 1 | """ ChatGLM model configuration """ 2 | 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.utils import logging 5 | 6 | logger = logging.get_logger(__name__) 7 | 8 | 9 | class ChatGLMConfig(PretrainedConfig): 10 | r""" 11 | This is the configuration class to store the configuration of a [`~ChatGLMModel`]. 12 | It is used to instantiate an ChatGLM model according to the specified arguments, defining the model 13 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 14 | the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture. 15 | 16 | Configuration objects inherit from [`PretrainedConfig`] and can be used 17 | to control the model outputs. Read the documentation from [`PretrainedConfig`] 18 | for more information. 19 | 20 | 21 | Args: 22 | vocab_size (`int`, *optional*, defaults to 150528): 23 | Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the 24 | `inputs_ids` passed when calling [`~ChatGLMModel`] or 25 | [`~TFChatGLMModel`]. 26 | hidden_size (`int`, *optional*, defaults to 4096): 27 | Dimension of the encoder layers and the pooler layer. 28 | num_hidden_layers (`int`, *optional*, defaults to 28): 29 | Number of hidden layers in the Transformer encoder. 30 | num_attention_heads (`int`, *optional*, defaults to 32): 31 | Number of attention heads for each attention layer in the Transformer encoder. 32 | inner_hidden_size (`int`, *optional*, defaults to 16384): 33 | Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 34 | max_sequence_length (`int`, *optional*, defaults to 512): 35 | The maximum sequence length that this model might ever be used with. 36 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 37 | layernorm_epsilon (`float`, *optional*, defaults to 1e-5): 38 | The epsilon used by the layer normalization layers. 39 | use_cache (`bool`, *optional*, defaults to `True`): 40 | Whether the model should return the last key/values attentions (not used by all models). 41 | Example: 42 | 43 | ```python 44 | >>> from configuration_chatglm import ChatGLMConfig 45 | >>> from modeling_chatglm import ChatGLMModel 46 | 47 | >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration 48 | >>> configuration = ChatGLMConfig() 49 | 50 | >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration 51 | >>> model = ChatGLMModel(configuration) 52 | 53 | >>> # Accessing the model configuration 54 | >>> configuration = model.config 55 | ``` 56 | """ 57 | model_type = "chatglm" 58 | 59 | def __init__( 60 | self, 61 | vocab_size=150528, 62 | hidden_size=4096, 63 | num_layers=28, 64 | num_attention_heads=32, 65 | layernorm_epsilon=1e-5, 66 | use_cache=False, 67 | bos_token_id=150004, 68 | eos_token_id=150005, 69 | pad_token_id=0, 70 | max_sequence_length=2048, 71 | inner_hidden_size=16384, 72 | position_encoding_2d=True, 73 | **kwargs 74 | ): 75 | self.num_layers = num_layers 76 | self.vocab_size = vocab_size 77 | self.hidden_size = hidden_size 78 | self.num_attention_heads = num_attention_heads 79 | self.max_sequence_length = max_sequence_length 80 | self.layernorm_epsilon = layernorm_epsilon 81 | self.inner_hidden_size = inner_hidden_size 82 | self.use_cache = use_cache 83 | self.bos_token_id = bos_token_id 84 | self.eos_token_id = eos_token_id 85 | self.pad_token_id = pad_token_id 86 | self.position_encoding_2d = position_encoding_2d 87 | super().__init__( 88 | pad_token_id=pad_token_id, 89 | bos_token_id=bos_token_id, 90 | eos_token_id=eos_token_id, 91 | **kwargs 92 | ) 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ChatGLM-Peft-Tuning 2 | 3 | 该项目基于清华的 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 进行finetune. 4 | 基于项目 [mymusise](https://github.com/mymusise/ChatGLM-Tuning) 修改 5 | 6 | 7 | 特别鸣谢! 8 | 9 | 10 | ## 测试环境 11 | 12 | - 显卡: GTX 3090 (24G) & A100 (40G) 13 | - 系统: Windows 11 & Ubuntu 18.04 14 | - 建议: python >=3.8, CUDA 11.2+ 15 | - 环境: pip install -r requirements.txt 16 | 17 | ### Windows 注意 18 | - Windows环境运行train.py时,bitsandbytes会报错: 19 | 20 | `CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment!` 21 | 22 | 23 | - bitsandbytes: 轻量级的CUDA自定义函数的包装器,主要提供了8位优化器、矩阵乘法(LLM.int8())和量化函数12。它可以用于PyTorch框架,提高深度学习模型的训练速度和效率2。 24 | 25 | 26 | - 解决方法: 27 | 1. put `libbitsandbytes_cuda116.dll` in 28 | 29 | `C:\Users\xxx\miniconda3\envs\textgen\lib\site-packages\bitsandbytes\` 30 | 2. edit `\bitsandbytes\cuda_setup\main.py`. search for: 31 | 32 | `if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None` 33 | 34 | replace with: 35 | 36 | ` 37 | if torch.cuda.is_available(): return 'libbitsandbytes_cuda116.dll', None, None, None, None 38 | ` 39 | 3. search for this twice: 40 | 41 | `self.lib = ct.cdll.LoadLibrary(binary_path)` 42 | 43 | replace with: 44 | 45 | `self.lib = ct.cdll.LoadLibrary(str(binary_path))` 46 | 47 | 48 | ### 项目概述 49 | - data 50 | - alpaca 51 | - ... 52 | - data.json 53 | - output 54 | - checkpoint 55 | - ... 56 | - chatglm-lora.pt 57 | - ... 58 | - tokenize_dataset_rows.py 59 | - finetune.py 60 | - infer.py 61 | - ... 62 | 63 | ## Pretreatment 64 | 65 | ```bash 66 | python tokenize_dataset_rows.py 67 | ``` 68 | 69 | 配置参数见 `arguments.py` '''tokenize_dataset_rows.py 参数配置''' 70 | - `--jsonl_path` 微调的数据路径, 格式jsonl, 对每行的['text']字段进行encode 71 | - `--save_path` 用于训练数据集的存储路径 72 | - `--max_seq_length` 样本文本的最大长度 73 | 74 | ## Finetune 75 | 76 | ```bash 77 | python finetune.py --save_total_limit 2 --dataset_path data/alpaca --lora_rank 8 --per_device_train_batch_size 1 --gradient_accumulation_steps 1 --max_steps 52000 --save_steps 1000 --learning_rate 2e-5 --logging_steps 50 --output_dir output 78 | ``` 79 | 80 | 配置参数见 `arguments.py` '''finetune.py 参数配置''' 81 | - `--dataset_path` 字符串集合的json格式的数据集 82 | - `--per_device_train_batch_size` 每个设备上的数据批次,显存足够可增加 83 | - `--gradient_accumulation_steps` 多次计算得到的梯度值进行累加,一次性进行参数更新 84 | - `--max_steps` 最大训练迭代次数 85 | - `--save_steps` checkpoint保存步长 86 | - `--save_total_limit` 保存条目数量上限 87 | - `--learning_rate` 模型学习率 88 | - `--logging_steps` 日志输出间隔 89 | - `--output_dir` finetune模型 & checkpoint 存储目录 90 | - `--fp16` 91 | 92 | LORA 93 | 94 | - `--lora_rank` 低秩矩阵的秩 95 | - `--lora_alpha` 控制低秩矩阵和原始矩阵之间权重平衡的系数 96 | - `--lora_dropout` 防止过拟合 97 | 98 | 99 | # Infer 100 | ```bash 101 | python infer.py 102 | ``` 103 | 104 | 配置参数见 `arguments.py` '''infer.py 参数配置''' 105 | - `--peft_path` finetune模型存储地址 106 | - `--max_length` 最大输出长度 107 | - `--temperature` 情感 108 | 109 | # Datasets 110 | - 数据集: [alpaca](https://github.com/tatsu-lab/stanford_alpaca) 111 | - 中文财经类QA数据集 112 | 113 | # TODO: 114 | 115 | - 数据/模型/张量并行 使用GLM pretrain 和 finetune 实现? 116 | - 使用RLHF 参考 [trlx](https://github.com/CarperAI/trlx) 117 | 118 | 119 | ## Cite 120 | 121 | 清华开源项目,参考引用下列论文 122 | 123 | ``` 124 | @inproceedings{ 125 | zeng2023glm-130b, 126 | title={{GLM}-130B: An Open Bilingual Pre-trained Model}, 127 | author={Aohan Zeng and Xiao Liu and Zhengxiao Du and Zihan Wang and Hanyu Lai and Ming Ding and Zhuoyi Yang and Yifan Xu and Wendi Zheng and Xiao Xia and Weng Lam Tam and Zixuan Ma and Yufei Xue and Jidong Zhai and Wenguang Chen and Zhiyuan Liu and Peng Zhang and Yuxiao Dong and Jie Tang}, 128 | booktitle={The Eleventh International Conference on Learning Representations (ICLR)}, 129 | year={2023}, 130 | url={https://openreview.net/forum?id=-Aw0rrrPUF} 131 | } 132 | ``` 133 | 134 | ``` 135 | @inproceedings{du2022glm, 136 | title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling}, 137 | author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie}, 138 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, 139 | pages={320--335}, 140 | year={2022} 141 | } 142 | ``` -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from modeling_chatglm import ChatGLMForConditionalGeneration 3 | import torch 4 | from peft import get_peft_model, LoraConfig, TaskType, tuners 5 | from arguments import get_args 6 | from single_layer import single_layer 7 | import os 8 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 9 | 10 | 11 | # 定义主函数 12 | def main(): 13 | # 解析命令行参数并赋值给args变量 14 | args = get_args() 15 | 16 | # reload the model 17 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) 18 | # model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) 19 | model = ChatGLMForConditionalGeneration.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, 20 | device_map='auto') 21 | 22 | # 设置peft配置,包括任务类型、推理模式、秩、alpha值和dropout率等参数 23 | peft_config = LoraConfig( 24 | peft_type="LORA", 25 | task_type="SEQ_2_SEQ_LM", 26 | # inference_mode=False, 27 | r=args.lora_rank, 28 | lora_alpha=args.lora_alpha, 29 | lora_dropout=args.lora_dropout, 30 | target_modules=["q", "k", "v"] 31 | ) 32 | # convert it again 33 | for key, module in model.named_modules(): 34 | if key.endswith('attention'): 35 | try: 36 | qkv_layer = single_layer(module.query_key_value.in_features, module.query_key_value.out_features) 37 | qkv_layer.update(module.query_key_value) 38 | module.query_key_value = qkv_layer 39 | except: 40 | print('no') 41 | pass 42 | module.query_key_value = tuners.lora.LoraModel(peft_config, module.query_key_value) 43 | 44 | # load the LoRA checkpoint 45 | model.load_state_dict(torch.load('output_finetune_99.pt'), strict=False) 46 | 47 | model.half().cuda().eval() 48 | 49 | # Let's chat! 50 | ''' 51 | response, history = model.chat(tokenizer, "你是谁?", history=[]) 52 | print(response) 53 | response, history = model.chat(tokenizer, "西南财经大学副校长是谁?", history=[]) 54 | print(response) 55 | response, history = model.chat(tokenizer, "西南财经大学校长是谁?", history=[]) 56 | print(response) 57 | ''' 58 | 59 | # return 60 | 61 | # torch.set_default_tensor_type(torch.cuda.HalfTensor) 62 | # model = ChatGLMForConditionalGeneration.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, device_map='auto') 63 | ''' 64 | peft_config = LoraConfig( 65 | task_type=TaskType.CAUSAL_LM, 66 | inference_mode=False, 67 | r=args.lora_rank, 68 | lora_alpha=args.lora_alpha, 69 | lora_dropout=args.lora_dropout 70 | ) 71 | 72 | model = get_peft_model(model, peft_config) 73 | model.load_state_dict(torch.load(args.peft_path), strict=False) 74 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 75 | 76 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) 77 | ''' 78 | # instructions = json.load(open("data/alpaca_data.json")) 79 | 80 | instructions = [ 81 | { 82 | 'instruction': "西南财经大学校长是谁?", 83 | "output": "西南财经大学的校长是卓志。他于2018年1月开始担任这一职务,并且是经济学博士和教授。他主要从事商业保险、风险管理和精算等领域的研究和高教管理。", 84 | }, 85 | { 86 | 'instruction': "西南财经大学副校长是谁?", 87 | "output": "西南财经大学现有两位副校长,分别是张邦富1和李志生。张邦富是党委常委、副校长,主要负责学校的教学、科研、人才培养等工作。李志生是党委常委、副校长,于2022年7月28日正式任职,主要负责学校的发展规划、国际合作与交流等工作。", 88 | } 89 | ] 90 | 91 | answers = [] 92 | 93 | with torch.no_grad(): 94 | for idx, item in enumerate(instructions[:5]): 95 | input_text = f"### {idx+1}.指令:\n{item['instruction']}\n\n" 96 | if item.get('input'): 97 | input_text += f"### {idx+1}.输入:\n{item['input']}\n\n" 98 | input_text += f"### {idx+1}.回复:" 99 | # print(input_text) 100 | batch = tokenizer(input_text, return_tensors="pt") 101 | out = model.generate( 102 | input_ids=batch["input_ids"], 103 | attention_mask=torch.ones_like(batch["input_ids"]).bool(), 104 | max_length=args.max_length, 105 | temperature=args.temperature 106 | ) 107 | out_text = tokenizer.decode(out[0]) 108 | answer = out_text.replace(input_text, "").replace("\nEND", "").strip() 109 | item['infer_answer'] = answer 110 | print(out_text) 111 | # print(f"### {idx+1}.Answer:\n", item.get('output'), '\n\n') 112 | answers.append({'index': idx, **item}) 113 | 114 | 115 | if __name__ == "__main__": 116 | main() -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments 2 | from transformers import Trainer, HfArgumentParser, get_linear_schedule_with_warmup, AutoModel 3 | from modeling_chatglm import ChatGLMForConditionalGeneration 4 | import torch 5 | import torch.nn as nn 6 | from peft import get_peft_model, LoraConfig, tuners 7 | from dataclasses import dataclass, field 8 | import os 9 | from arguments import get_args 10 | from single_layer import single_layer 11 | import loralib as lora 12 | import numpy as np 13 | from tokenize_dataset_rows import load_dataset 14 | from torch.cuda.amp import autocast 15 | import tqdm 16 | 17 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 18 | 19 | 20 | @dataclass # 使用dataclass装饰器,自动生成__init__等特殊方法 21 | class FinetuneArguments: # 定义一个数据类,用于存储微调的参数 22 | dataset_path: str = field(default="data/alpaca") # 数据集路径,默认为"data/alpaca" 23 | model_path: str = field(default="output") # 模型路径,默认为"output" 24 | lora_rank: int = field(default=8) # LoRA的秩,默认为8 25 | 26 | 27 | class CastOutputToFloat(nn.Sequential): # 定义一个继承自nn.Sequential的类,用于将输出转换为浮点类型 28 | def forward(self, x): return super().forward(x).to( 29 | torch.float32) # 重写forward方法,调用父类的forward方法,并将结果转换为torch.float32类型 30 | 31 | 32 | class ModifiedTrainer(Trainer): # 定义一个继承自Trainer的类,用于修改计算损失函数 33 | 34 | def compute_loss(self, model, inputs, return_outputs=False): # 重写compute_loss方法,输入模型和输入数据 35 | return model( # 返回模型的输出 36 | input_ids=inputs["input_ids"], # 输入id 37 | attention_mask=torch.ones_like(inputs["input_ids"]).bool(), # 注意力掩码,全1矩阵 38 | labels=inputs["input_ids"], # 标签和输入id相同 39 | ).loss # 输出损失值 40 | 41 | 42 | def data_collator(features: list) -> dict: # 定义一个函数,用于将特征列表转换为字典格式 43 | return { 44 | "input_ids": torch.stack([ # 返回一个键为"input_ids"的字典,值为特征列表中每个元素的"input_ids"属性组成的张量堆叠 45 | torch.LongTensor(f["input_ids"]) 46 | for f in features 47 | ]) 48 | } 49 | 50 | 51 | def save_tunable_parameters(model, path): # 定义一个函数,用于保存模型中可调节的参数到指定路径 52 | saved_params = { # 创建一个字典,存储模型中需要梯度的参数 53 | k: v.to("cpu") # 将参数值转换为cpu类型 54 | for k, v in model.named_parameters() # 遍历模型中命名的参数 55 | if v.requires_grad # 如果参数需要梯度 56 | } 57 | torch.save(saved_params, path) # 使用torch.save函数,将字典保存到路径 58 | 59 | 60 | # 定义主函数 61 | def main(): 62 | # 解析命令行参数并赋值给args变量 63 | args = get_args() 64 | 65 | # finetune_args, training_args = HfArgumentParser( 66 | # (FinetuneArguments, TrainingArguments)).parse_args_into_dataclasses() 67 | 68 | ''' 69 | training_args = TrainingArguments( 70 | output_dir=args.output_dir, 71 | per_device_train_batch_size=args.per_device_train_batch_size, 72 | logging_dir="./logs", # directory for storing logs 73 | fp16=True, 74 | do_train=True, 75 | gradient_accumulation_steps=args.gradient_accumulation_steps, 76 | learning_rate=args.learning_rate, 77 | save_steps=args.save_steps, 78 | max_steps=args.max_steps, 79 | ) 80 | ''' 81 | 82 | # 从预训练模型"THUDM/chatglm-6b"加载模型,并设置一些参数 83 | model = ChatGLMForConditionalGeneration.from_pretrained( 84 | # model = AutoModel.from_pretrained( 85 | "THUDM/chatglm-6b", 86 | # load_in_8bit=True, # 使用8位精度加载模型,节省内存 87 | trust_remote_code=True, # 信任远程代码,允许执行自定义操作 88 | device_map='auto') # 自动分配设备映射 89 | 90 | if args.checkpoint_enable: 91 | model.gradient_checkpointing_enable() # 启用梯度检查点功能,减少内存占用 92 | 93 | if args.grads_enable: 94 | model.enable_input_require_grads() # 启用输入梯度计算功能,支持高阶导数 95 | 96 | model.is_parallelizable = True # 设置模型为可并行化 97 | model.model_parallel = True # 设置模型为并行模式 98 | model.lm_head = CastOutputToFloat(model.lm_head) # 将输出层转换为浮点类型,提高精度 99 | model.config.use_cache = False # 关闭缓存功能 100 | 101 | # 设置peft配置,包括任务类型、推理模式、秩、alpha值和dropout率等参数 102 | peft_config = LoraConfig( 103 | peft_type="LORA", 104 | task_type="SEQ_2_SEQ_LM", 105 | inference_mode=False, 106 | r=args.lora_rank, 107 | lora_alpha=args.lora_alpha, 108 | lora_dropout=args.lora_dropout, 109 | target_modules=["q", "k", "v"] 110 | ) 111 | 112 | for key, module in model.named_modules(): 113 | if key.endswith('attention'): 114 | try: 115 | # Here we split the query_key_value layer into three linear layer for LoRA. But you can also use merged linear. 116 | qkv_layer = single_layer(module.query_key_value.in_features, module.query_key_value.out_features) 117 | qkv_layer.update(module.query_key_value) 118 | module.query_key_value = qkv_layer 119 | except: 120 | pass 121 | module.query_key_value = tuners.lora.LoraModel(peft_config, module.query_key_value) 122 | 123 | lora.mark_only_lora_as_trainable(model) 124 | 125 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 126 | trainable_params = sum([np.prod(p.size()) for p in model_parameters]) 127 | 128 | model_parameters = filter(lambda p: not p.requires_grad, model.parameters()) 129 | non_trainable_params = sum([np.prod(p.size()) for p in model_parameters]) 130 | 131 | print('trainable_params:{} ({:.2f}%), non_trainable_params:{}'.format(trainable_params, 132 | trainable_params / non_trainable_params * 100, 133 | non_trainable_params)) 134 | 135 | # 获取peft模型,即使用低秩逼近技术优化后的模型 136 | # model = get_peft_model(model, peft_config) 137 | if args.continue_training: 138 | model.load_state_dict(torch.load('output_finetune_99.pt'), strict=False) 139 | 140 | # 从指定路径加载数据集,并转换为torch格式 141 | train_dataset = load_dataset() 142 | # dataset = datasets.load_from_disk(args.dataset_path) 143 | 144 | NUM_EPOCHS = args.save_steps 145 | accumulate_step = 32 146 | 147 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) 148 | 149 | lr_scheduler = get_linear_schedule_with_warmup( 150 | optimizer=optimizer, 151 | num_warmup_steps=int(len(train_dataset) / accumulate_step), 152 | num_training_steps=(int(len(train_dataset) / accumulate_step) * NUM_EPOCHS), 153 | ) 154 | 155 | model.train() 156 | 157 | version = "finetune" 158 | with autocast(dtype=torch.bfloat16): 159 | for epoch in range(NUM_EPOCHS): 160 | torch.cuda.empty_cache() 161 | total_loss = 0 162 | for step, batch in enumerate(t := tqdm.tqdm(train_dataset)): 163 | batch = {k: v for k, v in batch.items()} 164 | outputs = model(**batch) 165 | loss_d = outputs.loss.detach().float() 166 | t.set_description(f"loss: {loss_d}") 167 | total_loss += loss_d 168 | loss = outputs.loss / accumulate_step 169 | loss.backward() 170 | if (step + 1) % accumulate_step == 0: 171 | optimizer.step() 172 | lr_scheduler.step() 173 | optimizer.zero_grad() 174 | torch.cuda.empty_cache() 175 | peft_model_id = f"{args.output_dir}_{version}_{epoch}" 176 | print(peft_model_id) 177 | torch.save(lora.lora_state_dict(model), peft_model_id + '.pt') 178 | print(epoch, total_loss / (step + 1)) 179 | 180 | ''' 181 | # 开始训练过程,使用ModifiedTrainer类创建训练器对象,并传入模型、数据集、参数和数据整理器等参数 182 | trainer = ModifiedTrainer( 183 | model=model, 184 | train_dataset=train_dataset, 185 | args=training_args, 186 | data_collator=data_collator, 187 | 188 | ) 189 | trainer.train() # 调用train方法进行训练 190 | ''' 191 | 192 | # 保存训练后的模型参数到指定路径下的文件中 193 | # save_tunable_parameters(model, os.path.join(args.output_dir, "chatglm-lora.pt")) 194 | 195 | 196 | if __name__ == "__main__": 197 | main() -------------------------------------------------------------------------------- /tokenization_chatglm.py: -------------------------------------------------------------------------------- 1 | """Tokenization classes for ChatGLM.""" 2 | import sys 3 | import unicodedata 4 | from typing import List, Optional, Union 5 | from functools import lru_cache 6 | import os 7 | import collections 8 | import re 9 | 10 | from transformers.tokenization_utils import PreTrainedTokenizer 11 | from icetk.text_tokenizer import TextTokenizer 12 | from icetk.utils import auto_create 13 | import icetk.sentencepiece_model_pb2 as sp_model 14 | from transformers.utils import logging 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 19 | "THUDM/chatglm-6b": 2048, 20 | } 21 | 22 | 23 | class SPTokenizer: 24 | def __init__( 25 | self, 26 | vocab_file, 27 | max_blank_length=80, 28 | byte_fallback=True, 29 | ): 30 | assert vocab_file is not None 31 | self.vocab_file = vocab_file 32 | self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""] 33 | self.max_blank_length = max_blank_length 34 | self.byte_fallback = byte_fallback 35 | self.text_tokenizer = self._build_text_tokenizer(encode_special_tokens=False) 36 | self.special_text_tokenizer = self._build_text_tokenizer(encode_special_tokens=True) 37 | 38 | @staticmethod 39 | def _configure_tokenizer( 40 | text_tokenizer: TextTokenizer, 41 | special_tokens: List[str], 42 | max_blank_length: int, 43 | byte_fallback: bool, 44 | encode_special_tokens=False, 45 | ): 46 | # special token 47 | special_token_type = 4 if encode_special_tokens else 3 # 3 - CONTROL, 4 - USER_DEFINE 48 | for token in special_tokens: 49 | text_tokenizer.proto.pieces.append( 50 | sp_model.ModelProto.SentencePiece(piece=token, score=0.0, type=special_token_type) 51 | ) 52 | # whitespaces 53 | for token in [SPTokenizer.get_tab_token()] + [ 54 | SPTokenizer.get_blank_token(i) for i in range(2, max_blank_length + 1) 55 | ]: 56 | text_tokenizer.proto.pieces.append(sp_model.ModelProto.SentencePiece(piece=token, score=0.0, type=4)) 57 | # byte fallback 58 | if byte_fallback: 59 | text_tokenizer.proto.trainer_spec.byte_fallback = True 60 | for i in range(256): 61 | text_tokenizer.proto.pieces.append( 62 | sp_model.ModelProto.SentencePiece(piece="<0x{:02X}>".format(i), score=0.0, type=6) 63 | ) 64 | text_tokenizer.refresh() 65 | 66 | def _build_text_tokenizer(self, encode_special_tokens=False): 67 | tokenizer = TextTokenizer(self.vocab_file) 68 | self._configure_tokenizer( 69 | tokenizer, self.special_tokens, self.max_blank_length, self.byte_fallback, encode_special_tokens 70 | ) 71 | return tokenizer 72 | 73 | def _get_text_tokenizer(self, encode_special_tokens=False): 74 | if encode_special_tokens: 75 | return self.special_text_tokenizer 76 | else: 77 | return self.text_tokenizer 78 | 79 | @staticmethod 80 | def get_blank_token(length: int): 81 | assert length >= 2 82 | return f"<|blank_{length}|>" 83 | 84 | @staticmethod 85 | def get_tab_token(): 86 | return f"<|tab|>" 87 | 88 | @property 89 | def num_image_tokens(self): 90 | return 20000 91 | 92 | @property 93 | def num_text_tokens(self): 94 | return self.text_tokenizer.num_tokens 95 | 96 | @property 97 | def num_tokens(self): 98 | return self.num_image_tokens + self.num_text_tokens 99 | 100 | @staticmethod 101 | def _encode_whitespaces(text: str, max_len: int = 80): 102 | text = text.replace("\t", SPTokenizer.get_tab_token()) 103 | for i in range(max_len, 1, -1): 104 | text = text.replace(" " * i, SPTokenizer.get_blank_token(i)) 105 | return text 106 | 107 | def _preprocess(self, text: str, linebreak=True, whitespaces=True): 108 | if linebreak: 109 | text = text.replace("\n", "") 110 | if whitespaces: 111 | text = self._encode_whitespaces(text, max_len=self.max_blank_length) 112 | return text 113 | 114 | def encode( 115 | self, text: str, linebreak=True, whitespaces=True, special_tokens=False, add_dummy_prefix=True 116 | ) -> List[int]: 117 | """ 118 | @param text: Text to encode. 119 | @param linebreak: Whether to encode newline (\n) in text. 120 | @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. 121 | @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. 122 | @param add_dummy_prefix: Whether to add dummy blank space in the beginning. 123 | """ 124 | text = self._preprocess(text, linebreak, whitespaces) 125 | if not add_dummy_prefix: 126 | text = "" + text 127 | tmp = self._get_text_tokenizer(encode_special_tokens=special_tokens).encode(text) 128 | tokens = [x + self.num_image_tokens for x in tmp] 129 | return tokens if add_dummy_prefix else tokens[2:] 130 | 131 | def decode(self, text_ids: List[int], special_tokens=False) -> str: 132 | ids = [int(_id) - self.num_image_tokens for _id in text_ids] 133 | text = self._get_text_tokenizer(encode_special_tokens=special_tokens).decode(ids) 134 | text = text.replace("", "\n") 135 | text = text.replace(SPTokenizer.get_tab_token(), "\t") 136 | for i in range(2, self.max_blank_length + 1): 137 | text = text.replace(self.get_blank_token(i), " " * i) 138 | return text 139 | 140 | def tokenize( 141 | self, text: str, linebreak=True, whitespaces=True, special_tokens=False, add_dummy_prefix=True 142 | ) -> List[str]: 143 | """ 144 | @param text: Text to encode. 145 | @param linebreak: Whether to encode newline (\n) in text. 146 | @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. 147 | @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. 148 | @param add_dummy_prefix: Whether to add dummy blank space in the beginning. 149 | """ 150 | text = self._preprocess(text, linebreak, whitespaces) 151 | if not add_dummy_prefix: 152 | text = "" + text 153 | tokens = self._get_text_tokenizer(encode_special_tokens=special_tokens).tokenize(text) 154 | return tokens if add_dummy_prefix else tokens[2:] 155 | 156 | def __getitem__(self, x: Union[int, str]): 157 | if isinstance(x, int): 158 | if x < self.num_image_tokens: 159 | return "".format(x) 160 | else: 161 | return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens) 162 | elif isinstance(x, str): 163 | if x.startswith("") and x[7:-1].isdigit(): 164 | return int(x[7:-1]) 165 | else: 166 | return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens 167 | else: 168 | raise ValueError("The key should be str or int.") 169 | 170 | 171 | class ChatGLMTokenizer(PreTrainedTokenizer): 172 | """ 173 | Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding. 174 | 175 | Args: 176 | vocab_file (`str`): 177 | Path to the vocabulary file. 178 | """ 179 | 180 | vocab_files_names = {"vocab_file": "ice_text.model"} 181 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 182 | model_input_names = ["input_ids"] 183 | 184 | def __init__( 185 | self, 186 | vocab_file, 187 | do_lower_case=False, 188 | remove_space=False, 189 | bos_token='sop', 190 | eos_token='eos', 191 | eop_token='eop', 192 | mask_token='[MASK]', 193 | gmask_token='[gMASK]', 194 | padding_side="left", 195 | **kwargs 196 | ) -> None: 197 | super().__init__( 198 | do_lower_case=do_lower_case, 199 | remove_space=remove_space, 200 | padding_side=padding_side, 201 | **kwargs 202 | ) 203 | 204 | self.do_lower_case = do_lower_case 205 | self.remove_space = remove_space 206 | self.vocab_file = vocab_file 207 | 208 | self.bos_token = bos_token 209 | self.eos_token = eos_token 210 | self.eop_token = eop_token 211 | self.mask_token = mask_token 212 | self.gMASK_token = gmask_token 213 | 214 | self.sp_tokenizer = SPTokenizer(vocab_file) 215 | 216 | """ Initialisation """ 217 | 218 | @property 219 | def eop_token_id(self) -> Optional[int]: 220 | """ 221 | `Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been 222 | set. 223 | """ 224 | if self.eop_token is None: 225 | return None 226 | return self.convert_tokens_to_ids(self.eop_token) 227 | 228 | @property 229 | def vocab_size(self): 230 | """ Returns vocab size """ 231 | return self.sp_tokenizer.num_tokens 232 | 233 | def get_vocab(self): 234 | """ Returns vocab as a dict """ 235 | vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} 236 | vocab.update(self.added_tokens_encoder) 237 | return vocab 238 | 239 | def preprocess_text(self, inputs): 240 | if self.remove_space: 241 | outputs = " ".join(inputs.strip().split()) 242 | else: 243 | outputs = inputs 244 | 245 | if self.do_lower_case: 246 | outputs = outputs.lower() 247 | 248 | return outputs 249 | 250 | def _tokenize(self, text, **kwargs): 251 | """ Returns a tokenized string. """ 252 | text = self.preprocess_text(text) 253 | 254 | seq = self.sp_tokenizer.tokenize(text) 255 | 256 | return seq 257 | 258 | def decode( 259 | self, 260 | token_ids: Union[List[int], List[List[int]]], 261 | skip_special_tokens: bool = False, 262 | clean_up_tokenization_spaces: bool = True, 263 | spaces_between_special_tokens: bool = True, 264 | **kwargs 265 | ) -> str: 266 | if isinstance(token_ids[0], list): 267 | tokens = [] 268 | for single_token_ids in token_ids: 269 | if self.pad_token_id in single_token_ids: # remove pad 270 | single_token_ids = list(filter((self.pad_token_id).__ne__, single_token_ids)) 271 | tokens.append(self.sp_tokenizer.decode(single_token_ids)) 272 | return (tokens) 273 | else: 274 | if self.pad_token_id in token_ids: # remove pad 275 | token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) 276 | return self.sp_tokenizer.decode(token_ids) 277 | 278 | def _convert_token_to_id(self, token): 279 | """ Converts a token (str) in an id using the vocab. """ 280 | return self.sp_tokenizer[token] 281 | 282 | def _convert_id_to_token(self, index): 283 | """Converts an index (integer) in a token (str) using the vocab.""" 284 | return self.sp_tokenizer[index] 285 | 286 | def save_vocabulary(self, save_directory, filename_prefix=None): 287 | """ 288 | Save the vocabulary and special tokens file to a directory. 289 | 290 | Args: 291 | save_directory (`str`): 292 | The directory in which to save the vocabulary. 293 | filename_prefix (`str`, *optional*): 294 | An optional prefix to add to the named of the saved files. 295 | 296 | Returns: 297 | `Tuple(str)`: Paths to the files saved. 298 | """ 299 | if os.path.isdir(save_directory): 300 | vocab_file = os.path.join( 301 | save_directory, VOCAB_FILES_NAMES["vocab_file"] 302 | ) 303 | else: 304 | vocab_file = save_directory 305 | 306 | with open(self.vocab_file, 'rb') as fin: 307 | proto_str = fin.read() 308 | 309 | with open(vocab_file, "wb") as writer: 310 | writer.write(proto_str) 311 | 312 | return (vocab_file,) 313 | 314 | def build_inputs_with_special_tokens( 315 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 316 | ) -> List[int]: 317 | """ 318 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 319 | adding special tokens. A BERT sequence has the following format: 320 | 321 | - single sequence: `[CLS] X [SEP]` 322 | - pair of sequences: `[CLS] A [SEP] B [SEP]` 323 | 324 | Args: 325 | token_ids_0 (`List[int]`): 326 | List of IDs to which the special tokens will be added. 327 | token_ids_1 (`List[int]`, *optional*): 328 | Optional second list of IDs for sequence pairs. 329 | 330 | Returns: 331 | `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. 332 | """ 333 | if token_ids_1 is not None: 334 | token_ids_0 += token_ids_1 335 | mask_ids = self.sp_tokenizer[self.mask_token] 336 | gmask_ids = self.sp_tokenizer[self.gMASK_token] 337 | if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0: 338 | token_ids_0 += [gmask_ids] 339 | 340 | if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids: 341 | token_ids_0 += [self.sp_tokenizer[self.eos_token]] 342 | 343 | token_ids_0 += [self.sp_tokenizer[self.bos_token]] 344 | 345 | return token_ids_0 346 | -------------------------------------------------------------------------------- /quantization.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear 2 | from torch.nn.parameter import Parameter 3 | 4 | import bz2 5 | import torch 6 | import base64 7 | import ctypes 8 | 9 | from typing import List 10 | from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up 11 | 12 | 13 | class W8A16Linear(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): 16 | ctx.inp_shape = inp.size() 17 | ctx.weight_shape = quant_w.size() 18 | ctx.weight_bit_width = weight_bit_width 19 | out_features = quant_w.size(0) 20 | inp = inp.contiguous().view(-1, inp.size(-1)) 21 | weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) 22 | output = inp.mm(weight.t()) 23 | ctx.save_for_backward(inp, quant_w, scale_w) 24 | return output.view(*(ctx.inp_shape[:-1] + (out_features,))) 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output: torch.Tensor): 28 | inp, quant_w, scale_w = ctx.saved_tensors 29 | weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) 30 | grad_output = grad_output.contiguous().view(-1, weight.size(0)) 31 | grad_input = grad_output.mm(weight) 32 | grad_weight = grad_output.t().mm(inp) 33 | return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None 34 | 35 | 36 | class Kernel: 37 | def __init__(self, code: bytes, function_names: List[str]): 38 | self.code = code 39 | self._function_names = function_names 40 | self._cmodule = LazyKernelCModule(self.code) 41 | 42 | for name in self._function_names: 43 | setattr(self, name, KernelFunction(self._cmodule, name)) 44 | 45 | 46 | quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" 47 | 48 | kernels = Kernel( 49 | bz2.decompress(base64.b64decode(quantization_code)), 50 | [ 51 | "int4WeightCompression", 52 | "int4WeightExtractionFloat", 53 | "int4WeightExtractionHalf", 54 | "int8WeightExtractionFloat", 55 | "int8WeightExtractionHalf", 56 | ], 57 | ) 58 | 59 | 60 | def compress_int4_weight(weight: torch.Tensor): # (n, m) 61 | with torch.cuda.device(weight.device): 62 | n, m = weight.size(0), weight.size(1) 63 | assert m % 2 == 0 64 | m = m // 2 65 | out = torch.empty(n, m, dtype=torch.int8, device="cuda") 66 | stream = torch.cuda.current_stream() 67 | 68 | gridDim = (n, 1, 1) 69 | blockDim = (min(round_up(m, 32), 1024), 1, 1) 70 | 71 | kernels.int4WeightCompression( 72 | gridDim, 73 | blockDim, 74 | 0, 75 | stream, 76 | [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], 77 | ) 78 | return out 79 | 80 | 81 | def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): 82 | if source_bit_width == 8: 83 | func = kernels.int8WeightExtractionHalf 84 | elif source_bit_width == 4: 85 | func = kernels.int4WeightExtractionHalf 86 | else: 87 | assert False, "Unsupported bit-width" 88 | 89 | with torch.cuda.device(weight.device): 90 | n, m = weight.size(0), weight.size(1) 91 | out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda") 92 | stream = torch.cuda.current_stream() 93 | 94 | gridDim = (n, 1, 1) 95 | blockDim = (min(round_up(m, 32), 1024), 1, 1) 96 | 97 | func( 98 | gridDim, 99 | blockDim, 100 | 0, 101 | stream, 102 | [ 103 | ctypes.c_void_p(weight.data_ptr()), 104 | ctypes.c_void_p(scale_list.data_ptr()), 105 | ctypes.c_void_p(out.data_ptr()), 106 | ctypes.c_int32(n), 107 | ctypes.c_int32(m), 108 | ], 109 | ) 110 | return out 111 | 112 | 113 | class QuantizedLinear(Linear): 114 | def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, *args, **kwargs): 115 | super(QuantizedLinear, self).__init__(*args, **kwargs) 116 | self.weight_bit_width = weight_bit_width 117 | 118 | shape = self.weight.shape 119 | del self.weight 120 | 121 | if weight_tensor is None: 122 | self.weight = torch.empty( 123 | shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] 124 | ) 125 | self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"]) 126 | else: 127 | self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() 128 | self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8) 129 | if weight_bit_width == 4: 130 | self.weight = compress_int4_weight(self.weight) 131 | 132 | self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) 133 | self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) 134 | self.bias = Parameter(bias_tensor.to(kwargs["device"]), requires_grad=False) 135 | 136 | def forward(self, input): 137 | output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) 138 | if self.bias is not None: 139 | output = output + self.bias 140 | return output 141 | 142 | 143 | def quantize(model, weight_bit_width): 144 | """Replace fp16 linear with quantized linear""" 145 | 146 | for layer in model.layers: 147 | layer.attention.query_key_value = QuantizedLinear( 148 | weight_bit_width=weight_bit_width, 149 | weight_tensor=layer.attention.query_key_value.weight.to(torch.cuda.current_device()), 150 | bias_tensor=layer.attention.query_key_value.bias, 151 | in_features=layer.attention.query_key_value.in_features, 152 | out_features=layer.attention.query_key_value.out_features, 153 | bias=True, 154 | dtype=torch.half, 155 | device=layer.attention.query_key_value.weight.device, 156 | ) 157 | layer.attention.dense = QuantizedLinear( 158 | weight_bit_width=weight_bit_width, 159 | weight_tensor=layer.attention.dense.weight.to(torch.cuda.current_device()), 160 | bias_tensor=layer.attention.dense.bias, 161 | in_features=layer.attention.dense.in_features, 162 | out_features=layer.attention.dense.out_features, 163 | bias=True, 164 | dtype=torch.half, 165 | device=layer.attention.dense.weight.device, 166 | ) 167 | layer.mlp.dense_h_to_4h = QuantizedLinear( 168 | weight_bit_width=weight_bit_width, 169 | weight_tensor=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), 170 | bias_tensor=layer.mlp.dense_h_to_4h.bias, 171 | in_features=layer.mlp.dense_h_to_4h.in_features, 172 | out_features=layer.mlp.dense_h_to_4h.out_features, 173 | bias=True, 174 | dtype=torch.half, 175 | device=layer.mlp.dense_h_to_4h.weight.device, 176 | ) 177 | layer.mlp.dense_4h_to_h = QuantizedLinear( 178 | weight_bit_width=weight_bit_width, 179 | weight_tensor=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), 180 | bias_tensor=layer.mlp.dense_4h_to_h.bias, 181 | in_features=layer.mlp.dense_4h_to_h.in_features, 182 | out_features=layer.mlp.dense_4h_to_h.out_features, 183 | bias=True, 184 | dtype=torch.half, 185 | device=layer.mlp.dense_4h_to_h.weight.device, 186 | ) 187 | return model 188 | -------------------------------------------------------------------------------- /modeling_chatglm.py: -------------------------------------------------------------------------------- 1 | """ PyTorch ChatGLM model. """ 2 | 3 | import math 4 | import copy 5 | import os 6 | import time 7 | 8 | import torch 9 | import torch.utils.checkpoint 10 | import torch.nn.functional as F 11 | from torch import nn 12 | from torch.nn import CrossEntropyLoss, LayerNorm 13 | from torch.nn.utils import skip_init 14 | from typing import Optional, Tuple, Union, List 15 | 16 | from transformers.utils import ( 17 | add_code_sample_docstrings, 18 | add_start_docstrings, 19 | add_start_docstrings_to_model_forward, 20 | ) 21 | from transformers.modeling_outputs import ( 22 | BaseModelOutputWithPast, 23 | CausalLMOutputWithPast, 24 | BaseModelOutputWithPastAndCrossAttentions, 25 | ) 26 | from transformers.modeling_utils import PreTrainedModel 27 | from transformers.utils import logging 28 | from transformers.generation.logits_process import LogitsProcessor 29 | from transformers.generation.utils import LogitsProcessorList 30 | 31 | from configuration_chatglm import ChatGLMConfig 32 | 33 | # flags required to enable jit fusion kernels 34 | torch._C._jit_set_profiling_mode(False) 35 | torch._C._jit_set_profiling_executor(False) 36 | torch._C._jit_override_can_fuse_on_cpu(True) 37 | torch._C._jit_override_can_fuse_on_gpu(True) 38 | 39 | logger = logging.get_logger(__name__) 40 | 41 | _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B" 42 | _CONFIG_FOR_DOC = "ChatGLM6BConfig" 43 | 44 | CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ 45 | "THUDM/chatglm-6b", 46 | # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm 47 | ] 48 | 49 | 50 | class InvalidScoreLogitsProcessor(LogitsProcessor): 51 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 52 | if torch.isnan(scores).any() or torch.isinf(scores).any(): 53 | scores.zero_() 54 | scores[..., 20005] = 1e5 55 | return scores 56 | 57 | 58 | def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): 59 | """Load tf checkpoints in a pytorch model.""" 60 | try: 61 | import re 62 | 63 | import numpy as np 64 | import tensorflow as tf 65 | except ImportError: 66 | logger.error( 67 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 68 | "https://www.tensorflow.org/install/ for installation instructions." 69 | ) 70 | raise 71 | tf_path = os.path.abspath(tf_checkpoint_path) 72 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 73 | # Load weights from TF model 74 | init_vars = tf.train.list_variables(tf_path) 75 | names = [] 76 | arrays = [] 77 | for name, shape in init_vars: 78 | logger.info(f"Loading TF weight {name} with shape {shape}") 79 | array = tf.train.load_variable(tf_path, name) 80 | names.append(name) 81 | arrays.append(array) 82 | 83 | for name, array in zip(names, arrays): 84 | name = name.split("/") 85 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 86 | # which are not required for using pretrained model 87 | if any( 88 | n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] 89 | for n in name 90 | ): 91 | logger.info(f"Skipping {'/'.join(name)}") 92 | continue 93 | pointer = model 94 | for m_name in name: 95 | if re.fullmatch(r"[A-Za-z]+_\d+", m_name): 96 | scope_names = re.split(r"_(\d+)", m_name) 97 | else: 98 | scope_names = [m_name] 99 | if scope_names[0] == "kernel" or scope_names[0] == "gamma": 100 | pointer = getattr(pointer, "weight") 101 | elif scope_names[0] == "output_bias" or scope_names[0] == "beta": 102 | pointer = getattr(pointer, "bias") 103 | elif scope_names[0] == "output_weights": 104 | pointer = getattr(pointer, "weight") 105 | elif scope_names[0] == "squad": 106 | pointer = getattr(pointer, "classifier") 107 | else: 108 | try: 109 | pointer = getattr(pointer, scope_names[0]) 110 | except AttributeError: 111 | logger.info(f"Skipping {'/'.join(name)}") 112 | continue 113 | if len(scope_names) >= 2: 114 | num = int(scope_names[1]) 115 | pointer = pointer[num] 116 | if m_name[-11:] == "_embeddings": 117 | pointer = getattr(pointer, "weight") 118 | elif m_name == "kernel": 119 | array = np.transpose(array) 120 | try: 121 | assert ( 122 | pointer.shape == array.shape 123 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 124 | except AssertionError as e: 125 | e.args += (pointer.shape, array.shape) 126 | raise 127 | logger.info(f"Initialize PyTorch weight {name}") 128 | pointer.data = torch.from_numpy(array) 129 | return model 130 | 131 | 132 | @torch.jit.script 133 | def gelu_impl(x): 134 | """OpenAI's gelu implementation.""" 135 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * 136 | (1.0 + 0.044715 * x * x))) 137 | 138 | 139 | def gelu(x): 140 | return gelu_impl(x) 141 | 142 | 143 | class RotaryEmbedding(torch.nn.Module): 144 | def __init__(self, dim, base=10000, precision=torch.half, learnable=False): 145 | super().__init__() 146 | inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) 147 | inv_freq = inv_freq.half() 148 | self.learnable = learnable 149 | if learnable: 150 | self.inv_freq = torch.nn.Parameter(inv_freq) 151 | self.max_seq_len_cached = None 152 | else: 153 | self.register_buffer('inv_freq', inv_freq) 154 | self.max_seq_len_cached = None 155 | self.cos_cached = None 156 | self.sin_cached = None 157 | self.precision = precision 158 | 159 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, 160 | error_msgs): 161 | pass 162 | 163 | def forward(self, x, seq_dim=1, seq_len=None): 164 | if seq_len is None: 165 | seq_len = x.shape[seq_dim] 166 | if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): 167 | self.max_seq_len_cached = None if self.learnable else seq_len 168 | t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) 169 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 170 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 171 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 172 | if self.precision == torch.bfloat16: 173 | emb = emb.float() 174 | 175 | # [sx, 1 (b * np), hn] 176 | cos_cached = emb.cos()[:, None, :] 177 | sin_cached = emb.sin()[:, None, :] 178 | if self.precision == torch.bfloat16: 179 | cos_cached = cos_cached.bfloat16() 180 | sin_cached = sin_cached.bfloat16() 181 | if self.learnable: 182 | return cos_cached, sin_cached 183 | self.cos_cached, self.sin_cached = cos_cached, sin_cached 184 | return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] 185 | 186 | 187 | def rotate_half(x): 188 | x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] 189 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions 190 | 191 | 192 | @torch.jit.script 193 | def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): 194 | # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] 195 | cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ 196 | F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) 197 | q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 198 | return q, k 199 | 200 | 201 | def attention_fn( 202 | self, 203 | query_layer, 204 | key_layer, 205 | value_layer, 206 | attention_mask, 207 | hidden_size_per_partition, 208 | layer_id, 209 | layer_past=None, 210 | scaling_attention_score=True, 211 | use_cache=False, 212 | ): 213 | if layer_past is not None: 214 | past_key, past_value = layer_past 215 | key_layer = torch.cat((past_key, key_layer), dim=0) 216 | value_layer = torch.cat((past_value, value_layer), dim=0) 217 | 218 | # seqlen, batch, num_attention_heads, hidden_size_per_attention_head 219 | seq_len, b, nh, hidden_size = key_layer.shape 220 | 221 | if use_cache: 222 | present = (key_layer, value_layer) 223 | else: 224 | present = None 225 | 226 | query_key_layer_scaling_coeff = float(layer_id + 1) 227 | if scaling_attention_score: 228 | query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) 229 | 230 | # =================================== 231 | # Raw attention scores. [b, np, s, s] 232 | # =================================== 233 | 234 | # [b, np, sq, sk] 235 | output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) 236 | 237 | # [sq, b, np, hn] -> [sq, b * np, hn] 238 | query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) 239 | # [sk, b, np, hn] -> [sk, b * np, hn] 240 | key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) 241 | 242 | matmul_result = torch.empty( 243 | output_size[0] * output_size[1], 244 | output_size[2], 245 | output_size[3], 246 | dtype=query_layer.dtype, 247 | device=query_layer.device, 248 | ) 249 | 250 | matmul_result = torch.baddbmm( 251 | matmul_result, 252 | query_layer.transpose(0, 1), # [b * np, sq, hn] 253 | key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] 254 | beta=0.0, 255 | alpha=1.0, 256 | ) 257 | 258 | # change view to [b, np, sq, sk] 259 | attention_scores = matmul_result.view(*output_size) 260 | 261 | if self.scale_mask_softmax: 262 | self.scale_mask_softmax.scale = query_key_layer_scaling_coeff 263 | attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous()) 264 | else: 265 | if not (attention_mask == 0).all(): 266 | # if auto-regressive, skip 267 | attention_scores.masked_fill_(attention_mask, -10000.0) 268 | dtype = attention_scores.type() 269 | attention_scores = attention_scores.float() 270 | attention_scores = attention_scores * query_key_layer_scaling_coeff 271 | 272 | attention_probs = F.softmax(attention_scores, dim=-1) 273 | 274 | attention_probs = attention_probs.type(dtype) 275 | 276 | # ========================= 277 | # Context layer. [sq, b, hp] 278 | # ========================= 279 | 280 | # value_layer -> context layer. 281 | # [sk, b, np, hn] --> [b, np, sq, hn] 282 | 283 | # context layer shape: [b, np, sq, hn] 284 | output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) 285 | 286 | # change view [sk, b * np, hn] 287 | value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) 288 | 289 | # change view [b * np, sq, sk] 290 | attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) 291 | 292 | # matmul: [b * np, sq, hn] 293 | context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) 294 | 295 | # change view [b, np, sq, hn] 296 | context_layer = context_layer.view(*output_size) 297 | 298 | # [b, np, sq, hn] --> [sq, b, np, hn] 299 | context_layer = context_layer.permute(2, 0, 1, 3).contiguous() 300 | 301 | # [sq, b, np, hn] --> [sq, b, hp] 302 | new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) 303 | context_layer = context_layer.view(*new_context_layer_shape) 304 | 305 | outputs = (context_layer, present, attention_probs) 306 | 307 | return outputs 308 | 309 | 310 | class SelfAttention(torch.nn.Module): 311 | def __init__(self, hidden_size, num_attention_heads, 312 | layer_id, hidden_size_per_attention_head=None, bias=True, 313 | params_dtype=torch.float, position_encoding_2d=True): 314 | super(SelfAttention, self).__init__() 315 | 316 | self.layer_id = layer_id 317 | self.hidden_size = hidden_size 318 | self.hidden_size_per_partition = hidden_size 319 | self.num_attention_heads = num_attention_heads 320 | self.num_attention_heads_per_partition = num_attention_heads 321 | self.position_encoding_2d = position_encoding_2d 322 | self.rotary_emb = RotaryEmbedding( 323 | self.hidden_size // (self.num_attention_heads * 2) 324 | if position_encoding_2d 325 | else self.hidden_size // self.num_attention_heads, 326 | base=10000, 327 | precision=torch.half, 328 | learnable=False, 329 | ) 330 | 331 | self.scale_mask_softmax = None 332 | 333 | if hidden_size_per_attention_head is None: 334 | self.hidden_size_per_attention_head = hidden_size // num_attention_heads 335 | else: 336 | self.hidden_size_per_attention_head = hidden_size_per_attention_head 337 | 338 | self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head 339 | 340 | # Strided linear layer. 341 | self.query_key_value = skip_init( 342 | torch.nn.Linear, 343 | hidden_size, 344 | 3 * self.inner_hidden_size, 345 | bias=bias, 346 | dtype=params_dtype, 347 | ) 348 | 349 | self.dense = skip_init( 350 | torch.nn.Linear, 351 | self.inner_hidden_size, 352 | hidden_size, 353 | bias=bias, 354 | dtype=params_dtype, 355 | ) 356 | 357 | @staticmethod 358 | def attention_mask_func(attention_scores, attention_mask): 359 | attention_scores.masked_fill_(attention_mask, -10000.0) 360 | return attention_scores 361 | 362 | def split_tensor_along_last_dim(self, tensor, num_partitions, 363 | contiguous_split_chunks=False): 364 | """Split a tensor along its last dimension. 365 | Arguments: 366 | tensor: input tensor. 367 | num_partitions: number of partitions to split the tensor 368 | contiguous_split_chunks: If True, make each chunk contiguous 369 | in memory. 370 | """ 371 | # Get the size and dimension. 372 | last_dim = tensor.dim() - 1 373 | last_dim_size = tensor.size()[last_dim] // num_partitions 374 | # Split. 375 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 376 | # Note: torch.split does not create contiguous tensors by default. 377 | if contiguous_split_chunks: 378 | return tuple(chunk.contiguous() for chunk in tensor_list) 379 | 380 | return tensor_list 381 | 382 | def forward( 383 | self, 384 | hidden_states: torch.Tensor, 385 | position_ids, 386 | attention_mask: torch.Tensor, 387 | layer_id, 388 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 389 | use_cache: bool = False, 390 | output_attentions: bool = False, 391 | ): 392 | """ 393 | hidden_states: [seq_len, batch, hidden_size] 394 | attention_mask: [(1, 1), seq_len, seq_len] 395 | """ 396 | 397 | # [seq_len, batch, 3 * hidden_size] 398 | mixed_raw_layer = self.query_key_value(hidden_states) 399 | 400 | # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head] 401 | new_tensor_shape = mixed_raw_layer.size()[:-1] + ( 402 | self.num_attention_heads_per_partition, 403 | 3 * self.hidden_size_per_attention_head, 404 | ) 405 | mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) 406 | 407 | # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] 408 | (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) 409 | 410 | if self.position_encoding_2d: 411 | q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) 412 | k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) 413 | cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) 414 | position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \ 415 | position_ids[:, 1, :].transpose(0, 1).contiguous() 416 | q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) 417 | q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) 418 | query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) 419 | key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) 420 | else: 421 | position_ids = position_ids.transpose(0, 1) 422 | cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) 423 | # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] 424 | query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) 425 | 426 | # [seq_len, batch, hidden_size] 427 | context_layer, present, attention_probs = attention_fn( 428 | self=self, 429 | query_layer=query_layer, 430 | key_layer=key_layer, 431 | value_layer=value_layer, 432 | attention_mask=attention_mask, 433 | hidden_size_per_partition=self.hidden_size_per_partition, 434 | layer_id=layer_id, 435 | layer_past=layer_past, 436 | use_cache=use_cache 437 | ) 438 | 439 | output = self.dense(context_layer) 440 | 441 | outputs = (output, present) 442 | 443 | if output_attentions: 444 | outputs += (attention_probs,) 445 | 446 | return outputs # output, present, attention_probs 447 | 448 | 449 | class GEGLU(torch.nn.Module): 450 | def __init__(self): 451 | super().__init__() 452 | self.activation_fn = F.gelu 453 | 454 | def forward(self, x): 455 | # dim=-1 breaks in jit for pt<1.10 456 | x1, x2 = x.chunk(2, dim=(x.ndim - 1)) 457 | return x1 * self.activation_fn(x2) 458 | 459 | 460 | class GLU(torch.nn.Module): 461 | def __init__(self, hidden_size, inner_hidden_size=None, 462 | layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float): 463 | super(GLU, self).__init__() 464 | self.layer_id = layer_id 465 | self.activation_func = activation_func 466 | 467 | # Project to 4h. 468 | self.hidden_size = hidden_size 469 | if inner_hidden_size is None: 470 | inner_hidden_size = 4 * hidden_size 471 | self.inner_hidden_size = inner_hidden_size 472 | self.dense_h_to_4h = skip_init( 473 | torch.nn.Linear, 474 | self.hidden_size, 475 | self.inner_hidden_size, 476 | bias=bias, 477 | dtype=params_dtype, 478 | ) 479 | # Project back to h. 480 | self.dense_4h_to_h = skip_init( 481 | torch.nn.Linear, 482 | self.inner_hidden_size, 483 | self.hidden_size, 484 | bias=bias, 485 | dtype=params_dtype, 486 | ) 487 | 488 | def forward(self, hidden_states): 489 | """ 490 | hidden_states: [seq_len, batch, hidden_size] 491 | """ 492 | 493 | # [seq_len, batch, inner_hidden_size] 494 | intermediate_parallel = self.dense_h_to_4h(hidden_states) 495 | 496 | intermediate_parallel = self.activation_func(intermediate_parallel) 497 | 498 | output = self.dense_4h_to_h(intermediate_parallel) 499 | 500 | return output 501 | 502 | 503 | class GLMBlock(torch.nn.Module): 504 | def __init__( 505 | self, 506 | hidden_size, 507 | num_attention_heads, 508 | layernorm_epsilon, 509 | layer_id, 510 | inner_hidden_size=None, 511 | hidden_size_per_attention_head=None, 512 | layernorm=LayerNorm, 513 | use_bias=True, 514 | params_dtype=torch.float, 515 | num_layers=28, 516 | position_encoding_2d=True 517 | ): 518 | super(GLMBlock, self).__init__() 519 | # Set output layer initialization if not provided. 520 | 521 | self.layer_id = layer_id 522 | 523 | # Layernorm on the input data. 524 | self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) 525 | 526 | self.position_encoding_2d = position_encoding_2d 527 | 528 | # Self attention. 529 | self.attention = SelfAttention( 530 | hidden_size, 531 | num_attention_heads, 532 | layer_id, 533 | hidden_size_per_attention_head=hidden_size_per_attention_head, 534 | bias=use_bias, 535 | params_dtype=params_dtype, 536 | position_encoding_2d=self.position_encoding_2d 537 | ) 538 | 539 | # Layernorm on the input data. 540 | self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) 541 | 542 | self.num_layers = num_layers 543 | 544 | # GLU 545 | self.mlp = GLU( 546 | hidden_size, 547 | inner_hidden_size=inner_hidden_size, 548 | bias=use_bias, 549 | layer_id=layer_id, 550 | params_dtype=params_dtype, 551 | ) 552 | 553 | def forward( 554 | self, 555 | hidden_states: torch.Tensor, 556 | position_ids, 557 | attention_mask: torch.Tensor, 558 | layer_id, 559 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 560 | use_cache: bool = False, 561 | output_attentions: bool = False, 562 | ): 563 | """ 564 | hidden_states: [seq_len, batch, hidden_size] 565 | attention_mask: [(1, 1), seq_len, seq_len] 566 | """ 567 | 568 | # Layer norm at the begining of the transformer layer. 569 | # [seq_len, batch, hidden_size] 570 | attention_input = self.input_layernorm(hidden_states) 571 | 572 | # Self attention. 573 | attention_outputs = self.attention( 574 | attention_input, 575 | position_ids, 576 | attention_mask=attention_mask, 577 | layer_id=layer_id, 578 | layer_past=layer_past, 579 | use_cache=use_cache, 580 | output_attentions=output_attentions 581 | ) 582 | 583 | attention_output = attention_outputs[0] 584 | 585 | outputs = attention_outputs[1:] 586 | 587 | # Residual connection. 588 | alpha = (2 * self.num_layers) ** 0.5 589 | hidden_states = attention_input * alpha + attention_output 590 | 591 | mlp_input = self.post_attention_layernorm(hidden_states) 592 | 593 | # MLP. 594 | mlp_output = self.mlp(mlp_input) 595 | 596 | # Second residual connection. 597 | output = mlp_input * alpha + mlp_output 598 | 599 | if use_cache: 600 | outputs = (output,) + outputs 601 | else: 602 | outputs = (output,) + outputs[1:] 603 | 604 | return outputs # hidden_states, present, attentions 605 | 606 | 607 | class ChatGLMPreTrainedModel(PreTrainedModel): 608 | """ 609 | An abstract class to handle weights initialization and 610 | a simple interface for downloading and loading pretrained models. 611 | """ 612 | 613 | is_parallelizable = True 614 | supports_gradient_checkpointing = True 615 | config_class = ChatGLMConfig 616 | base_model_prefix = "transformer" 617 | _no_split_modules = ["GLM6BBlock"] 618 | 619 | def __init__(self, *inputs, **kwargs): 620 | super().__init__(*inputs, **kwargs) 621 | 622 | def _init_weights(self, module): 623 | return 624 | std = self.config.initializer_range 625 | if isinstance(module, nn.Linear): 626 | module.weight.data.normal_(mean=0.0, std=std) 627 | if module.bias is not None: 628 | module.bias.data.zero_() 629 | elif isinstance(module, nn.Embedding): 630 | module.weight.data.normal_(mean=0.0, std=std) 631 | if module.padding_idx is not None: 632 | module.weight.data[module.padding_idx].zero_() 633 | 634 | def _set_gradient_checkpointing(self, module, value=False): 635 | if isinstance(module, (GLMBlock)): 636 | module.gradient_checkpointing = value 637 | 638 | 639 | CHATGLM_6B_START_DOCSTRING = r""" 640 | This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. 641 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general 642 | usage and behavior. 643 | 644 | Parameters: 645 | config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. 646 | Initializing with a config file does not load the weights associated with the model, only the configuration. 647 | Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 648 | """ 649 | 650 | CHATGLM_6B_INPUTS_DOCSTRING = r""" 651 | Args: 652 | input_ids (`torch.LongTensor` of shape `({0})`): 653 | Indices of input sequence tokens in the vocabulary. 654 | 655 | Indices can be obtained using [`ChatGLM6BTokenizer`]. 656 | See [`PreTrainedTokenizer.encode`] and 657 | [`PreTrainedTokenizer.__call__`] for details. 658 | 659 | [What are input IDs?](../glossary#input-ids) 660 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): 661 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 662 | 663 | - 1 for tokens that are **not masked**, 664 | - 0 for tokens that are **masked**. 665 | 666 | [What are attention masks?](../glossary#attention-mask) 667 | token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): 668 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: 669 | 670 | - 0 corresponds to a *sentence A* token, 671 | - 1 corresponds to a *sentence B* token. 672 | 673 | [What are token type IDs?](../glossary#token-type-ids) 674 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*): 675 | Indices of positions of each input sequence tokens in the position embeddings. 676 | Selected in the range `[0, config.max_position_embeddings - 1]`. 677 | 678 | [What are position IDs?](../glossary#position-ids) 679 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 680 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 681 | 682 | - 1 indicates the head is **not masked**, 683 | - 0 indicates the head is **masked**. 684 | 685 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): 686 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 687 | This is useful if you want more control over how to convert *input_ids* indices into associated vectors 688 | than the model's internal embedding lookup matrix. 689 | output_attentions (`bool`, *optional*): 690 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 691 | tensors for more detail. 692 | output_hidden_states (`bool`, *optional*): 693 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 694 | more detail. 695 | return_dict (`bool`, *optional*): 696 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 697 | """ 698 | 699 | 700 | @add_start_docstrings( 701 | "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.", 702 | CHATGLM_6B_START_DOCSTRING, 703 | ) 704 | class ChatGLMModel(ChatGLMPreTrainedModel): 705 | """ 706 | 707 | The model can behave as an encoder (with only self-attention) as well 708 | as a decoder, in which case a layer of cross-attention is added between 709 | the self-attention layers, following the architecture described in [Attention is 710 | all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, 711 | Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 712 | 713 | To behave as an decoder the model needs to be initialized with the 714 | `is_decoder` argument of the configuration set to `True`. 715 | To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` 716 | argument and `add_cross_attention` set to `True`; an 717 | `encoder_hidden_states` is then expected as an input to the forward pass. 718 | """ 719 | 720 | def __init__(self, config: ChatGLMConfig): 721 | super().__init__(config) 722 | 723 | # recording parameters 724 | self.max_sequence_length = config.max_sequence_length 725 | self.hidden_size = config.hidden_size 726 | self.params_dtype = torch.half 727 | self.num_attention_heads = config.num_attention_heads 728 | self.vocab_size = config.vocab_size 729 | self.num_layers = config.num_layers 730 | self.layernorm_epsilon = config.layernorm_epsilon 731 | self.inner_hidden_size = config.inner_hidden_size 732 | self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads 733 | self.position_encoding_2d = config.position_encoding_2d 734 | self.model_parallel = True 735 | 736 | self.word_embeddings = skip_init( 737 | torch.nn.Embedding, 738 | num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, 739 | dtype=self.params_dtype 740 | ) 741 | 742 | def get_layer(layer_id): 743 | return GLMBlock( 744 | self.hidden_size, 745 | self.num_attention_heads, 746 | self.layernorm_epsilon, 747 | layer_id, 748 | inner_hidden_size=self.inner_hidden_size, 749 | hidden_size_per_attention_head=self.hidden_size_per_attention_head, 750 | layernorm=LayerNorm, 751 | use_bias=True, 752 | params_dtype=self.params_dtype, 753 | position_encoding_2d=self.position_encoding_2d, 754 | ) 755 | 756 | self.layers = torch.nn.ModuleList( 757 | [get_layer(layer_id) for layer_id in range(self.num_layers)] 758 | ) 759 | 760 | # Final layer norm before output. 761 | self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) 762 | 763 | def get_input_embeddings(self): 764 | return self.word_embeddings 765 | 766 | def set_input_embeddings(self, new_embeddings: torch.Tensor): 767 | self.word_embeddings = new_embeddings 768 | 769 | @staticmethod 770 | def get_masks(seq, device): 771 | context_length = seq.index(150004) + 1 772 | 773 | attention_mask = torch.ones((1, len(seq), len(seq)), device=device) 774 | attention_mask.tril_() 775 | attention_mask[..., :context_length - 1] = 1 776 | attention_mask.unsqueeze_(1) 777 | attention_mask = (attention_mask < 0.5).bool() 778 | 779 | return attention_mask 780 | 781 | def get_position_ids(self, seq, mask_position, device, gmask=False): 782 | context_length = seq.index(150004) + 1 783 | if self.position_encoding_2d: 784 | seq_length = seq.index(150004) 785 | position_ids = torch.arange(context_length, dtype=torch.long, device=device) 786 | if not gmask: 787 | position_ids[seq_length:] = mask_position 788 | block_position_ids = torch.cat(( 789 | torch.zeros(seq_length, dtype=torch.long, device=device), 790 | torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1 791 | )) 792 | position_ids = torch.stack((position_ids, block_position_ids), dim=0) 793 | else: 794 | position_ids = torch.arange(context_length, dtype=torch.long, device=device) 795 | if not gmask: 796 | position_ids[context_length - 1:] = mask_position 797 | 798 | position_ids = position_ids.unsqueeze(0) 799 | 800 | return position_ids 801 | 802 | @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 803 | @add_code_sample_docstrings( 804 | checkpoint=_CHECKPOINT_FOR_DOC, 805 | output_type=BaseModelOutputWithPastAndCrossAttentions, 806 | config_class=_CONFIG_FOR_DOC, 807 | ) 808 | def forward( 809 | self, 810 | input_ids: Optional[torch.LongTensor] = None, 811 | position_ids: Optional[torch.LongTensor] = None, 812 | attention_mask: Optional[torch.Tensor] = None, 813 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 814 | inputs_embeds: Optional[torch.LongTensor] = None, 815 | use_cache: Optional[bool] = None, 816 | output_attentions: Optional[bool] = None, 817 | output_hidden_states: Optional[bool] = None, 818 | return_dict: Optional[bool] = None, 819 | ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: 820 | 821 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 822 | output_hidden_states = ( 823 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 824 | ) 825 | use_cache = use_cache if use_cache is not None else self.config.use_cache 826 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 827 | 828 | if input_ids is not None and inputs_embeds is not None: 829 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 830 | elif input_ids is not None: 831 | batch_size, seq_length = input_ids.shape[:2] 832 | elif inputs_embeds is not None: 833 | batch_size, seq_length, _ = inputs_embeds.shape[:2] 834 | else: 835 | raise ValueError("You have to specify either input_ids or inputs_embeds") 836 | 837 | if past_key_values is None: 838 | past_key_values = tuple([None] * len(self.layers)) 839 | 840 | MASK, gMASK = 150000, 150001 841 | mask_token = MASK if MASK in input_ids else gMASK 842 | use_gmask = False if MASK in input_ids else gMASK 843 | seq = input_ids[0].tolist() 844 | 845 | mask_position = seq.index(mask_token) 846 | 847 | if attention_mask is None: 848 | attention_mask = self.get_masks( 849 | seq=seq, 850 | device=input_ids.device 851 | ) 852 | 853 | if position_ids is None: 854 | position_ids = self.get_position_ids( 855 | seq=seq, 856 | mask_position=mask_position, 857 | device=input_ids.device, 858 | gmask=use_gmask 859 | ) 860 | 861 | if inputs_embeds is None: 862 | inputs_embeds = self.word_embeddings(input_ids) 863 | 864 | # [seq_len, batch, hidden_size] 865 | hidden_states = inputs_embeds.transpose(0, 1) 866 | 867 | presents = () if use_cache else None 868 | all_self_attentions = () if output_attentions else None 869 | all_hidden_states = () if output_hidden_states else None 870 | 871 | seq_length_with_past = seq_length 872 | past_key_values_length = 0 873 | if past_key_values[0] is not None: 874 | past_key_values_length = past_key_values[0][0].shape[0] 875 | seq_length_with_past = seq_length_with_past + past_key_values_length 876 | if attention_mask is None: 877 | attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() 878 | 879 | else: 880 | attention_mask = attention_mask.to(input_ids.device) 881 | 882 | for i, layer in enumerate(self.layers): 883 | 884 | if output_hidden_states: 885 | all_hidden_states = all_hidden_states + (hidden_states,) 886 | 887 | layer_ret = layer( 888 | hidden_states, 889 | position_ids=position_ids, 890 | attention_mask=attention_mask, 891 | layer_id=torch.tensor(i), 892 | layer_past=past_key_values[i], 893 | use_cache=use_cache, 894 | output_attentions=output_attentions 895 | ) 896 | 897 | hidden_states = layer_ret[0] 898 | 899 | if use_cache: 900 | presents = presents + (layer_ret[1],) 901 | 902 | if output_attentions: 903 | all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],) 904 | 905 | # Final layer norm. 906 | hidden_states = self.final_layernorm(hidden_states) 907 | 908 | if output_hidden_states: 909 | all_hidden_states = all_hidden_states + (hidden_states,) 910 | 911 | if not return_dict: 912 | return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) 913 | 914 | return BaseModelOutputWithPast( 915 | last_hidden_state=hidden_states, 916 | past_key_values=presents, 917 | hidden_states=all_hidden_states, 918 | attentions=all_self_attentions, 919 | ) 920 | 921 | 922 | class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): 923 | def __init__(self, config): 924 | super().__init__(config) 925 | 926 | # self.hidden_size = config.hidden_size 927 | # self.params_dtype = torch.half 928 | # self.vocab_size = config.vocab_size 929 | self.max_sequence_length = config.max_sequence_length 930 | 931 | self.position_encoding_2d = config.position_encoding_2d 932 | 933 | self.transformer = ChatGLMModel(config) 934 | 935 | self.lm_head = skip_init( 936 | nn.Linear, 937 | config.hidden_size, 938 | config.vocab_size, 939 | bias=False, 940 | dtype=torch.half 941 | ) 942 | 943 | def get_output_embeddings(self): 944 | return self.lm_head 945 | 946 | def set_output_embeddings(self, new_embeddings): 947 | self.lm_head = new_embeddings 948 | 949 | def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False): 950 | attention_mask = torch.ones((1, context_length, context_length), device=device) 951 | attention_mask.tril_() 952 | attention_mask[..., :context_length - 1] = 1 953 | attention_mask.unsqueeze_(1) 954 | attention_mask = (attention_mask < 0.5).bool() 955 | 956 | if self.position_encoding_2d: 957 | seq_length = seq.index(150004) 958 | position_ids = torch.arange(context_length, dtype=torch.long, device=device) 959 | if not gmask: 960 | position_ids[seq_length:] = mask_position 961 | block_position_ids = torch.cat(( 962 | torch.zeros(seq_length, dtype=torch.long, device=device), 963 | torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1 964 | )) 965 | position_ids = torch.stack((position_ids, block_position_ids), dim=0) 966 | else: 967 | position_ids = torch.arange(context_length, dtype=torch.long, device=device) 968 | if not gmask: 969 | position_ids[context_length - 1:] = mask_position 970 | 971 | position_ids = position_ids.unsqueeze(0) 972 | 973 | return attention_mask, position_ids 974 | 975 | def prepare_inputs_for_generation( 976 | self, 977 | input_ids: torch.LongTensor, 978 | past: Optional[torch.Tensor] = None, 979 | past_key_values: Optional[torch.Tensor] = None, 980 | attention_mask: Optional[torch.Tensor] = None, 981 | **kwargs 982 | ) -> dict: 983 | 984 | MASK, gMASK = 150000, 150001 985 | mask_token = MASK if MASK in input_ids else gMASK 986 | use_gmask = False if MASK in input_ids else gMASK 987 | seq = input_ids[0].tolist() 988 | mask_position = seq.index(mask_token) 989 | 990 | if mask_token not in seq: 991 | raise ValueError("You have to add either [MASK] or [gMASK] in your input") 992 | 993 | # only last token for input_ids if past is not None 994 | if past is not None or past_key_values is not None: 995 | context_length = seq.index(150004) 996 | last_token = input_ids[:, -1].unsqueeze(-1) 997 | if self.position_encoding_2d: 998 | position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long, 999 | device=input_ids.device) 1000 | else: 1001 | position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device) 1002 | 1003 | if past is None: 1004 | past = past_key_values 1005 | return { 1006 | "input_ids": last_token, 1007 | "past_key_values": past, 1008 | "position_ids": position_ids, 1009 | } 1010 | else: 1011 | attention_mask, position_ids = self.get_masks_and_position_ids( 1012 | seq=seq, 1013 | mask_position=mask_position, 1014 | context_length=len(seq), 1015 | device=input_ids.device, 1016 | gmask=use_gmask 1017 | ) 1018 | 1019 | return { 1020 | "input_ids": input_ids, 1021 | "past_key_values": past, 1022 | "position_ids": position_ids, 1023 | "attention_mask": attention_mask 1024 | } 1025 | 1026 | def forward( 1027 | self, 1028 | input_ids: Optional[torch.Tensor] = None, 1029 | position_ids: Optional[torch.Tensor] = None, 1030 | attention_mask: Optional[torch.Tensor] = None, 1031 | past_key_values: Optional[Tuple[torch.FloatTensor]] = None, 1032 | inputs_embeds: Optional[torch.Tensor] = None, 1033 | labels: Optional[torch.Tensor] = None, 1034 | use_cache: Optional[bool] = None, 1035 | output_attentions: Optional[bool] = None, 1036 | output_hidden_states: Optional[bool] = None, 1037 | return_dict: Optional[bool] = None, 1038 | ): 1039 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1040 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1041 | 1042 | transformer_outputs = self.transformer( 1043 | input_ids=input_ids, 1044 | position_ids=position_ids, 1045 | attention_mask=attention_mask, 1046 | past_key_values=past_key_values, 1047 | inputs_embeds=inputs_embeds, 1048 | use_cache=use_cache, 1049 | output_attentions=output_attentions, 1050 | output_hidden_states=output_hidden_states, 1051 | return_dict=return_dict, 1052 | ) 1053 | 1054 | hidden_states = transformer_outputs[0] 1055 | 1056 | lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous() 1057 | 1058 | loss = None 1059 | if labels is not None: 1060 | lm_logits = lm_logits.to(torch.float32) 1061 | 1062 | # Shift so that tokens < n predict n 1063 | shift_logits = lm_logits[..., :-1, :].contiguous() 1064 | shift_labels = labels[..., 1:].contiguous() 1065 | # Flatten the tokens 1066 | loss_fct = CrossEntropyLoss() 1067 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1068 | 1069 | lm_logits = lm_logits.to(hidden_states.dtype) 1070 | loss = loss.to(hidden_states.dtype) 1071 | 1072 | if not return_dict: 1073 | output = (lm_logits,) + transformer_outputs[1:] 1074 | return ((loss,) + output) if loss is not None else output 1075 | 1076 | return CausalLMOutputWithPast( 1077 | loss=loss, 1078 | logits=lm_logits, 1079 | past_key_values=transformer_outputs.past_key_values, 1080 | hidden_states=transformer_outputs.hidden_states, 1081 | attentions=transformer_outputs.attentions, 1082 | ) 1083 | 1084 | @staticmethod 1085 | def _reorder_cache( 1086 | past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor 1087 | ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: 1088 | """ 1089 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 1090 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 1091 | beam_idx at every generation step. 1092 | 1093 | Output shares the same memory storage as `past`. 1094 | """ 1095 | return tuple( 1096 | ( 1097 | layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), 1098 | layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), 1099 | ) 1100 | for layer_past in past 1101 | ) 1102 | 1103 | @torch.no_grad() 1104 | def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, 1105 | do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): 1106 | if history is None: 1107 | history = [] 1108 | if logits_processor is None: 1109 | logits_processor = LogitsProcessorList() 1110 | logits_processor.append(InvalidScoreLogitsProcessor()) 1111 | gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, 1112 | "temperature": temperature, "logits_processor": logits_processor, **kwargs} 1113 | if not history: 1114 | prompt = query 1115 | else: 1116 | prompt = "" 1117 | for i, (old_query, response) in enumerate(history): 1118 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) 1119 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) 1120 | input_ids = tokenizer([prompt], return_tensors="pt", padding=True) 1121 | input_ids = input_ids.to(self.device) 1122 | outputs = self.generate(**input_ids, **gen_kwargs) 1123 | outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:] 1124 | response = tokenizer.decode(outputs) 1125 | response = response.strip() 1126 | response = response.replace("[[训练时间]]", "2023年") 1127 | history = history + [(query, response)] 1128 | return response, history 1129 | 1130 | @torch.no_grad() 1131 | def generate( 1132 | self, 1133 | **kwargs, 1134 | ): 1135 | MASK, gMASK = 150000, 150001 1136 | bos, eos = 150004, 150005 1137 | 1138 | if "eos_token_id" not in kwargs: 1139 | kwargs["eos_token_id"] = eos 1140 | 1141 | stop = False 1142 | 1143 | return_seqs = [] 1144 | 1145 | while True: 1146 | output_ids = super().generate(**kwargs) 1147 | 1148 | return_seqs = [] 1149 | max_length = 0 1150 | 1151 | for i in range(output_ids.shape[0]): 1152 | output_seq = output_ids[i].tolist() 1153 | mask_token = MASK if MASK in output_seq else gMASK 1154 | mask_position = output_seq.index(mask_token) 1155 | bos_position = output_seq.index(bos) 1156 | if eos in output_seq: 1157 | eos_position = output_seq.index(eos) 1158 | else: 1159 | eos_position = len(output_seq) 1160 | 1161 | return_seq = output_seq[:mask_position] + output_seq[bos_position + 1:eos_position] + output_seq[ 1162 | mask_position + 1:bos_position] 1163 | max_length = max(max_length, len(return_seq)) 1164 | return_seqs.append(return_seq) 1165 | 1166 | for i in range(output_ids.shape[0]): 1167 | return_seqs[i] = [0] * (max_length - len(return_seqs[i])) + return_seqs[i] # padding 1168 | if mask_token not in return_seqs[i]: 1169 | stop = True 1170 | 1171 | if stop: 1172 | break 1173 | 1174 | for return_seq in return_seqs: 1175 | return_seq += [bos] 1176 | 1177 | kwargs['input_ids'] = torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device) 1178 | 1179 | return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device) 1180 | 1181 | def quantize(self, bits: int): 1182 | from .quantization import quantize 1183 | self.transformer = quantize(self.transformer, bits) 1184 | return self 1185 | --------------------------------------------------------------------------------