├── loratorch ├── __init__.py ├── utils.py └── layers.py ├── setup.py ├── LICENSE ├── examples ├── embedding.ipynb ├── linear.ipynb ├── mergedlinear.ipynb └── Finetune_open_clip_with_LoRA_Torch_on_CIFAR10.ipynb └── README.md /loratorch/__init__.py: -------------------------------------------------------------------------------- 1 | name = "lora" 2 | 3 | from .layers import * 4 | from .utils import * -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="loratorch", 8 | version="0.1.0", 9 | author="Baijiong Lin", 10 | author_email="bj.lin.email@gmail.com", 11 | description="PyTorch reimplementation of low-rank adaptation (LoRA).", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/Baijiong-Lin/LoRA-Torch", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | python_requires='>=3.6', 22 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Baijiong Lin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /loratorch/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # This code is reconstructed based on loralib (https://github.com/microsoft/LoRA) by Baijiong Lin. 3 | # ------------------------------------------------------------------------------------------ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from typing import Dict 8 | 9 | from .layers import LoRALayer 10 | 11 | def register_model_param_after_backward(model: nn.Module) -> None: 12 | for m in model.modules(): 13 | if isinstance(m, LoRALayer): 14 | m.register_weight_after_backward() 15 | 16 | def print_trainable_parameters(model): 17 | r"""Prints the number of trainable parameters in the model.""" 18 | trainable_params = 0 19 | all_param = 0 20 | for _, param in model.named_parameters(): 21 | all_param += param.numel() 22 | if param.requires_grad: 23 | trainable_params += param.numel() 24 | print( 25 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}" 26 | ) 27 | 28 | def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: 29 | for n, p in model.named_parameters(): 30 | if 'lora_' not in n: 31 | p.requires_grad = False 32 | if bias == 'none': 33 | return 34 | elif bias == 'all': 35 | for n, p in model.named_parameters(): 36 | if 'bias' in n: 37 | p.requires_grad = True 38 | elif bias == 'lora_only': 39 | for m in model.modules(): 40 | if isinstance(m, LoRALayer) and \ 41 | hasattr(m, 'bias') and \ 42 | m.bias is not None: 43 | m.bias.requires_grad = True 44 | else: 45 | raise NotImplementedError 46 | print_trainable_parameters(model) 47 | 48 | 49 | def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]: 50 | my_state_dict = model.state_dict() 51 | if bias == 'none': 52 | return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} 53 | elif bias == 'all': 54 | return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k} 55 | elif bias == 'lora_only': 56 | to_return = {} 57 | for k in my_state_dict: 58 | if 'lora_' in k: 59 | to_return[k] = my_state_dict[k] 60 | bias_name = k.split('lora_')[0]+'bias' 61 | if bias_name in my_state_dict: 62 | to_return[bias_name] = my_state_dict[bias_name] 63 | return to_return 64 | else: 65 | raise NotImplementedError 66 | -------------------------------------------------------------------------------- /examples/embedding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "authorship_tag": "ABX9TyNh+WHHXHiw1v/WeZj1rcht", 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "metadata": { 33 | "colab": { 34 | "base_uri": "https://localhost:8080/" 35 | }, 36 | "id": "330HoZ7N_S3x", 37 | "outputId": "1d7f8b83-b37f-4b55-ad77-aa099b236083" 38 | }, 39 | "outputs": [ 40 | { 41 | "output_type": "stream", 42 | "name": "stdout", 43 | "text": [ 44 | "Collecting git+https://github.com/microsoft/LoRA\n", 45 | " Cloning https://github.com/microsoft/LoRA to /tmp/pip-req-build-qrvzzdp0\n", 46 | " Running command git clone --filter=blob:none --quiet https://github.com/microsoft/LoRA /tmp/pip-req-build-qrvzzdp0\n", 47 | " Resolved https://github.com/microsoft/LoRA to commit 998cfe4d351f4d6b4a47f0921dec2397aa0b9dfe\n", 48 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 49 | "Collecting git+https://github.com/Baijiong-Lin/LoRA-Torch\n", 50 | " Cloning https://github.com/Baijiong-Lin/LoRA-Torch to /tmp/pip-req-build-ro2rj2g8\n", 51 | " Running command git clone --filter=blob:none --quiet https://github.com/Baijiong-Lin/LoRA-Torch /tmp/pip-req-build-ro2rj2g8\n", 52 | " Resolved https://github.com/Baijiong-Lin/LoRA-Torch to commit 34286e640a4d15dfe23c361f6d95b7cc55a8ec6b\n", 53 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "!pip install git+https://github.com/microsoft/LoRA\n", 59 | "!pip install git+https://github.com/Baijiong-Lin/LoRA-Torch" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "source": [ 65 | "import torch, loralib, loratorch, copy\n", 66 | "import torch.nn as nn" 67 | ], 68 | "metadata": { 69 | "id": "cqadI5Fu_WUX" 70 | }, 71 | "execution_count": 2, 72 | "outputs": [] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "source": [ 77 | "model_lib = loralib.Embedding(10, 3, r=4)\n", 78 | "\n", 79 | "for k, v in model_lib.named_parameters():\n", 80 | " print(k, v.size())\n", 81 | "\n", 82 | "\n", 83 | "loralib.mark_only_lora_as_trainable(model_lib)\n", 84 | "\n", 85 | "optimizer_lib = torch.optim.SGD(model_lib.parameters(), lr=0.1)\n", 86 | "\n", 87 | "for _ in range(3):\n", 88 | " model_lib.train()\n", 89 | "\n", 90 | " x = torch.rand(5, 10).long()\n", 91 | "\n", 92 | " loss_lib = model_lib(x).sum()\n", 93 | " optimizer_lib.zero_grad()\n", 94 | " loss_lib.backward()\n", 95 | " optimizer_lib.step()\n", 96 | "\n", 97 | " x_test = torch.rand(*x.size()).long()\n", 98 | " model_lib.eval()\n", 99 | " print(model_lib(x_test).size())" 100 | ], 101 | "metadata": { 102 | "colab": { 103 | "base_uri": "https://localhost:8080/" 104 | }, 105 | "id": "91GlkwcQD7YG", 106 | "outputId": "e87d2630-4ba5-456e-c8b3-63c25b36a9cb" 107 | }, 108 | "execution_count": 3, 109 | "outputs": [ 110 | { 111 | "output_type": "stream", 112 | "name": "stdout", 113 | "text": [ 114 | "weight torch.Size([10, 3])\n", 115 | "lora_A torch.Size([4, 10])\n", 116 | "lora_B torch.Size([3, 4])\n", 117 | "torch.Size([5, 10, 3])\n", 118 | "torch.Size([5, 10, 3])\n", 119 | "torch.Size([5, 10, 3])\n" 120 | ] 121 | } 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "source": [ 127 | "model_torch = loratorch.Embedding(10, 3, r=4)\n", 128 | "\n", 129 | "for k, v in model_torch.named_parameters():\n", 130 | " print(k, v.size())\n", 131 | "\n", 132 | "\n", 133 | "loratorch.mark_only_lora_as_trainable(model_torch)\n", 134 | "\n", 135 | "optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.1)\n", 136 | "\n", 137 | "for _ in range(3):\n", 138 | " model_torch.train()\n", 139 | "\n", 140 | " x = torch.rand(5, 10).long()\n", 141 | "\n", 142 | " loss_torch = model_torch(x).sum()\n", 143 | " optimizer_torch.zero_grad()\n", 144 | " loss_torch.backward()\n", 145 | " optimizer_torch.step()\n", 146 | "\n", 147 | " x_test = torch.rand(*x.size()).long()\n", 148 | " model_torch.eval()\n", 149 | " print(model_torch(x_test).size())" 150 | ], 151 | "metadata": { 152 | "colab": { 153 | "base_uri": "https://localhost:8080/" 154 | }, 155 | "id": "NsLzmMXSC3RQ", 156 | "outputId": "48c1e55c-1236-42c5-ac87-c3d55e6bcf87" 157 | }, 158 | "execution_count": 4, 159 | "outputs": [ 160 | { 161 | "output_type": "stream", 162 | "name": "stdout", 163 | "text": [ 164 | "weight torch.Size([10, 3])\n", 165 | "w_lora_A torch.Size([4, 3])\n", 166 | "w_lora_B torch.Size([10, 4])\n", 167 | "torch.Size([5, 10, 3])\n", 168 | "torch.Size([5, 10, 3])\n", 169 | "torch.Size([5, 10, 3])\n" 170 | ] 171 | } 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "source": [], 177 | "metadata": { 178 | "id": "q7yt9fq1GrnA" 179 | }, 180 | "execution_count": 4, 181 | "outputs": [] 182 | } 183 | ] 184 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LoRA-Torch 2 | 3 | [![Made With Love](https://img.shields.io/badge/Made%20With-Love-orange.svg)](https://github.com/Baijiong-Lin/LoRA-Torch) 4 | 5 | This codebase reimplementes [LoRA: Low-Rank Adaptation of Large Language Models (ICLR 2022)](https://openreview.net/forum?id=nZeVKeeFYf9) and is reconstructed based on [loralib](https://github.com/microsoft/LoRA). 6 | 7 | 8 | 9 | ## Features 10 | 11 | **The implementations of ``loratorch`` and ``loralib`` are very different.** We take the ``nn.Linear`` as an example as follows. 12 | 13 | 1. For ``loralib``, 14 | $h = x W_0^\top + \frac{\alpha}{r} x(BA)^\top,$ 15 | 16 | where $x\in\mathbb{R}^{k\times n}$ is the input matrix, $W_0\in\mathbb{R}^{m\times n}$ is the pre-trained weight matrix, $r$ is the predefined LoRA rank, $B\in\mathbb{R}^{m\times r}$ and $A\in \mathbb{R}^{r\times n}$ are the LoRA matrixes, and $\alpha$ is a hyper-parameter. 17 | 18 | 2. For ``loratorch``, 19 | $h = x (W_0 + \frac{\alpha}{r} BA)^\top.$ 20 | 21 | 22 | 23 | ``loralib`` computes $xW_0^\top$ and $x(BA)^\top$ respectively and then merges the results. While ``loratorch`` merges pre-trained weight $W_0$ and its LoRA weight $BA$ and then computes the results by simply using ``nn.Linear.forward()``. There is no difference between ``loralib`` and ``loratorch`` in the linear layers. But in some no-linear or complex layers, we are no sure whether this layer satisfies $L(x, W_0)+L(x, BA) = L(x, W_0+BA)$. Hence, it is difficult to extend LoRA to some complex layers by using ``loralib``. On the contrary, the idea of merging weights first in ``loratorch`` is more general and extensible. You just call ``merge_lora_param()`` in ``loratorch`` to merge weights and then call ``forward()`` in the original layer to compute the results. With the help of ``loratorch``, you can easily implement LoRA to any type of layer of ``torch.nn``. 24 | 25 | 26 | 27 | ## Supported Layers 28 | 29 | | | ``loralib`` | ``loratorch`` | | 30 | | ------------------------- |:--------------:|:--------------:| -------------------------------------------------- | 31 | | ``nn.Linear`` | ✓ | ✓ | [linear.ipynb](https://github.com/Baijiong-Lin/LoRA-Torch/blob/main/examples/linear.ipynb) | 32 | | ``nn.Embedding`` | ✓ | ✓ | [embedding.ipynb](https://github.com/Baijiong-Lin/LoRA-Torch/blob/main/examples/embedding.ipynb) | 33 | | ``nn.Conv1d`` | ✓ | ✓ | | 34 | | ``nn.Conv2d`` | ✓ | ✓ | | 35 | | ``nn.Conv3d`` | ✓ | ✓ | | 36 | | ``nn.MultiheadAttention`` | ✘ | ✓ | [Finetune_open_clip_with_LoRA_Torch_on_CIFAR10.ipynb](https://github.com/Baijiong-Lin/LoRA-Torch/blob/main/examples/Finetune_open_clip_with_LoRA_Torch_on_CIFAR10.ipynb) | 37 | | ``MergedLinear`` | ✓ (Error) | ✓ | [mergedlinear.ipynb](https://github.com/Baijiong-Lin/LoRA-Torch/blob/main/examples/mergedlinear.ipynb) | 38 | | $\cdots$ | hard to extend | easy to extend | | 39 | 40 | *We compare the results of ``loralib`` and ``loratorch`` in [examples](./examples) to demonstrate the correctness of the implementation in ``loratorch``.* 41 | 42 | 43 | 44 | ## Quick Start 45 | 46 | :bangbang: We have provided an [example](https://github.com/Baijiong-Lin/LoRA-Torch/blob/main/examples/Finetune_open_clip_with_LoRA_Torch_on_CIFAR10.ipynb) to demonstrate that how to apply LoRA-Torch to ``nn.MultiheadAttention`` in OpenCLIP. We greatly appreciate [@vietvo89](https://github.com/vietvo89)'s valuable contribution. 47 | 48 | **The usage of ``loratorch`` is the same as ``loralib``.** 49 | 50 | 1. Install ``loratorch``. 51 | 52 | ```bash 53 | pip install git+https://github.com/Baijiong-Lin/LoRA-Torch 54 | # Alternatively for developers 55 | # git clone https://github.com/Baijiong-Lin/LoRA-Torch 56 | # cd LoRA-Torch 57 | # pip install -e . 58 | ``` 59 | 60 | 2. Replace the layers where you would like to use LoRA by using ``loratorch``. 61 | 62 | ```python 63 | # ===== Before ===== 64 | # layer = nn.Linear(in_features, out_features) 65 | 66 | # ===== After ====== 67 | import loratorch as lora 68 | # Add a pair of low-rank adaptation matrices with rank r=16 and alpha=32 69 | layer = lora.Linear(in_features, out_features, r=16, lora_alpha=32) 70 | ``` 71 | 72 | 3. Mark only LoRA parameters as trainable before the training loop. 73 | 74 | ```python 75 | model = Model() 76 | # (!!!) This sets requires_grad to False for all parameters without the string "lora_" in their names 77 | lora.mark_only_lora_as_trainable(model) 78 | 79 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 80 | # Training loop 81 | for batch in dataloader: 82 | model.train() 83 | # forward process 84 | loss = forward_fun(model, batch) 85 | # backward process 86 | optimizer.zero_grad() 87 | loss.backward() 88 | optimizer.step() 89 | # (!!!) reregister model param to ensure they are in model.state_dict() and model.parameters() 90 | # (!!!) Without this line, the performance does not be affected but you will find that some weights are missing in model.state_dict() and model.parameters() 91 | lora.register_model_param_after_backward(model) 92 | ``` 93 | 94 | 4. Save LoRA model (only the LoRA matrixes will be saved). 95 | 96 | ```python 97 | # ===== Before ===== 98 | # torch.save(model.state_dict(), checkpoint_path) 99 | # ===== After ===== 100 | torch.save(lora.lora_state_dict(model), checkpoint_path) 101 | ``` 102 | 103 | 5. Load LoRA model (need to load the pre-trained model first). 104 | 105 | ```python 106 | # Load the pre-trained checkpoint first 107 | model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False) 108 | # Then load the LoRA checkpoint 109 | model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False) 110 | ``` 111 | 112 | ## Contributor 113 | 114 | ``loratorch`` is developed and maintained by [Baijiong Lin](https://baijiong-lin.github.io). 115 | 116 | ## Contact Us 117 | 118 | If you have any question or suggestion, please feel free to contact us by [raising an issue](https://github.com/Baijiong-Lin/LoRA-Torch/issues) or sending an email to ``bj.lin.email@gmail.com``. 119 | 120 | ## Acknowledgements 121 | 122 | ``loratorch`` is heavily based on ``loralib``. We thank its authors for their wonderful and open-source codebase. 123 | 124 | ## Citation 125 | 126 | If you find ``loratorch`` useful for your research or development, please cite the following: 127 | 128 | ```BibTeX 129 | @inproceedings{hu2022lora, 130 | title={Lo{RA}: Low-Rank Adaptation of Large Language Models}, 131 | author={Edward J Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Lu Wang and Weizhu Chen}, 132 | booktitle={International Conference on Learning Representations}, 133 | year={2022}, 134 | } 135 | 136 | @software{lin2023loratorch, 137 | author = {Baijiong Lin}, 138 | title = {{LoRA-Torch}: {PyTorch} Reimplementation of {LoRA}}, 139 | url = {https://github.com/Baijiong-Lin/LoRA-Torch}, 140 | year = {2023} 141 | } 142 | ``` 143 | 144 | ## License 145 | 146 | ``loratorch`` is released under the [MIT](./LICENSE) license. 147 | 148 | 149 | -------------------------------------------------------------------------------- /examples/linear.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "authorship_tag": "ABX9TyNIy6uRivivL3VYs0eNmj01", 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "metadata": { 33 | "colab": { 34 | "base_uri": "https://localhost:8080/" 35 | }, 36 | "id": "-j81b-ktkWH0", 37 | "outputId": "ca0a1e95-ae6c-44ac-cee0-5bf75f8fd3b1" 38 | }, 39 | "outputs": [ 40 | { 41 | "output_type": "stream", 42 | "name": "stdout", 43 | "text": [ 44 | "Collecting git+https://github.com/microsoft/LoRA\n", 45 | " Cloning https://github.com/microsoft/LoRA to /tmp/pip-req-build-e5ulip2v\n", 46 | " Running command git clone --filter=blob:none --quiet https://github.com/microsoft/LoRA /tmp/pip-req-build-e5ulip2v\n", 47 | " Resolved https://github.com/microsoft/LoRA to commit 998cfe4d351f4d6b4a47f0921dec2397aa0b9dfe\n", 48 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 49 | "Collecting git+https://github.com/Baijiong-Lin/LoRA-Torch\n", 50 | " Cloning https://github.com/Baijiong-Lin/LoRA-Torch to /tmp/pip-req-build-l221wgno\n", 51 | " Running command git clone --filter=blob:none --quiet https://github.com/Baijiong-Lin/LoRA-Torch /tmp/pip-req-build-l221wgno\n", 52 | " Resolved https://github.com/Baijiong-Lin/LoRA-Torch to commit 4b550558ff1c7ef6bd1009b80d30e52069396f69\n", 53 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "!pip install git+https://github.com/microsoft/LoRA\n", 59 | "!pip install git+https://github.com/Baijiong-Lin/LoRA-Torch" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "source": [ 65 | "import torch, loralib, loratorch, copy\n", 66 | "import torch.nn as nn" 67 | ], 68 | "metadata": { 69 | "id": "tDo6S6QPkiXI" 70 | }, 71 | "execution_count": 2, 72 | "outputs": [] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "source": [ 77 | "model_lib = loralib.Linear(5, 6, r=4, lora_alpha=1)\n", 78 | "model_torch = loratorch.Linear(5, 6, r=4, lora_alpha=1)\n", 79 | "\n", 80 | "model_lib.weight.data = copy.deepcopy(model_torch.weight.data)\n", 81 | "model_lib.lora_A.data = copy.deepcopy(model_torch.w_lora_A.data)\n", 82 | "model_lib.lora_B.data = copy.deepcopy(model_torch.w_lora_B.data)\n", 83 | "model_lib.bias.data = copy.deepcopy(model_torch.bias.data)" 84 | ], 85 | "metadata": { 86 | "id": "_2ceOeiIk8xp" 87 | }, 88 | "execution_count": 3, 89 | "outputs": [] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "source": [ 94 | "loralib.mark_only_lora_as_trainable(model_lib)\n", 95 | "loratorch.mark_only_lora_as_trainable(model_torch)" 96 | ], 97 | "metadata": { 98 | "id": "CF_PCLIslcJl" 99 | }, 100 | "execution_count": 4, 101 | "outputs": [] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "source": [ 106 | "optimizer_lib = torch.optim.SGD(model_lib.parameters(), lr=0.1)\n", 107 | "optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.1)\n", 108 | "\n", 109 | "for _ in range(3):\n", 110 | " model_lib.train()\n", 111 | " model_torch.train()\n", 112 | " x = torch.rand(2, 5)\n", 113 | "\n", 114 | " loss1 = model_lib(x).sum()\n", 115 | " optimizer_lib.zero_grad()\n", 116 | " loss1.backward()\n", 117 | " optimizer_lib.step()\n", 118 | "\n", 119 | " loss2 = model_torch(x).sum()\n", 120 | " optimizer_torch.zero_grad()\n", 121 | " loss2.backward()\n", 122 | " optimizer_torch.step()\n", 123 | "\n", 124 | " # for k, v in model_lib.named_parameters():\n", 125 | " # print(k, v.grad)\n", 126 | "\n", 127 | " # for k, v in model_torch.named_parameters():\n", 128 | " # print(k, v.grad)\n", 129 | " print(torch.isclose(model_lib.lora_A.grad, model_torch.w_lora_A.grad))\n", 130 | " print(torch.isclose(model_lib.lora_B.grad, model_torch.w_lora_B.grad))\n", 131 | "\n", 132 | " x_test = torch.rand(3, 5)\n", 133 | " model_lib.eval()\n", 134 | " model_torch.eval()\n", 135 | " print(torch.isclose(model_lib(x_test), model_torch(x_test)))" 136 | ], 137 | "metadata": { 138 | "colab": { 139 | "base_uri": "https://localhost:8080/" 140 | }, 141 | "id": "GdmSBAenkvP1", 142 | "outputId": "7e186013-f7a2-4fc6-90c3-854f1f1fc2fb" 143 | }, 144 | "execution_count": 5, 145 | "outputs": [ 146 | { 147 | "output_type": "stream", 148 | "name": "stdout", 149 | "text": [ 150 | "tensor([[True, True, True, True, True],\n", 151 | " [True, True, True, True, True],\n", 152 | " [True, True, True, True, True],\n", 153 | " [True, True, True, True, True]])\n", 154 | "tensor([[True, True, True, True],\n", 155 | " [True, True, True, True],\n", 156 | " [True, True, True, True],\n", 157 | " [True, True, True, True],\n", 158 | " [True, True, True, True],\n", 159 | " [True, True, True, True]])\n", 160 | "tensor([[True, True, True, True, True, True],\n", 161 | " [True, True, True, True, True, True],\n", 162 | " [True, True, True, True, True, True]])\n", 163 | "tensor([[True, True, True, True, True],\n", 164 | " [True, True, True, True, True],\n", 165 | " [True, True, True, True, True],\n", 166 | " [True, True, True, True, True]])\n", 167 | "tensor([[True, True, True, True],\n", 168 | " [True, True, True, True],\n", 169 | " [True, True, True, True],\n", 170 | " [True, True, True, True],\n", 171 | " [True, True, True, True],\n", 172 | " [True, True, True, True]])\n", 173 | "tensor([[True, True, True, True, True, True],\n", 174 | " [True, True, True, True, True, True],\n", 175 | " [True, True, True, True, True, True]])\n", 176 | "tensor([[True, True, True, True, True],\n", 177 | " [True, True, True, True, True],\n", 178 | " [True, True, True, True, True],\n", 179 | " [True, True, True, True, True]])\n", 180 | "tensor([[True, True, True, True],\n", 181 | " [True, True, True, True],\n", 182 | " [True, True, True, True],\n", 183 | " [True, True, True, True],\n", 184 | " [True, True, True, True],\n", 185 | " [True, True, True, True]])\n", 186 | "tensor([[True, True, True, True, True, True],\n", 187 | " [True, True, True, True, True, True],\n", 188 | " [True, True, True, True, True, True]])\n" 189 | ] 190 | } 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "source": [], 196 | "metadata": { 197 | "id": "Ul2JBtuQmIuG" 198 | }, 199 | "execution_count": 5, 200 | "outputs": [] 201 | } 202 | ] 203 | } -------------------------------------------------------------------------------- /examples/mergedlinear.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "authorship_tag": "ABX9TyOaHR81EoRiuDjhTvY9PZ+l", 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "metadata": { 33 | "colab": { 34 | "base_uri": "https://localhost:8080/" 35 | }, 36 | "id": "3YYQD-Le8iaB", 37 | "outputId": "1d5241cf-ce37-4abd-e282-170cba75f654" 38 | }, 39 | "outputs": [ 40 | { 41 | "output_type": "stream", 42 | "name": "stdout", 43 | "text": [ 44 | "Collecting git+https://github.com/microsoft/LoRA\n", 45 | " Cloning https://github.com/microsoft/LoRA to /tmp/pip-req-build-b6kloy6q\n", 46 | " Running command git clone --filter=blob:none --quiet https://github.com/microsoft/LoRA /tmp/pip-req-build-b6kloy6q\n", 47 | " Resolved https://github.com/microsoft/LoRA to commit 998cfe4d351f4d6b4a47f0921dec2397aa0b9dfe\n", 48 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 49 | "Collecting git+https://github.com/Baijiong-Lin/LoRA-Torch\n", 50 | " Cloning https://github.com/Baijiong-Lin/LoRA-Torch to /tmp/pip-req-build-pdln00a3\n", 51 | " Running command git clone --filter=blob:none --quiet https://github.com/Baijiong-Lin/LoRA-Torch /tmp/pip-req-build-pdln00a3\n", 52 | " Resolved https://github.com/Baijiong-Lin/LoRA-Torch to commit f2dccb93ac60313d15251fee12e8b7be5649471a\n", 53 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "!pip install git+https://github.com/microsoft/LoRA\n", 59 | "!pip install git+https://github.com/Baijiong-Lin/LoRA-Torch" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "source": [ 65 | "import torch, loralib, loratorch, copy\n", 66 | "import torch.nn as nn" 67 | ], 68 | "metadata": { 69 | "id": "t-KzlZJT81XD" 70 | }, 71 | "execution_count": 2, 72 | "outputs": [] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "source": [ 77 | "model_lib = loralib.MergedLinear(10, 3*10, bias=False, r=4, enable_lora=[True, False, True], fan_in_fan_out=False)\n", 78 | "loralib.mark_only_lora_as_trainable(model_lib)\n", 79 | "optimizer = torch.optim.SGD(model_lib.parameters(), lr=0.1)\n", 80 | "loss_fn = nn.MSELoss()\n", 81 | "\n", 82 | "for _ in range(3):\n", 83 | " model_lib.train()\n", 84 | " x = torch.rand(5, 10)\n", 85 | " y = torch.rand(5, 3*10)\n", 86 | " loss = loss_fn(model_lib(x), y)\n", 87 | " optimizer.zero_grad()\n", 88 | " loss.backward()\n", 89 | " optimizer.step()\n", 90 | "\n", 91 | " model_lib.eval()\n", 92 | " print(model_lib(x).size())" 93 | ], 94 | "metadata": { 95 | "colab": { 96 | "base_uri": "https://localhost:8080/", 97 | "height": 366 98 | }, 99 | "id": "LAyOBzAm82dD", 100 | "outputId": "bf2541a6-3632-41fe-cb18-e5adc63eb2fb" 101 | }, 102 | "execution_count": 3, 103 | "outputs": [ 104 | { 105 | "output_type": "error", 106 | "ename": "RuntimeError", 107 | "evalue": "ignored", 108 | "traceback": [ 109 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 110 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", 111 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mmodel_lib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_lib\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 112 | "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36meval\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 2305\u001b[0m \u001b[0mModule\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2306\u001b[0m \"\"\"\n\u001b[0;32m-> 2307\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2308\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2309\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrequires_grad_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequires_grad\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 113 | "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/loralib/layers.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, mode)\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0mgroups\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_lora\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 233\u001b[0m ).squeeze(0)\n\u001b[0;32m--> 234\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_pad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdelta_w\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscaling\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmerged\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 114 | "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/loralib/layers.py\u001b[0m in \u001b[0;36mzero_pad\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_zeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 205\u001b[0;31m result[:, self.lora_ind] = x.reshape(\n\u001b[0m\u001b[1;32m 206\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_features\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_lora\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_lora\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m )\n", 115 | "\u001b[0;31mRuntimeError\u001b[0m: shape mismatch: value tensor of shape [10, 20] cannot be broadcast to indexing result of shape [20, 20]" 116 | ] 117 | } 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "source": [ 123 | "model_torch = loratorch.MergedLinear(10, 3*10, bias=False, r=4, enable_lora=[True, False, True], fan_in_fan_out=False)\n", 124 | "loralib.mark_only_lora_as_trainable(model_torch)\n", 125 | "optimizer = torch.optim.SGD(model_torch.parameters(), lr=0.1)\n", 126 | "loss_fn = nn.MSELoss()\n", 127 | "\n", 128 | "for _ in range(3):\n", 129 | " model_torch.train()\n", 130 | " x = torch.rand(5, 10)\n", 131 | " y = torch.rand(5, 3*10)\n", 132 | " loss = loss_fn(model_torch(x), y)\n", 133 | " optimizer.zero_grad()\n", 134 | " loss.backward()\n", 135 | " optimizer.step()\n", 136 | "\n", 137 | " model_torch.eval()\n", 138 | " print(model_torch(x).size())" 139 | ], 140 | "metadata": { 141 | "colab": { 142 | "base_uri": "https://localhost:8080/" 143 | }, 144 | "id": "GRPPQ-Wk9evZ", 145 | "outputId": "92684b76-3dc3-47a5-8f8a-a38508850053" 146 | }, 147 | "execution_count": 4, 148 | "outputs": [ 149 | { 150 | "output_type": "stream", 151 | "name": "stdout", 152 | "text": [ 153 | "torch.Size([5, 30])\n", 154 | "torch.Size([5, 30])\n", 155 | "torch.Size([5, 30])\n" 156 | ] 157 | } 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "source": [], 163 | "metadata": { 164 | "id": "jB55l3PU9veg" 165 | }, 166 | "execution_count": null, 167 | "outputs": [] 168 | } 169 | ] 170 | } -------------------------------------------------------------------------------- /loratorch/layers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # This code is reconstructed based on loralib (https://github.com/microsoft/LoRA) by Baijiong Lin. 3 | # ------------------------------------------------------------------------------------------ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import math 9 | from typing import Optional, List 10 | 11 | def set_param(curr_mod, name, param=None, mode='update', with_nn=False): 12 | r"""Refer to https://github.com/Baijiong-Lin/MOML/blob/main/MTL/utils.py""" 13 | if '.' in name: 14 | n = name.split('.') 15 | module_name = n[0] 16 | rest = '.'.join(n[1:]) 17 | for name, mod in curr_mod.named_children(): 18 | if module_name == name: 19 | return set_param(mod, rest, param, mode=mode, with_nn=with_nn) 20 | else: 21 | if mode == 'update': 22 | delattr(curr_mod, name) 23 | if with_nn: 24 | setattr(curr_mod, name, nn.Parameter(param)) 25 | else: 26 | setattr(curr_mod, name, param) 27 | elif mode == 'get': 28 | if hasattr(curr_mod, name): 29 | p = getattr(curr_mod, name) 30 | return p 31 | 32 | class LoRALayer(): 33 | def __init__( 34 | self, 35 | r: int, 36 | lora_alpha: int, 37 | fan_in_fan_out: bool = False, 38 | ): 39 | self.r = r 40 | self.lora_alpha = lora_alpha 41 | if self.r > 0: 42 | self.scaling = self.lora_alpha / self.r 43 | # Mark the weight as unmerged 44 | self.merged = False 45 | # Set this to True if the layer to replace stores weight like (fan_in, fan_out) 46 | self.fan_in_fan_out = fan_in_fan_out 47 | # define params that require LoRA {'param_name': 'lora_name'} 48 | self.params_with_lora = {} 49 | 50 | def register_weight_after_backward(self): 51 | for param_name, _ in self.params_with_lora.items(): 52 | p = set_param(self, param_name, mode='get') 53 | # print('+'*10, param_name, p.flatten()[:10]) 54 | set_param(self, param_name, param=p, mode='update', with_nn=True) 55 | 56 | def register_lora_param(self): 57 | r"""Register LoRA matrix""" 58 | for param_name, lora_name in self.params_with_lora.items(): 59 | assert len(eval(f'self.{param_name}').size()) == 2 60 | self.register_parameter(f'{lora_name}_lora_A', 61 | nn.Parameter(eval(f'self.{param_name}').new_zeros((self.r, eval(f'self.{param_name}').size()[1]))) 62 | ) 63 | self.register_parameter(f'{lora_name}_lora_B', 64 | nn.Parameter(eval(f'self.{param_name}').new_zeros((eval(f'self.{param_name}').size()[0], self.r))) 65 | ) 66 | eval(f'self.{param_name}').requires_grad = False 67 | 68 | def init_lora_param(self): 69 | for param_name, lora_name in self.params_with_lora.items(): 70 | if hasattr(self, f'{lora_name}_lora_A'): 71 | # initialize A the same way as the default for nn.Linear and B to zero 72 | nn.init.kaiming_uniform_(eval(f'self.{lora_name}_lora_A'), a=math.sqrt(5)) 73 | nn.init.zeros_(eval(f'self.{lora_name}_lora_B')) 74 | 75 | def transpose(self, w: torch.Tensor): 76 | return w.transpose(0, 1) if self.fan_in_fan_out else w 77 | 78 | def merge_BA(self, param_name: str): 79 | lora_name = self.params_with_lora[param_name] 80 | return self.transpose((eval(f'self.{lora_name}_lora_B') @ eval(f'self.{lora_name}_lora_A')).view(eval(f'self.{param_name}').shape)) 81 | 82 | def merge_lora_param(self): 83 | r"""p_new = p + scaling * B @ A and keep differentiable to A and B""" 84 | for param_name, lora_name in self.params_with_lora.items(): 85 | p = set_param(self, param_name, mode='get') 86 | # detach() is very important here 87 | p_new = p.detach() + self.merge_BA(param_name) * self.scaling 88 | set_param(self, param_name, param=p_new, mode='update') 89 | 90 | def add_lora_data(self): 91 | r"""NOT differentiable""" 92 | for param_name, lora_name in self.params_with_lora.items(): 93 | eval(f'self.{param_name}').data += self.merge_BA(param_name) * self.scaling 94 | 95 | def sub_lora_data(self): 96 | r"""NOT differentiable""" 97 | for param_name, lora_name in self.params_with_lora.items(): 98 | eval(f'self.{param_name}').data -= self.merge_BA(param_name) * self.scaling 99 | 100 | def lora_train(self, mode: bool = True): 101 | if mode: 102 | if self.merged and self.r > 0: 103 | # Make sure that the weights are not merged 104 | self.sub_lora_data() 105 | self.merged = False 106 | else: 107 | if not self.merged and self.r > 0: 108 | # Merge the weights and mark it 109 | self.add_lora_data() 110 | self.merged = True 111 | 112 | 113 | class Embedding(nn.Embedding, LoRALayer): 114 | # LoRA implemented in a Embedding layer 115 | def __init__( 116 | self, 117 | num_embeddings: int, 118 | embedding_dim: int, 119 | r: int = 0, 120 | lora_alpha: int = 1, 121 | **kwargs 122 | ): 123 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) 124 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha) 125 | 126 | self.params_with_lora = {'weight': 'w'} 127 | if r > 0: 128 | self.register_lora_param() 129 | nn.Embedding.reset_parameters(self) 130 | self.init_lora_param() 131 | 132 | def init_lora_param(self): 133 | if hasattr(self, 'w_lora_A'): 134 | # initialize A the same way as the default for nn.Linear and B to zero 135 | nn.init.zeros_(self.w_lora_A) 136 | nn.init.normal_(self.w_lora_B) 137 | 138 | def train(self, mode: bool = True): 139 | nn.Embedding.train(self, mode) 140 | self.lora_train(mode) 141 | 142 | def forward(self, x: torch.Tensor, **kwargs): 143 | 144 | if self.r > 0 and not self.merged: 145 | self.merge_lora_param() 146 | result = nn.Embedding.forward(self, x, **kwargs) 147 | self.sub_lora_data() 148 | return result 149 | else: 150 | return nn.Embedding.forward(self, x, **kwargs) 151 | 152 | class Linear(nn.Linear, LoRALayer): 153 | # LoRA implemented in a Linear layer 154 | def __init__( 155 | self, 156 | in_features: int, 157 | out_features: int, 158 | r: int = 0, 159 | lora_alpha: int = 1, 160 | fan_in_fan_out: bool = False, 161 | **kwargs 162 | ): 163 | nn.Linear.__init__(self, in_features, out_features, **kwargs) 164 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, fan_in_fan_out=fan_in_fan_out) 165 | 166 | # Actual trainable parameters 167 | self.params_with_lora = {'weight': 'w'} 168 | if r > 0: 169 | self.register_lora_param() 170 | nn.Linear.reset_parameters(self) 171 | self.init_lora_param() 172 | self.weight.data = self.transpose(self.weight.data) 173 | 174 | def train(self, mode: bool = True): 175 | nn.Linear.train(self, mode) 176 | self.lora_train(mode) 177 | 178 | def forward(self, x: torch.Tensor, **kwargs): 179 | 180 | if self.r > 0 and not self.merged: 181 | self.merge_lora_param() 182 | result = nn.Linear.forward(self, x, **kwargs) 183 | self.sub_lora_data() 184 | return result 185 | else: 186 | return nn.Linear.forward(self, x, **kwargs) 187 | 188 | class Conv1d(nn.Conv1d, LoRALayer): 189 | # LoRA implemented in a Conv1d layer 190 | def __init__( 191 | self, 192 | in_channels: int, 193 | out_channels: int, 194 | kernel_size: int, 195 | r: int = 0, 196 | lora_alpha: int = 1, 197 | **kwargs 198 | ): 199 | nn.Conv1d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) 200 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha) 201 | 202 | assert type(kernel_size) is int 203 | # Actual trainable parameters 204 | self.params_with_lora = {'weight': 'w'} 205 | if r > 0: 206 | self.w_lora_A = nn.Parameter( 207 | self.weight.new_zeros((r*kernel_size, in_channels*kernel_size)) 208 | ) 209 | self.w_lora_B = nn.Parameter( 210 | self.weight.new_zeros((out_channels//self.groups, r*kernel_size)) 211 | ) 212 | # Freezing the pre-trained weight matrix 213 | self.weight.requires_grad = False 214 | nn.Conv1d.reset_parameters(self) 215 | self.init_lora_param() 216 | 217 | def train(self, mode: bool = True): 218 | nn.Conv1d.train(self, mode) 219 | self.lora_train(mode) 220 | 221 | def forward(self, x: torch.Tensor, **kwargs): 222 | 223 | if self.r > 0 and not self.merged: 224 | self.merge_lora_param() 225 | result = nn.Conv1d.forward(self, x, **kwargs) 226 | self.sub_lora_data() 227 | return result 228 | else: 229 | return nn.Conv1d.forward(self, x, **kwargs) 230 | 231 | class Conv2d(nn.Conv2d, LoRALayer): 232 | # LoRA implemented in a Conv2d layer 233 | def __init__( 234 | self, 235 | in_channels: int, 236 | out_channels: int, 237 | kernel_size: int, 238 | r: int = 0, 239 | lora_alpha: int = 1, 240 | **kwargs 241 | ): 242 | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) 243 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha) 244 | 245 | assert type(kernel_size) is int 246 | # Actual trainable parameters 247 | self.params_with_lora = {'weight': 'w'} 248 | if r > 0: 249 | self.w_lora_A = nn.Parameter( 250 | self.weight.new_zeros((r*kernel_size, in_channels*kernel_size)) 251 | ) 252 | self.w_lora_B = nn.Parameter( 253 | self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size)) 254 | ) 255 | # Freezing the pre-trained weight matrix 256 | self.weight.requires_grad = False 257 | nn.Conv2d.reset_parameters(self) 258 | self.init_lora_param() 259 | 260 | def train(self, mode: bool = True): 261 | nn.Conv2d.train(self, mode) 262 | self.lora_train(mode) 263 | 264 | def forward(self, x: torch.Tensor, **kwargs): 265 | 266 | if self.r > 0 and not self.merged: 267 | self.merge_lora_param() 268 | result = nn.Conv2d.forward(self, x, **kwargs) 269 | self.sub_lora_data() 270 | return result 271 | else: 272 | return nn.Conv2d.forward(self, x, **kwargs) 273 | 274 | class Conv3d(nn.Conv3d, LoRALayer): 275 | # LoRA implemented in a Conv3d layer 276 | def __init__( 277 | self, 278 | in_channels: int, 279 | out_channels: int, 280 | kernel_size: int, 281 | r: int = 0, 282 | lora_alpha: int = 1, 283 | **kwargs 284 | ): 285 | nn.Conv3d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) 286 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha) 287 | 288 | assert type(kernel_size) is int 289 | # Actual trainable parameters 290 | self.params_with_lora = {'weight': 'w'} 291 | if r > 0: 292 | self.w_lora_A = nn.Parameter( 293 | self.weight.new_zeros((r*kernel_size, in_channels*kernel_size*kernel_size)) 294 | ) 295 | self.w_lora_B = nn.Parameter( 296 | self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size)) 297 | ) 298 | # Freezing the pre-trained weight matrix 299 | self.weight.requires_grad = False 300 | nn.Conv3d.reset_parameters(self) 301 | self.init_lora_param() 302 | 303 | def train(self, mode: bool = True): 304 | nn.Conv3d.train(self, mode) 305 | self.lora_train(mode) 306 | 307 | def forward(self, x: torch.Tensor, **kwargs): 308 | 309 | if self.r > 0 and not self.merged: 310 | self.merge_lora_param() 311 | result = nn.Conv3d.forward(self, x, **kwargs) 312 | self.sub_lora_data() 313 | return result 314 | else: 315 | return nn.Conv3d.forward(self, x, **kwargs) 316 | 317 | class MultiheadAttention(nn.MultiheadAttention, LoRALayer): 318 | # LoRA implemented in a MultiheadAttention layer 319 | def __init__( 320 | self, 321 | embed_dim: int, 322 | num_heads: int, 323 | enable_lora: list = ['q', 'k', 'v', 'o'], 324 | r: int = 0, 325 | lora_alpha: int = 1, 326 | **kwargs 327 | ): 328 | nn.MultiheadAttention.__init__(self, embed_dim, num_heads, **kwargs) 329 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha) 330 | 331 | # Actual trainable parameters 332 | if self.r > 0: 333 | if 'o' in enable_lora: 334 | self.params_with_lora.update({'out_proj.weight': 'o'}) 335 | 336 | if not self._qkv_same_embed_dim: 337 | for n in ['q', 'k', 'v']: 338 | if n in enable_lora: 339 | self.params_with_lora.update({f'{n}_proj_weight': n}) 340 | self.register_lora_param() 341 | nn.MultiheadAttention._reset_parameters(self) 342 | self.init_lora_param() 343 | else: 344 | lora_name, enable_lora_bool = '', [] 345 | for n in ['q', 'k', 'v']: 346 | if n in enable_lora: 347 | lora_name += n 348 | enable_lora_bool.append(True) 349 | else: 350 | enable_lora_bool.append(False) 351 | self.params_with_lora.update({'in_proj_weight': lora_name}) 352 | self.register_lora_param() 353 | nn.MultiheadAttention._reset_parameters(self) 354 | if 'o' in enable_lora: 355 | self.init_lora_param_o() 356 | self.init_lora_param_qkv(enable_lora_bool) 357 | 358 | def init_lora_param_o(self): 359 | param_name, lora_name = 'out_proj.weight', 'o' 360 | if hasattr(self, f'{lora_name}_lora_A'): 361 | nn.init.kaiming_uniform_(eval(f'self.{lora_name}_lora_A'), a=math.sqrt(5)) 362 | nn.init.zeros_(eval(f'self.{lora_name}_lora_B')) 363 | 364 | def init_lora_param_qkv(self, enable_lora_bool): 365 | lora_name = self.params_with_lora['in_proj_weight'] 366 | nn.init.zeros_(eval(f'self.{lora_name}_lora_B')) 367 | dim = int(self.in_proj_weight.size()[1] / 3) 368 | for idx, enable in zip(range(3), enable_lora_bool): 369 | if enable: 370 | nn.init.kaiming_uniform_(eval(f'self.{lora_name}_lora_A')[:,idx*dim:(idx+1)*dim], a=math.sqrt(5)) 371 | else: 372 | nn.init.zeros_(eval(f'self.{lora_name}_lora_A')[:,idx*dim:(idx+1)*dim]) 373 | 374 | def train(self, mode: bool = True): 375 | nn.MultiheadAttention.train(self, mode) 376 | self.lora_train(mode) 377 | 378 | def forward(self, 379 | query: torch.Tensor, 380 | key: torch.Tensor, 381 | value: torch.Tensor, 382 | **kwargs): 383 | 384 | if self.r > 0 and not self.merged: 385 | self.merge_lora_param() 386 | result = nn.MultiheadAttention.forward(self, query, key, value, **kwargs) 387 | self.sub_lora_data() 388 | return result 389 | else: 390 | return nn.MultiheadAttention.forward(self, query, key, value, **kwargs) 391 | 392 | class MergedLinear(nn.Linear, LoRALayer): 393 | # LoRA implemented in a dense layer 394 | def __init__( 395 | self, 396 | in_features: int, 397 | out_features: int, 398 | r: int = 0, 399 | lora_alpha: int = 1, 400 | enable_lora: List[bool] = [False], 401 | fan_in_fan_out: bool = False, 402 | **kwargs 403 | ): 404 | nn.Linear.__init__(self, in_features, out_features, **kwargs) 405 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha) 406 | 407 | assert out_features % len(enable_lora) == 0, \ 408 | 'The length of enable_lora must divide out_features' 409 | self.enable_lora = enable_lora 410 | # Actual trainable parameters 411 | self.params_with_lora = {'weight': 'w'} 412 | if r > 0 and any(enable_lora): 413 | self.w_lora_A = nn.Parameter( 414 | self.weight.new_zeros((r * sum(enable_lora), in_features))) 415 | self.w_lora_B = nn.Parameter( 416 | self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) 417 | ) # weights for Conv1D with groups=sum(enable_lora) 418 | # Freezing the pre-trained weight matrix 419 | self.weight.requires_grad = False 420 | # Compute the indices 421 | self.lora_ind = self.weight.new_zeros( 422 | (out_features, ), dtype=torch.bool 423 | ).view(len(enable_lora), -1) 424 | self.lora_ind[enable_lora, :] = True 425 | self.lora_ind = self.lora_ind.view(-1) 426 | nn.Linear.reset_parameters(self) 427 | self.init_lora_param() 428 | self.weight.data = self.transpose(self.weight.data) 429 | 430 | def zero_pad(self, x): 431 | result = x.new_zeros((len(self.lora_ind), *x.shape[1:])) 432 | result[self.lora_ind] = x 433 | return result 434 | 435 | def merge_BA(self, param_name: str): 436 | lora_name = self.params_with_lora[param_name] 437 | delta_w = F.conv1d( 438 | eval(f'self.{lora_name}_lora_A').unsqueeze(0), 439 | eval(f'self.{lora_name}_lora_B').unsqueeze(-1), 440 | groups=sum(self.enable_lora) 441 | ).squeeze(0) 442 | return self.transpose(self.zero_pad(delta_w)) 443 | 444 | def train(self, mode: bool = True): 445 | nn.Linear.train(self, mode) 446 | self.lora_train(mode) 447 | 448 | def forward(self, x: torch.Tensor, **kwargs): 449 | 450 | if self.r > 0 and not self.merged: 451 | self.merge_lora_param() 452 | result = nn.Linear.forward(self, x, **kwargs) 453 | self.sub_lora_data() 454 | return result 455 | else: 456 | return nn.Linear.forward(self, x, **kwargs) 457 | -------------------------------------------------------------------------------- /examples/Finetune_open_clip_with_LoRA_Torch_on_CIFAR10.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | "kernelspec": { 4 | "name": "python3", 5 | "display_name": "Python 3", 6 | "language": "python" 7 | }, 8 | "language_info": { 9 | "name": "python", 10 | "version": "3.11.11", 11 | "mimetype": "text/x-python", 12 | "codemirror_mode": { 13 | "name": "ipython", 14 | "version": 3 15 | }, 16 | "pygments_lexer": "ipython3", 17 | "nbconvert_exporter": "python", 18 | "file_extension": ".py" 19 | }, 20 | "colab": { 21 | "provenance": [], 22 | "gpuType": "T4", 23 | "include_colab_link": true 24 | }, 25 | "accelerator": "GPU", 26 | "kaggle": { 27 | "accelerator": "nvidiaTeslaT4", 28 | "dataSources": [], 29 | "dockerImageVersionId": 31041, 30 | "isInternetEnabled": true, 31 | "language": "python", 32 | "sourceType": "notebook", 33 | "isGpuEnabled": true 34 | } 35 | }, 36 | "nbformat_minor": 0, 37 | "nbformat": 4, 38 | "cells": [ 39 | { 40 | "cell_type": "markdown", 41 | "metadata": { 42 | "id": "view-in-github", 43 | "colab_type": "text" 44 | }, 45 | "source": [ 46 | "\"Open" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "source": [ 52 | "### This example demonstrates how to apply LoRA-Torch to ``nn.MultiheadAttention`` in OpenCLIP. We greatly appreciate [Viet Q. Vo](https://vietvo89.github.io/)'s valuable contribution." 53 | ], 54 | "metadata": { 55 | "id": "Yhxyu7vivWl2" 56 | } 57 | }, 58 | { 59 | "cell_type": "code", 60 | "source": [ 61 | "!pip install open-clip-torch\n", 62 | "!pip install git+https://github.com/Baijiong-Lin/LoRA-Torch" 63 | ], 64 | "metadata": { 65 | "id": "753R63XXhzqE", 66 | "outputId": "ad3892df-aeb6-4ebe-b7ae-2400535ec3ab", 67 | "trusted": true, 68 | "execution": { 69 | "iopub.status.busy": "2025-06-09T01:09:26.555991Z", 70 | "iopub.execute_input": "2025-06-09T01:09:26.556177Z", 71 | "iopub.status.idle": "2025-06-09T01:10:48.370674Z", 72 | "shell.execute_reply.started": "2025-06-09T01:09:26.556160Z", 73 | "shell.execute_reply": "2025-06-09T01:10:48.369896Z" 74 | }, 75 | "scrolled": true, 76 | "colab": { 77 | "base_uri": "https://localhost:8080/" 78 | } 79 | }, 80 | "outputs": [ 81 | { 82 | "output_type": "stream", 83 | "name": "stdout", 84 | "text": [ 85 | "Requirement already satisfied: open-clip-torch in /usr/local/lib/python3.11/dist-packages (2.32.0)\n", 86 | "Requirement already satisfied: torch>=1.9.0 in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (2.6.0+cu124)\n", 87 | "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (0.21.0+cu124)\n", 88 | "Requirement already satisfied: regex in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (2024.11.6)\n", 89 | "Requirement already satisfied: ftfy in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (6.3.1)\n", 90 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (4.67.1)\n", 91 | "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (0.32.4)\n", 92 | "Requirement already satisfied: safetensors in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (0.5.3)\n", 93 | "Requirement already satisfied: timm in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (1.0.15)\n", 94 | "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (3.18.0)\n", 95 | "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (4.14.0)\n", 96 | "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (3.5)\n", 97 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (3.1.6)\n", 98 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (2025.3.2)\n", 99 | "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n", 100 | "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n", 101 | "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n", 102 | "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (9.1.0.70)\n", 103 | "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.5.8)\n", 104 | "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (11.2.1.3)\n", 105 | "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (10.3.5.147)\n", 106 | "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (11.6.1.9)\n", 107 | "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.3.1.170)\n", 108 | "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (0.6.2)\n", 109 | "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (2.21.5)\n", 110 | "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n", 111 | "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n", 112 | "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (3.2.0)\n", 113 | "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (1.13.1)\n", 114 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=1.9.0->open-clip-torch) (1.3.0)\n", 115 | "Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from ftfy->open-clip-torch) (0.2.13)\n", 116 | "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->open-clip-torch) (24.2)\n", 117 | "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->open-clip-torch) (6.0.2)\n", 118 | "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->open-clip-torch) (2.32.3)\n", 119 | "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->open-clip-torch) (1.1.2)\n", 120 | "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision->open-clip-torch) (2.0.2)\n", 121 | "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision->open-clip-torch) (11.2.1)\n", 122 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=1.9.0->open-clip-torch) (3.0.2)\n", 123 | "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub->open-clip-torch) (3.4.2)\n", 124 | "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub->open-clip-torch) (3.10)\n", 125 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub->open-clip-torch) (2.4.0)\n", 126 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub->open-clip-torch) (2025.4.26)\n", 127 | "Collecting git+https://github.com/Baijiong-Lin/LoRA-Torch\n", 128 | " Cloning https://github.com/Baijiong-Lin/LoRA-Torch to /tmp/pip-req-build-1a5zm6vx\n", 129 | " Running command git clone --filter=blob:none --quiet https://github.com/Baijiong-Lin/LoRA-Torch /tmp/pip-req-build-1a5zm6vx\n", 130 | " Resolved https://github.com/Baijiong-Lin/LoRA-Torch to commit 3b6f10a3bdebfb0da1abeb4c265f914ed06759e4\n", 131 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n" 132 | ] 133 | } 134 | ], 135 | "execution_count": 1 136 | }, 137 | { 138 | "cell_type": "code", 139 | "source": [ 140 | "import torch\n", 141 | "import torch.nn as nn\n", 142 | "from torchvision import transforms\n", 143 | "from torch.utils.data import DataLoader\n", 144 | "import open_clip\n", 145 | "import loratorch as lora\n", 146 | "from tqdm import tqdm" 147 | ], 148 | "metadata": { 149 | "id": "T9hDV9jQiU0L", 150 | "trusted": true, 151 | "execution": { 152 | "iopub.status.busy": "2025-06-09T01:10:48.376803Z", 153 | "iopub.execute_input": "2025-06-09T01:10:48.376979Z", 154 | "iopub.status.idle": "2025-06-09T01:11:15.903375Z", 155 | "shell.execute_reply.started": "2025-06-09T01:10:48.376955Z", 156 | "shell.execute_reply": "2025-06-09T01:11:15.902530Z" 157 | } 158 | }, 159 | "outputs": [], 160 | "execution_count": 2 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "source": [ 165 | "### A. Load Pre-trained Model" 166 | ], 167 | "metadata": { 168 | "id": "kDBpUYgYjrGw" 169 | } 170 | }, 171 | { 172 | "cell_type": "code", 173 | "source": [ 174 | "model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')\n", 175 | "model = model.cuda()\n", 176 | "tokenizer = open_clip.get_tokenizer('ViT-B-32')" 177 | ], 178 | "metadata": { 179 | "trusted": true, 180 | "execution": { 181 | "iopub.status.busy": "2025-06-09T01:11:15.904141Z", 182 | "iopub.execute_input": "2025-06-09T01:11:15.904827Z", 183 | "iopub.status.idle": "2025-06-09T01:11:20.771415Z", 184 | "shell.execute_reply.started": "2025-06-09T01:11:15.904798Z", 185 | "shell.execute_reply": "2025-06-09T01:11:20.770632Z" 186 | }, 187 | "id": "1bqndlUex2HS", 188 | "colab": { 189 | "base_uri": "https://localhost:8080/" 190 | }, 191 | "outputId": "9bf32ef6-6a5f-45c0-a8ea-64fb9381a6dc" 192 | }, 193 | "outputs": [ 194 | { 195 | "output_type": "stream", 196 | "name": "stderr", 197 | "text": [ 198 | "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", 199 | "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", 200 | "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", 201 | "You will be able to reuse this secret in all of your notebooks.\n", 202 | "Please note that authentication is recommended but still optional to access public models or datasets.\n", 203 | " warnings.warn(\n", 204 | "/usr/local/lib/python3.11/dist-packages/open_clip/factory.py:388: UserWarning: These pretrained weights were trained with QuickGELU activation but the model config does not have that enabled. Consider using a model config with a \"-quickgelu\" suffix or enable with a flag.\n", 205 | " warnings.warn(\n" 206 | ] 207 | } 208 | ], 209 | "execution_count": 3 210 | }, 211 | { 212 | "cell_type": "code", 213 | "source": [ 214 | "# prompt: count trainable parameters of model?\n", 215 | "\n", 216 | "def count_parameters(model):\n", 217 | " trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 218 | " vision_params = sum(p.numel() for p in model.visual.transformer.parameters() if p.requires_grad)\n", 219 | " text_params = sum(p.numel() for p in model.transformer.parameters() if p.requires_grad)\n", 220 | " embed_params = sum(p.numel() for p in model.token_embedding.parameters() if p.requires_grad)\n", 221 | " total_params = sum(p.numel() for p in model.parameters())\n", 222 | " print(f\"Total parameters: {total_params:,}\")\n", 223 | " print(f\"Trainable parameters - Full model: {trainable_params:,}\")\n", 224 | " print(f\"Trainable parameters - Vision: {vision_params:,}\")\n", 225 | " print(f\"Trainable parameters - Text: {text_params:,}\")\n", 226 | " print(f\"Trainable parameters - embedding: {embed_params:,}\")" 227 | ], 228 | "metadata": { 229 | "id": "VQoMyCMio4Kg", 230 | "trusted": true, 231 | "execution": { 232 | "iopub.status.busy": "2025-06-09T01:21:06.763946Z", 233 | "iopub.execute_input": "2025-06-09T01:21:06.764449Z", 234 | "iopub.status.idle": "2025-06-09T01:21:06.769661Z", 235 | "shell.execute_reply.started": "2025-06-09T01:21:06.764428Z", 236 | "shell.execute_reply": "2025-06-09T01:21:06.768894Z" 237 | } 238 | }, 239 | "outputs": [], 240 | "execution_count": 4 241 | }, 242 | { 243 | "cell_type": "code", 244 | "source": [ 245 | "print('Original model before adding lora')\n", 246 | "count_parameters(model)" 247 | ], 248 | "metadata": { 249 | "trusted": true, 250 | "colab": { 251 | "base_uri": "https://localhost:8080/" 252 | }, 253 | "id": "eJXm5sjmx2HW", 254 | "outputId": "77426425-634d-468e-e379-21db4d9312ec" 255 | }, 256 | "outputs": [ 257 | { 258 | "output_type": "stream", 259 | "name": "stdout", 260 | "text": [ 261 | "Original model before adding lora\n", 262 | "Total parameters: 151,277,313\n", 263 | "Trainable parameters - Full model: 151,277,313\n", 264 | "Trainable parameters - Vision: 85,054,464\n", 265 | "Trainable parameters - Text: 37,828,608\n", 266 | "Trainable parameters - embedding: 25,296,896\n" 267 | ] 268 | } 269 | ], 270 | "execution_count": 5 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "source": [ 275 | "### B. Load CIFAR-10" 276 | ], 277 | "metadata": { 278 | "id": "OMHX0z87i5_9" 279 | } 280 | }, 281 | { 282 | "cell_type": "code", 283 | "source": [ 284 | "# prompt: load cifar10 dataset\n", 285 | "\n", 286 | "from torchvision.datasets import CIFAR10\n", 287 | "\n", 288 | "train_dataset = CIFAR10(\n", 289 | " root=\"./data\", train=True, download=True,\n", 290 | " transform=preprocess\n", 291 | ")\n", 292 | "test_dataset = CIFAR10(\n", 293 | " root=\"./data\", train=False, download=True,\n", 294 | " transform=preprocess\n", 295 | ")\n", 296 | "\n", 297 | "batch_size_train = 256\n", 298 | "batch_size_test = 256\n", 299 | "train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=4)\n", 300 | "test_loader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, num_workers=4)\n" 301 | ], 302 | "metadata": { 303 | "id": "mA1W_9IIiA_P", 304 | "outputId": "5d786a0d-d462-4034-ad85-3ce88c030ca8", 305 | "trusted": true, 306 | "execution": { 307 | "iopub.status.busy": "2025-06-09T01:11:20.772338Z", 308 | "iopub.execute_input": "2025-06-09T01:11:20.772628Z", 309 | "iopub.status.idle": "2025-06-09T01:11:26.195815Z", 310 | "shell.execute_reply.started": "2025-06-09T01:11:20.772603Z", 311 | "shell.execute_reply": "2025-06-09T01:11:26.195198Z" 312 | }, 313 | "colab": { 314 | "base_uri": "https://localhost:8080/" 315 | } 316 | }, 317 | "outputs": [ 318 | { 319 | "output_type": "stream", 320 | "name": "stderr", 321 | "text": [ 322 | "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py:624: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", 323 | " warnings.warn(\n" 324 | ] 325 | } 326 | ], 327 | "execution_count": 6 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "source": [ 332 | "### C. Fine-tune OpenCLIP with LoRA\n", 333 | "\n", 334 | "_Note:_\n", 335 | "\n", 336 | "Please make sure ``loratorch.MultiheadAttention`` uses the same input parameter values as [`nn.MultiheadAttention`](https://docs.pytorch.org/docs/2.6/generated/torch.nn.MultiheadAttention.html#multiheadattention).\n", 337 | "\n", 338 | "For exmaple, the default value for batch_first in `nn.MultiheadAttention` is `False`, but `open_clip` sets it to `True` in some `attn` layers. The discussion of this can be found [here](https://github.com/Baijiong-Lin/LoRA-Torch/issues/6#issuecomment-2954122864).\n", 339 | "\n", 340 | "The best way of employing `loratorch.MultiheadAttention` is the following:\n", 341 | "```python\n", 342 | "lora_multihead = lora.MultiheadAttention(r=r,\n", 343 | " lora_alpha=lora_alpha,\n", 344 | " enable_lora=enable_lora,\n", 345 | " embed_dim=multihead.embed_dim,\n", 346 | " num_heads=multihead.num_heads,\n", 347 | " dropout=multihead.dropout,\n", 348 | " bias=True if hasattr(multihead, \"in_proj_bias\") else False,\n", 349 | " add_bias_kv=False if multihead.bias_k==None else True,\n", 350 | " add_zero_attn=multihead.add_zero_attn,\n", 351 | " kdim=multihead.kdim,\n", 352 | " vdim=multihead.vdim,\n", 353 | " batch_first=multihead.batch_first)\n", 354 | "```" 355 | ], 356 | "metadata": { 357 | "id": "Ahjx5_MPlbvU" 358 | } 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "source": [ 363 | "#### Apply LoRA to `attn` and `mlp`" 364 | ], 365 | "metadata": { 366 | "id": "Y1sLkX5tDZGz" 367 | } 368 | }, 369 | { 370 | "cell_type": "code", 371 | "source": [ 372 | "def apply_lora_attn_mlp(model, encoder_type='visual', rank=16, lora_alpha=32, mlp=True, attn=True):\n", 373 | " if encoder_type == 'visual':\n", 374 | " encoder = model.visual.transformer\n", 375 | " elif encoder_type == 'text':\n", 376 | " encoder = model.transformer\n", 377 | " else:\n", 378 | " raise ValueError(\"Invalid encoder_type. Choose 'visual' or 'text'.\")\n", 379 | "\n", 380 | " enable_lora=['q', 'k', 'v', 'o']\n", 381 | " for i, resblock in enumerate(encoder.resblocks):\n", 382 | " if hasattr(resblock, 'attn') and attn:\n", 383 | " multihead = resblock.attn\n", 384 | " lora_multihead = lora.MultiheadAttention(r=rank,\n", 385 | " lora_alpha=lora_alpha,\n", 386 | " enable_lora=enable_lora,\n", 387 | " embed_dim=multihead.embed_dim,\n", 388 | " num_heads=multihead.num_heads,\n", 389 | " dropout=multihead.dropout,\n", 390 | " bias=True if hasattr(multihead, \"in_proj_bias\") else False,\n", 391 | " add_bias_kv=False if multihead.bias_k==None else True,\n", 392 | " add_zero_attn=multihead.add_zero_attn,\n", 393 | " kdim=multihead.kdim,\n", 394 | " vdim=multihead.vdim,\n", 395 | " batch_first=multihead.batch_first)\n", 396 | " lora_multihead.load_state_dict(multihead.state_dict(), strict=False)\n", 397 | " resblock.attn = lora_multihead\n", 398 | "\n", 399 | " if hasattr(resblock, 'mlp') and mlp:\n", 400 | " old_mlp_fc=resblock.mlp.c_fc\n", 401 | " old_mlp_proj=resblock.mlp.c_proj\n", 402 | " new_mlp_fc = lora.Linear(\n", 403 | " old_mlp_fc.in_features,\n", 404 | " old_mlp_fc.out_features,\n", 405 | " bias=True if hasattr(old_mlp_fc, \"bias\") else False,\n", 406 | " r=rank,\n", 407 | " lora_alpha=lora_alpha,\n", 408 | " )\n", 409 | " new_mlp_proj = lora.Linear(\n", 410 | " old_mlp_proj.in_features,\n", 411 | " old_mlp_proj.out_features,\n", 412 | " bias=True if hasattr(old_mlp_proj, \"bias\") else False,\n", 413 | " r=rank,\n", 414 | " lora_alpha=lora_alpha,\n", 415 | " )\n", 416 | " new_mlp_fc.load_state_dict(old_mlp_fc.state_dict(),strict=False)\n", 417 | " new_mlp_proj.load_state_dict(old_mlp_proj.state_dict(),strict=False)\n", 418 | " resblock.mlp.c_fc = new_mlp_fc\n", 419 | " resblock.mlp.c_proj = new_mlp_proj\n", 420 | "\n", 421 | " lora.mark_only_lora_as_trainable(model)\n", 422 | " return model" 423 | ], 424 | "metadata": { 425 | "trusted": true, 426 | "execution": { 427 | "iopub.status.busy": "2025-06-09T01:20:16.374517Z", 428 | "iopub.execute_input": "2025-06-09T01:20:16.375239Z", 429 | "iopub.status.idle": "2025-06-09T01:20:16.382721Z", 430 | "shell.execute_reply.started": "2025-06-09T01:20:16.375208Z", 431 | "shell.execute_reply": "2025-06-09T01:20:16.382180Z" 432 | }, 433 | "id": "UQ8c81GOx2HZ" 434 | }, 435 | "outputs": [], 436 | "execution_count": 7 437 | }, 438 | { 439 | "cell_type": "code", 440 | "source": [ 441 | "apply_lora_attn_mlp(model, encoder_type='visual', rank=16, lora_alpha=32, mlp=True, attn=True)\n", 442 | "tokenizer = open_clip.get_tokenizer('ViT-B-32')" 443 | ], 444 | "metadata": { 445 | "trusted": true, 446 | "execution": { 447 | "iopub.status.busy": "2025-06-09T01:20:17.384221Z", 448 | "iopub.execute_input": "2025-06-09T01:20:17.384480Z", 449 | "iopub.status.idle": "2025-06-09T01:20:18.691620Z", 450 | "shell.execute_reply.started": "2025-06-09T01:20:17.384462Z", 451 | "shell.execute_reply": "2025-06-09T01:20:18.690864Z" 452 | }, 453 | "id": "WThYJ3zzx2HZ" 454 | }, 455 | "outputs": [], 456 | "execution_count": 8 457 | }, 458 | { 459 | "cell_type": "code", 460 | "source": [ 461 | "for name, param in model.visual.transformer.resblocks[0].named_parameters():\n", 462 | " print(name, param.requires_grad)\n", 463 | "\n", 464 | "# after adding lora\n", 465 | "print(\"\\nAfter adding LoRA to Attn+MLP:\")\n", 466 | "count_parameters(model)" 467 | ], 468 | "metadata": { 469 | "trusted": true, 470 | "execution": { 471 | "iopub.status.busy": "2025-06-09T01:21:17.260662Z", 472 | "iopub.execute_input": "2025-06-09T01:21:17.260951Z", 473 | "iopub.status.idle": "2025-06-09T01:21:17.267985Z", 474 | "shell.execute_reply.started": "2025-06-09T01:21:17.260934Z", 475 | "shell.execute_reply": "2025-06-09T01:21:17.267184Z" 476 | }, 477 | "colab": { 478 | "base_uri": "https://localhost:8080/" 479 | }, 480 | "id": "DT_NoSaTx2HZ", 481 | "outputId": "2427de82-3257-4b98-fe21-64c9330ee75c" 482 | }, 483 | "outputs": [ 484 | { 485 | "output_type": "stream", 486 | "name": "stdout", 487 | "text": [ 488 | "ln_1.weight False\n", 489 | "ln_1.bias False\n", 490 | "attn.in_proj_weight False\n", 491 | "attn.in_proj_bias False\n", 492 | "attn.o_lora_A True\n", 493 | "attn.o_lora_B True\n", 494 | "attn.qkv_lora_A True\n", 495 | "attn.qkv_lora_B True\n", 496 | "attn.out_proj.weight False\n", 497 | "attn.out_proj.bias False\n", 498 | "ln_2.weight False\n", 499 | "ln_2.bias False\n", 500 | "mlp.c_fc.weight False\n", 501 | "mlp.c_fc.bias False\n", 502 | "mlp.c_fc.w_lora_A True\n", 503 | "mlp.c_fc.w_lora_B True\n", 504 | "mlp.c_proj.weight False\n", 505 | "mlp.c_proj.bias False\n", 506 | "mlp.c_proj.w_lora_A True\n", 507 | "mlp.c_proj.w_lora_B True\n", 508 | "\n", 509 | "After adding LoRA to Attn+MLP:\n", 510 | "Total parameters: 153,636,609\n", 511 | "Trainable parameters - Full model: 2,359,296\n", 512 | "Trainable parameters - Vision: 2,359,296\n", 513 | "Trainable parameters - Text: 0\n", 514 | "Trainable parameters - embedding: 0\n" 515 | ] 516 | } 517 | ], 518 | "execution_count": 9 519 | }, 520 | { 521 | "cell_type": "code", 522 | "source": [ 523 | "# Tokenizer and text embeddings\n", 524 | "model.cuda()\n", 525 | "\n", 526 | "tokenizer = open_clip.get_tokenizer(\"ViT-B-32\")\n", 527 | "classnames = train_dataset.classes\n", 528 | "text_inputs = tokenizer([f\"a photo of a {label}\" for label in classnames]).cuda()\n", 529 | "with torch.no_grad():\n", 530 | " text_features = model.encode_text(text_inputs)\n", 531 | " text_features = text_features / text_features.norm(dim=-1, keepdim=True)\n", 532 | "\n", 533 | "# Optimizer\n", 534 | "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)" 535 | ], 536 | "metadata": { 537 | "trusted": true, 538 | "execution": { 539 | "iopub.status.busy": "2025-06-09T01:21:34.494110Z", 540 | "iopub.execute_input": "2025-06-09T01:21:34.494363Z", 541 | "iopub.status.idle": "2025-06-09T01:21:35.033731Z", 542 | "shell.execute_reply.started": "2025-06-09T01:21:34.494344Z", 543 | "shell.execute_reply": "2025-06-09T01:21:35.032959Z" 544 | }, 545 | "id": "RdVjI-pvx2HZ" 546 | }, 547 | "outputs": [], 548 | "execution_count": 10 549 | }, 550 | { 551 | "cell_type": "code", 552 | "source": [ 553 | "# Train loop\n", 554 | "model.train()\n", 555 | "for epoch in range(3):\n", 556 | " total_loss = 0\n", 557 | " correct = 0\n", 558 | " total = 0\n", 559 | " for images, labels in tqdm(train_loader):\n", 560 | " images, labels = images.cuda(), labels.cuda()\n", 561 | " image_features = model.encode_image(images)\n", 562 | " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n", 563 | " logits = image_features @ text_features.t()\n", 564 | " loss = nn.CrossEntropyLoss()(logits, labels)\n", 565 | "\n", 566 | " optimizer.zero_grad()\n", 567 | " loss.backward()\n", 568 | " optimizer.step()\n", 569 | " total_loss += loss.item()\n", 570 | "\n", 571 | " preds = logits.argmax(dim=1)\n", 572 | " correct += (preds == labels).sum().item()\n", 573 | " total += labels.size(0)\n", 574 | " # (!!!) reregister model param to ensure they are in model.state_dict() and model.parameters()\n", 575 | " # (!!!) Without this line, the performance does not be affected but you will find that some weights are missing in model.state_dict() and model.parameters()\n", 576 | " lora.register_model_param_after_backward(model)\n", 577 | "\n", 578 | " acc = correct / total\n", 579 | " print(f\"Epoch {epoch+1}: Loss={total_loss:.4f}, Accuracy={acc:.4f}\")" 580 | ], 581 | "metadata": { 582 | "outputId": "153083d4-3f77-4b1d-bc7e-a9798cfb1139", 583 | "trusted": true, 584 | "execution": { 585 | "iopub.status.busy": "2025-06-09T01:21:35.369183Z", 586 | "iopub.execute_input": "2025-06-09T01:21:35.369507Z", 587 | "iopub.status.idle": "2025-06-09T01:54:15.576408Z", 588 | "shell.execute_reply.started": "2025-06-09T01:21:35.369485Z", 589 | "shell.execute_reply": "2025-06-09T01:54:15.575561Z" 590 | }, 591 | "id": "UCgTKj1qx2HZ", 592 | "colab": { 593 | "base_uri": "https://localhost:8080/" 594 | } 595 | }, 596 | "outputs": [ 597 | { 598 | "output_type": "stream", 599 | "name": "stderr", 600 | "text": [ 601 | "100%|██████████| 196/196 [06:43<00:00, 2.06s/it]\n" 602 | ] 603 | }, 604 | { 605 | "output_type": "stream", 606 | "name": "stdout", 607 | "text": [ 608 | "Epoch 1: Loss=397.7533, Accuracy=0.9529\n" 609 | ] 610 | }, 611 | { 612 | "output_type": "stream", 613 | "name": "stderr", 614 | "text": [ 615 | "100%|██████████| 196/196 [06:40<00:00, 2.04s/it]\n" 616 | ] 617 | }, 618 | { 619 | "output_type": "stream", 620 | "name": "stdout", 621 | "text": [ 622 | "Epoch 2: Loss=387.0708, Accuracy=0.9796\n" 623 | ] 624 | }, 625 | { 626 | "output_type": "stream", 627 | "name": "stderr", 628 | "text": [ 629 | "100%|██████████| 196/196 [06:39<00:00, 2.04s/it]" 630 | ] 631 | }, 632 | { 633 | "output_type": "stream", 634 | "name": "stdout", 635 | "text": [ 636 | "Epoch 3: Loss=386.2361, Accuracy=0.9901\n" 637 | ] 638 | }, 639 | { 640 | "output_type": "stream", 641 | "name": "stderr", 642 | "text": [ 643 | "\n" 644 | ] 645 | } 646 | ], 647 | "execution_count": 11 648 | }, 649 | { 650 | "cell_type": "code", 651 | "source": [], 652 | "metadata": { 653 | "trusted": true, 654 | "id": "pMS8k5sXx2Ha" 655 | }, 656 | "outputs": [], 657 | "execution_count": 11 658 | } 659 | ] 660 | } --------------------------------------------------------------------------------