├── loratorch
├── __init__.py
├── utils.py
└── layers.py
├── setup.py
├── LICENSE
├── examples
├── embedding.ipynb
├── linear.ipynb
├── mergedlinear.ipynb
└── Finetune_open_clip_with_LoRA_Torch_on_CIFAR10.ipynb
└── README.md
/loratorch/__init__.py:
--------------------------------------------------------------------------------
1 | name = "lora"
2 |
3 | from .layers import *
4 | from .utils import *
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r", encoding="utf-8") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="loratorch",
8 | version="0.1.0",
9 | author="Baijiong Lin",
10 | author_email="bj.lin.email@gmail.com",
11 | description="PyTorch reimplementation of low-rank adaptation (LoRA).",
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | url="https://github.com/Baijiong-Lin/LoRA-Torch",
15 | packages=setuptools.find_packages(),
16 | classifiers=[
17 | "Programming Language :: Python :: 3",
18 | "License :: OSI Approved :: MIT License",
19 | "Operating System :: OS Independent",
20 | ],
21 | python_requires='>=3.6',
22 | )
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Baijiong Lin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/loratorch/utils.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------
2 | # This code is reconstructed based on loralib (https://github.com/microsoft/LoRA) by Baijiong Lin.
3 | # ------------------------------------------------------------------------------------------
4 | import torch
5 | import torch.nn as nn
6 |
7 | from typing import Dict
8 |
9 | from .layers import LoRALayer
10 |
11 | def register_model_param_after_backward(model: nn.Module) -> None:
12 | for m in model.modules():
13 | if isinstance(m, LoRALayer):
14 | m.register_weight_after_backward()
15 |
16 | def print_trainable_parameters(model):
17 | r"""Prints the number of trainable parameters in the model."""
18 | trainable_params = 0
19 | all_param = 0
20 | for _, param in model.named_parameters():
21 | all_param += param.numel()
22 | if param.requires_grad:
23 | trainable_params += param.numel()
24 | print(
25 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
26 | )
27 |
28 | def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
29 | for n, p in model.named_parameters():
30 | if 'lora_' not in n:
31 | p.requires_grad = False
32 | if bias == 'none':
33 | return
34 | elif bias == 'all':
35 | for n, p in model.named_parameters():
36 | if 'bias' in n:
37 | p.requires_grad = True
38 | elif bias == 'lora_only':
39 | for m in model.modules():
40 | if isinstance(m, LoRALayer) and \
41 | hasattr(m, 'bias') and \
42 | m.bias is not None:
43 | m.bias.requires_grad = True
44 | else:
45 | raise NotImplementedError
46 | print_trainable_parameters(model)
47 |
48 |
49 | def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
50 | my_state_dict = model.state_dict()
51 | if bias == 'none':
52 | return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
53 | elif bias == 'all':
54 | return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
55 | elif bias == 'lora_only':
56 | to_return = {}
57 | for k in my_state_dict:
58 | if 'lora_' in k:
59 | to_return[k] = my_state_dict[k]
60 | bias_name = k.split('lora_')[0]+'bias'
61 | if bias_name in my_state_dict:
62 | to_return[bias_name] = my_state_dict[bias_name]
63 | return to_return
64 | else:
65 | raise NotImplementedError
66 |
--------------------------------------------------------------------------------
/examples/embedding.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "authorship_tag": "ABX9TyNh+WHHXHiw1v/WeZj1rcht",
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 1,
32 | "metadata": {
33 | "colab": {
34 | "base_uri": "https://localhost:8080/"
35 | },
36 | "id": "330HoZ7N_S3x",
37 | "outputId": "1d7f8b83-b37f-4b55-ad77-aa099b236083"
38 | },
39 | "outputs": [
40 | {
41 | "output_type": "stream",
42 | "name": "stdout",
43 | "text": [
44 | "Collecting git+https://github.com/microsoft/LoRA\n",
45 | " Cloning https://github.com/microsoft/LoRA to /tmp/pip-req-build-qrvzzdp0\n",
46 | " Running command git clone --filter=blob:none --quiet https://github.com/microsoft/LoRA /tmp/pip-req-build-qrvzzdp0\n",
47 | " Resolved https://github.com/microsoft/LoRA to commit 998cfe4d351f4d6b4a47f0921dec2397aa0b9dfe\n",
48 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
49 | "Collecting git+https://github.com/Baijiong-Lin/LoRA-Torch\n",
50 | " Cloning https://github.com/Baijiong-Lin/LoRA-Torch to /tmp/pip-req-build-ro2rj2g8\n",
51 | " Running command git clone --filter=blob:none --quiet https://github.com/Baijiong-Lin/LoRA-Torch /tmp/pip-req-build-ro2rj2g8\n",
52 | " Resolved https://github.com/Baijiong-Lin/LoRA-Torch to commit 34286e640a4d15dfe23c361f6d95b7cc55a8ec6b\n",
53 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
54 | ]
55 | }
56 | ],
57 | "source": [
58 | "!pip install git+https://github.com/microsoft/LoRA\n",
59 | "!pip install git+https://github.com/Baijiong-Lin/LoRA-Torch"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "source": [
65 | "import torch, loralib, loratorch, copy\n",
66 | "import torch.nn as nn"
67 | ],
68 | "metadata": {
69 | "id": "cqadI5Fu_WUX"
70 | },
71 | "execution_count": 2,
72 | "outputs": []
73 | },
74 | {
75 | "cell_type": "code",
76 | "source": [
77 | "model_lib = loralib.Embedding(10, 3, r=4)\n",
78 | "\n",
79 | "for k, v in model_lib.named_parameters():\n",
80 | " print(k, v.size())\n",
81 | "\n",
82 | "\n",
83 | "loralib.mark_only_lora_as_trainable(model_lib)\n",
84 | "\n",
85 | "optimizer_lib = torch.optim.SGD(model_lib.parameters(), lr=0.1)\n",
86 | "\n",
87 | "for _ in range(3):\n",
88 | " model_lib.train()\n",
89 | "\n",
90 | " x = torch.rand(5, 10).long()\n",
91 | "\n",
92 | " loss_lib = model_lib(x).sum()\n",
93 | " optimizer_lib.zero_grad()\n",
94 | " loss_lib.backward()\n",
95 | " optimizer_lib.step()\n",
96 | "\n",
97 | " x_test = torch.rand(*x.size()).long()\n",
98 | " model_lib.eval()\n",
99 | " print(model_lib(x_test).size())"
100 | ],
101 | "metadata": {
102 | "colab": {
103 | "base_uri": "https://localhost:8080/"
104 | },
105 | "id": "91GlkwcQD7YG",
106 | "outputId": "e87d2630-4ba5-456e-c8b3-63c25b36a9cb"
107 | },
108 | "execution_count": 3,
109 | "outputs": [
110 | {
111 | "output_type": "stream",
112 | "name": "stdout",
113 | "text": [
114 | "weight torch.Size([10, 3])\n",
115 | "lora_A torch.Size([4, 10])\n",
116 | "lora_B torch.Size([3, 4])\n",
117 | "torch.Size([5, 10, 3])\n",
118 | "torch.Size([5, 10, 3])\n",
119 | "torch.Size([5, 10, 3])\n"
120 | ]
121 | }
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "source": [
127 | "model_torch = loratorch.Embedding(10, 3, r=4)\n",
128 | "\n",
129 | "for k, v in model_torch.named_parameters():\n",
130 | " print(k, v.size())\n",
131 | "\n",
132 | "\n",
133 | "loratorch.mark_only_lora_as_trainable(model_torch)\n",
134 | "\n",
135 | "optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.1)\n",
136 | "\n",
137 | "for _ in range(3):\n",
138 | " model_torch.train()\n",
139 | "\n",
140 | " x = torch.rand(5, 10).long()\n",
141 | "\n",
142 | " loss_torch = model_torch(x).sum()\n",
143 | " optimizer_torch.zero_grad()\n",
144 | " loss_torch.backward()\n",
145 | " optimizer_torch.step()\n",
146 | "\n",
147 | " x_test = torch.rand(*x.size()).long()\n",
148 | " model_torch.eval()\n",
149 | " print(model_torch(x_test).size())"
150 | ],
151 | "metadata": {
152 | "colab": {
153 | "base_uri": "https://localhost:8080/"
154 | },
155 | "id": "NsLzmMXSC3RQ",
156 | "outputId": "48c1e55c-1236-42c5-ac87-c3d55e6bcf87"
157 | },
158 | "execution_count": 4,
159 | "outputs": [
160 | {
161 | "output_type": "stream",
162 | "name": "stdout",
163 | "text": [
164 | "weight torch.Size([10, 3])\n",
165 | "w_lora_A torch.Size([4, 3])\n",
166 | "w_lora_B torch.Size([10, 4])\n",
167 | "torch.Size([5, 10, 3])\n",
168 | "torch.Size([5, 10, 3])\n",
169 | "torch.Size([5, 10, 3])\n"
170 | ]
171 | }
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "source": [],
177 | "metadata": {
178 | "id": "q7yt9fq1GrnA"
179 | },
180 | "execution_count": 4,
181 | "outputs": []
182 | }
183 | ]
184 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LoRA-Torch
2 |
3 | [](https://github.com/Baijiong-Lin/LoRA-Torch)
4 |
5 | This codebase reimplementes [LoRA: Low-Rank Adaptation of Large Language Models (ICLR 2022)](https://openreview.net/forum?id=nZeVKeeFYf9) and is reconstructed based on [loralib](https://github.com/microsoft/LoRA).
6 |
7 |
8 |
9 | ## Features
10 |
11 | **The implementations of ``loratorch`` and ``loralib`` are very different.** We take the ``nn.Linear`` as an example as follows.
12 |
13 | 1. For ``loralib``,
14 | $h = x W_0^\top + \frac{\alpha}{r} x(BA)^\top,$
15 |
16 | where $x\in\mathbb{R}^{k\times n}$ is the input matrix, $W_0\in\mathbb{R}^{m\times n}$ is the pre-trained weight matrix, $r$ is the predefined LoRA rank, $B\in\mathbb{R}^{m\times r}$ and $A\in \mathbb{R}^{r\times n}$ are the LoRA matrixes, and $\alpha$ is a hyper-parameter.
17 |
18 | 2. For ``loratorch``,
19 | $h = x (W_0 + \frac{\alpha}{r} BA)^\top.$
20 |
21 |
22 |
23 | ``loralib`` computes $xW_0^\top$ and $x(BA)^\top$ respectively and then merges the results. While ``loratorch`` merges pre-trained weight $W_0$ and its LoRA weight $BA$ and then computes the results by simply using ``nn.Linear.forward()``. There is no difference between ``loralib`` and ``loratorch`` in the linear layers. But in some no-linear or complex layers, we are no sure whether this layer satisfies $L(x, W_0)+L(x, BA) = L(x, W_0+BA)$. Hence, it is difficult to extend LoRA to some complex layers by using ``loralib``. On the contrary, the idea of merging weights first in ``loratorch`` is more general and extensible. You just call ``merge_lora_param()`` in ``loratorch`` to merge weights and then call ``forward()`` in the original layer to compute the results. With the help of ``loratorch``, you can easily implement LoRA to any type of layer of ``torch.nn``.
24 |
25 |
26 |
27 | ## Supported Layers
28 |
29 | | | ``loralib`` | ``loratorch`` | |
30 | | ------------------------- |:--------------:|:--------------:| -------------------------------------------------- |
31 | | ``nn.Linear`` | ✓ | ✓ | [linear.ipynb](https://github.com/Baijiong-Lin/LoRA-Torch/blob/main/examples/linear.ipynb) |
32 | | ``nn.Embedding`` | ✓ | ✓ | [embedding.ipynb](https://github.com/Baijiong-Lin/LoRA-Torch/blob/main/examples/embedding.ipynb) |
33 | | ``nn.Conv1d`` | ✓ | ✓ | |
34 | | ``nn.Conv2d`` | ✓ | ✓ | |
35 | | ``nn.Conv3d`` | ✓ | ✓ | |
36 | | ``nn.MultiheadAttention`` | ✘ | ✓ | [Finetune_open_clip_with_LoRA_Torch_on_CIFAR10.ipynb](https://github.com/Baijiong-Lin/LoRA-Torch/blob/main/examples/Finetune_open_clip_with_LoRA_Torch_on_CIFAR10.ipynb) |
37 | | ``MergedLinear`` | ✓ (Error) | ✓ | [mergedlinear.ipynb](https://github.com/Baijiong-Lin/LoRA-Torch/blob/main/examples/mergedlinear.ipynb) |
38 | | $\cdots$ | hard to extend | easy to extend | |
39 |
40 | *We compare the results of ``loralib`` and ``loratorch`` in [examples](./examples) to demonstrate the correctness of the implementation in ``loratorch``.*
41 |
42 |
43 |
44 | ## Quick Start
45 |
46 | :bangbang: We have provided an [example](https://github.com/Baijiong-Lin/LoRA-Torch/blob/main/examples/Finetune_open_clip_with_LoRA_Torch_on_CIFAR10.ipynb) to demonstrate that how to apply LoRA-Torch to ``nn.MultiheadAttention`` in OpenCLIP. We greatly appreciate [@vietvo89](https://github.com/vietvo89)'s valuable contribution.
47 |
48 | **The usage of ``loratorch`` is the same as ``loralib``.**
49 |
50 | 1. Install ``loratorch``.
51 |
52 | ```bash
53 | pip install git+https://github.com/Baijiong-Lin/LoRA-Torch
54 | # Alternatively for developers
55 | # git clone https://github.com/Baijiong-Lin/LoRA-Torch
56 | # cd LoRA-Torch
57 | # pip install -e .
58 | ```
59 |
60 | 2. Replace the layers where you would like to use LoRA by using ``loratorch``.
61 |
62 | ```python
63 | # ===== Before =====
64 | # layer = nn.Linear(in_features, out_features)
65 |
66 | # ===== After ======
67 | import loratorch as lora
68 | # Add a pair of low-rank adaptation matrices with rank r=16 and alpha=32
69 | layer = lora.Linear(in_features, out_features, r=16, lora_alpha=32)
70 | ```
71 |
72 | 3. Mark only LoRA parameters as trainable before the training loop.
73 |
74 | ```python
75 | model = Model()
76 | # (!!!) This sets requires_grad to False for all parameters without the string "lora_" in their names
77 | lora.mark_only_lora_as_trainable(model)
78 |
79 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
80 | # Training loop
81 | for batch in dataloader:
82 | model.train()
83 | # forward process
84 | loss = forward_fun(model, batch)
85 | # backward process
86 | optimizer.zero_grad()
87 | loss.backward()
88 | optimizer.step()
89 | # (!!!) reregister model param to ensure they are in model.state_dict() and model.parameters()
90 | # (!!!) Without this line, the performance does not be affected but you will find that some weights are missing in model.state_dict() and model.parameters()
91 | lora.register_model_param_after_backward(model)
92 | ```
93 |
94 | 4. Save LoRA model (only the LoRA matrixes will be saved).
95 |
96 | ```python
97 | # ===== Before =====
98 | # torch.save(model.state_dict(), checkpoint_path)
99 | # ===== After =====
100 | torch.save(lora.lora_state_dict(model), checkpoint_path)
101 | ```
102 |
103 | 5. Load LoRA model (need to load the pre-trained model first).
104 |
105 | ```python
106 | # Load the pre-trained checkpoint first
107 | model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False)
108 | # Then load the LoRA checkpoint
109 | model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False)
110 | ```
111 |
112 | ## Contributor
113 |
114 | ``loratorch`` is developed and maintained by [Baijiong Lin](https://baijiong-lin.github.io).
115 |
116 | ## Contact Us
117 |
118 | If you have any question or suggestion, please feel free to contact us by [raising an issue](https://github.com/Baijiong-Lin/LoRA-Torch/issues) or sending an email to ``bj.lin.email@gmail.com``.
119 |
120 | ## Acknowledgements
121 |
122 | ``loratorch`` is heavily based on ``loralib``. We thank its authors for their wonderful and open-source codebase.
123 |
124 | ## Citation
125 |
126 | If you find ``loratorch`` useful for your research or development, please cite the following:
127 |
128 | ```BibTeX
129 | @inproceedings{hu2022lora,
130 | title={Lo{RA}: Low-Rank Adaptation of Large Language Models},
131 | author={Edward J Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Lu Wang and Weizhu Chen},
132 | booktitle={International Conference on Learning Representations},
133 | year={2022},
134 | }
135 |
136 | @software{lin2023loratorch,
137 | author = {Baijiong Lin},
138 | title = {{LoRA-Torch}: {PyTorch} Reimplementation of {LoRA}},
139 | url = {https://github.com/Baijiong-Lin/LoRA-Torch},
140 | year = {2023}
141 | }
142 | ```
143 |
144 | ## License
145 |
146 | ``loratorch`` is released under the [MIT](./LICENSE) license.
147 |
148 |
149 |
--------------------------------------------------------------------------------
/examples/linear.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "authorship_tag": "ABX9TyNIy6uRivivL3VYs0eNmj01",
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 1,
32 | "metadata": {
33 | "colab": {
34 | "base_uri": "https://localhost:8080/"
35 | },
36 | "id": "-j81b-ktkWH0",
37 | "outputId": "ca0a1e95-ae6c-44ac-cee0-5bf75f8fd3b1"
38 | },
39 | "outputs": [
40 | {
41 | "output_type": "stream",
42 | "name": "stdout",
43 | "text": [
44 | "Collecting git+https://github.com/microsoft/LoRA\n",
45 | " Cloning https://github.com/microsoft/LoRA to /tmp/pip-req-build-e5ulip2v\n",
46 | " Running command git clone --filter=blob:none --quiet https://github.com/microsoft/LoRA /tmp/pip-req-build-e5ulip2v\n",
47 | " Resolved https://github.com/microsoft/LoRA to commit 998cfe4d351f4d6b4a47f0921dec2397aa0b9dfe\n",
48 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
49 | "Collecting git+https://github.com/Baijiong-Lin/LoRA-Torch\n",
50 | " Cloning https://github.com/Baijiong-Lin/LoRA-Torch to /tmp/pip-req-build-l221wgno\n",
51 | " Running command git clone --filter=blob:none --quiet https://github.com/Baijiong-Lin/LoRA-Torch /tmp/pip-req-build-l221wgno\n",
52 | " Resolved https://github.com/Baijiong-Lin/LoRA-Torch to commit 4b550558ff1c7ef6bd1009b80d30e52069396f69\n",
53 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
54 | ]
55 | }
56 | ],
57 | "source": [
58 | "!pip install git+https://github.com/microsoft/LoRA\n",
59 | "!pip install git+https://github.com/Baijiong-Lin/LoRA-Torch"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "source": [
65 | "import torch, loralib, loratorch, copy\n",
66 | "import torch.nn as nn"
67 | ],
68 | "metadata": {
69 | "id": "tDo6S6QPkiXI"
70 | },
71 | "execution_count": 2,
72 | "outputs": []
73 | },
74 | {
75 | "cell_type": "code",
76 | "source": [
77 | "model_lib = loralib.Linear(5, 6, r=4, lora_alpha=1)\n",
78 | "model_torch = loratorch.Linear(5, 6, r=4, lora_alpha=1)\n",
79 | "\n",
80 | "model_lib.weight.data = copy.deepcopy(model_torch.weight.data)\n",
81 | "model_lib.lora_A.data = copy.deepcopy(model_torch.w_lora_A.data)\n",
82 | "model_lib.lora_B.data = copy.deepcopy(model_torch.w_lora_B.data)\n",
83 | "model_lib.bias.data = copy.deepcopy(model_torch.bias.data)"
84 | ],
85 | "metadata": {
86 | "id": "_2ceOeiIk8xp"
87 | },
88 | "execution_count": 3,
89 | "outputs": []
90 | },
91 | {
92 | "cell_type": "code",
93 | "source": [
94 | "loralib.mark_only_lora_as_trainable(model_lib)\n",
95 | "loratorch.mark_only_lora_as_trainable(model_torch)"
96 | ],
97 | "metadata": {
98 | "id": "CF_PCLIslcJl"
99 | },
100 | "execution_count": 4,
101 | "outputs": []
102 | },
103 | {
104 | "cell_type": "code",
105 | "source": [
106 | "optimizer_lib = torch.optim.SGD(model_lib.parameters(), lr=0.1)\n",
107 | "optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.1)\n",
108 | "\n",
109 | "for _ in range(3):\n",
110 | " model_lib.train()\n",
111 | " model_torch.train()\n",
112 | " x = torch.rand(2, 5)\n",
113 | "\n",
114 | " loss1 = model_lib(x).sum()\n",
115 | " optimizer_lib.zero_grad()\n",
116 | " loss1.backward()\n",
117 | " optimizer_lib.step()\n",
118 | "\n",
119 | " loss2 = model_torch(x).sum()\n",
120 | " optimizer_torch.zero_grad()\n",
121 | " loss2.backward()\n",
122 | " optimizer_torch.step()\n",
123 | "\n",
124 | " # for k, v in model_lib.named_parameters():\n",
125 | " # print(k, v.grad)\n",
126 | "\n",
127 | " # for k, v in model_torch.named_parameters():\n",
128 | " # print(k, v.grad)\n",
129 | " print(torch.isclose(model_lib.lora_A.grad, model_torch.w_lora_A.grad))\n",
130 | " print(torch.isclose(model_lib.lora_B.grad, model_torch.w_lora_B.grad))\n",
131 | "\n",
132 | " x_test = torch.rand(3, 5)\n",
133 | " model_lib.eval()\n",
134 | " model_torch.eval()\n",
135 | " print(torch.isclose(model_lib(x_test), model_torch(x_test)))"
136 | ],
137 | "metadata": {
138 | "colab": {
139 | "base_uri": "https://localhost:8080/"
140 | },
141 | "id": "GdmSBAenkvP1",
142 | "outputId": "7e186013-f7a2-4fc6-90c3-854f1f1fc2fb"
143 | },
144 | "execution_count": 5,
145 | "outputs": [
146 | {
147 | "output_type": "stream",
148 | "name": "stdout",
149 | "text": [
150 | "tensor([[True, True, True, True, True],\n",
151 | " [True, True, True, True, True],\n",
152 | " [True, True, True, True, True],\n",
153 | " [True, True, True, True, True]])\n",
154 | "tensor([[True, True, True, True],\n",
155 | " [True, True, True, True],\n",
156 | " [True, True, True, True],\n",
157 | " [True, True, True, True],\n",
158 | " [True, True, True, True],\n",
159 | " [True, True, True, True]])\n",
160 | "tensor([[True, True, True, True, True, True],\n",
161 | " [True, True, True, True, True, True],\n",
162 | " [True, True, True, True, True, True]])\n",
163 | "tensor([[True, True, True, True, True],\n",
164 | " [True, True, True, True, True],\n",
165 | " [True, True, True, True, True],\n",
166 | " [True, True, True, True, True]])\n",
167 | "tensor([[True, True, True, True],\n",
168 | " [True, True, True, True],\n",
169 | " [True, True, True, True],\n",
170 | " [True, True, True, True],\n",
171 | " [True, True, True, True],\n",
172 | " [True, True, True, True]])\n",
173 | "tensor([[True, True, True, True, True, True],\n",
174 | " [True, True, True, True, True, True],\n",
175 | " [True, True, True, True, True, True]])\n",
176 | "tensor([[True, True, True, True, True],\n",
177 | " [True, True, True, True, True],\n",
178 | " [True, True, True, True, True],\n",
179 | " [True, True, True, True, True]])\n",
180 | "tensor([[True, True, True, True],\n",
181 | " [True, True, True, True],\n",
182 | " [True, True, True, True],\n",
183 | " [True, True, True, True],\n",
184 | " [True, True, True, True],\n",
185 | " [True, True, True, True]])\n",
186 | "tensor([[True, True, True, True, True, True],\n",
187 | " [True, True, True, True, True, True],\n",
188 | " [True, True, True, True, True, True]])\n"
189 | ]
190 | }
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "source": [],
196 | "metadata": {
197 | "id": "Ul2JBtuQmIuG"
198 | },
199 | "execution_count": 5,
200 | "outputs": []
201 | }
202 | ]
203 | }
--------------------------------------------------------------------------------
/examples/mergedlinear.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "authorship_tag": "ABX9TyOaHR81EoRiuDjhTvY9PZ+l",
8 | "include_colab_link": true
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 1,
32 | "metadata": {
33 | "colab": {
34 | "base_uri": "https://localhost:8080/"
35 | },
36 | "id": "3YYQD-Le8iaB",
37 | "outputId": "1d5241cf-ce37-4abd-e282-170cba75f654"
38 | },
39 | "outputs": [
40 | {
41 | "output_type": "stream",
42 | "name": "stdout",
43 | "text": [
44 | "Collecting git+https://github.com/microsoft/LoRA\n",
45 | " Cloning https://github.com/microsoft/LoRA to /tmp/pip-req-build-b6kloy6q\n",
46 | " Running command git clone --filter=blob:none --quiet https://github.com/microsoft/LoRA /tmp/pip-req-build-b6kloy6q\n",
47 | " Resolved https://github.com/microsoft/LoRA to commit 998cfe4d351f4d6b4a47f0921dec2397aa0b9dfe\n",
48 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
49 | "Collecting git+https://github.com/Baijiong-Lin/LoRA-Torch\n",
50 | " Cloning https://github.com/Baijiong-Lin/LoRA-Torch to /tmp/pip-req-build-pdln00a3\n",
51 | " Running command git clone --filter=blob:none --quiet https://github.com/Baijiong-Lin/LoRA-Torch /tmp/pip-req-build-pdln00a3\n",
52 | " Resolved https://github.com/Baijiong-Lin/LoRA-Torch to commit f2dccb93ac60313d15251fee12e8b7be5649471a\n",
53 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
54 | ]
55 | }
56 | ],
57 | "source": [
58 | "!pip install git+https://github.com/microsoft/LoRA\n",
59 | "!pip install git+https://github.com/Baijiong-Lin/LoRA-Torch"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "source": [
65 | "import torch, loralib, loratorch, copy\n",
66 | "import torch.nn as nn"
67 | ],
68 | "metadata": {
69 | "id": "t-KzlZJT81XD"
70 | },
71 | "execution_count": 2,
72 | "outputs": []
73 | },
74 | {
75 | "cell_type": "code",
76 | "source": [
77 | "model_lib = loralib.MergedLinear(10, 3*10, bias=False, r=4, enable_lora=[True, False, True], fan_in_fan_out=False)\n",
78 | "loralib.mark_only_lora_as_trainable(model_lib)\n",
79 | "optimizer = torch.optim.SGD(model_lib.parameters(), lr=0.1)\n",
80 | "loss_fn = nn.MSELoss()\n",
81 | "\n",
82 | "for _ in range(3):\n",
83 | " model_lib.train()\n",
84 | " x = torch.rand(5, 10)\n",
85 | " y = torch.rand(5, 3*10)\n",
86 | " loss = loss_fn(model_lib(x), y)\n",
87 | " optimizer.zero_grad()\n",
88 | " loss.backward()\n",
89 | " optimizer.step()\n",
90 | "\n",
91 | " model_lib.eval()\n",
92 | " print(model_lib(x).size())"
93 | ],
94 | "metadata": {
95 | "colab": {
96 | "base_uri": "https://localhost:8080/",
97 | "height": 366
98 | },
99 | "id": "LAyOBzAm82dD",
100 | "outputId": "bf2541a6-3632-41fe-cb18-e5adc63eb2fb"
101 | },
102 | "execution_count": 3,
103 | "outputs": [
104 | {
105 | "output_type": "error",
106 | "ename": "RuntimeError",
107 | "evalue": "ignored",
108 | "traceback": [
109 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
110 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
111 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mmodel_lib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_lib\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
112 | "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36meval\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 2305\u001b[0m \u001b[0mModule\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2306\u001b[0m \"\"\"\n\u001b[0;32m-> 2307\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2308\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2309\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrequires_grad_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequires_grad\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
113 | "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/loralib/layers.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, mode)\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0mgroups\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_lora\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 233\u001b[0m ).squeeze(0)\n\u001b[0;32m--> 234\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_pad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdelta_w\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscaling\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmerged\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
114 | "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/loralib/layers.py\u001b[0m in \u001b[0;36mzero_pad\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_zeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 205\u001b[0;31m result[:, self.lora_ind] = x.reshape(\n\u001b[0m\u001b[1;32m 206\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_features\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_lora\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_lora\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m )\n",
115 | "\u001b[0;31mRuntimeError\u001b[0m: shape mismatch: value tensor of shape [10, 20] cannot be broadcast to indexing result of shape [20, 20]"
116 | ]
117 | }
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "source": [
123 | "model_torch = loratorch.MergedLinear(10, 3*10, bias=False, r=4, enable_lora=[True, False, True], fan_in_fan_out=False)\n",
124 | "loralib.mark_only_lora_as_trainable(model_torch)\n",
125 | "optimizer = torch.optim.SGD(model_torch.parameters(), lr=0.1)\n",
126 | "loss_fn = nn.MSELoss()\n",
127 | "\n",
128 | "for _ in range(3):\n",
129 | " model_torch.train()\n",
130 | " x = torch.rand(5, 10)\n",
131 | " y = torch.rand(5, 3*10)\n",
132 | " loss = loss_fn(model_torch(x), y)\n",
133 | " optimizer.zero_grad()\n",
134 | " loss.backward()\n",
135 | " optimizer.step()\n",
136 | "\n",
137 | " model_torch.eval()\n",
138 | " print(model_torch(x).size())"
139 | ],
140 | "metadata": {
141 | "colab": {
142 | "base_uri": "https://localhost:8080/"
143 | },
144 | "id": "GRPPQ-Wk9evZ",
145 | "outputId": "92684b76-3dc3-47a5-8f8a-a38508850053"
146 | },
147 | "execution_count": 4,
148 | "outputs": [
149 | {
150 | "output_type": "stream",
151 | "name": "stdout",
152 | "text": [
153 | "torch.Size([5, 30])\n",
154 | "torch.Size([5, 30])\n",
155 | "torch.Size([5, 30])\n"
156 | ]
157 | }
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "source": [],
163 | "metadata": {
164 | "id": "jB55l3PU9veg"
165 | },
166 | "execution_count": null,
167 | "outputs": []
168 | }
169 | ]
170 | }
--------------------------------------------------------------------------------
/loratorch/layers.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------
2 | # This code is reconstructed based on loralib (https://github.com/microsoft/LoRA) by Baijiong Lin.
3 | # ------------------------------------------------------------------------------------------
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | import math
9 | from typing import Optional, List
10 |
11 | def set_param(curr_mod, name, param=None, mode='update', with_nn=False):
12 | r"""Refer to https://github.com/Baijiong-Lin/MOML/blob/main/MTL/utils.py"""
13 | if '.' in name:
14 | n = name.split('.')
15 | module_name = n[0]
16 | rest = '.'.join(n[1:])
17 | for name, mod in curr_mod.named_children():
18 | if module_name == name:
19 | return set_param(mod, rest, param, mode=mode, with_nn=with_nn)
20 | else:
21 | if mode == 'update':
22 | delattr(curr_mod, name)
23 | if with_nn:
24 | setattr(curr_mod, name, nn.Parameter(param))
25 | else:
26 | setattr(curr_mod, name, param)
27 | elif mode == 'get':
28 | if hasattr(curr_mod, name):
29 | p = getattr(curr_mod, name)
30 | return p
31 |
32 | class LoRALayer():
33 | def __init__(
34 | self,
35 | r: int,
36 | lora_alpha: int,
37 | fan_in_fan_out: bool = False,
38 | ):
39 | self.r = r
40 | self.lora_alpha = lora_alpha
41 | if self.r > 0:
42 | self.scaling = self.lora_alpha / self.r
43 | # Mark the weight as unmerged
44 | self.merged = False
45 | # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
46 | self.fan_in_fan_out = fan_in_fan_out
47 | # define params that require LoRA {'param_name': 'lora_name'}
48 | self.params_with_lora = {}
49 |
50 | def register_weight_after_backward(self):
51 | for param_name, _ in self.params_with_lora.items():
52 | p = set_param(self, param_name, mode='get')
53 | # print('+'*10, param_name, p.flatten()[:10])
54 | set_param(self, param_name, param=p, mode='update', with_nn=True)
55 |
56 | def register_lora_param(self):
57 | r"""Register LoRA matrix"""
58 | for param_name, lora_name in self.params_with_lora.items():
59 | assert len(eval(f'self.{param_name}').size()) == 2
60 | self.register_parameter(f'{lora_name}_lora_A',
61 | nn.Parameter(eval(f'self.{param_name}').new_zeros((self.r, eval(f'self.{param_name}').size()[1])))
62 | )
63 | self.register_parameter(f'{lora_name}_lora_B',
64 | nn.Parameter(eval(f'self.{param_name}').new_zeros((eval(f'self.{param_name}').size()[0], self.r)))
65 | )
66 | eval(f'self.{param_name}').requires_grad = False
67 |
68 | def init_lora_param(self):
69 | for param_name, lora_name in self.params_with_lora.items():
70 | if hasattr(self, f'{lora_name}_lora_A'):
71 | # initialize A the same way as the default for nn.Linear and B to zero
72 | nn.init.kaiming_uniform_(eval(f'self.{lora_name}_lora_A'), a=math.sqrt(5))
73 | nn.init.zeros_(eval(f'self.{lora_name}_lora_B'))
74 |
75 | def transpose(self, w: torch.Tensor):
76 | return w.transpose(0, 1) if self.fan_in_fan_out else w
77 |
78 | def merge_BA(self, param_name: str):
79 | lora_name = self.params_with_lora[param_name]
80 | return self.transpose((eval(f'self.{lora_name}_lora_B') @ eval(f'self.{lora_name}_lora_A')).view(eval(f'self.{param_name}').shape))
81 |
82 | def merge_lora_param(self):
83 | r"""p_new = p + scaling * B @ A and keep differentiable to A and B"""
84 | for param_name, lora_name in self.params_with_lora.items():
85 | p = set_param(self, param_name, mode='get')
86 | # detach() is very important here
87 | p_new = p.detach() + self.merge_BA(param_name) * self.scaling
88 | set_param(self, param_name, param=p_new, mode='update')
89 |
90 | def add_lora_data(self):
91 | r"""NOT differentiable"""
92 | for param_name, lora_name in self.params_with_lora.items():
93 | eval(f'self.{param_name}').data += self.merge_BA(param_name) * self.scaling
94 |
95 | def sub_lora_data(self):
96 | r"""NOT differentiable"""
97 | for param_name, lora_name in self.params_with_lora.items():
98 | eval(f'self.{param_name}').data -= self.merge_BA(param_name) * self.scaling
99 |
100 | def lora_train(self, mode: bool = True):
101 | if mode:
102 | if self.merged and self.r > 0:
103 | # Make sure that the weights are not merged
104 | self.sub_lora_data()
105 | self.merged = False
106 | else:
107 | if not self.merged and self.r > 0:
108 | # Merge the weights and mark it
109 | self.add_lora_data()
110 | self.merged = True
111 |
112 |
113 | class Embedding(nn.Embedding, LoRALayer):
114 | # LoRA implemented in a Embedding layer
115 | def __init__(
116 | self,
117 | num_embeddings: int,
118 | embedding_dim: int,
119 | r: int = 0,
120 | lora_alpha: int = 1,
121 | **kwargs
122 | ):
123 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
124 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
125 |
126 | self.params_with_lora = {'weight': 'w'}
127 | if r > 0:
128 | self.register_lora_param()
129 | nn.Embedding.reset_parameters(self)
130 | self.init_lora_param()
131 |
132 | def init_lora_param(self):
133 | if hasattr(self, 'w_lora_A'):
134 | # initialize A the same way as the default for nn.Linear and B to zero
135 | nn.init.zeros_(self.w_lora_A)
136 | nn.init.normal_(self.w_lora_B)
137 |
138 | def train(self, mode: bool = True):
139 | nn.Embedding.train(self, mode)
140 | self.lora_train(mode)
141 |
142 | def forward(self, x: torch.Tensor, **kwargs):
143 |
144 | if self.r > 0 and not self.merged:
145 | self.merge_lora_param()
146 | result = nn.Embedding.forward(self, x, **kwargs)
147 | self.sub_lora_data()
148 | return result
149 | else:
150 | return nn.Embedding.forward(self, x, **kwargs)
151 |
152 | class Linear(nn.Linear, LoRALayer):
153 | # LoRA implemented in a Linear layer
154 | def __init__(
155 | self,
156 | in_features: int,
157 | out_features: int,
158 | r: int = 0,
159 | lora_alpha: int = 1,
160 | fan_in_fan_out: bool = False,
161 | **kwargs
162 | ):
163 | nn.Linear.__init__(self, in_features, out_features, **kwargs)
164 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, fan_in_fan_out=fan_in_fan_out)
165 |
166 | # Actual trainable parameters
167 | self.params_with_lora = {'weight': 'w'}
168 | if r > 0:
169 | self.register_lora_param()
170 | nn.Linear.reset_parameters(self)
171 | self.init_lora_param()
172 | self.weight.data = self.transpose(self.weight.data)
173 |
174 | def train(self, mode: bool = True):
175 | nn.Linear.train(self, mode)
176 | self.lora_train(mode)
177 |
178 | def forward(self, x: torch.Tensor, **kwargs):
179 |
180 | if self.r > 0 and not self.merged:
181 | self.merge_lora_param()
182 | result = nn.Linear.forward(self, x, **kwargs)
183 | self.sub_lora_data()
184 | return result
185 | else:
186 | return nn.Linear.forward(self, x, **kwargs)
187 |
188 | class Conv1d(nn.Conv1d, LoRALayer):
189 | # LoRA implemented in a Conv1d layer
190 | def __init__(
191 | self,
192 | in_channels: int,
193 | out_channels: int,
194 | kernel_size: int,
195 | r: int = 0,
196 | lora_alpha: int = 1,
197 | **kwargs
198 | ):
199 | nn.Conv1d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
200 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
201 |
202 | assert type(kernel_size) is int
203 | # Actual trainable parameters
204 | self.params_with_lora = {'weight': 'w'}
205 | if r > 0:
206 | self.w_lora_A = nn.Parameter(
207 | self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
208 | )
209 | self.w_lora_B = nn.Parameter(
210 | self.weight.new_zeros((out_channels//self.groups, r*kernel_size))
211 | )
212 | # Freezing the pre-trained weight matrix
213 | self.weight.requires_grad = False
214 | nn.Conv1d.reset_parameters(self)
215 | self.init_lora_param()
216 |
217 | def train(self, mode: bool = True):
218 | nn.Conv1d.train(self, mode)
219 | self.lora_train(mode)
220 |
221 | def forward(self, x: torch.Tensor, **kwargs):
222 |
223 | if self.r > 0 and not self.merged:
224 | self.merge_lora_param()
225 | result = nn.Conv1d.forward(self, x, **kwargs)
226 | self.sub_lora_data()
227 | return result
228 | else:
229 | return nn.Conv1d.forward(self, x, **kwargs)
230 |
231 | class Conv2d(nn.Conv2d, LoRALayer):
232 | # LoRA implemented in a Conv2d layer
233 | def __init__(
234 | self,
235 | in_channels: int,
236 | out_channels: int,
237 | kernel_size: int,
238 | r: int = 0,
239 | lora_alpha: int = 1,
240 | **kwargs
241 | ):
242 | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
243 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
244 |
245 | assert type(kernel_size) is int
246 | # Actual trainable parameters
247 | self.params_with_lora = {'weight': 'w'}
248 | if r > 0:
249 | self.w_lora_A = nn.Parameter(
250 | self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
251 | )
252 | self.w_lora_B = nn.Parameter(
253 | self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
254 | )
255 | # Freezing the pre-trained weight matrix
256 | self.weight.requires_grad = False
257 | nn.Conv2d.reset_parameters(self)
258 | self.init_lora_param()
259 |
260 | def train(self, mode: bool = True):
261 | nn.Conv2d.train(self, mode)
262 | self.lora_train(mode)
263 |
264 | def forward(self, x: torch.Tensor, **kwargs):
265 |
266 | if self.r > 0 and not self.merged:
267 | self.merge_lora_param()
268 | result = nn.Conv2d.forward(self, x, **kwargs)
269 | self.sub_lora_data()
270 | return result
271 | else:
272 | return nn.Conv2d.forward(self, x, **kwargs)
273 |
274 | class Conv3d(nn.Conv3d, LoRALayer):
275 | # LoRA implemented in a Conv3d layer
276 | def __init__(
277 | self,
278 | in_channels: int,
279 | out_channels: int,
280 | kernel_size: int,
281 | r: int = 0,
282 | lora_alpha: int = 1,
283 | **kwargs
284 | ):
285 | nn.Conv3d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
286 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
287 |
288 | assert type(kernel_size) is int
289 | # Actual trainable parameters
290 | self.params_with_lora = {'weight': 'w'}
291 | if r > 0:
292 | self.w_lora_A = nn.Parameter(
293 | self.weight.new_zeros((r*kernel_size, in_channels*kernel_size*kernel_size))
294 | )
295 | self.w_lora_B = nn.Parameter(
296 | self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
297 | )
298 | # Freezing the pre-trained weight matrix
299 | self.weight.requires_grad = False
300 | nn.Conv3d.reset_parameters(self)
301 | self.init_lora_param()
302 |
303 | def train(self, mode: bool = True):
304 | nn.Conv3d.train(self, mode)
305 | self.lora_train(mode)
306 |
307 | def forward(self, x: torch.Tensor, **kwargs):
308 |
309 | if self.r > 0 and not self.merged:
310 | self.merge_lora_param()
311 | result = nn.Conv3d.forward(self, x, **kwargs)
312 | self.sub_lora_data()
313 | return result
314 | else:
315 | return nn.Conv3d.forward(self, x, **kwargs)
316 |
317 | class MultiheadAttention(nn.MultiheadAttention, LoRALayer):
318 | # LoRA implemented in a MultiheadAttention layer
319 | def __init__(
320 | self,
321 | embed_dim: int,
322 | num_heads: int,
323 | enable_lora: list = ['q', 'k', 'v', 'o'],
324 | r: int = 0,
325 | lora_alpha: int = 1,
326 | **kwargs
327 | ):
328 | nn.MultiheadAttention.__init__(self, embed_dim, num_heads, **kwargs)
329 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
330 |
331 | # Actual trainable parameters
332 | if self.r > 0:
333 | if 'o' in enable_lora:
334 | self.params_with_lora.update({'out_proj.weight': 'o'})
335 |
336 | if not self._qkv_same_embed_dim:
337 | for n in ['q', 'k', 'v']:
338 | if n in enable_lora:
339 | self.params_with_lora.update({f'{n}_proj_weight': n})
340 | self.register_lora_param()
341 | nn.MultiheadAttention._reset_parameters(self)
342 | self.init_lora_param()
343 | else:
344 | lora_name, enable_lora_bool = '', []
345 | for n in ['q', 'k', 'v']:
346 | if n in enable_lora:
347 | lora_name += n
348 | enable_lora_bool.append(True)
349 | else:
350 | enable_lora_bool.append(False)
351 | self.params_with_lora.update({'in_proj_weight': lora_name})
352 | self.register_lora_param()
353 | nn.MultiheadAttention._reset_parameters(self)
354 | if 'o' in enable_lora:
355 | self.init_lora_param_o()
356 | self.init_lora_param_qkv(enable_lora_bool)
357 |
358 | def init_lora_param_o(self):
359 | param_name, lora_name = 'out_proj.weight', 'o'
360 | if hasattr(self, f'{lora_name}_lora_A'):
361 | nn.init.kaiming_uniform_(eval(f'self.{lora_name}_lora_A'), a=math.sqrt(5))
362 | nn.init.zeros_(eval(f'self.{lora_name}_lora_B'))
363 |
364 | def init_lora_param_qkv(self, enable_lora_bool):
365 | lora_name = self.params_with_lora['in_proj_weight']
366 | nn.init.zeros_(eval(f'self.{lora_name}_lora_B'))
367 | dim = int(self.in_proj_weight.size()[1] / 3)
368 | for idx, enable in zip(range(3), enable_lora_bool):
369 | if enable:
370 | nn.init.kaiming_uniform_(eval(f'self.{lora_name}_lora_A')[:,idx*dim:(idx+1)*dim], a=math.sqrt(5))
371 | else:
372 | nn.init.zeros_(eval(f'self.{lora_name}_lora_A')[:,idx*dim:(idx+1)*dim])
373 |
374 | def train(self, mode: bool = True):
375 | nn.MultiheadAttention.train(self, mode)
376 | self.lora_train(mode)
377 |
378 | def forward(self,
379 | query: torch.Tensor,
380 | key: torch.Tensor,
381 | value: torch.Tensor,
382 | **kwargs):
383 |
384 | if self.r > 0 and not self.merged:
385 | self.merge_lora_param()
386 | result = nn.MultiheadAttention.forward(self, query, key, value, **kwargs)
387 | self.sub_lora_data()
388 | return result
389 | else:
390 | return nn.MultiheadAttention.forward(self, query, key, value, **kwargs)
391 |
392 | class MergedLinear(nn.Linear, LoRALayer):
393 | # LoRA implemented in a dense layer
394 | def __init__(
395 | self,
396 | in_features: int,
397 | out_features: int,
398 | r: int = 0,
399 | lora_alpha: int = 1,
400 | enable_lora: List[bool] = [False],
401 | fan_in_fan_out: bool = False,
402 | **kwargs
403 | ):
404 | nn.Linear.__init__(self, in_features, out_features, **kwargs)
405 | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
406 |
407 | assert out_features % len(enable_lora) == 0, \
408 | 'The length of enable_lora must divide out_features'
409 | self.enable_lora = enable_lora
410 | # Actual trainable parameters
411 | self.params_with_lora = {'weight': 'w'}
412 | if r > 0 and any(enable_lora):
413 | self.w_lora_A = nn.Parameter(
414 | self.weight.new_zeros((r * sum(enable_lora), in_features)))
415 | self.w_lora_B = nn.Parameter(
416 | self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
417 | ) # weights for Conv1D with groups=sum(enable_lora)
418 | # Freezing the pre-trained weight matrix
419 | self.weight.requires_grad = False
420 | # Compute the indices
421 | self.lora_ind = self.weight.new_zeros(
422 | (out_features, ), dtype=torch.bool
423 | ).view(len(enable_lora), -1)
424 | self.lora_ind[enable_lora, :] = True
425 | self.lora_ind = self.lora_ind.view(-1)
426 | nn.Linear.reset_parameters(self)
427 | self.init_lora_param()
428 | self.weight.data = self.transpose(self.weight.data)
429 |
430 | def zero_pad(self, x):
431 | result = x.new_zeros((len(self.lora_ind), *x.shape[1:]))
432 | result[self.lora_ind] = x
433 | return result
434 |
435 | def merge_BA(self, param_name: str):
436 | lora_name = self.params_with_lora[param_name]
437 | delta_w = F.conv1d(
438 | eval(f'self.{lora_name}_lora_A').unsqueeze(0),
439 | eval(f'self.{lora_name}_lora_B').unsqueeze(-1),
440 | groups=sum(self.enable_lora)
441 | ).squeeze(0)
442 | return self.transpose(self.zero_pad(delta_w))
443 |
444 | def train(self, mode: bool = True):
445 | nn.Linear.train(self, mode)
446 | self.lora_train(mode)
447 |
448 | def forward(self, x: torch.Tensor, **kwargs):
449 |
450 | if self.r > 0 and not self.merged:
451 | self.merge_lora_param()
452 | result = nn.Linear.forward(self, x, **kwargs)
453 | self.sub_lora_data()
454 | return result
455 | else:
456 | return nn.Linear.forward(self, x, **kwargs)
457 |
--------------------------------------------------------------------------------
/examples/Finetune_open_clip_with_LoRA_Torch_on_CIFAR10.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "metadata": {
3 | "kernelspec": {
4 | "name": "python3",
5 | "display_name": "Python 3",
6 | "language": "python"
7 | },
8 | "language_info": {
9 | "name": "python",
10 | "version": "3.11.11",
11 | "mimetype": "text/x-python",
12 | "codemirror_mode": {
13 | "name": "ipython",
14 | "version": 3
15 | },
16 | "pygments_lexer": "ipython3",
17 | "nbconvert_exporter": "python",
18 | "file_extension": ".py"
19 | },
20 | "colab": {
21 | "provenance": [],
22 | "gpuType": "T4",
23 | "include_colab_link": true
24 | },
25 | "accelerator": "GPU",
26 | "kaggle": {
27 | "accelerator": "nvidiaTeslaT4",
28 | "dataSources": [],
29 | "dockerImageVersionId": 31041,
30 | "isInternetEnabled": true,
31 | "language": "python",
32 | "sourceType": "notebook",
33 | "isGpuEnabled": true
34 | }
35 | },
36 | "nbformat_minor": 0,
37 | "nbformat": 4,
38 | "cells": [
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {
42 | "id": "view-in-github",
43 | "colab_type": "text"
44 | },
45 | "source": [
46 | " "
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "source": [
52 | "### This example demonstrates how to apply LoRA-Torch to ``nn.MultiheadAttention`` in OpenCLIP. We greatly appreciate [Viet Q. Vo](https://vietvo89.github.io/)'s valuable contribution."
53 | ],
54 | "metadata": {
55 | "id": "Yhxyu7vivWl2"
56 | }
57 | },
58 | {
59 | "cell_type": "code",
60 | "source": [
61 | "!pip install open-clip-torch\n",
62 | "!pip install git+https://github.com/Baijiong-Lin/LoRA-Torch"
63 | ],
64 | "metadata": {
65 | "id": "753R63XXhzqE",
66 | "outputId": "ad3892df-aeb6-4ebe-b7ae-2400535ec3ab",
67 | "trusted": true,
68 | "execution": {
69 | "iopub.status.busy": "2025-06-09T01:09:26.555991Z",
70 | "iopub.execute_input": "2025-06-09T01:09:26.556177Z",
71 | "iopub.status.idle": "2025-06-09T01:10:48.370674Z",
72 | "shell.execute_reply.started": "2025-06-09T01:09:26.556160Z",
73 | "shell.execute_reply": "2025-06-09T01:10:48.369896Z"
74 | },
75 | "scrolled": true,
76 | "colab": {
77 | "base_uri": "https://localhost:8080/"
78 | }
79 | },
80 | "outputs": [
81 | {
82 | "output_type": "stream",
83 | "name": "stdout",
84 | "text": [
85 | "Requirement already satisfied: open-clip-torch in /usr/local/lib/python3.11/dist-packages (2.32.0)\n",
86 | "Requirement already satisfied: torch>=1.9.0 in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (2.6.0+cu124)\n",
87 | "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (0.21.0+cu124)\n",
88 | "Requirement already satisfied: regex in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (2024.11.6)\n",
89 | "Requirement already satisfied: ftfy in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (6.3.1)\n",
90 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (4.67.1)\n",
91 | "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (0.32.4)\n",
92 | "Requirement already satisfied: safetensors in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (0.5.3)\n",
93 | "Requirement already satisfied: timm in /usr/local/lib/python3.11/dist-packages (from open-clip-torch) (1.0.15)\n",
94 | "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (3.18.0)\n",
95 | "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (4.14.0)\n",
96 | "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (3.5)\n",
97 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (3.1.6)\n",
98 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (2025.3.2)\n",
99 | "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n",
100 | "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n",
101 | "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n",
102 | "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (9.1.0.70)\n",
103 | "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.5.8)\n",
104 | "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (11.2.1.3)\n",
105 | "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (10.3.5.147)\n",
106 | "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (11.6.1.9)\n",
107 | "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.3.1.170)\n",
108 | "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (0.6.2)\n",
109 | "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (2.21.5)\n",
110 | "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n",
111 | "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (12.4.127)\n",
112 | "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (3.2.0)\n",
113 | "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=1.9.0->open-clip-torch) (1.13.1)\n",
114 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=1.9.0->open-clip-torch) (1.3.0)\n",
115 | "Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from ftfy->open-clip-torch) (0.2.13)\n",
116 | "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->open-clip-torch) (24.2)\n",
117 | "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->open-clip-torch) (6.0.2)\n",
118 | "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->open-clip-torch) (2.32.3)\n",
119 | "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub->open-clip-torch) (1.1.2)\n",
120 | "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision->open-clip-torch) (2.0.2)\n",
121 | "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision->open-clip-torch) (11.2.1)\n",
122 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=1.9.0->open-clip-torch) (3.0.2)\n",
123 | "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub->open-clip-torch) (3.4.2)\n",
124 | "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub->open-clip-torch) (3.10)\n",
125 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub->open-clip-torch) (2.4.0)\n",
126 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub->open-clip-torch) (2025.4.26)\n",
127 | "Collecting git+https://github.com/Baijiong-Lin/LoRA-Torch\n",
128 | " Cloning https://github.com/Baijiong-Lin/LoRA-Torch to /tmp/pip-req-build-1a5zm6vx\n",
129 | " Running command git clone --filter=blob:none --quiet https://github.com/Baijiong-Lin/LoRA-Torch /tmp/pip-req-build-1a5zm6vx\n",
130 | " Resolved https://github.com/Baijiong-Lin/LoRA-Torch to commit 3b6f10a3bdebfb0da1abeb4c265f914ed06759e4\n",
131 | " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
132 | ]
133 | }
134 | ],
135 | "execution_count": 1
136 | },
137 | {
138 | "cell_type": "code",
139 | "source": [
140 | "import torch\n",
141 | "import torch.nn as nn\n",
142 | "from torchvision import transforms\n",
143 | "from torch.utils.data import DataLoader\n",
144 | "import open_clip\n",
145 | "import loratorch as lora\n",
146 | "from tqdm import tqdm"
147 | ],
148 | "metadata": {
149 | "id": "T9hDV9jQiU0L",
150 | "trusted": true,
151 | "execution": {
152 | "iopub.status.busy": "2025-06-09T01:10:48.376803Z",
153 | "iopub.execute_input": "2025-06-09T01:10:48.376979Z",
154 | "iopub.status.idle": "2025-06-09T01:11:15.903375Z",
155 | "shell.execute_reply.started": "2025-06-09T01:10:48.376955Z",
156 | "shell.execute_reply": "2025-06-09T01:11:15.902530Z"
157 | }
158 | },
159 | "outputs": [],
160 | "execution_count": 2
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "source": [
165 | "### A. Load Pre-trained Model"
166 | ],
167 | "metadata": {
168 | "id": "kDBpUYgYjrGw"
169 | }
170 | },
171 | {
172 | "cell_type": "code",
173 | "source": [
174 | "model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')\n",
175 | "model = model.cuda()\n",
176 | "tokenizer = open_clip.get_tokenizer('ViT-B-32')"
177 | ],
178 | "metadata": {
179 | "trusted": true,
180 | "execution": {
181 | "iopub.status.busy": "2025-06-09T01:11:15.904141Z",
182 | "iopub.execute_input": "2025-06-09T01:11:15.904827Z",
183 | "iopub.status.idle": "2025-06-09T01:11:20.771415Z",
184 | "shell.execute_reply.started": "2025-06-09T01:11:15.904798Z",
185 | "shell.execute_reply": "2025-06-09T01:11:20.770632Z"
186 | },
187 | "id": "1bqndlUex2HS",
188 | "colab": {
189 | "base_uri": "https://localhost:8080/"
190 | },
191 | "outputId": "9bf32ef6-6a5f-45c0-a8ea-64fb9381a6dc"
192 | },
193 | "outputs": [
194 | {
195 | "output_type": "stream",
196 | "name": "stderr",
197 | "text": [
198 | "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n",
199 | "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
200 | "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
201 | "You will be able to reuse this secret in all of your notebooks.\n",
202 | "Please note that authentication is recommended but still optional to access public models or datasets.\n",
203 | " warnings.warn(\n",
204 | "/usr/local/lib/python3.11/dist-packages/open_clip/factory.py:388: UserWarning: These pretrained weights were trained with QuickGELU activation but the model config does not have that enabled. Consider using a model config with a \"-quickgelu\" suffix or enable with a flag.\n",
205 | " warnings.warn(\n"
206 | ]
207 | }
208 | ],
209 | "execution_count": 3
210 | },
211 | {
212 | "cell_type": "code",
213 | "source": [
214 | "# prompt: count trainable parameters of model?\n",
215 | "\n",
216 | "def count_parameters(model):\n",
217 | " trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
218 | " vision_params = sum(p.numel() for p in model.visual.transformer.parameters() if p.requires_grad)\n",
219 | " text_params = sum(p.numel() for p in model.transformer.parameters() if p.requires_grad)\n",
220 | " embed_params = sum(p.numel() for p in model.token_embedding.parameters() if p.requires_grad)\n",
221 | " total_params = sum(p.numel() for p in model.parameters())\n",
222 | " print(f\"Total parameters: {total_params:,}\")\n",
223 | " print(f\"Trainable parameters - Full model: {trainable_params:,}\")\n",
224 | " print(f\"Trainable parameters - Vision: {vision_params:,}\")\n",
225 | " print(f\"Trainable parameters - Text: {text_params:,}\")\n",
226 | " print(f\"Trainable parameters - embedding: {embed_params:,}\")"
227 | ],
228 | "metadata": {
229 | "id": "VQoMyCMio4Kg",
230 | "trusted": true,
231 | "execution": {
232 | "iopub.status.busy": "2025-06-09T01:21:06.763946Z",
233 | "iopub.execute_input": "2025-06-09T01:21:06.764449Z",
234 | "iopub.status.idle": "2025-06-09T01:21:06.769661Z",
235 | "shell.execute_reply.started": "2025-06-09T01:21:06.764428Z",
236 | "shell.execute_reply": "2025-06-09T01:21:06.768894Z"
237 | }
238 | },
239 | "outputs": [],
240 | "execution_count": 4
241 | },
242 | {
243 | "cell_type": "code",
244 | "source": [
245 | "print('Original model before adding lora')\n",
246 | "count_parameters(model)"
247 | ],
248 | "metadata": {
249 | "trusted": true,
250 | "colab": {
251 | "base_uri": "https://localhost:8080/"
252 | },
253 | "id": "eJXm5sjmx2HW",
254 | "outputId": "77426425-634d-468e-e379-21db4d9312ec"
255 | },
256 | "outputs": [
257 | {
258 | "output_type": "stream",
259 | "name": "stdout",
260 | "text": [
261 | "Original model before adding lora\n",
262 | "Total parameters: 151,277,313\n",
263 | "Trainable parameters - Full model: 151,277,313\n",
264 | "Trainable parameters - Vision: 85,054,464\n",
265 | "Trainable parameters - Text: 37,828,608\n",
266 | "Trainable parameters - embedding: 25,296,896\n"
267 | ]
268 | }
269 | ],
270 | "execution_count": 5
271 | },
272 | {
273 | "cell_type": "markdown",
274 | "source": [
275 | "### B. Load CIFAR-10"
276 | ],
277 | "metadata": {
278 | "id": "OMHX0z87i5_9"
279 | }
280 | },
281 | {
282 | "cell_type": "code",
283 | "source": [
284 | "# prompt: load cifar10 dataset\n",
285 | "\n",
286 | "from torchvision.datasets import CIFAR10\n",
287 | "\n",
288 | "train_dataset = CIFAR10(\n",
289 | " root=\"./data\", train=True, download=True,\n",
290 | " transform=preprocess\n",
291 | ")\n",
292 | "test_dataset = CIFAR10(\n",
293 | " root=\"./data\", train=False, download=True,\n",
294 | " transform=preprocess\n",
295 | ")\n",
296 | "\n",
297 | "batch_size_train = 256\n",
298 | "batch_size_test = 256\n",
299 | "train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=4)\n",
300 | "test_loader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, num_workers=4)\n"
301 | ],
302 | "metadata": {
303 | "id": "mA1W_9IIiA_P",
304 | "outputId": "5d786a0d-d462-4034-ad85-3ce88c030ca8",
305 | "trusted": true,
306 | "execution": {
307 | "iopub.status.busy": "2025-06-09T01:11:20.772338Z",
308 | "iopub.execute_input": "2025-06-09T01:11:20.772628Z",
309 | "iopub.status.idle": "2025-06-09T01:11:26.195815Z",
310 | "shell.execute_reply.started": "2025-06-09T01:11:20.772603Z",
311 | "shell.execute_reply": "2025-06-09T01:11:26.195198Z"
312 | },
313 | "colab": {
314 | "base_uri": "https://localhost:8080/"
315 | }
316 | },
317 | "outputs": [
318 | {
319 | "output_type": "stream",
320 | "name": "stderr",
321 | "text": [
322 | "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py:624: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
323 | " warnings.warn(\n"
324 | ]
325 | }
326 | ],
327 | "execution_count": 6
328 | },
329 | {
330 | "cell_type": "markdown",
331 | "source": [
332 | "### C. Fine-tune OpenCLIP with LoRA\n",
333 | "\n",
334 | "_Note:_\n",
335 | "\n",
336 | "Please make sure ``loratorch.MultiheadAttention`` uses the same input parameter values as [`nn.MultiheadAttention`](https://docs.pytorch.org/docs/2.6/generated/torch.nn.MultiheadAttention.html#multiheadattention).\n",
337 | "\n",
338 | "For exmaple, the default value for batch_first in `nn.MultiheadAttention` is `False`, but `open_clip` sets it to `True` in some `attn` layers. The discussion of this can be found [here](https://github.com/Baijiong-Lin/LoRA-Torch/issues/6#issuecomment-2954122864).\n",
339 | "\n",
340 | "The best way of employing `loratorch.MultiheadAttention` is the following:\n",
341 | "```python\n",
342 | "lora_multihead = lora.MultiheadAttention(r=r,\n",
343 | " lora_alpha=lora_alpha,\n",
344 | " enable_lora=enable_lora,\n",
345 | " embed_dim=multihead.embed_dim,\n",
346 | " num_heads=multihead.num_heads,\n",
347 | " dropout=multihead.dropout,\n",
348 | " bias=True if hasattr(multihead, \"in_proj_bias\") else False,\n",
349 | " add_bias_kv=False if multihead.bias_k==None else True,\n",
350 | " add_zero_attn=multihead.add_zero_attn,\n",
351 | " kdim=multihead.kdim,\n",
352 | " vdim=multihead.vdim,\n",
353 | " batch_first=multihead.batch_first)\n",
354 | "```"
355 | ],
356 | "metadata": {
357 | "id": "Ahjx5_MPlbvU"
358 | }
359 | },
360 | {
361 | "cell_type": "markdown",
362 | "source": [
363 | "#### Apply LoRA to `attn` and `mlp`"
364 | ],
365 | "metadata": {
366 | "id": "Y1sLkX5tDZGz"
367 | }
368 | },
369 | {
370 | "cell_type": "code",
371 | "source": [
372 | "def apply_lora_attn_mlp(model, encoder_type='visual', rank=16, lora_alpha=32, mlp=True, attn=True):\n",
373 | " if encoder_type == 'visual':\n",
374 | " encoder = model.visual.transformer\n",
375 | " elif encoder_type == 'text':\n",
376 | " encoder = model.transformer\n",
377 | " else:\n",
378 | " raise ValueError(\"Invalid encoder_type. Choose 'visual' or 'text'.\")\n",
379 | "\n",
380 | " enable_lora=['q', 'k', 'v', 'o']\n",
381 | " for i, resblock in enumerate(encoder.resblocks):\n",
382 | " if hasattr(resblock, 'attn') and attn:\n",
383 | " multihead = resblock.attn\n",
384 | " lora_multihead = lora.MultiheadAttention(r=rank,\n",
385 | " lora_alpha=lora_alpha,\n",
386 | " enable_lora=enable_lora,\n",
387 | " embed_dim=multihead.embed_dim,\n",
388 | " num_heads=multihead.num_heads,\n",
389 | " dropout=multihead.dropout,\n",
390 | " bias=True if hasattr(multihead, \"in_proj_bias\") else False,\n",
391 | " add_bias_kv=False if multihead.bias_k==None else True,\n",
392 | " add_zero_attn=multihead.add_zero_attn,\n",
393 | " kdim=multihead.kdim,\n",
394 | " vdim=multihead.vdim,\n",
395 | " batch_first=multihead.batch_first)\n",
396 | " lora_multihead.load_state_dict(multihead.state_dict(), strict=False)\n",
397 | " resblock.attn = lora_multihead\n",
398 | "\n",
399 | " if hasattr(resblock, 'mlp') and mlp:\n",
400 | " old_mlp_fc=resblock.mlp.c_fc\n",
401 | " old_mlp_proj=resblock.mlp.c_proj\n",
402 | " new_mlp_fc = lora.Linear(\n",
403 | " old_mlp_fc.in_features,\n",
404 | " old_mlp_fc.out_features,\n",
405 | " bias=True if hasattr(old_mlp_fc, \"bias\") else False,\n",
406 | " r=rank,\n",
407 | " lora_alpha=lora_alpha,\n",
408 | " )\n",
409 | " new_mlp_proj = lora.Linear(\n",
410 | " old_mlp_proj.in_features,\n",
411 | " old_mlp_proj.out_features,\n",
412 | " bias=True if hasattr(old_mlp_proj, \"bias\") else False,\n",
413 | " r=rank,\n",
414 | " lora_alpha=lora_alpha,\n",
415 | " )\n",
416 | " new_mlp_fc.load_state_dict(old_mlp_fc.state_dict(),strict=False)\n",
417 | " new_mlp_proj.load_state_dict(old_mlp_proj.state_dict(),strict=False)\n",
418 | " resblock.mlp.c_fc = new_mlp_fc\n",
419 | " resblock.mlp.c_proj = new_mlp_proj\n",
420 | "\n",
421 | " lora.mark_only_lora_as_trainable(model)\n",
422 | " return model"
423 | ],
424 | "metadata": {
425 | "trusted": true,
426 | "execution": {
427 | "iopub.status.busy": "2025-06-09T01:20:16.374517Z",
428 | "iopub.execute_input": "2025-06-09T01:20:16.375239Z",
429 | "iopub.status.idle": "2025-06-09T01:20:16.382721Z",
430 | "shell.execute_reply.started": "2025-06-09T01:20:16.375208Z",
431 | "shell.execute_reply": "2025-06-09T01:20:16.382180Z"
432 | },
433 | "id": "UQ8c81GOx2HZ"
434 | },
435 | "outputs": [],
436 | "execution_count": 7
437 | },
438 | {
439 | "cell_type": "code",
440 | "source": [
441 | "apply_lora_attn_mlp(model, encoder_type='visual', rank=16, lora_alpha=32, mlp=True, attn=True)\n",
442 | "tokenizer = open_clip.get_tokenizer('ViT-B-32')"
443 | ],
444 | "metadata": {
445 | "trusted": true,
446 | "execution": {
447 | "iopub.status.busy": "2025-06-09T01:20:17.384221Z",
448 | "iopub.execute_input": "2025-06-09T01:20:17.384480Z",
449 | "iopub.status.idle": "2025-06-09T01:20:18.691620Z",
450 | "shell.execute_reply.started": "2025-06-09T01:20:17.384462Z",
451 | "shell.execute_reply": "2025-06-09T01:20:18.690864Z"
452 | },
453 | "id": "WThYJ3zzx2HZ"
454 | },
455 | "outputs": [],
456 | "execution_count": 8
457 | },
458 | {
459 | "cell_type": "code",
460 | "source": [
461 | "for name, param in model.visual.transformer.resblocks[0].named_parameters():\n",
462 | " print(name, param.requires_grad)\n",
463 | "\n",
464 | "# after adding lora\n",
465 | "print(\"\\nAfter adding LoRA to Attn+MLP:\")\n",
466 | "count_parameters(model)"
467 | ],
468 | "metadata": {
469 | "trusted": true,
470 | "execution": {
471 | "iopub.status.busy": "2025-06-09T01:21:17.260662Z",
472 | "iopub.execute_input": "2025-06-09T01:21:17.260951Z",
473 | "iopub.status.idle": "2025-06-09T01:21:17.267985Z",
474 | "shell.execute_reply.started": "2025-06-09T01:21:17.260934Z",
475 | "shell.execute_reply": "2025-06-09T01:21:17.267184Z"
476 | },
477 | "colab": {
478 | "base_uri": "https://localhost:8080/"
479 | },
480 | "id": "DT_NoSaTx2HZ",
481 | "outputId": "2427de82-3257-4b98-fe21-64c9330ee75c"
482 | },
483 | "outputs": [
484 | {
485 | "output_type": "stream",
486 | "name": "stdout",
487 | "text": [
488 | "ln_1.weight False\n",
489 | "ln_1.bias False\n",
490 | "attn.in_proj_weight False\n",
491 | "attn.in_proj_bias False\n",
492 | "attn.o_lora_A True\n",
493 | "attn.o_lora_B True\n",
494 | "attn.qkv_lora_A True\n",
495 | "attn.qkv_lora_B True\n",
496 | "attn.out_proj.weight False\n",
497 | "attn.out_proj.bias False\n",
498 | "ln_2.weight False\n",
499 | "ln_2.bias False\n",
500 | "mlp.c_fc.weight False\n",
501 | "mlp.c_fc.bias False\n",
502 | "mlp.c_fc.w_lora_A True\n",
503 | "mlp.c_fc.w_lora_B True\n",
504 | "mlp.c_proj.weight False\n",
505 | "mlp.c_proj.bias False\n",
506 | "mlp.c_proj.w_lora_A True\n",
507 | "mlp.c_proj.w_lora_B True\n",
508 | "\n",
509 | "After adding LoRA to Attn+MLP:\n",
510 | "Total parameters: 153,636,609\n",
511 | "Trainable parameters - Full model: 2,359,296\n",
512 | "Trainable parameters - Vision: 2,359,296\n",
513 | "Trainable parameters - Text: 0\n",
514 | "Trainable parameters - embedding: 0\n"
515 | ]
516 | }
517 | ],
518 | "execution_count": 9
519 | },
520 | {
521 | "cell_type": "code",
522 | "source": [
523 | "# Tokenizer and text embeddings\n",
524 | "model.cuda()\n",
525 | "\n",
526 | "tokenizer = open_clip.get_tokenizer(\"ViT-B-32\")\n",
527 | "classnames = train_dataset.classes\n",
528 | "text_inputs = tokenizer([f\"a photo of a {label}\" for label in classnames]).cuda()\n",
529 | "with torch.no_grad():\n",
530 | " text_features = model.encode_text(text_inputs)\n",
531 | " text_features = text_features / text_features.norm(dim=-1, keepdim=True)\n",
532 | "\n",
533 | "# Optimizer\n",
534 | "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)"
535 | ],
536 | "metadata": {
537 | "trusted": true,
538 | "execution": {
539 | "iopub.status.busy": "2025-06-09T01:21:34.494110Z",
540 | "iopub.execute_input": "2025-06-09T01:21:34.494363Z",
541 | "iopub.status.idle": "2025-06-09T01:21:35.033731Z",
542 | "shell.execute_reply.started": "2025-06-09T01:21:34.494344Z",
543 | "shell.execute_reply": "2025-06-09T01:21:35.032959Z"
544 | },
545 | "id": "RdVjI-pvx2HZ"
546 | },
547 | "outputs": [],
548 | "execution_count": 10
549 | },
550 | {
551 | "cell_type": "code",
552 | "source": [
553 | "# Train loop\n",
554 | "model.train()\n",
555 | "for epoch in range(3):\n",
556 | " total_loss = 0\n",
557 | " correct = 0\n",
558 | " total = 0\n",
559 | " for images, labels in tqdm(train_loader):\n",
560 | " images, labels = images.cuda(), labels.cuda()\n",
561 | " image_features = model.encode_image(images)\n",
562 | " image_features = image_features / image_features.norm(dim=-1, keepdim=True)\n",
563 | " logits = image_features @ text_features.t()\n",
564 | " loss = nn.CrossEntropyLoss()(logits, labels)\n",
565 | "\n",
566 | " optimizer.zero_grad()\n",
567 | " loss.backward()\n",
568 | " optimizer.step()\n",
569 | " total_loss += loss.item()\n",
570 | "\n",
571 | " preds = logits.argmax(dim=1)\n",
572 | " correct += (preds == labels).sum().item()\n",
573 | " total += labels.size(0)\n",
574 | " # (!!!) reregister model param to ensure they are in model.state_dict() and model.parameters()\n",
575 | " # (!!!) Without this line, the performance does not be affected but you will find that some weights are missing in model.state_dict() and model.parameters()\n",
576 | " lora.register_model_param_after_backward(model)\n",
577 | "\n",
578 | " acc = correct / total\n",
579 | " print(f\"Epoch {epoch+1}: Loss={total_loss:.4f}, Accuracy={acc:.4f}\")"
580 | ],
581 | "metadata": {
582 | "outputId": "153083d4-3f77-4b1d-bc7e-a9798cfb1139",
583 | "trusted": true,
584 | "execution": {
585 | "iopub.status.busy": "2025-06-09T01:21:35.369183Z",
586 | "iopub.execute_input": "2025-06-09T01:21:35.369507Z",
587 | "iopub.status.idle": "2025-06-09T01:54:15.576408Z",
588 | "shell.execute_reply.started": "2025-06-09T01:21:35.369485Z",
589 | "shell.execute_reply": "2025-06-09T01:54:15.575561Z"
590 | },
591 | "id": "UCgTKj1qx2HZ",
592 | "colab": {
593 | "base_uri": "https://localhost:8080/"
594 | }
595 | },
596 | "outputs": [
597 | {
598 | "output_type": "stream",
599 | "name": "stderr",
600 | "text": [
601 | "100%|██████████| 196/196 [06:43<00:00, 2.06s/it]\n"
602 | ]
603 | },
604 | {
605 | "output_type": "stream",
606 | "name": "stdout",
607 | "text": [
608 | "Epoch 1: Loss=397.7533, Accuracy=0.9529\n"
609 | ]
610 | },
611 | {
612 | "output_type": "stream",
613 | "name": "stderr",
614 | "text": [
615 | "100%|██████████| 196/196 [06:40<00:00, 2.04s/it]\n"
616 | ]
617 | },
618 | {
619 | "output_type": "stream",
620 | "name": "stdout",
621 | "text": [
622 | "Epoch 2: Loss=387.0708, Accuracy=0.9796\n"
623 | ]
624 | },
625 | {
626 | "output_type": "stream",
627 | "name": "stderr",
628 | "text": [
629 | "100%|██████████| 196/196 [06:39<00:00, 2.04s/it]"
630 | ]
631 | },
632 | {
633 | "output_type": "stream",
634 | "name": "stdout",
635 | "text": [
636 | "Epoch 3: Loss=386.2361, Accuracy=0.9901\n"
637 | ]
638 | },
639 | {
640 | "output_type": "stream",
641 | "name": "stderr",
642 | "text": [
643 | "\n"
644 | ]
645 | }
646 | ],
647 | "execution_count": 11
648 | },
649 | {
650 | "cell_type": "code",
651 | "source": [],
652 | "metadata": {
653 | "trusted": true,
654 | "id": "pMS8k5sXx2Ha"
655 | },
656 | "outputs": [],
657 | "execution_count": 11
658 | }
659 | ]
660 | }
--------------------------------------------------------------------------------
|