├── LICENSE ├── README.md ├── constants.py ├── convert.py ├── generate.py ├── gguf_reader.py ├── gpt2_tokenizer.py ├── llamacpp_kernel.cu ├── model.py ├── py_bind.cpp ├── register_lib.py ├── requirements.txt ├── setup.py └── tp.py /LICENSE: -------------------------------------------------------------------------------- 1 | # Copied from gpt-fast repo 2 | Copyright 2023 Meta 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convert [`llama.cpp`](https://github.com/ggerganov/llama.cpp) to Pytorch 2 | 3 | The `llama.cpp` library is a cornerstone in language modeling with a variety of quantization techniques, but it's largely used within its own ecosystem. This repo's aim is to make these methods more accessible to the PyTorch community. 4 | 5 | This repo provides an example for converting GGUF files back into PyTorch state dict, allowing you to run inference purely in PyTorch. Currently supported models: 6 | 7 | * LLaMA / Mistral 8 | * Mixtral 9 | * Qwen / Qwen2 10 | * InternLM2 11 | * StarCoder2 12 | * Orion 13 | * MiniCPM 14 | * Xverse 15 | * Command-r-v01 16 | * StableLM 17 | * Gemma 18 | 19 | The code is largely inspired by the original [`llama.cpp`](https://github.com/ggerganov/llama.cpp) and [`GPT-Fast`](https://github.com/pytorch-labs/gpt-fast). 20 | 21 | ## Getting Started 22 | 23 | * Install the CUDA extension 24 | 25 | ```bash 26 | python setup.py install 27 | ``` 28 | 29 | * Convert GGUF file to torch state dict 30 | 31 | ```bash 32 | python convert.py --input tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --output TinyLlama-Q4_K_M 33 | ``` 34 | 35 | * Running inference 36 | 37 | ```bash 38 | python generate.py --checkpoint_path TinyLlama-Q4_K_M --interactive --compile 39 | ``` 40 | 41 | `torch.compile` will take minutes, you can also run in eager mode without `--compile` flag. 42 | 43 | 44 | ## Todo 45 | * Add support to more model 46 | * Support partitioned model 47 | * Support new MoE breaking change -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | # Modifed from https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/constants.py 2 | # The official constants.py lacks support for newly added quants as of 2024/03 3 | # So I copied it with some modifications 4 | 5 | from __future__ import annotations 6 | 7 | import sys 8 | from enum import Enum, IntEnum, auto 9 | from typing import Any 10 | 11 | # 12 | # constants 13 | # 14 | 15 | GGUF_MAGIC = 0x46554747 # "GGUF" 16 | GGUF_VERSION = 3 17 | GGUF_DEFAULT_ALIGNMENT = 32 18 | 19 | # 20 | # metadata keys 21 | # 22 | 23 | 24 | class Keys: 25 | class General: 26 | ARCHITECTURE = "general.architecture" 27 | QUANTIZATION_VERSION = "general.quantization_version" 28 | ALIGNMENT = "general.alignment" 29 | NAME = "general.name" 30 | AUTHOR = "general.author" 31 | URL = "general.url" 32 | DESCRIPTION = "general.description" 33 | LICENSE = "general.license" 34 | SOURCE_URL = "general.source.url" 35 | SOURCE_HF_REPO = "general.source.huggingface.repository" 36 | FILE_TYPE = "general.file_type" 37 | 38 | class LLM: 39 | CONTEXT_LENGTH = "{arch}.context_length" 40 | EMBEDDING_LENGTH = "{arch}.embedding_length" 41 | BLOCK_COUNT = "{arch}.block_count" 42 | FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" 43 | USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" 44 | TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" 45 | 46 | class Attention: 47 | HEAD_COUNT = "{arch}.attention.head_count" 48 | HEAD_COUNT_KV = "{arch}.attention.head_count_kv" 49 | MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias" 50 | CLAMP_KQV = "{arch}.attention.clamp_kqv" 51 | LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" 52 | LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" 53 | 54 | class Rope: 55 | DIMENSION_COUNT = "{arch}.rope.dimension_count" 56 | FREQ_BASE = "{arch}.rope.freq_base" 57 | SCALING_TYPE = "{arch}.rope.scaling.type" 58 | SCALING_FACTOR = "{arch}.rope.scaling.factor" 59 | SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" 60 | SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" 61 | 62 | class Tokenizer: 63 | MODEL = "tokenizer.ggml.model" 64 | LIST = "tokenizer.ggml.tokens" 65 | TOKEN_TYPE = "tokenizer.ggml.token_type" 66 | SCORES = "tokenizer.ggml.scores" 67 | MERGES = "tokenizer.ggml.merges" 68 | BOS_ID = "tokenizer.ggml.bos_token_id" 69 | EOS_ID = "tokenizer.ggml.eos_token_id" 70 | UNK_ID = "tokenizer.ggml.unknown_token_id" 71 | SEP_ID = "tokenizer.ggml.seperator_token_id" 72 | PAD_ID = "tokenizer.ggml.padding_token_id" 73 | ADD_BOS = "tokenizer.ggml.add_bos_token" 74 | ADD_EOS = "tokenizer.ggml.add_eos_token" 75 | HF_JSON = "tokenizer.huggingface.json" 76 | RWKV = "tokenizer.rwkv.world" 77 | CHAT_TEMPLATE = "tokenizer.chat_template" 78 | 79 | 80 | # 81 | # recommended mapping of model tensor names for storage in gguf 82 | # 83 | 84 | 85 | class MODEL_ARCH(IntEnum): 86 | LLAMA = auto() 87 | FALCON = auto() 88 | BAICHUAN = auto() 89 | GPT2 = auto() 90 | GPTJ = auto() 91 | GPTNEOX = auto() 92 | MPT = auto() 93 | STARCODER = auto() 94 | PERSIMMON = auto() 95 | REFACT = auto() 96 | BERT = auto() 97 | BLOOM = auto() 98 | STABLELM = auto() 99 | QWEN = auto() 100 | 101 | 102 | class MODEL_TENSOR(IntEnum): 103 | TOKEN_EMBD = auto() 104 | TOKEN_EMBD_NORM = auto() 105 | TOKEN_TYPES = auto() 106 | POS_EMBD = auto() 107 | OUTPUT = auto() 108 | OUTPUT_NORM = auto() 109 | ROPE_FREQS = auto() 110 | ATTN_Q = auto() 111 | ATTN_K = auto() 112 | ATTN_V = auto() 113 | ATTN_QKV = auto() 114 | ATTN_OUT = auto() 115 | ATTN_NORM = auto() 116 | ATTN_NORM_2 = auto() 117 | ATTN_ROT_EMBD = auto() 118 | FFN_GATE = auto() 119 | FFN_DOWN = auto() 120 | FFN_UP = auto() 121 | FFN_NORM = auto() 122 | ATTN_Q_NORM = auto() 123 | ATTN_K_NORM = auto() 124 | 125 | 126 | MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { 127 | MODEL_ARCH.LLAMA: "llama", 128 | MODEL_ARCH.FALCON: "falcon", 129 | MODEL_ARCH.BAICHUAN: "baichuan", 130 | MODEL_ARCH.GPT2: "gpt2", 131 | MODEL_ARCH.GPTJ: "gptj", 132 | MODEL_ARCH.GPTNEOX: "gptneox", 133 | MODEL_ARCH.MPT: "mpt", 134 | MODEL_ARCH.STARCODER: "starcoder", 135 | MODEL_ARCH.PERSIMMON: "persimmon", 136 | MODEL_ARCH.REFACT: "refact", 137 | MODEL_ARCH.BERT: "bert", 138 | MODEL_ARCH.BLOOM: "bloom", 139 | MODEL_ARCH.STABLELM: "stablelm", 140 | MODEL_ARCH.QWEN: "qwen", 141 | } 142 | 143 | TENSOR_NAMES: dict[MODEL_TENSOR, str] = { 144 | MODEL_TENSOR.TOKEN_EMBD: "token_embd", 145 | MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm", 146 | MODEL_TENSOR.TOKEN_TYPES: "token_types", 147 | MODEL_TENSOR.POS_EMBD: "position_embd", 148 | MODEL_TENSOR.OUTPUT_NORM: "output_norm", 149 | MODEL_TENSOR.OUTPUT: "output", 150 | MODEL_TENSOR.ROPE_FREQS: "rope_freqs", 151 | MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", 152 | MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", 153 | MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", 154 | MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", 155 | MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", 156 | MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", 157 | MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", 158 | MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", 159 | MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", 160 | MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", 161 | MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", 162 | MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", 163 | MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", 164 | MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", 165 | } 166 | 167 | MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { 168 | MODEL_ARCH.LLAMA: [ 169 | MODEL_TENSOR.TOKEN_EMBD, 170 | MODEL_TENSOR.OUTPUT_NORM, 171 | MODEL_TENSOR.OUTPUT, 172 | MODEL_TENSOR.ROPE_FREQS, 173 | MODEL_TENSOR.ATTN_NORM, 174 | MODEL_TENSOR.ATTN_Q, 175 | MODEL_TENSOR.ATTN_K, 176 | MODEL_TENSOR.ATTN_V, 177 | MODEL_TENSOR.ATTN_OUT, 178 | MODEL_TENSOR.ATTN_ROT_EMBD, 179 | MODEL_TENSOR.FFN_NORM, 180 | MODEL_TENSOR.FFN_GATE, 181 | MODEL_TENSOR.FFN_DOWN, 182 | MODEL_TENSOR.FFN_UP, 183 | ], 184 | MODEL_ARCH.GPTNEOX: [ 185 | MODEL_TENSOR.TOKEN_EMBD, 186 | MODEL_TENSOR.OUTPUT_NORM, 187 | MODEL_TENSOR.OUTPUT, 188 | MODEL_TENSOR.ATTN_NORM, 189 | MODEL_TENSOR.ATTN_QKV, 190 | MODEL_TENSOR.ATTN_OUT, 191 | MODEL_TENSOR.FFN_NORM, 192 | MODEL_TENSOR.FFN_DOWN, 193 | MODEL_TENSOR.FFN_UP, 194 | ], 195 | MODEL_ARCH.FALCON: [ 196 | MODEL_TENSOR.TOKEN_EMBD, 197 | MODEL_TENSOR.OUTPUT_NORM, 198 | MODEL_TENSOR.OUTPUT, 199 | MODEL_TENSOR.ATTN_NORM, 200 | MODEL_TENSOR.ATTN_NORM_2, 201 | MODEL_TENSOR.ATTN_QKV, 202 | MODEL_TENSOR.ATTN_OUT, 203 | MODEL_TENSOR.FFN_DOWN, 204 | MODEL_TENSOR.FFN_UP, 205 | ], 206 | MODEL_ARCH.BAICHUAN: [ 207 | MODEL_TENSOR.TOKEN_EMBD, 208 | MODEL_TENSOR.OUTPUT_NORM, 209 | MODEL_TENSOR.OUTPUT, 210 | MODEL_TENSOR.ROPE_FREQS, 211 | MODEL_TENSOR.ATTN_NORM, 212 | MODEL_TENSOR.ATTN_Q, 213 | MODEL_TENSOR.ATTN_K, 214 | MODEL_TENSOR.ATTN_V, 215 | MODEL_TENSOR.ATTN_OUT, 216 | MODEL_TENSOR.ATTN_ROT_EMBD, 217 | MODEL_TENSOR.FFN_NORM, 218 | MODEL_TENSOR.FFN_GATE, 219 | MODEL_TENSOR.FFN_DOWN, 220 | MODEL_TENSOR.FFN_UP, 221 | ], 222 | MODEL_ARCH.STARCODER: [ 223 | MODEL_TENSOR.TOKEN_EMBD, 224 | MODEL_TENSOR.POS_EMBD, 225 | MODEL_TENSOR.OUTPUT_NORM, 226 | MODEL_TENSOR.OUTPUT, 227 | MODEL_TENSOR.ATTN_NORM, 228 | MODEL_TENSOR.ATTN_QKV, 229 | MODEL_TENSOR.ATTN_OUT, 230 | MODEL_TENSOR.FFN_NORM, 231 | MODEL_TENSOR.FFN_DOWN, 232 | MODEL_TENSOR.FFN_UP, 233 | ], 234 | MODEL_ARCH.BERT: [ 235 | MODEL_TENSOR.TOKEN_EMBD, 236 | MODEL_TENSOR.TOKEN_TYPES, 237 | MODEL_TENSOR.POS_EMBD, 238 | MODEL_TENSOR.OUTPUT_NORM, 239 | MODEL_TENSOR.ATTN_NORM, 240 | MODEL_TENSOR.ATTN_Q, 241 | MODEL_TENSOR.ATTN_K, 242 | MODEL_TENSOR.ATTN_V, 243 | MODEL_TENSOR.ATTN_OUT, 244 | MODEL_TENSOR.FFN_NORM, 245 | MODEL_TENSOR.FFN_DOWN, 246 | MODEL_TENSOR.FFN_UP, 247 | ], 248 | MODEL_ARCH.MPT: [ 249 | MODEL_TENSOR.TOKEN_EMBD, 250 | MODEL_TENSOR.OUTPUT_NORM, 251 | MODEL_TENSOR.OUTPUT, 252 | MODEL_TENSOR.ATTN_NORM, 253 | MODEL_TENSOR.ATTN_QKV, 254 | MODEL_TENSOR.ATTN_OUT, 255 | MODEL_TENSOR.FFN_NORM, 256 | MODEL_TENSOR.FFN_DOWN, 257 | MODEL_TENSOR.FFN_UP, 258 | ], 259 | MODEL_ARCH.GPTJ: [ 260 | MODEL_TENSOR.TOKEN_EMBD, 261 | MODEL_TENSOR.OUTPUT_NORM, 262 | MODEL_TENSOR.OUTPUT, 263 | MODEL_TENSOR.ATTN_NORM, 264 | MODEL_TENSOR.ATTN_Q, 265 | MODEL_TENSOR.ATTN_K, 266 | MODEL_TENSOR.ATTN_V, 267 | MODEL_TENSOR.ATTN_OUT, 268 | MODEL_TENSOR.FFN_DOWN, 269 | MODEL_TENSOR.FFN_UP, 270 | ], 271 | MODEL_ARCH.PERSIMMON: [ 272 | MODEL_TENSOR.TOKEN_EMBD, 273 | MODEL_TENSOR.OUTPUT, 274 | MODEL_TENSOR.OUTPUT_NORM, 275 | MODEL_TENSOR.ATTN_NORM, 276 | MODEL_TENSOR.ATTN_QKV, 277 | MODEL_TENSOR.ATTN_OUT, 278 | MODEL_TENSOR.FFN_NORM, 279 | MODEL_TENSOR.FFN_DOWN, 280 | MODEL_TENSOR.FFN_UP, 281 | MODEL_TENSOR.ATTN_Q_NORM, 282 | MODEL_TENSOR.ATTN_K_NORM, 283 | MODEL_TENSOR.ATTN_ROT_EMBD, 284 | ], 285 | MODEL_ARCH.REFACT: [ 286 | MODEL_TENSOR.TOKEN_EMBD, 287 | MODEL_TENSOR.OUTPUT_NORM, 288 | MODEL_TENSOR.OUTPUT, 289 | MODEL_TENSOR.ATTN_NORM, 290 | MODEL_TENSOR.ATTN_Q, 291 | MODEL_TENSOR.ATTN_K, 292 | MODEL_TENSOR.ATTN_V, 293 | MODEL_TENSOR.ATTN_OUT, 294 | MODEL_TENSOR.FFN_NORM, 295 | MODEL_TENSOR.FFN_GATE, 296 | MODEL_TENSOR.FFN_DOWN, 297 | MODEL_TENSOR.FFN_UP, 298 | ], 299 | MODEL_ARCH.BLOOM: [ 300 | MODEL_TENSOR.TOKEN_EMBD, 301 | MODEL_TENSOR.TOKEN_EMBD_NORM, 302 | MODEL_TENSOR.OUTPUT_NORM, 303 | MODEL_TENSOR.OUTPUT, 304 | MODEL_TENSOR.ATTN_NORM, 305 | MODEL_TENSOR.ATTN_QKV, 306 | MODEL_TENSOR.ATTN_OUT, 307 | MODEL_TENSOR.FFN_NORM, 308 | MODEL_TENSOR.FFN_DOWN, 309 | MODEL_TENSOR.FFN_UP, 310 | ], 311 | MODEL_ARCH.STABLELM: [ 312 | MODEL_TENSOR.TOKEN_EMBD, 313 | MODEL_TENSOR.OUTPUT_NORM, 314 | MODEL_TENSOR.OUTPUT, 315 | MODEL_TENSOR.ROPE_FREQS, 316 | MODEL_TENSOR.ATTN_NORM, 317 | MODEL_TENSOR.ATTN_Q, 318 | MODEL_TENSOR.ATTN_K, 319 | MODEL_TENSOR.ATTN_V, 320 | MODEL_TENSOR.ATTN_OUT, 321 | MODEL_TENSOR.FFN_NORM, 322 | MODEL_TENSOR.FFN_GATE, 323 | MODEL_TENSOR.FFN_DOWN, 324 | MODEL_TENSOR.FFN_UP, 325 | ], 326 | MODEL_ARCH.QWEN: [ 327 | MODEL_TENSOR.TOKEN_EMBD, 328 | MODEL_TENSOR.OUTPUT_NORM, 329 | MODEL_TENSOR.OUTPUT, 330 | MODEL_TENSOR.ROPE_FREQS, 331 | MODEL_TENSOR.ATTN_NORM, 332 | MODEL_TENSOR.ATTN_QKV, 333 | MODEL_TENSOR.ATTN_OUT, 334 | MODEL_TENSOR.ATTN_ROT_EMBD, 335 | MODEL_TENSOR.FFN_NORM, 336 | MODEL_TENSOR.FFN_GATE, 337 | MODEL_TENSOR.FFN_DOWN, 338 | MODEL_TENSOR.FFN_UP, 339 | ], 340 | MODEL_ARCH.GPT2: [ 341 | # TODO 342 | ], 343 | # TODO 344 | } 345 | 346 | # tensors that will not be serialized 347 | MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { 348 | MODEL_ARCH.LLAMA: [ 349 | MODEL_TENSOR.ROPE_FREQS, 350 | MODEL_TENSOR.ATTN_ROT_EMBD, 351 | ], 352 | MODEL_ARCH.BAICHUAN: [ 353 | MODEL_TENSOR.ROPE_FREQS, 354 | MODEL_TENSOR.ATTN_ROT_EMBD, 355 | ], 356 | MODEL_ARCH.PERSIMMON: [ 357 | MODEL_TENSOR.ROPE_FREQS, 358 | ], 359 | MODEL_ARCH.QWEN: [ 360 | MODEL_TENSOR.ROPE_FREQS, 361 | MODEL_TENSOR.ATTN_ROT_EMBD, 362 | ], 363 | } 364 | 365 | # 366 | # types 367 | # 368 | 369 | 370 | class TokenType(IntEnum): 371 | NORMAL = 1 372 | UNKNOWN = 2 373 | CONTROL = 3 374 | USER_DEFINED = 4 375 | UNUSED = 5 376 | BYTE = 6 377 | 378 | 379 | class RopeScalingType(Enum): 380 | NONE = 'none' 381 | LINEAR = 'linear' 382 | YARN = 'yarn' 383 | 384 | 385 | class GGMLQuantizationType(IntEnum): 386 | F32 = 0 387 | F16 = 1 388 | Q4_0 = 2 389 | Q4_1 = 3 390 | Q5_0 = 6 391 | Q5_1 = 7 392 | Q8_0 = 8 393 | Q8_1 = 9 394 | Q2_K = 10 395 | Q3_K = 11 396 | Q4_K = 12 397 | Q5_K = 13 398 | Q6_K = 14 399 | Q8_K = 15 400 | IQ2_XXS = 16 401 | IQ2_XS = 17 402 | IQ3_XXS = 18, 403 | IQ1_S = 19 404 | IQ4_NL = 20 405 | IQ3_S = 21 406 | IQ2_S = 22 407 | IQ4_XS = 23 408 | 409 | 410 | class GGUFEndian(IntEnum): 411 | LITTLE = 0 412 | BIG = 1 413 | 414 | 415 | class GGUFValueType(IntEnum): 416 | UINT8 = 0 417 | INT8 = 1 418 | UINT16 = 2 419 | INT16 = 3 420 | UINT32 = 4 421 | INT32 = 5 422 | FLOAT32 = 6 423 | BOOL = 7 424 | STRING = 8 425 | ARRAY = 9 426 | UINT64 = 10 427 | INT64 = 11 428 | FLOAT64 = 12 429 | 430 | @staticmethod 431 | def get_type(val: Any) -> GGUFValueType: 432 | if isinstance(val, (str, bytes, bytearray)): 433 | return GGUFValueType.STRING 434 | elif isinstance(val, list): 435 | return GGUFValueType.ARRAY 436 | elif isinstance(val, float): 437 | return GGUFValueType.FLOAT32 438 | elif isinstance(val, bool): 439 | return GGUFValueType.BOOL 440 | elif isinstance(val, int): 441 | return GGUFValueType.INT32 442 | # TODO: need help with 64-bit types in Python 443 | else: 444 | print("Unknown type:", type(val)) 445 | sys.exit() 446 | 447 | 448 | # Note: Does not support GGML_QKK_64 449 | QK_K = 256 450 | # Items here are (block size, type size) 451 | GGML_QUANT_SIZES = { 452 | GGMLQuantizationType.F32: (1, 4), 453 | GGMLQuantizationType.F16: (1, 2), 454 | GGMLQuantizationType.Q4_0: (32, 2 + 16), 455 | GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), 456 | GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), 457 | GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16), 458 | GGMLQuantizationType.Q8_0: (32, 2 + 32), 459 | GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32), 460 | GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4), 461 | GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12), 462 | GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12), 463 | GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), 464 | GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), 465 | GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8), 466 | GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4), 467 | GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32), 468 | GGMLQuantizationType.IQ3_XXS: (256, 2 + 3 * QK_K // 8), 469 | GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16), 470 | GGMLQuantizationType.IQ4_NL: (32, 2 + 32 // 2), 471 | GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 32 + QK_K // 8 + QK_K // 64), 472 | GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 32 + QK_K // 32), 473 | GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 64 + QK_K // 2), 474 | } 475 | 476 | 477 | # Aliases for backward compatibility. 478 | 479 | # general 480 | KEY_GENERAL_ARCHITECTURE = Keys.General.ARCHITECTURE 481 | KEY_GENERAL_QUANTIZATION_VERSION = Keys.General.QUANTIZATION_VERSION 482 | KEY_GENERAL_ALIGNMENT = Keys.General.ALIGNMENT 483 | KEY_GENERAL_NAME = Keys.General.NAME 484 | KEY_GENERAL_AUTHOR = Keys.General.AUTHOR 485 | KEY_GENERAL_URL = Keys.General.URL 486 | KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION 487 | KEY_GENERAL_LICENSE = Keys.General.LICENSE 488 | KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL 489 | KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO 490 | KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE 491 | 492 | # LLM 493 | KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH 494 | KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH 495 | KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT 496 | KEY_FEED_FORWARD_LENGTH = Keys.LLM.FEED_FORWARD_LENGTH 497 | KEY_USE_PARALLEL_RESIDUAL = Keys.LLM.USE_PARALLEL_RESIDUAL 498 | KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT 499 | 500 | # attention 501 | KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT 502 | KEY_ATTENTION_HEAD_COUNT_KV = Keys.Attention.HEAD_COUNT_KV 503 | KEY_ATTENTION_MAX_ALIBI_BIAS = Keys.Attention.MAX_ALIBI_BIAS 504 | KEY_ATTENTION_CLAMP_KQV = Keys.Attention.CLAMP_KQV 505 | KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS 506 | KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS 507 | 508 | # RoPE 509 | KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT 510 | KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE 511 | KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE 512 | KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR 513 | KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN 514 | KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED 515 | 516 | # tokenization 517 | KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL 518 | KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST 519 | KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE 520 | KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES 521 | KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES 522 | KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID 523 | KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID 524 | KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID 525 | KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID 526 | KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID 527 | KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON 528 | KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV 529 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | from gguf_reader import GGUFReader 6 | from sentencepiece import sentencepiece_model_pb2 7 | 8 | 9 | def convert_to_state_dict(checkpoint, save_dir): 10 | if not os.path.exists(save_dir): 11 | os.makedirs(save_dir) 12 | state_dict = {} 13 | result = GGUFReader(checkpoint) 14 | architecture = result.fields['general.architecture'] 15 | architecture = str(bytes(architecture.parts[architecture.data[0]]), encoding = 'utf-8') 16 | if architecture not in ["llama", "qwen2", "internlm2", "starcoder2", "qwen", 17 | "stablelm", "orion", "minicpm", "gemma", "xverse", "command-r"]: 18 | print(f"Unsupported architecture {architecture}") 19 | return 20 | # write tensor 21 | for ts in result.tensors: 22 | if hasattr(ts.data.dtype, 'names') and ts.data.dtype.names is not None: 23 | for name in ts.data.dtype.names: 24 | state_dict[ts.name + "_" + name] = torch.tensor(ts.data[name]) 25 | else: 26 | state_dict[ts.name] = torch.tensor(ts.data) 27 | if "weight" in ts.name: 28 | state_dict[ts.name.replace("weight", "weight_type")] = torch.tensor(int(ts.tensor_type), dtype=torch.int) 29 | torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) 30 | # write vocab 31 | # note we ignore added tokens for simplicity 32 | vocab_type = result.fields["tokenizer.ggml.model"] 33 | vocab_type = str(bytes(vocab_type.parts[vocab_type.data[0]]), encoding = 'utf-8') 34 | if vocab_type == "gpt2": 35 | # bpe vocab 36 | merges = result.fields["tokenizer.ggml.merges"] 37 | with open(os.path.join(save_dir, "merges.txt"), 'w') as f: 38 | for idx in merges.data: 39 | data = str(bytes(merges.parts[idx]), encoding = 'utf-8') 40 | f.write(f"{data}\n") 41 | tokens = result.fields['tokenizer.ggml.tokens'] 42 | types = result.fields['tokenizer.ggml.token_type'] 43 | vocab_size = len(tokens.data) 44 | vocab = {} 45 | special_vocab = {} 46 | vocab_list = [] 47 | for i, idx in enumerate(tokens.data): 48 | token = str(bytes(tokens.parts[idx]), encoding='utf-8') 49 | token_type = int(types.parts[types.data[i]]) 50 | #if (token.startswith("[PAD") or token.startswith("": 88 | new_token.piece = rb'\x00' 89 | new_token.type = 1 90 | vocab_list.append(new_token.piece) 91 | vocab.pieces.append(new_token) 92 | if new_token.type == 3: 93 | special_vocab[i] = {"content": new_token.piece, "special": True} 94 | # hf_vocab doesn't correctly set unk token type, so we force one 95 | if not has_unk: 96 | vocab.pieces[0].type = 2 97 | 98 | with open(os.path.join(save_dir, "tokenizer.model"), 'wb') as f: 99 | f.write(vocab.SerializeToString()) 100 | 101 | tokenizer_conf = {} 102 | if 'tokenizer.ggml.bos_token_id' in result.fields: 103 | tokenizer_conf["bos_token"] = vocab_list[int(result.fields['tokenizer.ggml.bos_token_id'].parts[-1])] 104 | if 'tokenizer.ggml.eos_token_id' in result.fields: 105 | tokenizer_conf["eos_token"] = vocab_list[int(result.fields['tokenizer.ggml.eos_token_id'].parts[-1])] 106 | if 'tokenizer.ggml.padding_token_id' in result.fields: 107 | tokenizer_conf["pad_token"] = vocab_list[int(result.fields['tokenizer.ggml.padding_token_id'].parts[-1])] 108 | if 'tokenizer.ggml.unknown_token_id' in result.fields: 109 | tokenizer_conf["unk_token"] = vocab_list[int(result.fields['tokenizer.ggml.unknown_token_id'].parts[-1])] 110 | if 'tokenizer.ggml.add_bos_token' in result.fields: 111 | tokenizer_conf["add_bos_token"] = bool(result.fields['tokenizer.ggml.add_bos_token'].parts[-1]) 112 | if 'tokenizer.ggml.add_eos_token' in result.fields: 113 | tokenizer_conf["add_eos_token"] = bool(result.fields['tokenizer.ggml.add_eos_token'].parts[-1]) 114 | if 'tokenizer.chat_template' in result.fields: 115 | tokenizer_conf['chat_template'] = str(bytes(result.fields['tokenizer.chat_template'].parts[-1]), encoding = 'utf-8') 116 | if special_vocab: 117 | tokenizer_conf["added_tokens_decoder"] = special_vocab 118 | json.dump(tokenizer_conf, open(os.path.join(save_dir, "tokenizer_config.json"), 'w'), indent=2) 119 | 120 | 121 | # write config 122 | context_length = int(result.fields[f'{architecture}.context_length'].parts[-1]) 123 | n_layer = int(result.fields[f'{architecture}.block_count'].parts[-1]) 124 | n_head = int(result.fields[f'{architecture}.attention.head_count'].parts[-1]) 125 | intermediate_size = int(result.fields[f'{architecture}.feed_forward_length'].parts[-1]) 126 | # qwen use ffn_size / 2 for ffn layers 127 | if architecture == "qwen": 128 | intermediate_size = intermediate_size / 2 129 | dim = int(result.fields[f'{architecture}.embedding_length'].parts[-1]) 130 | if f'{architecture}.logit_scale' in result.fields: 131 | logit_scale = float(result.fields[f'{architecture}.logit_scale'].parts[-1]) 132 | else: 133 | logit_scale = 1.0 134 | # https://github.com/ggerganov/llama.cpp/blob/9731134296af3a6839cd682e51d9c2109a871de5/llama.cpp#L12301 135 | if architecture in ["qwen2", "gemma", "qwen", "stablelm", "starcoder2", "phi2"]: 136 | rope_type = "neox" 137 | elif architecture in ["llama", "internlm2", "baichuan", "startcoder", "orion", "minicpm", 138 | "xverse", "command-r"]: 139 | rope_type = "norm" 140 | else: 141 | rope_type = "none" 142 | 143 | if architecture in ["starcoder2", "phi2", "gemma"]: 144 | hidden_act = "gelu_tanh" 145 | else: 146 | hidden_act = "silu" 147 | 148 | if architecture in ["starcoder2", "phi2"]: 149 | mlp_gate = False 150 | else: 151 | mlp_gate = True 152 | 153 | if architecture in ["starcoder2", "phi2", "stablelm", "orion", "command-r"]: 154 | layernorm = True 155 | else: 156 | layernorm = False 157 | model_config= { 158 | "architecture": architecture, 159 | "block_size": context_length, 160 | "vocab_size": vocab_size, 161 | "n_layer": n_layer, 162 | "n_head": n_head, 163 | "dim": dim, 164 | "intermediate_size": intermediate_size, 165 | "hidden_act": hidden_act, 166 | "rope_type": rope_type, 167 | "mlp_gate": mlp_gate, 168 | "layernorm": layernorm, 169 | "logit_scale": logit_scale 170 | } 171 | if f'{architecture}.attention.head_count_kv' in result.fields: 172 | model_config['n_local_heads'] = int(result.fields[f'{architecture}.attention.head_count_kv'].parts[-1]) 173 | if f'{architecture}.attention.layer_norm_rms_epsilon' in result.fields: 174 | model_config['norm_eps'] = float(result.fields[f'{architecture}.attention.layer_norm_rms_epsilon'].parts[-1]) 175 | if f'{architecture}.attention.key_length' in result.fields: 176 | model_config['head_dim'] = int(result.fields[f'{architecture}.attention.key_length'].parts[-1]) 177 | if f'{architecture}.rope.freq_base' in result.fields: 178 | model_config['rope_base'] = float(result.fields[f'{architecture}.rope.freq_base'].parts[-1]) 179 | if f'{architecture}.rope_dimension_count' in result.fields: 180 | model_config['rope_dim'] = int(result.fields[f'{architecture}.rope_dimension_count'].parts[-1]) 181 | if f'{architecture}.expert_count' in result.fields: 182 | model_config['num_experts'] = int(result.fields[f'{architecture}.expert_count'].parts[-1]) 183 | model_config['num_experts_per_tok'] = int(result.fields[f'{architecture}.expert_used_count'].parts[-1]) 184 | model_config['moe'] = (model_config['num_experts'] > 1) 185 | 186 | json.dump(model_config, open(os.path.join(save_dir, "config.json"), 'w'), indent=2) 187 | 188 | 189 | if __name__ == '__main__': 190 | import argparse 191 | parser = argparse.ArgumentParser(description='Convert GGUF checkpoints to torch') 192 | 193 | parser.add_argument('--input', type=str, help='The path to GGUF file') 194 | parser.add_argument('--output', type=str, help='The path to output directory') 195 | args = parser.parse_args() 196 | convert_to_state_dict(args.input, args.output) 197 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | import json 8 | import time 9 | from functools import lru_cache 10 | from pathlib import Path 11 | from typing import Optional, Tuple 12 | 13 | import jinja2 14 | import torch 15 | import torch._dynamo.config 16 | import torch._inductor.config 17 | import torch.distributed as dist 18 | from jinja2.sandbox import ImmutableSandboxedEnvironment 19 | from sentencepiece import SentencePieceProcessor 20 | 21 | from gpt2_tokenizer import GPT2Tokenizer 22 | from model import Transformer 23 | from tp import maybe_init_dist, _get_world_size 24 | 25 | torch._inductor.config.coordinate_descent_tuning = True 26 | torch._inductor.config.triton.unique_kernel_names = True 27 | 28 | def device_sync(device): 29 | if "cuda" in device: 30 | torch.cuda.synchronize() 31 | elif "cpu" in device: 32 | pass 33 | else: 34 | print(f"device={device} is not yet suppported") 35 | 36 | @lru_cache 37 | def get_template(chat_template): 38 | jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) 39 | return jinja_env.from_string(chat_template) 40 | 41 | def apply_template(chat_template, query): 42 | compiled_template = get_template(chat_template) 43 | chat = [{"role": "user", "content": query}] 44 | rendered_chat = compiled_template.render(messages=chat, add_generation_prompt=True) 45 | return rendered_chat 46 | 47 | def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization 48 | q = torch.empty_like(probs_sort).exponential_(1) 49 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) 50 | 51 | def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): 52 | logits = logits / max(temperature, 1e-5) 53 | 54 | if top_k is not None: 55 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 56 | pivot = v.select(-1, -1).unsqueeze(-1) 57 | logits = torch.where(logits < pivot, -float("Inf"), logits) 58 | probs = torch.nn.functional.softmax(logits, dim=-1) 59 | return probs 60 | 61 | def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): 62 | probs = logits_to_probs(logits[0, -1], temperature, top_k) 63 | idx_next = multinomial_sample_one_no_sync(probs) 64 | return idx_next, probs 65 | 66 | def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: 67 | # input_pos: [B, S] 68 | logits = model(x, input_pos) 69 | return sample(logits, **sampling_kwargs)[0] 70 | 71 | def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 72 | # input_pos: [B, 1] 73 | assert input_pos.shape[-1] == 1 74 | logits = model(x, input_pos) 75 | return sample(logits, **sampling_kwargs) 76 | 77 | def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): 78 | new_tokens, new_probs = [], [] 79 | for i in range(num_new_tokens): 80 | with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here 81 | next_token, next_prob = decode_one_token( 82 | model, cur_token, input_pos, **sampling_kwargs 83 | ) 84 | input_pos += 1 85 | new_tokens.append(next_token.clone()) 86 | status = callback(new_tokens[-1]) 87 | if not status: 88 | break 89 | new_probs.append(next_prob.clone()) 90 | cur_token = next_token.view(1, -1) 91 | 92 | return new_tokens, new_probs 93 | 94 | 95 | def model_forward(model, x, input_pos): 96 | return model(x, input_pos) 97 | 98 | 99 | @torch.no_grad() 100 | def generate( 101 | model: Transformer, 102 | prompt: torch.Tensor, 103 | max_new_tokens: int, 104 | *, 105 | interactive: bool, 106 | callback = lambda x: x, 107 | **sampling_kwargs 108 | ) -> torch.Tensor: 109 | """ 110 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 111 | """ 112 | # create an empty tensor of the expected final shape and fill in the current tokens 113 | T = prompt.size(0) 114 | T_new = T + max_new_tokens 115 | if interactive: 116 | max_seq_length = 350 117 | else: 118 | max_seq_length = min(T_new, model.config.block_size) 119 | 120 | device, dtype = prompt.device, prompt.dtype 121 | with torch.device(device): 122 | model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) 123 | 124 | # create an empty tensor of the expected final shape and fill in the current tokens 125 | empty = torch.empty(T_new, dtype=dtype, device=device) 126 | empty[:T] = prompt 127 | seq = empty 128 | input_pos = torch.arange(0, T, device=device) 129 | 130 | next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) 131 | seq[T] = next_token 132 | callback(next_token) 133 | 134 | input_pos = torch.tensor([T], device=device, dtype=torch.int) 135 | 136 | generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) 137 | seq[T + 1: T + 1 + len(generated_tokens)] = torch.cat(generated_tokens) 138 | 139 | return seq[:T + 1 + len(generated_tokens)] 140 | 141 | 142 | def encode_tokens(tokenizer, special_tokens, string, bos=True, device='cuda'): 143 | text_list = [string] 144 | for token in special_tokens: 145 | new_text_list = [] 146 | for text in text_list: 147 | text = text.split(token) 148 | new_text_list.extend([e for pair in zip(text, [token] * (len(text) - 1)) for e in pair] + [text[-1]]) 149 | text_list = new_text_list 150 | 151 | tokens = [] 152 | for text in text_list: 153 | if not text: 154 | continue 155 | if text in special_tokens: 156 | tokens.append(int(special_tokens[text])) 157 | else: 158 | tokens.extend(tokenizer.encode(text)) 159 | if bos: 160 | # some spm model has wrong set bos 161 | bos_id = tokenizer.bos_id() if tokenizer.bos_id() > 0 else 2 162 | tokens = [bos_id] + tokens 163 | return torch.tensor(tokens, dtype=torch.int, device=device) 164 | 165 | 166 | def _load_model(checkpoint_path, device, precision, use_tp): 167 | with torch.device('meta'): 168 | model = Transformer.from_json(str(checkpoint_path / "config.json")) 169 | 170 | checkpoint = torch.load(str(checkpoint_path / "pytorch_model.bin"), mmap=True, weights_only=True) 171 | 172 | model.load_state_dict(checkpoint, strict=False, assign=True) 173 | # Fixed tied embedding 174 | if model.output.weight.device == torch.device("meta"): 175 | model.output.weight = model.token_embd.weight 176 | model.output.weight_type = model.token_embd.weight_type 177 | 178 | #for k,v in list(model.named_parameters()) + list(model.named_buffers()): 179 | # print(k, v.device) 180 | model = model._apply(lambda t: torch.zeros_like(t, device="cpu") 181 | if t.device == torch.device("meta") else t) 182 | for name, module in model.named_modules(): 183 | if hasattr(module, "weight_type"): 184 | module.weight_type_int = int(module.weight_type) 185 | 186 | if use_tp: 187 | from tp import apply_tp 188 | apply_tp(model) 189 | 190 | model = model.to(device=device, dtype=precision) 191 | return model.eval() 192 | 193 | 194 | def main( 195 | prompt: str = "Hello, my name is", 196 | interactive: bool = False, 197 | num_samples: int = 5, 198 | max_new_tokens: int = 100, 199 | top_k: int = 20, 200 | temperature: float = 0.8, 201 | checkpoint_path: Path = Path("."), 202 | compile: bool = True, 203 | compile_prefill: bool = False, 204 | device='cuda', 205 | ) -> None: 206 | """Generates text samples based on a pre-trained Transformer model and tokenizer. 207 | """ 208 | tokenizer_path = checkpoint_path / "tokenizer.model" 209 | use_spm = True 210 | if not tokenizer_path.is_file(): 211 | use_spm = False 212 | tokenizer_path = (checkpoint_path / "vocab.json", 213 | checkpoint_path / "merges.txt", 214 | checkpoint_path / "tokenizer_config.json") 215 | 216 | tokenizer_config = json.load(open(checkpoint_path / "tokenizer_config.json")) 217 | chat_template = tokenizer_config["chat_template"] if "chat_template" in tokenizer_config else None 218 | 219 | global print 220 | rank = maybe_init_dist() 221 | use_tp = rank is not None 222 | if use_tp: 223 | if rank != 0: 224 | # only print on rank 0 225 | print = lambda *args, **kwargs: None 226 | 227 | precision = torch.float16 228 | 229 | print("Loading model ...") 230 | t0 = time.time() 231 | model = _load_model(checkpoint_path, device, precision, use_tp) 232 | 233 | device_sync(device=device) # MKG 234 | print(f"Time to load model: {time.time() - t0:.02f} seconds") 235 | 236 | if use_spm: 237 | tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) 238 | else: 239 | tokenizer = GPT2Tokenizer(*tokenizer_path) 240 | tokenizer_conf = json.load(open(checkpoint_path / "tokenizer_config.json")) 241 | special_tokens = {v["content"]: k for k, v in tokenizer_conf.get("added_tokens_decoder", {}).items()} 242 | encoded = encode_tokens(tokenizer, special_tokens, prompt, 243 | bos=tokenizer_conf.get("add_bos_token", False), device=device) 244 | prompt_length = encoded.size(0) 245 | 246 | #torch.manual_seed(1234) 247 | if compile: 248 | global decode_one_token, prefill 249 | decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) 250 | 251 | # Uncomment to squeeze more perf out of prefill 252 | if compile_prefill: 253 | prefill = torch.compile(prefill, fullgraph=True, dynamic=True) 254 | 255 | 256 | aggregate_metrics = { 257 | 'tokens_per_sec': [], 258 | } 259 | start = -1 if compile else 0 260 | 261 | for i in range(start, num_samples): 262 | device_sync(device=device) # MKG 263 | if i >= 0 and interactive: 264 | if not use_tp or rank == 0: 265 | prompt = input("What is your prompt? ") 266 | else: 267 | prompt = "" 268 | if use_tp: 269 | prompt_list = [None for _ in range(_get_world_size())] 270 | dist.all_gather_object(prompt_list, prompt) 271 | prompt = prompt_list[0] 272 | 273 | if chat_template is not None: 274 | prompt = apply_template(chat_template, prompt) 275 | encoded = encode_tokens(tokenizer, special_tokens, prompt.strip(), 276 | bos=tokenizer_conf.get("add_bos_token", False), device=device) 277 | 278 | if interactive and i >= 0: 279 | buffer = [] 280 | token_buffer = [] 281 | period_id = tokenizer.encode('.')[-1] 282 | done_generating = False 283 | def callback(x): 284 | nonlocal done_generating, token_buffer 285 | if done_generating: 286 | return False 287 | token_buffer.extend(x.tolist()) 288 | token = tokenizer.decode([period_id] + token_buffer)[1:] 289 | if token.endswith("�"): 290 | return True 291 | else: 292 | buffer.append(token) 293 | token_buffer = [] 294 | if x.item() == tokenizer.eos_id(): 295 | done_generating = True 296 | if len(buffer) == 4 or done_generating: 297 | print(''.join(buffer), end='', flush=True) 298 | buffer.clear() 299 | return True 300 | else: 301 | callback = lambda x : x 302 | t0 = time.perf_counter() 303 | y = generate( 304 | model, 305 | encoded, 306 | max_new_tokens, 307 | interactive=interactive, 308 | callback=callback, 309 | temperature=temperature, 310 | top_k=top_k, 311 | ) 312 | if i == -1: 313 | print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") 314 | continue 315 | device_sync(device=device) # MKG 316 | t = time.perf_counter() - t0 317 | 318 | if not interactive: 319 | print(tokenizer.decode(y.tolist())) 320 | else: 321 | print() 322 | tokens_generated = y.size(0) - prompt_length 323 | tokens_sec = tokens_generated / t 324 | aggregate_metrics['tokens_per_sec'].append(tokens_sec) 325 | print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") 326 | 327 | 328 | if __name__ == '__main__': 329 | import argparse 330 | parser = argparse.ArgumentParser(description='Your CLI description.') 331 | 332 | parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') 333 | parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') 334 | parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') 335 | parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') 336 | parser.add_argument('--top_k', type=int, default=20, help='Top-k for sampling.') 337 | parser.add_argument('--temperature', type=float, default=1.0, help='Temperature for sampling.') 338 | parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints"), help='Model checkpoint path.') 339 | parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') 340 | parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') 341 | parser.add_argument('--device', type=str, default="cuda", help='device to use') 342 | 343 | args = parser.parse_args() 344 | main( 345 | args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, 346 | args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.device 347 | ) 348 | -------------------------------------------------------------------------------- /gguf_reader.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/gguf_reader.py 2 | # 3 | # GGUF file reading/modification support. For API usage information, 4 | # please see the files scripts/ for some fairly simple examples. 5 | # 6 | from __future__ import annotations 7 | 8 | import os 9 | from collections import OrderedDict 10 | from typing import Any, Literal, NamedTuple, TypeVar, Union 11 | 12 | import numpy as np 13 | import numpy.typing as npt 14 | 15 | from constants import ( 16 | GGML_QUANT_SIZES, 17 | GGUF_DEFAULT_ALIGNMENT, 18 | GGUF_MAGIC, 19 | GGUF_VERSION, 20 | GGMLQuantizationType, 21 | GGUFValueType, 22 | ) 23 | 24 | 25 | READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION] 26 | 27 | 28 | class ReaderField(NamedTuple): 29 | # Offset to start of this field. 30 | offset: int 31 | 32 | # Name of the field (not necessarily from file data). 33 | name: str 34 | 35 | # Data parts. Some types have multiple components, such as strings 36 | # that consist of a length followed by the string data. 37 | parts: list[npt.NDArray[Any]] = [] 38 | 39 | # Indexes into parts that we can call the actual data. For example 40 | # an array of strings will be populated with indexes to the actual 41 | # string data. 42 | data: list[int] = [-1] 43 | 44 | types: list[GGUFValueType] = [] 45 | 46 | 47 | class ReaderTensor(NamedTuple): 48 | name: str 49 | tensor_type: GGMLQuantizationType 50 | shape: npt.NDArray[np.uint32] 51 | n_elements: int 52 | n_bytes: int 53 | data_offset: int 54 | data: npt.NDArray[Any] 55 | field: ReaderField 56 | 57 | 58 | class GGUFReader: 59 | # I - same as host, S - swapped 60 | byte_order: Literal['I' | 'S'] = 'I' 61 | alignment: int = GGUF_DEFAULT_ALIGNMENT 62 | 63 | # Note: Internal helper, API may change. 64 | gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = { 65 | GGUFValueType.UINT8: np.uint8, 66 | GGUFValueType.INT8: np.int8, 67 | GGUFValueType.UINT16: np.uint16, 68 | GGUFValueType.INT16: np.int16, 69 | GGUFValueType.UINT32: np.uint32, 70 | GGUFValueType.INT32: np.int32, 71 | GGUFValueType.FLOAT32: np.float32, 72 | GGUFValueType.UINT64: np.uint64, 73 | GGUFValueType.INT64: np.int64, 74 | GGUFValueType.FLOAT64: np.float64, 75 | GGUFValueType.BOOL: np.bool_, 76 | } 77 | 78 | def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'): 79 | self.data = np.memmap(path, mode = mode) 80 | offs = 0 81 | if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC: 82 | raise ValueError('GGUF magic invalid') 83 | offs += 4 84 | temp_version = self._get(offs, np.uint32) 85 | if temp_version[0] & 65535 == 0: 86 | # If we get 0 here that means it's (probably) a GGUF file created for 87 | # the opposite byte order of the machine this script is running on. 88 | self.byte_order = 'S' 89 | temp_version = temp_version.newbyteorder(self.byte_order) 90 | version = temp_version[0] 91 | if version not in READER_SUPPORTED_VERSIONS: 92 | raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle') 93 | self.fields: OrderedDict[str, ReaderField] = OrderedDict() 94 | self.tensors: list[ReaderTensor] = [] 95 | offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32])) 96 | temp_counts = self._get(offs, np.uint64, 2) 97 | offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64])) 98 | offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64])) 99 | tensor_count, kv_count = temp_counts 100 | offs = self._build_fields(offs, kv_count) 101 | offs, tensors_fields = self._build_tensors_fields(offs, tensor_count) 102 | new_align = self.fields.get('general.alignment') 103 | if new_align is not None: 104 | if new_align.types != [GGUFValueType.UINT64]: 105 | raise ValueError('Bad type for general.alignment field') 106 | self.alignment = new_align.parts[-1][0] 107 | padding = offs % self.alignment 108 | if padding != 0: 109 | offs += self.alignment - padding 110 | self._build_tensors(offs, tensors_fields) 111 | 112 | _DT = TypeVar('_DT', bound = npt.DTypeLike) 113 | 114 | # Fetch a key/value metadata field by key. 115 | def get_field(self, key: str) -> Union[ReaderField, None]: 116 | return self.fields.get(key, None) 117 | 118 | # Fetch a tensor from the list by index. 119 | def get_tensor(self, idx: int) -> ReaderTensor: 120 | return self.tensors[idx] 121 | 122 | def _get( 123 | self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None, 124 | ) -> npt.NDArray[Any]: 125 | count = int(count) 126 | itemsize = int(np.empty([], dtype = dtype).itemsize) 127 | end_offs = offset + itemsize * count 128 | return ( 129 | self.data[offset:end_offs] 130 | .view(dtype = dtype)[:count] 131 | .newbyteorder(override_order or self.byte_order) 132 | ) 133 | 134 | def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int: 135 | if field.name in self.fields: 136 | raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}') 137 | self.fields[field.name] = field 138 | return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts) 139 | 140 | def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: 141 | slen = self._get(offset, np.uint64) 142 | return slen, self._get(offset + 8, np.uint8, slen[0]) 143 | 144 | def _get_field_parts( 145 | self, orig_offs: int, raw_type: int, 146 | ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]: 147 | offs = orig_offs 148 | types: list[GGUFValueType] = [] 149 | gtype = GGUFValueType(raw_type) 150 | types.append(gtype) 151 | # Handle strings. 152 | if gtype == GGUFValueType.STRING: 153 | sparts: list[npt.NDArray[Any]] = list(self._get_str(offs)) 154 | size = sum(int(part.nbytes) for part in sparts) 155 | return size, sparts, [1], types 156 | # Check if it's a simple scalar type. 157 | nptype = self.gguf_scalar_to_np.get(gtype) 158 | if nptype is not None: 159 | val = self._get(offs, nptype) 160 | return int(val.nbytes), [val], [0], types 161 | # Handle arrays. 162 | if gtype == GGUFValueType.ARRAY: 163 | raw_itype = self._get(offs, np.uint32) 164 | offs += int(raw_itype.nbytes) 165 | alen = self._get(offs, np.uint64) 166 | offs += int(alen.nbytes) 167 | aparts: list[npt.NDArray[Any]] = [raw_itype, alen] 168 | data_idxs: list[int] = [] 169 | for idx in range(alen[0]): 170 | curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0]) 171 | if idx == 0: 172 | types += curr_types 173 | idxs_offs = len(aparts) 174 | aparts += curr_parts 175 | data_idxs += (idx + idxs_offs for idx in curr_idxs) 176 | offs += curr_size 177 | return offs - orig_offs, aparts, data_idxs, types 178 | # We can't deal with this one. 179 | raise ValueError('Unknown/unhandled field type {gtype}') 180 | 181 | def _get_tensor(self, orig_offs: int) -> ReaderField: 182 | offs = orig_offs 183 | name_len, name_data = self._get_str(offs) 184 | offs += int(name_len.nbytes + name_data.nbytes) 185 | n_dims = self._get(offs, np.uint32) 186 | offs += int(n_dims.nbytes) 187 | dims = self._get(offs, np.uint64, n_dims[0]) 188 | offs += int(dims.nbytes) 189 | raw_dtype = self._get(offs, np.uint32) 190 | offs += int(raw_dtype.nbytes) 191 | offset_tensor = self._get(offs, np.uint64) 192 | offs += int(offset_tensor.nbytes) 193 | return ReaderField( 194 | orig_offs, 195 | str(bytes(name_data), encoding = 'utf-8'), 196 | [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor], 197 | [1, 3, 4, 5], 198 | ) 199 | 200 | def _build_fields(self, offs: int, count: int) -> int: 201 | for _ in range(count): 202 | orig_offs = offs 203 | kv_klen, kv_kdata = self._get_str(offs) 204 | offs += int(kv_klen.nbytes + kv_kdata.nbytes) 205 | raw_kv_type = self._get(offs, np.uint32) 206 | offs += int(raw_kv_type.nbytes) 207 | parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type] 208 | idxs_offs = len(parts) 209 | field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0]) 210 | parts += field_parts 211 | self._push_field(ReaderField( 212 | orig_offs, 213 | str(bytes(kv_kdata), encoding = 'utf-8'), 214 | parts, 215 | [idx + idxs_offs for idx in field_idxs], 216 | field_types, 217 | ), skip_sum = True) 218 | offs += field_size 219 | return offs 220 | 221 | def _build_tensors_fields(self, offs: int, count: int) -> tuple[int, list[ReaderField]]: 222 | tensor_fields = [] 223 | for _ in range(count): 224 | field = self._get_tensor(offs) 225 | offs += sum(int(part.nbytes) for part in field.parts) 226 | tensor_fields.append(field) 227 | return offs, tensor_fields 228 | 229 | def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None: 230 | tensors = [] 231 | for field in fields: 232 | _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts 233 | ggml_type = GGMLQuantizationType(raw_dtype[0]) 234 | n_elems = np.prod(dims) 235 | block_size, type_size = GGML_QUANT_SIZES[ggml_type] 236 | n_bytes = n_elems * type_size // block_size 237 | data_offs = int(start_offs + offset_tensor[0]) 238 | item_type: npt.DTypeLike 239 | if ggml_type == GGMLQuantizationType.F32: 240 | item_count = n_elems 241 | item_type = np.float32 242 | elif ggml_type == GGMLQuantizationType.F16: 243 | item_count = n_elems 244 | item_type = np.float16 245 | else: 246 | item_count = n_bytes 247 | item_type = np.uint8 248 | tensors.append(ReaderTensor( 249 | name = str(bytes(name_data), encoding = 'utf-8'), 250 | tensor_type = ggml_type, 251 | shape = dims, 252 | n_elements = n_elems, 253 | n_bytes = n_bytes, 254 | data_offset = data_offs, 255 | data = self._get(data_offs, item_type, item_count), 256 | field = field, 257 | )) 258 | self.tensors = tensors 259 | -------------------------------------------------------------------------------- /gpt2_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/tokenization_gpt2.py 2 | 3 | import json 4 | from functools import lru_cache 5 | 6 | import regex as re 7 | 8 | @lru_cache() 9 | def bytes_to_unicode(): 10 | """ 11 | Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control 12 | characters the bpe code barfs on. 13 | 14 | The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab 15 | if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for 16 | decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup 17 | tables between utf-8 bytes and unicode strings. 18 | """ 19 | bs = ( 20 | list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 21 | ) 22 | cs = bs[:] 23 | n = 0 24 | for b in range(2**8): 25 | if b not in bs: 26 | bs.append(b) 27 | cs.append(2**8 + n) 28 | n += 1 29 | cs = [chr(n) for n in cs] 30 | return dict(zip(bs, cs)) 31 | 32 | 33 | def get_pairs(word): 34 | """ 35 | Return set of symbol pairs in a word. 36 | 37 | Word is represented as tuple of symbols (symbols being variable-length strings). 38 | """ 39 | pairs = set() 40 | prev_char = word[0] 41 | for char in word[1:]: 42 | pairs.add((prev_char, char)) 43 | prev_char = char 44 | return pairs 45 | 46 | class GPT2Tokenizer: 47 | def __init__(self, vocab_file, merges_file, tokenizer_config_file): 48 | with open(vocab_file, encoding="utf-8") as vocab_handle: 49 | self.encoder = json.load(vocab_handle) 50 | self.decoder = {v: k for k, v in self.encoder.items()} 51 | self.byte_encoder = bytes_to_unicode() 52 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 53 | with open(merges_file, encoding="utf-8") as merges_handle: 54 | bpe_merges = merges_handle.read().split("\n") 55 | if bpe_merges[0].startswith("#"): 56 | bpe_merges = bpe_merges[1:] 57 | bpe_merges = [tuple(merge.split()) for merge in bpe_merges] 58 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 59 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 60 | # todo fix qwen 61 | # self.pat = re.compile(r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""") 62 | with open(tokenizer_config_file, encoding="utf-8") as conf: 63 | spec_tokens = json.load(conf) 64 | self.bos_token_id = self.encoder[spec_tokens["bos_token"]] if "bos_token" in spec_tokens else 0 65 | self.eos_token_id = self.encoder[spec_tokens["eos_token"]] if "eos_token" in spec_tokens else 0 66 | self.unk_token_id = self.encoder[spec_tokens["unk_token"]] if "unk_token" in spec_tokens else 0 67 | self.cache = {} 68 | 69 | def bos_id(self): 70 | return self.bos_token_id 71 | 72 | def eos_id(self): 73 | return self.eos_token_id 74 | 75 | def unk_id(self): 76 | return self.unk_token_id 77 | 78 | def bpe(self, token): 79 | if token in self.cache: 80 | return self.cache[token] 81 | word = tuple(token) 82 | pairs = get_pairs(word) 83 | 84 | if not pairs: 85 | return token 86 | 87 | while True: 88 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 89 | if bigram not in self.bpe_ranks: 90 | break 91 | first, second = bigram 92 | new_word = [] 93 | i = 0 94 | while i < len(word): 95 | try: 96 | j = word.index(first, i) 97 | except ValueError: 98 | new_word.extend(word[i:]) 99 | break 100 | else: 101 | new_word.extend(word[i:j]) 102 | i = j 103 | 104 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 105 | new_word.append(first + second) 106 | i += 2 107 | else: 108 | new_word.append(word[i]) 109 | i += 1 110 | new_word = tuple(new_word) 111 | word = new_word 112 | if len(word) == 1: 113 | break 114 | else: 115 | pairs = get_pairs(word) 116 | word = " ".join(word) 117 | self.cache[token] = word 118 | return word 119 | 120 | def encode(self, text): 121 | bpe_tokens = [] 122 | for token in re.findall(self.pat, text): 123 | token = "".join( 124 | self.byte_encoder[b] for b in token.encode("utf-8") 125 | ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) 126 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) 127 | return [self.encoder.get(token, self.unk_token_id) for token in bpe_tokens] 128 | 129 | def decode(self, idx): 130 | tokens = [self.decoder.get(index) for index in idx] 131 | text = "".join(tokens) 132 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace") 133 | return text 134 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | import math 8 | import os 9 | import json 10 | from dataclasses import dataclass 11 | from typing import Optional 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch import Tensor 16 | from torch.nn import functional as F 17 | from torch.distributed import _functional_collectives as funcol 18 | 19 | import register_lib 20 | 21 | def find_multiple(n: int, k: int) -> int: 22 | if n % k == 0: 23 | return n 24 | return n + k - (n % k) 25 | 26 | @dataclass 27 | class ModelArgs: 28 | architecture: str = "llama" 29 | block_size: int = 2048 30 | vocab_size: int = 32000 31 | n_layer: int = 32 32 | n_head: int = 32 33 | dim: int = 4096 34 | intermediate_size: int = None 35 | n_local_heads: int = -1 36 | head_dim: int = -1 37 | rope_base: float = 10000 38 | rope_type: str = "none" 39 | rope_dim: int = -1 40 | norm_eps: float = 1e-5 41 | moe: bool = False 42 | num_experts: int = 1 43 | num_experts_per_tok: int = 1 44 | hidden_act: str = "silu" 45 | mlp_gate: bool = True 46 | layernorm: bool = False 47 | logit_scale: float = 1.0 48 | 49 | def __post_init__(self): 50 | if self.n_local_heads == -1: 51 | self.n_local_heads = self.n_head 52 | if self.intermediate_size is None: 53 | hidden_dim = 4 * self.dim 54 | n_hidden = int(2 * hidden_dim / 3) 55 | self.intermediate_size = find_multiple(n_hidden, 256) 56 | if self.head_dim == -1: 57 | self.head_dim = self.dim // self.n_head 58 | if self.rope_dim == -1: 59 | self.rope_dim = self.head_dim 60 | 61 | 62 | class KVCache(nn.Module): 63 | def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.float16): 64 | super().__init__() 65 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 66 | self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) 67 | self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) 68 | 69 | def update(self, input_pos, k_val, v_val): 70 | # input_pos: [S], k_val: [B, H, S, D] 71 | assert input_pos.shape[0] == k_val.shape[2] 72 | 73 | k_out = self.k_cache 74 | v_out = self.v_cache 75 | k_out[:, :, input_pos] = k_val 76 | v_out[:, :, input_pos] = v_val 77 | 78 | return k_out, v_out 79 | 80 | 81 | class Transformer(nn.Module): 82 | def __init__(self, config: ModelArgs) -> None: 83 | super().__init__() 84 | self.config = config 85 | 86 | self.token_embd = Embedding(config.vocab_size, config.dim) 87 | self.blk = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) 88 | if config.layernorm: 89 | self.output_norm = nn.LayerNorm(config.dim, eps=config.norm_eps) 90 | else: 91 | self.output_norm = RMSNorm(config.dim, eps=config.norm_eps) 92 | self.output = Linear(config.dim, config.vocab_size, bias=False) 93 | 94 | self.freqs_cis: Optional[Tensor] = None 95 | self.mask_cache: Optional[Tensor] = None 96 | self.max_batch_size = -1 97 | self.max_seq_length = -1 98 | 99 | def setup_caches(self, max_batch_size, max_seq_length): 100 | if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: 101 | return 102 | head_dim = self.config.head_dim 103 | max_seq_length = find_multiple(max_seq_length, 8) 104 | self.max_seq_length = max_seq_length 105 | self.max_batch_size = max_batch_size 106 | for b in self.blk: 107 | b.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim) 108 | 109 | self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.rope_dim, self.config.rope_base) 110 | self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) 111 | 112 | def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: 113 | assert self.freqs_cis is not None, "Caches must be initialized first" 114 | mask = self.causal_mask[None, None, input_pos] 115 | freqs_cis = self.freqs_cis[input_pos] 116 | x = self.token_embd(idx) 117 | if self.config.architecture == "gemma": 118 | x = x * (self.config.dim ** 0.5) 119 | if self.config.architecture == "minicpm": 120 | x = x * 12.0 121 | 122 | for i, layer in enumerate(self.blk): 123 | x = layer(x, input_pos, freqs_cis, mask) 124 | x = self.output_norm(x) 125 | logits = self.output(x) 126 | if self.config.architecture == "minicpm": 127 | logits = logits / (self.config.dim // 256) 128 | logits *= self.config.logit_scale 129 | return logits 130 | 131 | @classmethod 132 | def from_json(cls, name: str): 133 | config = json.load(open(name, 'r')) 134 | return cls(ModelArgs(**config)) 135 | 136 | class TransformerBlock(nn.Module): 137 | def __init__(self, config: ModelArgs): 138 | super().__init__() 139 | self.config = config 140 | assert config.dim % config.n_head == 0 141 | 142 | # Attention norm 143 | if config.layernorm: 144 | self.attn_norm = nn.LayerNorm(config.dim, eps=config.norm_eps) 145 | else: 146 | self.attn_norm = RMSNorm(config.dim, config.norm_eps) 147 | 148 | # Attention layer 149 | # https://github.com/pacman100/llama.cpp/blob/ee5b171250f707b08334aa8dcda259888bc2ccc6/gguf-py/gguf/tensor_mapping.py#L97 150 | if config.architecture in ["qwen", "phi2"]: 151 | self.concat_qkv = True 152 | self.attn_qkv = Linear(config.dim, config.head_dim * (config.n_head + config.n_local_heads * 2), bias=True) 153 | else: 154 | self.concat_qkv = False 155 | self.attn_q = Linear(config.dim, config.n_head * config.head_dim, bias=True) 156 | self.attn_k = Linear(config.dim, config.n_local_heads * config.head_dim, bias=True) 157 | self.attn_v = Linear(config.dim, config.n_local_heads * config.head_dim, bias=True) 158 | self.attn_output = Linear(config.n_head * config.head_dim, config.dim, bias=True) 159 | self.kv_cache = None 160 | 161 | # ffn norm 162 | if config.layernorm: 163 | self.ffn_norm = nn.LayerNorm(config.dim, eps=config.norm_eps) 164 | else: 165 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps) 166 | if config.hidden_act == "gelu_tanh": 167 | self.act_fn = nn.GELU(approximate="tanh") 168 | elif config.hidden_act == "gelu": 169 | self.act_fn = nn.GELU() 170 | else: 171 | self.act_fn = nn.SiLU() 172 | 173 | # ffn layer 174 | if config.moe: 175 | self.ffn_gate_inp = Linear(config.dim, config.num_experts, bias=True) 176 | self.ffn_gate = nn.ModuleList(Linear(config.dim, config.intermediate_size, bias=True) for _ in range(config.num_experts)) 177 | self.ffn_up = nn.ModuleList(Linear(config.dim, config.intermediate_size, bias=True) for _ in range(config.num_experts)) 178 | self.ffn_down = nn.ModuleList(Linear(config.intermediate_size, config.dim, bias=True) for _ in range(config.num_experts)) 179 | else: 180 | if config.mlp_gate: 181 | self.ffn_gate = Linear(config.dim, config.intermediate_size, bias=True) 182 | self.ffn_up = Linear(config.dim, config.intermediate_size, bias=True) 183 | self.ffn_down = Linear(config.intermediate_size, config.dim, bias=True) 184 | 185 | self.n_head = config.n_head 186 | self.head_dim = config.head_dim 187 | self.n_local_heads = config.n_local_heads 188 | self.dim = config.dim 189 | self.moe = config.moe 190 | self.num_experts = config.num_experts 191 | self.num_experts_per_tok = config.num_experts_per_tok 192 | self.rope_type = config.rope_type 193 | self.rope_dim = config.rope_dim 194 | self.world_size = int(os.environ.get("LOCAL_WORLD_SIZE", "1")) 195 | 196 | def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: 197 | bsz, seqlen, _ = x.shape 198 | 199 | attn_x = self.attn_norm(x) 200 | # attention 201 | if self.concat_qkv: 202 | qkv = self.attn_qkv(attn_x).view(bsz, seqlen, -1, self.head_dim) 203 | q, k, v = qkv.split([self.n_head, self.n_local_heads, self.n_local_heads], 204 | dim=2) 205 | else: 206 | q = self.attn_q(attn_x).view(bsz, seqlen, self.n_head, self.head_dim) 207 | k = self.attn_k(attn_x).view(bsz, seqlen, self.n_local_heads, self.head_dim) 208 | v = self.attn_v(attn_x).view(bsz, seqlen, self.n_local_heads, self.head_dim) 209 | 210 | if self.rope_dim != self.head_dim: 211 | q_rot, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:] 212 | k_rot, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:] 213 | q_rot = apply_rotary_emb(q_rot, freqs_cis, self.rope_type) 214 | k_rot = apply_rotary_emb(k_rot, freqs_cis, self.rope_type) 215 | q = torch.cat((q_rot, q_pass), dim=-1) 216 | k = torch.cat((k_rot, k_pass), dim=-1) 217 | else: 218 | q = apply_rotary_emb(q, freqs_cis, self.rope_type) 219 | k = apply_rotary_emb(k, freqs_cis, self.rope_type) 220 | 221 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 222 | 223 | if self.kv_cache is not None: 224 | k, v = self.kv_cache.update(input_pos, k, v) 225 | 226 | k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 227 | v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 228 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) 229 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, -1) 230 | y = self.attn_output(y) 231 | 232 | if self.world_size > 1: 233 | y = funcol.all_reduce(y, "sum", list(range(self.world_size))) 234 | 235 | if self.config.architecture == "minicpm": 236 | y = y * 1.4 / math.sqrt(self.config.n_layer) 237 | 238 | y = x + y 239 | 240 | if self.config.architecture in ["phi2", "command-r"]: 241 | mlp_y = attn_x 242 | else: 243 | mlp_y = self.ffn_norm(y) 244 | 245 | # mlp 246 | if self.moe: 247 | # reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/mixtral.py 248 | # This is inefficient for small batch size since it calculates all experts 249 | mlp_y = mlp_y.view(-1, mlp_y.shape[-1]) 250 | routing_weights = F.softmax(self.ffn_gate_inp(mlp_y), dim=1) 251 | routing_weights, selected_experts = torch.topk(routing_weights, 252 | self.num_experts_per_tok, 253 | dim=-1) 254 | routing_weights /= routing_weights.sum(dim=-1, keepdim=True) 255 | z = None 256 | for idx in range(self.num_experts): 257 | if self.ffn_gate[idx] is None: continue 258 | z_idx = self.ffn_down[idx](self.act_fn(self.ffn_gate[idx](mlp_y)) * self.ffn_up[idx](mlp_y)) 259 | expert_mask = (selected_experts == idx) 260 | expert_weights = (routing_weights * expert_mask).sum(dim=-1, 261 | keepdim=True) 262 | z_idx = z_idx * expert_weights 263 | z = z_idx if z is None else z + z_idx 264 | elif self.config.mlp_gate: 265 | z = self.ffn_down(self.act_fn(self.ffn_gate(mlp_y)) * self.ffn_up(mlp_y)) 266 | else: 267 | z = self.ffn_down(self.act_fn(self.ffn_up(mlp_y))) 268 | 269 | if self.world_size > 1: 270 | z = funcol.all_reduce(z, "sum", list(range(self.world_size))) 271 | 272 | if self.config.architecture == "minicpm": 273 | z = z * 1.4 / math.sqrt(self.config.n_layer) 274 | z = z + y 275 | 276 | return z 277 | 278 | 279 | class RMSNorm(nn.Module): 280 | def __init__(self, dim: int, eps: float = 1e-5): 281 | super().__init__() 282 | self.eps = eps 283 | self.weight = nn.Parameter(torch.ones(dim)) 284 | 285 | def _norm(self, x): 286 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 287 | 288 | def forward(self, x: Tensor) -> Tensor: 289 | output = self._norm(x.float()).type_as(x) 290 | return output * self.weight 291 | 292 | 293 | class Linear(nn.Module): 294 | """Quantized linear layer""" 295 | def __init__(self, infeatures, outfeatures, bias, **kwargs): 296 | super().__init__() 297 | self.infeatures = infeatures 298 | self.outfeatures = outfeatures 299 | # Fake weight 300 | self.register_buffer('weight', torch.nn.parameter.UninitializedBuffer()) 301 | self.register_buffer('weight_type', torch.zeros((), dtype=torch.int)) 302 | 303 | if bias: 304 | self.register_buffer( 305 | 'bias', torch.zeros((outfeatures), dtype=torch.float16)) 306 | else: 307 | self.bias = None 308 | 309 | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): 310 | if prefix + 'weight' in state_dict: 311 | weight = state_dict[prefix + 'weight'] 312 | self.weight.materialize(weight.shape, dtype=weight.dtype) 313 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 314 | 315 | def forward(self, x): 316 | xshape = x.view(-1, x.shape[-1]) 317 | if self.weight_type_int < 2: 318 | output = xshape @ self.weight.view(self.outfeatures, self.infeatures).T 319 | # Force to use dequant for 2-bit model for now 320 | elif xshape.shape[0] == 1: 321 | output = torch.ops.llama_cpp.ggml_mul_mat_vec_a8(self.weight, xshape, self.weight_type_int, self.outfeatures) 322 | elif xshape.shape[0] < 8 and self.weight_type_int < 16: 323 | output = torch.ops.llama_cpp.ggml_mul_mat_a8(self.weight, xshape, self.weight_type_int, self.outfeatures) 324 | else: 325 | weight = torch.ops.llama_cpp.ggml_dequantize(self.weight, self.weight_type_int, self.outfeatures, self.infeatures) 326 | output = xshape @ weight.T 327 | if self.bias is not None: 328 | output = output + self.bias 329 | output = output.view(*x.shape[:-1], self.outfeatures) 330 | return output 331 | 332 | class Embedding(nn.Module): 333 | """Quantized embedding layer""" 334 | def __init__(self, vocab_size, dim): 335 | super().__init__() 336 | self.vocab_size = vocab_size 337 | self.dim = dim 338 | # Fake weight 339 | self.register_buffer('weight', torch.nn.parameter.UninitializedBuffer()) 340 | self.register_buffer('weight_type', torch.zeros((), dtype=torch.int)) 341 | 342 | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): 343 | if prefix + 'weight' in state_dict: 344 | weight = state_dict[prefix + 'weight'] 345 | self.weight.materialize(weight.shape, dtype=weight.dtype) 346 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 347 | 348 | def forward(self, ind): 349 | if self.weight_type_int < 2: 350 | return torch.embedding(self.weight.view(self.vocab_size, self.dim), ind) 351 | ind_flat = ind.flatten() 352 | quant = torch.index_select(self.weight.view(self.vocab_size, -1), dim=0, index=ind_flat) 353 | dequant = torch.ops.llama_cpp.ggml_dequantize(quant, self.weight_type_int, 354 | self.dim, ind_flat.shape[0]) 355 | return dequant.view(*ind.shape, self.dim) 356 | 357 | 358 | def precompute_freqs_cis( 359 | seq_len: int, n_elem: int, base: int = 10000 360 | ) -> Tensor: 361 | freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) 362 | t = torch.arange(seq_len, device=freqs.device) 363 | freqs = torch.outer(t, freqs) 364 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 365 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) 366 | return cache.to(dtype=torch.float16) 367 | 368 | 369 | def apply_rotary_emb(x: Tensor, freqs_cis: Tensor, rope_type: str) -> Tensor: 370 | if rope_type == "norm": 371 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2) 372 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) 373 | x_out2 = torch.stack( 374 | [ 375 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], 376 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], 377 | ], 378 | -1, 379 | ) 380 | else: 381 | xshaped = x.float().reshape(*x.shape[:-1], 2, -1) 382 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(-1), 2) 383 | x_out2 = torch.stack( 384 | [ 385 | xshaped[..., 0, :] * freqs_cis[..., 0] - xshaped[..., 1, :] * freqs_cis[..., 1], 386 | xshaped[..., 1, :] * freqs_cis[..., 0] + xshaped[..., 0, :] * freqs_cis[..., 1], 387 | ], 388 | -1, 389 | ) 390 | 391 | x_out2 = x_out2.flatten(3) 392 | return x_out2.type_as(x) 393 | -------------------------------------------------------------------------------- /py_bind.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor ggml_dequantize( 5 | torch::Tensor X, 6 | int8_t type, 7 | int64_t m, 8 | int64_t n 9 | ); 10 | 11 | torch::Tensor ggml_mul_mat_vec( 12 | torch::Tensor W, // quant weight 13 | torch::Tensor X, // input 14 | int8_t type, 15 | int64_t m 16 | ); 17 | 18 | torch::Tensor ggml_mul_mat_vec_a8( 19 | torch::Tensor W, // quant weight 20 | torch::Tensor X, // input 21 | int8_t type, 22 | int64_t row 23 | ); 24 | 25 | torch::Tensor ggml_mul_mat_a8( 26 | torch::Tensor W, // quant weight 27 | torch::Tensor X, // input 28 | int8_t type, 29 | int64_t row 30 | ); 31 | 32 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 33 | m.def("ggml_dequantize", &ggml_dequantize, "ggml_dequantize"); 34 | m.def("ggml_mul_mat_vec", &ggml_mul_mat_vec, "ggml_mul_mat_vec"); 35 | m.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8, "ggml_mul_mat_vec_a8"); 36 | m.def("ggml_mul_mat_a8", &ggml_mul_mat_a8, "ggml_mul_mat_a8"); 37 | } 38 | -------------------------------------------------------------------------------- /register_lib.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import llamacpp_cuda 3 | 4 | import torch._custom_ops 5 | from torch import Tensor 6 | 7 | my_lib = torch.library.Library("llama_cpp", "DEF") 8 | 9 | @torch._custom_ops.custom_op("llama_cpp::ggml_dequantize") 10 | def ggml_dequantize(x: Tensor, type: int, m: int, n: int) -> Tensor: 11 | raise NotImplementedError() 12 | 13 | @torch._custom_ops.impl_abstract("llama_cpp::ggml_dequantize") 14 | def ggml_dequantize_abs(x: Tensor, type: int, m: int, n: int) -> Tensor: 15 | return x.new_empty((m, n), dtype=torch.half) 16 | 17 | @torch._custom_ops.impl("llama_cpp::ggml_dequantize", device_types="cuda") 18 | def ggml_dequantize_cuda(x: Tensor, type: int, m: int, n: int) -> Tensor: 19 | return llamacpp_cuda.ggml_dequantize(x, type, m, n) 20 | 21 | @torch._custom_ops.custom_op("llama_cpp::ggml_mul_mat_vec") 22 | def ggml_mul_mat_vec(x: Tensor, y: Tensor, type: int, m: int) -> Tensor: 23 | raise NotImplementedError() 24 | 25 | @torch._custom_ops.impl_abstract("llama_cpp::ggml_mul_mat_vec") 26 | def ggml_mul_mat_vec_abs(x: Tensor, y: Tensor, type: int, m: int) -> Tensor: 27 | assert x.device == y.device 28 | result = x.new_empty((1, m), dtype=torch.half) 29 | return result 30 | 31 | @torch._custom_ops.impl("llama_cpp::ggml_mul_mat_vec", device_types="cuda") 32 | def ggml_mul_mat_vec_cuda(x: Tensor, y: Tensor, type: int, m: int) -> Tensor: 33 | return llamacpp_cuda.ggml_mul_mat_vec(x, y, type, m) 34 | 35 | @torch._custom_ops.custom_op("llama_cpp::ggml_mul_mat_vec_a8") 36 | def ggml_mul_mat_vec_a8(x: Tensor, y: Tensor, type: int, m: int) -> Tensor: 37 | raise NotImplementedError() 38 | 39 | @torch._custom_ops.impl_abstract("llama_cpp::ggml_mul_mat_vec_a8") 40 | def ggml_mul_mat_vec_a8_abs(x: Tensor, y: Tensor, type: int, m: int) -> Tensor: 41 | assert x.device == y.device 42 | result = x.new_empty((1, m), dtype=torch.half) 43 | return result 44 | 45 | @torch._custom_ops.impl("llama_cpp::ggml_mul_mat_vec_a8", device_types="cuda") 46 | def ggml_mul_mat_vec_a8_cuda(x: Tensor, y: Tensor, type: int, m: int) -> Tensor: 47 | return llamacpp_cuda.ggml_mul_mat_vec_a8(x, y, type, m) 48 | 49 | @torch._custom_ops.custom_op("llama_cpp::ggml_mul_mat_a8") 50 | def ggml_mul_mat_a8(x: Tensor, y: Tensor, type: int, m: int) -> Tensor: 51 | raise NotImplementedError() 52 | 53 | @torch._custom_ops.impl_abstract("llama_cpp::ggml_mul_mat_a8") 54 | def ggml_mul_mat_a8_abs(x: Tensor, y: Tensor, type: int, m: int) -> Tensor: 55 | assert x.device == y.device 56 | result = x.new_empty((x.shape[0], m), dtype=torch.half) 57 | return result 58 | 59 | @torch._custom_ops.impl("llama_cpp::ggml_mul_mat_a8", device_types="cuda") 60 | def ggml_mul_mat_a8_cuda(x: Tensor, y: Tensor, type: int, m: int) -> Tensor: 61 | return llamacpp_cuda.ggml_mul_mat_a8(x, y, type, m) 62 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gguf 2 | jinja2 3 | torch 4 | sentencepiece -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension 2 | from torch.utils import cpp_extension 3 | 4 | setup(name='llamacpp_cuda', 5 | ext_modules=[cpp_extension.CUDAExtension( 6 | 'llamacpp_cuda', 7 | ['py_bind.cpp', 'llamacpp_kernel.cu'], 8 | extra_compile_args={'cxx': ['-g', '-lineinfo', '-fno-strict-aliasing'], 9 | 'nvcc': ['-O3', '-g', '-Xcompiler', '-rdynamic', '-lineinfo']})], 10 | cmdclass={'build_ext': cpp_extension.BuildExtension}) 11 | -------------------------------------------------------------------------------- /tp.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/pytorch-labs/gpt-fast/blob/main/tp.py 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | import os 8 | from typing import Optional 9 | 10 | import numpy as np 11 | import torch 12 | import torch.distributed as dist 13 | from torch import nn 14 | 15 | from model import Transformer, Linear 16 | 17 | QK_K = 256 18 | # Items here are (block size, type size) 19 | GGML_QUANT_SIZES = { 20 | 0: (1, 4), 21 | 1: (1, 2), 22 | 2: (32, 2 + 16), 23 | 3: (32, 2 + 2 + 16), 24 | 6: (32, 2 + 4 + 16), 25 | 7: (32, 2 + 2 + 4 + 16), 26 | 8: (32, 2 + 32), 27 | 9: (32, 4 + 4 + 32), 28 | 10: (256, 2 + 2 + QK_K // 16 + QK_K // 4), 29 | 11: (256, 2 + QK_K // 4 + QK_K // 8 + 12), 30 | 12: (256, 2 + 2 + QK_K // 2 + 12), 31 | 13: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), 32 | 14: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), 33 | 15: (256, 4 + QK_K + QK_K // 8), 34 | 16: (256, 2 + QK_K // 4), 35 | 17: (256, 2 + QK_K // 4 + QK_K // 32), 36 | 18: (256, 2 + 3 * QK_K // 8), 37 | 19: (256, 2 + QK_K // 8 + QK_K // 16), 38 | 20: (32, 2 + 32 // 2), 39 | 21: (256, 2 + QK_K // 4 + QK_K // 32 + QK_K // 8 + QK_K // 64), 40 | 22: (256, 2 + QK_K // 4 + QK_K // 32 + QK_K // 32), 41 | 23: (256, 2 + 2 + QK_K // 64 + QK_K // 2), 42 | } 43 | 44 | 45 | def _get_rank() -> int: 46 | return int(os.environ.get("LOCAL_RANK", "0")) 47 | 48 | def is_local(): 49 | return _get_rank() == 0 50 | 51 | def local_break(): 52 | if is_local(): 53 | breakpoint() 54 | dist.barrier() 55 | 56 | def _get_world_size() -> int: 57 | return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) 58 | 59 | def maybe_init_dist() -> Optional[int]: 60 | try: 61 | # provided by torchrun 62 | rank = _get_rank() 63 | world_size = _get_world_size() 64 | 65 | if world_size < 2: 66 | # too few gpus to parallelize, tp is no-op 67 | return None 68 | except KeyError: 69 | # not run via torchrun, no-op 70 | return None 71 | 72 | torch.cuda.set_device(rank) 73 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) 74 | return rank 75 | 76 | 77 | def _apply_tp_linear(linear: Linear, style: str) -> None: 78 | rank = _get_rank() 79 | world_size = _get_world_size() 80 | 81 | block_size = GGML_QUANT_SIZES[linear.weight_type_int][0] 82 | assert linear.infeatures % block_size == 0 83 | 84 | def shard(x, dim): 85 | assert x.size(dim=dim) % world_size == 0 86 | return torch.tensor_split(x, world_size, dim=dim)[rank] 87 | 88 | weight = linear.weight.view(linear.outfeatures, linear.infeatures // block_size, -1) 89 | if style == "colwise": 90 | sharded_weight = shard(weight, 0) 91 | linear.outfeatures = linear.outfeatures // world_size 92 | if linear.bias is not None: 93 | linear.bias = nn.Parameter(shard(linear.bias, 0), requires_grad=False) 94 | else: 95 | sharded_weight = shard(weight, 1) 96 | linear.infeatures = linear.infeatures // world_size 97 | if linear.bias is not None: 98 | linear.bias = linear.bias / world_size 99 | linear.weight = nn.Parameter(sharded_weight.contiguous().view(-1), requires_grad=False) 100 | 101 | 102 | def _apply_tp_Transformer(Transformer: Transformer) -> None: 103 | # overwrite config before Transformer.setup_cache is called 104 | world_size = _get_world_size() 105 | Transformer.config.n_head = Transformer.config.n_head // world_size 106 | Transformer.config.dim = Transformer.config.dim // world_size 107 | Transformer.config.n_local_heads = Transformer.config.n_local_heads // world_size 108 | 109 | 110 | def apply_tp(model: Transformer) -> None: 111 | rank = _get_rank() 112 | world_size = _get_world_size() 113 | _apply_tp_Transformer(model) 114 | for block in model.blk: 115 | if isinstance(block.ffn_gate, nn.ModuleList): 116 | # Expert parallel for MOE 117 | expert_indicies = np.array_split(range( 118 | block.num_experts), world_size)[rank].tolist() 119 | block.ffn_gate = nn.ModuleList(block.ffn_gate[i] if i in expert_indicies else None for i in range(block.num_experts)) 120 | block.ffn_up = nn.ModuleList(block.ffn_up[i] if i in expert_indicies else None for i in range(block.num_experts)) 121 | block.ffn_down = nn.ModuleList(block.ffn_down[i] if i in expert_indicies else None for i in range(block.num_experts)) 122 | else: 123 | _apply_tp_linear(block.ffn_gate, "colwise") 124 | _apply_tp_linear(block.ffn_up, "colwise") 125 | _apply_tp_linear(block.ffn_down, "rowwise") 126 | _apply_tp_linear(block.attn_q, "colwise") 127 | _apply_tp_linear(block.attn_k, "colwise") 128 | _apply_tp_linear(block.attn_v, "colwise") 129 | _apply_tp_linear(block.attn_output, "rowwise") 130 | 131 | # overwrite 132 | block.n_head = block.n_head // world_size 133 | block.dim = block.dim // world_size 134 | block.head_dim = block.dim // block.n_head 135 | block.n_local_heads = block.n_local_heads // world_size 136 | --------------------------------------------------------------------------------