├── .gitignore ├── LICENSE ├── README.md ├── Using-LinearDoRA.ipynb ├── Using-LinearDoRAMerged.ipynb └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sebastian Raschka 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LoRA and DoRA From Scratch 2 | 3 | LoRA and DoRA from Scratch implementations as supplementary material for the article [https://magazine.sebastianraschka.com/p/lora-and-dora-from-scratch](https://magazine.sebastianraschka.com/p/lora-and-dora-from-scratch). 4 | -------------------------------------------------------------------------------- /Using-LinearDoRA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d2abd10e-e63e-4904-badf-5a16409503b1", 6 | "metadata": {}, 7 | "source": [ 8 | "# LoRA and DoRA from Scratch -- A Multilayer Perceptron Example\n", 9 | "\n", 10 | "## Using the LinearLoRA and LinearDoRA classes" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "263e27da-47c7-4030-83c6-bf5f7e8bef74", 16 | "metadata": {}, 17 | "source": [ 18 | "This code notebook illustrates how LoRA ([https://arxiv.org/abs/2106.09685](https://arxiv.org/abs/2106.09685)) and DoRA ([https://arxiv.org/abs/2402.09353](https://arxiv.org/abs/2402.09353)) work by implementing these methods from scratch.\n", 19 | "\n", 20 | "Note that this is a companion notebook to my blog article [Improving LoRA: Implementing Weight-Decomposed Low-Rank Adaptation (DoRA) from Scratch](https://magazine.sebastianraschka.com/p/lora-and-dora-from-scratch)." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "id": "9e4208cf-39b3-4a0b-a0e0-d7679a2d60c3", 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "Author: Sebastian Raschka\n", 34 | "\n", 35 | "Python implementation: CPython\n", 36 | "Python version : 3.10.6\n", 37 | "IPython version : 8.12.0\n", 38 | "\n", 39 | "torch: 2.1.0\n", 40 | "\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "%load_ext watermark\n", 46 | "%watermark -a 'Sebastian Raschka' -v -p torch" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "id": "c1c52f02-94fb-4f45-902e-79126e27347d", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "import time\n", 57 | "import numpy as np\n", 58 | "from torchvision import datasets\n", 59 | "from torchvision import transforms\n", 60 | "from torch.utils.data import DataLoader\n", 61 | "import torch.nn.functional as F\n", 62 | "import torch.nn as nn\n", 63 | "import torch\n", 64 | "\n", 65 | "\n", 66 | "if torch.cuda.is_available():\n", 67 | " torch.backends.cudnn.deterministic = True" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "id": "629ec66a-eb81-40a5-ae3d-d5c1d2a7e390", 73 | "metadata": {}, 74 | "source": [ 75 | "## Settings and Dataset" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 3, 81 | "id": "4ade5e86-8bd8-4a35-8db1-44451601b292", 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "Image batch dimensions: torch.Size([64, 1, 28, 28])\n", 89 | "Image label dimensions: torch.Size([64])\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "##########################\n", 95 | "### SETTINGS\n", 96 | "##########################\n", 97 | "\n", 98 | "# Device\n", 99 | "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 100 | "BATCH_SIZE = 64\n", 101 | "\n", 102 | "##########################\n", 103 | "### MNIST DATASET\n", 104 | "##########################\n", 105 | "\n", 106 | "# Note transforms.ToTensor() scales input images\n", 107 | "# to 0-1 range\n", 108 | "train_dataset = datasets.MNIST(root='data', \n", 109 | " train=True, \n", 110 | " transform=transforms.ToTensor(),\n", 111 | " download=True)\n", 112 | "\n", 113 | "test_dataset = datasets.MNIST(root='data', \n", 114 | " train=False, \n", 115 | " transform=transforms.ToTensor())\n", 116 | "\n", 117 | "\n", 118 | "train_loader = DataLoader(dataset=train_dataset, \n", 119 | " batch_size=BATCH_SIZE, \n", 120 | " shuffle=True)\n", 121 | "\n", 122 | "test_loader = DataLoader(dataset=test_dataset, \n", 123 | " batch_size=BATCH_SIZE, \n", 124 | " shuffle=False)\n", 125 | "\n", 126 | "# Checking the dataset\n", 127 | "for images, labels in train_loader: \n", 128 | " print('Image batch dimensions:', images.shape)\n", 129 | " print('Image label dimensions:', labels.shape)\n", 130 | " break" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "id": "394e5da8-2978-40f0-bca7-b79e8e35734f", 136 | "metadata": {}, 137 | "source": [ 138 | "# Multilayer Perceptron Model (Without LoRA and DoRA)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 4, 144 | "id": "7e905c42-7f59-4a08-b6c5-a10f99f33e9e", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "##########################\n", 149 | "### MODEL\n", 150 | "##########################\n", 151 | "\n", 152 | "# Hyperparameters\n", 153 | "random_seed = 123\n", 154 | "learning_rate = 0.005\n", 155 | "num_epochs = 2\n", 156 | "\n", 157 | "# Architecture\n", 158 | "num_features = 784\n", 159 | "num_hidden_1 = 128\n", 160 | "num_hidden_2 = 256\n", 161 | "num_classes = 10\n", 162 | "\n", 163 | "\n", 164 | "class MultilayerPerceptron(nn.Module):\n", 165 | "\n", 166 | " def __init__(self, num_features, num_hidden_1, num_hidden_2, num_classes):\n", 167 | " super().__init__()\n", 168 | "\n", 169 | " self.layers = nn.Sequential(\n", 170 | " nn.Linear(num_features, num_hidden_1),\n", 171 | " nn.ReLU(),\n", 172 | " nn.Linear(num_hidden_1, num_hidden_2),\n", 173 | " nn.ReLU(),\n", 174 | " nn.Linear(num_hidden_2, num_classes)\n", 175 | " )\n", 176 | "\n", 177 | " def forward(self, x):\n", 178 | " x = self.layers(x)\n", 179 | " return x\n", 180 | "\n", 181 | "\n", 182 | "torch.manual_seed(random_seed)\n", 183 | "model_pretrained = MultilayerPerceptron(\n", 184 | " num_features=num_features,\n", 185 | " num_hidden_1=num_hidden_1,\n", 186 | " num_hidden_2=num_hidden_2, \n", 187 | " num_classes=num_classes\n", 188 | ")\n", 189 | "\n", 190 | "model_pretrained.to(DEVICE)\n", 191 | "optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=learning_rate)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 5, 197 | "id": "cf31624a-d950-402f-a564-2e7fb63db8a4", 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "def compute_accuracy(model, data_loader, device):\n", 202 | " model.eval()\n", 203 | " correct_pred, num_examples = 0, 0\n", 204 | " with torch.no_grad():\n", 205 | " for features, targets in data_loader:\n", 206 | " features = features.view(-1, 28*28).to(device)\n", 207 | " targets = targets.to(device)\n", 208 | " logits = model(features)\n", 209 | " _, predicted_labels = torch.max(logits, 1)\n", 210 | " num_examples += targets.size(0)\n", 211 | " correct_pred += (predicted_labels == targets).sum()\n", 212 | " return correct_pred.float()/num_examples * 100\n", 213 | "\n", 214 | "\n", 215 | "def train(num_epochs, model, optimizer, train_loader, device):\n", 216 | "\n", 217 | " start_time = time.time()\n", 218 | " for epoch in range(num_epochs):\n", 219 | " model.train()\n", 220 | " for batch_idx, (features, targets) in enumerate(train_loader):\n", 221 | "\n", 222 | " features = features.view(-1, 28*28).to(device)\n", 223 | " targets = targets.to(device)\n", 224 | "\n", 225 | " # FORWARD AND BACK PROP\n", 226 | " logits = model(features)\n", 227 | " loss = F.cross_entropy(logits, targets)\n", 228 | " optimizer.zero_grad()\n", 229 | "\n", 230 | " loss.backward()\n", 231 | "\n", 232 | " # UPDATE MODEL PARAMETERS\n", 233 | " optimizer.step()\n", 234 | "\n", 235 | " # LOGGING\n", 236 | " if not batch_idx % 400:\n", 237 | " print('Epoch: %03d/%03d | Batch %03d/%03d | Loss: %.4f'\n", 238 | " % (epoch+1, num_epochs, batch_idx,\n", 239 | " len(train_loader), loss))\n", 240 | "\n", 241 | " with torch.set_grad_enabled(False):\n", 242 | " print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n", 243 | " epoch+1, num_epochs,\n", 244 | " compute_accuracy(model, train_loader, device)))\n", 245 | "\n", 246 | " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", 247 | "\n", 248 | " print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 6, 254 | "id": "f47cfe4e-65eb-440e-b922-17c6dee7d7e2", 255 | "metadata": {}, 256 | "outputs": [ 257 | { 258 | "name": "stdout", 259 | "output_type": "stream", 260 | "text": [ 261 | "Epoch: 001/002 | Batch 000/938 | Loss: 2.2971\n", 262 | "Epoch: 001/002 | Batch 400/938 | Loss: 0.1770\n", 263 | "Epoch: 001/002 | Batch 800/938 | Loss: 0.1582\n", 264 | "Epoch: 001/002 training accuracy: 95.62%\n", 265 | "Time elapsed: 0.07 min\n", 266 | "Epoch: 002/002 | Batch 000/938 | Loss: 0.0501\n", 267 | "Epoch: 002/002 | Batch 400/938 | Loss: 0.0408\n", 268 | "Epoch: 002/002 | Batch 800/938 | Loss: 0.0828\n", 269 | "Epoch: 002/002 training accuracy: 97.22%\n", 270 | "Time elapsed: 0.15 min\n", 271 | "Total Training Time: 0.15 min\n", 272 | "Test accuracy: 96.41%\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "train(num_epochs, model_pretrained, optimizer_pretrained, train_loader, DEVICE)\n", 278 | "print(f'Test accuracy: {compute_accuracy(model_pretrained, test_loader, DEVICE):.2f}%')" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "id": "fb3480b9-aea5-411e-b252-d7fc8a5dd21d", 284 | "metadata": {}, 285 | "source": [ 286 | "# Multilayer Perceptron with LoRA and DoRA" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "id": "36b9d281-22ba-4120-af95-f6b95adcaa03", 292 | "metadata": {}, 293 | "source": [ 294 | "## Modify model by injecting LoRA and DoRA layers" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 7, 300 | "id": "215795c5-c0d4-4886-b4d6-a5a0e7cc8c7e", 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "class LoRALayer(nn.Module):\n", 305 | " def __init__(self, in_dim, out_dim, rank, alpha):\n", 306 | " super().__init__()\n", 307 | " std_dev = 1 / torch.sqrt(torch.tensor(rank).float())\n", 308 | " self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)\n", 309 | " self.B = nn.Parameter(torch.zeros(rank, out_dim))\n", 310 | " self.alpha = alpha\n", 311 | "\n", 312 | " def forward(self, x):\n", 313 | " x = self.alpha * (x @ self.A @ self.B)\n", 314 | " return x\n", 315 | "\n", 316 | "\n", 317 | "class LinearWithLoRA(nn.Module):\n", 318 | " def __init__(self, linear, rank, alpha):\n", 319 | " super().__init__()\n", 320 | " self.linear = linear\n", 321 | " self.lora = LoRALayer(\n", 322 | " linear.in_features, linear.out_features, rank, alpha\n", 323 | " )\n", 324 | "\n", 325 | " def forward(self, x):\n", 326 | " return self.linear(x) + self.lora(x)\n", 327 | "\n", 328 | "\n", 329 | "class LinearWithDoRA(nn.Module):\n", 330 | " def __init__(self, linear, rank, alpha):\n", 331 | " super().__init__()\n", 332 | " self.linear = linear\n", 333 | " self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)\n", 334 | " self.m = nn.Parameter(torch.ones(1, linear.out_features))\n", 335 | "\n", 336 | " def forward(self, x):\n", 337 | " linear_output = self.linear(x)\n", 338 | " lora_output = self.lora(x)\n", 339 | " lora_output_norm = lora_output / (lora_output.norm(p=2, dim=1, keepdim=True) + 1e-9)\n", 340 | " dora_modification = self.m * lora_output_norm\n", 341 | " return linear_output + dora_modification\n" 342 | ] 343 | }, 344 | { 345 | "cell_type": "markdown", 346 | "id": "06cfda10-ac5a-4958-ad42-910ba73aa639", 347 | "metadata": {}, 348 | "source": [ 349 | "Since the B matrix is initialized to 0's, the initial LoRA and DoRA layers (before training) should not affect the outputs of the forward pass, which we can confirm as follows:" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 8, 355 | "id": "0441c93f-0ee5-4003-acc3-f24541f06c66", 356 | "metadata": {}, 357 | "outputs": [ 358 | { 359 | "name": "stdout", 360 | "output_type": "stream", 361 | "text": [ 362 | "Original output: tensor([[0.6639, 0.4487]], grad_fn=)\n" 363 | ] 364 | } 365 | ], 366 | "source": [ 367 | "torch.manual_seed(123)\n", 368 | "\n", 369 | "layer = nn.Linear(10, 2)\n", 370 | "x = torch.randn((1, 10))\n", 371 | "\n", 372 | "print(\"Original output:\", layer(x))" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 9, 378 | "id": "b132184a-f87e-423f-850a-2dc44fe76770", 379 | "metadata": {}, 380 | "outputs": [ 381 | { 382 | "name": "stdout", 383 | "output_type": "stream", 384 | "text": [ 385 | "LoRA output: tensor([[0.6639, 0.4487]], grad_fn=)\n" 386 | ] 387 | } 388 | ], 389 | "source": [ 390 | "layer_lora_1 = LinearWithLoRA(layer, rank=2, alpha=4)\n", 391 | "\n", 392 | "print(\"LoRA output:\", layer_lora_1(x))" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": 10, 398 | "id": "91bbc702-1955-457f-bfa9-0a547a8c41a5", 399 | "metadata": {}, 400 | "outputs": [ 401 | { 402 | "name": "stdout", 403 | "output_type": "stream", 404 | "text": [ 405 | "DoRA output: tensor([[0.6639, 0.4487]], grad_fn=)\n" 406 | ] 407 | } 408 | ], 409 | "source": [ 410 | "layer_dora_1 = LinearWithDoRA(layer, rank=2, alpha=4)\n", 411 | "\n", 412 | "print(\"DoRA output:\", layer_dora_1(x))" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 11, 418 | "id": "dc66ffa1-5822-4833-b636-d3a8170e84a2", 419 | "metadata": {}, 420 | "outputs": [ 421 | { 422 | "data": { 423 | "text/plain": [ 424 | "MultilayerPerceptron(\n", 425 | " (layers): Sequential(\n", 426 | " (0): Linear(in_features=784, out_features=128, bias=True)\n", 427 | " (1): ReLU()\n", 428 | " (2): Linear(in_features=128, out_features=256, bias=True)\n", 429 | " (3): ReLU()\n", 430 | " (4): Linear(in_features=256, out_features=10, bias=True)\n", 431 | " )\n", 432 | ")" 433 | ] 434 | }, 435 | "execution_count": 11, 436 | "metadata": {}, 437 | "output_type": "execute_result" 438 | } 439 | ], 440 | "source": [ 441 | "model_pretrained" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": 12, 447 | "id": "b00a7e8f-09ff-499e-b593-3f6dac87f1bf", 448 | "metadata": {}, 449 | "outputs": [], 450 | "source": [ 451 | "import copy\n", 452 | "\n", 453 | "model_lora = copy.deepcopy(model_pretrained)\n", 454 | "model_dora = copy.deepcopy(model_pretrained)" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": 13, 460 | "id": "b1e3ef6f-255c-4c71-9da5-7d06d8c439a3", 461 | "metadata": {}, 462 | "outputs": [ 463 | { 464 | "data": { 465 | "text/plain": [ 466 | "MultilayerPerceptron(\n", 467 | " (layers): Sequential(\n", 468 | " (0): LinearWithLoRA(\n", 469 | " (linear): Linear(in_features=784, out_features=128, bias=True)\n", 470 | " (lora): LoRALayer()\n", 471 | " )\n", 472 | " (1): ReLU()\n", 473 | " (2): LinearWithLoRA(\n", 474 | " (linear): Linear(in_features=128, out_features=256, bias=True)\n", 475 | " (lora): LoRALayer()\n", 476 | " )\n", 477 | " (3): ReLU()\n", 478 | " (4): LinearWithLoRA(\n", 479 | " (linear): Linear(in_features=256, out_features=10, bias=True)\n", 480 | " (lora): LoRALayer()\n", 481 | " )\n", 482 | " )\n", 483 | ")" 484 | ] 485 | }, 486 | "execution_count": 13, 487 | "metadata": {}, 488 | "output_type": "execute_result" 489 | } 490 | ], 491 | "source": [ 492 | "model_lora.layers[0] = LinearWithLoRA(model_lora.layers[0], rank=4, alpha=8)\n", 493 | "model_lora.layers[2] = LinearWithLoRA(model_lora.layers[2], rank=4, alpha=8)\n", 494 | "model_lora.layers[4] = LinearWithLoRA(model_lora.layers[4], rank=4, alpha=8)\n", 495 | "\n", 496 | "model_lora.to(DEVICE)\n", 497 | "optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)\n", 498 | "model_lora" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": 14, 504 | "id": "ec5a4ef4-bfc0-4fdb-bf7d-114aaf89795a", 505 | "metadata": {}, 506 | "outputs": [ 507 | { 508 | "data": { 509 | "text/plain": [ 510 | "MultilayerPerceptron(\n", 511 | " (layers): Sequential(\n", 512 | " (0): LinearWithDoRA(\n", 513 | " (linear): Linear(in_features=784, out_features=128, bias=True)\n", 514 | " (lora): LoRALayer()\n", 515 | " )\n", 516 | " (1): ReLU()\n", 517 | " (2): LinearWithDoRA(\n", 518 | " (linear): Linear(in_features=128, out_features=256, bias=True)\n", 519 | " (lora): LoRALayer()\n", 520 | " )\n", 521 | " (3): ReLU()\n", 522 | " (4): LinearWithDoRA(\n", 523 | " (linear): Linear(in_features=256, out_features=10, bias=True)\n", 524 | " (lora): LoRALayer()\n", 525 | " )\n", 526 | " )\n", 527 | ")" 528 | ] 529 | }, 530 | "execution_count": 14, 531 | "metadata": {}, 532 | "output_type": "execute_result" 533 | } 534 | ], 535 | "source": [ 536 | "model_dora.layers[0] = LinearWithDoRA(model_dora.layers[0], rank=4, alpha=8)\n", 537 | "model_dora.layers[2] = LinearWithDoRA(model_dora.layers[2], rank=4, alpha=8)\n", 538 | "model_dora.layers[4] = LinearWithDoRA(model_dora.layers[4], rank=4, alpha=8)\n", 539 | "\n", 540 | "model_dora.to(DEVICE)\n", 541 | "optimizer_dora = torch.optim.Adam(model_dora.parameters(), lr=learning_rate)\n", 542 | "model_dora" 543 | ] 544 | }, 545 | { 546 | "cell_type": "markdown", 547 | "id": "9756742d-f574-400a-8d4e-cc55233df83c", 548 | "metadata": {}, 549 | "source": [ 550 | "We just initialized the LoRA & DoRA layers but haven't trained the LoRA layers yet, so a model with and without initial LoRA weights should have the same predictive performance:" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": 15, 556 | "id": "d2ac620b-2fdf-4b94-92bb-f8d00e640306", 557 | "metadata": {}, 558 | "outputs": [ 559 | { 560 | "name": "stdout", 561 | "output_type": "stream", 562 | "text": [ 563 | "Test accuracy orig model: 96.41%\n", 564 | "Test accuracy LoRA model: 96.41%\n", 565 | "Test accuracy DoRA model: 96.41%\n" 566 | ] 567 | } 568 | ], 569 | "source": [ 570 | "print(f'Test accuracy orig model: {compute_accuracy(model_pretrained, test_loader, DEVICE):.2f}%')\n", 571 | "print(f'Test accuracy LoRA model: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')\n", 572 | "print(f'Test accuracy DoRA model: {compute_accuracy(model_dora, test_loader, DEVICE):.2f}%')" 573 | ] 574 | }, 575 | { 576 | "cell_type": "markdown", 577 | "id": "4ceed732-7989-4f01-a5a1-1036eb41512d", 578 | "metadata": {}, 579 | "source": [ 580 | "## Train model with LoRA" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": 16, 586 | "id": "a35d4c20-f754-4e82-85a1-ad19a30b3dfe", 587 | "metadata": {}, 588 | "outputs": [], 589 | "source": [ 590 | "def freeze_linear_layers(model):\n", 591 | " for child in model.children():\n", 592 | " if isinstance(child, nn.Linear):\n", 593 | " for param in child.parameters():\n", 594 | " param.requires_grad = False\n", 595 | " else:\n", 596 | " # Recursively freeze linear layers in children modules\n", 597 | " freeze_linear_layers(child)" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": 17, 603 | "id": "88454690-abe6-49de-986e-9a6fe7883000", 604 | "metadata": {}, 605 | "outputs": [ 606 | { 607 | "name": "stdout", 608 | "output_type": "stream", 609 | "text": [ 610 | "layers.0.linear.weight: False\n", 611 | "layers.0.linear.bias: False\n", 612 | "layers.0.lora.A: True\n", 613 | "layers.0.lora.B: True\n", 614 | "layers.2.linear.weight: False\n", 615 | "layers.2.linear.bias: False\n", 616 | "layers.2.lora.A: True\n", 617 | "layers.2.lora.B: True\n", 618 | "layers.4.linear.weight: False\n", 619 | "layers.4.linear.bias: False\n", 620 | "layers.4.lora.A: True\n", 621 | "layers.4.lora.B: True\n" 622 | ] 623 | } 624 | ], 625 | "source": [ 626 | "freeze_linear_layers(model_lora)\n", 627 | "\n", 628 | "# Check if linear layers are frozen\n", 629 | "for name, param in model_lora.named_parameters():\n", 630 | " print(f\"{name}: {param.requires_grad}\")" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": 18, 636 | "id": "1b807c7b-8d4a-4a1e-8a56-42bbdbc82fed", 637 | "metadata": {}, 638 | "outputs": [ 639 | { 640 | "name": "stdout", 641 | "output_type": "stream", 642 | "text": [ 643 | "Epoch: 001/002 | Batch 000/938 | Loss: 0.0223\n", 644 | "Epoch: 001/002 | Batch 400/938 | Loss: 0.3377\n", 645 | "Epoch: 001/002 | Batch 800/938 | Loss: 0.2396\n", 646 | "Epoch: 001/002 training accuracy: 96.89%\n", 647 | "Time elapsed: 0.07 min\n", 648 | "Epoch: 002/002 | Batch 000/938 | Loss: 0.3162\n", 649 | "Epoch: 002/002 | Batch 400/938 | Loss: 0.1081\n", 650 | "Epoch: 002/002 | Batch 800/938 | Loss: 0.1336\n", 651 | "Epoch: 002/002 training accuracy: 97.84%\n", 652 | "Time elapsed: 0.14 min\n", 653 | "Total Training Time: 0.14 min\n", 654 | "Test accuracy LoRA finetune: 96.78%\n" 655 | ] 656 | } 657 | ], 658 | "source": [ 659 | "optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)\n", 660 | "train(num_epochs, model_lora, optimizer_lora, train_loader, DEVICE)\n", 661 | "print(f'Test accuracy LoRA finetune: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')" 662 | ] 663 | }, 664 | { 665 | "cell_type": "markdown", 666 | "id": "49c08f9a-3834-4ff7-8608-be29fae19487", 667 | "metadata": {}, 668 | "source": [ 669 | "## Train model with DoRA" 670 | ] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "execution_count": 19, 675 | "id": "94c86eee-91d5-4172-aae1-360ef4bc7d05", 676 | "metadata": {}, 677 | "outputs": [ 678 | { 679 | "name": "stdout", 680 | "output_type": "stream", 681 | "text": [ 682 | "layers.0.m: True\n", 683 | "layers.0.linear.weight: False\n", 684 | "layers.0.linear.bias: False\n", 685 | "layers.0.lora.A: True\n", 686 | "layers.0.lora.B: True\n", 687 | "layers.2.m: True\n", 688 | "layers.2.linear.weight: False\n", 689 | "layers.2.linear.bias: False\n", 690 | "layers.2.lora.A: True\n", 691 | "layers.2.lora.B: True\n", 692 | "layers.4.m: True\n", 693 | "layers.4.linear.weight: False\n", 694 | "layers.4.linear.bias: False\n", 695 | "layers.4.lora.A: True\n", 696 | "layers.4.lora.B: True\n" 697 | ] 698 | } 699 | ], 700 | "source": [ 701 | "freeze_linear_layers(model_dora)\n", 702 | "\n", 703 | "# Check if linear layers are frozen\n", 704 | "for name, param in model_dora.named_parameters():\n", 705 | " print(f\"{name}: {param.requires_grad}\")" 706 | ] 707 | }, 708 | { 709 | "cell_type": "code", 710 | "execution_count": 20, 711 | "id": "690f066c-d7ac-468e-9068-d57922e1582b", 712 | "metadata": {}, 713 | "outputs": [ 714 | { 715 | "name": "stdout", 716 | "output_type": "stream", 717 | "text": [ 718 | "Epoch: 001/002 | Batch 000/938 | Loss: 0.0852\n", 719 | "Epoch: 001/002 | Batch 400/938 | Loss: 0.0059\n", 720 | "Epoch: 001/002 | Batch 800/938 | Loss: 0.0012\n", 721 | "Epoch: 001/002 training accuracy: 98.09%\n", 722 | "Time elapsed: 0.08 min\n", 723 | "Epoch: 002/002 | Batch 000/938 | Loss: 0.0118\n", 724 | "Epoch: 002/002 | Batch 400/938 | Loss: 0.0086\n", 725 | "Epoch: 002/002 | Batch 800/938 | Loss: 0.0049\n", 726 | "Epoch: 002/002 training accuracy: 98.33%\n", 727 | "Time elapsed: 0.15 min\n", 728 | "Total Training Time: 0.15 min\n", 729 | "Test accuracy DoRA finetune: 97.45%\n" 730 | ] 731 | } 732 | ], 733 | "source": [ 734 | "optimizer_dora = torch.optim.Adam(model_dora.parameters(), lr=learning_rate)\n", 735 | "train(num_epochs, model_dora, optimizer_dora, train_loader, DEVICE)\n", 736 | "print(f'Test accuracy DoRA finetune: {compute_accuracy(model_dora, test_loader, DEVICE):.2f}%')" 737 | ] 738 | }, 739 | { 740 | "cell_type": "code", 741 | "execution_count": null, 742 | "id": "38619de0-1503-4d5a-8806-33b7e64bfb87", 743 | "metadata": {}, 744 | "outputs": [], 745 | "source": [] 746 | } 747 | ], 748 | "metadata": { 749 | "kernelspec": { 750 | "display_name": "Python 3 (ipykernel)", 751 | "language": "python", 752 | "name": "python3" 753 | }, 754 | "language_info": { 755 | "codemirror_mode": { 756 | "name": "ipython", 757 | "version": 3 758 | }, 759 | "file_extension": ".py", 760 | "mimetype": "text/x-python", 761 | "name": "python", 762 | "nbconvert_exporter": "python", 763 | "pygments_lexer": "ipython3", 764 | "version": "3.10.6" 765 | } 766 | }, 767 | "nbformat": 4, 768 | "nbformat_minor": 5 769 | } 770 | -------------------------------------------------------------------------------- /Using-LinearDoRAMerged.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d2abd10e-e63e-4904-badf-5a16409503b1", 6 | "metadata": {}, 7 | "source": [ 8 | "# LoRA and DoRA from Scratch -- A Multilayer Perceptron Example\n", 9 | "\n", 10 | "## Using the LinearLoRAMerged and LinearDoRAMerged classes" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "263e27da-47c7-4030-83c6-bf5f7e8bef74", 16 | "metadata": {}, 17 | "source": [ 18 | "This code notebook illustrates how LoRA ([https://arxiv.org/abs/2106.09685](https://arxiv.org/abs/2106.09685)) and DoRA ([https://arxiv.org/abs/2402.09353](https://arxiv.org/abs/2402.09353)) work by implementing these methods from scratch.\n", 19 | "\n", 20 | "Note that this is a companion notebook to my blog article [Improving LoRA: Implementing Weight-Decomposed Low-Rank Adaptation (DoRA) from Scratch](https://magazine.sebastianraschka.com/p/lora-and-dora-from-scratch)." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "id": "9e4208cf-39b3-4a0b-a0e0-d7679a2d60c3", 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "Author: Sebastian Raschka\n", 34 | "\n", 35 | "Python implementation: CPython\n", 36 | "Python version : 3.10.6\n", 37 | "IPython version : 8.12.0\n", 38 | "\n", 39 | "torch: 2.1.0\n", 40 | "\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "%load_ext watermark\n", 46 | "%watermark -a 'Sebastian Raschka' -v -p torch" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "id": "c1c52f02-94fb-4f45-902e-79126e27347d", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "import time\n", 57 | "import numpy as np\n", 58 | "from torchvision import datasets\n", 59 | "from torchvision import transforms\n", 60 | "from torch.utils.data import DataLoader\n", 61 | "import torch.nn.functional as F\n", 62 | "import torch.nn as nn\n", 63 | "import torch\n", 64 | "\n", 65 | "\n", 66 | "if torch.cuda.is_available():\n", 67 | " torch.backends.cudnn.deterministic = True" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "id": "629ec66a-eb81-40a5-ae3d-d5c1d2a7e390", 73 | "metadata": {}, 74 | "source": [ 75 | "## Settings and Dataset" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 3, 81 | "id": "4ade5e86-8bd8-4a35-8db1-44451601b292", 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "Image batch dimensions: torch.Size([64, 1, 28, 28])\n", 89 | "Image label dimensions: torch.Size([64])\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "##########################\n", 95 | "### SETTINGS\n", 96 | "##########################\n", 97 | "\n", 98 | "# Device\n", 99 | "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 100 | "BATCH_SIZE = 64\n", 101 | "\n", 102 | "##########################\n", 103 | "### MNIST DATASET\n", 104 | "##########################\n", 105 | "\n", 106 | "# Note transforms.ToTensor() scales input images\n", 107 | "# to 0-1 range\n", 108 | "train_dataset = datasets.MNIST(root='data', \n", 109 | " train=True, \n", 110 | " transform=transforms.ToTensor(),\n", 111 | " download=True)\n", 112 | "\n", 113 | "test_dataset = datasets.MNIST(root='data', \n", 114 | " train=False, \n", 115 | " transform=transforms.ToTensor())\n", 116 | "\n", 117 | "\n", 118 | "train_loader = DataLoader(dataset=train_dataset, \n", 119 | " batch_size=BATCH_SIZE, \n", 120 | " shuffle=True)\n", 121 | "\n", 122 | "test_loader = DataLoader(dataset=test_dataset, \n", 123 | " batch_size=BATCH_SIZE, \n", 124 | " shuffle=False)\n", 125 | "\n", 126 | "# Checking the dataset\n", 127 | "for images, labels in train_loader: \n", 128 | " print('Image batch dimensions:', images.shape)\n", 129 | " print('Image label dimensions:', labels.shape)\n", 130 | " break" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "id": "394e5da8-2978-40f0-bca7-b79e8e35734f", 136 | "metadata": {}, 137 | "source": [ 138 | "# Multilayer Perceptron Model (Without LoRA and DoRA)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 4, 144 | "id": "7e905c42-7f59-4a08-b6c5-a10f99f33e9e", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "##########################\n", 149 | "### MODEL\n", 150 | "##########################\n", 151 | "\n", 152 | "# Hyperparameters\n", 153 | "random_seed = 123\n", 154 | "learning_rate = 0.005\n", 155 | "num_epochs = 2\n", 156 | "\n", 157 | "# Architecture\n", 158 | "num_features = 784\n", 159 | "num_hidden_1 = 128\n", 160 | "num_hidden_2 = 256\n", 161 | "num_classes = 10\n", 162 | "\n", 163 | "\n", 164 | "class MultilayerPerceptron(nn.Module):\n", 165 | "\n", 166 | " def __init__(self, num_features, num_hidden_1, num_hidden_2, num_classes):\n", 167 | " super().__init__()\n", 168 | "\n", 169 | " self.layers = nn.Sequential(\n", 170 | " nn.Linear(num_features, num_hidden_1),\n", 171 | " nn.ReLU(),\n", 172 | " nn.Linear(num_hidden_1, num_hidden_2),\n", 173 | " nn.ReLU(),\n", 174 | " nn.Linear(num_hidden_2, num_classes)\n", 175 | " )\n", 176 | "\n", 177 | " def forward(self, x):\n", 178 | " x = self.layers(x)\n", 179 | " return x\n", 180 | "\n", 181 | "\n", 182 | "torch.manual_seed(random_seed)\n", 183 | "model_pretrained = MultilayerPerceptron(\n", 184 | " num_features=num_features,\n", 185 | " num_hidden_1=num_hidden_1,\n", 186 | " num_hidden_2=num_hidden_2, \n", 187 | " num_classes=num_classes\n", 188 | ")\n", 189 | "\n", 190 | "model_pretrained.to(DEVICE)\n", 191 | "optimizer_pretrained = torch.optim.Adam(model_pretrained.parameters(), lr=learning_rate)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 5, 197 | "id": "cf31624a-d950-402f-a564-2e7fb63db8a4", 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "def compute_accuracy(model, data_loader, device):\n", 202 | " model.eval()\n", 203 | " correct_pred, num_examples = 0, 0\n", 204 | " with torch.no_grad():\n", 205 | " for features, targets in data_loader:\n", 206 | " features = features.view(-1, 28*28).to(device)\n", 207 | " targets = targets.to(device)\n", 208 | " logits = model(features)\n", 209 | " _, predicted_labels = torch.max(logits, 1)\n", 210 | " num_examples += targets.size(0)\n", 211 | " correct_pred += (predicted_labels == targets).sum()\n", 212 | " return correct_pred.float()/num_examples * 100\n", 213 | "\n", 214 | "\n", 215 | "def train(num_epochs, model, optimizer, train_loader, device):\n", 216 | "\n", 217 | " start_time = time.time()\n", 218 | " for epoch in range(num_epochs):\n", 219 | " model.train()\n", 220 | " for batch_idx, (features, targets) in enumerate(train_loader):\n", 221 | "\n", 222 | " features = features.view(-1, 28*28).to(device)\n", 223 | " targets = targets.to(device)\n", 224 | "\n", 225 | " # FORWARD AND BACK PROP\n", 226 | " logits = model(features)\n", 227 | " loss = F.cross_entropy(logits, targets)\n", 228 | " optimizer.zero_grad()\n", 229 | "\n", 230 | " loss.backward()\n", 231 | "\n", 232 | " # UPDATE MODEL PARAMETERS\n", 233 | " optimizer.step()\n", 234 | "\n", 235 | " # LOGGING\n", 236 | " if not batch_idx % 400:\n", 237 | " print('Epoch: %03d/%03d | Batch %03d/%03d | Loss: %.4f'\n", 238 | " % (epoch+1, num_epochs, batch_idx,\n", 239 | " len(train_loader), loss))\n", 240 | "\n", 241 | " with torch.set_grad_enabled(False):\n", 242 | " print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n", 243 | " epoch+1, num_epochs,\n", 244 | " compute_accuracy(model, train_loader, device)))\n", 245 | "\n", 246 | " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", 247 | "\n", 248 | " print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 6, 254 | "id": "f47cfe4e-65eb-440e-b922-17c6dee7d7e2", 255 | "metadata": {}, 256 | "outputs": [ 257 | { 258 | "name": "stdout", 259 | "output_type": "stream", 260 | "text": [ 261 | "Epoch: 001/002 | Batch 000/938 | Loss: 2.2971\n", 262 | "Epoch: 001/002 | Batch 400/938 | Loss: 0.1770\n", 263 | "Epoch: 001/002 | Batch 800/938 | Loss: 0.1582\n", 264 | "Epoch: 001/002 training accuracy: 95.62%\n", 265 | "Time elapsed: 0.08 min\n", 266 | "Epoch: 002/002 | Batch 000/938 | Loss: 0.0501\n", 267 | "Epoch: 002/002 | Batch 400/938 | Loss: 0.0408\n", 268 | "Epoch: 002/002 | Batch 800/938 | Loss: 0.0828\n", 269 | "Epoch: 002/002 training accuracy: 97.22%\n", 270 | "Time elapsed: 0.16 min\n", 271 | "Total Training Time: 0.16 min\n", 272 | "Test accuracy: 96.41%\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "train(num_epochs, model_pretrained, optimizer_pretrained, train_loader, DEVICE)\n", 278 | "print(f'Test accuracy: {compute_accuracy(model_pretrained, test_loader, DEVICE):.2f}%')" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "id": "fb3480b9-aea5-411e-b252-d7fc8a5dd21d", 284 | "metadata": {}, 285 | "source": [ 286 | "# Multilayer Perceptron with LoRA and DoRA" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "id": "36b9d281-22ba-4120-af95-f6b95adcaa03", 292 | "metadata": {}, 293 | "source": [ 294 | "## Modify model by injecting LoRA and DoRA layers" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 7, 300 | "id": "215795c5-c0d4-4886-b4d6-a5a0e7cc8c7e", 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "class LoRALayer(nn.Module):\n", 305 | " def __init__(self, in_dim, out_dim, rank, alpha):\n", 306 | " super().__init__()\n", 307 | " std_dev = 1 / torch.sqrt(torch.tensor(rank).float())\n", 308 | " self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)\n", 309 | " self.B = nn.Parameter(torch.zeros(rank, out_dim))\n", 310 | " self.alpha = alpha\n", 311 | "\n", 312 | " def forward(self, x):\n", 313 | " x = self.alpha * (x @ self.A @ self.B)\n", 314 | " return x\n", 315 | "\n", 316 | " \n", 317 | "# This LoRA code is equivalent to LinearWithLoRA\n", 318 | "class LinearWithLoRAMerged(nn.Module):\n", 319 | " def __init__(self, linear, rank, alpha):\n", 320 | " super().__init__()\n", 321 | " self.linear = linear\n", 322 | " self.lora = LoRALayer(\n", 323 | " linear.in_features, linear.out_features, rank, alpha\n", 324 | " )\n", 325 | "\n", 326 | " def forward(self, x):\n", 327 | " lora = self.lora.A @ self.lora.B\n", 328 | " combined_weight = self.linear.weight + self.lora.alpha*lora.T\n", 329 | " return F.linear(x, combined_weight, self.linear.bias)\n", 330 | "\n", 331 | " \n", 332 | "# This DoRA code is equivalent to LinearWithDoRA\n", 333 | "# Code inspired by https://github.com/catid/dora/blob/main/dora.py\n", 334 | "class LinearWithDoRAMerged(nn.Module):\n", 335 | " def __init__(self, linear, rank, alpha):\n", 336 | " super().__init__()\n", 337 | " self.linear = linear\n", 338 | " self.lora = LoRALayer(\n", 339 | " linear.in_features, linear.out_features, rank, alpha\n", 340 | " )\n", 341 | " \n", 342 | " self.m = nn.Parameter(\n", 343 | " self.linear.weight.norm(p=2, dim=0, keepdim=True))\n", 344 | "\n", 345 | " def forward(self, x):\n", 346 | " lora = self.lora.A @ self.lora.B\n", 347 | " numerator = self.linear.weight + self.lora.alpha*lora.T\n", 348 | " denominator = numerator.norm(p=2, dim=0, keepdim=True)\n", 349 | " directional_component = numerator / denominator\n", 350 | " new_weight = self.m * directional_component\n", 351 | " return F.linear(x, new_weight, self.linear.bias)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "id": "06cfda10-ac5a-4958-ad42-910ba73aa639", 357 | "metadata": {}, 358 | "source": [ 359 | "Since the B matrix is initialized to 0's, the initial LoRA and DoRA layers (before training) should not affect the outputs of the forward pass, which we can confirm as follows:" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 8, 365 | "id": "0441c93f-0ee5-4003-acc3-f24541f06c66", 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "name": "stdout", 370 | "output_type": "stream", 371 | "text": [ 372 | "Original output: tensor([[0.6639, 0.4487]], grad_fn=)\n" 373 | ] 374 | } 375 | ], 376 | "source": [ 377 | "torch.manual_seed(123)\n", 378 | "\n", 379 | "layer = nn.Linear(10, 2)\n", 380 | "x = torch.randn((1, 10))\n", 381 | "\n", 382 | "print(\"Original output:\", layer(x))" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 9, 388 | "id": "f555c364-9c5f-4a20-8c8c-feb830131555", 389 | "metadata": {}, 390 | "outputs": [ 391 | { 392 | "name": "stdout", 393 | "output_type": "stream", 394 | "text": [ 395 | "LoRA output: tensor([[0.6639, 0.4487]], grad_fn=)\n" 396 | ] 397 | } 398 | ], 399 | "source": [ 400 | "layer_lora_2 = LinearWithLoRAMerged(layer, rank=2, alpha=4)\n", 401 | "print(\"LoRA output:\", layer_lora_2(x))" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 10, 407 | "id": "5356918c-afdb-4bdd-8745-97ba33a9fc86", 408 | "metadata": {}, 409 | "outputs": [ 410 | { 411 | "name": "stdout", 412 | "output_type": "stream", 413 | "text": [ 414 | "DoRA output: tensor([[0.6639, 0.4487]], grad_fn=)\n" 415 | ] 416 | } 417 | ], 418 | "source": [ 419 | "layer_dora_2 = LinearWithDoRAMerged(layer, rank=2, alpha=4)\n", 420 | "\n", 421 | "print(\"DoRA output:\", layer_dora_2(x))" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 11, 427 | "id": "dc66ffa1-5822-4833-b636-d3a8170e84a2", 428 | "metadata": {}, 429 | "outputs": [ 430 | { 431 | "data": { 432 | "text/plain": [ 433 | "MultilayerPerceptron(\n", 434 | " (layers): Sequential(\n", 435 | " (0): Linear(in_features=784, out_features=128, bias=True)\n", 436 | " (1): ReLU()\n", 437 | " (2): Linear(in_features=128, out_features=256, bias=True)\n", 438 | " (3): ReLU()\n", 439 | " (4): Linear(in_features=256, out_features=10, bias=True)\n", 440 | " )\n", 441 | ")" 442 | ] 443 | }, 444 | "execution_count": 11, 445 | "metadata": {}, 446 | "output_type": "execute_result" 447 | } 448 | ], 449 | "source": [ 450 | "model_pretrained" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": 12, 456 | "id": "b00a7e8f-09ff-499e-b593-3f6dac87f1bf", 457 | "metadata": {}, 458 | "outputs": [], 459 | "source": [ 460 | "import copy\n", 461 | "\n", 462 | "model_lora = copy.deepcopy(model_pretrained)\n", 463 | "model_dora = copy.deepcopy(model_pretrained)" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": 13, 469 | "id": "b1e3ef6f-255c-4c71-9da5-7d06d8c439a3", 470 | "metadata": {}, 471 | "outputs": [ 472 | { 473 | "data": { 474 | "text/plain": [ 475 | "MultilayerPerceptron(\n", 476 | " (layers): Sequential(\n", 477 | " (0): LinearWithLoRAMerged(\n", 478 | " (linear): Linear(in_features=784, out_features=128, bias=True)\n", 479 | " (lora): LoRALayer()\n", 480 | " )\n", 481 | " (1): ReLU()\n", 482 | " (2): LinearWithLoRAMerged(\n", 483 | " (linear): Linear(in_features=128, out_features=256, bias=True)\n", 484 | " (lora): LoRALayer()\n", 485 | " )\n", 486 | " (3): ReLU()\n", 487 | " (4): LinearWithLoRAMerged(\n", 488 | " (linear): Linear(in_features=256, out_features=10, bias=True)\n", 489 | " (lora): LoRALayer()\n", 490 | " )\n", 491 | " )\n", 492 | ")" 493 | ] 494 | }, 495 | "execution_count": 13, 496 | "metadata": {}, 497 | "output_type": "execute_result" 498 | } 499 | ], 500 | "source": [ 501 | "model_lora.layers[0] = LinearWithLoRAMerged(model_lora.layers[0], rank=4, alpha=8)\n", 502 | "model_lora.layers[2] = LinearWithLoRAMerged(model_lora.layers[2], rank=4, alpha=8)\n", 503 | "model_lora.layers[4] = LinearWithLoRAMerged(model_lora.layers[4], rank=4, alpha=8)\n", 504 | "\n", 505 | "model_lora.to(DEVICE)\n", 506 | "optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)\n", 507 | "model_lora" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": 14, 513 | "id": "ec5a4ef4-bfc0-4fdb-bf7d-114aaf89795a", 514 | "metadata": {}, 515 | "outputs": [ 516 | { 517 | "data": { 518 | "text/plain": [ 519 | "MultilayerPerceptron(\n", 520 | " (layers): Sequential(\n", 521 | " (0): LinearWithDoRAMerged(\n", 522 | " (linear): Linear(in_features=784, out_features=128, bias=True)\n", 523 | " (lora): LoRALayer()\n", 524 | " )\n", 525 | " (1): ReLU()\n", 526 | " (2): LinearWithDoRAMerged(\n", 527 | " (linear): Linear(in_features=128, out_features=256, bias=True)\n", 528 | " (lora): LoRALayer()\n", 529 | " )\n", 530 | " (3): ReLU()\n", 531 | " (4): LinearWithDoRAMerged(\n", 532 | " (linear): Linear(in_features=256, out_features=10, bias=True)\n", 533 | " (lora): LoRALayer()\n", 534 | " )\n", 535 | " )\n", 536 | ")" 537 | ] 538 | }, 539 | "execution_count": 14, 540 | "metadata": {}, 541 | "output_type": "execute_result" 542 | } 543 | ], 544 | "source": [ 545 | "model_dora.layers[0] = LinearWithDoRAMerged(model_dora.layers[0], rank=4, alpha=8)\n", 546 | "model_dora.layers[2] = LinearWithDoRAMerged(model_dora.layers[2], rank=4, alpha=8)\n", 547 | "model_dora.layers[4] = LinearWithDoRAMerged(model_dora.layers[4], rank=4, alpha=8)\n", 548 | "\n", 549 | "model_dora.to(DEVICE)\n", 550 | "optimizer_dora = torch.optim.Adam(model_dora.parameters(), lr=learning_rate)\n", 551 | "model_dora" 552 | ] 553 | }, 554 | { 555 | "cell_type": "markdown", 556 | "id": "9756742d-f574-400a-8d4e-cc55233df83c", 557 | "metadata": {}, 558 | "source": [ 559 | "We just initialized the LoRA & DoRA layers but haven't trained the LoRA layers yet, so a model with and without initial LoRA weights should have the same predictive performance:" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": 15, 565 | "id": "d2ac620b-2fdf-4b94-92bb-f8d00e640306", 566 | "metadata": {}, 567 | "outputs": [ 568 | { 569 | "name": "stdout", 570 | "output_type": "stream", 571 | "text": [ 572 | "Test accuracy orig model: 96.41%\n", 573 | "Test accuracy LoRA model: 96.41%\n", 574 | "Test accuracy DoRA model: 96.41%\n" 575 | ] 576 | } 577 | ], 578 | "source": [ 579 | "print(f'Test accuracy orig model: {compute_accuracy(model_pretrained, test_loader, DEVICE):.2f}%')\n", 580 | "print(f'Test accuracy LoRA model: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')\n", 581 | "print(f'Test accuracy DoRA model: {compute_accuracy(model_dora, test_loader, DEVICE):.2f}%')" 582 | ] 583 | }, 584 | { 585 | "cell_type": "markdown", 586 | "id": "4ceed732-7989-4f01-a5a1-1036eb41512d", 587 | "metadata": {}, 588 | "source": [ 589 | "## Train model with LoRA" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": 16, 595 | "id": "a35d4c20-f754-4e82-85a1-ad19a30b3dfe", 596 | "metadata": {}, 597 | "outputs": [], 598 | "source": [ 599 | "def freeze_linear_layers(model):\n", 600 | " for child in model.children():\n", 601 | " if isinstance(child, nn.Linear):\n", 602 | " for param in child.parameters():\n", 603 | " param.requires_grad = False\n", 604 | " else:\n", 605 | " # Recursively freeze linear layers in children modules\n", 606 | " freeze_linear_layers(child)" 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": 17, 612 | "id": "88454690-abe6-49de-986e-9a6fe7883000", 613 | "metadata": {}, 614 | "outputs": [ 615 | { 616 | "name": "stdout", 617 | "output_type": "stream", 618 | "text": [ 619 | "layers.0.linear.weight: False\n", 620 | "layers.0.linear.bias: False\n", 621 | "layers.0.lora.A: True\n", 622 | "layers.0.lora.B: True\n", 623 | "layers.2.linear.weight: False\n", 624 | "layers.2.linear.bias: False\n", 625 | "layers.2.lora.A: True\n", 626 | "layers.2.lora.B: True\n", 627 | "layers.4.linear.weight: False\n", 628 | "layers.4.linear.bias: False\n", 629 | "layers.4.lora.A: True\n", 630 | "layers.4.lora.B: True\n" 631 | ] 632 | } 633 | ], 634 | "source": [ 635 | "freeze_linear_layers(model_lora)\n", 636 | "\n", 637 | "# Check if linear layers are frozen\n", 638 | "for name, param in model_lora.named_parameters():\n", 639 | " print(f\"{name}: {param.requires_grad}\")" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": 18, 645 | "id": "1b807c7b-8d4a-4a1e-8a56-42bbdbc82fed", 646 | "metadata": {}, 647 | "outputs": [ 648 | { 649 | "name": "stdout", 650 | "output_type": "stream", 651 | "text": [ 652 | "Epoch: 001/002 | Batch 000/938 | Loss: 0.0223\n", 653 | "Epoch: 001/002 | Batch 400/938 | Loss: 0.1485\n", 654 | "Epoch: 001/002 | Batch 800/938 | Loss: 0.3456\n", 655 | "Epoch: 001/002 training accuracy: 97.20%\n", 656 | "Time elapsed: 0.07 min\n", 657 | "Epoch: 002/002 | Batch 000/938 | Loss: 0.3768\n", 658 | "Epoch: 002/002 | Batch 400/938 | Loss: 0.0851\n", 659 | "Epoch: 002/002 | Batch 800/938 | Loss: 0.1195\n", 660 | "Epoch: 002/002 training accuracy: 97.85%\n", 661 | "Time elapsed: 0.15 min\n", 662 | "Total Training Time: 0.15 min\n", 663 | "Test accuracy LoRA finetune: 96.93%\n" 664 | ] 665 | } 666 | ], 667 | "source": [ 668 | "optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)\n", 669 | "train(num_epochs, model_lora, optimizer_lora, train_loader, DEVICE)\n", 670 | "print(f'Test accuracy LoRA finetune: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')" 671 | ] 672 | }, 673 | { 674 | "cell_type": "markdown", 675 | "id": "49c08f9a-3834-4ff7-8608-be29fae19487", 676 | "metadata": {}, 677 | "source": [ 678 | "## Train model with DoRA" 679 | ] 680 | }, 681 | { 682 | "cell_type": "code", 683 | "execution_count": 19, 684 | "id": "94c86eee-91d5-4172-aae1-360ef4bc7d05", 685 | "metadata": {}, 686 | "outputs": [ 687 | { 688 | "name": "stdout", 689 | "output_type": "stream", 690 | "text": [ 691 | "layers.0.m: True\n", 692 | "layers.0.linear.weight: False\n", 693 | "layers.0.linear.bias: False\n", 694 | "layers.0.lora.A: True\n", 695 | "layers.0.lora.B: True\n", 696 | "layers.2.m: True\n", 697 | "layers.2.linear.weight: False\n", 698 | "layers.2.linear.bias: False\n", 699 | "layers.2.lora.A: True\n", 700 | "layers.2.lora.B: True\n", 701 | "layers.4.m: True\n", 702 | "layers.4.linear.weight: False\n", 703 | "layers.4.linear.bias: False\n", 704 | "layers.4.lora.A: True\n", 705 | "layers.4.lora.B: True\n" 706 | ] 707 | } 708 | ], 709 | "source": [ 710 | "freeze_linear_layers(model_dora)\n", 711 | "\n", 712 | "# Check if linear layers are frozen\n", 713 | "for name, param in model_dora.named_parameters():\n", 714 | " print(f\"{name}: {param.requires_grad}\")" 715 | ] 716 | }, 717 | { 718 | "cell_type": "code", 719 | "execution_count": 20, 720 | "id": "690f066c-d7ac-468e-9068-d57922e1582b", 721 | "metadata": {}, 722 | "outputs": [ 723 | { 724 | "name": "stdout", 725 | "output_type": "stream", 726 | "text": [ 727 | "Epoch: 001/002 | Batch 000/938 | Loss: 0.0852\n", 728 | "Epoch: 001/002 | Batch 400/938 | Loss: 0.0144\n", 729 | "Epoch: 001/002 | Batch 800/938 | Loss: 0.0036\n", 730 | "Epoch: 001/002 training accuracy: 97.83%\n", 731 | "Time elapsed: 0.10 min\n", 732 | "Epoch: 002/002 | Batch 000/938 | Loss: 0.0393\n", 733 | "Epoch: 002/002 | Batch 400/938 | Loss: 0.0488\n", 734 | "Epoch: 002/002 | Batch 800/938 | Loss: 0.0133\n", 735 | "Epoch: 002/002 training accuracy: 98.12%\n", 736 | "Time elapsed: 0.20 min\n", 737 | "Total Training Time: 0.20 min\n", 738 | "Test accuracy DoRA finetune: 97.12%\n" 739 | ] 740 | } 741 | ], 742 | "source": [ 743 | "optimizer_dora = torch.optim.Adam(model_dora.parameters(), lr=learning_rate)\n", 744 | "train(num_epochs, model_dora, optimizer_dora, train_loader, DEVICE)\n", 745 | "print(f'Test accuracy DoRA finetune: {compute_accuracy(model_dora, test_loader, DEVICE):.2f}%')" 746 | ] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "execution_count": null, 751 | "id": "38619de0-1503-4d5a-8806-33b7e64bfb87", 752 | "metadata": {}, 753 | "outputs": [], 754 | "source": [] 755 | } 756 | ], 757 | "metadata": { 758 | "kernelspec": { 759 | "display_name": "Python 3 (ipykernel)", 760 | "language": "python", 761 | "name": "python3" 762 | }, 763 | "language_info": { 764 | "codemirror_mode": { 765 | "name": "ipython", 766 | "version": 3 767 | }, 768 | "file_extension": ".py", 769 | "mimetype": "text/x-python", 770 | "name": "python", 771 | "nbconvert_exporter": "python", 772 | "pygments_lexer": "ipython3", 773 | "version": "3.10.6" 774 | } 775 | }, 776 | "nbformat": 4, 777 | "nbformat_minor": 5 778 | } 779 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.1.0 2 | torchvision>=0.16.0 3 | watermark>=2.3.1 --------------------------------------------------------------------------------