├── sllm
├── __init__.py
├── ops
│ ├── __init__.py
│ └── matmul.cpp
├── nn
│ ├── __init__.py
│ ├── autodiff.py
│ ├── transformer.py
│ └── layers.py
├── common.py
├── utils.py
├── config.py
└── train.py
├── requirements.txt
├── .DS_Store
├── assets
└── logo.png
├── pyproject.toml
├── example.py
├── setup.py
├── .gitignore
├── README.md
└── LICENSE
/sllm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets
2 | numpy
3 | torch
4 | transformers
5 |
--------------------------------------------------------------------------------
/sllm/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from .matmul import bundled_scaled_matmul
2 |
--------------------------------------------------------------------------------
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HenryNdubuaku/super-lazy-autograd/HEAD/.DS_Store
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HenryNdubuaku/super-lazy-autograd/HEAD/assets/logo.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=42", "wheel", "torch", "pybind11"]
3 | build-backend = "setuptools.build_meta"
4 |
--------------------------------------------------------------------------------
/sllm/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from sllm.nn.layers import (Attention, Linear, LoraLinear,
2 | MLP)
3 | from sllm.nn.transformer import SuperLazyLanguageModel
4 |
--------------------------------------------------------------------------------
/sllm/common.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from platformdirs import user_cache_dir
5 |
6 | CACHE_DIR = user_cache_dir("sllm")
7 | WEIGHT_DIR = f"{CACHE_DIR}/weights"
8 | GRADIENT_DIR = f"{CACHE_DIR}/gradient_checkpoints"
9 | os.makedirs(GRADIENT_DIR, exist_ok=True)
10 |
11 | MAX_CONCURRENT_THREADS = os.cpu_count() * 2
12 | MINI_BATCH_SIZE = 2
13 |
14 | DTYPE = torch.float32
15 | MAX_GRAD_NORM = 0.01
16 |
17 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
18 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datasets import load_dataset
3 |
4 | from sllm.nn import SuperLazyLanguageModel
5 | from sllm.train import prepare_dataset, sft
6 |
7 | torch.manual_seed(42)
8 |
9 | name = "Qwen/Qwen2-0.5B-Instruct"
10 | dataset = load_dataset("yahma/alpaca-cleaned", split="train[:20]")
11 |
12 | dataset = prepare_dataset(
13 | model_name=name,
14 | instructions=dataset["instruction"],
15 | responses=dataset["output"],
16 | inputs=dataset["input"],
17 | max_seq_len=256,
18 | )
19 |
20 | model = SuperLazyLanguageModel(
21 | name=name,
22 | lora_alpha=16,
23 | lora_r=4,
24 | lora_dropout=0.1,
25 | )
26 |
27 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
28 | sft(model=model, dataset=dataset, optimizer=optimizer, batch_size=2, epochs=3)
29 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 | from torch.utils.cpp_extension import BuildExtension, CppExtension
3 |
4 | ext_modules = [
5 | CppExtension(
6 | name='sllm.ops.matmul',
7 | sources=['sllm/ops/matmul.cpp'],
8 | extra_compile_args=['-O3']
9 | )
10 | ]
11 |
12 | setup(
13 | name="sllm-lib",
14 | version="0.0.1",
15 | description="Super Lazy Language Model",
16 | long_description=open("README.md").read(),
17 | long_description_content_type="text/markdown",
18 | author="Henry Ndubuaku",
19 | packages=find_packages(),
20 | install_requires=[
21 | "numpy",
22 | "transformers",
23 | "platformdirs",
24 | "tqdm",
25 | "pybind11"
26 | ],
27 | ext_modules=ext_modules,
28 | cmdclass={"build_ext": BuildExtension},
29 | classifiers=[
30 | "Programming Language :: Python :: 3",
31 | "License :: OSI Approved :: MIT License",
32 | "Operating System :: OS Independent",
33 | ],
34 | )
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 |
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .nox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | *.py,cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Jupyter Notebook checkpoints
52 | .ipynb_checkpoints
53 |
54 | # IPython
55 | profile_default/
56 | ipython_config.py
57 |
58 | # Environments
59 | .env
60 | .venv
61 | env/
62 | venv/
63 | ENV/
64 | env.bak/
65 | venv.bak/
66 |
67 | # IDEs and editors
68 | .vscode/
69 | .idea/
70 | *.sublime-workspace
71 | *.sublime-project
72 |
73 | # PyCharm
74 | *.iml
75 |
76 | # System files
77 | .DS_Store
78 | Thumbs.db
79 |
80 | # Python cache files
81 | __pycache__/
82 | *.py[cod]
83 |
84 | # mypy
85 | .mypy_cache/
86 | .dmypy.json
87 | dmypy.json
88 |
89 | # Pyre type checker
90 | .pyre/
91 |
92 | # pyright
93 | pyrightconfig.json
94 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | [](https://www.linkedin.com//company/80434055) [](https://twitter.com/hmunachii)
6 |
7 | Author: [Henry Ndubuaku](https://www.linkedin.com/in/henry-ndubuaku-7b6350b8/)
8 |
9 | ## Overview
10 |
11 | I mean, do not train or fine-tune LLMs on your laptop, traing is done at much higher precision than inference (float32 or bfload16). Also, additional memory is often used for the gradients, optimizer states, and batch size. So, 4 - 6x the model size. For simplicity, around 8-24G of RAM per 1B params.
12 |
13 | HOWEVER, if you must do so on a laptop for whatever weird reason, this library implements most language models such that only the weights for each layer is loaded to the RAM, it implements LoRA fine-tuning such that frozen params are memory-mapped rather than loaded.
14 |
15 | Note the following:
16 | 1) Compute intensity = computation time / communication time, and maximisin this means maximising GPU utilisation.
17 | 2) Many computations in transformer models can be parallelised, QKV projections for example.
18 | 3) Most operations in transformers follow the signature A @ B * Scale, A.K.A scaled dot-product.
19 | 4) Q @ K.T / sqrt(dimK) is obiously equivalent to Q @ K.T * dimK^(-1/2)
20 | 5) But Lora_A @ Lora_B = Lora_A @ Lora_B * 1, also A * B = I @ A * B, and so on.
21 |
22 | We expressed the transformer forward pass and the backward vector-jacobian products for each layer as a bunch of scaled matmuls, which are bundled together and executed in parallel across different CPU cores as C++ extensions to bypass GIL. This concept makes it easy for an upcoming feature, where each bundle could be distributed across your friends' laptops, such that they only execute one operation called Bundled Scaled Matmul. You're welcome.
23 |
24 | ## Limitations
25 |
26 | 1) Gradient accumulation, gradient checkpointing and lazy execution trade time-complexity for memory-efficiency, but you have no choice, do you?
27 | 2) Yeah...your laptop will definitley heat up, GPUs burn up at data centers and cost so much to cool, your laptop is not special.
28 |
29 | ## Supported Models
30 |
31 | - deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
32 | - Qwen/Qwen2.5-0.5B
33 | - Qwen/Qwen2.5-0.5B-Instruct
34 | - Qwen/Qwen2.5-1.5B
35 | - Qwen/Qwen2.5-1.5B-Instruct
36 | - Qwen/Qwen2.5-3B
37 | - Qwen/Qwen2.5-3B-Instruct
38 |
39 | ## Getting Started
40 |
41 | 1. ```bash
42 | pip install sllm-lib
43 | ```
44 | 2. Initialize the model:
45 | ```python
46 | from sllm.nn import SuperLazyLanguageModel
47 | from sllm.config import Config
48 |
49 | config = Config(
50 | model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
51 | lora_alpha=32,
52 | lora_r=8,
53 | lora_dropout=0.1,
54 | )
55 |
56 | model = SuperLazyLanguageModel(config)
57 |
58 | # Train like a normal pytorch model
59 | ```
60 | 4. You can use SLLM functionalities:
61 | ```python
62 | import torch
63 | from datasets import load_dataset
64 |
65 | from sllm.nn import SuperLazyLanguageModel
66 | from sllm.train import sft, prepare_dataset
67 |
68 | torch.manual_seed(42)
69 |
70 | name = "Qwen/Qwen2-0.5B-Instruct"
71 | dataset = load_dataset("yahma/alpaca-cleaned", split="train[:200]")
72 |
73 | dataset = prepare_dataset(
74 | model_name=name,
75 | instructions=dataset["instruction"],
76 | responses=dataset["output"],
77 | inputs=dataset["input"],
78 | max_seq_len=256,
79 | )
80 |
81 | model = SuperLazyLanguageModel(
82 | name=name,
83 | lora_alpha=32,
84 | lora_r=8,
85 | lora_dropout=0.1,
86 | )
87 |
88 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
89 | sft(model=model, dataset=dataset, optimizer=optimizer, batch_size=8, epochs=3)
90 | ```
91 |
92 | ## Contributing
93 | Whether you’re improving documentation, optimizing kernels, or adding new features, your help is invaluable.
94 |
95 | 1. Create a feature branch (`git checkout -b feature/awesome-improvement`).
96 | 2. Commit your changes (`git commit -m 'Add awesome feature'`).
97 | 3. Push to the branch (`git push origin feature/awesome-improvement`).
98 | 4. Open a Pull Request.
99 |
--------------------------------------------------------------------------------
/sllm/utils.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import math
3 | import os
4 | import random
5 | import shutil
6 |
7 | import numpy as np
8 | import torch
9 | from transformers import AutoModelForCausalLM
10 |
11 | from sllm.common import DTYPE, GRADIENT_DIR
12 |
13 |
14 | def load_tensor_from_storage(weight_path, shape, dtype=DTYPE, to_ram=False):
15 | """
16 | Load a tensor from a binary file.
17 |
18 | This function reads raw data from a file using torch.from_file and reshapes it into the desired tensor shape.
19 | If the flag 'to_ram' is set to True, the tensor is cloned into RAM to avoid memory-mapping.
20 |
21 | Args:
22 | weight_path (str): The file path to the binary weight file.
23 | shape (tuple or int): Desired shape of the tensor. If a tuple is provided, the total number
24 | of elements is computed as the product of the tuple components; otherwise, the shape is treated as the total size.
25 | dtype (torch.dtype, optional): Data type of the tensor. Defaults to DTYPE.
26 | to_ram (bool, optional): If True, clones the tensor to ensure it resides in RAM. Defaults to False.
27 |
28 | Returns:
29 | torch.Tensor: The tensor with the specified shape loaded from the file.
30 | """
31 | if isinstance(shape, tuple):
32 | size = math.prod(shape)
33 | else:
34 | size = shape
35 |
36 | data = torch.from_file(weight_path, shared=False, size=size, dtype=dtype)
37 | data = data.view(shape)
38 |
39 | if to_ram:
40 | data = data.clone()
41 |
42 | return data
43 |
44 |
45 | @torch._dynamo.disable
46 | def save_tensor_to_storage(weight_path, data):
47 | """
48 | Save a tensor to a binary file.
49 |
50 | The function first ensures the tensor is contiguous and detached from the computation graph.
51 | It then converts the tensor to a NumPy array and writes it to the specified file in binary format.
52 |
53 | Args:
54 | weight_path (str): The file path where the tensor will be saved.
55 | data (torch.Tensor): The tensor to be saved.
56 |
57 | Returns:
58 | None
59 | """
60 | data.contiguous().detach().numpy().tofile(weight_path)
61 |
62 |
63 | def download_weights(weight_dir, model_name):
64 | """
65 | Download and save the pretrained model weights to a local directory.
66 |
67 | This function checks if the local weight directory already exists. If not, it creates the directory,
68 | loads the pretrained model using Hugging Face's AutoModelForCausalLM with the specified torch dtype,
69 | and saves each model parameter as a binary file in the directory. After saving, the model is deleted
70 | and garbage is collected to free memory.
71 |
72 | Args:
73 | weight_dir (str): The local directory path where weights will be stored.
74 | model_name (str): The name or identifier of the pretrained model to download.
75 |
76 | Returns:
77 | None
78 | """
79 | if os.path.exists(weight_dir):
80 | return
81 |
82 | os.makedirs(weight_dir, exist_ok=True)
83 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=DTYPE)
84 |
85 | for name, param in model.named_parameters():
86 | file_path = f"{weight_dir}/{name}.bin"
87 | param.detach().numpy().tofile(file_path)
88 |
89 | del model
90 | gc.collect()
91 |
92 |
93 | def remove_weights(weight_dir):
94 | """
95 | Remove the weights stored in a specified directory or file.
96 |
97 | If the provided path is a directory, the entire directory is deleted recursively.
98 | Otherwise, if it is a file, the file is removed.
99 |
100 | Args:
101 | weight_dir (str): The file or directory path where the weights are stored.
102 |
103 | Returns:
104 | None
105 | """
106 | if os.path.exists(weight_dir):
107 | if os.path.isdir(weight_dir):
108 | shutil.rmtree(weight_dir)
109 | else:
110 | os.remove(weight_dir)
111 |
112 |
113 | def clear_gradient_dir():
114 | """
115 | Clear the gradient directory.
116 |
117 | This function removes the entire gradient directory (if it exists) and then creates an empty directory with the same path.
118 | This is used to manage temporary storage during gradient computations.
119 |
120 | Returns:
121 | None
122 | """
123 | if os.path.exists(GRADIENT_DIR):
124 | shutil.rmtree(GRADIENT_DIR)
125 | os.makedirs(GRADIENT_DIR, exist_ok=True)
126 |
--------------------------------------------------------------------------------
/sllm/ops/matmul.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * @file matmul.cpp
3 | * @brief PyTorch extension module for performing scaled matrix multiplication.
4 | *
5 | * This file implements functions to perform matrix multiplications with a scaling factor.
6 | * It provides both a simple operation and a bundled version that can execute multiple operations
7 | * concurrently using OpenMP for batch processing. The module is exposed to Python via PyBind11.
8 | */
9 |
10 | #include
11 | #include
12 | #include
13 | #ifdef _OPENMP
14 | // OpenMP header included when available.
15 | #endif
16 |
17 | namespace py = pybind11;
18 |
19 | /**
20 | * @brief Performs scaled matrix multiplication.
21 | *
22 | * Computes the matrix product of A and B using torch::matmul and then scales the resulting tensor
23 | * by the provided scale factor.
24 | *
25 | * @param A A torch::Tensor representing the first matrix.
26 | * @param B A torch::Tensor representing the second matrix.
27 | * @param scale A double representing the scaling factor to be applied to the product.
28 | * @return torch::Tensor The scaled result of the matrix multiplication (i.e., (A * B) * scale).
29 | */
30 | torch::Tensor matmul(const torch::Tensor& A, const torch::Tensor& B, double scale) {
31 | return torch::matmul(A, B) * scale;
32 | }
33 |
34 | /**
35 | * @brief Executes a set of scaled matrix multiplications bundled together.
36 | *
37 | * Processes a vector of tuples where each tuple (referred to as a "bundle") contains three elements:
38 | * - The first element is a torch::Tensor representing the matrix A.
39 | * - The second element is a torch::Tensor representing the matrix B.
40 | * - The third element is a double specifying the scaling factor for the multiplication.
41 | *
42 | * The function reshapes each input tensor to separate batch dimensions from the matrix dimensions,
43 | * then performs the scaled matrix multiplication over the batch elements. When B_reshaped has a single
44 | * batch element, it is broadcasted across the entire batch of A_reshaped. OpenMP is used for parallelization
45 | * over the batch dimension if it is available.
46 | *
47 | * The final resulting tensor is reshaped to match the original input dimensions (except for the last dimension
48 | * which becomes the last dimension of B). The function returns a vector of results with each result corresponding
49 | * to a bundle in the input.
50 | *
51 | * @param matmul_bundles A vector of py::tuple objects, each containing:
52 | * - A torch::Tensor for matrix A,
53 | * - A torch::Tensor for matrix B,
54 | * - A double value for the scaling factor.
55 | * @return std::vector A vector containing the resulting tensors after performing the
56 | * bundled scaled matrix multiplications.
57 | */
58 | std::vector bundled_scaled_matmul(const std::vector& matmul_bundles) {
59 | std::vector results;
60 |
61 | for (const auto& bundle : matmul_bundles) {
62 | auto original_A = bundle[0].cast();
63 | auto B = bundle[1].cast();
64 | double scale = bundle[2].cast();
65 |
66 | auto A_shape = original_A.sizes();
67 | auto B_shape = B.sizes();
68 |
69 | // Reshape to isolate the batch dimension from matrix dimensions.
70 | auto A_reshaped = original_A.reshape({-1, A_shape[A_shape.size()-2], A_shape[A_shape.size()-1]});
71 | auto B_reshaped = B.reshape({-1, B_shape[B_shape.size()-2], B_shape[B_shape.size()-1]});
72 | int64_t batch = A_reshaped.size(0);
73 | auto C = torch::empty({batch, A_reshaped.size(1), B_reshaped.size(2)}, torch::kFloat32);
74 |
75 | // Parallelize over the batch dimension using OpenMP.
76 | #pragma omp parallel for
77 | for (int64_t i = 0; i < batch; ++i) {
78 | auto A_i = A_reshaped[i];
79 | // Use the same B for all if B_reshaped only has one element; otherwise, index accordingly.
80 | auto b_tensor = (B_reshaped.size(0) == 1) ? B_reshaped[0] : B_reshaped[i];
81 | C[i] = matmul(A_i, b_tensor, scale);
82 | }
83 |
84 | // Restore the original batch shape with the new matrix multiplication dimensions.
85 | std::vector result_shape;
86 | for (size_t j = 0; j < A_shape.size()-1; ++j) {
87 | result_shape.push_back(A_shape[j]);
88 | }
89 | result_shape.push_back(B_shape[B_shape.size()-1]);
90 | results.push_back(C.reshape(result_shape));
91 | }
92 | return results;
93 | }
94 |
95 | /**
96 | * @brief PyBind11 module definition for the matrix multiplication extension.
97 | *
98 | * Exposes the bundled_scaled_matmul function to Python, allowing it to be called as a regular Python function.
99 | * The module is named "matmul".
100 | */
101 | PYBIND11_MODULE(matmul, m) {
102 | m.def("bundled_scaled_matmul", &bundled_scaled_matmul, "Distributes bundles of scaled matrix multiplication operations to remote devices concurrently");
103 | }
104 |
--------------------------------------------------------------------------------
/sllm/config.py:
--------------------------------------------------------------------------------
1 | """
2 | This module defines the configuration class for the transformer-based model.
3 | It encapsulates all the hyperparameters and weight configurations required for the model
4 | and handles downloading and merging with the pretrained configuration from Hugging Face.
5 |
6 | Key functionalities include:
7 | - Setting model hyperparameters such as hidden size, number of layers, attention heads, etc.
8 | - Configuring LoRA (Low-Rank Adaptation) parameters for efficient fine-tuning.
9 | - Downloading necessary weights for the specified model.
10 | - Overwriting default parameters with those from the pretrained configuration.
11 | """
12 |
13 | import os
14 |
15 | from transformers import AutoConfig
16 |
17 | from sllm.common import WEIGHT_DIR
18 | from sllm.utils import download_weights, remove_weights
19 |
20 |
21 | class Config:
22 | """
23 | A configuration object for initializing the transformer-based model.
24 |
25 | This class holds various hyperparameters required by the model. It downloads pretrained
26 | weights if needed and loads configurations from a pretrained model, overriding any matching
27 | attributes provided in the constructor. In addition to standard model settings, it also
28 | includes configuration details for LoRA adaptation.
29 |
30 | Class Attributes:
31 | model_type (str): Default type of the model, set to "transformer".
32 | keys_to_ignore_at_inference (List[str]): List of keys (e.g., "past_key_values") to be
33 | ignored during inference.
34 | """
35 |
36 | model_type = "transformer"
37 | keys_to_ignore_at_inference = ["past_key_values"]
38 |
39 | def __init__(
40 | self,
41 | model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
42 | vocab_size=151936,
43 | bos_token_id=151643,
44 | eos_token_id=151643,
45 | model_type="qwen2",
46 | hidden_size=1536,
47 | intermediate_size=8960,
48 | num_hidden_layers=28,
49 | num_attention_heads=12,
50 | num_key_value_heads=2,
51 | hidden_act="silu",
52 | max_position_embeddings=131072,
53 | initializer_range=0.02,
54 | rms_norm_eps=1e-6,
55 | use_cache=True,
56 | tie_word_embeddings=False,
57 | rope_theta=10000.0,
58 | rope_scaling=None,
59 | use_mrope=False,
60 | use_sliding_window=False,
61 | sliding_window=None,
62 | max_window_layers=21,
63 | attention_dropout=0.0,
64 | torch_dtype="bfloat16",
65 | pad_token_id=None,
66 | output_attentions=False,
67 | output_hidden_states=False,
68 | use_return_dict=True,
69 | lora_alpha=32,
70 | lora_r=8,
71 | lora_dropout=0.1,
72 | **kwargs,
73 | ):
74 | """
75 | Initialize the Config object with the provided hyperparameters and then update with the pretrained configuration.
76 |
77 | The initialization process involves:
78 | 1. Setting initial values for many model hyperparameters and LoRA parameters.
79 | 2. Determining the local weight directory based on the provided model name.
80 | 3. Downloading the pretrained weights (if not already present) into the specified directory.
81 | 4. Loading additional configuration parameters from the pretrained model via AutoConfig,
82 | and updating the current configuration for any matching attributes.
83 |
84 | Side Effects:
85 | - Downloads the pretrained weights (if not already available) to a local weight directory.
86 | - Updates attributes of this Config instance based on the pretrained model's configuration.
87 | """
88 | self.vocab_size = vocab_size
89 | self.max_position_embeddings = max_position_embeddings
90 | self.hidden_size = hidden_size
91 | self.intermediate_size = intermediate_size
92 | self.num_hidden_layers = num_hidden_layers
93 | self.num_attention_heads = num_attention_heads
94 | self.use_sliding_window = use_sliding_window
95 | self.sliding_window = sliding_window
96 | self.max_window_layers = max_window_layers
97 |
98 | if num_key_value_heads is None:
99 | num_key_value_heads = num_attention_heads
100 |
101 | self.num_key_value_heads = num_key_value_heads
102 | self.hidden_act = hidden_act
103 | self.initializer_range = initializer_range
104 | self.rms_norm_eps = rms_norm_eps
105 | self.use_cache = use_cache
106 | self.rope_theta = rope_theta
107 | self.rope_scaling = rope_scaling
108 | self.attention_dropout = attention_dropout
109 |
110 | self.use_mrope = use_mrope
111 | self.torch_dtype = torch_dtype
112 | self.tie_word_embeddings = tie_word_embeddings
113 | self.bos_token_id = bos_token_id
114 | self.eos_token_id = eos_token_id
115 | self.model_type = model_type
116 | self.pad_token_id = pad_token_id
117 | self.output_attentions = output_attentions
118 | self.output_hidden_states = output_hidden_states
119 | self.use_return_dict = use_return_dict
120 |
121 | self.lora_alpha = lora_alpha
122 | self.lora_r = lora_r
123 | self.lora_dropout = lora_dropout
124 |
125 | # Determine local weight storage directory based on the model name.
126 | model_dir = model_name.split("/")[-1]
127 | self.weight_dir = f"{WEIGHT_DIR}/{model_dir}"
128 |
129 | # Ensure that the pretrained weights are downloaded locally.
130 | download_weights(self.weight_dir, model_name)
131 |
132 | # Load additional configuration parameters from the pretrained model.
133 | config = AutoConfig.from_pretrained(model_name, **kwargs)
134 |
135 | # Overwrite any matching attributes in this Config instance with the pretrained values.
136 | for key, value in config.__dict__.items():
137 | if hasattr(self, key):
138 | setattr(self, key, value)
139 |
--------------------------------------------------------------------------------
/sllm/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datasets import Dataset
3 | from tqdm import tqdm
4 | from transformers import AutoTokenizer
5 | import time # add time module if not imported
6 |
7 | from sllm.utils import clear_gradient_dir
8 | from sllm.common import MINI_BATCH_SIZE
9 |
10 |
11 | def format_example(example):
12 | """
13 | Format a single example into a prompt and full training text.
14 |
15 | The function builds a prompt with a header (containing the instruction and, if available,
16 | additional input) followed by a "Response:" marker. It then concatenates the prompt with
17 | the output to form the full text for training.
18 |
19 | Args:
20 | example (dict): A dictionary with keys 'instruction', 'input', and 'output'.
21 |
22 | Returns:
23 | dict: A dictionary with keys:
24 | - "text": A string that concatenates the prompt and the output.
25 | - "prompt_text": The text before adding the output.
26 | """
27 | prompt_text = f"Instruction: {example['instruction']}\n"
28 | if example["input"]:
29 | prompt_text += f"Input: {example['input']}\n"
30 | prompt_text += "Response:"
31 | full_text = prompt_text + " " + example["output"]
32 | return {"text": full_text, "prompt_text": prompt_text}
33 |
34 |
35 | def prepare_dataset(model_name, instructions, responses, inputs=None, max_seq_len=256):
36 | """
37 | Prepare and tokenize a dataset for training.
38 |
39 | The function creates a Hugging Face Dataset from the provided instructions, responses, and optional inputs.
40 | It then formats each example using `format_example`, tokenizes the text (padding/truncating to `max_seq_len`),
41 | and creates a special label mask in which the tokens corresponding to the prompt are replaced by -100 (to be ignored
42 | during loss computation). Finally, the tokenized dataset is split into mini-batches.
43 |
44 | Args:
45 | model_name (str): The pretrained model name used to load the tokenizer.
46 | instructions (List[str]): A list of instruction strings.
47 | responses (List[str]): A list of response strings.
48 | inputs (List[str] or None, optional): A list of additional input strings for examples; defaults to None.
49 | max_seq_len (int, optional): Maximum sequence length for tokenization; defaults to 256.
50 |
51 | Returns:
52 | dict: A dictionary with keys 'input_ids', 'attention_mask', and 'labels',
53 | where each value is a list of tensors split into mini-batches of size MINI_BATCH_SIZE.
54 | """
55 | def tokenize_fn(example):
56 | tokenized = tokenizer(
57 | example["text"],
58 | truncation=True,
59 | padding="max_length",
60 | max_length=max_seq_len,
61 | )
62 | prompt_tokens = tokenizer(
63 | example["prompt_text"],
64 | truncation=False,
65 | add_special_tokens=False,
66 | )["input_ids"]
67 | tokenized["prompt_length"] = len(prompt_tokens)
68 | return tokenized
69 |
70 | def mask_labels(example):
71 | # Copy the tokenized input_ids to labels.
72 | labels = example["input_ids"].copy()
73 | prompt_length = example["prompt_length"]
74 | # Mask the tokens corresponding to the prompt (set to -100) so that they do not contribute to the loss.
75 | for i in range(min(prompt_length, len(labels))):
76 | labels[i] = -100
77 | example["labels"] = labels
78 | return example
79 |
80 | dataset = Dataset.from_dict(
81 | {
82 | "instruction": instructions,
83 | "input": inputs,
84 | "output": responses,
85 | }
86 | )
87 | dataset = dataset.map(format_example)
88 |
89 | tokenizer = AutoTokenizer.from_pretrained(model_name)
90 | dataset = dataset.map(tokenize_fn)
91 | dataset = dataset.map(mask_labels)
92 |
93 | dataset = dataset.shuffle(seed=42)
94 |
95 | input_ids = torch.tensor(dataset["input_ids"])
96 | attention_masks = torch.tensor(dataset["attention_mask"])
97 | labels = torch.tensor(dataset["labels"])
98 | print(f"Training on {input_ids.numel() // 1000}k tokens")
99 | return {
100 | "input_ids": input_ids.split(MINI_BATCH_SIZE),
101 | "attention_mask": attention_masks.split(MINI_BATCH_SIZE),
102 | "labels": labels.split(MINI_BATCH_SIZE),
103 | }
104 |
105 |
106 | def sft(model, dataset, optimizer, batch_size=1, epochs=1):
107 | """
108 | Perform supervised fine-tuning (SFT) on a language model.
109 |
110 | This function trains the model on the provided dataset using gradient accumulation.
111 | The effective batch size is determined as the product of MINI_BATCH_SIZE and grad_accum_steps.
112 |
113 | Gradient Accumulation Derivation:
114 | If the provided batch_size is larger than MINI_BATCH_SIZE (the size of each mini-batch),
115 | the gradients of several forward/backward passes are accumulated before performing an optimizer step.
116 | In order to ensure that the effective gradient is equivalent to that computed on the entire batch,
117 | the loss of each mini-batch is divided by grad_accum_steps. That is,
118 | loss_effective = (loss_mini_batch / grad_accum_steps)
119 | This scaling ensures that when the gradients from grad_accum_steps mini-batches are summed,
120 | the resulting update is equivalent to the gradient of the average loss over the full batch.
121 |
122 | During training, the function reports the average loss, seconds per sample, and seconds elapsed per epoch.
123 |
124 | Args:
125 | model (torch.nn.Module): The model to be fine-tuned.
126 | dataset (dict): A dictionary with keys 'input_ids', 'attention_mask', and 'labels',
127 | where each value is a list of tensors representing mini-batches.
128 | optimizer (torch.optim.Optimizer): The optimizer for the model.
129 | batch_size (int, optional): The total batch size for each optimizer update; defaults to 1.
130 | epochs (int, optional): Number of training epochs; defaults to 1.
131 |
132 | Returns:
133 | None
134 | """
135 | grad_accum_steps = batch_size // MINI_BATCH_SIZE
136 | if grad_accum_steps < 1:
137 | grad_accum_steps = 1
138 |
139 | total_batches = len(dataset["input_ids"])
140 | for epoch in range(epochs):
141 | epoch_loss = 0
142 | model.train()
143 | optimizer.zero_grad()
144 | clear_gradient_dir()
145 | accum_steps = 0
146 |
147 | epoch_start_time = time.time()
148 |
149 | with tqdm(total=total_batches, desc=f"Epoch {epoch+1}") as pbar:
150 | zipped_dataset = zip(
151 | dataset["input_ids"], dataset["attention_mask"], dataset["labels"]
152 | )
153 | for i, batch in enumerate(zipped_dataset, start=1):
154 | batch_input, batch_mask, batch_labels = batch
155 | output = model(
156 | input_ids=batch_input,
157 | attention_mask=batch_mask,
158 | labels=batch_labels,
159 | )
160 | loss = output.loss
161 | # Scale the mini-batch loss to average over grad_accum_steps
162 | loss = loss / grad_accum_steps
163 | loss.backward()
164 | epoch_loss += loss.item() * grad_accum_steps
165 | accum_steps += 1
166 |
167 | if accum_steps % grad_accum_steps == 0:
168 | optimizer.step()
169 | optimizer.zero_grad()
170 | accum_steps = 0
171 |
172 | avg_loss = epoch_loss / i
173 | elapsed = time.time() - epoch_start_time
174 | sec_per_sample = elapsed / (i * MINI_BATCH_SIZE)
175 | sec_per_epoch = time.time() - epoch_start_time
176 | pbar.set_postfix(
177 | loss=f"{avg_loss:.1f}",
178 | sec_per_sample=f"{sec_per_sample:.2f}",
179 | sec_per_epoch=f"{sec_per_epoch:.2f}",
180 | )
181 | pbar.update(1)
182 |
183 | # Apply any remaining accumulated gradients.
184 | if accum_steps > 0:
185 | optimizer.step()
186 | optimizer.zero_grad()
187 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/sllm/nn/autodiff.py:
--------------------------------------------------------------------------------
1 | """
2 | This module implements custom autograd functions for performing matrix multiplications,
3 | including LoRA-adapted operations for efficient fine-tuning. It leverages lazy weight loading,
4 | bundled scaled matrix multiplications, and gradient clipping to manage memory and computational
5 | efficiency.
6 |
7 | Functions:
8 | clip_grad(grad): Clips a gradient tensor based on MAX_GRAD_NORM.
9 |
10 | Classes:
11 | MatmulFunction: Custom autograd function for computing scaled matrix multiplications.
12 | BundledMatmulFunction: Custom autograd function that computes multiple scaled matrix multiplications together.
13 | LoraFunction: Custom autograd function to perform a LoRA-adapted linear operation.
14 | LoraQKVLinearFunction: Custom autograd function for query, key, and value projections with LoRA adaptation.
15 | """
16 |
17 | import os
18 |
19 | import torch
20 | import torch.nn.functional as F
21 |
22 | from sllm.common import GRADIENT_DIR, MAX_GRAD_NORM
23 | from sllm.ops import bundled_scaled_matmul
24 | from sllm.utils import load_tensor_from_storage, save_tensor_to_storage
25 |
26 |
27 | def clip_grad(grad):
28 | """
29 | Clip a gradient tensor to prevent exploding gradients.
30 |
31 | If the norm of `grad` exceeds MAX_GRAD_NORM, scale it down accordingly.
32 |
33 | Args:
34 | grad (Tensor or None): The gradient tensor to be clipped.
35 |
36 | Returns:
37 | Tensor or None: The clipped gradient tensor, or None if input is None.
38 | """
39 | if grad is None:
40 | return None
41 | norm = grad.norm()
42 | if norm > MAX_GRAD_NORM:
43 | grad = grad * (MAX_GRAD_NORM / (norm + 1e-6))
44 | return grad
45 |
46 |
47 | class MatmulFunction(torch.autograd.Function):
48 | """
49 | Custom autograd function for performing a scaled matrix multiplication.
50 |
51 | Forward:
52 | Computes Y = bundled_scaled_matmul([(A, B, scale)])[0], which is equivalent to
53 | Y = (A @ B) * scale.
54 |
55 | Backward Derivation:
56 | Given Y = scale * (A @ B), standard matrix calculus yields:
57 | - dY/dA = scale * grad_output @ B^T
58 | - dY/dB = scale * A^T @ grad_output
59 | Since the scale is a constant and non-differentiable here, its derivative is not returned
60 | (i.e. None). In this implementation, the bundled_scaled_matmul function is used to compute:
61 | grad_A = grad_output @ B^T and grad_B = A^T @ grad_output,
62 | assuming that the scale factor is applied in the forward pass and treated as a constant.
63 | """
64 |
65 | @staticmethod
66 | def forward(ctx, A, B, scale):
67 | """
68 | Forward pass for the scaled matrix multiplication.
69 |
70 | Args:
71 | A (Tensor): Left-hand side matrix.
72 | B (Tensor): Right-hand side matrix.
73 | scale (float): Scaling factor applied after matrix multiplication.
74 |
75 | Returns:
76 | Tensor: Result of the scaled matrix multiplication.
77 | """
78 | ctx.save_for_backward(A, B)
79 | return bundled_scaled_matmul([(A, B, scale)])[0]
80 |
81 | @staticmethod
82 | def backward(ctx, grad_output):
83 | """
84 | Backward pass for the scaled matrix multiplication.
85 |
86 | Derivation:
87 | Let Y = scale * (A @ B). Then, by the chain rule:
88 | dL/dA = dL/dY @ dY/dA = grad_output @ (B^T) * scale,
89 | dL/dB = (A^T) @ grad_output * scale.
90 | Here, the implementation uses bundled_scaled_matmul with a scaling factor of 1.0,
91 | effectively treating the scale as a constant whose derivative is omitted.
92 |
93 | Args:
94 | grad_output (Tensor): Gradient tensor propagated from subsequent layers.
95 |
96 | Returns:
97 | Tuple[Tensor, Tensor, None]: Gradients with respect to A, B, and None for scale.
98 | """
99 | A, B = ctx.saved_tensors
100 | bundles = [
101 | (grad_output, B.transpose(-2, -1), 1.0), # represents grad_output @ B^T
102 | (A.transpose(-2, -1), grad_output, 1.0), # represents A^T @ grad_output
103 | ]
104 | grad_A, grad_B = bundled_scaled_matmul(bundles)
105 | grad_A = clip_grad(grad_A)
106 | grad_B = clip_grad(grad_B)
107 | return grad_A, grad_B, None
108 |
109 |
110 | class BundledMatmulFunction(torch.autograd.Function):
111 | """
112 | Custom autograd function for performing multiple scaled matrix multiplications in batch.
113 |
114 | Forward:
115 | Accepts a list of bundles where each bundle is a tuple (M, N, scale) and returns a list
116 | of results computed by bundled_scaled_matmul.
117 |
118 | Backward Derivation:
119 | For each bundle where the forward operation is Y_i = M_i @ N_i * scale_i, the gradient
120 | derivation uses the standard identities:
121 | dL/dM_i = grad_output @ (N_i^T)
122 | dL/dN_i = (M_i^T) @ grad_output
123 | Here, a new bundle is constructed for each stored tuple in the forward pass with a fixed scale (1.0)
124 | and processed in batch.
125 | """
126 |
127 | @staticmethod
128 | def forward(ctx, bundles):
129 | """
130 | Forward pass for the bundled scaled matrix multiplications.
131 |
132 | Args:
133 | bundles (List[Tuple[Tensor, Tensor, float]]): A list of tuples, each containing two tensors
134 | and a scaling factor.
135 |
136 | Returns:
137 | List[Tensor]: A list of tensors resulting from the matrix multiplications.
138 | """
139 | ctx.save_for_backward(*bundles)
140 | return bundled_scaled_matmul(bundles)
141 |
142 | @staticmethod
143 | def backward(ctx, grad_output):
144 | """
145 | Backward pass for the bundled matrix multiplications.
146 |
147 | Derivation:
148 | For each bundle (M, N, scale) corresponding to Y = M @ N * scale, the gradient with respect
149 | to M is grad_output @ (N^T) and with respect to N is (M^T) @ grad_output. The backward pass
150 | constructs a new list of bundles where for each saved bundle it forms:
151 | (grad_output, N^T, 1.0)
152 | and computes the corresponding gradients in one batched call.
153 |
154 | Args:
155 | grad_output (Tensor): Gradient tensor propagated from subsequent layers.
156 |
157 | Returns:
158 | Tuple[List[Tensor]]: Gradients corresponding to the input bundles.
159 | """
160 | bundles = ctx.saved_tensors
161 | triple_list = [
162 | (grad_output, bundle[1].transpose(-2, -1), 1.0) for bundle in bundles
163 | ]
164 | grad_bundles = bundled_scaled_matmul(triple_list)
165 | return grad_bundles
166 |
167 |
168 | class LoraFunction(torch.autograd.Function):
169 | """
170 | Custom autograd function for LoRA-adapted linear operations.
171 |
172 | Forward:
173 | Computes the effective weight as W_eff = W + (A @ B * scale) and then computes:
174 | Y = bundled_scaled_matmul([(x, W_eff^T, 1.0)])[0]
175 | Optionally adds a bias.
176 |
177 | Backward Derivation:
178 | Let W_eff = W + (A @ B * scale) and Y = x @ (W_eff)^T.
179 | Using the chain rule:
180 | dL/dx = grad_output @ W_eff
181 | dL/dW_eff = x^T @ grad_output
182 | Then, the derivatives with respect to A and B are computed via:
183 | dW_eff/dA = B * scale, dW_eff/dB = A * scale.
184 | The backward pass first computes an intermediate gradient E from the matrix multiplication,
185 | then derives:
186 | grad_A = E @ (B^T) * scale
187 | grad_B = (A^T) @ E * scale.
188 | The pre-trained weight W is assumed fixed, so its gradient is not computed.
189 | """
190 |
191 | @staticmethod
192 | def forward(ctx, x, A, B, W_path, scale, bias=None):
193 | """
194 | Forward pass for the LoRA function.
195 |
196 | Args:
197 | x (Tensor): Input tensor.
198 | A (Tensor): LoRA parameter A.
199 | B (Tensor): LoRA parameter B.
200 | W_path (str): File path to the pre-trained weight tensor.
201 | scale (float): Scaling factor for the LoRA update.
202 | bias (Tensor, optional): Bias tensor to add to the output.
203 |
204 | Returns:
205 | Tensor: Output tensor after applying the effective weight and bias.
206 | """
207 | ctx.save_for_backward(A, B)
208 | ctx.scale = scale
209 | ctx.x_path = os.path.join(GRADIENT_DIR, W_path.split("/")[-1] + ".x.bin")
210 | save_tensor_to_storage(ctx.x_path, x)
211 |
212 | W = load_tensor_from_storage(
213 | weight_path=W_path,
214 | shape=(A.shape[0], B.shape[1]),
215 | dtype=A.dtype,
216 | to_ram=False,
217 | )
218 |
219 | effective_W = W + (A @ B * scale)
220 | Wx = bundled_scaled_matmul([(x, effective_W.transpose(-2, -1), 1.0)])[0]
221 |
222 | ctx.effecttive_weight = effective_W
223 |
224 | if bias is not None:
225 | Wx = Wx + bias
226 | return Wx
227 |
228 | @staticmethod
229 | def backward(ctx, grad_output):
230 | """
231 | Backward pass for the LoRA function.
232 |
233 | Derivation:
234 | Given:
235 | Y = x @ (W_eff)^T, with W_eff = W + (A @ B) * scale.
236 | Then:
237 | dL/dx = grad_output @ W_eff.
238 | dL/dW_eff = x^T @ grad_output.
239 | Since W is fixed, we only differentiate through the LoRA update.
240 | The gradient of W_eff with respect to A is: dW_eff/dA = B * scale,
241 | and with respect to B is: dW_eff/dB = A * scale.
242 | Therefore, an intermediate gradient E = dL/dW_eff is computed and then:
243 | grad_A = E @ (B^T) * scale,
244 | grad_B = (A^T) @ E * scale.
245 |
246 | Args:
247 | grad_output (Tensor): Gradient tensor from subsequent operations.
248 |
249 | Returns:
250 | Tuple: Gradients with respect to x, A, B, None (for weight), None (for scale), and None (for bias).
251 | """
252 | x = load_tensor_from_storage(
253 | ctx.x_path, shape=grad_output.shape, dtype=grad_output.dtype, to_ram=False
254 | )
255 | A, B = ctx.saved_tensors
256 | scale = ctx.scale
257 |
258 | effective_W = ctx.effecttive_weight
259 |
260 | bundles = [
261 | (grad_output, effective_W, 1.0), # dL/dx = grad_output @ W_eff
262 | (x.transpose(-2, -1), grad_output, 1.0), # dL/dW_eff = x^T @ grad_output
263 | ]
264 | grad_x, E = bundled_scaled_matmul(bundles)
265 |
266 | bundles = [
267 | (E, B.transpose(-2, -1), scale), # grad_A = E @ B^T * scale
268 | (A.transpose(-2, -1), E, scale) # grad_B = A^T @ E * scale
269 | ]
270 | grad_A, grad_B = bundled_scaled_matmul(bundles)
271 |
272 | grad_w = None
273 | grad_scale = None
274 | grad_b = None
275 |
276 | return grad_x, grad_A, grad_B, grad_w, grad_scale, grad_b
277 |
278 |
279 | class LoraQKVLinearFunction(torch.autograd.Function):
280 | """
281 | Custom autograd function for LoRA-adapted query/key/value projections.
282 |
283 | Forward:
284 | For each projection (Q, K, V), the effective weight is computed as:
285 | effective = weight + (LoRA_A @ LoRA_B * scaling)
286 | and the projections are computed as:
287 | Projection = bundled_scaled_matmul([(x, effective, 1.0)])
288 | Biases are added if provided.
289 |
290 | Backward Derivation:
291 | Let Q = x @ (q_effective)^T + bias, with q_effective = q_weight + (q_proj_lora_A @ q_proj_lora_B * scaling).
292 | By the chain rule:
293 | dL/dx = sum_{proj in {Q, K, V}} [ (x^T)' from that projection ],
294 | where for each projection the gradients with respect to the effective weights are obtained by:
295 | dL/d(effective) = x^T @ grad_projection.
296 | Then, the gradients with respect to the low-rank parameters (LoRA_A and LoRA_B) are computed
297 | using:
298 | grad_LoRA_A = dL/d(effective) @ (LoRA_B)^T * scaling,
299 | grad_LoRA_B = (LoRA_A)^T @ dL/d(effective) * scaling.
300 | The input x gradient is computed as the sum of contributions from Q, K, and V pathways.
301 | """
302 |
303 | @staticmethod
304 | def forward(
305 | ctx,
306 | x,
307 | q_proj_weight_path,
308 | k_proj_weight_path,
309 | v_proj_weight_path,
310 | q_proj_bias,
311 | k_proj_bias,
312 | v_proj_bias,
313 | q_proj_lora_A,
314 | q_proj_lora_B,
315 | k_proj_lora_A,
316 | k_proj_lora_B,
317 | v_proj_lora_A,
318 | v_proj_lora_B,
319 | scaling,
320 | ):
321 | """
322 | Forward pass for the LoRA QKV function.
323 |
324 | Args:
325 | x (Tensor): Input tensor.
326 | q_proj_weight_path (str): File path for the query projection weight.
327 | k_proj_weight_path (str): File path for the key projection weight.
328 | v_proj_weight_path (str): File path for the value projection weight.
329 | q_proj_bias (Tensor or None): Bias tensor for the query projection.
330 | k_proj_bias (Tensor or None): Bias tensor for the key projection.
331 | v_proj_bias (Tensor or None): Bias tensor for the value projection.
332 | q_proj_lora_A (Tensor): LoRA parameter A for query projection.
333 | q_proj_lora_B (Tensor): LoRA parameter B for query projection.
334 | k_proj_lora_A (Tensor): LoRA parameter A for key projection.
335 | k_proj_lora_B (Tensor): LoRA parameter B for key projection.
336 | v_proj_lora_A (Tensor): LoRA parameter A for value projection.
337 | v_proj_lora_B (Tensor): LoRA parameter B for value projection.
338 | scaling (float): Scaling factor for the LoRA update.
339 |
340 | Returns:
341 | Tuple[Tensor, Tensor, Tensor]: The projected query, key, and value tensors.
342 | """
343 | ctx.save_for_backward(
344 | q_proj_lora_A,
345 | q_proj_lora_B,
346 | k_proj_lora_A,
347 | k_proj_lora_B,
348 | v_proj_lora_A,
349 | v_proj_lora_B,
350 | )
351 |
352 | ctx.scaling = scaling
353 | ctx.q_bias_flag = q_proj_bias is not None
354 | ctx.k_bias_flag = k_proj_bias is not None
355 | ctx.v_bias_flag = v_proj_bias is not None
356 |
357 | # Save input x shape for later and store x on disk.
358 | ctx.x_shape = x.shape
359 | ctx.x_path = os.path.join(
360 | GRADIENT_DIR, q_proj_weight_path.split("/")[-1] + ".x.bin"
361 | )
362 | save_tensor_to_storage(ctx.x_path, x)
363 |
364 | q_shape = (q_proj_lora_A.shape[0], x.shape[-1])
365 | kv_shape = (k_proj_lora_B.shape[-1], x.shape[-1])
366 | ctx.q_shape = q_shape
367 | ctx.kv_shape = kv_shape
368 |
369 | q_proj_weight = load_tensor_from_storage(
370 | weight_path=q_proj_weight_path,
371 | shape=q_shape,
372 | dtype=q_proj_lora_A.dtype,
373 | to_ram=False,
374 | ).transpose(-2, -1)
375 | k_proj_weight = load_tensor_from_storage(
376 | weight_path=k_proj_weight_path,
377 | shape=kv_shape,
378 | dtype=k_proj_lora_A.dtype,
379 | to_ram=False,
380 | ).transpose(-2, -1)
381 | v_proj_weight = load_tensor_from_storage(
382 | weight_path=v_proj_weight_path,
383 | shape=kv_shape,
384 | dtype=v_proj_lora_A.dtype,
385 | to_ram=False,
386 | ).transpose(-2, -1)
387 |
388 | # Compute effective weights with LoRA update.
389 | q_effective = q_proj_weight + (q_proj_lora_A @ q_proj_lora_B * scaling)
390 | k_effective = k_proj_weight + (k_proj_lora_A @ k_proj_lora_B * scaling)
391 | v_effective = v_proj_weight + (v_proj_lora_A @ v_proj_lora_B * scaling)
392 |
393 | ctx.q_effective = q_effective
394 | ctx.k_effective = k_effective
395 | ctx.v_effective = v_effective
396 |
397 | bundles = [
398 | (x, q_effective, 1.0),
399 | (x, k_effective, 1.0),
400 | (x, v_effective, 1.0),
401 | ]
402 |
403 | Q, K, V = bundled_scaled_matmul(bundles)
404 |
405 | if q_proj_bias is not None:
406 | Q = Q + q_proj_bias
407 | if k_proj_bias is not None:
408 | K = K + k_proj_bias
409 | if v_proj_bias is not None:
410 | V = V + v_proj_bias
411 |
412 | return Q, K, V
413 |
414 | @staticmethod
415 | def backward(ctx, grad_Q, grad_K, grad_V):
416 | """
417 | Backward pass for the LoRA QKV function.
418 |
419 | Derivation:
420 | For each projection (Q, K, V), let:
421 | effective = weight + (LoRA_A @ LoRA_B * scaling)
422 | and the forward operation is:
423 | Projection = x @ (effective)^T (+ bias).
424 | The gradients are computed as follows:
425 | 1. Compute dL/d(effective) for each projection by:
426 | dL/d(effective) = x^T @ grad_projection
427 | 2. The gradient with respect to x is obtained by summing over contributions:
428 | grad_x = grad_xQ + grad_xK + grad_xV
429 | 3. Using the chain rule and the linearity of the LoRA update:
430 | grad_LoRA_A = dL/d(effective) @ (LoRA_B)^T * scaling,
431 | grad_LoRA_B = (LoRA_A)^T @ dL/d(effective) * scaling.
432 | Here, the bundled_scaled_matmul function is used to compute both the gradients for x
433 | (from each of Q, K, V) and the gradients for the effective weights, which are then propagated
434 | to the low-rank LoRA parameters.
435 |
436 | Args:
437 | grad_Q (Tensor): Gradient with respect to the query output.
438 | grad_K (Tensor): Gradient with respect to the key output.
439 | grad_V (Tensor): Gradient with respect to the value output.
440 |
441 | Returns:
442 | Tuple: Gradients for each input parameter in the same order as in the forward pass.
443 | """
444 | (
445 | q_proj_lora_A,
446 | q_proj_lora_B,
447 | k_proj_lora_A,
448 | k_proj_lora_B,
449 | v_proj_lora_A,
450 | v_proj_lora_B,
451 | ) = ctx.saved_tensors
452 | scale = ctx.scaling
453 | x = load_tensor_from_storage(
454 | ctx.x_path, shape=ctx.x_shape, dtype=grad_Q.dtype, to_ram=False
455 | )
456 |
457 | q_effective = ctx.q_effective
458 | k_effective = ctx.k_effective
459 | v_effective = ctx.v_effective
460 |
461 | bundles = [
462 | (grad_Q, q_effective.transpose(-2, -1), 1.0),
463 | (grad_K, k_effective.transpose(-2, -1), 1.0),
464 | (grad_V, v_effective.transpose(-2, -1), 1.0),
465 | (x.transpose(-2, -1), grad_Q, 1.0),
466 | (x.transpose(-2, -1), grad_K, 1.0),
467 | (x.transpose(-2, -1), grad_V, 1.0),
468 | ]
469 | (
470 | grad_xQ,
471 | grad_xK,
472 | grad_xV,
473 | grad_effective_q,
474 | grad_effective_k,
475 | grad_effective_v,
476 | ) = bundled_scaled_matmul(bundles)
477 |
478 | grad_x = grad_xQ + grad_xK + grad_xV
479 |
480 | grad_q_bias = grad_Q.sum(dim=0) if ctx.q_bias_flag else None
481 | grad_k_bias = grad_K.sum(dim=0) if ctx.k_bias_flag else None
482 | grad_v_bias = grad_V.sum(dim=0) if ctx.v_bias_flag else None
483 |
484 | bundles = [
485 | (grad_effective_q, q_proj_lora_B.transpose(-2, -1), scale),
486 | (q_proj_lora_A.transpose(-2, -1), grad_effective_q, scale),
487 | (grad_effective_k, k_proj_lora_B.transpose(-2, -1), scale),
488 | (k_proj_lora_A.transpose(-2, -1), grad_effective_k, scale),
489 | (grad_effective_v, v_proj_lora_B.transpose(-2, -1), scale),
490 | (v_proj_lora_A.transpose(-2, -1), grad_effective_v, scale),
491 | ]
492 | grad_q_A, grad_q_B, grad_k_A, grad_k_B, grad_v_A, grad_v_B = bundled_scaled_matmul(
493 | bundles
494 | )
495 |
496 | grad_x = clip_grad(grad_x)
497 | grad_q_bias = clip_grad(grad_q_bias)
498 | grad_k_bias = clip_grad(grad_k_bias)
499 | grad_v_bias = clip_grad(grad_v_bias)
500 | grad_q_A = clip_grad(grad_q_A)
501 | grad_q_B = clip_grad(grad_q_B)
502 | grad_k_A = clip_grad(grad_k_A)
503 | grad_k_B = clip_grad(grad_k_B)
504 | grad_v_A = clip_grad(grad_v_A)
505 | grad_v_B = clip_grad(grad_v_B)
506 | grad_effective_q = clip_grad(grad_effective_q)
507 | grad_effective_k = clip_grad(grad_effective_k)
508 | grad_effective_v = clip_grad(grad_effective_v)
509 | grad_xQ = clip_grad(grad_xQ)
510 | grad_xK = clip_grad(grad_xK)
511 | grad_xV = clip_grad(grad_xV)
512 | grad_Q = clip_grad(grad_Q)
513 | grad_K = clip_grad(grad_K)
514 | grad_V = clip_grad(grad_V)
515 |
516 | grad_q_weight_path = None
517 | grad_k_weight_path = None
518 | grad_v_weight_path = None
519 | grad_scale = None
520 |
521 | return (
522 | grad_x,
523 | grad_q_weight_path,
524 | grad_k_weight_path,
525 | grad_v_weight_path,
526 | grad_q_bias,
527 | grad_k_bias,
528 | grad_v_bias,
529 | grad_q_A,
530 | grad_q_B,
531 | grad_k_A,
532 | grad_k_B,
533 | grad_v_A,
534 | grad_v_B,
535 | grad_scale,
536 | )
537 |
--------------------------------------------------------------------------------
/sllm/nn/transformer.py:
--------------------------------------------------------------------------------
1 | """
2 | This module implements a Transformer-based language model with support for decoder layers,
3 | gradient checkpointing, and caching mechanisms. It defines the following classes:
4 |
5 | - DecoderLayer: Implements a single layer of the transformer decoder.
6 | - Transformer: Composes multiple decoder layers into a full transformer model.
7 | - SuperLazyLanguageModel: Encapsulates the transformer and head for sequence generation tasks.
8 |
9 | The implementation leverages PyTorch for tensor computations and Hugging Face transformers
10 | for cache management.
11 | """
12 |
13 | from typing import List, Optional, Tuple, Union
14 |
15 | import torch
16 | import torch.nn.functional as F
17 | from torch import nn
18 | from transformers import Cache, DynamicCache, SlidingWindowCache, StaticCache
19 | from transformers.modeling_outputs import (BaseModelOutputWithPast,
20 | CausalLMOutputWithPast)
21 |
22 | from sllm.common import DTYPE
23 | from sllm.config import Config
24 | from sllm.nn.layers import (Attention, Embedding, Linear,
25 | MLP, RMSNorm, RotaryEmbedding)
26 |
27 |
28 | class DecoderLayer(nn.Module):
29 | """
30 | A single decoder layer of the Transformer.
31 |
32 | This layer consists of:
33 | - RMS normalization applied to the input (input_layernorm).
34 | - A self-attention mechanism (self_attn).
35 | - A residual connection adding the result back to the input.
36 | - A second RMS normalization (post_attention_layernorm) followed by a feed-forward MLP,
37 | with an additional residual connection.
38 |
39 | Args:
40 | config (Config): Configuration object containing model hyperparameters.
41 | layer_idx (int): Index of the layer to load layer-specific weights.
42 | """
43 |
44 | def __init__(self, config: Config, layer_idx: int):
45 | """
46 | Initialize the decoder layer.
47 |
48 | Loads layer-specific weights for normalization from the provided weight directory.
49 |
50 | Args:
51 | config (Config): Model configuration.
52 | layer_idx (int): Index for the current layer.
53 | """
54 | super().__init__()
55 | self.hidden_size = config.hidden_size
56 | self.self_attn = Attention(config=config, layer_idx=layer_idx)
57 | self.mlp = MLP(config, layer_idx)
58 |
59 | input_layer_norm_weight_path = (
60 | f"{config.weight_dir}/model.layers.{layer_idx}.input_layernorm.weight.bin"
61 | )
62 | post_attention_layer_norm_weight_path = f"{config.weight_dir}/model.layers.{layer_idx}.post_attention_layernorm.weight.bin"
63 |
64 | self.input_layernorm = RMSNorm(
65 | hidden_size=config.hidden_size,
66 | weight_path=input_layer_norm_weight_path,
67 | eps=config.rms_norm_eps,
68 | )
69 | self.post_attention_layernorm = RMSNorm(
70 | hidden_size=config.hidden_size,
71 | weight_path=post_attention_layer_norm_weight_path,
72 | eps=config.rms_norm_eps,
73 | )
74 |
75 | def forward(
76 | self,
77 | hidden_states: torch.Tensor,
78 | attention_mask: Optional[torch.Tensor] = None,
79 | position_ids: Optional[torch.LongTensor] = None,
80 | past_key_value: Optional[Cache] = None,
81 | output_attentions: Optional[bool] = False,
82 | use_cache: Optional[bool] = False,
83 | cache_position: Optional[torch.LongTensor] = None,
84 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
85 | **kwargs,
86 | ) -> Tuple[
87 | torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
88 | ]:
89 | """
90 | Perform a forward pass through the decoder layer.
91 |
92 | The input passes through an initial layer normalization, self-attention block, and
93 | a feed-forward MLP with residual connections. Optionally, attention weights can be returned.
94 |
95 | Args:
96 | hidden_states (torch.Tensor): Input tensor with shape (batch_size, seq_length, hidden_size).
97 | attention_mask (Optional[torch.Tensor], optional): Attention mask for self-attention.
98 | position_ids (Optional[torch.LongTensor], optional): Tensor containing position indices.
99 | past_key_value (Optional[Cache], optional): Cached past key and value tensors.
100 | output_attentions (Optional[bool], optional): If True, returns self-attention weights.
101 | use_cache (Optional[bool], optional): If True, enables caching for inference.
102 | cache_position (Optional[torch.LongTensor], optional): Positions for caching tokens.
103 | position_embeddings (Optional[Tuple[torch.Tensor, torch.Tensor]], optional): Pre-computed position embeddings.
104 | **kwargs: Additional keyword arguments.
105 |
106 | Returns:
107 | Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
108 | - The output tensor of shape (batch_size, seq_length, hidden_size).
109 | - Optionally, a tuple of self-attention weights.
110 | """
111 | residual = hidden_states
112 |
113 | hidden_states = self.input_layernorm(hidden_states)
114 |
115 | hidden_states, self_attn_weights = self.self_attn(
116 | hidden_states=hidden_states,
117 | attention_mask=attention_mask,
118 | position_ids=position_ids,
119 | past_key_value=past_key_value,
120 | output_attentions=output_attentions,
121 | use_cache=use_cache,
122 | cache_position=cache_position,
123 | position_embeddings=position_embeddings,
124 | **kwargs,
125 | )
126 | hidden_states = residual + hidden_states
127 |
128 | residual = hidden_states
129 | hidden_states = self.post_attention_layernorm(hidden_states)
130 | hidden_states = self.mlp(hidden_states)
131 | hidden_states = residual + hidden_states
132 |
133 | outputs = (hidden_states,)
134 | if output_attentions:
135 | outputs += (self_attn_weights,)
136 |
137 | return outputs
138 |
139 |
140 | class Transformer(nn.Module):
141 | """
142 | Transformer decoder composed of multiple decoder layers.
143 |
144 | This module includes:
145 | - Token embedding.
146 | - A stack of decoder layers.
147 | - Rotary positional embeddings.
148 | - Final normalization.
149 | - Optional support for gradient checkpointing and caching.
150 |
151 | Attributes:
152 | padding_idx (int): Padding index for token embeddings.
153 | vocab_size (int): Size of the vocabulary.
154 | embed_tokens (Embedding): Token embedding layer.
155 | layers (nn.ModuleList): A list of decoder layers.
156 | norm (RMSNorm): Normalization applied after the last decoder layer.
157 | rotary_emb (RotaryEmbedding): Module for rotary position embeddings.
158 | gradient_checkpointing (bool): Enables gradient checkpointing if True.
159 | config (Config): Model configuration.
160 | """
161 |
162 | def __init__(self, config: Config):
163 | """
164 | Initialize the Transformer decoder.
165 |
166 | Loads weights for token embeddings, each decoder layer, and the final normalization.
167 |
168 | Args:
169 | config (Config): Configuration object with model hyperparameters.
170 | """
171 | super().__init__()
172 | self.padding_idx = config.pad_token_id
173 | self.vocab_size = config.vocab_size
174 |
175 | self.embed_tokens = Embedding(
176 | vocab_size=config.vocab_size,
177 | hidden_size=config.hidden_size,
178 | padding_idx=self.padding_idx,
179 | weight_path=f"{config.weight_dir}/model.embed_tokens.weight.bin",
180 | )
181 | self.layers = nn.ModuleList(
182 | [
183 | DecoderLayer(config, layer_idx)
184 | for layer_idx in range(config.num_hidden_layers)
185 | ]
186 | )
187 | norm_weight_path = f"{config.weight_dir}/model.norm.weight.bin"
188 | self.norm = RMSNorm(
189 | config.hidden_size, weight_path=norm_weight_path, eps=config.rms_norm_eps
190 | )
191 | self.rotary_emb = RotaryEmbedding(config=config)
192 | self.gradient_checkpointing = False
193 | self.config = config
194 |
195 | def get_input_embeddings(self):
196 | """
197 | Retrieve the input embeddings.
198 |
199 | Returns:
200 | Embedding: The embedding layer used for token lookup.
201 | """
202 | return self.embed_tokens
203 |
204 | def set_input_embeddings(self, value):
205 | """
206 | Set the input embeddings.
207 |
208 | Args:
209 | value (Embedding): A new embedding layer.
210 | """
211 | self.embed_tokens = value
212 |
213 | def forward(
214 | self,
215 | input_ids: torch.LongTensor = None,
216 | attention_mask: Optional[torch.Tensor] = None,
217 | position_ids: Optional[torch.LongTensor] = None,
218 | past_key_values: Optional[Cache] = None,
219 | inputs_embeds: Optional[torch.FloatTensor] = None,
220 | use_cache: Optional[bool] = None,
221 | output_attentions: Optional[bool] = None,
222 | output_hidden_states: Optional[bool] = None,
223 | return_dict: Optional[bool] = None,
224 | cache_position: Optional[torch.LongTensor] = None,
225 | ) -> Union[Tuple, BaseModelOutputWithPast]:
226 | """
227 | Perform a forward pass through the Transformer decoder.
228 |
229 | Delegates embedding lookup to the embedding layer, applies rotary position embeddings,
230 | and passes data sequentially through each decoder layer. Also manages caching and generates
231 | an updated causal mask based on past key values.
232 |
233 | Args:
234 | input_ids (torch.LongTensor, optional): Input token IDs.
235 | attention_mask (Optional[torch.Tensor], optional): Attention mask for the sequence.
236 | position_ids (Optional[torch.LongTensor], optional): Position IDs for the sequence.
237 | past_key_values (Optional[Cache], optional): Cached key/value pairs.
238 | inputs_embeds (Optional[torch.FloatTensor], optional): Pre-computed input embeddings.
239 | use_cache (Optional[bool], optional): If True, enables caching.
240 | output_attentions (Optional[bool], optional): If True, outputs attention weights.
241 | output_hidden_states (Optional[bool], optional): If True, outputs hidden states.
242 | return_dict (Optional[bool], optional): If True, returns a dict-like object instead of a tuple.
243 | cache_position (Optional[torch.LongTensor], optional): Cache position indices.
244 |
245 | Returns:
246 | Union[Tuple, BaseModelOutputWithPast]:
247 | - If return_dict is False, returns a tuple with logits and optionally additional outputs.
248 | - Otherwise, returns a `BaseModelOutputWithPast` with last hidden state, cached key values,
249 | hidden states, and attentions.
250 | """
251 | output_attentions = (
252 | output_attentions
253 | if output_attentions is not None
254 | else self.config.output_attentions
255 | )
256 | output_hidden_states = (
257 | output_hidden_states
258 | if output_hidden_states is not None
259 | else self.config.output_hidden_states
260 | )
261 | use_cache = use_cache if use_cache is not None else self.config.use_cache
262 | return_dict = (
263 | return_dict if return_dict is not None else self.config.use_return_dict
264 | )
265 |
266 | if inputs_embeds is None:
267 | inputs_embeds = self.embed_tokens(input_ids)
268 |
269 | if use_cache and past_key_values is None:
270 | past_key_values = DynamicCache()
271 |
272 | if cache_position is None:
273 | past_seen_tokens = (
274 | past_key_values.get_seq_length() if past_key_values is not None else 0
275 | )
276 | cache_position = torch.arange(
277 | past_seen_tokens,
278 | past_seen_tokens + inputs_embeds.shape[1],
279 | device=inputs_embeds.device,
280 | )
281 |
282 | if position_ids is None:
283 | position_ids = cache_position.unsqueeze(0)
284 |
285 | causal_mask = self._update_causal_mask(
286 | attention_mask, inputs_embeds, cache_position, past_key_values
287 | )
288 |
289 | hidden_states = inputs_embeds
290 | position_embeddings = self.rotary_emb(hidden_states, position_ids)
291 | all_hidden_states = () if output_hidden_states else None
292 | all_self_attns = () if output_attentions else None
293 |
294 | for decoder_layer in self.layers[: self.config.num_hidden_layers]:
295 | if output_hidden_states:
296 | all_hidden_states += (hidden_states,)
297 |
298 | if self.gradient_checkpointing and self.training:
299 | layer_outputs = self._gradient_checkpointing_func(
300 | decoder_layer.__call__,
301 | hidden_states,
302 | causal_mask,
303 | position_ids,
304 | past_key_values,
305 | output_attentions,
306 | use_cache,
307 | cache_position,
308 | position_embeddings,
309 | )
310 | else:
311 | layer_outputs = decoder_layer(
312 | hidden_states,
313 | attention_mask=causal_mask,
314 | position_ids=position_ids,
315 | past_key_value=past_key_values,
316 | output_attentions=output_attentions,
317 | use_cache=use_cache,
318 | cache_position=cache_position,
319 | position_embeddings=position_embeddings,
320 | )
321 |
322 | hidden_states = layer_outputs[0]
323 |
324 | if output_attentions:
325 | all_self_attns += (layer_outputs[1],)
326 |
327 | hidden_states = self.norm(hidden_states)
328 |
329 | if output_hidden_states:
330 | all_hidden_states += (hidden_states,)
331 |
332 | output = BaseModelOutputWithPast(
333 | last_hidden_state=hidden_states,
334 | past_key_values=past_key_values if use_cache else None,
335 | hidden_states=all_hidden_states,
336 | attentions=all_self_attns,
337 | )
338 | return output if return_dict else output.to_tuple()
339 |
340 | def _update_causal_mask(
341 | self,
342 | attention_mask: torch.Tensor,
343 | input_tensor: torch.Tensor,
344 | cache_position: torch.Tensor,
345 | past_key_values: dict,
346 | ):
347 | """
348 | Generate a causal attention mask accounting for cached tokens.
349 |
350 | The mask is created based on the current input tensor, the cache positions, and the type of caching
351 | used (static or sliding window). This allows the model to correctly attend to previous tokens and
352 | manage attention when using caches.
353 |
354 | Args:
355 | attention_mask (torch.Tensor): Original attention mask.
356 | input_tensor (torch.Tensor): Input tensor with shape (batch_size, seq_length, hidden_size).
357 | cache_position (torch.Tensor): Tensor indicating positions for cached tokens.
358 | past_key_values (dict): Cached past key values, possibly an instance of StaticCache or SlidingWindowCache.
359 |
360 | Returns:
361 | torch.Tensor: A 4D causal attention mask.
362 | """
363 | past_seen_tokens = (
364 | past_key_values.get_seq_length() if past_key_values is not None else 0
365 | )
366 | using_static_cache = isinstance(past_key_values, StaticCache)
367 | using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
368 |
369 | dtype, device = input_tensor.dtype, input_tensor.device
370 | sequence_length = input_tensor.shape[1]
371 |
372 | if using_sliding_window_cache or using_static_cache:
373 | target_length = past_key_values.get_max_cache_shape()
374 | else:
375 | target_length = (
376 | attention_mask.shape[-1]
377 | if isinstance(attention_mask, torch.Tensor)
378 | else past_seen_tokens + sequence_length + 1
379 | )
380 |
381 | causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
382 | attention_mask,
383 | sequence_length=sequence_length,
384 | target_length=target_length,
385 | dtype=dtype,
386 | device=device,
387 | cache_position=cache_position,
388 | batch_size=input_tensor.shape[0],
389 | config=self.config,
390 | past_key_values=past_key_values,
391 | )
392 |
393 | return causal_mask
394 |
395 | @staticmethod
396 | def _prepare_4d_causal_attention_mask_with_cache_position(
397 | attention_mask: torch.Tensor,
398 | sequence_length: int,
399 | target_length: int,
400 | dtype: torch.dtype,
401 | device: torch.device,
402 | cache_position: torch.Tensor,
403 | batch_size: int,
404 | config: Config,
405 | past_key_values: dict,
406 | ):
407 | """
408 | Create a 4D causal attention mask with cache positions.
409 |
410 | This function prepares a mask of shape (batch_size, 1, query_length, key_value_length) from a
411 | 2D mask or creates a new one if needed. For sliding window configurations, additional masking is applied.
412 |
413 | Args:
414 | attention_mask (torch.Tensor): Input attention mask.
415 | sequence_length (int): Length of the current sequence.
416 | target_length (int): The target length (e.g., including cached tokens).
417 | dtype (torch.dtype): Data type for the mask.
418 | device (torch.device): Device for the mask tensor.
419 | cache_position (torch.Tensor): Tensor indicating positions for caching.
420 | batch_size (int): Batch size.
421 | config (Config): Model configuration (may include sliding window settings).
422 | past_key_values (dict): Cached past key values.
423 |
424 | Returns:
425 | torch.Tensor: A 4D causal attention mask tensor of shape (batch_size, 1, query_length, key_value_length).
426 | """
427 | if attention_mask is not None and attention_mask.dim() == 4:
428 | causal_mask = attention_mask
429 |
430 | else:
431 | min_dtype = torch.finfo(dtype).min
432 | causal_mask = torch.full(
433 | (sequence_length, target_length),
434 | fill_value=min_dtype,
435 | dtype=dtype,
436 | device=device,
437 | )
438 | diagonal_attend_mask = torch.arange(
439 | target_length, device=device
440 | ) > cache_position.reshape(-1, 1)
441 |
442 | if config.sliding_window is not None:
443 | if (
444 | not isinstance(past_key_values, SlidingWindowCache)
445 | or sequence_length > target_length
446 | ):
447 | sliding_attend_mask = torch.arange(
448 | target_length, device=device
449 | ) <= (cache_position.reshape(-1, 1) - config.sliding_window)
450 | diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
451 | causal_mask *= diagonal_attend_mask
452 | causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
453 |
454 | if attention_mask is not None:
455 | causal_mask = causal_mask.clone()
456 | if attention_mask.shape[-1] > target_length:
457 | attention_mask = attention_mask[:, :target_length]
458 | mask_length = attention_mask.shape[-1]
459 | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
460 | :, None, None, :
461 | ].to(causal_mask.device)
462 | padding_mask = padding_mask == 0
463 | causal_mask[:, :, :, :mask_length] = causal_mask[
464 | :, :, :, :mask_length
465 | ].masked_fill(padding_mask, min_dtype)
466 |
467 | return causal_mask
468 |
469 |
470 | class SuperLazyLanguageModel(nn.Module):
471 | """
472 | A super lazy language model that encapsulates a Transformer decoder with optional LoRA parameters.
473 |
474 | This model wraps the Transformer decoder, defines the head for vocabulary logits,
475 | and computes the loss when labels are provided. It supports caching for efficient generation.
476 |
477 | Attributes:
478 | config (Config): Model configuration containing hyperparameters.
479 | model (Transformer): Underlying Transformer-based decoder.
480 | loss_function (nn.CrossEntropyLoss): Loss function for training.
481 | vocab_size (int): Size of the model vocabulary.
482 | lm_head (Linear): Linear projection layer from hidden states to vocabulary logits.
483 | """
484 |
485 | def __init__(self, name, lora_alpha=16, lora_r=4, lora_dropout=0.1):
486 | """
487 | Initialize the SuperLazyLanguageModel.
488 |
489 | Loads configuration and initializes the transformer decoder along with the language modeling head.
490 | Optionally, applies LoRA modifications if enabled in the configuration.
491 |
492 | Args:
493 | name (str): Identifier or name of the model.
494 | lora_alpha (int, optional): LoRA alpha hyperparameter. Defaults to 16.
495 | lora_r (int, optional): LoRA rank. Defaults to 4.
496 | lora_dropout (float, optional): LoRA dropout rate. Defaults to 0.1.
497 | """
498 | super().__init__()
499 |
500 | self.config = Config(
501 | model_name=name,
502 | lora_alpha=lora_alpha,
503 | lora_r=lora_r,
504 | lora_dropout=lora_dropout,
505 | )
506 |
507 | self.model = Transformer(self.config)
508 | self.loss_function = nn.CrossEntropyLoss()
509 | self.vocab_size = self.config.vocab_size
510 |
511 | if self.config.tie_word_embeddings:
512 | lm_head_weight_path = (
513 | f"{self.config.weight_dir}/model.embed_tokens.weight.bin"
514 | )
515 | else:
516 | lm_head_weight_path = f"{self.config.weight_dir}/lm_head.weight.bin"
517 |
518 | self.lm_head = Linear(
519 | self.config.hidden_size,
520 | self.config.vocab_size,
521 | weight_path=lm_head_weight_path,
522 | bias_path=None,
523 | )
524 |
525 | def get_input_embeddings(self):
526 | """
527 | Retrieve the input embeddings from the model.
528 |
529 | Returns:
530 | Embedding: Input embedding layer.
531 | """
532 | return self.model.embed_tokens
533 |
534 | def set_input_embeddings(self, value):
535 | """
536 | Set the input embeddings for the model.
537 |
538 | Args:
539 | value (Embedding): New input embedding layer.
540 | """
541 | self.model.embed_tokens = value
542 |
543 | def get_output_embeddings(self):
544 | """
545 | Retrieve the output embeddings (language modeling head).
546 |
547 | Returns:
548 | Linear: The linear layer projecting to vocabulary logits.
549 | """
550 | return self.lm_head
551 |
552 | def set_output_embeddings(self, new_embeddings):
553 | """
554 | Set the output embeddings (language modeling head) for the model.
555 |
556 | Args:
557 | new_embeddings (Linear): New output embedding layer.
558 | """
559 | self.lm_head = new_embeddings
560 |
561 | def set_decoder(self, decoder):
562 | """
563 | Replace the current Transformer decoder with a new one.
564 |
565 | Args:
566 | decoder (Transformer): A new transformer decoder module.
567 | """
568 | self.model = decoder
569 |
570 | def get_decoder(self):
571 | """
572 | Retrieve the current Transformer decoder.
573 |
574 | Returns:
575 | Transformer: The underlying transformer decoder.
576 | """
577 | return self.model
578 |
579 | def forward(
580 | self,
581 | input_ids: torch.LongTensor = None,
582 | attention_mask: Optional[torch.Tensor] = None,
583 | position_ids: Optional[torch.LongTensor] = None,
584 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
585 | inputs_embeds: Optional[torch.FloatTensor] = None,
586 | labels: Optional[torch.LongTensor] = None,
587 | use_cache: Optional[bool] = None,
588 | output_attentions: Optional[bool] = None,
589 | output_hidden_states: Optional[bool] = None,
590 | return_dict: Optional[bool] = None,
591 | cache_position: Optional[torch.LongTensor] = None,
592 | logits_to_keep: Union[int, torch.Tensor] = 0,
593 | **kwargs,
594 | ) -> Union[Tuple, CausalLMOutputWithPast]:
595 | """
596 | Perform a forward pass through the SuperLazyLanguageModel.
597 |
598 | The method computes embeddings (or uses precomputed ones), obtains outputs from the Transformer decoder,
599 | projects hidden states to vocabulary logits using the lm_head, and computes the cross-entropy loss if labels are provided.
600 | It supports caching for efficient autoregressive generation and can return outputs as a tuple or as a dict-like object.
601 |
602 | Args:
603 | input_ids (torch.LongTensor, optional): Token IDs as input.
604 | attention_mask (Optional[torch.Tensor], optional): Mask to avoid attending to padding tokens.
605 | position_ids (Optional[torch.LongTensor], optional): Position IDs corresponding to the tokens.
606 | past_key_values (Optional[Union[Cache, List[torch.FloatTensor]]], optional): Cached key/value tensors.
607 | inputs_embeds (Optional[torch.FloatTensor], optional): Pre-computed input embeddings.
608 | labels (Optional[torch.LongTensor], optional): Ground truth labels for computing the loss.
609 | use_cache (Optional[bool], optional): Whether to use past key values for caching.
610 | output_attentions (Optional[bool], optional): Whether to return attention weights.
611 | output_hidden_states (Optional[bool], optional): Whether to return hidden states.
612 | return_dict (Optional[bool], optional): Whether to return a dict-like output.
613 | cache_position (Optional[torch.LongTensor], optional): Cache positions for tokens.
614 | logits_to_keep (Union[int, torch.Tensor], optional): Slicing index or tensor for logits computation.
615 | **kwargs: Additional keyword arguments.
616 |
617 | Returns:
618 | Union[Tuple, CausalLMOutputWithPast]:
619 | - If return_dict is False, returns a tuple containing logits and optionally other outputs.
620 | - Otherwise, returns a `CausalLMOutputWithPast` with loss (if computed), logits, past key values,
621 | hidden states, and attention weights.
622 | """
623 | output_attentions = (
624 | output_attentions
625 | if output_attentions is not None
626 | else self.config.output_attentions
627 | )
628 | output_hidden_states = (
629 | output_hidden_states
630 | if output_hidden_states is not None
631 | else self.config.output_hidden_states
632 | )
633 | return_dict = (
634 | return_dict if return_dict is not None else self.config.use_return_dict
635 | )
636 |
637 | outputs = self.model(
638 | input_ids=input_ids,
639 | attention_mask=attention_mask,
640 | position_ids=position_ids,
641 | past_key_values=past_key_values,
642 | inputs_embeds=inputs_embeds,
643 | use_cache=use_cache,
644 | output_attentions=output_attentions,
645 | output_hidden_states=output_hidden_states,
646 | return_dict=return_dict,
647 | cache_position=cache_position,
648 | **kwargs,
649 | )
650 |
651 | hidden_states = outputs[0]
652 | slice_indices = (
653 | slice(-logits_to_keep, None)
654 | if isinstance(logits_to_keep, int)
655 | else logits_to_keep
656 | )
657 | logits = self.lm_head(hidden_states[:, slice_indices, :])
658 |
659 | loss = None
660 | if labels is not None:
661 | shifted_logits = logits[:, :-1, :]
662 | shifted_labels = labels[:, 1:]
663 | loss = self.loss_function(
664 | shifted_logits.reshape(-1, shifted_logits.size(-1)),
665 | shifted_labels.reshape(-1),
666 | )
667 |
668 | if not return_dict:
669 | output = (logits,) + outputs[1:]
670 | return (loss,) + output if loss is not None else output
671 |
672 | return CausalLMOutputWithPast(
673 | loss=loss,
674 | logits=logits,
675 | past_key_values=outputs.past_key_values,
676 | hidden_states=outputs.hidden_states,
677 | attentions=outputs.attentions,
678 | )
679 |
--------------------------------------------------------------------------------
/sllm/nn/layers.py:
--------------------------------------------------------------------------------
1 | """
2 | This module implements various neural network layers and functions that are used for
3 | lazy weight loading and efficient computation in a transformer-based model. The layers include:
4 |
5 | - Embedding: Token embedding that lazily loads weights from disk.
6 | - RotaryEmbedding: Implements rotary position embeddings with dynamic frequency updates.
7 | - RMSNorm: Root-mean-square layer normalization with lazy weight loading.
8 | - Linear: A linear layer that lazily loads large weight matrices.
9 | - MLP: Feed-forward neural network layer using bundled matrix multiplications.
10 | - LoraLinear: Linear layer with LoRA adaptation for efficient fine-tuning.
11 | - LoraQKVLinear: Specialized LoRA linear layer for query/key/value projections.
12 | - Attention: Multi-head attention layer with support for rotary embeddings and caching.
13 |
14 | Each layer leverages lazy weight loading via the `load_tensor_from_storage` utility to
15 | optimize memory usage.
16 | """
17 |
18 | import gc
19 | from typing import Any, Optional, Tuple
20 |
21 | import torch
22 | import torch.nn.functional as F
23 | from torch import nn
24 |
25 | from sllm.common import DTYPE
26 | from sllm.config import Config
27 | from sllm.nn.autodiff import (BundledMatmulFunction,
28 | LoraFunction,
29 | LoraQKVLinearFunction,
30 | MatmulFunction)
31 | from sllm.utils import load_tensor_from_storage
32 |
33 |
34 | class Embedding(torch.nn.Module):
35 | """
36 | Token embedding layer with lazy weight loading.
37 |
38 | This module loads the weight matrix from storage on each forward pass,
39 | avoiding the need to hold the full embedding matrix in RAM.
40 | """
41 |
42 | def __init__(self, weight_path, vocab_size, hidden_size, padding_idx=None):
43 | """
44 | Initialize the Embedding layer.
45 |
46 | Args:
47 | weight_path (str): Path to the weight file.
48 | vocab_size (int): Number of tokens in the vocabulary.
49 | hidden_size (int): Dimensionality of the embedding vectors.
50 | padding_idx (optional): Padding index for tokens; defaults to None.
51 | """
52 | super().__init__()
53 | self.weight_path = weight_path
54 | self.vocab_size = vocab_size
55 | self.hidden_size = hidden_size
56 | self.padding_idx = padding_idx
57 |
58 | def forward(self, input_ids):
59 | """
60 | Perform a forward pass to retrieve embeddings for the provided token IDs.
61 |
62 | Args:
63 | input_ids (Tensor): Tensor containing token IDs with shape (...).
64 |
65 | Returns:
66 | Tensor: Embedding vectors corresponding to the input IDs.
67 | """
68 | with torch.no_grad():
69 | weight = load_tensor_from_storage(
70 | weight_path=self.weight_path,
71 | shape=(self.vocab_size, self.hidden_size),
72 | to_ram=False,
73 | )
74 | return weight[input_ids]
75 |
76 |
77 | class RotaryEmbedding(nn.Module):
78 | """
79 | Rotary position embedding module with dynamic frequency updates.
80 |
81 | Computes rotary embeddings and applies attention scaling. It updates the frequency
82 | parameters if the input sequence length exceeds cached values.
83 | """
84 |
85 | def __init__(self, config: Config):
86 | """
87 | Initialize the RotaryEmbedding module.
88 |
89 | Args:
90 | config (Config): Model configuration containing parameters such as
91 | max_position_embeddings and rope_theta.
92 | """
93 | super().__init__()
94 | self.max_seq_len_cached = config.max_position_embeddings
95 | self.original_max_seq_len = config.max_position_embeddings
96 | self.config = config
97 | inv_freq, self.attention_scaling = self.compute_rope_parameters(self.config)
98 | self.register_buffer("inv_freq", inv_freq, persistent=False)
99 | self.original_inv_freq = self.inv_freq
100 |
101 | def _dynamic_frequency_update(self, position_ids, device):
102 | """
103 | Dynamically update the frequency parameters if needed.
104 |
105 | The update occurs if:
106 | 1. The sequence length exceeds the cached maximum.
107 | 2. The current sequence length is smaller than the original max sequence length
108 | while the cached maximum is larger.
109 |
110 | Args:
111 | position_ids (Tensor): Tensor of position indices.
112 | device (torch.device): The device on which to perform the computation.
113 | """
114 | seq_len = torch.max(position_ids) + 1
115 | if seq_len > self.max_seq_len_cached:
116 | inv_freq, self.attention_scaling = self.rope_init_fn(
117 | self.config, device, seq_len=seq_len
118 | )
119 | self.register_buffer("inv_freq", inv_freq, persistent=False)
120 | self.max_seq_len_cached = seq_len
121 |
122 | if (
123 | seq_len < self.original_max_seq_len
124 | and self.max_seq_len_cached > self.original_max_seq_len
125 | ):
126 | self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
127 | self.max_seq_len_cached = self.original_max_seq_len
128 |
129 | @torch.no_grad()
130 | def forward(self, x, position_ids):
131 | """
132 | Compute rotary embeddings for the input tensor.
133 |
134 | Args:
135 | x (Tensor): Input tensor of shape (batch_size, seq_length, ...).
136 | position_ids (Tensor): Position indices for each token.
137 |
138 | Returns:
139 | Tuple[Tensor, Tensor]: A tuple (cos, sin) where each tensor contains
140 | the cosine and sine components of the rotary embeddings.
141 | """
142 | inv_freq_expanded = (
143 | self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
144 | )
145 | position_ids_expanded = position_ids[:, None, :].float()
146 | device_type = x.device.type
147 | device_type = (
148 | device_type
149 | if isinstance(device_type, str) and device_type != "mps"
150 | else "cpu"
151 | )
152 |
153 | with torch.autocast(device_type=device_type, enabled=False):
154 | freqs = (
155 | inv_freq_expanded.float() @ position_ids_expanded.float()
156 | ).transpose(1, 2)
157 | emb = torch.cat((freqs, freqs), dim=-1)
158 | cos = emb.cos()
159 | sin = emb.sin()
160 |
161 | # Apply attention scaling for advanced RoPE types (e.g., yarn)
162 | cos = cos * self.attention_scaling
163 | sin = sin * self.attention_scaling
164 |
165 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
166 |
167 | def compute_rope_parameters(
168 | self,
169 | config: Optional[Config] = None,
170 | ) -> Tuple["torch.Tensor", float]:
171 | """
172 | Compute the rotary embedding parameters.
173 |
174 | This method calculates the inverse frequency tensor and returns the corresponding attention scaling
175 | factor.
176 |
177 | Args:
178 | config (Optional[Config], optional): The model configuration. If not provided,
179 | defaults to the instance's configuration.
180 |
181 | Returns:
182 | Tuple[Tensor, float]: A tuple (inv_freq, attention_factor) where:
183 | - inv_freq is the inverse frequency tensor.
184 | - attention_factor is the scaling factor for attention.
185 | """
186 | partial_rotary_factor = (
187 | config.partial_rotary_factor
188 | if hasattr(config, "partial_rotary_factor")
189 | else 1.0
190 | )
191 | head_dim = getattr(
192 | config, "head_dim", config.hidden_size // config.num_attention_heads
193 | )
194 | dim = int(head_dim * partial_rotary_factor)
195 | attention_factor = 1.0
196 | inv_freq = 1.0 / (
197 | config.rope_theta
198 | ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)
199 | )
200 | return inv_freq, attention_factor
201 |
202 |
203 | class RMSNorm(torch.nn.Module):
204 | """
205 | Root-mean-square layer normalization with lazy-loaded weights.
206 |
207 | Applies RMS normalization to the input tensor using pre-loaded weights.
208 | """
209 |
210 | def __init__(self, hidden_size, weight_path, eps=1e-6):
211 | """
212 | Initialize the RMSNorm layer.
213 |
214 | Args:
215 | hidden_size (int or tuple): The shape of the weight tensor.
216 | weight_path (str): Path to the weight file.
217 | eps (float, optional): A small value to avoid division by zero. Defaults to 1e-6.
218 | """
219 | super().__init__()
220 | weight = load_tensor_from_storage(
221 | weight_path=weight_path, shape=hidden_size, to_ram=True
222 | )
223 | self.register_buffer("weight", weight)
224 | self.variance_epsilon = eps
225 |
226 | def forward(self, hidden_states):
227 | """
228 | Apply RMS normalization to the input tensor.
229 |
230 | Args:
231 | hidden_states (Tensor): Input tensor with shape (..., hidden_size).
232 |
233 | Returns:
234 | Tensor: The normalized tensor scaled by the registered weight.
235 | """
236 | input_dtype = hidden_states.dtype
237 | hidden_states = hidden_states.to(torch.float32)
238 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
239 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
240 | return self.weight * hidden_states.to(input_dtype)
241 |
242 | def extra_repr(self):
243 | """
244 | Return additional string representation of the RMSNorm module.
245 |
246 | Returns:
247 | str: A string representation showing the weight shape and epsilon value.
248 | """
249 | return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
250 |
251 |
252 | class Linear(nn.Module):
253 | """
254 | Linear layer that lazily loads a large weight matrix from disk.
255 |
256 | The weight matrix is loaded on every forward pass with gradients disabled.
257 | The bias, if provided, is loaded entirely into RAM.
258 | """
259 |
260 | def __init__(self, in_features, out_features, weight_path, bias_path=None):
261 | """
262 | Initialize the Linear layer.
263 |
264 | Args:
265 | in_features (int): Size of each input sample.
266 | out_features (int): Size of each output sample.
267 | weight_path (str): Path to the weight file.
268 | bias_path (str, optional): Path to the bias file. Defaults to None.
269 | """
270 | super().__init__()
271 | self.in_features = in_features
272 | self.out_features = out_features
273 | self.weight_path = weight_path
274 | self.bias_path = bias_path
275 |
276 | if self.bias_path is not None:
277 | with torch.no_grad():
278 | bias = load_tensor_from_storage(
279 | weight_path=self.bias_path, shape=(self.out_features,), to_ram=True
280 | )
281 | self.register_buffer("bias", bias)
282 | else:
283 | self.bias = None
284 |
285 | def forward(self, x):
286 | """
287 | Perform the forward pass of the linear layer.
288 |
289 | Loads the weight matrix from disk and computes the matrix multiplication.
290 | If bias is present, adds it to the result.
291 |
292 | Args:
293 | x (Tensor): Input tensor of shape (..., in_features).
294 |
295 | Returns:
296 | Tensor: Output tensor of shape (..., out_features).
297 | """
298 | weight = load_tensor_from_storage(
299 | weight_path=self.weight_path,
300 | shape=(self.out_features, self.in_features),
301 | to_ram=False,
302 | ).t()
303 |
304 | Wx = MatmulFunction.apply(x, weight, 1.0)
305 | if self.bias is not None:
306 | Wx += self.bias
307 | return Wx
308 |
309 |
310 | class MLP(nn.Module):
311 | """
312 | Feed-forward network (MLP) module using bundled matrix multiplications.
313 |
314 | This layer implements a gated MLP where the input is projected with two parallel
315 | linear layers, one of which is passed through a non-linearity and then elementwise
316 | multiplied, followed by a final projection.
317 | """
318 |
319 | def __init__(self, config, layer_idx):
320 | """
321 | Initialize the MLP module.
322 |
323 | Args:
324 | config (Config): Model configuration including hidden and intermediate sizes.
325 | layer_idx (int): Index of the current layer for loading weight files.
326 | """
327 | super().__init__()
328 | self.config = config
329 | self.hidden_size = config.hidden_size
330 | self.intermediate_size = config.intermediate_size
331 | self.act_fn = nn.SiLU()
332 | self.gate_proj_path = (
333 | f"{config.weight_dir}/model.layers.{layer_idx}.mlp.gate_proj.weight.bin"
334 | )
335 | self.up_proj_path = (
336 | f"{config.weight_dir}/model.layers.{layer_idx}.mlp.up_proj.weight.bin"
337 | )
338 | self.down_proj_path = (
339 | f"{config.weight_dir}/model.layers.{layer_idx}.mlp.down_proj.weight.bin"
340 | )
341 |
342 | def forward(self, x):
343 | """
344 | Compute the forward pass of the MLP.
345 |
346 | Projects the input using two parallel linear transformations (gate and up projections),
347 | applies a non-linearity and elementwise multiplication, and finally projects back down.
348 |
349 | Args:
350 | x (Tensor): Input tensor with shape (..., hidden_size).
351 |
352 | Returns:
353 | Tensor: Output tensor with shape (..., hidden_size).
354 | """
355 | gate_proj = load_tensor_from_storage(
356 | weight_path=self.gate_proj_path,
357 | shape=(self.intermediate_size, self.hidden_size),
358 | to_ram=False,
359 | )
360 | up_proj = load_tensor_from_storage(
361 | weight_path=self.up_proj_path,
362 | shape=(self.intermediate_size, self.hidden_size),
363 | to_ram=False,
364 | )
365 | down_proj = load_tensor_from_storage(
366 | weight_path=self.down_proj_path,
367 | shape=(self.hidden_size, self.intermediate_size),
368 | to_ram=False,
369 | )
370 |
371 | bundles = [
372 | (x, gate_proj.t(), 1.0),
373 | (x, up_proj.t(), 1.0),
374 | ]
375 | gate_proj_out, up_proj_out = BundledMatmulFunction.apply(bundles)
376 | activated_gate_proj = self.act_fn(gate_proj_out) * up_proj_out
377 | return MatmulFunction.apply(activated_gate_proj, down_proj.t(), 1.0)
378 |
379 |
380 | class LoraLinear(nn.Module):
381 | """
382 | LoRA adapted linear layer for efficient fine-tuning.
383 |
384 | This module applies LoRA by incorporating trainable low-rank matrices (lora_A, lora_B)
385 | to adjust the pre-trained weight matrix.
386 | """
387 |
388 | def __init__(
389 | self,
390 | in_features,
391 | out_features,
392 | r,
393 | alpha,
394 | weight_path,
395 | bias_path=None,
396 | lora_dropout=0.0,
397 | ):
398 | """
399 | Initialize the LoraLinear layer.
400 |
401 | Args:
402 | in_features (int): Number of input features.
403 | out_features (int): Number of output features.
404 | r (int): LoRA rank.
405 | alpha (int): LoRA scaling factor.
406 | weight_path (str): Path to the pre-trained weight file.
407 | bias_path (str, optional): Path to the bias file. Defaults to None.
408 | lora_dropout (float, optional): Dropout probability for LoRA. Defaults to 0.0.
409 | """
410 | super().__init__()
411 | self.in_features = in_features
412 | self.out_features = out_features
413 | self.r = r
414 | self.alpha = alpha
415 | self.scaling = self.alpha / self.r
416 | self.weight_path = weight_path
417 | self.lora_A = nn.Parameter(torch.randn(r, in_features))
418 | self.lora_B = nn.Parameter(torch.randn(out_features, r))
419 | self.lora_dropout = nn.Dropout(lora_dropout)
420 | self.bias = None
421 | if bias_path is not None:
422 | with torch.no_grad():
423 | bias = load_tensor_from_storage(
424 | weight_path=bias_path, shape=(self.out_features,), to_ram=True
425 | )
426 | self.register_buffer("bias", bias)
427 |
428 | def forward(self, x):
429 | """
430 | Forward pass for the LoraLinear layer.
431 |
432 | Applies dropout to the input, then uses the LoRA function to compute the modified linear transformation.
433 |
434 | Args:
435 | x (Tensor): Input tensor of shape (..., in_features).
436 |
437 | Returns:
438 | Tensor: Output tensor of shape (..., out_features).
439 | """
440 | x_dropped = self.lora_dropout(x)
441 | return LoraFunction.apply(
442 | x_dropped,
443 | self.lora_A,
444 | self.lora_B,
445 | self.weight_path,
446 | self.scaling,
447 | self.bias,
448 | )
449 |
450 |
451 | class LoraQKVLinear(nn.Module):
452 | """
453 | LoRA adapted linear layer for query, key, and value projections in attention.
454 |
455 | This layer applies LoRA to the query, key, and value projections and lazily loads
456 | the corresponding weight matrices.
457 | """
458 |
459 | def __init__(
460 | self,
461 | config,
462 | head_dim,
463 | q_weight_path,
464 | k_weight_path,
465 | v_weight_path,
466 | q_bias_path=None,
467 | k_bias_path=None,
468 | v_bias_path=None,
469 | ):
470 | """
471 | Initialize the LoraQKVLinear layer.
472 |
473 | Args:
474 | config (Config): Model configuration with LoRA parameters.
475 | head_dim (int): Dimension of each attention head.
476 | q_weight_path (str): Path to the query projection weight file.
477 | k_weight_path (str): Path to the key projection weight file.
478 | v_weight_path (str): Path to the value projection weight file.
479 | q_bias_path (str, optional): Path to the query projection bias file.
480 | k_bias_path (str, optional): Path to the key projection bias file.
481 | v_bias_path (str, optional): Path to the value projection bias file.
482 | """
483 | super().__init__()
484 | self.scaling = config.lora_alpha / config.lora_r
485 | self.lora_dropout = nn.Dropout(config.lora_dropout)
486 |
487 | self.q_proj_lora_A = nn.Parameter(
488 | torch.randn(config.hidden_size, config.lora_r, dtype=DTYPE)
489 | )
490 | nn.init.kaiming_uniform_(self.q_proj_lora_A, nonlinearity="linear")
491 | self.q_proj_lora_B = nn.Parameter(
492 | torch.zeros(
493 | config.lora_r, config.num_attention_heads * head_dim, dtype=DTYPE
494 | )
495 | )
496 | self.k_proj_lora_A = nn.Parameter(
497 | torch.randn(config.hidden_size, config.lora_r, dtype=DTYPE)
498 | )
499 | nn.init.kaiming_uniform_(self.k_proj_lora_A, nonlinearity="linear")
500 | self.k_proj_lora_B = nn.Parameter(
501 | torch.zeros(
502 | config.lora_r, config.num_key_value_heads * head_dim, dtype=DTYPE
503 | )
504 | )
505 | self.v_proj_lora_A = nn.Parameter(
506 | torch.randn(config.hidden_size, config.lora_r, dtype=DTYPE)
507 | )
508 | nn.init.kaiming_uniform_(self.v_proj_lora_A, nonlinearity="linear")
509 | self.v_proj_lora_B = nn.Parameter(
510 | torch.zeros(
511 | config.lora_r, config.num_key_value_heads * head_dim, dtype=DTYPE
512 | )
513 | )
514 |
515 | self.q_weight_path = q_weight_path
516 | self.k_weight_path = k_weight_path
517 | self.v_weight_path = v_weight_path
518 |
519 | self.q_proj_bias = None
520 | self.k_proj_bias = None
521 | self.v_proj_bias = None
522 |
523 | q_dim = config.num_attention_heads * head_dim
524 | kv_dim = config.num_key_value_heads * head_dim
525 |
526 | if q_bias_path is not None:
527 | self.q_proj_bias = load_tensor_from_storage(
528 | weight_path=q_bias_path, shape=(q_dim,), to_ram=True
529 | )
530 |
531 | if k_bias_path is not None:
532 | self.k_proj_bias = load_tensor_from_storage(
533 | weight_path=k_bias_path, shape=(kv_dim,), to_ram=True
534 | )
535 |
536 | if v_bias_path is not None:
537 | self.v_proj_bias = load_tensor_from_storage(
538 | weight_path=v_bias_path, shape=(kv_dim,), to_ram=True
539 | )
540 |
541 | def forward(self, x):
542 | """
543 | Forward pass for the LoRA QKV linear layer.
544 |
545 | Applies dropout to the input and computes the LoRA adapted query, key, and value projections.
546 |
547 | Args:
548 | x (Tensor): Input tensor of shape (..., hidden_size).
549 |
550 | Returns:
551 | Tuple[Tensor, Tensor, Tensor]: A tuple containing the projected query, key, and value tensors.
552 | """
553 | x_dropped = self.lora_dropout(x)
554 | return LoraQKVLinearFunction.apply(
555 | x_dropped,
556 | self.q_weight_path,
557 | self.k_weight_path,
558 | self.v_weight_path,
559 | self.q_proj_bias,
560 | self.k_proj_bias,
561 | self.v_proj_bias,
562 | self.q_proj_lora_A,
563 | self.q_proj_lora_B,
564 | self.k_proj_lora_A,
565 | self.k_proj_lora_B,
566 | self.v_proj_lora_A,
567 | self.v_proj_lora_B,
568 | self.scaling,
569 | )
570 |
571 |
572 | class Attention(nn.Module):
573 | """
574 | Multi-head attention layer with support for rotary embeddings and caching.
575 |
576 | This layer computes queries, keys, and values using a LoRA adapted linear layer,
577 | applies rotary position embeddings, and then performs scaled dot-product attention.
578 | It also supports updating caches via past key/value mechanisms.
579 | """
580 |
581 | def __init__(self, config: Config, layer_idx: int):
582 | """
583 | Initialize the Attention layer.
584 |
585 | Args:
586 | config (Config): Model configuration.
587 | layer_idx (int): Index of the current layer for loading layer-specific weights.
588 | """
589 | super().__init__()
590 | self.config = config
591 | self.layer_idx = layer_idx
592 | self.head_dim = getattr(
593 | config, "head_dim", config.hidden_size // config.num_attention_heads
594 | )
595 | self.num_key_value_groups = (
596 | config.num_attention_heads // config.num_key_value_heads
597 | )
598 | self.scaling = self.head_dim**-0.5
599 | self.attention_dropout = config.attention_dropout
600 | self.is_causal = True
601 |
602 | self.q_proj_weight_path = (
603 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.q_proj.weight.bin"
604 | )
605 | self.k_proj_weight_path = (
606 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.k_proj.weight.bin"
607 | )
608 | self.v_proj_weight_path = (
609 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.v_proj.weight.bin"
610 | )
611 | self.q_proj_bias_path = (
612 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.q_proj.bias.bin"
613 | )
614 | self.k_proj_bias_path = (
615 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.k_proj.bias.bin"
616 | )
617 | self.v_proj_bias_path = (
618 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.v_proj.bias.bin"
619 | )
620 | self.o_proj_weight_path = (
621 | f"{config.weight_dir}/model.layers.{layer_idx}.self_attn.o_proj.weight.bin"
622 | )
623 |
624 | self.lora = LoraQKVLinear(
625 | config,
626 | self.head_dim,
627 | self.q_proj_weight_path,
628 | self.k_proj_weight_path,
629 | self.v_proj_weight_path,
630 | self.q_proj_bias_path,
631 | self.k_proj_bias_path,
632 | self.v_proj_bias_path,
633 | )
634 |
635 | def forward(
636 | self,
637 | hidden_states: torch.Tensor,
638 | position_embeddings: Tuple[torch.Tensor, torch.Tensor],
639 | attention_mask: Optional[torch.Tensor],
640 | past_key_value: Optional[Any] = None,
641 | cache_position: Optional[torch.LongTensor] = None,
642 | **kwargs,
643 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
644 | """
645 | Forward pass for the Attention layer.
646 |
647 | Projects inputs into queries, keys, and values, applies rotary embeddings, and computes
648 | the attention output along with optional attention weights.
649 |
650 | Args:
651 | hidden_states (Tensor): Input tensor with shape (batch_size, seq_length, hidden_size).
652 | position_embeddings (Tuple[Tensor, Tensor]): Tuple containing cosine and sine embeddings.
653 | attention_mask (Optional[Tensor]): Attention mask to apply.
654 | past_key_value (Optional[Any]): Cached past key/value states.
655 | cache_position (Optional[LongTensor]): Cache positions for tokens.
656 | **kwargs: Additional keyword arguments.
657 |
658 | Returns:
659 | Tuple:
660 | - attn_output (Tensor): Output tensor after attention.
661 | - attn_weights (Optional[Tensor]): Attention weights if computed.
662 | """
663 | input_shape = hidden_states.shape[:-1]
664 | hidden_shape = (*input_shape, -1, self.head_dim)
665 |
666 | query_states, key_states, value_states = self.lora(hidden_states)
667 | query_states = query_states.view(hidden_shape).transpose(1, 2)
668 | key_states = key_states.view(hidden_shape).transpose(1, 2)
669 | value_states = value_states.view(hidden_shape).transpose(1, 2)
670 |
671 | cos, sin = position_embeddings
672 | query_states, key_states = self.apply_rotary_pos_emb(
673 | query_states, key_states, cos, sin
674 | )
675 |
676 | if past_key_value is not None:
677 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
678 | key_states, value_states = past_key_value.update(
679 | key_states, value_states, self.layer_idx, cache_kwargs
680 | )
681 |
682 | sliding_window = None
683 | if (
684 | self.config.use_sliding_window
685 | and getattr(self.config, "sliding_window", None) is not None
686 | and self.layer_idx >= self.config.max_window_layers
687 | ):
688 | sliding_window = self.config.sliding_window
689 |
690 | attn_output, attn_weights = self.attention_forward(
691 | self,
692 | query_states,
693 | key_states,
694 | value_states,
695 | attention_mask,
696 | dropout=0.0 if not self.training else self.attention_dropout,
697 | scaling=self.scaling,
698 | sliding_window=sliding_window,
699 | **kwargs,
700 | )
701 |
702 | attn_output = attn_output.reshape(*input_shape, -1).contiguous()
703 |
704 | with torch.no_grad():
705 | o_proj_weight = load_tensor_from_storage(
706 | weight_path=self.o_proj_weight_path,
707 | shape=(self.config.hidden_size, self.config.hidden_size),
708 | to_ram=False,
709 | )
710 | attn_output = MatmulFunction.apply(attn_output, o_proj_weight.t(), 1.0)
711 | gc.collect()
712 | return attn_output, attn_weights
713 |
714 | def attention_forward(
715 | self,
716 | module: nn.Module,
717 | query: torch.Tensor,
718 | key: torch.Tensor,
719 | value: torch.Tensor,
720 | attention_mask: Optional[torch.Tensor],
721 | scaling: float,
722 | dropout: float = 0.0,
723 | **kwargs,
724 | ):
725 | """
726 | Compute the core attention mechanism.
727 |
728 | Applies scaled dot-product attention over the queries, keys, and values, taking into account
729 | the attention mask and dropout.
730 |
731 | Args:
732 | module (nn.Module): Reference to the current module (used for configuration).
733 | query (Tensor): Query tensor.
734 | key (Tensor): Key tensor.
735 | value (Tensor): Value tensor.
736 | attention_mask (Optional[Tensor]): Mask to apply to the attention weights.
737 | scaling (float): Scaling factor applied to the dot product.
738 | dropout (float, optional): Dropout probability for the attention weights.
739 | **kwargs: Additional keyword arguments.
740 |
741 | Returns:
742 | Tuple:
743 | - attn_output (Tensor): The attention output tensor.
744 | - attn_weights (Tensor): The computed attention weights.
745 | """
746 | key_states = self.repeat_kv(key, module.num_key_value_groups)
747 | value_states = self.repeat_kv(value, module.num_key_value_groups)
748 |
749 | attn_weights = MatmulFunction.apply(
750 | query, key_states.transpose(2, 3), scaling
751 | )
752 | if attention_mask is not None:
753 | causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
754 | attn_weights = attn_weights + causal_mask
755 |
756 | attn_weights = nn.functional.softmax(
757 | attn_weights, dim=-1, dtype=torch.float32
758 | ).to(query.dtype)
759 | attn_weights = nn.functional.dropout(
760 | attn_weights, p=dropout, training=module.training
761 | )
762 | attn_output = MatmulFunction.apply(attn_weights, value_states, 1.0)
763 | attn_output = attn_output.transpose(1, 2).contiguous()
764 | return attn_output, attn_weights
765 |
766 | def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
767 | """
768 | Repeat key/value hidden states to match the required number of groups.
769 |
770 | Args:
771 | hidden_states (Tensor): Hidden states tensor with shape (batch, num_heads, seq_len, head_dim).
772 | n_rep (int): Number of repetitions.
773 |
774 | Returns:
775 | Tensor: Reshaped tensor with repeated key/value states.
776 | """
777 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape
778 | if n_rep == 1:
779 | return hidden_states
780 | hidden_states = hidden_states[:, :, None, :, :].expand(
781 | batch, num_key_value_heads, n_rep, slen, head_dim
782 | )
783 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
784 |
785 | def rotate_half(self, x):
786 | """
787 | Rotate half of the tensor for rotary embedding application.
788 |
789 | Args:
790 | x (Tensor): Input tensor.
791 |
792 | Returns:
793 | Tensor: Tensor with rotated halves.
794 | """
795 | x1 = x[..., : x.shape[-1] // 2]
796 | x2 = x[..., x.shape[-1] // 2 :]
797 | return torch.cat((-x2, x1), dim=-1)
798 |
799 | def apply_rotary_pos_emb(self, queries, keys, cos, sin, unsqueeze_dim=1):
800 | """
801 | Apply rotary positional embeddings to queries and keys.
802 |
803 | Args:
804 | queries (Tensor): Query tensor.
805 | keys (Tensor): Key tensor.
806 | cos (Tensor): Cosine embedding.
807 | sin (Tensor): Sine embedding.
808 | unsqueeze_dim (int, optional): Dimension along which to unsqueeze the cosine and sine embeddings.
809 | Defaults to 1.
810 |
811 | Returns:
812 | Tuple[Tensor, Tensor]: Rotated queries and keys with positional embeddings applied.
813 | """
814 | cos = cos.unsqueeze(unsqueeze_dim)
815 | sin = sin.unsqueeze(unsqueeze_dim)
816 | queries_embed = (queries * cos) + (self.rotate_half(queries) * sin)
817 | keys_embed = (keys * cos) + (self.rotate_half(keys) * sin)
818 | return queries_embed, keys_embed
819 |
--------------------------------------------------------------------------------