├── README.md ├── chatglm ├── config.json ├── configuration_chatglm.py ├── ice_text.model ├── modeling_chatglm.py ├── quantization.py ├── test_modeling_chatglm.py ├── tokenization_chatglm.py └── tokenizer_config.json ├── load_and_predict.py ├── model └── model.weights ├── requirements.txt └── train_and_save.py /README.md: -------------------------------------------------------------------------------- 1 | * [chatglm-tiny: 从头开始训练一个chatglm小模型](https://zhuanlan.zhihu.com/p/642355086) 2 | * [pytorch在cpu上的一个bug:Semi-reproducible random torch.baddbmm NaNs](https://zhuanlan.zhihu.com/p/660594749) 3 | -------------------------------------------------------------------------------- /chatglm/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "THUDM/chatglm-6b", 3 | "architectures": [ 4 | "ChatGLMModel" 5 | ], 6 | "auto_map": { 7 | "AutoConfig": "configuration_chatglm.ChatGLMConfig", 8 | "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration", 9 | "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration" 10 | }, 11 | "bos_token_id": 130004, 12 | "eos_token_id": 130005, 13 | "mask_token_id": 130000, 14 | "gmask_token_id": 130001, 15 | "pad_token_id": 3, 16 | "hidden_size": 32, 17 | "inner_hidden_size": 32, 18 | "layernorm_epsilon": 1e-05, 19 | "max_sequence_length": 2048, 20 | "model_type": "chatglm", 21 | "num_attention_heads": 2, 22 | "num_layers": 2, 23 | "position_encoding_2d": true, 24 | "torch_dtype": "float16", 25 | "transformers_version": "4.23.1", 26 | "use_cache": true, 27 | "vocab_size": 130528 28 | } 29 | -------------------------------------------------------------------------------- /chatglm/configuration_chatglm.py: -------------------------------------------------------------------------------- 1 | """ ChatGLM model configuration """ 2 | 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.utils import logging 5 | 6 | logger = logging.get_logger(__name__) 7 | 8 | 9 | class ChatGLMConfig(PretrainedConfig): 10 | r""" 11 | This is the configuration class to store the configuration of a [`~ChatGLMModel`]. 12 | It is used to instantiate an ChatGLM model according to the specified arguments, defining the model 13 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 14 | the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture. 15 | 16 | Configuration objects inherit from [`PretrainedConfig`] and can be used 17 | to control the model outputs. Read the documentation from [`PretrainedConfig`] 18 | for more information. 19 | 20 | 21 | Args: 22 | vocab_size (`int`, *optional*, defaults to 150528): 23 | Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the 24 | `inputs_ids` passed when calling [`~ChatGLMModel`] or 25 | [`~TFChatGLMModel`]. 26 | hidden_size (`int`, *optional*, defaults to 4096): 27 | Dimension of the encoder layers and the pooler layer. 28 | num_hidden_layers (`int`, *optional*, defaults to 28): 29 | Number of hidden layers in the Transformer encoder. 30 | num_attention_heads (`int`, *optional*, defaults to 32): 31 | Number of attention heads for each attention layer in the Transformer encoder. 32 | inner_hidden_size (`int`, *optional*, defaults to 16384): 33 | Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 34 | max_sequence_length (`int`, *optional*, defaults to 512): 35 | The maximum sequence length that this model might ever be used with. 36 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 37 | layernorm_epsilon (`float`, *optional*, defaults to 1e-5): 38 | The epsilon used by the layer normalization layers. 39 | use_cache (`bool`, *optional*, defaults to `True`): 40 | Whether the model should return the last key/values attentions (not used by all models). 41 | Example: 42 | 43 | ```python 44 | >>> from configuration_chatglm import ChatGLMConfig 45 | >>> from modeling_chatglm import ChatGLMModel 46 | 47 | >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration 48 | >>> configuration = ChatGLMConfig() 49 | 50 | >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration 51 | >>> model = ChatGLMModel(configuration) 52 | 53 | >>> # Accessing the model configuration 54 | >>> configuration = model.config 55 | ``` 56 | """ 57 | model_type = "chatglm" 58 | 59 | def __init__( 60 | self, 61 | vocab_size=150528, 62 | hidden_size=4096, 63 | num_layers=28, 64 | num_attention_heads=32, 65 | layernorm_epsilon=1e-5, 66 | use_cache=False, 67 | bos_token_id=150004, 68 | eos_token_id=150005, 69 | mask_token_id=150000, 70 | gmask_token_id=150001, 71 | pad_token_id=0, 72 | max_sequence_length=2048, 73 | inner_hidden_size=16384, 74 | position_encoding_2d=True, 75 | quantization_bit=0, 76 | pre_seq_len=None, 77 | prefix_projection=False, 78 | **kwargs 79 | ): 80 | self.num_layers = num_layers 81 | self.vocab_size = vocab_size 82 | self.hidden_size = hidden_size 83 | self.num_attention_heads = num_attention_heads 84 | self.max_sequence_length = max_sequence_length 85 | self.layernorm_epsilon = layernorm_epsilon 86 | self.inner_hidden_size = inner_hidden_size 87 | self.use_cache = use_cache 88 | self.bos_token_id = bos_token_id 89 | self.eos_token_id = eos_token_id 90 | self.pad_token_id = pad_token_id 91 | self.mask_token_id = mask_token_id 92 | self.gmask_token_id = gmask_token_id 93 | self.position_encoding_2d = position_encoding_2d 94 | self.quantization_bit = quantization_bit 95 | self.pre_seq_len = pre_seq_len 96 | self.prefix_projection = prefix_projection 97 | 98 | super().__init__( 99 | pad_token_id=pad_token_id, 100 | bos_token_id=bos_token_id, 101 | eos_token_id=eos_token_id, 102 | **kwargs 103 | ) 104 | -------------------------------------------------------------------------------- /chatglm/ice_text.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinsblog/chatglm-tiny/5ac7766647cba6883576e5313cb2c61fb94d2e0f/chatglm/ice_text.model -------------------------------------------------------------------------------- /chatglm/modeling_chatglm.py: -------------------------------------------------------------------------------- 1 | """ PyTorch ChatGLM model. """ 2 | 3 | import math 4 | import copy 5 | import os 6 | import warnings 7 | import re 8 | import sys 9 | 10 | import torch 11 | import torch.utils.checkpoint 12 | import torch.nn.functional as F 13 | from torch import nn 14 | from torch.nn import CrossEntropyLoss, LayerNorm 15 | from torch.nn.utils import skip_init 16 | from typing import Optional, Tuple, Union, List, Callable, Dict, Any 17 | 18 | from transformers.utils import ( 19 | add_code_sample_docstrings, 20 | add_start_docstrings, 21 | add_start_docstrings_to_model_forward, 22 | ) 23 | from transformers.modeling_outputs import ( 24 | BaseModelOutputWithPast, 25 | CausalLMOutputWithPast, 26 | BaseModelOutputWithPastAndCrossAttentions, 27 | ) 28 | from transformers.modeling_utils import PreTrainedModel 29 | from transformers.utils import logging 30 | from transformers.generation.logits_process import LogitsProcessor 31 | from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput 32 | 33 | from .configuration_chatglm import ChatGLMConfig 34 | 35 | # flags required to enable jit fusion kernels 36 | 37 | if sys.platform != 'darwin': 38 | torch._C._jit_set_profiling_mode(False) 39 | torch._C._jit_set_profiling_executor(False) 40 | torch._C._jit_override_can_fuse_on_cpu(True) 41 | torch._C._jit_override_can_fuse_on_gpu(True) 42 | 43 | logger = logging.get_logger(__name__) 44 | 45 | _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B" 46 | _CONFIG_FOR_DOC = "ChatGLM6BConfig" 47 | 48 | CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ 49 | "THUDM/chatglm-6b", 50 | # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm 51 | ] 52 | 53 | 54 | class InvalidScoreLogitsProcessor(LogitsProcessor): 55 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 56 | if torch.isnan(scores).any() or torch.isinf(scores).any(): 57 | scores.zero_() 58 | scores[..., 5] = 5e4 59 | return scores 60 | 61 | 62 | def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): 63 | """Load tf checkpoints in a pytorch model.""" 64 | try: 65 | import re 66 | 67 | import numpy as np 68 | import tensorflow as tf 69 | except ImportError: 70 | logger.error( 71 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions." 73 | ) 74 | raise 75 | tf_path = os.path.abspath(tf_checkpoint_path) 76 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 77 | # Load weights from TF model 78 | init_vars = tf.train.list_variables(tf_path) 79 | names = [] 80 | arrays = [] 81 | for name, shape in init_vars: 82 | logger.info(f"Loading TF weight {name} with shape {shape}") 83 | array = tf.train.load_variable(tf_path, name) 84 | names.append(name) 85 | arrays.append(array) 86 | 87 | for name, array in zip(names, arrays): 88 | name = name.split("/") 89 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 90 | # which are not required for using pretrained model 91 | if any( 92 | n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] 93 | for n in name 94 | ): 95 | logger.info(f"Skipping {'/'.join(name)}") 96 | continue 97 | pointer = model 98 | for m_name in name: 99 | if re.fullmatch(r"[A-Za-z]+_\d+", m_name): 100 | scope_names = re.split(r"_(\d+)", m_name) 101 | else: 102 | scope_names = [m_name] 103 | if scope_names[0] == "kernel" or scope_names[0] == "gamma": 104 | pointer = getattr(pointer, "weight") 105 | elif scope_names[0] == "output_bias" or scope_names[0] == "beta": 106 | pointer = getattr(pointer, "bias") 107 | elif scope_names[0] == "output_weights": 108 | pointer = getattr(pointer, "weight") 109 | elif scope_names[0] == "squad": 110 | pointer = getattr(pointer, "classifier") 111 | else: 112 | try: 113 | pointer = getattr(pointer, scope_names[0]) 114 | except AttributeError: 115 | logger.info(f"Skipping {'/'.join(name)}") 116 | continue 117 | if len(scope_names) >= 2: 118 | num = int(scope_names[1]) 119 | pointer = pointer[num] 120 | if m_name[-11:] == "_embeddings": 121 | pointer = getattr(pointer, "weight") 122 | elif m_name == "kernel": 123 | array = np.transpose(array) 124 | try: 125 | assert ( 126 | pointer.shape == array.shape 127 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 128 | except AssertionError as e: 129 | e.args += (pointer.shape, array.shape) 130 | raise 131 | logger.info(f"Initialize PyTorch weight {name}") 132 | pointer.data = torch.from_numpy(array) 133 | return model 134 | 135 | 136 | class PrefixEncoder(torch.nn.Module): 137 | """ 138 | The torch.nn model to encode the prefix 139 | Input shape: (batch-size, prefix-length) 140 | Output shape: (batch-size, prefix-length, 2*layers*hidden) 141 | """ 142 | 143 | def __init__(self, config): 144 | super().__init__() 145 | self.prefix_projection = config.prefix_projection 146 | if self.prefix_projection: 147 | # Use a two-layer MLP to encode the prefix 148 | self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) 149 | self.trans = torch.nn.Sequential( 150 | torch.nn.Linear(config.hidden_size, config.hidden_size), 151 | torch.nn.Tanh(), 152 | torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) 153 | ) 154 | else: 155 | self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) 156 | 157 | def forward(self, prefix: torch.Tensor): 158 | if self.prefix_projection: 159 | prefix_tokens = self.embedding(prefix) 160 | past_key_values = self.trans(prefix_tokens) 161 | else: 162 | past_key_values = self.embedding(prefix) 163 | return past_key_values 164 | 165 | 166 | @torch.jit.script 167 | def gelu_impl(x): 168 | """OpenAI's gelu implementation.""" 169 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * 170 | (1.0 + 0.044715 * x * x))) 171 | 172 | 173 | def gelu(x): 174 | return gelu_impl(x) 175 | 176 | 177 | class RotaryEmbedding(torch.nn.Module): 178 | def __init__(self, dim, base=10000, precision=torch.half, learnable=False): 179 | super().__init__() 180 | inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) 181 | inv_freq = inv_freq.half() 182 | self.learnable = learnable 183 | if learnable: 184 | self.inv_freq = torch.nn.Parameter(inv_freq) 185 | self.max_seq_len_cached = None 186 | else: 187 | self.register_buffer('inv_freq', inv_freq) 188 | self.max_seq_len_cached = None 189 | self.cos_cached = None 190 | self.sin_cached = None 191 | self.precision = precision 192 | 193 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, 194 | error_msgs): 195 | pass 196 | 197 | def forward(self, x, seq_dim=1, seq_len=None): 198 | if seq_len is None: 199 | seq_len = x.shape[seq_dim] 200 | if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): 201 | self.max_seq_len_cached = None if self.learnable else seq_len 202 | t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) 203 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 204 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 205 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 206 | if self.precision == torch.bfloat16: 207 | emb = emb.float() 208 | 209 | # [sx, 1 (b * np), hn] 210 | cos_cached = emb.cos()[:, None, :] 211 | sin_cached = emb.sin()[:, None, :] 212 | if self.precision == torch.bfloat16: 213 | cos_cached = cos_cached.bfloat16() 214 | sin_cached = sin_cached.bfloat16() 215 | if self.learnable: 216 | return cos_cached, sin_cached 217 | self.cos_cached, self.sin_cached = cos_cached, sin_cached 218 | return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] 219 | 220 | def _apply(self, fn): 221 | if self.cos_cached is not None: 222 | self.cos_cached = fn(self.cos_cached) 223 | if self.sin_cached is not None: 224 | self.sin_cached = fn(self.sin_cached) 225 | return super()._apply(fn) 226 | 227 | 228 | def rotate_half(x): 229 | x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] 230 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions 231 | 232 | 233 | @torch.jit.script 234 | def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): 235 | # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] 236 | cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ 237 | F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) 238 | q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 239 | return q, k 240 | 241 | 242 | def attention_fn( 243 | self, 244 | query_layer, 245 | key_layer, 246 | value_layer, 247 | attention_mask, 248 | hidden_size_per_partition, 249 | layer_id, 250 | layer_past=None, 251 | scaling_attention_score=True, 252 | use_cache=False, 253 | ): 254 | if layer_past is not None: 255 | past_key, past_value = layer_past[0], layer_past[1] 256 | key_layer = torch.cat((past_key, key_layer), dim=0) 257 | value_layer = torch.cat((past_value, value_layer), dim=0) 258 | 259 | # seqlen, batch, num_attention_heads, hidden_size_per_attention_head 260 | seq_len, b, nh, hidden_size = key_layer.shape 261 | 262 | if use_cache: 263 | present = (key_layer, value_layer) 264 | else: 265 | present = None 266 | 267 | query_key_layer_scaling_coeff = float(layer_id + 1) 268 | if scaling_attention_score: 269 | query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) 270 | 271 | # =================================== 272 | # Raw attention scores. [b, np, s, s] 273 | # =================================== 274 | 275 | # [b, np, sq, sk] 276 | output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) 277 | 278 | # [sq, b, np, hn] -> [sq, b * np, hn] 279 | query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) 280 | # [sk, b, np, hn] -> [sk, b * np, hn] 281 | key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) 282 | 283 | matmul_result = torch.zeros( 284 | 1, 1, 1, 285 | dtype=query_layer.dtype, 286 | device=query_layer.device, 287 | ) 288 | 289 | matmul_result = torch.baddbmm( 290 | matmul_result, 291 | query_layer.transpose(0, 1), # [b * np, sq, hn] 292 | key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] 293 | beta=0.0, 294 | alpha=1.0, 295 | ) 296 | 297 | # change view to [b, np, sq, sk] 298 | attention_scores = matmul_result.view(*output_size) 299 | 300 | if self.scale_mask_softmax: 301 | self.scale_mask_softmax.scale = query_key_layer_scaling_coeff 302 | attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous()) 303 | else: 304 | if not (attention_mask == 0).all(): 305 | # if auto-regressive, skip 306 | attention_scores.masked_fill_(attention_mask, -10000.0) 307 | dtype = attention_scores.dtype 308 | attention_scores = attention_scores.float() 309 | attention_scores = attention_scores * query_key_layer_scaling_coeff 310 | 311 | attention_probs = F.softmax(attention_scores, dim=-1) 312 | 313 | attention_probs = attention_probs.type(dtype) 314 | 315 | # ========================= 316 | # Context layer. [sq, b, hp] 317 | # ========================= 318 | 319 | # value_layer -> context layer. 320 | # [sk, b, np, hn] --> [b, np, sq, hn] 321 | 322 | # context layer shape: [b, np, sq, hn] 323 | output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) 324 | 325 | # change view [sk, b * np, hn] 326 | value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) 327 | 328 | # change view [b * np, sq, sk] 329 | attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) 330 | 331 | # matmul: [b * np, sq, hn] 332 | context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) 333 | 334 | # change view [b, np, sq, hn] 335 | context_layer = context_layer.view(*output_size) 336 | 337 | # [b, np, sq, hn] --> [sq, b, np, hn] 338 | context_layer = context_layer.permute(2, 0, 1, 3).contiguous() 339 | 340 | # [sq, b, np, hn] --> [sq, b, hp] 341 | new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) 342 | context_layer = context_layer.view(*new_context_layer_shape) 343 | 344 | outputs = (context_layer, present, attention_probs) 345 | 346 | return outputs 347 | 348 | 349 | def default_init(cls, *args, **kwargs): 350 | return cls(*args, **kwargs) 351 | 352 | 353 | class SelfAttention(torch.nn.Module): 354 | def __init__(self, hidden_size, num_attention_heads, 355 | layer_id, hidden_size_per_attention_head=None, bias=True, 356 | params_dtype=torch.float, position_encoding_2d=True, empty_init=True): 357 | if empty_init: 358 | init_method = skip_init 359 | else: 360 | init_method = default_init 361 | super(SelfAttention, self).__init__() 362 | 363 | self.layer_id = layer_id 364 | self.hidden_size = hidden_size 365 | self.hidden_size_per_partition = hidden_size 366 | self.num_attention_heads = num_attention_heads 367 | self.num_attention_heads_per_partition = num_attention_heads 368 | self.position_encoding_2d = position_encoding_2d 369 | self.rotary_emb = RotaryEmbedding( 370 | self.hidden_size // (self.num_attention_heads * 2) 371 | if position_encoding_2d 372 | else self.hidden_size // self.num_attention_heads, 373 | base=10000, 374 | precision=torch.half, 375 | learnable=False, 376 | ) 377 | 378 | self.scale_mask_softmax = None 379 | 380 | if hidden_size_per_attention_head is None: 381 | self.hidden_size_per_attention_head = hidden_size // num_attention_heads 382 | else: 383 | self.hidden_size_per_attention_head = hidden_size_per_attention_head 384 | 385 | self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head 386 | 387 | # Strided linear layer. 388 | self.query_key_value = init_method( 389 | torch.nn.Linear, 390 | hidden_size, 391 | 3 * self.inner_hidden_size, 392 | bias=bias, 393 | dtype=params_dtype, 394 | ) 395 | 396 | self.dense = init_method( 397 | torch.nn.Linear, 398 | self.inner_hidden_size, 399 | hidden_size, 400 | bias=bias, 401 | dtype=params_dtype, 402 | ) 403 | 404 | @staticmethod 405 | def attention_mask_func(attention_scores, attention_mask): 406 | attention_scores.masked_fill_(attention_mask, -10000.0) 407 | return attention_scores 408 | 409 | def split_tensor_along_last_dim(self, tensor, num_partitions, 410 | contiguous_split_chunks=False): 411 | """Split a tensor along its last dimension. 412 | Arguments: 413 | tensor: input tensor. 414 | num_partitions: number of partitions to split the tensor 415 | contiguous_split_chunks: If True, make each chunk contiguous 416 | in memory. 417 | """ 418 | # Get the size and dimension. 419 | last_dim = tensor.dim() - 1 420 | last_dim_size = tensor.size()[last_dim] // num_partitions 421 | # Split. 422 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 423 | # Note: torch.split does not create contiguous tensors by default. 424 | if contiguous_split_chunks: 425 | return tuple(chunk.contiguous() for chunk in tensor_list) 426 | 427 | return tensor_list 428 | 429 | def forward( 430 | self, 431 | hidden_states: torch.Tensor, 432 | position_ids, 433 | attention_mask: torch.Tensor, 434 | layer_id, 435 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 436 | use_cache: bool = False, 437 | output_attentions: bool = False, 438 | ): 439 | """ 440 | hidden_states: [seq_len, batch, hidden_size] 441 | attention_mask: [(1, 1), seq_len, seq_len] 442 | """ 443 | 444 | # [seq_len, batch, 3 * hidden_size] 445 | mixed_raw_layer = self.query_key_value(hidden_states) 446 | 447 | # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head] 448 | new_tensor_shape = mixed_raw_layer.size()[:-1] + ( 449 | self.num_attention_heads_per_partition, 450 | 3 * self.hidden_size_per_attention_head, 451 | ) 452 | mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) 453 | 454 | # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] 455 | (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) 456 | 457 | if self.position_encoding_2d: 458 | q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) 459 | k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) 460 | cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) 461 | position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \ 462 | position_ids[:, 1, :].transpose(0, 1).contiguous() 463 | q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) 464 | q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) 465 | query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) 466 | key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) 467 | else: 468 | position_ids = position_ids.transpose(0, 1) 469 | cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) 470 | # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] 471 | query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) 472 | 473 | # [seq_len, batch, hidden_size] 474 | context_layer, present, attention_probs = attention_fn( 475 | self=self, 476 | query_layer=query_layer, 477 | key_layer=key_layer, 478 | value_layer=value_layer, 479 | attention_mask=attention_mask, 480 | hidden_size_per_partition=self.hidden_size_per_partition, 481 | layer_id=layer_id, 482 | layer_past=layer_past, 483 | use_cache=use_cache 484 | ) 485 | 486 | output = self.dense(context_layer) 487 | 488 | outputs = (output, present) 489 | 490 | if output_attentions: 491 | outputs += (attention_probs,) 492 | 493 | return outputs # output, present, attention_probs 494 | 495 | 496 | class GEGLU(torch.nn.Module): 497 | def __init__(self): 498 | super().__init__() 499 | self.activation_fn = F.gelu 500 | 501 | def forward(self, x): 502 | # dim=-1 breaks in jit for pt<1.10 503 | x1, x2 = x.chunk(2, dim=(x.ndim - 1)) 504 | return x1 * self.activation_fn(x2) 505 | 506 | 507 | class GLU(torch.nn.Module): 508 | def __init__(self, hidden_size, inner_hidden_size=None, 509 | layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): 510 | super(GLU, self).__init__() 511 | if empty_init: 512 | init_method = skip_init 513 | else: 514 | init_method = default_init 515 | self.layer_id = layer_id 516 | self.activation_func = activation_func 517 | 518 | # Project to 4h. 519 | self.hidden_size = hidden_size 520 | if inner_hidden_size is None: 521 | inner_hidden_size = 4 * hidden_size 522 | self.inner_hidden_size = inner_hidden_size 523 | self.dense_h_to_4h = init_method( 524 | torch.nn.Linear, 525 | self.hidden_size, 526 | self.inner_hidden_size, 527 | bias=bias, 528 | dtype=params_dtype, 529 | ) 530 | # Project back to h. 531 | self.dense_4h_to_h = init_method( 532 | torch.nn.Linear, 533 | self.inner_hidden_size, 534 | self.hidden_size, 535 | bias=bias, 536 | dtype=params_dtype, 537 | ) 538 | 539 | def forward(self, hidden_states): 540 | """ 541 | hidden_states: [seq_len, batch, hidden_size] 542 | """ 543 | 544 | # [seq_len, batch, inner_hidden_size] 545 | intermediate_parallel = self.dense_h_to_4h(hidden_states) 546 | 547 | intermediate_parallel = self.activation_func(intermediate_parallel) 548 | 549 | output = self.dense_4h_to_h(intermediate_parallel) 550 | 551 | return output 552 | 553 | 554 | class GLMBlock(torch.nn.Module): 555 | def __init__( 556 | self, 557 | hidden_size, 558 | num_attention_heads, 559 | layernorm_epsilon, 560 | layer_id, 561 | inner_hidden_size=None, 562 | hidden_size_per_attention_head=None, 563 | layernorm=LayerNorm, 564 | use_bias=True, 565 | params_dtype=torch.float, 566 | num_layers=28, 567 | position_encoding_2d=True, 568 | empty_init=True 569 | ): 570 | super(GLMBlock, self).__init__() 571 | # Set output layer initialization if not provided. 572 | 573 | self.layer_id = layer_id 574 | 575 | # Layernorm on the input data. 576 | self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) 577 | 578 | self.position_encoding_2d = position_encoding_2d 579 | 580 | # Self attention. 581 | self.attention = SelfAttention( 582 | hidden_size, 583 | num_attention_heads, 584 | layer_id, 585 | hidden_size_per_attention_head=hidden_size_per_attention_head, 586 | bias=use_bias, 587 | params_dtype=params_dtype, 588 | position_encoding_2d=self.position_encoding_2d, 589 | empty_init=empty_init 590 | ) 591 | 592 | # Layernorm on the input data. 593 | self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) 594 | 595 | self.num_layers = num_layers 596 | 597 | # GLU 598 | self.mlp = GLU( 599 | hidden_size, 600 | inner_hidden_size=inner_hidden_size, 601 | bias=use_bias, 602 | layer_id=layer_id, 603 | params_dtype=params_dtype, 604 | empty_init=empty_init 605 | ) 606 | 607 | def forward( 608 | self, 609 | hidden_states: torch.Tensor, 610 | position_ids, 611 | attention_mask: torch.Tensor, 612 | layer_id, 613 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 614 | use_cache: bool = False, 615 | output_attentions: bool = False, 616 | ): 617 | """ 618 | hidden_states: [seq_len, batch, hidden_size] 619 | attention_mask: [(1, 1), seq_len, seq_len] 620 | """ 621 | 622 | # Layer norm at the begining of the transformer layer. 623 | # [seq_len, batch, hidden_size] 624 | attention_input = self.input_layernorm(hidden_states) 625 | 626 | # Self attention. 627 | attention_outputs = self.attention( 628 | attention_input, 629 | position_ids, 630 | attention_mask=attention_mask, 631 | layer_id=layer_id, 632 | layer_past=layer_past, 633 | use_cache=use_cache, 634 | output_attentions=output_attentions 635 | ) 636 | 637 | attention_output = attention_outputs[0] 638 | 639 | outputs = attention_outputs[1:] 640 | 641 | # Residual connection. 642 | alpha = (2 * self.num_layers) ** 0.5 643 | hidden_states = attention_input * alpha + attention_output 644 | 645 | mlp_input = self.post_attention_layernorm(hidden_states) 646 | 647 | # MLP. 648 | mlp_output = self.mlp(mlp_input) 649 | 650 | # Second residual connection. 651 | output = mlp_input * alpha + mlp_output 652 | 653 | if use_cache: 654 | outputs = (output,) + outputs 655 | else: 656 | outputs = (output,) + outputs[1:] 657 | 658 | return outputs # hidden_states, present, attentions 659 | 660 | 661 | class ChatGLMPreTrainedModel(PreTrainedModel): 662 | """ 663 | An abstract class to handle weights initialization and 664 | a simple interface for downloading and loading pretrained models. 665 | """ 666 | 667 | is_parallelizable = False 668 | supports_gradient_checkpointing = True 669 | config_class = ChatGLMConfig 670 | base_model_prefix = "transformer" 671 | _no_split_modules = ["GLMBlock"] 672 | 673 | def __init__(self, *inputs, **kwargs): 674 | super().__init__(*inputs, **kwargs) 675 | 676 | def _init_weights(self, module: nn.Module): 677 | """Initialize the weights.""" 678 | return 679 | 680 | def get_masks(self, input_ids, device): 681 | batch_size, seq_length = input_ids.shape 682 | context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] 683 | attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) 684 | attention_mask.tril_() 685 | for i, context_length in enumerate(context_lengths): 686 | attention_mask[i, :, :context_length] = 1 687 | attention_mask.unsqueeze_(1) 688 | attention_mask = (attention_mask < 0.5).bool() 689 | 690 | return attention_mask 691 | 692 | def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): 693 | batch_size, seq_length = input_ids.shape 694 | if use_gmasks is None: 695 | use_gmasks = [False] * batch_size 696 | context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] 697 | if self.position_encoding_2d: 698 | position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) 699 | for i, context_length in enumerate(context_lengths): 700 | position_ids[i, context_length:] = mask_positions[i] 701 | block_position_ids = [torch.cat(( 702 | torch.zeros(context_length, dtype=torch.long, device=device), 703 | torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 704 | )) for context_length in context_lengths] 705 | block_position_ids = torch.stack(block_position_ids, dim=0) 706 | position_ids = torch.stack((position_ids, block_position_ids), dim=1) 707 | else: 708 | position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) 709 | for i, context_length in enumerate(context_lengths): 710 | if not use_gmasks[i]: 711 | position_ids[i, context_length:] = mask_positions[i] 712 | 713 | return position_ids 714 | 715 | def _set_gradient_checkpointing(self, module, value=False): 716 | if isinstance(module, ChatGLMModel): 717 | module.gradient_checkpointing = value 718 | 719 | 720 | CHATGLM_6B_START_DOCSTRING = r""" 721 | This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. 722 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general 723 | usage and behavior. 724 | 725 | Parameters: 726 | config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. 727 | Initializing with a config file does not load the weights associated with the model, only the configuration. 728 | Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 729 | """ 730 | 731 | CHATGLM_6B_INPUTS_DOCSTRING = r""" 732 | Args: 733 | input_ids (`torch.LongTensor` of shape `({0})`): 734 | Indices of input sequence tokens in the vocabulary. 735 | 736 | Indices can be obtained using [`ChatGLM6BTokenizer`]. 737 | See [`PreTrainedTokenizer.encode`] and 738 | [`PreTrainedTokenizer.__call__`] for details. 739 | 740 | [What are input IDs?](../glossary#input-ids) 741 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): 742 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 743 | 744 | - 1 for tokens that are **not masked**, 745 | - 0 for tokens that are **masked**. 746 | 747 | [What are attention masks?](../glossary#attention-mask) 748 | token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): 749 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: 750 | 751 | - 0 corresponds to a *sentence A* token, 752 | - 1 corresponds to a *sentence B* token. 753 | 754 | [What are token type IDs?](../glossary#token-type-ids) 755 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*): 756 | Indices of positions of each input sequence tokens in the position embeddings. 757 | Selected in the range `[0, config.max_position_embeddings - 1]`. 758 | 759 | [What are position IDs?](../glossary#position-ids) 760 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 761 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 762 | 763 | - 1 indicates the head is **not masked**, 764 | - 0 indicates the head is **masked**. 765 | 766 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): 767 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 768 | This is useful if you want more control over how to convert *input_ids* indices into associated vectors 769 | than the model's internal embedding lookup matrix. 770 | output_attentions (`bool`, *optional*): 771 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 772 | tensors for more detail. 773 | output_hidden_states (`bool`, *optional*): 774 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 775 | more detail. 776 | return_dict (`bool`, *optional*): 777 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 778 | """ 779 | 780 | 781 | @add_start_docstrings( 782 | "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.", 783 | CHATGLM_6B_START_DOCSTRING, 784 | ) 785 | class ChatGLMModel(ChatGLMPreTrainedModel): 786 | """ 787 | 788 | The model can behave as an encoder (with only self-attention) as well 789 | as a decoder, in which case a layer of cross-attention is added between 790 | the self-attention layers, following the architecture described in [Attention is 791 | all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, 792 | Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 793 | 794 | To behave as an decoder the model needs to be initialized with the 795 | `is_decoder` argument of the configuration set to `True`. 796 | To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` 797 | argument and `add_cross_attention` set to `True`; an 798 | `encoder_hidden_states` is then expected as an input to the forward pass. 799 | """ 800 | 801 | def __init__(self, config: ChatGLMConfig, empty_init=True): 802 | super().__init__(config) 803 | if empty_init: 804 | init_method = skip_init 805 | else: 806 | init_method = default_init 807 | # recording parameters 808 | self.max_sequence_length = config.max_sequence_length 809 | self.hidden_size = config.hidden_size 810 | self.params_dtype = torch.half 811 | self.num_attention_heads = config.num_attention_heads 812 | self.vocab_size = config.vocab_size 813 | self.num_layers = config.num_layers 814 | self.layernorm_epsilon = config.layernorm_epsilon 815 | self.inner_hidden_size = config.inner_hidden_size 816 | self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads 817 | self.position_encoding_2d = config.position_encoding_2d 818 | self.pre_seq_len = config.pre_seq_len 819 | self.prefix_projection = config.prefix_projection 820 | 821 | self.word_embeddings = init_method( 822 | torch.nn.Embedding, 823 | num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, 824 | dtype=self.params_dtype 825 | ) 826 | self.gradient_checkpointing = False 827 | 828 | def get_layer(layer_id): 829 | return GLMBlock( 830 | self.hidden_size, 831 | self.num_attention_heads, 832 | self.layernorm_epsilon, 833 | layer_id, 834 | inner_hidden_size=self.inner_hidden_size, 835 | hidden_size_per_attention_head=self.hidden_size_per_attention_head, 836 | layernorm=LayerNorm, 837 | use_bias=True, 838 | params_dtype=self.params_dtype, 839 | position_encoding_2d=self.position_encoding_2d, 840 | empty_init=empty_init 841 | ) 842 | 843 | self.layers = torch.nn.ModuleList( 844 | [get_layer(layer_id) for layer_id in range(self.num_layers)] 845 | ) 846 | 847 | # Final layer norm before output. 848 | self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) 849 | 850 | if self.pre_seq_len is not None: 851 | for param in self.parameters(): 852 | param.requires_grad = False 853 | self.prefix_tokens = torch.arange(self.pre_seq_len).long() 854 | self.prefix_encoder = PrefixEncoder(config) 855 | self.dropout = torch.nn.Dropout(0.1) 856 | 857 | # total_params = sum(p.numel() for p in self.parameters()) 858 | # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) 859 | # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) 860 | 861 | def get_input_embeddings(self): 862 | return self.word_embeddings 863 | 864 | def set_input_embeddings(self, new_embeddings: torch.Tensor): 865 | self.word_embeddings = new_embeddings 866 | 867 | def get_prompt(self, batch_size, device, dtype=torch.half): 868 | prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) 869 | past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) 870 | past_key_values = past_key_values.view( 871 | batch_size, 872 | self.pre_seq_len, 873 | self.num_layers * 2, 874 | self.num_attention_heads, 875 | self.hidden_size // self.num_attention_heads 876 | ) 877 | # seq_len, b, nh, hidden_size 878 | past_key_values = self.dropout(past_key_values) 879 | past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) 880 | # past_key_values = [(v[0], v[1]) for v in past_key_values] 881 | return past_key_values 882 | 883 | @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 884 | @add_code_sample_docstrings( 885 | checkpoint=_CHECKPOINT_FOR_DOC, 886 | output_type=BaseModelOutputWithPastAndCrossAttentions, 887 | config_class=_CONFIG_FOR_DOC, 888 | ) 889 | def forward( 890 | self, 891 | input_ids: Optional[torch.LongTensor] = None, 892 | position_ids: Optional[torch.LongTensor] = None, 893 | attention_mask: Optional[torch.Tensor] = None, 894 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 895 | inputs_embeds: Optional[torch.LongTensor] = None, 896 | use_cache: Optional[bool] = None, 897 | output_attentions: Optional[bool] = None, 898 | output_hidden_states: Optional[bool] = None, 899 | return_dict: Optional[bool] = None, 900 | ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: 901 | 902 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 903 | output_hidden_states = ( 904 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 905 | ) 906 | use_cache = use_cache if use_cache is not None else self.config.use_cache 907 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 908 | 909 | if self.gradient_checkpointing and self.training: 910 | if use_cache: 911 | logger.warning_once( 912 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 913 | ) 914 | use_cache = False 915 | 916 | if input_ids is not None and inputs_embeds is not None: 917 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 918 | elif input_ids is not None: 919 | batch_size, seq_length = input_ids.shape[:2] 920 | elif inputs_embeds is not None: 921 | batch_size, seq_length = inputs_embeds.shape[:2] 922 | else: 923 | raise ValueError("You have to specify either input_ids or inputs_embeds") 924 | 925 | if inputs_embeds is None: 926 | inputs_embeds = self.word_embeddings(input_ids) 927 | 928 | if past_key_values is None: 929 | if self.pre_seq_len is not None: 930 | past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, 931 | dtype=inputs_embeds.dtype) 932 | else: 933 | past_key_values = tuple([None] * len(self.layers)) 934 | 935 | if attention_mask is None: 936 | attention_mask = self.get_masks( 937 | input_ids, 938 | device=input_ids.device 939 | ) 940 | 941 | 942 | if position_ids is None: 943 | MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id 944 | seqs = input_ids.tolist() 945 | 946 | mask_positions, use_gmasks = [], [] 947 | for seq in seqs: 948 | mask_token = gMASK if gMASK in seq else MASK 949 | use_gmask = mask_token == gMASK 950 | mask_positions.append(seq.index(mask_token)) 951 | use_gmasks.append(use_gmask) 952 | 953 | position_ids = self.get_position_ids( 954 | input_ids, 955 | mask_positions=mask_positions, 956 | device=input_ids.device, 957 | use_gmasks=use_gmasks 958 | ) 959 | 960 | if self.pre_seq_len is not None and attention_mask is not None: 961 | prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( 962 | attention_mask.device) 963 | prefix_attention_mask = (prefix_attention_mask < 0.5).bool() 964 | attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) 965 | 966 | # [seq_len, batch, hidden_size] 967 | hidden_states = inputs_embeds.transpose(0, 1) 968 | 969 | presents = () if use_cache else None 970 | all_self_attentions = () if output_attentions else None 971 | all_hidden_states = () if output_hidden_states else None 972 | 973 | if attention_mask is None: 974 | attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() 975 | else: 976 | attention_mask = attention_mask.to(hidden_states.device) 977 | 978 | for i, layer in enumerate(self.layers): 979 | 980 | if output_hidden_states: 981 | all_hidden_states = all_hidden_states + (hidden_states,) 982 | layer_past = past_key_values[i] 983 | 984 | if self.gradient_checkpointing and self.training: 985 | layer_ret = torch.utils.checkpoint.checkpoint( 986 | layer, 987 | hidden_states, 988 | position_ids, 989 | attention_mask, 990 | torch.tensor(i), 991 | layer_past, 992 | use_cache, 993 | output_attentions 994 | ) 995 | else: 996 | layer_ret = layer( 997 | hidden_states, 998 | position_ids=position_ids, 999 | attention_mask=attention_mask, 1000 | layer_id=torch.tensor(i), 1001 | layer_past=layer_past, 1002 | use_cache=use_cache, 1003 | output_attentions=output_attentions 1004 | ) 1005 | 1006 | hidden_states = layer_ret[0] 1007 | 1008 | if use_cache: 1009 | presents = presents + (layer_ret[1],) 1010 | 1011 | if output_attentions: 1012 | all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],) 1013 | 1014 | # Final layer norm. 1015 | hidden_states = self.final_layernorm(hidden_states) 1016 | 1017 | if output_hidden_states: 1018 | all_hidden_states = all_hidden_states + (hidden_states,) 1019 | 1020 | if not return_dict: 1021 | return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) 1022 | 1023 | return BaseModelOutputWithPast( 1024 | last_hidden_state=hidden_states, 1025 | past_key_values=presents, 1026 | hidden_states=all_hidden_states, 1027 | attentions=all_self_attentions, 1028 | ) 1029 | 1030 | 1031 | class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): 1032 | def __init__(self, config: ChatGLMConfig, empty_init=True): 1033 | super().__init__(config) 1034 | if empty_init: 1035 | init_method = skip_init 1036 | else: 1037 | init_method = default_init 1038 | 1039 | # self.hidden_size = config.hidden_size 1040 | # self.params_dtype = torch.half 1041 | # self.vocab_size = config.vocab_size 1042 | self.max_sequence_length = config.max_sequence_length 1043 | 1044 | self.position_encoding_2d = config.position_encoding_2d 1045 | 1046 | self.transformer = ChatGLMModel(config, empty_init=empty_init) 1047 | 1048 | self.lm_head = init_method( 1049 | nn.Linear, 1050 | config.hidden_size, 1051 | config.vocab_size, 1052 | bias=False, 1053 | dtype=torch.half 1054 | ) 1055 | 1056 | self.config = config 1057 | 1058 | self.quantized = False 1059 | 1060 | if self.config.quantization_bit: 1061 | self.quantize(self.config.quantization_bit, empty_init=True) 1062 | 1063 | def get_output_embeddings(self): 1064 | return self.lm_head 1065 | 1066 | def set_output_embeddings(self, new_embeddings): 1067 | self.lm_head = new_embeddings 1068 | 1069 | def _update_model_kwargs_for_generation( 1070 | self, 1071 | outputs: ModelOutput, 1072 | model_kwargs: Dict[str, Any], 1073 | is_encoder_decoder: bool = False, 1074 | standardize_cache_format: bool = False, 1075 | ) -> Dict[str, Any]: 1076 | # update past_key_values 1077 | model_kwargs["past_key_values"] = self._extract_past_from_model_output( 1078 | outputs, standardize_cache_format=standardize_cache_format 1079 | ) 1080 | 1081 | # update attention mask 1082 | if "attention_mask" in model_kwargs: 1083 | attention_mask = model_kwargs["attention_mask"] 1084 | if attention_mask is not None and attention_mask.dtype == torch.bool: 1085 | attention_mask = torch.cat( 1086 | [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) 1087 | new_attention_mask = attention_mask[:, :, -1:].clone() 1088 | new_attention_mask[..., -1] = False 1089 | model_kwargs["attention_mask"] = torch.cat( 1090 | [attention_mask, new_attention_mask], dim=2 1091 | ) 1092 | 1093 | # update position ids 1094 | if "position_ids" in model_kwargs: 1095 | position_ids = model_kwargs["position_ids"] 1096 | new_position_id = position_ids[..., -1:].clone() 1097 | new_position_id[:, 1, :] += 1 1098 | model_kwargs["position_ids"] = torch.cat( 1099 | [position_ids, new_position_id], dim=-1 1100 | ) 1101 | 1102 | return model_kwargs 1103 | 1104 | def prepare_inputs_for_generation( 1105 | self, 1106 | input_ids: torch.LongTensor, 1107 | past: Optional[torch.Tensor] = None, 1108 | past_key_values: Optional[torch.Tensor] = None, 1109 | attention_mask: Optional[torch.Tensor] = None, 1110 | position_ids: Optional[torch.Tensor] = None, 1111 | **kwargs 1112 | ) -> dict: 1113 | batch_size, seq_length = input_ids.shape 1114 | MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id 1115 | seqs = input_ids.tolist() 1116 | mask_positions, use_gmasks = [], [] 1117 | for seq in seqs: 1118 | mask_token = gMASK if gMASK in seq else MASK 1119 | use_gmask = mask_token == gMASK 1120 | mask_positions.append(seq.index(mask_token)) 1121 | use_gmasks.append(use_gmask) 1122 | 1123 | # only last token for input_ids if past is not None 1124 | if past is not None or past_key_values is not None: 1125 | last_token = input_ids[:, -1].unsqueeze(-1) 1126 | if attention_mask is not None and attention_mask.dtype == torch.bool: 1127 | attention_mask = attention_mask[:, :, -1:] 1128 | else: 1129 | attention_mask = None 1130 | if position_ids is not None: 1131 | position_ids = position_ids[..., -1:] 1132 | else: 1133 | context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] 1134 | if self.position_encoding_2d: 1135 | position_ids = torch.tensor( 1136 | [[mask_position, seq_length - context_length] for mask_position, context_length in 1137 | zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) 1138 | else: 1139 | position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, 1140 | device=input_ids.device).unsqueeze(-1) 1141 | 1142 | if past is None: 1143 | past = past_key_values 1144 | return { 1145 | "input_ids": last_token, 1146 | "past_key_values": past, 1147 | "position_ids": position_ids, 1148 | "attention_mask": attention_mask 1149 | } 1150 | else: 1151 | if attention_mask is not None and attention_mask.dtype != torch.bool: 1152 | logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") 1153 | attention_mask = None 1154 | if attention_mask is None: 1155 | attention_mask = self.get_masks( 1156 | input_ids, 1157 | device=input_ids.device 1158 | ) 1159 | if position_ids is None: 1160 | position_ids = self.get_position_ids( 1161 | input_ids, 1162 | device=input_ids.device, 1163 | mask_positions=mask_positions, 1164 | use_gmasks=use_gmasks 1165 | ) 1166 | 1167 | return { 1168 | "input_ids": input_ids, 1169 | "past_key_values": past, 1170 | "position_ids": position_ids, 1171 | "attention_mask": attention_mask 1172 | } 1173 | 1174 | def forward( 1175 | self, 1176 | input_ids: Optional[torch.Tensor] = None, 1177 | position_ids: Optional[torch.Tensor] = None, 1178 | attention_mask: Optional[torch.Tensor] = None, 1179 | past_key_values: Optional[Tuple[torch.FloatTensor]] = None, 1180 | inputs_embeds: Optional[torch.Tensor] = None, 1181 | labels: Optional[torch.Tensor] = None, 1182 | use_cache: Optional[bool] = None, 1183 | output_attentions: Optional[bool] = None, 1184 | output_hidden_states: Optional[bool] = None, 1185 | return_dict: Optional[bool] = None, 1186 | ): 1187 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1188 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1189 | 1190 | transformer_outputs = self.transformer( 1191 | input_ids=input_ids, 1192 | position_ids=position_ids, 1193 | attention_mask=attention_mask, 1194 | past_key_values=past_key_values, 1195 | inputs_embeds=inputs_embeds, 1196 | use_cache=use_cache, 1197 | output_attentions=output_attentions, 1198 | output_hidden_states=output_hidden_states, 1199 | return_dict=return_dict, 1200 | ) 1201 | 1202 | hidden_states = transformer_outputs[0] 1203 | 1204 | lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous() 1205 | 1206 | loss = None 1207 | if labels is not None: 1208 | lm_logits = lm_logits.to(torch.float32) 1209 | 1210 | # Shift so that tokens < n predict n 1211 | shift_logits = lm_logits[..., :-1, :].contiguous() 1212 | shift_labels = labels[..., 1:].contiguous() 1213 | # Flatten the tokens 1214 | loss_fct = CrossEntropyLoss(ignore_index=-100) 1215 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1216 | 1217 | lm_logits = lm_logits.to(hidden_states.dtype) 1218 | loss = loss.to(hidden_states.dtype) 1219 | 1220 | if not return_dict: 1221 | output = (lm_logits,) + transformer_outputs[1:] 1222 | return ((loss,) + output) if loss is not None else output 1223 | 1224 | return CausalLMOutputWithPast( 1225 | loss=loss, 1226 | logits=lm_logits, 1227 | past_key_values=transformer_outputs.past_key_values, 1228 | hidden_states=transformer_outputs.hidden_states, 1229 | attentions=transformer_outputs.attentions, 1230 | ) 1231 | 1232 | @staticmethod 1233 | def _reorder_cache( 1234 | past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor 1235 | ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: 1236 | """ 1237 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 1238 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 1239 | beam_idx at every generation step. 1240 | 1241 | Output shares the same memory storage as `past`. 1242 | """ 1243 | return tuple( 1244 | ( 1245 | layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), 1246 | layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), 1247 | ) 1248 | for layer_past in past 1249 | ) 1250 | 1251 | def process_response(self, response): 1252 | response = response.strip() 1253 | response = response.replace("[[训练时间]]", "2023年") 1254 | punkts = [ 1255 | [",", ","], 1256 | ["!", "!"], 1257 | [":", ":"], 1258 | [";", ";"], 1259 | ["\?", "?"], 1260 | ] 1261 | for item in punkts: 1262 | response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) 1263 | response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) 1264 | return response 1265 | 1266 | @torch.no_grad() 1267 | def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, 1268 | do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): 1269 | if history is None: 1270 | history = [] 1271 | if logits_processor is None: 1272 | logits_processor = LogitsProcessorList() 1273 | logits_processor.append(InvalidScoreLogitsProcessor()) 1274 | gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": False, "top_p": top_p, 1275 | "temperature": temperature, "logits_processor": logits_processor, **kwargs} 1276 | if not history: 1277 | prompt = query 1278 | else: 1279 | prompt = "" 1280 | for i, (old_query, response) in enumerate(history): 1281 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) 1282 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) 1283 | inputs = tokenizer([prompt], return_tensors="pt") 1284 | inputs = inputs.to(self.device) 1285 | outputs = self.generate(**inputs, **gen_kwargs) 1286 | outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] 1287 | response = tokenizer.decode(outputs) 1288 | response = self.process_response(response) 1289 | history = history + [(query, response)] 1290 | return response, history 1291 | 1292 | @torch.no_grad() 1293 | def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, 1294 | do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): 1295 | if history is None: 1296 | history = [] 1297 | if logits_processor is None: 1298 | logits_processor = LogitsProcessorList() 1299 | logits_processor.append(InvalidScoreLogitsProcessor()) 1300 | gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, 1301 | "temperature": temperature, "logits_processor": logits_processor, **kwargs} 1302 | if not history: 1303 | prompt = query 1304 | else: 1305 | prompt = "" 1306 | for i, (old_query, response) in enumerate(history): 1307 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) 1308 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) 1309 | inputs = tokenizer([prompt], return_tensors="pt") 1310 | inputs = inputs.to(self.device) 1311 | for outputs in self.stream_generate(**inputs, **gen_kwargs): 1312 | outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] 1313 | response = tokenizer.decode(outputs) 1314 | response = self.process_response(response) 1315 | new_history = history + [(query, response)] 1316 | yield response, new_history 1317 | 1318 | @torch.no_grad() 1319 | def stream_generate( 1320 | self, 1321 | input_ids, 1322 | generation_config: Optional[GenerationConfig] = None, 1323 | logits_processor: Optional[LogitsProcessorList] = None, 1324 | stopping_criteria: Optional[StoppingCriteriaList] = None, 1325 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, 1326 | **kwargs, 1327 | ): 1328 | batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] 1329 | 1330 | if generation_config is None: 1331 | generation_config = self.generation_config 1332 | generation_config = copy.deepcopy(generation_config) 1333 | model_kwargs = generation_config.update(**kwargs) 1334 | bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id 1335 | 1336 | if isinstance(eos_token_id, int): 1337 | eos_token_id = [eos_token_id] 1338 | 1339 | has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None 1340 | if has_default_max_length and generation_config.max_new_tokens is None: 1341 | warnings.warn( 1342 | f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " 1343 | "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" 1344 | " recommend using `max_new_tokens` to control the maximum length of the generation.", 1345 | UserWarning, 1346 | ) 1347 | elif generation_config.max_new_tokens is not None: 1348 | generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length 1349 | if not has_default_max_length: 1350 | logger.warn( 1351 | f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" 1352 | f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " 1353 | "Please refer to the documentation for more information. " 1354 | "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", 1355 | UserWarning, 1356 | ) 1357 | 1358 | if input_ids_seq_length >= generation_config.max_length: 1359 | input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" 1360 | logger.warning( 1361 | f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" 1362 | f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" 1363 | " increasing `max_new_tokens`." 1364 | ) 1365 | 1366 | # 2. Set generation parameters if not already defined 1367 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 1368 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 1369 | 1370 | logits_processor = self._get_logits_processor( 1371 | generation_config=generation_config, 1372 | input_ids_seq_length=input_ids_seq_length, 1373 | encoder_input_ids=input_ids, 1374 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 1375 | logits_processor=logits_processor, 1376 | ) 1377 | 1378 | stopping_criteria = self._get_stopping_criteria( 1379 | generation_config=generation_config, stopping_criteria=stopping_criteria 1380 | ) 1381 | logits_warper = self._get_logits_warper(generation_config) 1382 | 1383 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) 1384 | scores = None 1385 | while True: 1386 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 1387 | # forward pass to get next token 1388 | outputs = self( 1389 | **model_inputs, 1390 | return_dict=True, 1391 | output_attentions=False, 1392 | output_hidden_states=False, 1393 | ) 1394 | 1395 | next_token_logits = outputs.logits[:, -1, :] 1396 | 1397 | # pre-process distribution 1398 | next_token_scores = logits_processor(input_ids, next_token_logits) 1399 | next_token_scores = logits_warper(input_ids, next_token_scores) 1400 | 1401 | # sample 1402 | probs = nn.functional.softmax(next_token_scores, dim=-1) 1403 | if generation_config.do_sample: 1404 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 1405 | else: 1406 | next_tokens = torch.argmax(probs, dim=-1) 1407 | 1408 | # update generated ids, model inputs, and length for next step 1409 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 1410 | model_kwargs = self._update_model_kwargs_for_generation( 1411 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 1412 | ) 1413 | unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) 1414 | 1415 | # stop when each sentence is finished, or if we exceed the maximum length 1416 | if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): 1417 | break 1418 | yield input_ids 1419 | 1420 | def quantize(self, bits: int, empty_init=False, **kwargs): 1421 | if bits == 0: 1422 | return 1423 | 1424 | from .quantization import quantize 1425 | 1426 | if self.quantized: 1427 | logger.info("Already quantized.") 1428 | return self 1429 | 1430 | self.quantized = True 1431 | 1432 | self.config.quantization_bit = bits 1433 | 1434 | self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs) 1435 | return self 1436 | -------------------------------------------------------------------------------- /chatglm/quantization.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear 2 | from torch.nn.parameter import Parameter 3 | 4 | import bz2 5 | import torch 6 | import base64 7 | import ctypes 8 | from transformers.utils import logging 9 | 10 | from typing import List 11 | from functools import partial 12 | 13 | logger = logging.get_logger(__name__) 14 | 15 | try: 16 | from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up 17 | 18 | class Kernel: 19 | def __init__(self, code: bytes, function_names: List[str]): 20 | self.code = code 21 | self._function_names = function_names 22 | self._cmodule = LazyKernelCModule(self.code) 23 | 24 | for name in self._function_names: 25 | setattr(self, name, KernelFunction(self._cmodule, name)) 26 | 27 | quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" 28 | 29 | kernels = Kernel( 30 | bz2.decompress(base64.b64decode(quantization_code)), 31 | [ 32 | "int4WeightCompression", 33 | "int4WeightExtractionFloat", 34 | "int4WeightExtractionHalf", 35 | "int8WeightExtractionFloat", 36 | "int8WeightExtractionHalf", 37 | ], 38 | ) 39 | except Exception as exception: 40 | kernels = None 41 | logger.warning("Failed to load cpm_kernels:" + str(exception)) 42 | 43 | 44 | class W8A16Linear(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): 47 | ctx.inp_shape = inp.size() 48 | ctx.weight_bit_width = weight_bit_width 49 | out_features = quant_w.size(0) 50 | inp = inp.contiguous().view(-1, inp.size(-1)) 51 | weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) 52 | ctx.weight_shape = weight.size() 53 | output = inp.mm(weight.t()) 54 | ctx.save_for_backward(inp, quant_w, scale_w) 55 | return output.view(*(ctx.inp_shape[:-1] + (out_features,))) 56 | 57 | @staticmethod 58 | def backward(ctx, grad_output: torch.Tensor): 59 | inp, quant_w, scale_w = ctx.saved_tensors 60 | weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) 61 | grad_output = grad_output.contiguous().view(-1, weight.size(0)) 62 | grad_input = grad_output.mm(weight) 63 | grad_weight = grad_output.t().mm(inp) 64 | return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None 65 | 66 | 67 | def compress_int4_weight(weight: torch.Tensor): # (n, m) 68 | with torch.cuda.device(weight.device): 69 | n, m = weight.size(0), weight.size(1) 70 | assert m % 2 == 0 71 | m = m // 2 72 | out = torch.empty(n, m, dtype=torch.int8, device="cuda") 73 | stream = torch.cuda.current_stream() 74 | 75 | gridDim = (n, 1, 1) 76 | blockDim = (min(round_up(m, 32), 1024), 1, 1) 77 | 78 | kernels.int4WeightCompression( 79 | gridDim, 80 | blockDim, 81 | 0, 82 | stream, 83 | [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], 84 | ) 85 | return out 86 | 87 | 88 | def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): 89 | if source_bit_width == 8: 90 | func = kernels.int8WeightExtractionHalf 91 | elif source_bit_width == 4: 92 | func = kernels.int4WeightExtractionHalf 93 | else: 94 | assert False, "Unsupported bit-width" 95 | 96 | with torch.cuda.device(weight.device): 97 | n, m = weight.size(0), weight.size(1) 98 | out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda") 99 | stream = torch.cuda.current_stream() 100 | 101 | gridDim = (n, 1, 1) 102 | blockDim = (min(round_up(m, 32), 1024), 1, 1) 103 | 104 | func( 105 | gridDim, 106 | blockDim, 107 | 0, 108 | stream, 109 | [ 110 | ctypes.c_void_p(weight.data_ptr()), 111 | ctypes.c_void_p(scale_list.data_ptr()), 112 | ctypes.c_void_p(out.data_ptr()), 113 | ctypes.c_int32(n), 114 | ctypes.c_int32(m), 115 | ], 116 | ) 117 | return out 118 | 119 | 120 | class QuantizedLinear(Linear): 121 | def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, empty_init=False, *args, **kwargs): 122 | super(QuantizedLinear, self).__init__(*args, **kwargs) 123 | self.weight_bit_width = weight_bit_width 124 | 125 | shape = self.weight.shape 126 | del self.weight 127 | 128 | if weight_tensor is None or empty_init: 129 | self.weight = torch.empty( 130 | shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] 131 | ) 132 | self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"]) 133 | else: 134 | self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() 135 | self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8) 136 | if weight_bit_width == 4: 137 | self.weight = compress_int4_weight(self.weight) 138 | 139 | self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) 140 | self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) 141 | if bias_tensor is not None: 142 | self.bias = Parameter(bias_tensor.to(kwargs["device"]), requires_grad=False) 143 | else: 144 | self.bias = None 145 | 146 | def forward(self, input): 147 | output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) 148 | if self.bias is not None: 149 | output = output + self.bias 150 | return output 151 | 152 | 153 | def quantize(model, weight_bit_width, empty_init=False, **kwargs): 154 | """Replace fp16 linear with quantized linear""" 155 | 156 | for layer in model.layers: 157 | layer.attention.query_key_value = QuantizedLinear( 158 | weight_bit_width=weight_bit_width, 159 | weight_tensor=layer.attention.query_key_value.weight.to(torch.cuda.current_device()), 160 | bias_tensor=layer.attention.query_key_value.bias, 161 | in_features=layer.attention.query_key_value.in_features, 162 | out_features=layer.attention.query_key_value.out_features, 163 | bias=True, 164 | dtype=torch.half, 165 | device=layer.attention.query_key_value.weight.device, 166 | empty_init=empty_init 167 | ) 168 | layer.attention.dense = QuantizedLinear( 169 | weight_bit_width=weight_bit_width, 170 | weight_tensor=layer.attention.dense.weight.to(torch.cuda.current_device()), 171 | bias_tensor=layer.attention.dense.bias, 172 | in_features=layer.attention.dense.in_features, 173 | out_features=layer.attention.dense.out_features, 174 | bias=True, 175 | dtype=torch.half, 176 | device=layer.attention.dense.weight.device, 177 | empty_init=empty_init 178 | ) 179 | layer.mlp.dense_h_to_4h = QuantizedLinear( 180 | weight_bit_width=weight_bit_width, 181 | weight_tensor=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), 182 | bias_tensor=layer.mlp.dense_h_to_4h.bias, 183 | in_features=layer.mlp.dense_h_to_4h.in_features, 184 | out_features=layer.mlp.dense_h_to_4h.out_features, 185 | bias=True, 186 | dtype=torch.half, 187 | device=layer.mlp.dense_h_to_4h.weight.device, 188 | empty_init=empty_init 189 | ) 190 | layer.mlp.dense_4h_to_h = QuantizedLinear( 191 | weight_bit_width=weight_bit_width, 192 | weight_tensor=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), 193 | bias_tensor=layer.mlp.dense_4h_to_h.bias, 194 | in_features=layer.mlp.dense_4h_to_h.in_features, 195 | out_features=layer.mlp.dense_4h_to_h.out_features, 196 | bias=True, 197 | dtype=torch.half, 198 | device=layer.mlp.dense_4h_to_h.weight.device, 199 | empty_init=empty_init 200 | ) 201 | return model 202 | -------------------------------------------------------------------------------- /chatglm/test_modeling_chatglm.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import math 3 | import unittest 4 | import torch 5 | import random 6 | 7 | from transformers import AutoTokenizer, AutoModel 8 | from transformers.testing_utils import require_torch, slow, torch_device 9 | 10 | 11 | def set_random_seed(seed): 12 | import random 13 | 14 | random.seed(seed) 15 | 16 | # pytorch RNGs 17 | import torch 18 | 19 | torch.manual_seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | if torch.cuda.is_available(): 22 | torch.cuda.manual_seed_all(seed) 23 | 24 | # numpy RNG 25 | import numpy as np 26 | 27 | np.random.seed(seed) 28 | 29 | 30 | 31 | def ids_tensor(shape, vocab_size): 32 | # Creates a random int32 tensor of the shape within the vocab size 33 | total_dims = 1 34 | for dim in shape: 35 | total_dims *= dim 36 | 37 | values = [] 38 | for _ in range(total_dims): 39 | values.append(random.randint(0, vocab_size - 1)) 40 | 41 | return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() 42 | 43 | 44 | def get_model_and_tokenizer(): 45 | model = AutoModel.from_pretrained("/mnt/vepfs/workspace/zxdu/chatglm_6b", trust_remote_code=True).half() 46 | model.to(torch_device) 47 | model.eval() 48 | tokenizer = AutoTokenizer.from_pretrained("/mnt/vepfs/workspace/zxdu/chatglm_6b", trust_remote_code=True) 49 | return model, tokenizer 50 | 51 | 52 | @require_torch 53 | class ChatGLMGenerationTest(unittest.TestCase): 54 | def get_generation_kwargs(self): 55 | pass 56 | 57 | def test_chat(self): 58 | model, tokenizer = get_model_and_tokenizer() 59 | prompts = ["你好", "介绍一下清华大学", "它创建于哪一年"] 60 | history = [] 61 | set_random_seed(42) 62 | expected_responses = [ 63 | '你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 64 | '清华大学是中国著名的综合性研究型大学,位于中国北京市海淀区,创建于 1911 年,前身是清华学堂。作为我国顶尖高等教育机构之一,清华大学在科学研究、工程技术、信息技术、经济管理等领域处于领先地位,也是世界上最著名的工程学府之一。\n\n清华大学拥有世界一流的教学设施和科学研究平台,设有多个学院和研究中心,包括工程学院、自然科学学院、社会科学学院、人文学院、法学院、经济管理学院等。学校拥有众多知名教授和研究团队,其中包括多位院士、国家杰出青年科学基金获得者、长江学者等。\n\n清华大学的本科生招生范围为全国中学毕业生,本科生入学要求严格,考试成绩优秀。同时,清华大学也提供研究生和博士生招生,包括硕士研究生和博士研究生。', 65 | '清华大学创建于 1911 年。' 66 | ] 67 | for (prompt, expected_response) in zip(prompts, expected_responses): 68 | response, history = model.chat(tokenizer, prompt, history=history) 69 | print(repr(response)) 70 | self.assertEquals(expected_response, response) 71 | 72 | def test_stream_chat(self): 73 | model, tokenizer = get_model_and_tokenizer() 74 | prompts = ["你好", "介绍一下清华大学", "它创建于哪一年"] 75 | history = [] 76 | expected_responses = [ 77 | '你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 78 | '清华大学是中国著名的综合性研究型大学,位于中国北京市海淀区,创建于 1911 年,前身是清华学堂。作为我国顶尖高等教育机构之一,清华大学在科学研究、工程技术、信息技术、经济管理等领域处于领先地位,也是世界上最著名的工程学府之一。\n\n清华大学拥有世界一流的教学设施和科学研究平台,设有多个学院和研究中心,包括工程学院、自然科学学院、社会科学学院、人文学院、法学院、经济管理学院等。学校拥有众多知名教授和研究团队,其中包括多位院士、国家杰出青年科学基金获得者、长江学者等。\n\n清华大学的本科生招生范围为全国中学毕业生,本科生入学要求严格,考试成绩优秀。同时,清华大学也提供研究生和博士生招生,包括硕士研究生和博士研究生。', 79 | '清华大学创建于 1911 年。' 80 | ] 81 | set_random_seed(42) 82 | for prompt, expected_response in zip(prompts, expected_responses): 83 | response = "" 84 | for idx, (response, history) in enumerate(model.stream_chat(tokenizer, prompt, history=history)): 85 | pass 86 | print(repr(response)) 87 | self.assertEquals(expected_response, response) 88 | 89 | def test_generation(self): 90 | model, tokenizer = get_model_and_tokenizer() 91 | sentence = "晚上睡不着怎么办" 92 | parameters = [(False, 2048, 1), 93 | (False, 64, 1), 94 | (True, 2048, 1), 95 | (True, 64, 1), 96 | (True, 2048, 4)] 97 | expected_out_sentences = [ 98 | '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。\n\n3. 避免刺激性物质:避免饮用含咖啡因的饮料,如咖啡、茶和可乐,并尽可能减少饮酒。\n\n4. 放松身心:尝试进行放松的活动,如冥想、深呼吸、瑜伽或听轻柔的音乐。\n\n5. 避免在床上做其他事情:例如看电视、使用电脑或智能手机等。\n\n6. 练习放松技巧:例如渐进性肌肉松弛法、冥想或深呼吸练习。\n\n7. 寻求帮助:如果长时间都无法正常入睡,可以考虑咨询医生或专业心理医生,寻求更进一步的帮助。\n\n希望这些方法能有助于入睡。', 99 | '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 保持规律的睡眠时间表:尽量在同一时间上床,并尝试在早上醒来时自然起床。\n\n2. 创建舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,并使用舒适的床垫和枕头。', 100 | '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体释放褪黑素,进而导致难以入睡。建议你在睡前一小时停止使用这些设备。\n\n3. 创建舒适的睡眠环境:确保卧室安静、黑暗、凉爽,舒适的床垫和枕头,保持卧室温度适宜,这有助于让你更容易入睡。\n\n4. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽或轻松的散步,减轻压力和焦虑,让你更容易入睡。\n\n5. 避免咖啡因和酒精:咖啡因和酒精会让大脑更加兴奋,进而干扰身体入睡过程。建议在睡前几小时避免饮用这些物质。\n\n6. 做一些安静的活动:阅读一本书、听轻柔的音乐、绣或者绘画等安静的活动,有助于自己放松身心,进而更容易入睡。\n\n如果采取以上这些方法仍然无法入睡,建议咨询医生或专业的睡眠专家,获取更好的建议和帮助。', 101 | '晚上睡不着怎么办 以下是一些有助于在晚上更好地入睡的方法:\n\n1. 维持规律的睡眠时间:每晚尽可能在同一时间上床,保持规律的睡眠时间表,帮助身体调整并更容易入睡。\n\n2. 避免在床上使用电子设备:手机、平板电脑、电脑等电子设备会发出蓝光,这会干扰身体', 102 | '晚上睡不着怎么办 以下是一些可能有助于在晚上入睡的方法:\n\n1. 建立规律的睡眠时间表:尽量在同一时间入睡和起床,即使在周末和假期也要尽量保持一致。\n\n2. 创造舒适的睡眠环境:保持房间安静、凉爽、黑暗、舒适,使用舒适的床垫和枕头等。\n\n3. 放松身心:尝试进行一些放松的活动,如冥想、深呼吸、瑜伽、听轻柔的音乐等,缓解压力和紧张情绪。\n\n4. 避免刺激性物质:避免饮用咖啡、茶、可乐等含咖啡因的饮料,避免吸烟和饮酒等刺激性物质。\n\n5. 避免躺在床上翻来覆去:如果躺在床上超过20分钟还不能入睡,就不要躺在床上翻来覆去,而是起床去做一些放松的活动,直到感到困倦为止。\n\n6. 练习放松技巧:如果感到焦虑或紧张,可以尝试进行一些放松技巧,如渐进性肌肉松弛、冥想等。\n\n7. 改善睡眠障碍:如果已经尝试了上述方法仍然无法入睡,可以考虑咨询医生,了解是否存在其他睡眠障碍问题,并接受相应的治疗。'] 103 | for (do_sample, max_length, num_beams), expected_output_sentence in zip(parameters, expected_out_sentences): 104 | set_random_seed(42) 105 | inputs = tokenizer(sentence, return_tensors="pt") 106 | inputs = inputs.to(torch_device) 107 | 108 | outputs = model.generate( 109 | **inputs, 110 | do_sample=do_sample, 111 | max_length=max_length, 112 | num_beams=num_beams 113 | ) 114 | 115 | outputs = outputs.tolist()[0] 116 | out_sentence = tokenizer.decode(outputs, skip_special_tokens=True) 117 | print(out_sentence) 118 | self.assertEquals(expected_output_sentence, out_sentence) 119 | 120 | def test_batch_generation(self): 121 | model, tokenizer = get_model_and_tokenizer() 122 | sentences = [ 123 | "你好", 124 | "介绍一下清华大学" 125 | ] 126 | parameters = [(False, 2048, 1), 127 | (False, 64, 1), 128 | (True, 2048, 1), 129 | (True, 64, 1), 130 | (True, 2048, 4)] 131 | expected_out_sentences = [ 132 | ['你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 133 | '介绍一下清华大学 清华大学是中国著名的综合性大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,1946年迁回清华园。新中国成立后,清华学校更名为清华大学。\n\n清华大学是中国最顶尖的大学之一,在工程、科学、技术、经济、管理等领域都有很高的学术声誉和影响力。学校拥有世界一流的教学设施和科学研究平台,有多个学院和研究中心,包括工程学院、自然科学学院、人文学院、社会科学学院、经济管理学院、法学院、美术学院、医学院、器学院等。\n\n清华大学的本科生招生始于2000年,实行全面二孩政策后,本科生招生规模不断扩大。截至2022年,清华大学共有本科生近3万人,研究生近2万人,其中国际学生占比约为10%。清华大学的本科生教育注重通识教育和个性化培养,强调实践、创新、国际化和综合素质。'], 134 | [ 135 | '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 136 | '介绍一下清华大学 清华大学是中国著名的综合性大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,1946年迁回' 137 | ], 138 | [ 139 | '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 140 | '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路 30 号,其溯源于 1911 年创建的清华学堂, 1925 年更名为清华学校, 1937 年秋抗日战争全面爆发后闭校。1949 年 10 月开学复校,成为我国第一个社会主义大学生活了的高校。截至 2023 年,清华学校共管辖 2 个学院、13 个系,有本科专业 60 个,研究生专业 190 个。' 141 | ], 142 | [ 143 | '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 144 | '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路 30 号,其溯源于 1911 年创建的清华学堂, 1925 年更名为清华学校, 1937 年秋抗日战争全面爆发后' 145 | ], 146 | [ 147 | '你好 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。', 148 | '介绍一下清华大学 清华大学是中国著名的综合性研究型大学,位于北京市海淀区双清路30号,其历史可以追溯到1911年创建的清华学堂,1925年更名为清华学校,1937年抗日战争全面爆发后南迁长沙,与北京大学、南开大学组建国立长沙临时大学,1938年迁至 昆明改名为国立西南联合大学,1946年迁回北京。新中国成立后,清华学校更名为清华大学。' 149 | ] 150 | ] 151 | for (do_sample, max_length, num_beams), expected_output_sentence in zip(parameters, expected_out_sentences): 152 | set_random_seed(42) 153 | inputs = tokenizer(sentences, return_tensors="pt", padding=True) 154 | inputs = inputs.to(torch_device) 155 | 156 | outputs = model.generate( 157 | **inputs, 158 | do_sample=do_sample, 159 | max_length=max_length, 160 | num_beams=num_beams 161 | ) 162 | 163 | batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) 164 | print(batch_out_sentence) 165 | self.assertListEqual(expected_output_sentence, batch_out_sentence) 166 | -------------------------------------------------------------------------------- /chatglm/tokenization_chatglm.py: -------------------------------------------------------------------------------- 1 | """Tokenization classes for ChatGLM.""" 2 | from typing import List, Optional, Union 3 | import os 4 | 5 | from transformers.tokenization_utils import PreTrainedTokenizer 6 | from transformers.utils import logging, PaddingStrategy 7 | from transformers.tokenization_utils_base import EncodedInput, BatchEncoding 8 | from typing import Dict 9 | import sentencepiece as spm 10 | import numpy as np 11 | 12 | logger = logging.get_logger(__name__) 13 | 14 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 15 | "THUDM/chatglm-6b": 2048, 16 | } 17 | 18 | 19 | class TextTokenizer: 20 | def __init__(self, model_path): 21 | self.sp = spm.SentencePieceProcessor() 22 | self.sp.Load(model_path) 23 | self.num_tokens = self.sp.vocab_size() 24 | 25 | def encode(self, text): 26 | return self.sp.EncodeAsIds(text) 27 | 28 | def decode(self, ids: List[int]): 29 | return self.sp.DecodeIds(ids) 30 | 31 | def tokenize(self, text): 32 | return self.sp.EncodeAsPieces(text) 33 | 34 | def convert_tokens_to_string(self, tokens): 35 | return self.sp.DecodePieces(tokens) 36 | 37 | def convert_tokens_to_ids(self, tokens): 38 | return [self.sp.PieceToId(token) for token in tokens] 39 | 40 | def convert_token_to_id(self, token): 41 | return self.sp.PieceToId(token) 42 | 43 | def convert_id_to_token(self, idx): 44 | return self.sp.IdToPiece(idx) 45 | 46 | def __len__(self): 47 | return self.num_tokens 48 | 49 | 50 | class SPTokenizer: 51 | def __init__( 52 | self, 53 | vocab_file, 54 | num_image_tokens=20000, 55 | max_blank_length=80, 56 | byte_fallback=True, 57 | ): 58 | assert vocab_file is not None 59 | self.vocab_file = vocab_file 60 | self.num_image_tokens = num_image_tokens 61 | self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""] 62 | self.max_blank_length = max_blank_length 63 | self.byte_fallback = byte_fallback 64 | self.text_tokenizer = TextTokenizer(vocab_file) 65 | 66 | def _get_text_tokenizer(self): 67 | return self.text_tokenizer 68 | 69 | @staticmethod 70 | def get_blank_token(length: int): 71 | assert length >= 2 72 | return f"<|blank_{length}|>" 73 | 74 | @staticmethod 75 | def get_tab_token(): 76 | return f"<|tab|>" 77 | 78 | @property 79 | def num_text_tokens(self): 80 | return self.text_tokenizer.num_tokens 81 | 82 | @property 83 | def num_tokens(self): 84 | return self.num_image_tokens + self.num_text_tokens 85 | 86 | @staticmethod 87 | def _encode_whitespaces(text: str, max_len: int = 80): 88 | text = text.replace("\t", SPTokenizer.get_tab_token()) 89 | for i in range(max_len, 1, -1): 90 | text = text.replace(" " * i, SPTokenizer.get_blank_token(i)) 91 | return text 92 | 93 | def _preprocess(self, text: str, linebreak=True, whitespaces=True): 94 | if linebreak: 95 | text = text.replace("\n", "") 96 | if whitespaces: 97 | text = self._encode_whitespaces(text, max_len=self.max_blank_length) 98 | return text 99 | 100 | def encode( 101 | self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True 102 | ) -> List[int]: 103 | """ 104 | @param text: Text to encode. 105 | @param linebreak: Whether to encode newline (\n) in text. 106 | @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. 107 | @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. 108 | @param add_dummy_prefix: Whether to add dummy blank space in the beginning. 109 | """ 110 | text = self._preprocess(text, linebreak, whitespaces) 111 | if not add_dummy_prefix: 112 | text = "" + text 113 | tmp = self._get_text_tokenizer().encode(text) 114 | tokens = [x + self.num_image_tokens for x in tmp] 115 | return tokens if add_dummy_prefix else tokens[2:] 116 | 117 | def postprocess(self, text): 118 | text = text.replace("", "\n") 119 | text = text.replace(SPTokenizer.get_tab_token(), "\t") 120 | for i in range(2, self.max_blank_length + 1): 121 | text = text.replace(self.get_blank_token(i), " " * i) 122 | return text 123 | 124 | def decode(self, text_ids: List[int]) -> str: 125 | ids = [int(_id) - self.num_image_tokens for _id in text_ids] 126 | ids = [_id for _id in ids if _id >= 0] 127 | text = self._get_text_tokenizer().decode(ids) 128 | text = self.postprocess(text) 129 | return text 130 | 131 | def decode_tokens(self, tokens: List[str]) -> str: 132 | text = self._get_text_tokenizer().convert_tokens_to_string(tokens) 133 | text = self.postprocess(text) 134 | return text 135 | 136 | def tokenize( 137 | self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True 138 | ) -> List[str]: 139 | """ 140 | @param text: Text to encode. 141 | @param linebreak: Whether to encode newline (\n) in text. 142 | @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. 143 | @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. 144 | @param add_dummy_prefix: Whether to add dummy blank space in the beginning. 145 | """ 146 | text = self._preprocess(text, linebreak, whitespaces) 147 | if not add_dummy_prefix: 148 | text = "" + text 149 | tokens = self._get_text_tokenizer().tokenize(text) 150 | return tokens if add_dummy_prefix else tokens[2:] 151 | 152 | def __getitem__(self, x: Union[int, str]): 153 | if isinstance(x, int): 154 | if x < self.num_image_tokens: 155 | return "".format(x) 156 | else: 157 | return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens) 158 | elif isinstance(x, str): 159 | if x.startswith("") and x[7:-1].isdigit(): 160 | return int(x[7:-1]) 161 | else: 162 | return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens 163 | else: 164 | raise ValueError("The key should be str or int.") 165 | 166 | 167 | class ChatGLMTokenizer(PreTrainedTokenizer): 168 | """ 169 | Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding. 170 | 171 | Args: 172 | vocab_file (`str`): 173 | Path to the vocabulary file. 174 | """ 175 | 176 | vocab_files_names = {"vocab_file": "ice_text.model"} 177 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 178 | model_input_names = ["input_ids", "attention_mask", "position_ids"] 179 | 180 | def __init__( 181 | self, 182 | vocab_file, 183 | do_lower_case=False, 184 | remove_space=False, 185 | bos_token='', 186 | eos_token='', 187 | end_token='', 188 | mask_token='[MASK]', 189 | gmask_token='[gMASK]', 190 | padding_side="left", 191 | pad_token="", 192 | unk_token="", 193 | num_image_tokens=20000, 194 | **kwargs 195 | ) -> None: 196 | super().__init__( 197 | do_lower_case=do_lower_case, 198 | remove_space=remove_space, 199 | padding_side=padding_side, 200 | bos_token=bos_token, 201 | eos_token=eos_token, 202 | end_token=end_token, 203 | mask_token=mask_token, 204 | gmask_token=gmask_token, 205 | pad_token=pad_token, 206 | unk_token=unk_token, 207 | num_image_tokens=num_image_tokens, 208 | **kwargs 209 | ) 210 | 211 | self.do_lower_case = do_lower_case 212 | self.remove_space = remove_space 213 | self.vocab_file = vocab_file 214 | 215 | self.bos_token = bos_token 216 | self.eos_token = eos_token 217 | self.end_token = end_token 218 | self.mask_token = mask_token 219 | self.gmask_token = gmask_token 220 | 221 | self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens) 222 | 223 | """ Initialisation """ 224 | 225 | @property 226 | def gmask_token_id(self) -> Optional[int]: 227 | if self.gmask_token is None: 228 | return None 229 | return self.convert_tokens_to_ids(self.gmask_token) 230 | 231 | @property 232 | def end_token_id(self) -> Optional[int]: 233 | """ 234 | `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been 235 | set. 236 | """ 237 | if self.end_token is None: 238 | return None 239 | return self.convert_tokens_to_ids(self.end_token) 240 | 241 | @property 242 | def vocab_size(self): 243 | """ Returns vocab size """ 244 | return self.sp_tokenizer.num_tokens 245 | 246 | def get_vocab(self): 247 | """ Returns vocab as a dict """ 248 | vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} 249 | vocab.update(self.added_tokens_encoder) 250 | return vocab 251 | 252 | def preprocess_text(self, inputs): 253 | if self.remove_space: 254 | outputs = " ".join(inputs.strip().split()) 255 | else: 256 | outputs = inputs 257 | 258 | if self.do_lower_case: 259 | outputs = outputs.lower() 260 | 261 | return outputs 262 | 263 | def _tokenize(self, text, **kwargs): 264 | """ Returns a tokenized string. """ 265 | text = self.preprocess_text(text) 266 | 267 | seq = self.sp_tokenizer.tokenize(text) 268 | 269 | return seq 270 | 271 | def convert_tokens_to_string(self, tokens: List[str]) -> str: 272 | return self.sp_tokenizer.decode_tokens(tokens) 273 | 274 | def _decode( 275 | self, 276 | token_ids: Union[int, List[int]], 277 | **kwargs 278 | ) -> str: 279 | if isinstance(token_ids, int): 280 | token_ids = [token_ids] 281 | if len(token_ids) == 0: 282 | return "" 283 | if self.pad_token_id in token_ids: # remove pad 284 | token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) 285 | return super()._decode(token_ids, **kwargs) 286 | 287 | def _convert_token_to_id(self, token): 288 | """ Converts a token (str) in an id using the vocab. """ 289 | return self.sp_tokenizer[token] 290 | 291 | def _convert_id_to_token(self, index): 292 | """Converts an index (integer) in a token (str) using the vocab.""" 293 | return self.sp_tokenizer[index] 294 | 295 | def save_vocabulary(self, save_directory, filename_prefix=None): 296 | """ 297 | Save the vocabulary and special tokens file to a directory. 298 | 299 | Args: 300 | save_directory (`str`): 301 | The directory in which to save the vocabulary. 302 | filename_prefix (`str`, *optional*): 303 | An optional prefix to add to the named of the saved files. 304 | 305 | Returns: 306 | `Tuple(str)`: Paths to the files saved. 307 | """ 308 | if os.path.isdir(save_directory): 309 | vocab_file = os.path.join( 310 | save_directory, self.vocab_files_names["vocab_file"] 311 | ) 312 | else: 313 | vocab_file = save_directory 314 | 315 | with open(self.vocab_file, 'rb') as fin: 316 | proto_str = fin.read() 317 | 318 | with open(vocab_file, "wb") as writer: 319 | writer.write(proto_str) 320 | 321 | return (vocab_file,) 322 | 323 | def build_inputs_with_special_tokens( 324 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 325 | ) -> List[int]: 326 | """ 327 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 328 | adding special tokens. A BERT sequence has the following format: 329 | 330 | - single sequence: `[CLS] X [SEP]` 331 | - pair of sequences: `[CLS] A [SEP] B [SEP]` 332 | 333 | Args: 334 | token_ids_0 (`List[int]`): 335 | List of IDs to which the special tokens will be added. 336 | token_ids_1 (`List[int]`, *optional*): 337 | Optional second list of IDs for sequence pairs. 338 | 339 | Returns: 340 | `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. 341 | """ 342 | gmask_id = self.sp_tokenizer[self.gmask_token] 343 | eos_id = self.sp_tokenizer[self.eos_token] 344 | token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]] 345 | if token_ids_1 is not None: 346 | token_ids_0 = token_ids_0 + token_ids_1 + [eos_id] 347 | return token_ids_0 348 | 349 | def _pad( 350 | self, 351 | encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], 352 | max_length: Optional[int] = None, 353 | padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, 354 | pad_to_multiple_of: Optional[int] = None, 355 | return_attention_mask: Optional[bool] = None, 356 | ) -> dict: 357 | """ 358 | Pad encoded inputs (on left/right and up to predefined length or max length in the batch) 359 | 360 | Args: 361 | encoded_inputs: 362 | Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). 363 | max_length: maximum length of the returned list and optionally padding length (see below). 364 | Will truncate by taking into account the special tokens. 365 | padding_strategy: PaddingStrategy to use for padding. 366 | 367 | - PaddingStrategy.LONGEST Pad to the longest sequence in the batch 368 | - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) 369 | - PaddingStrategy.DO_NOT_PAD: Do not pad 370 | The tokenizer padding sides are defined in self.padding_side: 371 | 372 | - 'left': pads on the left of the sequences 373 | - 'right': pads on the right of the sequences 374 | pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. 375 | This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability 376 | `>= 7.5` (Volta). 377 | return_attention_mask: 378 | (optional) Set to False to avoid returning attention mask (default: set to model specifics) 379 | """ 380 | # Load from model defaults 381 | bos_token_id = self.sp_tokenizer[self.bos_token] 382 | mask_token_id = self.sp_tokenizer[self.mask_token] 383 | gmask_token_id = self.sp_tokenizer[self.gmask_token] 384 | assert self.padding_side == "left" 385 | 386 | required_input = encoded_inputs[self.model_input_names[0]] 387 | seq_length = len(required_input) 388 | 389 | if padding_strategy == PaddingStrategy.LONGEST: 390 | max_length = len(required_input) 391 | 392 | if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): 393 | max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of 394 | 395 | needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length 396 | 397 | # Initialize attention mask if not present. 398 | if max_length is not None: 399 | if "attention_mask" not in encoded_inputs: 400 | if bos_token_id in required_input: 401 | context_length = required_input.index(bos_token_id) 402 | else: 403 | context_length = seq_length 404 | attention_mask = np.ones((1, seq_length, seq_length)) 405 | attention_mask = np.tril(attention_mask) 406 | attention_mask[:, :, :context_length] = 1 407 | attention_mask = np.bool_(attention_mask < 0.5) 408 | encoded_inputs["attention_mask"] = attention_mask 409 | 410 | if "position_ids" not in encoded_inputs: 411 | if bos_token_id in required_input: 412 | context_length = required_input.index(bos_token_id) 413 | else: 414 | context_length = seq_length 415 | position_ids = np.arange(seq_length, dtype=np.int64) 416 | mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id 417 | if mask_token in required_input: 418 | mask_position = required_input.index(mask_token) 419 | position_ids[context_length:] = mask_position 420 | block_position_ids = np.concatenate( 421 | [np.zeros(context_length, dtype=np.int64), 422 | np.arange(1, seq_length - context_length + 1, dtype=np.int64)]) 423 | encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) 424 | 425 | if needs_to_be_padded: 426 | difference = max_length - len(required_input) 427 | 428 | if "attention_mask" in encoded_inputs: 429 | encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"], 430 | pad_width=[(0, 0), (difference, 0), (difference, 0)], 431 | mode='constant', constant_values=True) 432 | if "token_type_ids" in encoded_inputs: 433 | encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ 434 | "token_type_ids" 435 | ] 436 | if "special_tokens_mask" in encoded_inputs: 437 | encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] 438 | if "position_ids" in encoded_inputs: 439 | encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"], 440 | pad_width=[(0, 0), (difference, 0)]) 441 | encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input 442 | 443 | return encoded_inputs 444 | -------------------------------------------------------------------------------- /chatglm/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name_or_path": "THUDM/chatglm-6b", 3 | "bos_token": "", 4 | "eos_token": "", 5 | "end_token": "", 6 | "gmask_token": "[gMASK]", 7 | "mask_token": "[MASK]", 8 | "pad_token": "", 9 | "unk_token": "", 10 | "remove_space": false, 11 | "do_lower_case": false, 12 | "tokenizer_class": "ChatGLMTokenizer", 13 | "num_image_tokens": 0, 14 | "auto_map": { 15 | "AutoTokenizer": [ 16 | "tokenization_chatglm.ChatGLMTokenizer", 17 | null 18 | ] 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /load_and_predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from chatglm.configuration_chatglm import ChatGLMConfig 4 | from chatglm.modeling_chatglm import ChatGLMForConditionalGeneration 5 | from chatglm.tokenization_chatglm import ChatGLMTokenizer 6 | 7 | 8 | if __name__ == '__main__': 9 | # 初始化模型 10 | config = ChatGLMConfig.from_pretrained("chatglm/config.json") 11 | model = ChatGLMForConditionalGeneration(config=config, empty_init=False).bfloat16() 12 | tokenizer = ChatGLMTokenizer.from_pretrained("chatglm") 13 | 14 | # 加载权重 15 | model.load_state_dict(torch.load("model/model.weights")) 16 | 17 | # 准备测试数据 18 | test_data = [ 19 | ('你好', 'hello'), 20 | ('苹果', 'apple') 21 | ] 22 | 23 | # 测试生成效果 24 | model = model.eval() 25 | for query, _ in test_data: 26 | response, history = model.chat(tokenizer, query, history=[], max_length=20) 27 | print(query, response) 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /model/model.weights: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinsblog/chatglm-tiny/5ac7766647cba6883576e5313cb2c61fb94d2e0f/model/model.weights -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf 2 | transformers==4.27.1 3 | cpm_kernels 4 | torch>=1.10 5 | gradio 6 | mdtex2html 7 | sentencepiece 8 | accelerate 9 | datasets 10 | -------------------------------------------------------------------------------- /train_and_save.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data import Dataset 4 | 5 | from chatglm.configuration_chatglm import ChatGLMConfig 6 | from chatglm.modeling_chatglm import ChatGLMForConditionalGeneration 7 | from chatglm.tokenization_chatglm import ChatGLMTokenizer 8 | 9 | global tokenizer, config, model 10 | 11 | 12 | class TrainDataset(Dataset): 13 | 14 | def __init__(self, train_data, max_seq_length=512): 15 | super().__init__() 16 | self.data = train_data 17 | self.max_seq_length = max_seq_length 18 | 19 | def __getitem__(self, index): 20 | context, target = self.data[index] 21 | context_ids = tokenizer.encode(context, max_length=self.max_seq_length, truncation=True) 22 | target_ids = tokenizer.encode(target, max_length=self.max_seq_length, truncation=True, add_special_tokens=False) 23 | input_ids = context_ids + target_ids + [tokenizer.eos_token_id] 24 | return {'input_ids': input_ids, 'context_len': len(context_ids)} 25 | 26 | def __len__(self): 27 | return len(self.data) 28 | 29 | 30 | def collate_fn(batch): 31 | len_ids = [len(d["input_ids"]) for d in batch] 32 | longest = max(len_ids) # 之后按照batch中最长的input_ids进行padding 33 | 34 | input_ids = [] 35 | labels_list = [] 36 | 37 | for length, d in sorted(zip(len_ids, batch), key=lambda x: -x[0]): 38 | ids = d["input_ids"] 39 | context_len = d["context_len"] 40 | 41 | labels = ( 42 | [-100] * (context_len - 1) + ids[(context_len - 1):] + [-100] * (longest - length) 43 | ) # -100标志位后面会在计算loss时会被忽略不贡献损失,我们集中优化target部分生成的loss 44 | 45 | ids = ids + [tokenizer.pad_token_id] * (longest - length) 46 | 47 | input_ids.append(torch.LongTensor(ids)) 48 | labels_list.append(torch.LongTensor(labels)) 49 | 50 | input_ids = torch.stack(input_ids) 51 | labels = torch.stack(labels_list) 52 | return { 53 | "input_ids": input_ids, 54 | "labels": labels, 55 | } 56 | 57 | 58 | if __name__ == '__main__': 59 | # 初始化模型 60 | config = ChatGLMConfig.from_pretrained("chatglm/config.json") 61 | model = ChatGLMForConditionalGeneration(config=config, empty_init=False).bfloat16() 62 | tokenizer = ChatGLMTokenizer.from_pretrained("chatglm") 63 | 64 | print(model) 65 | total_params = sum(p.numel() for p in model.parameters()) 66 | print(f'{total_params:,} total parameters.') 67 | total_trainable_params = sum( 68 | p.numel() for p in model.parameters() if p.requires_grad) 69 | print(f'{total_trainable_params:,} training parameters.') 70 | 71 | # 准备训练数据 72 | train_data = [ 73 | ('你好', 'hello'), 74 | ('苹果', 'apple') 75 | ] 76 | train_dataset = TrainDataset(train_data=train_data) 77 | train_dataloader = DataLoader(dataset=train_dataset, collate_fn=collate_fn, shuffle=True, batch_size=2) 78 | 79 | # 开始训练 80 | LR = 1e-2 81 | NUM_EPOCHS = 100 82 | 83 | optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01) 84 | 85 | model.train() 86 | 87 | for epoch in range(NUM_EPOCHS): 88 | model.train() 89 | for step, batch in enumerate(train_dataloader): 90 | batch = {k: v for k, v in batch.items()} 91 | outputs = model(**batch) 92 | loss = outputs.loss.detach().float() 93 | print(f"epoch={epoch}, step={step}, loss={loss}") 94 | outputs.loss.backward() 95 | optimizer.step() 96 | optimizer.zero_grad() 97 | 98 | # 保存权重 99 | torch.save(model.state_dict(), "model/model.weights") 100 | --------------------------------------------------------------------------------