├── LICENSE ├── README.md ├── advanced_usage.ipynb ├── demo.ipynb ├── minlora ├── __init__.py ├── model.py └── utils.py └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # minLoRA 2 | 3 | 4 | A minimal, but versatile PyTorch re-implementation of [LoRA](https://github.com/microsoft/LoRA). In only ~100 lines of code, minLoRA supports the following features: 5 | 6 | ### Features 7 | 8 | - Functional, no need to modify the model definition 9 | - Works everywhere, as long as you use `torch.nn.Module` 10 | - PyTorch native, uses PyTorch's `torch.nn.utils.parametrize` to do all the heavy lifting 11 | - Easily extendable, you can add your own LoRA parameterization 12 | - Supports training, inference, and inference with multiple LoRA models 13 | 14 | ## Demo 15 | 16 | - `demo.ipynb` shows the basic usage of the library 17 | - `advanced_usage.ipynb` shows how you can add LoRA to other layers such as embedding, and how to tie weights 18 | 19 | ## Examples 20 | 21 | - Finetuning GPT using LoRA + nanoGPT: https://github.com/cccntu/LoRAnanoGPT/pull/1/files 22 | 23 | ## Library Installation 24 | 25 | If you want to `import minlora` into your project: 26 | 27 | ``` 28 | git clone https://github.com/cccntu/minLoRA.git 29 | cd minLoRA 30 | pip install -e . 31 | ``` 32 | 33 | ## Usage 34 | 35 | ```python 36 | import torch 37 | from minlora import add_lora, apply_to_lora, disable_lora, enable_lora, get_lora_params, merge_lora, name_is_lora, remove_lora, load_multiple_lora, select_lora 38 | ``` 39 | 40 | ### Training a model with minLoRA 41 | 42 | ```python 43 | model = torch.nn.Linear(in_features=5, out_features=3) 44 | # Step 1: Add LoRA to the model 45 | add_lora(model) 46 | 47 | # Step 2: Collect the parameters, pass them to the optimizer 48 | 49 | parameters = [ 50 | {"params": list(get_lora_params(model))}, 51 | ] 52 | optimizer = torch.optim.AdamW(parameters, lr=1e-3) 53 | 54 | # Step 3: Train the model 55 | # ... 56 | 57 | # Step 4: export the LoRA parameters 58 | lora_state_dict = get_lora_state_dict(model) 59 | ``` 60 | 61 | ### Loading and Inferencing with minLoRA 62 | 63 | ```python 64 | # Step 1: Add LoRA to your model 65 | add_lora(model) 66 | 67 | # Step 2: Load the LoRA parameters 68 | _ = model.load_state_dict(lora_state_dict, strict=False) 69 | 70 | # Step 3: Merge the LoRA parameters into the model 71 | merge_lora(model) 72 | ``` 73 | 74 | ### Inferencing with multiple LoRA models 75 | 76 | ```python 77 | # to avoid re-adding lora to the model when rerun the cell, remove lora first 78 | remove_lora(model) 79 | # Step 1: Add LoRA to your model 80 | add_lora(model) 81 | 82 | # Step 2: Load the LoRA parameters 83 | 84 | # load three sets of LoRA parameters 85 | lora_state_dicts = [lora_state_dict_0, lora_state_dict_1, lora_state_dict_2] 86 | 87 | load_multiple_lora(model, lora_state_dicts) 88 | 89 | 90 | # Step 3: Select which LoRA to use at inference time 91 | Y0 = select_lora(model, 0)(x) 92 | Y1 = select_lora(model, 1)(x) 93 | Y2 = select_lora(model, 2)(x) 94 | ``` 95 | ### References 96 | 97 | - [microsoft/LoRA](https://github.com/microsoft/LoRA) has the official implementation of LoRA, in PyTorch 98 | - [karpathy/minGPT](https://github.com/karpathy/minGPT) the structure of the repo is adapted from minGPT 99 | 100 | 101 | ### TODO 102 | - [x] A notebook to show how to configure LoRA parameters 103 | - [x] Real training & inference examples 104 | -------------------------------------------------------------------------------- /advanced_usage.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "92a4ce86", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from functools import partial\n", 11 | "\n", 12 | "import torch\n", 13 | "from minlora import (\n", 14 | " LoRAParametrization,\n", 15 | " add_lora,\n", 16 | " apply_to_lora,\n", 17 | " merge_lora,\n", 18 | ")\n", 19 | "from torch import nn\n", 20 | "\n", 21 | "_ = torch.set_grad_enabled(False)" 22 | ] 23 | }, 24 | { 25 | "attachments": {}, 26 | "cell_type": "markdown", 27 | "id": "2baccfbd", 28 | "metadata": {}, 29 | "source": [ 30 | "## Adding LoRA to layers other than nn.Linear" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "id": "ec04a954", 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "data": { 41 | "text/plain": [ 42 | "Sequential(\n", 43 | " (0): ParametrizedEmbedding(\n", 44 | " 3, 2\n", 45 | " (parametrizations): ModuleDict(\n", 46 | " (weight): ParametrizationList(\n", 47 | " (0): LoRAParametrization()\n", 48 | " )\n", 49 | " )\n", 50 | " )\n", 51 | " (1): ParametrizedLinear(\n", 52 | " in_features=2, out_features=3, bias=True\n", 53 | " (parametrizations): ModuleDict(\n", 54 | " (weight): ParametrizationList(\n", 55 | " (0): LoRAParametrization()\n", 56 | " )\n", 57 | " )\n", 58 | " )\n", 59 | ")" 60 | ] 61 | }, 62 | "execution_count": 2, 63 | "metadata": {}, 64 | "output_type": "execute_result" 65 | } 66 | ], 67 | "source": [ 68 | "## add_lora supports an optional `lora_config` argument of type Dict[Type[nn.Module], Dict[str, Callable]]\n", 69 | "## it specifies how to apply lora to each layer\n", 70 | "\n", 71 | "## Currently, there are support for nn.Embedding, nn.Linear, and nn.Conv2d\n", 72 | "\n", 73 | "lora_config = {\n", 74 | " nn.Embedding: {\n", 75 | " \"weight\": partial(LoRAParametrization.from_embedding, rank=4),\n", 76 | " },\n", 77 | " nn.Linear: {\n", 78 | " \"weight\": partial(LoRAParametrization.from_linear, rank=4),\n", 79 | " },\n", 80 | "}\n", 81 | "\n", 82 | "model = nn.Sequential(\n", 83 | " nn.Embedding(num_embeddings=3, embedding_dim=2),\n", 84 | " nn.Linear(in_features=2, out_features=3),\n", 85 | ")\n", 86 | "add_lora(model, lora_config=lora_config)\n", 87 | "model" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "id": "6d96024c", 93 | "metadata": {}, 94 | "source": [ 95 | "## Tying weights" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 3, 101 | "id": "5b649fe9", 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "name": "stdout", 106 | "output_type": "stream", 107 | "text": [ 108 | "True\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "# let's see if this works\n", 114 | "linear = nn.Linear(in_features=2, out_features=3)\n", 115 | "embedding = nn.Embedding(num_embeddings=3, embedding_dim=2)\n", 116 | "# tie the weights of the linear layer and the embedding layer\n", 117 | "embedding.weight = linear.weight\n", 118 | "print(torch.allclose(embedding.weight, linear.weight))\n", 119 | "# so far so good" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 4, 125 | "id": "c7d00069", 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | "False\n" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "# now, add lora to the linear layer\n", 138 | "add_lora(linear)\n", 139 | "# and update the lora weights to make it non-zero\n", 140 | "linear.apply(apply_to_lora(lambda x: nn.init.ones_(x.lora_B)))\n", 141 | "# and the weights are no longer the same\n", 142 | "print(torch.allclose(embedding.weight, linear.weight))" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 5, 148 | "id": "3e6e9bfe", 149 | "metadata": {}, 150 | "outputs": [ 151 | { 152 | "name": "stdout", 153 | "output_type": "stream", 154 | "text": [ 155 | " \n" 156 | ] 157 | } 158 | ], 159 | "source": [ 160 | "# because adding lora makes the `weight` a computed property that returns a tensor.\n", 161 | "# It's not a Parameter anymore\n", 162 | "print(type(linear.weight), type(embedding.weight))" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 6, 168 | "id": "d02d2819", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "# to tie the weights, we need to add lora to the embedding layer as well\n", 173 | "# let's add lora to the embedding layer\n", 174 | "\n", 175 | "add_lora(embedding, lora_config=lora_config)\n", 176 | "# tie the lora weights\n", 177 | "# because the fan_in and fan_out are opposite to each other, we need to swap the lora weights A and B\n", 178 | "# here we assign the linear layer's A to the embedding layer's B, and vice versa\n", 179 | "# you can do it the other way around as well, but the initialization will be different\n", 180 | "embedding.parametrizations.weight[0].lora_A = linear.parametrizations.weight[0].lora_B\n", 181 | "embedding.parametrizations.weight[0].lora_B = linear.parametrizations.weight[0].lora_A\n", 182 | "linear.apply(apply_to_lora(lambda x: nn.init.uniform_(x.lora_B)))\n", 183 | "linear.apply(apply_to_lora(lambda x: nn.init.uniform_(x.lora_B)))\n", 184 | "assert torch.allclose(linear.weight, embedding.weight)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 7, 190 | "id": "f34b1f1d", 191 | "metadata": { 192 | "lines_to_next_cell": 2 193 | }, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "torch.Size([3, 4]) torch.Size([4, 2]) torch.Size([4, 2]) torch.Size([3, 4])\n" 200 | ] 201 | } 202 | ], 203 | "source": [ 204 | "# although the shape of the weight is the same, the lora parameters have different shapes\n", 205 | "print(\n", 206 | " embedding.parametrizations.weight[0].lora_A.shape,\n", 207 | " linear.parametrizations.weight[0].lora_A.shape,\n", 208 | " embedding.parametrizations.weight[0].lora_B.shape,\n", 209 | " linear.parametrizations.weight[0].lora_B.shape,\n", 210 | ")" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 8, 216 | "id": "34ada79d", 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "True\n", 224 | "True\n" 225 | ] 226 | } 227 | ], 228 | "source": [ 229 | "# update to the linear layer will also update the embedding layer\n", 230 | "linear.apply(apply_to_lora(lambda x: nn.init.uniform_(x.lora_B)))\n", 231 | "linear.apply(apply_to_lora(lambda x: nn.init.uniform_(x.lora_A)))\n", 232 | "print(torch.allclose(linear.weight, embedding.weight))\n", 233 | "# vice versa\n", 234 | "embedding.apply(apply_to_lora(lambda x: nn.init.uniform_(x.lora_B)))\n", 235 | "embedding.apply(apply_to_lora(lambda x: nn.init.uniform_(x.lora_A)))\n", 236 | "print(torch.allclose(linear.weight, embedding.weight))\n", 237 | "# embedding.apply(apply_to_lora(lambda x: print(x.lora_B, x.lora_A)))\n", 238 | "# linear.apply(apply_to_lora(lambda x: print(x.lora_B, x.lora_A)))" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 9, 244 | "id": "51c20159", 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "# we can put the logic of tying the weights in a function\n", 249 | "def tie_weights(linear: nn.Linear, embedding: nn.Embedding):\n", 250 | " \"\"\"tie the weights of the linear layer and the embedding layer both with the same lora\"\"\"\n", 251 | " # this line below is optional if the original is already tied\n", 252 | " embedding.parametrizations.weight.original = linear.parametrizations.weight.original\n", 253 | " embedding.parametrizations.weight[0].lora_A = linear.parametrizations.weight[0].lora_B\n", 254 | " embedding.parametrizations.weight[0].lora_B = linear.parametrizations.weight[0].lora_A\n", 255 | "# you can import this function directly:\n", 256 | "from minlora import tie_weights, untie_weights" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 10, 262 | "id": "9d5cc0b4", 263 | "metadata": { 264 | "lines_to_next_cell": 0 265 | }, 266 | "outputs": [], 267 | "source": [ 268 | "# now back to our first model with lora\n", 269 | "tie_weights(model[0], model[1])\n", 270 | "# update the lora weights of the linear layer\n", 271 | "apply_to_lora(lambda x: nn.init.uniform_(x.lora_B))(model[1])\n", 272 | "# and the weights are the still the same\n", 273 | "assert torch.allclose(model[0].weight, model[1].weight)\n", 274 | "merge_lora(model)\n", 275 | "# even after merging lora, the weights are still the same\n", 276 | "assert torch.allclose(model[0].weight, model[1].weight)" 277 | ] 278 | } 279 | ], 280 | "metadata": { 281 | "jupytext": { 282 | "cell_metadata_filter": "-all", 283 | "main_language": "python", 284 | "notebook_metadata_filter": "-all" 285 | }, 286 | "kernelspec": { 287 | "display_name": "lora", 288 | "language": "python", 289 | "name": "python3" 290 | }, 291 | "language_info": { 292 | "codemirror_mode": { 293 | "name": "ipython", 294 | "version": 3 295 | }, 296 | "file_extension": ".py", 297 | "mimetype": "text/x-python", 298 | "name": "python", 299 | "nbconvert_exporter": "python", 300 | "pygments_lexer": "ipython3", 301 | "version": "3.9.16" 302 | }, 303 | "vscode": { 304 | "interpreter": { 305 | "hash": "cd38cab5b092fbce1866c43acaed152c77b80a12cd5e2b7fb23112c1a171e061" 306 | } 307 | } 308 | }, 309 | "nbformat": 4, 310 | "nbformat_minor": 5 311 | } 312 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "8209bea0", 7 | "metadata": { 8 | "lines_to_next_cell": 2 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import torch\n", 13 | "from minlora import add_lora, apply_to_lora, disable_lora, enable_lora, get_lora_params, merge_lora, name_is_lora, remove_lora, load_multiple_lora, select_lora, get_lora_state_dict\n", 14 | "_ = torch.set_grad_enabled(False)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "id": "492093a9", 21 | "metadata": { 22 | "lines_to_next_cell": 0 23 | }, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "tensor([[-0.3555, -0.0929, 0.6221]])\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "# a simple model\n", 35 | "model = torch.nn.Sequential(\n", 36 | " torch.nn.Linear(in_features=5, out_features=7),\n", 37 | " torch.nn.Linear(in_features=7, out_features=3),\n", 38 | ")\n", 39 | "\n", 40 | "x = torch.randn(1, 5)\n", 41 | "y = model(x)\n", 42 | "print(y)\n", 43 | "Y0 = y" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "id": "98584a8c", 50 | "metadata": { 51 | "lines_to_next_cell": 0 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "# add lora to the model\n", 56 | "# becase B is initialized to 0, the output is the same as before\n", 57 | "add_lora(model)\n", 58 | "y = model(x)\n", 59 | "assert torch.allclose(y, Y0)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "id": "c0251891", 66 | "metadata": { 67 | "lines_to_next_cell": 0 68 | }, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "tensor([[-0.4703, -0.3157, 0.4262]])\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "# to make the output different, we need to initialize B to something non-zero\n", 80 | "model.apply(apply_to_lora(lambda x: torch.nn.init.ones_(x.lora_B)))\n", 81 | "y = model(x)\n", 82 | "print(y)\n", 83 | "assert not torch.allclose(y, Y0)\n", 84 | "Y1 = y" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 5, 90 | "id": "196087bc", 91 | "metadata": { 92 | "lines_to_next_cell": 0 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "# now let's try to disable lora, the output is the same as before lora is added\n", 97 | "disable_lora(model)\n", 98 | "y = model(x)\n", 99 | "assert torch.allclose(y, Y0)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 6, 105 | "id": "1e9cba3c", 106 | "metadata": { 107 | "lines_to_next_cell": 0 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "# enable lora again\n", 112 | "enable_lora(model)\n", 113 | "y = model(x)\n", 114 | "assert torch.allclose(y, Y1)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 7, 120 | "id": "57f19300", 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/plain": [ 126 | "dict_keys(['0.parametrizations.weight.0.lora_A', '0.parametrizations.weight.0.lora_B', '1.parametrizations.weight.0.lora_A', '1.parametrizations.weight.0.lora_B'])" 127 | ] 128 | }, 129 | "execution_count": 7, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "# let's save the state dict for later use\n", 136 | "state_dict_to_save = get_lora_state_dict(model)\n", 137 | "state_dict_to_save.keys()" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 8, 143 | "id": "19a06b21", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "# you can remove lora from the model\n", 148 | "remove_lora(model)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 9, 154 | "id": "522e71f1", 155 | "metadata": { 156 | "lines_to_next_cell": 0 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "# lets try to load the lora back\n", 161 | "# first we need to add lora to the model\n", 162 | "add_lora(model)\n", 163 | "# then we can load the lora parameters\n", 164 | "# strict=False is needed because we are loading a subset of the parameters\n", 165 | "_ = model.load_state_dict(state_dict_to_save, strict=False) \n", 166 | "y = model(x)\n", 167 | "assert torch.allclose(y, Y1)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 10, 173 | "id": "9f0c8570", 174 | "metadata": { 175 | "lines_to_next_cell": 0 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "# we can merge it to make it a normal linear layer, so there is no overhead for inference\n", 180 | "merge_lora(model)\n", 181 | "y = model(x)\n", 182 | "assert torch.allclose(y, Y1)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 11, 188 | "id": "ee283143", 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "data": { 193 | "text/plain": [ 194 | "Sequential(\n", 195 | " (0): Linear(in_features=5, out_features=7, bias=True)\n", 196 | " (1): Linear(in_features=7, out_features=3, bias=True)\n", 197 | ")" 198 | ] 199 | }, 200 | "execution_count": 11, 201 | "metadata": {}, 202 | "output_type": "execute_result" 203 | } 204 | ], 205 | "source": [ 206 | "# model now has no lora parameters\n", 207 | "model" 208 | ] 209 | }, 210 | { 211 | "attachments": {}, 212 | "cell_type": "markdown", 213 | "id": "f3c246e1", 214 | "metadata": {}, 215 | "source": [ 216 | "## Training a model" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 12, 222 | "id": "edfaee1e", 223 | "metadata": { 224 | "lines_to_next_cell": 0 225 | }, 226 | "outputs": [], 227 | "source": [ 228 | "model = torch.nn.Linear(in_features=5, out_features=3)\n", 229 | "# Step 1: Add LoRA to the model\n", 230 | "add_lora(model)\n", 231 | "\n", 232 | "# Step 2: Collect the parameters, pass them to the optimizer\n", 233 | "\n", 234 | "parameters = [\n", 235 | " {\"params\": list(get_lora_params(model))},\n", 236 | "]\n", 237 | "optimizer = torch.optim.AdamW(parameters, lr=1e-3)\n", 238 | "\n", 239 | "# Step 3: Train the model\n", 240 | "# ...\n", 241 | "# simulate training, update the LoRA parameters\n", 242 | "model.apply(apply_to_lora(lambda x: torch.nn.init.normal_(x.lora_A)))\n", 243 | "model.apply(apply_to_lora(lambda x: torch.nn.init.normal_(x.lora_B)))\n", 244 | "\n", 245 | "# Step 4: export the LoRA parameters\n", 246 | "state_dict = model.state_dict()\n", 247 | "lora_state_dict = {k: v for k, v in state_dict.items() if name_is_lora(k)}" 248 | ] 249 | }, 250 | { 251 | "attachments": {}, 252 | "cell_type": "markdown", 253 | "id": "539e7d19", 254 | "metadata": {}, 255 | "source": [ 256 | "## Loading and Inferencing with LoRA" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 13, 262 | "id": "1a9836de", 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "# Step 1: Add LoRA to your model\n", 267 | "add_lora(model)\n", 268 | "\n", 269 | "# Step 2: Load the LoRA parameters\n", 270 | "_ = model.load_state_dict(lora_state_dict, strict=False)\n", 271 | "\n", 272 | "# Step 3: Merge the LoRA parameters into the model\n", 273 | "merge_lora(model)" 274 | ] 275 | }, 276 | { 277 | "attachments": {}, 278 | "cell_type": "markdown", 279 | "id": "ccba9d68", 280 | "metadata": {}, 281 | "source": [ 282 | "## Inferencing with multiple LoRA models" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 14, 288 | "id": "a0ef4b28", 289 | "metadata": {}, 290 | "outputs": [], 291 | "source": [ 292 | "# to avoid re-adding lora to the model when rerun the cell, remove lora first \n", 293 | "remove_lora(model)\n", 294 | "# Step 1: Add LoRA to your model\n", 295 | "add_lora(model)\n", 296 | "\n", 297 | "# Step 2: Load the LoRA parameters\n", 298 | "\n", 299 | "# fake 3 sets of LoRA parameters\n", 300 | "lora_state_dict_0 = lora_state_dict\n", 301 | "lora_state_dict_1 = {k: torch.ones_like(v) for k, v in lora_state_dict.items()}\n", 302 | "lora_state_dict_2 = {k: torch.zeros_like(v) for k, v in lora_state_dict.items()}\n", 303 | "lora_state_dicts = [lora_state_dict_0, lora_state_dict_1, lora_state_dict_2]\n", 304 | "\n", 305 | "load_multiple_lora(model, lora_state_dicts)\n", 306 | "\n", 307 | "# Step 3: Select which LoRA to use at inference time\n", 308 | "Y0 = select_lora(model, 0)(x)\n", 309 | "Y1 = select_lora(model, 1)(x)\n", 310 | "Y2 = select_lora(model, 2)(x)" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 15, 316 | "id": "c67602a3", 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "data": { 321 | "text/plain": [ 322 | "(tensor([[-0.6234, 2.2759, -0.9541]]),\n", 323 | " tensor([[2.4910, 3.4921, 1.5635]]),\n", 324 | " tensor([[ 0.3378, 1.3389, -0.5897]]))" 325 | ] 326 | }, 327 | "execution_count": 15, 328 | "metadata": {}, 329 | "output_type": "execute_result" 330 | } 331 | ], 332 | "source": [ 333 | "Y0, Y1, Y2" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 16, 339 | "id": "537c5c6c", 340 | "metadata": {}, 341 | "outputs": [ 342 | { 343 | "name": "stdout", 344 | "output_type": "stream", 345 | "text": [ 346 | "tensor([[-0.6234, 2.2759, -0.9541]])\n", 347 | "tensor([[2.4910, 3.4921, 1.5635]])\n", 348 | "tensor([[ 0.3378, 1.3389, -0.5897]])\n" 349 | ] 350 | } 351 | ], 352 | "source": [ 353 | "remove_lora(model)\n", 354 | "init_state_dict = model.state_dict()\n", 355 | "# verify that it's the same as if we load the lora parameters one by one\n", 356 | "for state_dict in lora_state_dicts:\n", 357 | " remove_lora(model)\n", 358 | " _ = model.load_state_dict(init_state_dict, strict=False)\n", 359 | " add_lora(model)\n", 360 | " _ = model.load_state_dict(state_dict, strict=False)\n", 361 | " merge_lora(model)\n", 362 | " y = model(x)\n", 363 | " print(y)" 364 | ] 365 | } 366 | ], 367 | "metadata": { 368 | "jupytext": { 369 | "cell_metadata_filter": "-all", 370 | "main_language": "python", 371 | "notebook_metadata_filter": "-all" 372 | }, 373 | "kernelspec": { 374 | "display_name": "lora", 375 | "language": "python", 376 | "name": "python3" 377 | }, 378 | "language_info": { 379 | "codemirror_mode": { 380 | "name": "ipython", 381 | "version": 3 382 | }, 383 | "file_extension": ".py", 384 | "mimetype": "text/x-python", 385 | "name": "python", 386 | "nbconvert_exporter": "python", 387 | "pygments_lexer": "ipython3", 388 | "version": "3.9.16" 389 | }, 390 | "vscode": { 391 | "interpreter": { 392 | "hash": "cd38cab5b092fbce1866c43acaed152c77b80a12cd5e2b7fb23112c1a171e061" 393 | } 394 | } 395 | }, 396 | "nbformat": 4, 397 | "nbformat_minor": 5 398 | } 399 | -------------------------------------------------------------------------------- /minlora/__init__.py: -------------------------------------------------------------------------------- 1 | from minlora.model import LoRAParametrization, add_lora, default_lora_config, merge_lora, remove_lora 2 | from minlora.utils import ( 3 | apply_to_lora, 4 | disable_lora, 5 | enable_lora, 6 | get_bias_params, 7 | get_lora_params, 8 | get_lora_state_dict, 9 | load_multiple_lora, 10 | name_is_lora, 11 | select_lora, 12 | tie_weights, 13 | untie_weights, 14 | ) 15 | -------------------------------------------------------------------------------- /minlora/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | References: 3 | 1) the official LoRA implementation released by Microsoft: 4 | https://github.com/microsoft/LoRA/blob/main/loralib/layers.py 5 | """ 6 | 7 | import math 8 | from functools import partial 9 | 10 | import torch 11 | import torch.nn.utils.parametrize as parametrize 12 | from torch import nn 13 | 14 | 15 | class LoRAParametrization(nn.Module): 16 | def __init__(self, fan_in, fan_out, fan_in_fan_out=False, rank=4, lora_dropout_p=0.0, lora_alpha=1): 17 | super().__init__() 18 | # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x 19 | # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings 20 | self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x) 21 | self.lora_A = nn.Parameter(torch.zeros(self.swap((rank, fan_in)))) 22 | self.lora_B = nn.Parameter(torch.zeros(self.swap((fan_out, rank)))) 23 | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) 24 | self.lora_alpha, self.rank = lora_alpha, rank 25 | self.scaling = lora_alpha / rank 26 | self.lora_dropout = nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x 27 | self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x 28 | self.register_buffer("lora_dropout_mask", torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype)) 29 | self.forward_fn = self.lora_forward 30 | 31 | def _dropout(self, A): 32 | # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x 33 | return A * self.lora_dropout(self.lora_dropout_mask) 34 | 35 | def lora_forward(self, X): 36 | return X + torch.matmul(*self.swap((self.lora_B, self.dropout_fn(self.lora_A)))).view(X.shape) * self.scaling 37 | 38 | def forward(self, X): 39 | return self.forward_fn(X) 40 | 41 | def disable_lora(self): 42 | self.forward_fn = lambda x: x 43 | 44 | def enable_lora(self): 45 | self.forward_fn = self.lora_forward 46 | 47 | @classmethod 48 | def from_linear(cls, layer, rank=4, lora_dropout_p=0.0, lora_alpha=1): 49 | fan_out, fan_in = layer.weight.shape 50 | return cls( 51 | fan_in, fan_out, fan_in_fan_out=False, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha 52 | ) 53 | 54 | @classmethod 55 | def from_conv2d(cls, layer, rank=4, lora_dropout_p=0.0, lora_alpha=1): 56 | fan_out, fan_in = layer.weight.view(layer.weight.shape[0], -1).shape 57 | return cls( 58 | fan_in, fan_out, fan_in_fan_out=False, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha 59 | ) 60 | 61 | @classmethod 62 | def from_embedding(cls, layer, rank=4, lora_dropout_p=0.0, lora_alpha=1): 63 | fan_in, fan_out = layer.weight.shape 64 | return cls( 65 | fan_in, fan_out, fan_in_fan_out=True, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha 66 | ) 67 | 68 | 69 | default_lora_config = { # specify which layers to add lora to, by default only add to linear layers 70 | nn.Linear: { 71 | "weight": partial(LoRAParametrization.from_linear, rank=4), 72 | }, 73 | } 74 | 75 | 76 | def apply_lora(layer, register=True, merge=False, lora_config=default_lora_config): 77 | """add lora parametrization to a layer, designed to be used with model.apply""" 78 | if register: 79 | if type(layer) in lora_config: 80 | for attr_name, parametrization in lora_config[type(layer)].items(): 81 | parametrize.register_parametrization(layer, attr_name, parametrization(layer)) 82 | else: # this will remove all parametrizations, use with caution 83 | if hasattr(layer, "parametrizations"): 84 | for attr_name in layer.parametrizations.keys(): 85 | parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=merge) 86 | 87 | 88 | def add_lora(model, lora_config=default_lora_config): 89 | """add lora parametrization to all layers in a model. Calling it twice will add lora twice""" 90 | model.apply(partial(apply_lora, lora_config=lora_config)) 91 | 92 | 93 | def add_lora_by_name(model, target_module_names, lora_config=default_lora_config): 94 | """Add LoRA parameterization to specific layers in a model by names""" 95 | for name, layer in model.named_modules(): 96 | if any([m in name for m in target_module_names]): 97 | add_lora(layer, lora_config=lora_config) 98 | 99 | 100 | def merge_lora(model): 101 | """merge lora parametrization to all layers in a model. This will remove all parametrization""" 102 | model.apply(partial(apply_lora, register=False, merge=True)) 103 | 104 | 105 | def remove_lora(model): 106 | """remove lora parametrization to all layers in a model. This will remove all parametrization""" 107 | model.apply(partial(apply_lora, register=False, merge=False)) 108 | -------------------------------------------------------------------------------- /minlora/utils.py: -------------------------------------------------------------------------------- 1 | from minlora import LoRAParametrization 2 | from torch import nn 3 | 4 | 5 | def apply_to_lora(fn): 6 | """apply a function to LoRAParametrization layers, designed to be used with model.apply""" 7 | 8 | def apply_fn(layer): 9 | if isinstance(layer, LoRAParametrization): 10 | fn(layer) 11 | 12 | return apply_fn 13 | 14 | 15 | enable_lora = lambda model: model.apply(apply_to_lora(lambda x: x.enable_lora())) 16 | disable_lora = lambda model: model.apply(apply_to_lora(lambda x: x.disable_lora())) 17 | 18 | 19 | # ------------------- helper function for collecting parameters for training/saving ------------------- 20 | 21 | 22 | def name_is_lora(name): 23 | return ( 24 | len(name.split(".")) >= 4 25 | and (name.split(".")[-4]) == "parametrizations" 26 | and name.split(".")[-1] in ["lora_A", "lora_B"] 27 | ) 28 | 29 | 30 | def name_is_bias(name): 31 | return name.split(".")[-1] == "bias" 32 | 33 | 34 | def get_params_by_name(model, print_shapes=False, name_filter=None): 35 | for n, p in model.named_parameters(): 36 | if name_filter is None or name_filter(n): 37 | if print_shapes: 38 | print(n, p.shape) 39 | yield p 40 | 41 | 42 | def get_lora_params(model, print_shapes=False): 43 | return get_params_by_name(model, print_shapes=print_shapes, name_filter=name_is_lora) 44 | 45 | 46 | def get_bias_params(model, print_shapes=False): 47 | return get_params_by_name(model, print_shapes=print_shapes, name_filter=name_is_bias) 48 | 49 | 50 | def get_lora_state_dict(model): 51 | return {k: v for k, v in model.state_dict().items() if name_is_lora(k)} 52 | 53 | 54 | # ------------------- helper function for inferencing with multiple lora ------------------- 55 | 56 | 57 | def _prepare_for_multiple_lora(lora_layer): 58 | lora_layer.lora_As = [] 59 | lora_layer.lora_Bs = [] 60 | 61 | 62 | def _append_lora(lora_layer): 63 | lora_layer.lora_As.append(nn.Parameter(lora_layer.lora_A.clone())) 64 | lora_layer.lora_Bs.append(nn.Parameter(lora_layer.lora_B.clone())) 65 | 66 | 67 | def load_multiple_lora(model, lora_state_dicts): 68 | model.apply(apply_to_lora(_prepare_for_multiple_lora)) 69 | for state_dict in lora_state_dicts: 70 | _ = model.load_state_dict(state_dict, strict=False) 71 | model.apply(apply_to_lora(_append_lora)) 72 | return model 73 | 74 | 75 | def _select_lora(lora_layer, index): 76 | lora_layer.lora_A = lora_layer.lora_As[index] 77 | lora_layer.lora_B = lora_layer.lora_Bs[index] 78 | 79 | 80 | def select_lora(model, index): 81 | model.apply(apply_to_lora(lambda x: _select_lora(x, index))) 82 | return model 83 | 84 | 85 | # ------------------- helper function for tying and untieing weights ------------------- 86 | 87 | 88 | def tie_weights(linear: nn.Linear, embedding: nn.Embedding): 89 | """tie the weights of the linear layer and the embedding layer both with the same lora""" 90 | # this line below is optional if the original is already tied 91 | embedding.parametrizations.weight.original = linear.parametrizations.weight.original 92 | embedding.parametrizations.weight[0].lora_A = linear.parametrizations.weight[0].lora_B 93 | embedding.parametrizations.weight[0].lora_B = linear.parametrizations.weight[0].lora_A 94 | 95 | 96 | def untie_weights(linear: nn.Linear, embedding: nn.Embedding): 97 | """untie the weights of the linear layer and the embedding layer""" 98 | embedding.parametrizations.weight.original = nn.Parameter(embedding.weight.original.clone()) 99 | embedding.parametrizations.weight[0].lora_A = nn.Parameter(embedding.parametrizations.weight[0].lora_A.clone()) 100 | embedding.parametrizations.weight[0].lora_B = nn.Parameter(embedding.parametrizations.weight[0].lora_B.clone()) 101 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="minLoRA", 5 | version="0.1.0", 6 | author="Jonathan Chang", 7 | packages=["minlora"], 8 | description="A PyTorch re-implementation of LoRA", 9 | license="MIT", 10 | install_requires=[ 11 | "torch>=1.9.0", 12 | ], 13 | ) 14 | --------------------------------------------------------------------------------