├── sllm ├── __init__.py ├── ops │ ├── __init__.py │ └── matmul.cpp ├── nn │ ├── __init__.py │ ├── autodiff.py │ ├── transformer.py │ └── layers.py ├── common.py ├── utils.py ├── config.py └── train.py ├── requirements.txt ├── .DS_Store ├── assets └── logo.png ├── pyproject.toml ├── example.py ├── setup.py ├── .gitignore ├── README.md └── LICENSE /sllm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | numpy 3 | torch 4 | transformers 5 | -------------------------------------------------------------------------------- /sllm/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .matmul import bundled_scaled_matmul 2 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HenryNdubuaku/super-lazy-autograd/HEAD/.DS_Store -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HenryNdubuaku/super-lazy-autograd/HEAD/assets/logo.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel", "torch", "pybind11"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /sllm/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from sllm.nn.layers import (Attention, Linear, LoraLinear, 2 | MLP) 3 | from sllm.nn.transformer import SuperLazyLanguageModel 4 | -------------------------------------------------------------------------------- /sllm/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from platformdirs import user_cache_dir 5 | 6 | CACHE_DIR = user_cache_dir("sllm") 7 | WEIGHT_DIR = f"{CACHE_DIR}/weights" 8 | GRADIENT_DIR = f"{CACHE_DIR}/gradient_checkpoints" 9 | os.makedirs(GRADIENT_DIR, exist_ok=True) 10 | 11 | MAX_CONCURRENT_THREADS = os.cpu_count() * 2 12 | MINI_BATCH_SIZE = 2 13 | 14 | DTYPE = torch.float32 15 | MAX_GRAD_NORM = 0.01 16 | 17 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 18 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import load_dataset 3 | 4 | from sllm.nn import SuperLazyLanguageModel 5 | from sllm.train import prepare_dataset, sft 6 | 7 | torch.manual_seed(42) 8 | 9 | name = "Qwen/Qwen2-0.5B-Instruct" 10 | dataset = load_dataset("yahma/alpaca-cleaned", split="train[:20]") 11 | 12 | dataset = prepare_dataset( 13 | model_name=name, 14 | instructions=dataset["instruction"], 15 | responses=dataset["output"], 16 | inputs=dataset["input"], 17 | max_seq_len=256, 18 | ) 19 | 20 | model = SuperLazyLanguageModel( 21 | name=name, 22 | lora_alpha=16, 23 | lora_r=4, 24 | lora_dropout=0.1, 25 | ) 26 | 27 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) 28 | sft(model=model, dataset=dataset, optimizer=optimizer, batch_size=2, epochs=3) 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | from torch.utils.cpp_extension import BuildExtension, CppExtension 3 | 4 | ext_modules = [ 5 | CppExtension( 6 | name='sllm.ops.matmul', 7 | sources=['sllm/ops/matmul.cpp'], 8 | extra_compile_args=['-O3'] 9 | ) 10 | ] 11 | 12 | setup( 13 | name="sllm-lib", 14 | version="0.0.1", 15 | description="Super Lazy Language Model", 16 | long_description=open("README.md").read(), 17 | long_description_content_type="text/markdown", 18 | author="Henry Ndubuaku", 19 | packages=find_packages(), 20 | install_requires=[ 21 | "numpy", 22 | "transformers", 23 | "platformdirs", 24 | "tqdm", 25 | "pybind11" 26 | ], 27 | ext_modules=ext_modules, 28 | cmdclass={"build_ext": BuildExtension}, 29 | classifiers=[ 30 | "Programming Language :: Python :: 3", 31 | "License :: OSI Approved :: MIT License", 32 | "Operating System :: OS Independent", 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .nox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | *.py,cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Jupyter Notebook checkpoints 52 | .ipynb_checkpoints 53 | 54 | # IPython 55 | profile_default/ 56 | ipython_config.py 57 | 58 | # Environments 59 | .env 60 | .venv 61 | env/ 62 | venv/ 63 | ENV/ 64 | env.bak/ 65 | venv.bak/ 66 | 67 | # IDEs and editors 68 | .vscode/ 69 | .idea/ 70 | *.sublime-workspace 71 | *.sublime-project 72 | 73 | # PyCharm 74 | *.iml 75 | 76 | # System files 77 | .DS_Store 78 | Thumbs.db 79 | 80 | # Python cache files 81 | __pycache__/ 82 | *.py[cod] 83 | 84 | # mypy 85 | .mypy_cache/ 86 | .dmypy.json 87 | dmypy.json 88 | 89 | # Pyre type checker 90 | .pyre/ 91 | 92 | # pyright 93 | pyrightconfig.json 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Alt text 3 |

4 | 5 | ![License](https://img.shields.io/github/license/hmunachi/SuperLazyLanguageModel?style=flat-square)[![LinkedIn](https://img.shields.io/badge/-LinkedIn-blue?style=flat-square&logo=linkedin&logoColor=white)](https://www.linkedin.com//company/80434055) [![Twitter](https://img.shields.io/twitter/follow/hmunachii?style=social)](https://twitter.com/hmunachii) 6 | 7 | Author: [Henry Ndubuaku](https://www.linkedin.com/in/henry-ndubuaku-7b6350b8/) 8 | 9 | ## Overview 10 | 11 | I mean, do not train or fine-tune LLMs on your laptop, traing is done at much higher precision than inference (float32 or bfload16). Also, additional memory is often used for the gradients, optimizer states, and batch size. So, 4 - 6x the model size. For simplicity, around 8-24G of RAM per 1B params. 12 | 13 | HOWEVER, if you must do so on a laptop for whatever weird reason, this library implements most language models such that only the weights for each layer is loaded to the RAM, it implements LoRA fine-tuning such that frozen params are memory-mapped rather than loaded. 14 | 15 | Note the following: 16 | 1) Compute intensity = computation time / communication time, and maximisin this means maximising GPU utilisation. 17 | 2) Many computations in transformer models can be parallelised, QKV projections for example. 18 | 3) Most operations in transformers follow the signature A @ B * Scale, A.K.A scaled dot-product. 19 | 4) Q @ K.T / sqrt(dimK) is obiously equivalent to Q @ K.T * dimK^(-1/2) 20 | 5) But Lora_A @ Lora_B = Lora_A @ Lora_B * 1, also A * B = I @ A * B, and so on. 21 | 22 | We expressed the transformer forward pass and the backward vector-jacobian products for each layer as a bunch of scaled matmuls, which are bundled together and executed in parallel across different CPU cores as C++ extensions to bypass GIL. This concept makes it easy for an upcoming feature, where each bundle could be distributed across your friends' laptops, such that they only execute one operation called Bundled Scaled Matmul. You're welcome. 23 | 24 | ## Limitations 25 | 26 | 1) Gradient accumulation, gradient checkpointing and lazy execution trade time-complexity for memory-efficiency, but you have no choice, do you? 27 | 2) Yeah...your laptop will definitley heat up, GPUs burn up at data centers and cost so much to cool, your laptop is not special. 28 | 29 | ## Supported Models 30 | 31 | - deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B 32 | - Qwen/Qwen2.5-0.5B 33 | - Qwen/Qwen2.5-0.5B-Instruct 34 | - Qwen/Qwen2.5-1.5B 35 | - Qwen/Qwen2.5-1.5B-Instruct 36 | - Qwen/Qwen2.5-3B 37 | - Qwen/Qwen2.5-3B-Instruct 38 | 39 | ## Getting Started 40 | 41 | 1. ```bash 42 | pip install sllm-lib 43 | ``` 44 | 2. Initialize the model: 45 | ```python 46 | from sllm.nn import SuperLazyLanguageModel 47 | from sllm.config import Config 48 | 49 | config = Config( 50 | model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", 51 | lora_alpha=32, 52 | lora_r=8, 53 | lora_dropout=0.1, 54 | ) 55 | 56 | model = SuperLazyLanguageModel(config) 57 | 58 | # Train like a normal pytorch model 59 | ``` 60 | 4. You can use SLLM functionalities: 61 | ```python 62 | import torch 63 | from datasets import load_dataset 64 | 65 | from sllm.nn import SuperLazyLanguageModel 66 | from sllm.train import sft, prepare_dataset 67 | 68 | torch.manual_seed(42) 69 | 70 | name = "Qwen/Qwen2-0.5B-Instruct" 71 | dataset = load_dataset("yahma/alpaca-cleaned", split="train[:200]") 72 | 73 | dataset = prepare_dataset( 74 | model_name=name, 75 | instructions=dataset["instruction"], 76 | responses=dataset["output"], 77 | inputs=dataset["input"], 78 | max_seq_len=256, 79 | ) 80 | 81 | model = SuperLazyLanguageModel( 82 | name=name, 83 | lora_alpha=32, 84 | lora_r=8, 85 | lora_dropout=0.1, 86 | ) 87 | 88 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) 89 | sft(model=model, dataset=dataset, optimizer=optimizer, batch_size=8, epochs=3) 90 | ``` 91 | 92 | ## Contributing 93 | Whether you’re improving documentation, optimizing kernels, or adding new features, your help is invaluable. 94 | 95 | 1. Create a feature branch (`git checkout -b feature/awesome-improvement`). 96 | 2. Commit your changes (`git commit -m 'Add awesome feature'`). 97 | 3. Push to the branch (`git push origin feature/awesome-improvement`). 98 | 4. Open a Pull Request. 99 | -------------------------------------------------------------------------------- /sllm/utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import math 3 | import os 4 | import random 5 | import shutil 6 | 7 | import numpy as np 8 | import torch 9 | from transformers import AutoModelForCausalLM 10 | 11 | from sllm.common import DTYPE, GRADIENT_DIR 12 | 13 | 14 | def load_tensor_from_storage(weight_path, shape, dtype=DTYPE, to_ram=False): 15 | """ 16 | Load a tensor from a binary file. 17 | 18 | This function reads raw data from a file using torch.from_file and reshapes it into the desired tensor shape. 19 | If the flag 'to_ram' is set to True, the tensor is cloned into RAM to avoid memory-mapping. 20 | 21 | Args: 22 | weight_path (str): The file path to the binary weight file. 23 | shape (tuple or int): Desired shape of the tensor. If a tuple is provided, the total number 24 | of elements is computed as the product of the tuple components; otherwise, the shape is treated as the total size. 25 | dtype (torch.dtype, optional): Data type of the tensor. Defaults to DTYPE. 26 | to_ram (bool, optional): If True, clones the tensor to ensure it resides in RAM. Defaults to False. 27 | 28 | Returns: 29 | torch.Tensor: The tensor with the specified shape loaded from the file. 30 | """ 31 | if isinstance(shape, tuple): 32 | size = math.prod(shape) 33 | else: 34 | size = shape 35 | 36 | data = torch.from_file(weight_path, shared=False, size=size, dtype=dtype) 37 | data = data.view(shape) 38 | 39 | if to_ram: 40 | data = data.clone() 41 | 42 | return data 43 | 44 | 45 | @torch._dynamo.disable 46 | def save_tensor_to_storage(weight_path, data): 47 | """ 48 | Save a tensor to a binary file. 49 | 50 | The function first ensures the tensor is contiguous and detached from the computation graph. 51 | It then converts the tensor to a NumPy array and writes it to the specified file in binary format. 52 | 53 | Args: 54 | weight_path (str): The file path where the tensor will be saved. 55 | data (torch.Tensor): The tensor to be saved. 56 | 57 | Returns: 58 | None 59 | """ 60 | data.contiguous().detach().numpy().tofile(weight_path) 61 | 62 | 63 | def download_weights(weight_dir, model_name): 64 | """ 65 | Download and save the pretrained model weights to a local directory. 66 | 67 | This function checks if the local weight directory already exists. If not, it creates the directory, 68 | loads the pretrained model using Hugging Face's AutoModelForCausalLM with the specified torch dtype, 69 | and saves each model parameter as a binary file in the directory. After saving, the model is deleted 70 | and garbage is collected to free memory. 71 | 72 | Args: 73 | weight_dir (str): The local directory path where weights will be stored. 74 | model_name (str): The name or identifier of the pretrained model to download. 75 | 76 | Returns: 77 | None 78 | """ 79 | if os.path.exists(weight_dir): 80 | return 81 | 82 | os.makedirs(weight_dir, exist_ok=True) 83 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=DTYPE) 84 | 85 | for name, param in model.named_parameters(): 86 | file_path = f"{weight_dir}/{name}.bin" 87 | param.detach().numpy().tofile(file_path) 88 | 89 | del model 90 | gc.collect() 91 | 92 | 93 | def remove_weights(weight_dir): 94 | """ 95 | Remove the weights stored in a specified directory or file. 96 | 97 | If the provided path is a directory, the entire directory is deleted recursively. 98 | Otherwise, if it is a file, the file is removed. 99 | 100 | Args: 101 | weight_dir (str): The file or directory path where the weights are stored. 102 | 103 | Returns: 104 | None 105 | """ 106 | if os.path.exists(weight_dir): 107 | if os.path.isdir(weight_dir): 108 | shutil.rmtree(weight_dir) 109 | else: 110 | os.remove(weight_dir) 111 | 112 | 113 | def clear_gradient_dir(): 114 | """ 115 | Clear the gradient directory. 116 | 117 | This function removes the entire gradient directory (if it exists) and then creates an empty directory with the same path. 118 | This is used to manage temporary storage during gradient computations. 119 | 120 | Returns: 121 | None 122 | """ 123 | if os.path.exists(GRADIENT_DIR): 124 | shutil.rmtree(GRADIENT_DIR) 125 | os.makedirs(GRADIENT_DIR, exist_ok=True) 126 | -------------------------------------------------------------------------------- /sllm/ops/matmul.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file matmul.cpp 3 | * @brief PyTorch extension module for performing scaled matrix multiplication. 4 | * 5 | * This file implements functions to perform matrix multiplications with a scaling factor. 6 | * It provides both a simple operation and a bundled version that can execute multiple operations 7 | * concurrently using OpenMP for batch processing. The module is exposed to Python via PyBind11. 8 | */ 9 | 10 | #include 11 | #include 12 | #include 13 | #ifdef _OPENMP 14 | // OpenMP header included when available. 15 | #endif 16 | 17 | namespace py = pybind11; 18 | 19 | /** 20 | * @brief Performs scaled matrix multiplication. 21 | * 22 | * Computes the matrix product of A and B using torch::matmul and then scales the resulting tensor 23 | * by the provided scale factor. 24 | * 25 | * @param A A torch::Tensor representing the first matrix. 26 | * @param B A torch::Tensor representing the second matrix. 27 | * @param scale A double representing the scaling factor to be applied to the product. 28 | * @return torch::Tensor The scaled result of the matrix multiplication (i.e., (A * B) * scale). 29 | */ 30 | torch::Tensor matmul(const torch::Tensor& A, const torch::Tensor& B, double scale) { 31 | return torch::matmul(A, B) * scale; 32 | } 33 | 34 | /** 35 | * @brief Executes a set of scaled matrix multiplications bundled together. 36 | * 37 | * Processes a vector of tuples where each tuple (referred to as a "bundle") contains three elements: 38 | * - The first element is a torch::Tensor representing the matrix A. 39 | * - The second element is a torch::Tensor representing the matrix B. 40 | * - The third element is a double specifying the scaling factor for the multiplication. 41 | * 42 | * The function reshapes each input tensor to separate batch dimensions from the matrix dimensions, 43 | * then performs the scaled matrix multiplication over the batch elements. When B_reshaped has a single 44 | * batch element, it is broadcasted across the entire batch of A_reshaped. OpenMP is used for parallelization 45 | * over the batch dimension if it is available. 46 | * 47 | * The final resulting tensor is reshaped to match the original input dimensions (except for the last dimension 48 | * which becomes the last dimension of B). The function returns a vector of results with each result corresponding 49 | * to a bundle in the input. 50 | * 51 | * @param matmul_bundles A vector of py::tuple objects, each containing: 52 | * - A torch::Tensor for matrix A, 53 | * - A torch::Tensor for matrix B, 54 | * - A double value for the scaling factor. 55 | * @return std::vector A vector containing the resulting tensors after performing the 56 | * bundled scaled matrix multiplications. 57 | */ 58 | std::vector bundled_scaled_matmul(const std::vector& matmul_bundles) { 59 | std::vector results; 60 | 61 | for (const auto& bundle : matmul_bundles) { 62 | auto original_A = bundle[0].cast(); 63 | auto B = bundle[1].cast(); 64 | double scale = bundle[2].cast(); 65 | 66 | auto A_shape = original_A.sizes(); 67 | auto B_shape = B.sizes(); 68 | 69 | // Reshape to isolate the batch dimension from matrix dimensions. 70 | auto A_reshaped = original_A.reshape({-1, A_shape[A_shape.size()-2], A_shape[A_shape.size()-1]}); 71 | auto B_reshaped = B.reshape({-1, B_shape[B_shape.size()-2], B_shape[B_shape.size()-1]}); 72 | int64_t batch = A_reshaped.size(0); 73 | auto C = torch::empty({batch, A_reshaped.size(1), B_reshaped.size(2)}, torch::kFloat32); 74 | 75 | // Parallelize over the batch dimension using OpenMP. 76 | #pragma omp parallel for 77 | for (int64_t i = 0; i < batch; ++i) { 78 | auto A_i = A_reshaped[i]; 79 | // Use the same B for all if B_reshaped only has one element; otherwise, index accordingly. 80 | auto b_tensor = (B_reshaped.size(0) == 1) ? B_reshaped[0] : B_reshaped[i]; 81 | C[i] = matmul(A_i, b_tensor, scale); 82 | } 83 | 84 | // Restore the original batch shape with the new matrix multiplication dimensions. 85 | std::vector result_shape; 86 | for (size_t j = 0; j < A_shape.size()-1; ++j) { 87 | result_shape.push_back(A_shape[j]); 88 | } 89 | result_shape.push_back(B_shape[B_shape.size()-1]); 90 | results.push_back(C.reshape(result_shape)); 91 | } 92 | return results; 93 | } 94 | 95 | /** 96 | * @brief PyBind11 module definition for the matrix multiplication extension. 97 | * 98 | * Exposes the bundled_scaled_matmul function to Python, allowing it to be called as a regular Python function. 99 | * The module is named "matmul". 100 | */ 101 | PYBIND11_MODULE(matmul, m) { 102 | m.def("bundled_scaled_matmul", &bundled_scaled_matmul, "Distributes bundles of scaled matrix multiplication operations to remote devices concurrently"); 103 | } 104 | -------------------------------------------------------------------------------- /sllm/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module defines the configuration class for the transformer-based model. 3 | It encapsulates all the hyperparameters and weight configurations required for the model 4 | and handles downloading and merging with the pretrained configuration from Hugging Face. 5 | 6 | Key functionalities include: 7 | - Setting model hyperparameters such as hidden size, number of layers, attention heads, etc. 8 | - Configuring LoRA (Low-Rank Adaptation) parameters for efficient fine-tuning. 9 | - Downloading necessary weights for the specified model. 10 | - Overwriting default parameters with those from the pretrained configuration. 11 | """ 12 | 13 | import os 14 | 15 | from transformers import AutoConfig 16 | 17 | from sllm.common import WEIGHT_DIR 18 | from sllm.utils import download_weights, remove_weights 19 | 20 | 21 | class Config: 22 | """ 23 | A configuration object for initializing the transformer-based model. 24 | 25 | This class holds various hyperparameters required by the model. It downloads pretrained 26 | weights if needed and loads configurations from a pretrained model, overriding any matching 27 | attributes provided in the constructor. In addition to standard model settings, it also 28 | includes configuration details for LoRA adaptation. 29 | 30 | Class Attributes: 31 | model_type (str): Default type of the model, set to "transformer". 32 | keys_to_ignore_at_inference (List[str]): List of keys (e.g., "past_key_values") to be 33 | ignored during inference. 34 | """ 35 | 36 | model_type = "transformer" 37 | keys_to_ignore_at_inference = ["past_key_values"] 38 | 39 | def __init__( 40 | self, 41 | model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", 42 | vocab_size=151936, 43 | bos_token_id=151643, 44 | eos_token_id=151643, 45 | model_type="qwen2", 46 | hidden_size=1536, 47 | intermediate_size=8960, 48 | num_hidden_layers=28, 49 | num_attention_heads=12, 50 | num_key_value_heads=2, 51 | hidden_act="silu", 52 | max_position_embeddings=131072, 53 | initializer_range=0.02, 54 | rms_norm_eps=1e-6, 55 | use_cache=True, 56 | tie_word_embeddings=False, 57 | rope_theta=10000.0, 58 | rope_scaling=None, 59 | use_mrope=False, 60 | use_sliding_window=False, 61 | sliding_window=None, 62 | max_window_layers=21, 63 | attention_dropout=0.0, 64 | torch_dtype="bfloat16", 65 | pad_token_id=None, 66 | output_attentions=False, 67 | output_hidden_states=False, 68 | use_return_dict=True, 69 | lora_alpha=32, 70 | lora_r=8, 71 | lora_dropout=0.1, 72 | **kwargs, 73 | ): 74 | """ 75 | Initialize the Config object with the provided hyperparameters and then update with the pretrained configuration. 76 | 77 | The initialization process involves: 78 | 1. Setting initial values for many model hyperparameters and LoRA parameters. 79 | 2. Determining the local weight directory based on the provided model name. 80 | 3. Downloading the pretrained weights (if not already present) into the specified directory. 81 | 4. Loading additional configuration parameters from the pretrained model via AutoConfig, 82 | and updating the current configuration for any matching attributes. 83 | 84 | Side Effects: 85 | - Downloads the pretrained weights (if not already available) to a local weight directory. 86 | - Updates attributes of this Config instance based on the pretrained model's configuration. 87 | """ 88 | self.vocab_size = vocab_size 89 | self.max_position_embeddings = max_position_embeddings 90 | self.hidden_size = hidden_size 91 | self.intermediate_size = intermediate_size 92 | self.num_hidden_layers = num_hidden_layers 93 | self.num_attention_heads = num_attention_heads 94 | self.use_sliding_window = use_sliding_window 95 | self.sliding_window = sliding_window 96 | self.max_window_layers = max_window_layers 97 | 98 | if num_key_value_heads is None: 99 | num_key_value_heads = num_attention_heads 100 | 101 | self.num_key_value_heads = num_key_value_heads 102 | self.hidden_act = hidden_act 103 | self.initializer_range = initializer_range 104 | self.rms_norm_eps = rms_norm_eps 105 | self.use_cache = use_cache 106 | self.rope_theta = rope_theta 107 | self.rope_scaling = rope_scaling 108 | self.attention_dropout = attention_dropout 109 | 110 | self.use_mrope = use_mrope 111 | self.torch_dtype = torch_dtype 112 | self.tie_word_embeddings = tie_word_embeddings 113 | self.bos_token_id = bos_token_id 114 | self.eos_token_id = eos_token_id 115 | self.model_type = model_type 116 | self.pad_token_id = pad_token_id 117 | self.output_attentions = output_attentions 118 | self.output_hidden_states = output_hidden_states 119 | self.use_return_dict = use_return_dict 120 | 121 | self.lora_alpha = lora_alpha 122 | self.lora_r = lora_r 123 | self.lora_dropout = lora_dropout 124 | 125 | # Determine local weight storage directory based on the model name. 126 | model_dir = model_name.split("/")[-1] 127 | self.weight_dir = f"{WEIGHT_DIR}/{model_dir}" 128 | 129 | # Ensure that the pretrained weights are downloaded locally. 130 | download_weights(self.weight_dir, model_name) 131 | 132 | # Load additional configuration parameters from the pretrained model. 133 | config = AutoConfig.from_pretrained(model_name, **kwargs) 134 | 135 | # Overwrite any matching attributes in this Config instance with the pretrained values. 136 | for key, value in config.__dict__.items(): 137 | if hasattr(self, key): 138 | setattr(self, key, value) 139 | -------------------------------------------------------------------------------- /sllm/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import Dataset 3 | from tqdm import tqdm 4 | from transformers import AutoTokenizer 5 | import time # add time module if not imported 6 | 7 | from sllm.utils import clear_gradient_dir 8 | from sllm.common import MINI_BATCH_SIZE 9 | 10 | 11 | def format_example(example): 12 | """ 13 | Format a single example into a prompt and full training text. 14 | 15 | The function builds a prompt with a header (containing the instruction and, if available, 16 | additional input) followed by a "Response:" marker. It then concatenates the prompt with 17 | the output to form the full text for training. 18 | 19 | Args: 20 | example (dict): A dictionary with keys 'instruction', 'input', and 'output'. 21 | 22 | Returns: 23 | dict: A dictionary with keys: 24 | - "text": A string that concatenates the prompt and the output. 25 | - "prompt_text": The text before adding the output. 26 | """ 27 | prompt_text = f"Instruction: {example['instruction']}\n" 28 | if example["input"]: 29 | prompt_text += f"Input: {example['input']}\n" 30 | prompt_text += "Response:" 31 | full_text = prompt_text + " " + example["output"] 32 | return {"text": full_text, "prompt_text": prompt_text} 33 | 34 | 35 | def prepare_dataset(model_name, instructions, responses, inputs=None, max_seq_len=256): 36 | """ 37 | Prepare and tokenize a dataset for training. 38 | 39 | The function creates a Hugging Face Dataset from the provided instructions, responses, and optional inputs. 40 | It then formats each example using `format_example`, tokenizes the text (padding/truncating to `max_seq_len`), 41 | and creates a special label mask in which the tokens corresponding to the prompt are replaced by -100 (to be ignored 42 | during loss computation). Finally, the tokenized dataset is split into mini-batches. 43 | 44 | Args: 45 | model_name (str): The pretrained model name used to load the tokenizer. 46 | instructions (List[str]): A list of instruction strings. 47 | responses (List[str]): A list of response strings. 48 | inputs (List[str] or None, optional): A list of additional input strings for examples; defaults to None. 49 | max_seq_len (int, optional): Maximum sequence length for tokenization; defaults to 256. 50 | 51 | Returns: 52 | dict: A dictionary with keys 'input_ids', 'attention_mask', and 'labels', 53 | where each value is a list of tensors split into mini-batches of size MINI_BATCH_SIZE. 54 | """ 55 | def tokenize_fn(example): 56 | tokenized = tokenizer( 57 | example["text"], 58 | truncation=True, 59 | padding="max_length", 60 | max_length=max_seq_len, 61 | ) 62 | prompt_tokens = tokenizer( 63 | example["prompt_text"], 64 | truncation=False, 65 | add_special_tokens=False, 66 | )["input_ids"] 67 | tokenized["prompt_length"] = len(prompt_tokens) 68 | return tokenized 69 | 70 | def mask_labels(example): 71 | # Copy the tokenized input_ids to labels. 72 | labels = example["input_ids"].copy() 73 | prompt_length = example["prompt_length"] 74 | # Mask the tokens corresponding to the prompt (set to -100) so that they do not contribute to the loss. 75 | for i in range(min(prompt_length, len(labels))): 76 | labels[i] = -100 77 | example["labels"] = labels 78 | return example 79 | 80 | dataset = Dataset.from_dict( 81 | { 82 | "instruction": instructions, 83 | "input": inputs, 84 | "output": responses, 85 | } 86 | ) 87 | dataset = dataset.map(format_example) 88 | 89 | tokenizer = AutoTokenizer.from_pretrained(model_name) 90 | dataset = dataset.map(tokenize_fn) 91 | dataset = dataset.map(mask_labels) 92 | 93 | dataset = dataset.shuffle(seed=42) 94 | 95 | input_ids = torch.tensor(dataset["input_ids"]) 96 | attention_masks = torch.tensor(dataset["attention_mask"]) 97 | labels = torch.tensor(dataset["labels"]) 98 | print(f"Training on {input_ids.numel() // 1000}k tokens") 99 | return { 100 | "input_ids": input_ids.split(MINI_BATCH_SIZE), 101 | "attention_mask": attention_masks.split(MINI_BATCH_SIZE), 102 | "labels": labels.split(MINI_BATCH_SIZE), 103 | } 104 | 105 | 106 | def sft(model, dataset, optimizer, batch_size=1, epochs=1): 107 | """ 108 | Perform supervised fine-tuning (SFT) on a language model. 109 | 110 | This function trains the model on the provided dataset using gradient accumulation. 111 | The effective batch size is determined as the product of MINI_BATCH_SIZE and grad_accum_steps. 112 | 113 | Gradient Accumulation Derivation: 114 | If the provided batch_size is larger than MINI_BATCH_SIZE (the size of each mini-batch), 115 | the gradients of several forward/backward passes are accumulated before performing an optimizer step. 116 | In order to ensure that the effective gradient is equivalent to that computed on the entire batch, 117 | the loss of each mini-batch is divided by grad_accum_steps. That is, 118 | loss_effective = (loss_mini_batch / grad_accum_steps) 119 | This scaling ensures that when the gradients from grad_accum_steps mini-batches are summed, 120 | the resulting update is equivalent to the gradient of the average loss over the full batch. 121 | 122 | During training, the function reports the average loss, seconds per sample, and seconds elapsed per epoch. 123 | 124 | Args: 125 | model (torch.nn.Module): The model to be fine-tuned. 126 | dataset (dict): A dictionary with keys 'input_ids', 'attention_mask', and 'labels', 127 | where each value is a list of tensors representing mini-batches. 128 | optimizer (torch.optim.Optimizer): The optimizer for the model. 129 | batch_size (int, optional): The total batch size for each optimizer update; defaults to 1. 130 | epochs (int, optional): Number of training epochs; defaults to 1. 131 | 132 | Returns: 133 | None 134 | """ 135 | grad_accum_steps = batch_size // MINI_BATCH_SIZE 136 | if grad_accum_steps < 1: 137 | grad_accum_steps = 1 138 | 139 | total_batches = len(dataset["input_ids"]) 140 | for epoch in range(epochs): 141 | epoch_loss = 0 142 | model.train() 143 | optimizer.zero_grad() 144 | clear_gradient_dir() 145 | accum_steps = 0 146 | 147 | epoch_start_time = time.time() 148 | 149 | with tqdm(total=total_batches, desc=f"Epoch {epoch+1}") as pbar: 150 | zipped_dataset = zip( 151 | dataset["input_ids"], dataset["attention_mask"], dataset["labels"] 152 | ) 153 | for i, batch in enumerate(zipped_dataset, start=1): 154 | batch_input, batch_mask, batch_labels = batch 155 | output = model( 156 | input_ids=batch_input, 157 | attention_mask=batch_mask, 158 | labels=batch_labels, 159 | ) 160 | loss = output.loss 161 | # Scale the mini-batch loss to average over grad_accum_steps 162 | loss = loss / grad_accum_steps 163 | loss.backward() 164 | epoch_loss += loss.item() * grad_accum_steps 165 | accum_steps += 1 166 | 167 | if accum_steps % grad_accum_steps == 0: 168 | optimizer.step() 169 | optimizer.zero_grad() 170 | accum_steps = 0 171 | 172 | avg_loss = epoch_loss / i 173 | elapsed = time.time() - epoch_start_time 174 | sec_per_sample = elapsed / (i * MINI_BATCH_SIZE) 175 | sec_per_epoch = time.time() - epoch_start_time 176 | pbar.set_postfix( 177 | loss=f"{avg_loss:.1f}", 178 | sec_per_sample=f"{sec_per_sample:.2f}", 179 | sec_per_epoch=f"{sec_per_epoch:.2f}", 180 | ) 181 | pbar.update(1) 182 | 183 | # Apply any remaining accumulated gradients. 184 | if accum_steps > 0: 185 | optimizer.step() 186 | optimizer.zero_grad() 187 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /sllm/nn/autodiff.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements custom autograd functions for performing matrix multiplications, 3 | including LoRA-adapted operations for efficient fine-tuning. It leverages lazy weight loading, 4 | bundled scaled matrix multiplications, and gradient clipping to manage memory and computational 5 | efficiency. 6 | 7 | Functions: 8 | clip_grad(grad): Clips a gradient tensor based on MAX_GRAD_NORM. 9 | 10 | Classes: 11 | MatmulFunction: Custom autograd function for computing scaled matrix multiplications. 12 | BundledMatmulFunction: Custom autograd function that computes multiple scaled matrix multiplications together. 13 | LoraFunction: Custom autograd function to perform a LoRA-adapted linear operation. 14 | LoraQKVLinearFunction: Custom autograd function for query, key, and value projections with LoRA adaptation. 15 | """ 16 | 17 | import os 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | 22 | from sllm.common import GRADIENT_DIR, MAX_GRAD_NORM 23 | from sllm.ops import bundled_scaled_matmul 24 | from sllm.utils import load_tensor_from_storage, save_tensor_to_storage 25 | 26 | 27 | def clip_grad(grad): 28 | """ 29 | Clip a gradient tensor to prevent exploding gradients. 30 | 31 | If the norm of `grad` exceeds MAX_GRAD_NORM, scale it down accordingly. 32 | 33 | Args: 34 | grad (Tensor or None): The gradient tensor to be clipped. 35 | 36 | Returns: 37 | Tensor or None: The clipped gradient tensor, or None if input is None. 38 | """ 39 | if grad is None: 40 | return None 41 | norm = grad.norm() 42 | if norm > MAX_GRAD_NORM: 43 | grad = grad * (MAX_GRAD_NORM / (norm + 1e-6)) 44 | return grad 45 | 46 | 47 | class MatmulFunction(torch.autograd.Function): 48 | """ 49 | Custom autograd function for performing a scaled matrix multiplication. 50 | 51 | Forward: 52 | Computes Y = bundled_scaled_matmul([(A, B, scale)])[0], which is equivalent to 53 | Y = (A @ B) * scale. 54 | 55 | Backward Derivation: 56 | Given Y = scale * (A @ B), standard matrix calculus yields: 57 | - dY/dA = scale * grad_output @ B^T 58 | - dY/dB = scale * A^T @ grad_output 59 | Since the scale is a constant and non-differentiable here, its derivative is not returned 60 | (i.e. None). In this implementation, the bundled_scaled_matmul function is used to compute: 61 | grad_A = grad_output @ B^T and grad_B = A^T @ grad_output, 62 | assuming that the scale factor is applied in the forward pass and treated as a constant. 63 | """ 64 | 65 | @staticmethod 66 | def forward(ctx, A, B, scale): 67 | """ 68 | Forward pass for the scaled matrix multiplication. 69 | 70 | Args: 71 | A (Tensor): Left-hand side matrix. 72 | B (Tensor): Right-hand side matrix. 73 | scale (float): Scaling factor applied after matrix multiplication. 74 | 75 | Returns: 76 | Tensor: Result of the scaled matrix multiplication. 77 | """ 78 | ctx.save_for_backward(A, B) 79 | return bundled_scaled_matmul([(A, B, scale)])[0] 80 | 81 | @staticmethod 82 | def backward(ctx, grad_output): 83 | """ 84 | Backward pass for the scaled matrix multiplication. 85 | 86 | Derivation: 87 | Let Y = scale * (A @ B). Then, by the chain rule: 88 | dL/dA = dL/dY @ dY/dA = grad_output @ (B^T) * scale, 89 | dL/dB = (A^T) @ grad_output * scale. 90 | Here, the implementation uses bundled_scaled_matmul with a scaling factor of 1.0, 91 | effectively treating the scale as a constant whose derivative is omitted. 92 | 93 | Args: 94 | grad_output (Tensor): Gradient tensor propagated from subsequent layers. 95 | 96 | Returns: 97 | Tuple[Tensor, Tensor, None]: Gradients with respect to A, B, and None for scale. 98 | """ 99 | A, B = ctx.saved_tensors 100 | bundles = [ 101 | (grad_output, B.transpose(-2, -1), 1.0), # represents grad_output @ B^T 102 | (A.transpose(-2, -1), grad_output, 1.0), # represents A^T @ grad_output 103 | ] 104 | grad_A, grad_B = bundled_scaled_matmul(bundles) 105 | grad_A = clip_grad(grad_A) 106 | grad_B = clip_grad(grad_B) 107 | return grad_A, grad_B, None 108 | 109 | 110 | class BundledMatmulFunction(torch.autograd.Function): 111 | """ 112 | Custom autograd function for performing multiple scaled matrix multiplications in batch. 113 | 114 | Forward: 115 | Accepts a list of bundles where each bundle is a tuple (M, N, scale) and returns a list 116 | of results computed by bundled_scaled_matmul. 117 | 118 | Backward Derivation: 119 | For each bundle where the forward operation is Y_i = M_i @ N_i * scale_i, the gradient 120 | derivation uses the standard identities: 121 | dL/dM_i = grad_output @ (N_i^T) 122 | dL/dN_i = (M_i^T) @ grad_output 123 | Here, a new bundle is constructed for each stored tuple in the forward pass with a fixed scale (1.0) 124 | and processed in batch. 125 | """ 126 | 127 | @staticmethod 128 | def forward(ctx, bundles): 129 | """ 130 | Forward pass for the bundled scaled matrix multiplications. 131 | 132 | Args: 133 | bundles (List[Tuple[Tensor, Tensor, float]]): A list of tuples, each containing two tensors 134 | and a scaling factor. 135 | 136 | Returns: 137 | List[Tensor]: A list of tensors resulting from the matrix multiplications. 138 | """ 139 | ctx.save_for_backward(*bundles) 140 | return bundled_scaled_matmul(bundles) 141 | 142 | @staticmethod 143 | def backward(ctx, grad_output): 144 | """ 145 | Backward pass for the bundled matrix multiplications. 146 | 147 | Derivation: 148 | For each bundle (M, N, scale) corresponding to Y = M @ N * scale, the gradient with respect 149 | to M is grad_output @ (N^T) and with respect to N is (M^T) @ grad_output. The backward pass 150 | constructs a new list of bundles where for each saved bundle it forms: 151 | (grad_output, N^T, 1.0) 152 | and computes the corresponding gradients in one batched call. 153 | 154 | Args: 155 | grad_output (Tensor): Gradient tensor propagated from subsequent layers. 156 | 157 | Returns: 158 | Tuple[List[Tensor]]: Gradients corresponding to the input bundles. 159 | """ 160 | bundles = ctx.saved_tensors 161 | triple_list = [ 162 | (grad_output, bundle[1].transpose(-2, -1), 1.0) for bundle in bundles 163 | ] 164 | grad_bundles = bundled_scaled_matmul(triple_list) 165 | return grad_bundles 166 | 167 | 168 | class LoraFunction(torch.autograd.Function): 169 | """ 170 | Custom autograd function for LoRA-adapted linear operations. 171 | 172 | Forward: 173 | Computes the effective weight as W_eff = W + (A @ B * scale) and then computes: 174 | Y = bundled_scaled_matmul([(x, W_eff^T, 1.0)])[0] 175 | Optionally adds a bias. 176 | 177 | Backward Derivation: 178 | Let W_eff = W + (A @ B * scale) and Y = x @ (W_eff)^T. 179 | Using the chain rule: 180 | dL/dx = grad_output @ W_eff 181 | dL/dW_eff = x^T @ grad_output 182 | Then, the derivatives with respect to A and B are computed via: 183 | dW_eff/dA = B * scale, dW_eff/dB = A * scale. 184 | The backward pass first computes an intermediate gradient E from the matrix multiplication, 185 | then derives: 186 | grad_A = E @ (B^T) * scale 187 | grad_B = (A^T) @ E * scale. 188 | The pre-trained weight W is assumed fixed, so its gradient is not computed. 189 | """ 190 | 191 | @staticmethod 192 | def forward(ctx, x, A, B, W_path, scale, bias=None): 193 | """ 194 | Forward pass for the LoRA function. 195 | 196 | Args: 197 | x (Tensor): Input tensor. 198 | A (Tensor): LoRA parameter A. 199 | B (Tensor): LoRA parameter B. 200 | W_path (str): File path to the pre-trained weight tensor. 201 | scale (float): Scaling factor for the LoRA update. 202 | bias (Tensor, optional): Bias tensor to add to the output. 203 | 204 | Returns: 205 | Tensor: Output tensor after applying the effective weight and bias. 206 | """ 207 | ctx.save_for_backward(A, B) 208 | ctx.scale = scale 209 | ctx.x_path = os.path.join(GRADIENT_DIR, W_path.split("/")[-1] + ".x.bin") 210 | save_tensor_to_storage(ctx.x_path, x) 211 | 212 | W = load_tensor_from_storage( 213 | weight_path=W_path, 214 | shape=(A.shape[0], B.shape[1]), 215 | dtype=A.dtype, 216 | to_ram=False, 217 | ) 218 | 219 | effective_W = W + (A @ B * scale) 220 | Wx = bundled_scaled_matmul([(x, effective_W.transpose(-2, -1), 1.0)])[0] 221 | 222 | ctx.effecttive_weight = effective_W 223 | 224 | if bias is not None: 225 | Wx = Wx + bias 226 | return Wx 227 | 228 | @staticmethod 229 | def backward(ctx, grad_output): 230 | """ 231 | Backward pass for the LoRA function. 232 | 233 | Derivation: 234 | Given: 235 | Y = x @ (W_eff)^T, with W_eff = W + (A @ B) * scale. 236 | Then: 237 | dL/dx = grad_output @ W_eff. 238 | dL/dW_eff = x^T @ grad_output. 239 | Since W is fixed, we only differentiate through the LoRA update. 240 | The gradient of W_eff with respect to A is: dW_eff/dA = B * scale, 241 | and with respect to B is: dW_eff/dB = A * scale. 242 | Therefore, an intermediate gradient E = dL/dW_eff is computed and then: 243 | grad_A = E @ (B^T) * scale, 244 | grad_B = (A^T) @ E * scale. 245 | 246 | Args: 247 | grad_output (Tensor): Gradient tensor from subsequent operations. 248 | 249 | Returns: 250 | Tuple: Gradients with respect to x, A, B, None (for weight), None (for scale), and None (for bias). 251 | """ 252 | x = load_tensor_from_storage( 253 | ctx.x_path, shape=grad_output.shape, dtype=grad_output.dtype, to_ram=False 254 | ) 255 | A, B = ctx.saved_tensors 256 | scale = ctx.scale 257 | 258 | effective_W = ctx.effecttive_weight 259 | 260 | bundles = [ 261 | (grad_output, effective_W, 1.0), # dL/dx = grad_output @ W_eff 262 | (x.transpose(-2, -1), grad_output, 1.0), # dL/dW_eff = x^T @ grad_output 263 | ] 264 | grad_x, E = bundled_scaled_matmul(bundles) 265 | 266 | bundles = [ 267 | (E, B.transpose(-2, -1), scale), # grad_A = E @ B^T * scale 268 | (A.transpose(-2, -1), E, scale) # grad_B = A^T @ E * scale 269 | ] 270 | grad_A, grad_B = bundled_scaled_matmul(bundles) 271 | 272 | grad_w = None 273 | grad_scale = None 274 | grad_b = None 275 | 276 | return grad_x, grad_A, grad_B, grad_w, grad_scale, grad_b 277 | 278 | 279 | class LoraQKVLinearFunction(torch.autograd.Function): 280 | """ 281 | Custom autograd function for LoRA-adapted query/key/value projections. 282 | 283 | Forward: 284 | For each projection (Q, K, V), the effective weight is computed as: 285 | effective = weight + (LoRA_A @ LoRA_B * scaling) 286 | and the projections are computed as: 287 | Projection = bundled_scaled_matmul([(x, effective, 1.0)]) 288 | Biases are added if provided. 289 | 290 | Backward Derivation: 291 | Let Q = x @ (q_effective)^T + bias, with q_effective = q_weight + (q_proj_lora_A @ q_proj_lora_B * scaling). 292 | By the chain rule: 293 | dL/dx = sum_{proj in {Q, K, V}} [ (x^T)' from that projection ], 294 | where for each projection the gradients with respect to the effective weights are obtained by: 295 | dL/d(effective) = x^T @ grad_projection. 296 | Then, the gradients with respect to the low-rank parameters (LoRA_A and LoRA_B) are computed 297 | using: 298 | grad_LoRA_A = dL/d(effective) @ (LoRA_B)^T * scaling, 299 | grad_LoRA_B = (LoRA_A)^T @ dL/d(effective) * scaling. 300 | The input x gradient is computed as the sum of contributions from Q, K, and V pathways. 301 | """ 302 | 303 | @staticmethod 304 | def forward( 305 | ctx, 306 | x, 307 | q_proj_weight_path, 308 | k_proj_weight_path, 309 | v_proj_weight_path, 310 | q_proj_bias, 311 | k_proj_bias, 312 | v_proj_bias, 313 | q_proj_lora_A, 314 | q_proj_lora_B, 315 | k_proj_lora_A, 316 | k_proj_lora_B, 317 | v_proj_lora_A, 318 | v_proj_lora_B, 319 | scaling, 320 | ): 321 | """ 322 | Forward pass for the LoRA QKV function. 323 | 324 | Args: 325 | x (Tensor): Input tensor. 326 | q_proj_weight_path (str): File path for the query projection weight. 327 | k_proj_weight_path (str): File path for the key projection weight. 328 | v_proj_weight_path (str): File path for the value projection weight. 329 | q_proj_bias (Tensor or None): Bias tensor for the query projection. 330 | k_proj_bias (Tensor or None): Bias tensor for the key projection. 331 | v_proj_bias (Tensor or None): Bias tensor for the value projection. 332 | q_proj_lora_A (Tensor): LoRA parameter A for query projection. 333 | q_proj_lora_B (Tensor): LoRA parameter B for query projection. 334 | k_proj_lora_A (Tensor): LoRA parameter A for key projection. 335 | k_proj_lora_B (Tensor): LoRA parameter B for key projection. 336 | v_proj_lora_A (Tensor): LoRA parameter A for value projection. 337 | v_proj_lora_B (Tensor): LoRA parameter B for value projection. 338 | scaling (float): Scaling factor for the LoRA update. 339 | 340 | Returns: 341 | Tuple[Tensor, Tensor, Tensor]: The projected query, key, and value tensors. 342 | """ 343 | ctx.save_for_backward( 344 | q_proj_lora_A, 345 | q_proj_lora_B, 346 | k_proj_lora_A, 347 | k_proj_lora_B, 348 | v_proj_lora_A, 349 | v_proj_lora_B, 350 | ) 351 | 352 | ctx.scaling = scaling 353 | ctx.q_bias_flag = q_proj_bias is not None 354 | ctx.k_bias_flag = k_proj_bias is not None 355 | ctx.v_bias_flag = v_proj_bias is not None 356 | 357 | # Save input x shape for later and store x on disk. 358 | ctx.x_shape = x.shape 359 | ctx.x_path = os.path.join( 360 | GRADIENT_DIR, q_proj_weight_path.split("/")[-1] + ".x.bin" 361 | ) 362 | save_tensor_to_storage(ctx.x_path, x) 363 | 364 | q_shape = (q_proj_lora_A.shape[0], x.shape[-1]) 365 | kv_shape = (k_proj_lora_B.shape[-1], x.shape[-1]) 366 | ctx.q_shape = q_shape 367 | ctx.kv_shape = kv_shape 368 | 369 | q_proj_weight = load_tensor_from_storage( 370 | weight_path=q_proj_weight_path, 371 | shape=q_shape, 372 | dtype=q_proj_lora_A.dtype, 373 | to_ram=False, 374 | ).transpose(-2, -1) 375 | k_proj_weight = load_tensor_from_storage( 376 | weight_path=k_proj_weight_path, 377 | shape=kv_shape, 378 | dtype=k_proj_lora_A.dtype, 379 | to_ram=False, 380 | ).transpose(-2, -1) 381 | v_proj_weight = load_tensor_from_storage( 382 | weight_path=v_proj_weight_path, 383 | shape=kv_shape, 384 | dtype=v_proj_lora_A.dtype, 385 | to_ram=False, 386 | ).transpose(-2, -1) 387 | 388 | # Compute effective weights with LoRA update. 389 | q_effective = q_proj_weight + (q_proj_lora_A @ q_proj_lora_B * scaling) 390 | k_effective = k_proj_weight + (k_proj_lora_A @ k_proj_lora_B * scaling) 391 | v_effective = v_proj_weight + (v_proj_lora_A @ v_proj_lora_B * scaling) 392 | 393 | ctx.q_effective = q_effective 394 | ctx.k_effective = k_effective 395 | ctx.v_effective = v_effective 396 | 397 | bundles = [ 398 | (x, q_effective, 1.0), 399 | (x, k_effective, 1.0), 400 | (x, v_effective, 1.0), 401 | ] 402 | 403 | Q, K, V = bundled_scaled_matmul(bundles) 404 | 405 | if q_proj_bias is not None: 406 | Q = Q + q_proj_bias 407 | if k_proj_bias is not None: 408 | K = K + k_proj_bias 409 | if v_proj_bias is not None: 410 | V = V + v_proj_bias 411 | 412 | return Q, K, V 413 | 414 | @staticmethod 415 | def backward(ctx, grad_Q, grad_K, grad_V): 416 | """ 417 | Backward pass for the LoRA QKV function. 418 | 419 | Derivation: 420 | For each projection (Q, K, V), let: 421 | effective = weight + (LoRA_A @ LoRA_B * scaling) 422 | and the forward operation is: 423 | Projection = x @ (effective)^T (+ bias). 424 | The gradients are computed as follows: 425 | 1. Compute dL/d(effective) for each projection by: 426 | dL/d(effective) = x^T @ grad_projection 427 | 2. The gradient with respect to x is obtained by summing over contributions: 428 | grad_x = grad_xQ + grad_xK + grad_xV 429 | 3. Using the chain rule and the linearity of the LoRA update: 430 | grad_LoRA_A = dL/d(effective) @ (LoRA_B)^T * scaling, 431 | grad_LoRA_B = (LoRA_A)^T @ dL/d(effective) * scaling. 432 | Here, the bundled_scaled_matmul function is used to compute both the gradients for x 433 | (from each of Q, K, V) and the gradients for the effective weights, which are then propagated 434 | to the low-rank LoRA parameters. 435 | 436 | Args: 437 | grad_Q (Tensor): Gradient with respect to the query output. 438 | grad_K (Tensor): Gradient with respect to the key output. 439 | grad_V (Tensor): Gradient with respect to the value output. 440 | 441 | Returns: 442 | Tuple: Gradients for each input parameter in the same order as in the forward pass. 443 | """ 444 | ( 445 | q_proj_lora_A, 446 | q_proj_lora_B, 447 | k_proj_lora_A, 448 | k_proj_lora_B, 449 | v_proj_lora_A, 450 | v_proj_lora_B, 451 | ) = ctx.saved_tensors 452 | scale = ctx.scaling 453 | x = load_tensor_from_storage( 454 | ctx.x_path, shape=ctx.x_shape, dtype=grad_Q.dtype, to_ram=False 455 | ) 456 | 457 | q_effective = ctx.q_effective 458 | k_effective = ctx.k_effective 459 | v_effective = ctx.v_effective 460 | 461 | bundles = [ 462 | (grad_Q, q_effective.transpose(-2, -1), 1.0), 463 | (grad_K, k_effective.transpose(-2, -1), 1.0), 464 | (grad_V, v_effective.transpose(-2, -1), 1.0), 465 | (x.transpose(-2, -1), grad_Q, 1.0), 466 | (x.transpose(-2, -1), grad_K, 1.0), 467 | (x.transpose(-2, -1), grad_V, 1.0), 468 | ] 469 | ( 470 | grad_xQ, 471 | grad_xK, 472 | grad_xV, 473 | grad_effective_q, 474 | grad_effective_k, 475 | grad_effective_v, 476 | ) = bundled_scaled_matmul(bundles) 477 | 478 | grad_x = grad_xQ + grad_xK + grad_xV 479 | 480 | grad_q_bias = grad_Q.sum(dim=0) if ctx.q_bias_flag else None 481 | grad_k_bias = grad_K.sum(dim=0) if ctx.k_bias_flag else None 482 | grad_v_bias = grad_V.sum(dim=0) if ctx.v_bias_flag else None 483 | 484 | bundles = [ 485 | (grad_effective_q, q_proj_lora_B.transpose(-2, -1), scale), 486 | (q_proj_lora_A.transpose(-2, -1), grad_effective_q, scale), 487 | (grad_effective_k, k_proj_lora_B.transpose(-2, -1), scale), 488 | (k_proj_lora_A.transpose(-2, -1), grad_effective_k, scale), 489 | (grad_effective_v, v_proj_lora_B.transpose(-2, -1), scale), 490 | (v_proj_lora_A.transpose(-2, -1), grad_effective_v, scale), 491 | ] 492 | grad_q_A, grad_q_B, grad_k_A, grad_k_B, grad_v_A, grad_v_B = bundled_scaled_matmul( 493 | bundles 494 | ) 495 | 496 | grad_x = clip_grad(grad_x) 497 | grad_q_bias = clip_grad(grad_q_bias) 498 | grad_k_bias = clip_grad(grad_k_bias) 499 | grad_v_bias = clip_grad(grad_v_bias) 500 | grad_q_A = clip_grad(grad_q_A) 501 | grad_q_B = clip_grad(grad_q_B) 502 | grad_k_A = clip_grad(grad_k_A) 503 | grad_k_B = clip_grad(grad_k_B) 504 | grad_v_A = clip_grad(grad_v_A) 505 | grad_v_B = clip_grad(grad_v_B) 506 | grad_effective_q = clip_grad(grad_effective_q) 507 | grad_effective_k = clip_grad(grad_effective_k) 508 | grad_effective_v = clip_grad(grad_effective_v) 509 | grad_xQ = clip_grad(grad_xQ) 510 | grad_xK = clip_grad(grad_xK) 511 | grad_xV = clip_grad(grad_xV) 512 | grad_Q = clip_grad(grad_Q) 513 | grad_K = clip_grad(grad_K) 514 | grad_V = clip_grad(grad_V) 515 | 516 | grad_q_weight_path = None 517 | grad_k_weight_path = None 518 | grad_v_weight_path = None 519 | grad_scale = None 520 | 521 | return ( 522 | grad_x, 523 | grad_q_weight_path, 524 | grad_k_weight_path, 525 | grad_v_weight_path, 526 | grad_q_bias, 527 | grad_k_bias, 528 | grad_v_bias, 529 | grad_q_A, 530 | grad_q_B, 531 | grad_k_A, 532 | grad_k_B, 533 | grad_v_A, 534 | grad_v_B, 535 | grad_scale, 536 | ) 537 | -------------------------------------------------------------------------------- /sllm/nn/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements a Transformer-based language model with support for decoder layers, 3 | gradient checkpointing, and caching mechanisms. It defines the following classes: 4 | 5 | - DecoderLayer: Implements a single layer of the transformer decoder. 6 | - Transformer: Composes multiple decoder layers into a full transformer model. 7 | - SuperLazyLanguageModel: Encapsulates the transformer and head for sequence generation tasks. 8 | 9 | The implementation leverages PyTorch for tensor computations and Hugging Face transformers 10 | for cache management. 11 | """ 12 | 13 | from typing import List, Optional, Tuple, Union 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | from torch import nn 18 | from transformers import Cache, DynamicCache, SlidingWindowCache, StaticCache 19 | from transformers.modeling_outputs import (BaseModelOutputWithPast, 20 | CausalLMOutputWithPast) 21 | 22 | from sllm.common import DTYPE 23 | from sllm.config import Config 24 | from sllm.nn.layers import (Attention, Embedding, Linear, 25 | MLP, RMSNorm, RotaryEmbedding) 26 | 27 | 28 | class DecoderLayer(nn.Module): 29 | """ 30 | A single decoder layer of the Transformer. 31 | 32 | This layer consists of: 33 | - RMS normalization applied to the input (input_layernorm). 34 | - A self-attention mechanism (self_attn). 35 | - A residual connection adding the result back to the input. 36 | - A second RMS normalization (post_attention_layernorm) followed by a feed-forward MLP, 37 | with an additional residual connection. 38 | 39 | Args: 40 | config (Config): Configuration object containing model hyperparameters. 41 | layer_idx (int): Index of the layer to load layer-specific weights. 42 | """ 43 | 44 | def __init__(self, config: Config, layer_idx: int): 45 | """ 46 | Initialize the decoder layer. 47 | 48 | Loads layer-specific weights for normalization from the provided weight directory. 49 | 50 | Args: 51 | config (Config): Model configuration. 52 | layer_idx (int): Index for the current layer. 53 | """ 54 | super().__init__() 55 | self.hidden_size = config.hidden_size 56 | self.self_attn = Attention(config=config, layer_idx=layer_idx) 57 | self.mlp = MLP(config, layer_idx) 58 | 59 | input_layer_norm_weight_path = ( 60 | f"{config.weight_dir}/model.layers.{layer_idx}.input_layernorm.weight.bin" 61 | ) 62 | post_attention_layer_norm_weight_path = f"{config.weight_dir}/model.layers.{layer_idx}.post_attention_layernorm.weight.bin" 63 | 64 | self.input_layernorm = RMSNorm( 65 | hidden_size=config.hidden_size, 66 | weight_path=input_layer_norm_weight_path, 67 | eps=config.rms_norm_eps, 68 | ) 69 | self.post_attention_layernorm = RMSNorm( 70 | hidden_size=config.hidden_size, 71 | weight_path=post_attention_layer_norm_weight_path, 72 | eps=config.rms_norm_eps, 73 | ) 74 | 75 | def forward( 76 | self, 77 | hidden_states: torch.Tensor, 78 | attention_mask: Optional[torch.Tensor] = None, 79 | position_ids: Optional[torch.LongTensor] = None, 80 | past_key_value: Optional[Cache] = None, 81 | output_attentions: Optional[bool] = False, 82 | use_cache: Optional[bool] = False, 83 | cache_position: Optional[torch.LongTensor] = None, 84 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 85 | **kwargs, 86 | ) -> Tuple[ 87 | torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] 88 | ]: 89 | """ 90 | Perform a forward pass through the decoder layer. 91 | 92 | The input passes through an initial layer normalization, self-attention block, and 93 | a feed-forward MLP with residual connections. Optionally, attention weights can be returned. 94 | 95 | Args: 96 | hidden_states (torch.Tensor): Input tensor with shape (batch_size, seq_length, hidden_size). 97 | attention_mask (Optional[torch.Tensor], optional): Attention mask for self-attention. 98 | position_ids (Optional[torch.LongTensor], optional): Tensor containing position indices. 99 | past_key_value (Optional[Cache], optional): Cached past key and value tensors. 100 | output_attentions (Optional[bool], optional): If True, returns self-attention weights. 101 | use_cache (Optional[bool], optional): If True, enables caching for inference. 102 | cache_position (Optional[torch.LongTensor], optional): Positions for caching tokens. 103 | position_embeddings (Optional[Tuple[torch.Tensor, torch.Tensor]], optional): Pre-computed position embeddings. 104 | **kwargs: Additional keyword arguments. 105 | 106 | Returns: 107 | Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 108 | - The output tensor of shape (batch_size, seq_length, hidden_size). 109 | - Optionally, a tuple of self-attention weights. 110 | """ 111 | residual = hidden_states 112 | 113 | hidden_states = self.input_layernorm(hidden_states) 114 | 115 | hidden_states, self_attn_weights = self.self_attn( 116 | hidden_states=hidden_states, 117 | attention_mask=attention_mask, 118 | position_ids=position_ids, 119 | past_key_value=past_key_value, 120 | output_attentions=output_attentions, 121 | use_cache=use_cache, 122 | cache_position=cache_position, 123 | position_embeddings=position_embeddings, 124 | **kwargs, 125 | ) 126 | hidden_states = residual + hidden_states 127 | 128 | residual = hidden_states 129 | hidden_states = self.post_attention_layernorm(hidden_states) 130 | hidden_states = self.mlp(hidden_states) 131 | hidden_states = residual + hidden_states 132 | 133 | outputs = (hidden_states,) 134 | if output_attentions: 135 | outputs += (self_attn_weights,) 136 | 137 | return outputs 138 | 139 | 140 | class Transformer(nn.Module): 141 | """ 142 | Transformer decoder composed of multiple decoder layers. 143 | 144 | This module includes: 145 | - Token embedding. 146 | - A stack of decoder layers. 147 | - Rotary positional embeddings. 148 | - Final normalization. 149 | - Optional support for gradient checkpointing and caching. 150 | 151 | Attributes: 152 | padding_idx (int): Padding index for token embeddings. 153 | vocab_size (int): Size of the vocabulary. 154 | embed_tokens (Embedding): Token embedding layer. 155 | layers (nn.ModuleList): A list of decoder layers. 156 | norm (RMSNorm): Normalization applied after the last decoder layer. 157 | rotary_emb (RotaryEmbedding): Module for rotary position embeddings. 158 | gradient_checkpointing (bool): Enables gradient checkpointing if True. 159 | config (Config): Model configuration. 160 | """ 161 | 162 | def __init__(self, config: Config): 163 | """ 164 | Initialize the Transformer decoder. 165 | 166 | Loads weights for token embeddings, each decoder layer, and the final normalization. 167 | 168 | Args: 169 | config (Config): Configuration object with model hyperparameters. 170 | """ 171 | super().__init__() 172 | self.padding_idx = config.pad_token_id 173 | self.vocab_size = config.vocab_size 174 | 175 | self.embed_tokens = Embedding( 176 | vocab_size=config.vocab_size, 177 | hidden_size=config.hidden_size, 178 | padding_idx=self.padding_idx, 179 | weight_path=f"{config.weight_dir}/model.embed_tokens.weight.bin", 180 | ) 181 | self.layers = nn.ModuleList( 182 | [ 183 | DecoderLayer(config, layer_idx) 184 | for layer_idx in range(config.num_hidden_layers) 185 | ] 186 | ) 187 | norm_weight_path = f"{config.weight_dir}/model.norm.weight.bin" 188 | self.norm = RMSNorm( 189 | config.hidden_size, weight_path=norm_weight_path, eps=config.rms_norm_eps 190 | ) 191 | self.rotary_emb = RotaryEmbedding(config=config) 192 | self.gradient_checkpointing = False 193 | self.config = config 194 | 195 | def get_input_embeddings(self): 196 | """ 197 | Retrieve the input embeddings. 198 | 199 | Returns: 200 | Embedding: The embedding layer used for token lookup. 201 | """ 202 | return self.embed_tokens 203 | 204 | def set_input_embeddings(self, value): 205 | """ 206 | Set the input embeddings. 207 | 208 | Args: 209 | value (Embedding): A new embedding layer. 210 | """ 211 | self.embed_tokens = value 212 | 213 | def forward( 214 | self, 215 | input_ids: torch.LongTensor = None, 216 | attention_mask: Optional[torch.Tensor] = None, 217 | position_ids: Optional[torch.LongTensor] = None, 218 | past_key_values: Optional[Cache] = None, 219 | inputs_embeds: Optional[torch.FloatTensor] = None, 220 | use_cache: Optional[bool] = None, 221 | output_attentions: Optional[bool] = None, 222 | output_hidden_states: Optional[bool] = None, 223 | return_dict: Optional[bool] = None, 224 | cache_position: Optional[torch.LongTensor] = None, 225 | ) -> Union[Tuple, BaseModelOutputWithPast]: 226 | """ 227 | Perform a forward pass through the Transformer decoder. 228 | 229 | Delegates embedding lookup to the embedding layer, applies rotary position embeddings, 230 | and passes data sequentially through each decoder layer. Also manages caching and generates 231 | an updated causal mask based on past key values. 232 | 233 | Args: 234 | input_ids (torch.LongTensor, optional): Input token IDs. 235 | attention_mask (Optional[torch.Tensor], optional): Attention mask for the sequence. 236 | position_ids (Optional[torch.LongTensor], optional): Position IDs for the sequence. 237 | past_key_values (Optional[Cache], optional): Cached key/value pairs. 238 | inputs_embeds (Optional[torch.FloatTensor], optional): Pre-computed input embeddings. 239 | use_cache (Optional[bool], optional): If True, enables caching. 240 | output_attentions (Optional[bool], optional): If True, outputs attention weights. 241 | output_hidden_states (Optional[bool], optional): If True, outputs hidden states. 242 | return_dict (Optional[bool], optional): If True, returns a dict-like object instead of a tuple. 243 | cache_position (Optional[torch.LongTensor], optional): Cache position indices. 244 | 245 | Returns: 246 | Union[Tuple, BaseModelOutputWithPast]: 247 | - If return_dict is False, returns a tuple with logits and optionally additional outputs. 248 | - Otherwise, returns a `BaseModelOutputWithPast` with last hidden state, cached key values, 249 | hidden states, and attentions. 250 | """ 251 | output_attentions = ( 252 | output_attentions 253 | if output_attentions is not None 254 | else self.config.output_attentions 255 | ) 256 | output_hidden_states = ( 257 | output_hidden_states 258 | if output_hidden_states is not None 259 | else self.config.output_hidden_states 260 | ) 261 | use_cache = use_cache if use_cache is not None else self.config.use_cache 262 | return_dict = ( 263 | return_dict if return_dict is not None else self.config.use_return_dict 264 | ) 265 | 266 | if inputs_embeds is None: 267 | inputs_embeds = self.embed_tokens(input_ids) 268 | 269 | if use_cache and past_key_values is None: 270 | past_key_values = DynamicCache() 271 | 272 | if cache_position is None: 273 | past_seen_tokens = ( 274 | past_key_values.get_seq_length() if past_key_values is not None else 0 275 | ) 276 | cache_position = torch.arange( 277 | past_seen_tokens, 278 | past_seen_tokens + inputs_embeds.shape[1], 279 | device=inputs_embeds.device, 280 | ) 281 | 282 | if position_ids is None: 283 | position_ids = cache_position.unsqueeze(0) 284 | 285 | causal_mask = self._update_causal_mask( 286 | attention_mask, inputs_embeds, cache_position, past_key_values 287 | ) 288 | 289 | hidden_states = inputs_embeds 290 | position_embeddings = self.rotary_emb(hidden_states, position_ids) 291 | all_hidden_states = () if output_hidden_states else None 292 | all_self_attns = () if output_attentions else None 293 | 294 | for decoder_layer in self.layers[: self.config.num_hidden_layers]: 295 | if output_hidden_states: 296 | all_hidden_states += (hidden_states,) 297 | 298 | if self.gradient_checkpointing and self.training: 299 | layer_outputs = self._gradient_checkpointing_func( 300 | decoder_layer.__call__, 301 | hidden_states, 302 | causal_mask, 303 | position_ids, 304 | past_key_values, 305 | output_attentions, 306 | use_cache, 307 | cache_position, 308 | position_embeddings, 309 | ) 310 | else: 311 | layer_outputs = decoder_layer( 312 | hidden_states, 313 | attention_mask=causal_mask, 314 | position_ids=position_ids, 315 | past_key_value=past_key_values, 316 | output_attentions=output_attentions, 317 | use_cache=use_cache, 318 | cache_position=cache_position, 319 | position_embeddings=position_embeddings, 320 | ) 321 | 322 | hidden_states = layer_outputs[0] 323 | 324 | if output_attentions: 325 | all_self_attns += (layer_outputs[1],) 326 | 327 | hidden_states = self.norm(hidden_states) 328 | 329 | if output_hidden_states: 330 | all_hidden_states += (hidden_states,) 331 | 332 | output = BaseModelOutputWithPast( 333 | last_hidden_state=hidden_states, 334 | past_key_values=past_key_values if use_cache else None, 335 | hidden_states=all_hidden_states, 336 | attentions=all_self_attns, 337 | ) 338 | return output if return_dict else output.to_tuple() 339 | 340 | def _update_causal_mask( 341 | self, 342 | attention_mask: torch.Tensor, 343 | input_tensor: torch.Tensor, 344 | cache_position: torch.Tensor, 345 | past_key_values: dict, 346 | ): 347 | """ 348 | Generate a causal attention mask accounting for cached tokens. 349 | 350 | The mask is created based on the current input tensor, the cache positions, and the type of caching 351 | used (static or sliding window). This allows the model to correctly attend to previous tokens and 352 | manage attention when using caches. 353 | 354 | Args: 355 | attention_mask (torch.Tensor): Original attention mask. 356 | input_tensor (torch.Tensor): Input tensor with shape (batch_size, seq_length, hidden_size). 357 | cache_position (torch.Tensor): Tensor indicating positions for cached tokens. 358 | past_key_values (dict): Cached past key values, possibly an instance of StaticCache or SlidingWindowCache. 359 | 360 | Returns: 361 | torch.Tensor: A 4D causal attention mask. 362 | """ 363 | past_seen_tokens = ( 364 | past_key_values.get_seq_length() if past_key_values is not None else 0 365 | ) 366 | using_static_cache = isinstance(past_key_values, StaticCache) 367 | using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) 368 | 369 | dtype, device = input_tensor.dtype, input_tensor.device 370 | sequence_length = input_tensor.shape[1] 371 | 372 | if using_sliding_window_cache or using_static_cache: 373 | target_length = past_key_values.get_max_cache_shape() 374 | else: 375 | target_length = ( 376 | attention_mask.shape[-1] 377 | if isinstance(attention_mask, torch.Tensor) 378 | else past_seen_tokens + sequence_length + 1 379 | ) 380 | 381 | causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( 382 | attention_mask, 383 | sequence_length=sequence_length, 384 | target_length=target_length, 385 | dtype=dtype, 386 | device=device, 387 | cache_position=cache_position, 388 | batch_size=input_tensor.shape[0], 389 | config=self.config, 390 | past_key_values=past_key_values, 391 | ) 392 | 393 | return causal_mask 394 | 395 | @staticmethod 396 | def _prepare_4d_causal_attention_mask_with_cache_position( 397 | attention_mask: torch.Tensor, 398 | sequence_length: int, 399 | target_length: int, 400 | dtype: torch.dtype, 401 | device: torch.device, 402 | cache_position: torch.Tensor, 403 | batch_size: int, 404 | config: Config, 405 | past_key_values: dict, 406 | ): 407 | """ 408 | Create a 4D causal attention mask with cache positions. 409 | 410 | This function prepares a mask of shape (batch_size, 1, query_length, key_value_length) from a 411 | 2D mask or creates a new one if needed. For sliding window configurations, additional masking is applied. 412 | 413 | Args: 414 | attention_mask (torch.Tensor): Input attention mask. 415 | sequence_length (int): Length of the current sequence. 416 | target_length (int): The target length (e.g., including cached tokens). 417 | dtype (torch.dtype): Data type for the mask. 418 | device (torch.device): Device for the mask tensor. 419 | cache_position (torch.Tensor): Tensor indicating positions for caching. 420 | batch_size (int): Batch size. 421 | config (Config): Model configuration (may include sliding window settings). 422 | past_key_values (dict): Cached past key values. 423 | 424 | Returns: 425 | torch.Tensor: A 4D causal attention mask tensor of shape (batch_size, 1, query_length, key_value_length). 426 | """ 427 | if attention_mask is not None and attention_mask.dim() == 4: 428 | causal_mask = attention_mask 429 | 430 | else: 431 | min_dtype = torch.finfo(dtype).min 432 | causal_mask = torch.full( 433 | (sequence_length, target_length), 434 | fill_value=min_dtype, 435 | dtype=dtype, 436 | device=device, 437 | ) 438 | diagonal_attend_mask = torch.arange( 439 | target_length, device=device 440 | ) > cache_position.reshape(-1, 1) 441 | 442 | if config.sliding_window is not None: 443 | if ( 444 | not isinstance(past_key_values, SlidingWindowCache) 445 | or sequence_length > target_length 446 | ): 447 | sliding_attend_mask = torch.arange( 448 | target_length, device=device 449 | ) <= (cache_position.reshape(-1, 1) - config.sliding_window) 450 | diagonal_attend_mask.bitwise_or_(sliding_attend_mask) 451 | causal_mask *= diagonal_attend_mask 452 | causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) 453 | 454 | if attention_mask is not None: 455 | causal_mask = causal_mask.clone() 456 | if attention_mask.shape[-1] > target_length: 457 | attention_mask = attention_mask[:, :target_length] 458 | mask_length = attention_mask.shape[-1] 459 | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ 460 | :, None, None, : 461 | ].to(causal_mask.device) 462 | padding_mask = padding_mask == 0 463 | causal_mask[:, :, :, :mask_length] = causal_mask[ 464 | :, :, :, :mask_length 465 | ].masked_fill(padding_mask, min_dtype) 466 | 467 | return causal_mask 468 | 469 | 470 | class SuperLazyLanguageModel(nn.Module): 471 | """ 472 | A super lazy language model that encapsulates a Transformer decoder with optional LoRA parameters. 473 | 474 | This model wraps the Transformer decoder, defines the head for vocabulary logits, 475 | and computes the loss when labels are provided. It supports caching for efficient generation. 476 | 477 | Attributes: 478 | config (Config): Model configuration containing hyperparameters. 479 | model (Transformer): Underlying Transformer-based decoder. 480 | loss_function (nn.CrossEntropyLoss): Loss function for training. 481 | vocab_size (int): Size of the model vocabulary. 482 | lm_head (Linear): Linear projection layer from hidden states to vocabulary logits. 483 | """ 484 | 485 | def __init__(self, name, lora_alpha=16, lora_r=4, lora_dropout=0.1): 486 | """ 487 | Initialize the SuperLazyLanguageModel. 488 | 489 | Loads configuration and initializes the transformer decoder along with the language modeling head. 490 | Optionally, applies LoRA modifications if enabled in the configuration. 491 | 492 | Args: 493 | name (str): Identifier or name of the model. 494 | lora_alpha (int, optional): LoRA alpha hyperparameter. Defaults to 16. 495 | lora_r (int, optional): LoRA rank. Defaults to 4. 496 | lora_dropout (float, optional): LoRA dropout rate. Defaults to 0.1. 497 | """ 498 | super().__init__() 499 | 500 | self.config = Config( 501 | model_name=name, 502 | lora_alpha=lora_alpha, 503 | lora_r=lora_r, 504 | lora_dropout=lora_dropout, 505 | ) 506 | 507 | self.model = Transformer(self.config) 508 | self.loss_function = nn.CrossEntropyLoss() 509 | self.vocab_size = self.config.vocab_size 510 | 511 | if self.config.tie_word_embeddings: 512 | lm_head_weight_path = ( 513 | f"{self.config.weight_dir}/model.embed_tokens.weight.bin" 514 | ) 515 | else: 516 | lm_head_weight_path = f"{self.config.weight_dir}/lm_head.weight.bin" 517 | 518 | self.lm_head = Linear( 519 | self.config.hidden_size, 520 | self.config.vocab_size, 521 | weight_path=lm_head_weight_path, 522 | bias_path=None, 523 | ) 524 | 525 | def get_input_embeddings(self): 526 | """ 527 | Retrieve the input embeddings from the model. 528 | 529 | Returns: 530 | Embedding: Input embedding layer. 531 | """ 532 | return self.model.embed_tokens 533 | 534 | def set_input_embeddings(self, value): 535 | """ 536 | Set the input embeddings for the model. 537 | 538 | Args: 539 | value (Embedding): New input embedding layer. 540 | """ 541 | self.model.embed_tokens = value 542 | 543 | def get_output_embeddings(self): 544 | """ 545 | Retrieve the output embeddings (language modeling head). 546 | 547 | Returns: 548 | Linear: The linear layer projecting to vocabulary logits. 549 | """ 550 | return self.lm_head 551 | 552 | def set_output_embeddings(self, new_embeddings): 553 | """ 554 | Set the output embeddings (language modeling head) for the model. 555 | 556 | Args: 557 | new_embeddings (Linear): New output embedding layer. 558 | """ 559 | self.lm_head = new_embeddings 560 | 561 | def set_decoder(self, decoder): 562 | """ 563 | Replace the current Transformer decoder with a new one. 564 | 565 | Args: 566 | decoder (Transformer): A new transformer decoder module. 567 | """ 568 | self.model = decoder 569 | 570 | def get_decoder(self): 571 | """ 572 | Retrieve the current Transformer decoder. 573 | 574 | Returns: 575 | Transformer: The underlying transformer decoder. 576 | """ 577 | return self.model 578 | 579 | def forward( 580 | self, 581 | input_ids: torch.LongTensor = None, 582 | attention_mask: Optional[torch.Tensor] = None, 583 | position_ids: Optional[torch.LongTensor] = None, 584 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 585 | inputs_embeds: Optional[torch.FloatTensor] = None, 586 | labels: Optional[torch.LongTensor] = None, 587 | use_cache: Optional[bool] = None, 588 | output_attentions: Optional[bool] = None, 589 | output_hidden_states: Optional[bool] = None, 590 | return_dict: Optional[bool] = None, 591 | cache_position: Optional[torch.LongTensor] = None, 592 | logits_to_keep: Union[int, torch.Tensor] = 0, 593 | **kwargs, 594 | ) -> Union[Tuple, CausalLMOutputWithPast]: 595 | """ 596 | Perform a forward pass through the SuperLazyLanguageModel. 597 | 598 | The method computes embeddings (or uses precomputed ones), obtains outputs from the Transformer decoder, 599 | projects hidden states to vocabulary logits using the lm_head, and computes the cross-entropy loss if labels are provided. 600 | It supports caching for efficient autoregressive generation and can return outputs as a tuple or as a dict-like object. 601 | 602 | Args: 603 | input_ids (torch.LongTensor, optional): Token IDs as input. 604 | attention_mask (Optional[torch.Tensor], optional): Mask to avoid attending to padding tokens. 605 | position_ids (Optional[torch.LongTensor], optional): Position IDs corresponding to the tokens. 606 | past_key_values (Optional[Union[Cache, List[torch.FloatTensor]]], optional): Cached key/value tensors. 607 | inputs_embeds (Optional[torch.FloatTensor], optional): Pre-computed input embeddings. 608 | labels (Optional[torch.LongTensor], optional): Ground truth labels for computing the loss. 609 | use_cache (Optional[bool], optional): Whether to use past key values for caching. 610 | output_attentions (Optional[bool], optional): Whether to return attention weights. 611 | output_hidden_states (Optional[bool], optional): Whether to return hidden states. 612 | return_dict (Optional[bool], optional): Whether to return a dict-like output. 613 | cache_position (Optional[torch.LongTensor], optional): Cache positions for tokens. 614 | logits_to_keep (Union[int, torch.Tensor], optional): Slicing index or tensor for logits computation. 615 | **kwargs: Additional keyword arguments. 616 | 617 | Returns: 618 | Union[Tuple, CausalLMOutputWithPast]: 619 | - If return_dict is False, returns a tuple containing logits and optionally other outputs. 620 | - Otherwise, returns a `CausalLMOutputWithPast` with loss (if computed), logits, past key values, 621 | hidden states, and attention weights. 622 | """ 623 | output_attentions = ( 624 | output_attentions 625 | if output_attentions is not None 626 | else self.config.output_attentions 627 | ) 628 | output_hidden_states = ( 629 | output_hidden_states 630 | if output_hidden_states is not None 631 | else self.config.output_hidden_states 632 | ) 633 | return_dict = ( 634 | return_dict if return_dict is not None else self.config.use_return_dict 635 | ) 636 | 637 | outputs = self.model( 638 | input_ids=input_ids, 639 | attention_mask=attention_mask, 640 | position_ids=position_ids, 641 | past_key_values=past_key_values, 642 | inputs_embeds=inputs_embeds, 643 | use_cache=use_cache, 644 | output_attentions=output_attentions, 645 | output_hidden_states=output_hidden_states, 646 | return_dict=return_dict, 647 | cache_position=cache_position, 648 | **kwargs, 649 | ) 650 | 651 | hidden_states = outputs[0] 652 | slice_indices = ( 653 | slice(-logits_to_keep, None) 654 | if isinstance(logits_to_keep, int) 655 | else logits_to_keep 656 | ) 657 | logits = self.lm_head(hidden_states[:, slice_indices, :]) 658 | 659 | loss = None 660 | if labels is not None: 661 | shifted_logits = logits[:, :-1, :] 662 | shifted_labels = labels[:, 1:] 663 | loss = self.loss_function( 664 | shifted_logits.reshape(-1, shifted_logits.size(-1)), 665 | shifted_labels.reshape(-1), 666 | ) 667 | 668 | if not return_dict: 669 | output = (logits,) + outputs[1:] 670 | return (loss,) + output if loss is not None else output 671 | 672 | return CausalLMOutputWithPast( 673 | loss=loss, 674 | logits=logits, 675 | past_key_values=outputs.past_key_values, 676 | hidden_states=outputs.hidden_states, 677 | attentions=outputs.attentions, 678 | ) 679 | -------------------------------------------------------------------------------- /sllm/nn/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements various neural network layers and functions that are used for 3 | lazy weight loading and efficient computation in a transformer-based model. The layers include: 4 | 5 | - Embedding: Token embedding that lazily loads weights from disk. 6 | - RotaryEmbedding: Implements rotary position embeddings with dynamic frequency updates. 7 | - RMSNorm: Root-mean-square layer normalization with lazy weight loading. 8 | - Linear: A linear layer that lazily loads large weight matrices. 9 | - MLP: Feed-forward neural network layer using bundled matrix multiplications. 10 | - LoraLinear: Linear layer with LoRA adaptation for efficient fine-tuning. 11 | - LoraQKVLinear: Specialized LoRA linear layer for query/key/value projections. 12 | - Attention: Multi-head attention layer with support for rotary embeddings and caching. 13 | 14 | Each layer leverages lazy weight loading via the `load_tensor_from_storage` utility to 15 | optimize memory usage. 16 | """ 17 | 18 | import gc 19 | from typing import Any, Optional, Tuple 20 | 21 | import torch 22 | import torch.nn.functional as F 23 | from torch import nn 24 | 25 | from sllm.common import DTYPE 26 | from sllm.config import Config 27 | from sllm.nn.autodiff import (BundledMatmulFunction, 28 | LoraFunction, 29 | LoraQKVLinearFunction, 30 | MatmulFunction) 31 | from sllm.utils import load_tensor_from_storage 32 | 33 | 34 | class Embedding(torch.nn.Module): 35 | """ 36 | Token embedding layer with lazy weight loading. 37 | 38 | This module loads the weight matrix from storage on each forward pass, 39 | avoiding the need to hold the full embedding matrix in RAM. 40 | """ 41 | 42 | def __init__(self, weight_path, vocab_size, hidden_size, padding_idx=None): 43 | """ 44 | Initialize the Embedding layer. 45 | 46 | Args: 47 | weight_path (str): Path to the weight file. 48 | vocab_size (int): Number of tokens in the vocabulary. 49 | hidden_size (int): Dimensionality of the embedding vectors. 50 | padding_idx (optional): Padding index for tokens; defaults to None. 51 | """ 52 | super().__init__() 53 | self.weight_path = weight_path 54 | self.vocab_size = vocab_size 55 | self.hidden_size = hidden_size 56 | self.padding_idx = padding_idx 57 | 58 | def forward(self, input_ids): 59 | """ 60 | Perform a forward pass to retrieve embeddings for the provided token IDs. 61 | 62 | Args: 63 | input_ids (Tensor): Tensor containing token IDs with shape (...). 64 | 65 | Returns: 66 | Tensor: Embedding vectors corresponding to the input IDs. 67 | """ 68 | with torch.no_grad(): 69 | weight = load_tensor_from_storage( 70 | weight_path=self.weight_path, 71 | shape=(self.vocab_size, self.hidden_size), 72 | to_ram=False, 73 | ) 74 | return weight[input_ids] 75 | 76 | 77 | class RotaryEmbedding(nn.Module): 78 | """ 79 | Rotary position embedding module with dynamic frequency updates. 80 | 81 | Computes rotary embeddings and applies attention scaling. It updates the frequency 82 | parameters if the input sequence length exceeds cached values. 83 | """ 84 | 85 | def __init__(self, config: Config): 86 | """ 87 | Initialize the RotaryEmbedding module. 88 | 89 | Args: 90 | config (Config): Model configuration containing parameters such as 91 | max_position_embeddings and rope_theta. 92 | """ 93 | super().__init__() 94 | self.max_seq_len_cached = config.max_position_embeddings 95 | self.original_max_seq_len = config.max_position_embeddings 96 | self.config = config 97 | inv_freq, self.attention_scaling = self.compute_rope_parameters(self.config) 98 | self.register_buffer("inv_freq", inv_freq, persistent=False) 99 | self.original_inv_freq = self.inv_freq 100 | 101 | def _dynamic_frequency_update(self, position_ids, device): 102 | """ 103 | Dynamically update the frequency parameters if needed. 104 | 105 | The update occurs if: 106 | 1. The sequence length exceeds the cached maximum. 107 | 2. The current sequence length is smaller than the original max sequence length 108 | while the cached maximum is larger. 109 | 110 | Args: 111 | position_ids (Tensor): Tensor of position indices. 112 | device (torch.device): The device on which to perform the computation. 113 | """ 114 | seq_len = torch.max(position_ids) + 1 115 | if seq_len > self.max_seq_len_cached: 116 | inv_freq, self.attention_scaling = self.rope_init_fn( 117 | self.config, device, seq_len=seq_len 118 | ) 119 | self.register_buffer("inv_freq", inv_freq, persistent=False) 120 | self.max_seq_len_cached = seq_len 121 | 122 | if ( 123 | seq_len < self.original_max_seq_len 124 | and self.max_seq_len_cached > self.original_max_seq_len 125 | ): 126 | self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 127 | self.max_seq_len_cached = self.original_max_seq_len 128 | 129 | @torch.no_grad() 130 | def forward(self, x, position_ids): 131 | """ 132 | Compute rotary embeddings for the input tensor. 133 | 134 | Args: 135 | x (Tensor): Input tensor of shape (batch_size, seq_length, ...). 136 | position_ids (Tensor): Position indices for each token. 137 | 138 | Returns: 139 | Tuple[Tensor, Tensor]: A tuple (cos, sin) where each tensor contains 140 | the cosine and sine components of the rotary embeddings. 141 | """ 142 | inv_freq_expanded = ( 143 | self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) 144 | ) 145 | position_ids_expanded = position_ids[:, None, :].float() 146 | device_type = x.device.type 147 | device_type = ( 148 | device_type 149 | if isinstance(device_type, str) and device_type != "mps" 150 | else "cpu" 151 | ) 152 | 153 | with torch.autocast(device_type=device_type, enabled=False): 154 | freqs = ( 155 | inv_freq_expanded.float() @ position_ids_expanded.float() 156 | ).transpose(1, 2) 157 | emb = torch.cat((freqs, freqs), dim=-1) 158 | cos = emb.cos() 159 | sin = emb.sin() 160 | 161 | # Apply attention scaling for advanced RoPE types (e.g., yarn) 162 | cos = cos * self.attention_scaling 163 | sin = sin * self.attention_scaling 164 | 165 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 166 | 167 | def compute_rope_parameters( 168 | self, 169 | config: Optional[Config] = None, 170 | ) -> Tuple["torch.Tensor", float]: 171 | """ 172 | Compute the rotary embedding parameters. 173 | 174 | This method calculates the inverse frequency tensor and returns the corresponding attention scaling 175 | factor. 176 | 177 | Args: 178 | config (Optional[Config], optional): The model configuration. If not provided, 179 | defaults to the instance's configuration. 180 | 181 | Returns: 182 | Tuple[Tensor, float]: A tuple (inv_freq, attention_factor) where: 183 | - inv_freq is the inverse frequency tensor. 184 | - attention_factor is the scaling factor for attention. 185 | """ 186 | partial_rotary_factor = ( 187 | config.partial_rotary_factor 188 | if hasattr(config, "partial_rotary_factor") 189 | else 1.0 190 | ) 191 | head_dim = getattr( 192 | config, "head_dim", config.hidden_size // config.num_attention_heads 193 | ) 194 | dim = int(head_dim * partial_rotary_factor) 195 | attention_factor = 1.0 196 | inv_freq = 1.0 / ( 197 | config.rope_theta 198 | ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim) 199 | ) 200 | return inv_freq, attention_factor 201 | 202 | 203 | class RMSNorm(torch.nn.Module): 204 | """ 205 | Root-mean-square layer normalization with lazy-loaded weights. 206 | 207 | Applies RMS normalization to the input tensor using pre-loaded weights. 208 | """ 209 | 210 | def __init__(self, hidden_size, weight_path, eps=1e-6): 211 | """ 212 | Initialize the RMSNorm layer. 213 | 214 | Args: 215 | hidden_size (int or tuple): The shape of the weight tensor. 216 | weight_path (str): Path to the weight file. 217 | eps (float, optional): A small value to avoid division by zero. Defaults to 1e-6. 218 | """ 219 | super().__init__() 220 | weight = load_tensor_from_storage( 221 | weight_path=weight_path, shape=hidden_size, to_ram=True 222 | ) 223 | self.register_buffer("weight", weight) 224 | self.variance_epsilon = eps 225 | 226 | def forward(self, hidden_states): 227 | """ 228 | Apply RMS normalization to the input tensor. 229 | 230 | Args: 231 | hidden_states (Tensor): Input tensor with shape (..., hidden_size). 232 | 233 | Returns: 234 | Tensor: The normalized tensor scaled by the registered weight. 235 | """ 236 | input_dtype = hidden_states.dtype 237 | hidden_states = hidden_states.to(torch.float32) 238 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 239 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 240 | return self.weight * hidden_states.to(input_dtype) 241 | 242 | def extra_repr(self): 243 | """ 244 | Return additional string representation of the RMSNorm module. 245 | 246 | Returns: 247 | str: A string representation showing the weight shape and epsilon value. 248 | """ 249 | return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" 250 | 251 | 252 | class Linear(nn.Module): 253 | """ 254 | Linear layer that lazily loads a large weight matrix from disk. 255 | 256 | The weight matrix is loaded on every forward pass with gradients disabled. 257 | The bias, if provided, is loaded entirely into RAM. 258 | """ 259 | 260 | def __init__(self, in_features, out_features, weight_path, bias_path=None): 261 | """ 262 | Initialize the Linear layer. 263 | 264 | Args: 265 | in_features (int): Size of each input sample. 266 | out_features (int): Size of each output sample. 267 | weight_path (str): Path to the weight file. 268 | bias_path (str, optional): Path to the bias file. Defaults to None. 269 | """ 270 | super().__init__() 271 | self.in_features = in_features 272 | self.out_features = out_features 273 | self.weight_path = weight_path 274 | self.bias_path = bias_path 275 | 276 | if self.bias_path is not None: 277 | with torch.no_grad(): 278 | bias = load_tensor_from_storage( 279 | weight_path=self.bias_path, shape=(self.out_features,), to_ram=True 280 | ) 281 | self.register_buffer("bias", bias) 282 | else: 283 | self.bias = None 284 | 285 | def forward(self, x): 286 | """ 287 | Perform the forward pass of the linear layer. 288 | 289 | Loads the weight matrix from disk and computes the matrix multiplication. 290 | If bias is present, adds it to the result. 291 | 292 | Args: 293 | x (Tensor): Input tensor of shape (..., in_features). 294 | 295 | Returns: 296 | Tensor: Output tensor of shape (..., out_features). 297 | """ 298 | weight = load_tensor_from_storage( 299 | weight_path=self.weight_path, 300 | shape=(self.out_features, self.in_features), 301 | to_ram=False, 302 | ).t() 303 | 304 | Wx = MatmulFunction.apply(x, weight, 1.0) 305 | if self.bias is not None: 306 | Wx += self.bias 307 | return Wx 308 | 309 | 310 | class MLP(nn.Module): 311 | """ 312 | Feed-forward network (MLP) module using bundled matrix multiplications. 313 | 314 | This layer implements a gated MLP where the input is projected with two parallel 315 | linear layers, one of which is passed through a non-linearity and then elementwise 316 | multiplied, followed by a final projection. 317 | """ 318 | 319 | def __init__(self, config, layer_idx): 320 | """ 321 | Initialize the MLP module. 322 | 323 | Args: 324 | config (Config): Model configuration including hidden and intermediate sizes. 325 | layer_idx (int): Index of the current layer for loading weight files. 326 | """ 327 | super().__init__() 328 | self.config = config 329 | self.hidden_size = config.hidden_size 330 | self.intermediate_size = config.intermediate_size 331 | self.act_fn = nn.SiLU() 332 | self.gate_proj_path = ( 333 | f"{config.weight_dir}/model.layers.{layer_idx}.mlp.gate_proj.weight.bin" 334 | ) 335 | self.up_proj_path = ( 336 | f"{config.weight_dir}/model.layers.{layer_idx}.mlp.up_proj.weight.bin" 337 | ) 338 | self.down_proj_path = ( 339 | f"{config.weight_dir}/model.layers.{layer_idx}.mlp.down_proj.weight.bin" 340 | ) 341 | 342 | def forward(self, x): 343 | """ 344 | Compute the forward pass of the MLP. 345 | 346 | Projects the input using two parallel linear transformations (gate and up projections), 347 | applies a non-linearity and elementwise multiplication, and finally projects back down. 348 | 349 | Args: 350 | x (Tensor): Input tensor with shape (..., hidden_size). 351 | 352 | Returns: 353 | Tensor: Output tensor with shape (..., hidden_size). 354 | """ 355 | gate_proj = load_tensor_from_storage( 356 | weight_path=self.gate_proj_path, 357 | shape=(self.intermediate_size, self.hidden_size), 358 | to_ram=False, 359 | ) 360 | up_proj = load_tensor_from_storage( 361 | weight_path=self.up_proj_path, 362 | shape=(self.intermediate_size, self.hidden_size), 363 | to_ram=False, 364 | ) 365 | down_proj = load_tensor_from_storage( 366 | weight_path=self.down_proj_path, 367 | shape=(self.hidden_size, self.intermediate_size), 368 | to_ram=False, 369 | ) 370 | 371 | bundles = [ 372 | (x, gate_proj.t(), 1.0), 373 | (x, up_proj.t(), 1.0), 374 | ] 375 | gate_proj_out, up_proj_out = BundledMatmulFunction.apply(bundles) 376 | activated_gate_proj = self.act_fn(gate_proj_out) * up_proj_out 377 | return MatmulFunction.apply(activated_gate_proj, down_proj.t(), 1.0) 378 | 379 | 380 | class LoraLinear(nn.Module): 381 | """ 382 | LoRA adapted linear layer for efficient fine-tuning. 383 | 384 | This module applies LoRA by incorporating trainable low-rank matrices (lora_A, lora_B) 385 | to adjust the pre-trained weight matrix. 386 | """ 387 | 388 | def __init__( 389 | self, 390 | in_features, 391 | out_features, 392 | r, 393 | alpha, 394 | weight_path, 395 | bias_path=None, 396 | lora_dropout=0.0, 397 | ): 398 | """ 399 | Initialize the LoraLinear layer. 400 | 401 | Args: 402 | in_features (int): Number of input features. 403 | out_features (int): Number of output features. 404 | r (int): LoRA rank. 405 | alpha (int): LoRA scaling factor. 406 | weight_path (str): Path to the pre-trained weight file. 407 | bias_path (str, optional): Path to the bias file. Defaults to None. 408 | lora_dropout (float, optional): Dropout probability for LoRA. Defaults to 0.0. 409 | """ 410 | super().__init__() 411 | self.in_features = in_features 412 | self.out_features = out_features 413 | self.r = r 414 | self.alpha = alpha 415 | self.scaling = self.alpha / self.r 416 | self.weight_path = weight_path 417 | self.lora_A = nn.Parameter(torch.randn(r, in_features)) 418 | self.lora_B = nn.Parameter(torch.randn(out_features, r)) 419 | self.lora_dropout = nn.Dropout(lora_dropout) 420 | self.bias = None 421 | if bias_path is not None: 422 | with torch.no_grad(): 423 | bias = load_tensor_from_storage( 424 | weight_path=bias_path, shape=(self.out_features,), to_ram=True 425 | ) 426 | self.register_buffer("bias", bias) 427 | 428 | def forward(self, x): 429 | """ 430 | Forward pass for the LoraLinear layer. 431 | 432 | Applies dropout to the input, then uses the LoRA function to compute the modified linear transformation. 433 | 434 | Args: 435 | x (Tensor): Input tensor of shape (..., in_features). 436 | 437 | Returns: 438 | Tensor: Output tensor of shape (..., out_features). 439 | """ 440 | x_dropped = self.lora_dropout(x) 441 | return LoraFunction.apply( 442 | x_dropped, 443 | self.lora_A, 444 | self.lora_B, 445 | self.weight_path, 446 | self.scaling, 447 | self.bias, 448 | ) 449 | 450 | 451 | class LoraQKVLinear(nn.Module): 452 | """ 453 | LoRA adapted linear layer for query, key, and value projections in attention. 454 | 455 | This layer applies LoRA to the query, key, and value projections and lazily loads 456 | the corresponding weight matrices. 457 | """ 458 | 459 | def __init__( 460 | self, 461 | config, 462 | head_dim, 463 | q_weight_path, 464 | k_weight_path, 465 | v_weight_path, 466 | q_bias_path=None, 467 | k_bias_path=None, 468 | v_bias_path=None, 469 | ): 470 | """ 471 | Initialize the LoraQKVLinear layer. 472 | 473 | Args: 474 | config (Config): Model configuration with LoRA parameters. 475 | head_dim (int): Dimension of each attention head. 476 | q_weight_path (str): Path to the query projection weight file. 477 | k_weight_path (str): Path to the key projection weight file. 478 | v_weight_path (str): Path to the value projection weight file. 479 | q_bias_path (str, optional): Path to the query projection bias file. 480 | k_bias_path (str, optional): Path to the key projection bias file. 481 | v_bias_path (str, optional): Path to the value projection bias file. 482 | """ 483 | super().__init__() 484 | self.scaling = config.lora_alpha / config.lora_r 485 | self.lora_dropout = nn.Dropout(config.lora_dropout) 486 | 487 | self.q_proj_lora_A = nn.Parameter( 488 | torch.randn(config.hidden_size, config.lora_r, dtype=DTYPE) 489 | ) 490 | nn.init.kaiming_uniform_(self.q_proj_lora_A, nonlinearity="linear") 491 | self.q_proj_lora_B = nn.Parameter( 492 | torch.zeros( 493 | config.lora_r, config.num_attention_heads * head_dim, dtype=DTYPE 494 | ) 495 | ) 496 | self.k_proj_lora_A = nn.Parameter( 497 | torch.randn(config.hidden_size, config.lora_r, dtype=DTYPE) 498 | ) 499 | nn.init.kaiming_uniform_(self.k_proj_lora_A, nonlinearity="linear") 500 | self.k_proj_lora_B = nn.Parameter( 501 | torch.zeros( 502 | config.lora_r, config.num_key_value_heads * head_dim, dtype=DTYPE 503 | ) 504 | ) 505 | self.v_proj_lora_A = nn.Parameter( 506 | torch.randn(config.hidden_size, config.lora_r, dtype=DTYPE) 507 | ) 508 | nn.init.kaiming_uniform_(self.v_proj_lora_A, nonlinearity="linear") 509 | self.v_proj_lora_B = nn.Parameter( 510 | torch.zeros( 511 | config.lora_r, config.num_key_value_heads * head_dim, dtype=DTYPE 512 | ) 513 | ) 514 | 515 | self.q_weight_path = q_weight_path 516 | self.k_weight_path = k_weight_path 517 | self.v_weight_path = v_weight_path 518 | 519 | self.q_proj_bias = None 520 | self.k_proj_bias = None 521 | self.v_proj_bias = None 522 | 523 | q_dim = config.num_attention_heads * head_dim 524 | kv_dim = config.num_key_value_heads * head_dim 525 | 526 | if q_bias_path is not None: 527 | self.q_proj_bias = load_tensor_from_storage( 528 | weight_path=q_bias_path, shape=(q_dim,), to_ram=True 529 | ) 530 | 531 | if k_bias_path is not None: 532 | self.k_proj_bias = load_tensor_from_storage( 533 | weight_path=k_bias_path, shape=(kv_dim,), to_ram=True 534 | ) 535 | 536 | if v_bias_path is not None: 537 | self.v_proj_bias = load_tensor_from_storage( 538 | weight_path=v_bias_path, shape=(kv_dim,), to_ram=True 539 | ) 540 | 541 | def forward(self, x): 542 | """ 543 | Forward pass for the LoRA QKV linear layer. 544 | 545 | Applies dropout to the input and computes the LoRA adapted query, key, and value projections. 546 | 547 | Args: 548 | x (Tensor): Input tensor of shape (..., hidden_size). 549 | 550 | Returns: 551 | Tuple[Tensor, Tensor, Tensor]: A tuple containing the projected query, key, and value tensors. 552 | """ 553 | x_dropped = self.lora_dropout(x) 554 | return LoraQKVLinearFunction.apply( 555 | x_dropped, 556 | self.q_weight_path, 557 | self.k_weight_path, 558 | self.v_weight_path, 559 | self.q_proj_bias, 560 | self.k_proj_bias, 561 | self.v_proj_bias, 562 | self.q_proj_lora_A, 563 | self.q_proj_lora_B, 564 | self.k_proj_lora_A, 565 | self.k_proj_lora_B, 566 | self.v_proj_lora_A, 567 | self.v_proj_lora_B, 568 | self.scaling, 569 | ) 570 | 571 | 572 | class Attention(nn.Module): 573 | """ 574 | Multi-head attention layer with support for rotary embeddings and caching. 575 | 576 | This layer computes queries, keys, and values using a LoRA adapted linear layer, 577 | applies rotary position embeddings, and then performs scaled dot-product attention. 578 | It also supports updating caches via past key/value mechanisms. 579 | """ 580 | 581 | def __init__(self, config: Config, layer_idx: int): 582 | """ 583 | Initialize the Attention layer. 584 | 585 | Args: 586 | config (Config): Model configuration. 587 | layer_idx (int): Index of the current layer for loading layer-specific weights. 588 | """ 589 | super().__init__() 590 | self.config = config 591 | self.layer_idx = layer_idx 592 | self.head_dim = getattr( 593 | config, "head_dim", config.hidden_size // config.num_attention_heads 594 | ) 595 | self.num_key_value_groups = ( 596 | config.num_attention_heads // config.num_key_value_heads 597 | ) 598 | self.scaling = self.head_dim**-0.5 599 | self.attention_dropout = config.attention_dropout 600 | self.is_causal = True 601 | 602 | self.q_proj_weight_path = ( 603 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.q_proj.weight.bin" 604 | ) 605 | self.k_proj_weight_path = ( 606 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.k_proj.weight.bin" 607 | ) 608 | self.v_proj_weight_path = ( 609 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.v_proj.weight.bin" 610 | ) 611 | self.q_proj_bias_path = ( 612 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.q_proj.bias.bin" 613 | ) 614 | self.k_proj_bias_path = ( 615 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.k_proj.bias.bin" 616 | ) 617 | self.v_proj_bias_path = ( 618 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.v_proj.bias.bin" 619 | ) 620 | self.o_proj_weight_path = ( 621 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.o_proj.weight.bin" 622 | ) 623 | 624 | self.lora = LoraQKVLinear( 625 | config, 626 | self.head_dim, 627 | self.q_proj_weight_path, 628 | self.k_proj_weight_path, 629 | self.v_proj_weight_path, 630 | self.q_proj_bias_path, 631 | self.k_proj_bias_path, 632 | self.v_proj_bias_path, 633 | ) 634 | 635 | def forward( 636 | self, 637 | hidden_states: torch.Tensor, 638 | position_embeddings: Tuple[torch.Tensor, torch.Tensor], 639 | attention_mask: Optional[torch.Tensor], 640 | past_key_value: Optional[Any] = None, 641 | cache_position: Optional[torch.LongTensor] = None, 642 | **kwargs, 643 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 644 | """ 645 | Forward pass for the Attention layer. 646 | 647 | Projects inputs into queries, keys, and values, applies rotary embeddings, and computes 648 | the attention output along with optional attention weights. 649 | 650 | Args: 651 | hidden_states (Tensor): Input tensor with shape (batch_size, seq_length, hidden_size). 652 | position_embeddings (Tuple[Tensor, Tensor]): Tuple containing cosine and sine embeddings. 653 | attention_mask (Optional[Tensor]): Attention mask to apply. 654 | past_key_value (Optional[Any]): Cached past key/value states. 655 | cache_position (Optional[LongTensor]): Cache positions for tokens. 656 | **kwargs: Additional keyword arguments. 657 | 658 | Returns: 659 | Tuple: 660 | - attn_output (Tensor): Output tensor after attention. 661 | - attn_weights (Optional[Tensor]): Attention weights if computed. 662 | """ 663 | input_shape = hidden_states.shape[:-1] 664 | hidden_shape = (*input_shape, -1, self.head_dim) 665 | 666 | query_states, key_states, value_states = self.lora(hidden_states) 667 | query_states = query_states.view(hidden_shape).transpose(1, 2) 668 | key_states = key_states.view(hidden_shape).transpose(1, 2) 669 | value_states = value_states.view(hidden_shape).transpose(1, 2) 670 | 671 | cos, sin = position_embeddings 672 | query_states, key_states = self.apply_rotary_pos_emb( 673 | query_states, key_states, cos, sin 674 | ) 675 | 676 | if past_key_value is not None: 677 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 678 | key_states, value_states = past_key_value.update( 679 | key_states, value_states, self.layer_idx, cache_kwargs 680 | ) 681 | 682 | sliding_window = None 683 | if ( 684 | self.config.use_sliding_window 685 | and getattr(self.config, "sliding_window", None) is not None 686 | and self.layer_idx >= self.config.max_window_layers 687 | ): 688 | sliding_window = self.config.sliding_window 689 | 690 | attn_output, attn_weights = self.attention_forward( 691 | self, 692 | query_states, 693 | key_states, 694 | value_states, 695 | attention_mask, 696 | dropout=0.0 if not self.training else self.attention_dropout, 697 | scaling=self.scaling, 698 | sliding_window=sliding_window, 699 | **kwargs, 700 | ) 701 | 702 | attn_output = attn_output.reshape(*input_shape, -1).contiguous() 703 | 704 | with torch.no_grad(): 705 | o_proj_weight = load_tensor_from_storage( 706 | weight_path=self.o_proj_weight_path, 707 | shape=(self.config.hidden_size, self.config.hidden_size), 708 | to_ram=False, 709 | ) 710 | attn_output = MatmulFunction.apply(attn_output, o_proj_weight.t(), 1.0) 711 | gc.collect() 712 | return attn_output, attn_weights 713 | 714 | def attention_forward( 715 | self, 716 | module: nn.Module, 717 | query: torch.Tensor, 718 | key: torch.Tensor, 719 | value: torch.Tensor, 720 | attention_mask: Optional[torch.Tensor], 721 | scaling: float, 722 | dropout: float = 0.0, 723 | **kwargs, 724 | ): 725 | """ 726 | Compute the core attention mechanism. 727 | 728 | Applies scaled dot-product attention over the queries, keys, and values, taking into account 729 | the attention mask and dropout. 730 | 731 | Args: 732 | module (nn.Module): Reference to the current module (used for configuration). 733 | query (Tensor): Query tensor. 734 | key (Tensor): Key tensor. 735 | value (Tensor): Value tensor. 736 | attention_mask (Optional[Tensor]): Mask to apply to the attention weights. 737 | scaling (float): Scaling factor applied to the dot product. 738 | dropout (float, optional): Dropout probability for the attention weights. 739 | **kwargs: Additional keyword arguments. 740 | 741 | Returns: 742 | Tuple: 743 | - attn_output (Tensor): The attention output tensor. 744 | - attn_weights (Tensor): The computed attention weights. 745 | """ 746 | key_states = self.repeat_kv(key, module.num_key_value_groups) 747 | value_states = self.repeat_kv(value, module.num_key_value_groups) 748 | 749 | attn_weights = MatmulFunction.apply( 750 | query, key_states.transpose(2, 3), scaling 751 | ) 752 | if attention_mask is not None: 753 | causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] 754 | attn_weights = attn_weights + causal_mask 755 | 756 | attn_weights = nn.functional.softmax( 757 | attn_weights, dim=-1, dtype=torch.float32 758 | ).to(query.dtype) 759 | attn_weights = nn.functional.dropout( 760 | attn_weights, p=dropout, training=module.training 761 | ) 762 | attn_output = MatmulFunction.apply(attn_weights, value_states, 1.0) 763 | attn_output = attn_output.transpose(1, 2).contiguous() 764 | return attn_output, attn_weights 765 | 766 | def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 767 | """ 768 | Repeat key/value hidden states to match the required number of groups. 769 | 770 | Args: 771 | hidden_states (Tensor): Hidden states tensor with shape (batch, num_heads, seq_len, head_dim). 772 | n_rep (int): Number of repetitions. 773 | 774 | Returns: 775 | Tensor: Reshaped tensor with repeated key/value states. 776 | """ 777 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 778 | if n_rep == 1: 779 | return hidden_states 780 | hidden_states = hidden_states[:, :, None, :, :].expand( 781 | batch, num_key_value_heads, n_rep, slen, head_dim 782 | ) 783 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 784 | 785 | def rotate_half(self, x): 786 | """ 787 | Rotate half of the tensor for rotary embedding application. 788 | 789 | Args: 790 | x (Tensor): Input tensor. 791 | 792 | Returns: 793 | Tensor: Tensor with rotated halves. 794 | """ 795 | x1 = x[..., : x.shape[-1] // 2] 796 | x2 = x[..., x.shape[-1] // 2 :] 797 | return torch.cat((-x2, x1), dim=-1) 798 | 799 | def apply_rotary_pos_emb(self, queries, keys, cos, sin, unsqueeze_dim=1): 800 | """ 801 | Apply rotary positional embeddings to queries and keys. 802 | 803 | Args: 804 | queries (Tensor): Query tensor. 805 | keys (Tensor): Key tensor. 806 | cos (Tensor): Cosine embedding. 807 | sin (Tensor): Sine embedding. 808 | unsqueeze_dim (int, optional): Dimension along which to unsqueeze the cosine and sine embeddings. 809 | Defaults to 1. 810 | 811 | Returns: 812 | Tuple[Tensor, Tensor]: Rotated queries and keys with positional embeddings applied. 813 | """ 814 | cos = cos.unsqueeze(unsqueeze_dim) 815 | sin = sin.unsqueeze(unsqueeze_dim) 816 | queries_embed = (queries * cos) + (self.rotate_half(queries) * sin) 817 | keys_embed = (keys * cos) + (self.rotate_half(keys) * sin) 818 | return queries_embed, keys_embed 819 | --------------------------------------------------------------------------------