├── .gitignore ├── README.md ├── examples └── test.py ├── lte ├── __init__.py ├── base.py ├── config.py ├── ddp │ ├── ddp_lte.py │ └── linear.py ├── dmp │ ├── dmp_lte.py │ └── linear.py ├── mhlora │ └── linear.py ├── misc │ ├── __init__.py │ ├── attention.py │ ├── common.py │ ├── distributed.py │ ├── merge.py │ └── position.py ├── prepare.py └── replica.py ├── setup.py └── unittest └── run_tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | runs/ 2 | wandb/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LTE: LoRA-the-Explorer 2 | 3 | LoRA-the-explorer (LTE) is a framework to fine-tune and pre-train models without directly optimizing over the main weights. 4 | This is a minimal re-implementation of the codebase with tools for small- to mid-scale research development. 5 | 6 | 7 | ## Installation 8 | Tested on Ubuntu with Python 3.11 and PyTorch2.1/2.2. Older torch versions may not support some operators used in the codebase. 9 | 10 | ```bash 11 | git clone https://github.com/minyoungg/LTE 12 | cd LTE 13 | pip install -e . 14 | ``` 15 | 16 | ## Example usage 17 | By default, this codebase uses the reset-less version (see Appendix in the paper) 18 | 19 | ```python 20 | import lte 21 | 22 | # your neural network 23 | model = MyModel() 24 | 25 | # converts into an LTE model 26 | model = lte.prepare_model_for_lte( 27 | model.cuda(), 28 | lte.LTEConfig.default( 29 | lora_r=32, 30 | lora_alpha=4096, 31 | num_heads=32, 32 | ), 33 | ) 34 | ``` 35 | 36 | Given a mini-batch, LTE will automatically chunk up the batch size and parallelize it across each LoRA head. 37 | 38 | ```python 39 | x = get_data() 40 | assert x.size(0) % 32 == 0, 'make sure batch-size is divisible by num_heads' 41 | model(x) 42 | ``` 43 | 44 | To merge the model, you can use a merge scheduler `lte.misc.merge.MergeCondition`, or you can implement your own. For example: 45 | 46 | ```python 47 | for n, m in model.named_modules(): 48 | if isinstance(m, lte.LTELayer, lte.ReplicaLayer): 49 | m.merge_parameters() 50 | ``` 51 | 52 | If you have layers that are not supported, you can pass them as a replica layer, which will replicate the layer across all devices. These parameters are averaged when merged. Unfortunately, replica layers will likely require a separate learning rate from the LoRA parameters. 53 | 54 | ```python 55 | model = lte.prepare_model_for_lte( 56 | model.cuda(), 57 | lte.LTEConfig.default( 58 | lora_r=32, 59 | lora_alpha=4096, 60 | num_heads=32, 61 | ), 62 | replica_layers=[model.ignore_this_layer] 63 | ) 64 | ``` 65 | 66 | We include some helpful functions that might be useful. 67 | 68 | ```python 69 | # convert Conv2D projection layers in ViT with their linear counterparts 70 | lte.misc.replace_conv_proj_with_linear(model) 71 | 72 | # disables affine parameters in LayerNorm 73 | # from my experience, disabling results in better performance for both LTE and standard training 74 | lte.misc.disable_norm_affine_parameters(model) 75 | ``` 76 | 77 | The current codebase was mainly used for artificially emulating distributed training pipelines. 78 | We currently provide `DistributedDataParallel`(DDP) and `DistributedModelParallel`(DMP). (A better name could have been chosen). 79 | 80 | Here is a quick TLDR of how they are implemented. 81 | Assume we have `H` virtual-devices on `N` gpu-devices. In DDP mode, we will create `1` main weight and `H` LoRA parameters. 82 | The LoRA devices will share the forward pass of the main weight across all `H` virtual devices since it is redundant. 83 | Using `torchrun` will chunk the data across devices and also across virtual devices. 84 | This will keep memory and compute costs low for development purposes. 85 | 86 | DMP mode is more faithful to how it will be implemented in practice. DMP creates `H` copies of the main weights and `H` LoRA parameters distributed across `N` devices. 87 | Most PyTorch cuda operations should be non-blocking, but they will still run much slower than DMP as they do not share any computation between the virtual devices. 88 | 89 | You can switch between these modes via a flag. 90 | ```python 91 | model = lte.prepare_model_for_lte( 92 | ... 93 | mode='ddp' # or 'dmp' ('ddp' by default) 94 | ) 95 | ``` 96 | 97 | DMP will automatically distribute across all visible cuda-devices without using `torchrun`, so make sure you set the visibility correctly. 98 | ```bash 99 | # will automatically distribute across 4 devices 100 | CUDA_VISIBLE_DEVICES=1,2,3,4 python lte_dmp_train_script.sh 101 | ``` 102 | 103 | DDP should be used with `torchrun`. 104 | 105 | ### Helpful guidelines 106 | - First, test whether the mhlora parameterization of the model will converge to the same test loss. We added `mode="mhlora"` to help you with this. 107 | - Note that different alpha values might result in the same training loss, but vastly different test loss. Alpha values of (1024, 2048, 4096, 8192) is a good range to search over. 108 | - If mhlora matches the pre-training performance, LTE with `merge_iter=1` can recover the same performance. 109 | - LTE will require longer training iteration to converge since the mini-batch is sharded across each head. Using a larger batch size may help. 110 | - Next, increase the `merge_iter` to get the asynchronous benefits. 111 | 112 | 113 | ### MORE CODE COMING SOON 114 | - [ ] 4bit quantization support 115 | - [ ] Layernorm and Conv2d support 116 | - [ ] Full training example 117 | 118 | Note: we do not support standalone parameters, so wrap it as a module to replicate. 119 | 120 | ### Citation 121 | If you found the library useful, please consider citing 122 | ```bibtex 123 | @article{huh2024lte, 124 | title={Training Neural Networks from Scratch with Parallel Low-Rank Adapters}, 125 | author={Huh, Minyoung and Cheung, Brian and Bernstein, Jeremy and Isola, Phillip and Agrawal, Pulkit}, 126 | journal={arXiv preprint arXiv:2402.16828}, 127 | year={2024} 128 | } 129 | ``` 130 | -------------------------------------------------------------------------------- /examples/test.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import lte 3 | import torchvision.models as tvm 4 | 5 | 6 | vit_s_32_config = { 7 | "patch_size": 32, 8 | "num_layers": 12, 9 | "num_heads": 6, 10 | "hidden_dim": 384, 11 | "mlp_dim": 1536, 12 | "dropout": 0.0, 13 | "attention_dropout": 0.0, 14 | } 15 | 16 | model = tvm.VisionTransformer( 17 | image_size=224, 18 | num_classes=1000, 19 | **vit_s_32_config, 20 | ) 21 | 22 | # NOTE: future revision will support Conv2d and LayerNorm for LoRA 23 | only_linear = True 24 | mode = "ddp" 25 | 26 | # Using custom-attention because of NonDynamicallyQuantizableLinear is not LoRA compatible 27 | lte.misc.use_custom_attention(model) 28 | 29 | # Parameters are ignored in LTE so for exact behavior use fixed position embedding 30 | # or modifiy the existing codebase to ensure gradients does not flow-between each other. 31 | # We added sinusoidal embedding in lte.misc.position for vision models. 32 | model.encoder.pos_embedding.requires_grad = False 33 | model.class_token.requires_grad = False 34 | 35 | if only_linear: 36 | ### OPTION:1 37 | # Converting Conv2d to Linear for simplicity (although LoRA supports Conv2d as well) 38 | lte.misc.replace_conv_proj_with_linear(model) 39 | 40 | # Disabling layer normalization affine parameters (it usually performs worse with affine parameters) 41 | lte.misc.disable_norm_affine_parameters(model) 42 | 43 | model = lte.prepare_model_for_lte( 44 | model.cuda(), 45 | lte.LTEConfig.default( 46 | lora_r=32, 47 | lora_alpha=4096, 48 | num_heads=32, 49 | ), 50 | mode=mode, 51 | strict=True, 52 | ) 53 | print(model) 54 | 55 | else: 56 | ### OPTION:2 57 | replica_layers = [m for m in model.modules() if isinstance(m, (nn.Conv2d, nn.LayerNorm))] 58 | replica_layers.append(model.heads) 59 | 60 | model = lte.prepare_model_for_lte( 61 | model.cuda(), 62 | lte.LTEConfig.default( 63 | lora_r=32, 64 | lora_alpha=4096, 65 | num_heads=32, 66 | ), 67 | replica_layers=replica_layers, 68 | mode=mode, 69 | strict=True, 70 | ) 71 | print(model) 72 | -------------------------------------------------------------------------------- /lte/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import LTEConfig 2 | from lte import misc 3 | from lte.base import LTELayer, ReplicaLayer 4 | from lte.replica import MultiheadReplicaLayer 5 | from lte.prepare import prepare_model_for_lte -------------------------------------------------------------------------------- /lte/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class LTELayer(nn.Module): 4 | """ 5 | The base class for LTE layers. Used to universally identify LTE layers. 6 | """ 7 | def __init__(self): 8 | super().__init__() 9 | self.num_heads = None 10 | self._repr_A = None 11 | self._repr_B = None 12 | return 13 | 14 | def __repr__(self): 15 | repr_str = \ 16 | f'MultiheadLoraLayer( {self.num_heads} x ' + \ 17 | '{\n' + \ 18 | ' ' * 4 + 'lora_A_weight: ' + self._repr_A + '\n' + \ 19 | ' ' * 4 + 'lora_B_weight: ' + self._repr_B + '\n' + \ 20 | '})' 21 | return repr_str 22 | 23 | 24 | class ReplicaLayer(nn.Module): 25 | """ 26 | The base class for Replica layers. Used to universally identify Replica layers. 27 | """ 28 | def __init__(self): 29 | super().__init__() 30 | self.num_heads = None 31 | self._repr = None 32 | return 33 | 34 | def __repr__(self): 35 | self._repr = self.replicas[0].__repr__() 36 | return f"Replica( {self.num_heads} x {self._repr} )" 37 | -------------------------------------------------------------------------------- /lte/config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | 4 | class LTEConfig(): 5 | """ 6 | LTE configuration. 7 | 8 | Example:: 9 | lte_config = LTEConfig.default(lora_r=16) 10 | 11 | model = lte.deprecated.prepare_model_for_lte( 12 | model, 13 | lte_config, 14 | ) 15 | """ 16 | 17 | @staticmethod 18 | def default(**kwargs): 19 | cfg = CN() 20 | 21 | cfg.lora = CN() 22 | cfg.lora.lora_r = 8 23 | cfg.lora.lora_alpha = 16 24 | cfg.lora.lora_bias = False 25 | cfg.lora.num_heads = 1 26 | 27 | # If you want to eventually add custom layers to the model, you can add them here 28 | # Anything below will only be applied to Linear layers. If you want to use different 29 | # LTE parameterization for differen layer, you can customize it here. 30 | cfg.lora.linear = CN() 31 | # cfg.lora.linear.lora_r = 32 32 | 33 | # Override any default values with the kwargs 34 | LTEConfig.override_kwargs(cfg, kwargs) 35 | return cfg 36 | 37 | 38 | @staticmethod 39 | def override_kwargs(cfg, kwargs): 40 | for k, v in kwargs.items(): 41 | if k not in cfg.lora.keys(): 42 | raise ValueError(f"Invalid lora config {k}") 43 | cfg.lora[k] = v 44 | return cfg 45 | -------------------------------------------------------------------------------- /lte/ddp/ddp_lte.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from lte import LTELayer 4 | 5 | 6 | class DistributedDataParallelLTE(LTELayer): 7 | """ 8 | Virtual DDP wrapper for LTE. 9 | The main weight is shared across all LoRA devices. However, this will require 10 | two forward passed on the main weight in the reset-less version. 11 | This will reduce overall computation cost and memory usage. 12 | Each GPU device will compute the forward pass on all the LoRA parameters. 13 | This layer is meant to be used with torchrun. 14 | 15 | Use DMP for more faithful implementation. 16 | A more efficient implementation would be to mix both DDP and DMP. 17 | 18 | Args: 19 | num_heads (int): number of LoRA heads 20 | lora_bias (bool): whether to use bias for LoRA 21 | lora_alpha (int): the LoRA scaling factor 22 | lora_r (int): the rank of LoRA 23 | """ 24 | def __init__( 25 | self, 26 | num_heads, 27 | lora_bias=False, 28 | lora_alpha=4096, 29 | lora_r=32, 30 | ): 31 | 32 | self.num_heads = num_heads 33 | self.lora_alpha = lora_alpha 34 | self.lora_r = lora_r 35 | self.lora_bias = lora_bias 36 | 37 | self._orig_param_names = [n for n, _ in self.named_parameters()] 38 | [p.requires_grad_(False) for p in self.parameters()] 39 | 40 | self.lora_A = nn.ModuleList() 41 | self.lora_B = nn.ModuleList() 42 | 43 | self.register_buffer("lora_initialized", torch.zeros(1)) 44 | self.register_buffer('merged', torch.zeros(1)) 45 | return 46 | 47 | def create_parallelized_params(self): 48 | """ creates vmap forward pass to evaluate all lora paths at the same time """ 49 | 50 | for param in self.parameters(): 51 | param.requires_grad_(False) 52 | 53 | A_params, _ = torch.func.stack_module_state(self.lora_A) 54 | self.lora_A_weight = nn.Parameter(A_params['weight']) 55 | 56 | B_params, _ = torch.func.stack_module_state(self.lora_B) 57 | self.lora_B_weight = nn.Parameter(B_params['weight']) 58 | 59 | if self.lora_bias: 60 | self.lora_A_bias = nn.Parameter(A_params["bias"]) 61 | self.lora_B_bias = nn.Parameter(B_params['bias']) 62 | else: 63 | self.lora_A_bias, self.lora_B_bias = None, None 64 | 65 | # skeleton template for forward pass 66 | self.lora_A = self.lora_A[0] 67 | self.lora_B = self.lora_B[0] 68 | 69 | # delete old parameters 70 | for m in [self.lora_A, self.lora_B]: 71 | for p in m._parameters.values(): 72 | del p 73 | 74 | self.register_buffer("prev_lora_A_weight", torch.zeros_like(self.lora_A_weight)) 75 | self.register_buffer("prev_lora_B_weight", torch.zeros_like(self.lora_B_weight)) 76 | if self.lora_bias: 77 | self.register_buffer("prev_lora_A_bias", torch.zeros_like(self.lora_A_bias)) 78 | self.register_buffer("prev_lora_B_bias", torch.zeros_like(self.lora_B_bias)) 79 | else: 80 | self.prev_lora_A_bias, self.prev_lora_B_bias = None, None 81 | return 82 | 83 | def convert_to_lte_module(self): 84 | """ converts module into LTE module """ 85 | self.scaling = self.lora_alpha / self.lora_r 86 | 87 | # store representation 88 | self._repr_A = list(self.lora_A)[0].__repr__() 89 | self._repr_B = list(self.lora_B)[0].__repr__() 90 | 91 | self.create_parallelized_params() 92 | self.reset_lora_parameters() 93 | self.lora_initialized.data[0] = 1 94 | return 95 | 96 | def get_lora_params(self): 97 | """ retrieves lora paramters """ 98 | A = self.lora_A_weight 99 | B = self.lora_B_weight 100 | 101 | b_A, b_B = None, None 102 | if self.lora_bias: 103 | b_A = self.lora_A_bias 104 | b_B = self.lora_B_bias 105 | return (A, B), (b_A, b_B) 106 | 107 | 108 | @torch.no_grad() 109 | def merge_parameters(self): 110 | """ merges all lora parameters into the main module """ 111 | 112 | def average_merging(delta_weights, delta_biases=None): 113 | if delta_biases is None: 114 | return delta_weights.mean(0), delta_biases 115 | return delta_weights.mean(0), delta_biases.mean(0) 116 | 117 | lora_delta_weights, lora_delta_biases = self.compute_delta() 118 | 119 | if self.merged: 120 | # subtracting previous delta to compute correct update 121 | prev_delta_weights, prev_delta_biases = \ 122 | self.lora_to_delta( 123 | self.prev_lora_A_weight, 124 | self.prev_lora_B_weight, 125 | self.prev_lora_A_bias, 126 | self.prev_lora_B_bias, 127 | self.scaling 128 | ) 129 | 130 | lora_delta_weights -= prev_delta_weights 131 | if self.lora_bias: 132 | lora_delta_biases -= prev_delta_biases 133 | 134 | self.prev_lora_A_weight.data = self.lora_A_weight.data.detach().clone() 135 | self.prev_lora_B_weight.data = self.lora_B_weight.data.detach().clone() 136 | 137 | if self.lora_bias: 138 | self.prev_lora_A_bias.data = self.lora_A_bias.data.detach().clone() 139 | self.prev_lora_B_bias.data = self.lora_B_bias.data.detach().clone() 140 | 141 | delta_weight, delta_bias = average_merging(lora_delta_weights, lora_delta_biases) 142 | 143 | self.weight.data += delta_weight.data.clone().detach().to(self.weight.dtype) 144 | if self.lora_bias: 145 | self.bias.data += delta_bias.data.clone().detach().to(self.bias.dtype) 146 | 147 | self.merged.data[0] = 1 148 | return 149 | 150 | def parallel_lora_forward(self, inputs): 151 | """ 152 | Chunks the inputs and applies the parallel forward pass across all LoRA layers 153 | For example given a batch x = [x1, x2] with 2 LoRA heads lora1 and lora2 154 | the output is y = [lora1(x1), lora2(x2)]. If you want same data to be processed 155 | across all lora heads, replicate the mini-batch by the number of heads in the 156 | main optimization loop. 157 | 158 | Args: 159 | inputs (torch.Tensor): the input tensor of shape 160 | Returns: 161 | torch.Tensor: the output tensor of shape 162 | """ 163 | input_shape = inputs.shape 164 | 165 | # reshapes tensor into [num_heads x batch_size x ... x features] 166 | inputs = inputs.unflatten(0, (self.num_heads, -1)) 167 | if isinstance(self, nn.Linear): 168 | # convert to 3D tensor [num_heads x batch_size x features] 169 | inputs = inputs.reshape(self.num_heads, -1, input_shape[-1]) 170 | in_dims = 0 171 | 172 | # higher chunk will be slower but can save memory 173 | chunk_size = 1 # math.ceil(self.num_heads // 16) 174 | 175 | x = parallel_lora_forward( 176 | inputs, 177 | self.lora_A, 178 | self.lora_B, 179 | self.lora_A_weight, 180 | self.lora_B_weight, 181 | self.lora_A_bias, 182 | self.lora_B_bias, 183 | use_baddbmm_linear=isinstance(self, nn.Linear), 184 | in_dims=in_dims, 185 | chunk_size=chunk_size, 186 | ) 187 | 188 | # subtract contribution of itself from previous synchronization 189 | x -= parallel_lora_forward( 190 | inputs, 191 | self.lora_A, 192 | self.lora_B, 193 | self.prev_lora_A_weight, 194 | self.prev_lora_B_weight, 195 | self.prev_lora_A_bias, 196 | self.prev_lora_B_bias, 197 | use_baddbmm_linear=isinstance(self, nn.Linear), 198 | in_dims=in_dims, 199 | chunk_size=chunk_size, 200 | ) 201 | 202 | # scaling parameter as a tensor 203 | x *= self.scaling 204 | 205 | if isinstance(self, nn.Linear): 206 | x = x.reshape(*input_shape[:-1], x.shape[-1]) 207 | else: 208 | x = x.flatten(0, 1) 209 | return x 210 | 211 | 212 | def baddbmm_linear(x, lora_A_weight, lora_A_bias, lora_B_weight, lora_B_bias): 213 | """ 214 | Batched matmul using BLAS and LAPACK operations. 215 | Faster than vmapping using. 216 | 217 | Args: 218 | x (torch.Tensor): input tensor of shape [num_heads x batch_size x features] 219 | loar_A_weight (torch.Tensor): first LoRA parameters of shape [num_heads x in_features x r] 220 | loar_A_bias (torch.Tensor): first LoRA bias of shape [num_heads x r] 221 | loar_B_weight (torch.Tensor): second LoRA parameters of shape [num_heads x r x out_features] 222 | loar_B_bias (torch.Tensor): second LoRA bias of shape [num_heads x out_features] 223 | Returns: 224 | torch.Tensor: output tensor of shape [num_heads x batch_size x out_features] 225 | 226 | NOTE: always assumes sequence dimension is flattened with the batch dimension. 227 | """ 228 | assert len(x.shape) == 3, f'Expected 3D tensor got {x.shape}' 229 | 230 | if lora_A_bias is not None: 231 | x = torch.baddbmm(lora_A_bias.unsqueeze(1), x, lora_A_weight.permute(0, 2, 1)) 232 | x = torch.baddbmm(lora_B_bias.unsqueeze(1), x, lora_B_weight.permute(0, 2, 1)) 233 | else: 234 | x = torch.bmm(x, lora_A_weight.permute(0, 2, 1)) 235 | x = torch.bmm(x, lora_B_weight.permute(0, 2, 1)) 236 | return x 237 | 238 | 239 | def mhlora_baddbmm_linear(x, lora_A_weight, lora_A_bias, lora_B_weight, lora_B_bias): 240 | """ 241 | Special case of baddbmm_linear to reduce memory usage. 242 | The same input is used for all heads. 243 | Good for MHLoRA or evaluation while still enabling LTE. 244 | 245 | Args: 246 | x (torch.Tensor): input tensor of shape [batch_size x features] 247 | loar_A_weight (torch.Tensor): first LoRA parameters of shape [num_heads x in_features x r] 248 | loar_A_bias (torch.Tensor): first LoRA bias of shape [num_heads x r] 249 | loar_B_weight (torch.Tensor): second LoRA parameters of shape [num_heads x r x out_features] 250 | loar_B_bias (torch.Tensor): second LoRA bias of shape [num_heads x out_features] 251 | Returns: 252 | torch.Tensor: output tensor of shape [num_heads x batch_size x out_features] 253 | 254 | NOTE: for cleanliness of the code, we disabled force paralleization option and 255 | this function not used in the current codebase, however we leave it for those who might want to use it. 256 | """ 257 | assert len(x.shape) == 2, f'Expected 2D tensor got {x.shape}' 258 | 259 | (b, f_in), (h, f_out) = x.shape, lora_B_weight.shape[:2] 260 | 261 | # more memory efficient than tiling it first and using baddbmm 262 | x = (x @ lora_A_weight.view(-1, f_in).T).unflatten(-1, (h, -1)) 263 | if lora_A_bias is not None: 264 | x += lora_A_bias.unsqueeze(0) 265 | 266 | x = x.swapaxes(0, 1) 267 | if lora_B_bias is not None: 268 | x = torch.baddbmm(lora_B_bias.unsqueeze(1), x, lora_B_weight.transpose(1, 2)) 269 | else: 270 | x = torch.bmm(x, lora_B_weight.transpose(1, 2)) 271 | return x.view(h, b, f_out) 272 | 273 | 274 | def parallel_lora_forward( 275 | x, 276 | lora_A_fn, 277 | lora_B_fn, 278 | lora_A_weight, 279 | lora_B_weight, 280 | lora_A_bias=None, 281 | lora_B_bias=None, 282 | use_baddbmm_linear=False, 283 | in_dims=(0, 0, None), 284 | chunk_size=1 285 | ): 286 | """ 287 | Applies the forward pass of the parallel Lora network. 288 | 289 | Args: 290 | x (torch.Tensor): Input tensor. 291 | lora_A_fn (callable): Function that applies the forward pass of the first Lora network. 292 | lora_B_fn (callable): Function that applies the forward pass of the second Lora network. 293 | lora_A_weight (torch.Tensor): Weight tensor for the first Lora network. 294 | lora_B_weight (torch.Tensor): Weight tensor for the second Lora network. 295 | lora_A_bias (torch.Tensor, optional): Bias tensor for the first Lora network. Defaults to None. 296 | lora_B_bias (torch.Tensor, optional): Bias tensor for the second Lora network. Defaults to None. 297 | use_baddbmm_linear (bool, optional): If True, uses baddbmm for faster computation. Defaults to False. 298 | in_dims (tuple, optional): Tuple of input dimensions for torch.vmap. Defaults to (0, 0, None). 299 | chunk_size (int, optional): Chunk size for torch.vmap. Defaults to 1. 300 | 301 | Returns: 302 | torch.Tensor: Output tensor after applying the forward pass of the parallel Lora network. 303 | """ 304 | if use_baddbmm_linear: 305 | if len(x.shape) == 2: 306 | baddbmm_function = mhlora_baddbmm_linear 307 | elif len(x.shape) == 3: 308 | baddbmm_function = baddbmm_linear 309 | else: 310 | raise RuntimeError(f'Unsupported input shape {x.shape}') 311 | 312 | x = baddbmm_function( 313 | x, 314 | lora_A_weight, 315 | lora_A_bias, 316 | lora_B_weight, 317 | lora_B_bias, 318 | ) 319 | else: 320 | 321 | def parallelize_A(params, buffers, data): 322 | return torch.func.functional_call(lora_A_fn, (params, buffers), (data,)) 323 | 324 | def parallelize_B(params, buffers, data): 325 | return torch.func.functional_call(lora_B_fn, (params, buffers), (data,)) 326 | 327 | vmap_A = torch.vmap(parallelize_A, in_dims=in_dims, chunk_size=chunk_size) 328 | vmap_B = torch.vmap(parallelize_B, chunk_size=chunk_size) 329 | 330 | if lora_A_bias is not None: 331 | x = vmap_A({'weight': lora_A_weight, 'bias': lora_A_bias}, {}, x) 332 | else: 333 | x = vmap_A({'weight': lora_A_weight}, {}, x) 334 | 335 | if lora_B_bias is not None: 336 | x = vmap_B({'weight': lora_B_weight, 'bias': lora_B_bias}, {}, x) 337 | else: 338 | x = vmap_B({'weight': lora_B_weight}, {}, x) 339 | return x 340 | -------------------------------------------------------------------------------- /lte/ddp/linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from lte.ddp.ddp_lte import DistributedDataParallelLTE 5 | 6 | 7 | class MultiheadLoRALinear( 8 | DistributedDataParallelLTE, 9 | nn.Linear, 10 | ): 11 | """ 12 | Multihead Linear layer with LoRA [distributed-data parallel version] 13 | 14 | Args: 15 | in_features (int): the number of input features 16 | out_features (int): the number of output features 17 | bias (bool): whether to use bias 18 | num_heads (int): the number of heads 19 | lora_r (int): the rank of LoRA 20 | lora_alpha (int): the alpha value for LoRA 21 | lora_bias (bool): whether to use bias for LoRA 22 | """ 23 | def __init__( 24 | self, 25 | in_features: int, 26 | out_features: int, 27 | bias=True, 28 | num_heads: int = 2, 29 | lora_r: int = 1, 30 | lora_alpha: int = 1, 31 | lora_bias: bool = False, 32 | ): 33 | 34 | nn.Linear.__init__(self, in_features, out_features, bias) 35 | DistributedDataParallelLTE.__init__( 36 | self, 37 | num_heads=num_heads, 38 | lora_bias=lora_bias, 39 | lora_alpha=lora_alpha, 40 | lora_r=lora_r, 41 | ) 42 | 43 | for _ in range(num_heads): 44 | self.lora_A.append(nn.Linear(in_features, lora_r, bias=lora_bias)) 45 | self.lora_B.append(nn.Linear(lora_r, out_features, bias=lora_bias)) 46 | 47 | self.convert_to_lte_module() 48 | return 49 | 50 | def forward(self, x): 51 | """ 52 | Args: 53 | x (torch.Tensor): the input tensor 54 | Returns: 55 | outputs (torch.Tensor): the output tensor 56 | 57 | if not self.training then the forward pass is the same as the original Linear layer 58 | and uses the latest merged weights and biases. 59 | """ 60 | outputs = super().forward(x) 61 | if self.training: 62 | if x.size(0) % self.num_heads != 0: 63 | raise ValueError("During training input size must be divisible by num_heads") 64 | outputs = outputs + self.parallel_lora_forward(x) 65 | return outputs 66 | 67 | @torch.no_grad() 68 | def compute_delta(self): 69 | """ computes the delta weight and bias for lora """ 70 | (A, B), (b_A, b_B) = self.get_lora_params() 71 | return self.lora_to_delta(A, B, b_A, b_B, self.scaling) 72 | 73 | @torch.no_grad() 74 | def lora_to_delta(self, A, B, b_A, b_B, scaling): 75 | """ computes the delta weight and bias for lora """ 76 | delta_weight = self.scaling * B @ A 77 | delta_bias = None 78 | 79 | if self.lora_bias: 80 | delta_bias = self.scaling * \ 81 | (B @ b_A.unsqueeze(2) + b_B.unsqueeze(2)).squeeze(2) 82 | return delta_weight, delta_bias 83 | 84 | @torch.no_grad() 85 | def reset_lora_parameters(self): 86 | """ resets lora parameters. default is orthogonal initialization """ 87 | 88 | def init_param(params): 89 | for p in params: 90 | nn.init.orthogonal_(p) 91 | p.data *= math.sqrt(p.shape[1] / p.shape[0]) 92 | return 93 | 94 | init_param(self.lora_A_weight.data) 95 | if self.lora_bias: 96 | nn.init.zeros_(self.lora_A_bias) 97 | 98 | nn.init.zeros_(self.lora_B_weight.data) 99 | if self.lora_bias: 100 | nn.init.zeros_(self.lora_B_bias) 101 | return 102 | -------------------------------------------------------------------------------- /lte/dmp/dmp_lte.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from lte import LTELayer 4 | 5 | 6 | class DistributedModelParallelLTE(LTELayer): 7 | """ 8 | Virtual DMP wrapper for LTE. 9 | Unlike the DDP version, for N devices we create N copies of the main weight and 10 | also the N LoRA parameters. Only computes forward pass once as the previous 11 | weights can be directly merged into the main weight. 12 | This layer will automatically assign the weights to the correct device and does 13 | not requires torchrun. 14 | 15 | This implementation is more faithful however requires more compute and memory. 16 | If developing on a single node, it is recommended to proto-type on DMP 17 | and re-implement the layer into DDP. 18 | 19 | Args: 20 | num_heads (int): number of LoRA heads 21 | lora_bias (bool): whether to use bias for LoRA 22 | lora_alpha (int): the LoRA scaling factor 23 | lora_r (int): the rank of LoRA 24 | """ 25 | 26 | def __init__( 27 | self, 28 | num_heads, 29 | lora_bias=False, 30 | lora_alpha=4096, 31 | lora_r=32, 32 | ): 33 | 34 | self.num_heads = num_heads 35 | self.lora_alpha = lora_alpha 36 | self.lora_r = lora_r 37 | self.lora_bias = lora_bias 38 | self.main_device = "cuda:0" 39 | 40 | self.lora_A = nn.ModuleList() 41 | self.lora_B = nn.ModuleList() 42 | self.layers = nn.ModuleList() 43 | 44 | self.register_buffer('merged', torch.zeros(1)) 45 | self.register_buffer("lora_initialized", torch.zeros(1)) 46 | return 47 | 48 | def convert_to_lte_module(self): 49 | self.scaling = self.lora_alpha / self.lora_r 50 | 51 | # store representation 52 | self._repr_A = list(self.lora_A)[0].__repr__() 53 | self._repr_B = list(self.lora_B)[0].__repr__() 54 | 55 | self.weight.requires_grad_(False) 56 | self.bias.requires_grad_(False) 57 | 58 | for p in self.layers.parameters(): 59 | p.requires_grad_(False) 60 | 61 | self.reset_lora_parameters() 62 | self.lora_initialized.data[0] = 1 63 | return 64 | 65 | def __repr__(self): 66 | repr_str = \ 67 | f'MultiheadLoraLayer( {self.num_heads} x ' + \ 68 | '{\n' + \ 69 | ' ' * 4 + 'lora_A_weight: ' + self._repr_A + '\n' + \ 70 | ' ' * 4 + 'lora_B_weight: ' + self._repr_B + '\n' + \ 71 | '})' 72 | return repr_str 73 | 74 | def get_lora_params(self): 75 | """ retrieves lora paramters """ 76 | A = torch.stack([m.weight.to(device=self.main_device) for m in self.lora_A]) 77 | B = torch.stack([m.weight.to(device=self.main_device) for m in self.lora_B]) 78 | 79 | b_A, b_B = None, None 80 | if self.lora_bias: 81 | b_A = torch.stack([m.bias.to(device=self.main_device) for m in self.lora_A]) 82 | b_B = torch.stack([m.bias.to(device=self.main_device) for m in self.lora_B]) 83 | return (A, B), (b_A, b_B) 84 | 85 | @torch.no_grad() 86 | def merge_parameters(self): 87 | """ merges all lora parameters into the main module """ 88 | 89 | def average_merging(delta_weights, delta_biases=None): 90 | if delta_biases is None: 91 | return delta_weights.mean(0), delta_biases 92 | return delta_weights.mean(0), delta_biases.mean(0) 93 | 94 | lora_delta_weights, lora_delta_biases = self.compute_delta() 95 | 96 | if not self.merged: 97 | # register for the first time 98 | self.register_buffer('prev_delta_weights', torch.zeros_like(lora_delta_weights.data.clone())) 99 | 100 | if self.lora_bias: 101 | self.register_buffer('prev_delta_biases', torch.zeros_like(lora_delta_biases.data.clone())) 102 | 103 | delta_weight, delta_bias = \ 104 | average_merging( 105 | lora_delta_weights - self.prev_delta_weights, 106 | lora_delta_biases if (not self.lora_bias) else \ 107 | lora_delta_biases - self.prev_delta_biases 108 | ) 109 | 110 | self.prev_delta_weights.data = lora_delta_weights.data.clone() 111 | if self.lora_bias: 112 | self.prev_delta_biases.data = lora_delta_biases.data.clone() 113 | 114 | self.weight.data += delta_weight.data.clone().to(device=self.weight.device) 115 | if self.lora_bias: 116 | self.bias.data += delta_bias.data.clone().to(device=self.bias.device) 117 | 118 | for i in range(len(self.layers)): 119 | device = self.layers[i].weight.device 120 | 121 | self.layers[i].weight.data = self.weight.data.clone().to(device) 122 | self.layers[i].weight.data -= lora_delta_weights[i].data.clone().to(device) 123 | 124 | if self.lora_bias: 125 | self.layers[i].bias.data = self.bias.data.clone().to(device) 126 | self.layers[i].bias.data -= lora_delta_biases[i].data.clone().to(device) 127 | 128 | self.merged.data[0] = 1 129 | return delta_weight, delta_bias 130 | 131 | def parallel_lora_forward(self, inputs): 132 | """ 133 | Applies the LoRA forward pass in parallel 134 | 135 | Args: 136 | inputs (Tensor): the input tensor 137 | Returns: 138 | outputs (Tensor): the output tensor 139 | """ 140 | inputs = inputs.chunk(self.num_heads) 141 | outputs = [] 142 | 143 | for x, lora_A, lora_B in zip(inputs, self.lora_A, self.lora_B): 144 | outputs.append(lora_B(lora_A(x.to(lora_A.weight.device)))) 145 | 146 | outputs = torch.cat([x.to(device=inputs.device) for x in outputs]) 147 | return outputs 148 | -------------------------------------------------------------------------------- /lte/dmp/linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from lte.dmp.dmp_lte import DistributedModelParallelLTE 6 | import lte.misc.distributed as D 7 | 8 | 9 | class MultiheadLoRALinear( 10 | DistributedModelParallelLTE, 11 | nn.Linear, 12 | ): 13 | """ 14 | Multihead Linear layer with LoRA [distributed-model parallel version] 15 | 16 | Args: 17 | in_features (int): the number of input features 18 | out_features (int): the number of output features 19 | bias (bool): whether to use bias 20 | num_heads (int): the number of heads 21 | lora_r (int): the rank of LoRA 22 | lora_alpha (int): the alpha value for LoRA 23 | lora_bias (bool): whether to use bias for LoRA 24 | """ 25 | 26 | def __init__( 27 | self, 28 | in_features: int, 29 | out_features: int, 30 | bias=True, 31 | num_heads: int = 2, 32 | lora_r: int = 1, 33 | lora_alpha: int = 1, 34 | lora_bias: bool = False, 35 | ): 36 | 37 | nn.Linear.__init__(self, in_features, out_features, bias) 38 | DistributedModelParallelLTE.__init__( 39 | self, 40 | num_heads=num_heads, 41 | lora_bias=lora_bias, 42 | lora_alpha=lora_alpha, 43 | lora_r=lora_r, 44 | ) 45 | 46 | for i in range(num_heads): 47 | device_id = f"cuda:{i % D.num_visible_devices()}" 48 | self.layers.append(nn.Linear(in_features, out_features, bias=bias, device=device_id)) 49 | self.lora_A.append(nn.Linear(in_features, lora_r, bias=lora_bias, device=device_id)) 50 | self.lora_B.append(nn.Linear(lora_r, out_features, bias=lora_bias, device=device_id)) 51 | 52 | self.convert_to_lte_module() 53 | return 54 | 55 | def forward(self, inputs): 56 | """ 57 | Args: 58 | x (torch.Tensor): the input tensor 59 | Returns: 60 | outputs (torch.Tensor): the output tensor 61 | 62 | if not self.training then the forward pass is the same as the original Linear layer 63 | and uses the latest merged weights and biases. 64 | """ 65 | if not self.training: 66 | outputs = F.linear(inputs.to(device=self.weight.device), self.weight, self.bias) 67 | else: 68 | if inputs.size(0) % self.num_heads != 0: 69 | raise ValueError("During training input size must be divisible by num_heads") 70 | xs = inputs.chunk(self.num_heads) 71 | outputs = [] 72 | 73 | for x, layer, lora_A, lora_B in zip(xs, self.layers, self.lora_A, self.lora_B): 74 | x = x.to(device=lora_A.weight.device) 75 | s = self.scaling 76 | outputs.append(s * lora_B(lora_A(x)) + layer(x)) 77 | 78 | outputs = torch.cat([x.to(device=inputs.device) for x in outputs]) 79 | return outputs 80 | 81 | @torch.no_grad() 82 | def compute_delta(self): 83 | """ computes the delta weight and bias for lora """ 84 | (A, B), (b_A, b_B) = self.get_lora_params() 85 | 86 | delta_weight = self.scaling * B @ A 87 | delta_bias = None 88 | 89 | if self.lora_bias: 90 | delta_bias = self.scaling * (B @ b_A.unsqueeze(2) + b_B.unsqueeze(2)).squeeze(2) 91 | return delta_weight, delta_bias 92 | 93 | @torch.no_grad() 94 | def reset_lora_parameters(self): 95 | """ resets lora parameters. default is orthogonal initialization """ 96 | 97 | def init_param(p): 98 | nn.init.orthogonal_(p) 99 | p.data *= math.sqrt(p.shape[1] / p.shape[0]) 100 | return 101 | 102 | for lora_A, lora_B in zip(self.lora_A, self.lora_B): 103 | init_param(lora_A.weight.data) 104 | if self.lora_bias: 105 | nn.init.zeros_(lora_A.bias.data) 106 | 107 | nn.init.zeros_(lora_B.weight.data) 108 | if self.lora_bias: 109 | nn.init.zeros_(lora_B.bias.data) 110 | return 111 | -------------------------------------------------------------------------------- /lte/mhlora/linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from lte import LTELayer 5 | 6 | 7 | class MultiheadLoRALinear(nn.Linear, LTELayer): 8 | """ 9 | Args: 10 | in_features (int): the number of input features 11 | out_features (int): the number of output features 12 | bias (bool): whether to use bias 13 | num_heads (int): the number of heads 14 | lora_r (int): the rank of LoRA 15 | lora_alpha (int): the alpha value for LoRA 16 | lora_bias (bool): whether to use bias for LoRA 17 | """ 18 | 19 | def __init__( 20 | self, 21 | in_features: int, 22 | out_features: int, 23 | bias=True, 24 | num_heads: int = 2, 25 | lora_r: int = 1, 26 | lora_alpha: int = 1, 27 | lora_bias: bool = False, 28 | ): 29 | 30 | nn.Linear.__init__(self, in_features, out_features, bias) 31 | self.lora_alpha = lora_alpha 32 | self.lora_r = lora_r 33 | self.lora_bias = lora_bias 34 | self.scaling = self.lora_alpha / self.lora_r 35 | 36 | self.lora_A, self.lora_B = [], [] 37 | 38 | for _ in range(num_heads): 39 | self.lora_A.append(nn.Linear(in_features, lora_r, bias=lora_bias)) 40 | self.lora_B.append(nn.Linear(lora_r, out_features, bias=lora_bias)) 41 | 42 | self.lora_A = nn.ModuleList(self.lora_A) 43 | self.lora_B = nn.ModuleList(self.lora_B) 44 | 45 | # store representation 46 | self._repr_A = list(self.lora_A)[0].__repr__() 47 | self._repr_B = list(self.lora_B)[0].__repr__() 48 | self.reset_lora_parameters() 49 | 50 | # disable training of original parameters 51 | self.weight.requires_grad = False 52 | if self.bias is not None: 53 | self.bias.requires_grad = False 54 | return 55 | 56 | def forward(self, x): 57 | """ 58 | Args: 59 | x (torch.Tensor): the input tensor 60 | Returns: 61 | outputs (torch.Tensor): the output tensor 62 | """ 63 | outputs = super().forward(x) 64 | for A, B in zip(self.lora_A, self.lora_B): 65 | outputs += self.scaling * B(A(x)) 66 | return outputs 67 | 68 | @torch.no_grad() 69 | def reset_lora_parameters(self): 70 | """ resets lora parameters. default is orthogonal initialization """ 71 | 72 | def init_param(p): 73 | nn.init.orthogonal_(p) 74 | p.data *= math.sqrt(p.shape[1] / p.shape[0]) 75 | return 76 | 77 | for lora_A, lora_B in zip(self.lora_A, self.lora_B): 78 | init_param(lora_A.weight.data) 79 | if self.lora_bias: 80 | nn.init.zeros_(lora_A.bias.data) 81 | 82 | nn.init.zeros_(lora_B.weight.data) 83 | if self.lora_bias: 84 | nn.init.zeros_(lora_B.bias.data) 85 | return 86 | -------------------------------------------------------------------------------- /lte/misc/__init__.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | import torch.nn as nn 4 | import einops 5 | 6 | from lte.misc.attention import MultiheadAttention 7 | from lte.misc import merge, position, common 8 | 9 | 10 | def _get_submodules(model, key): 11 | parent = model.get_submodule(".".join(key.split(".")[:-1])) 12 | target_name = key.split(".")[-1] 13 | target = model.get_submodule(key) 14 | return parent, target, target_name 15 | 16 | 17 | def use_custom_attention(model, split_qkv=False): 18 | """ 19 | Replaces torch.MultiheadAttention with custom MultiheadAttention. 20 | LTE looks for nn.Linear which is not used in the torch.MultiheadAttention. 21 | Updates model in place but returns the model for convenience. 22 | 23 | Args: 24 | model (nn.Module): the model to convert 25 | Returns: 26 | model (nn.Module): the model with custom MultiheadAttention modules 27 | 28 | Example:: 29 | model = lte.misc.use_custom_attention(model) 30 | """ 31 | 32 | key_list = [key for key, _ in model.named_modules()] 33 | for key in key_list: 34 | parent_module, old_module, target_name = _get_submodules(model, key) 35 | 36 | if isinstance(old_module, nn.MultiheadAttention): 37 | new_module = MultiheadAttention( 38 | embed_dim=old_module.embed_dim, 39 | num_heads=old_module.num_heads, 40 | dropout=old_module.dropout, 41 | bias=(old_module.in_proj_bias is not None), 42 | split_qkv=split_qkv, 43 | ) 44 | 45 | if not new_module.split_qkv: 46 | new_module.in_proj.weight.data = old_module.in_proj_weight.data 47 | new_module.in_proj.bias.data = old_module.in_proj_bias.data 48 | new_module.out_proj.weight.data = old_module.out_proj.weight.data 49 | new_module.out_proj.bias.data = old_module.out_proj.bias.data 50 | 51 | setattr(parent_module, target_name, new_module) 52 | del old_module 53 | 54 | torch.cuda.empty_cache() 55 | gc.collect() 56 | return model 57 | 58 | 59 | class LinearProjection(nn.Module): 60 | """ 61 | Linear projection layer 62 | 63 | Args: 64 | hidden_dim (int): the hidden dimension of the linear projection 65 | patch_size (int): the patch size of the input image 66 | """ 67 | def __init__(self, hidden_dim, patch_size): 68 | super().__init__() 69 | self.patch_size = patch_size 70 | self.linear_proj = nn.Linear( 71 | 3 * patch_size * patch_size, hidden_dim, bias=True 72 | ) 73 | 74 | def forward(self, x): 75 | """ 76 | Args: 77 | x (torch.Tensor): the input tensor of shape (b, c, h, w) 78 | Returns: 79 | x (torch.Tensor): the output tensor of shape (b, embed, h//patch_size, w//patch_size) 80 | """ 81 | _, _, h, w = x.shape 82 | # patchify 83 | x = einops.rearrange( 84 | x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', 85 | p1=self.patch_size, 86 | p2=self.patch_size, 87 | ) 88 | x = self.linear_proj(x) 89 | 90 | # reshape to the desired output shape 91 | x = einops.rearrange( 92 | x, 'b (h w) c -> b c h w', 93 | h=h//self.patch_size, 94 | w=w//self.patch_size, 95 | ) 96 | return x 97 | 98 | 99 | def replace_conv_proj_with_linear(model): 100 | """ 101 | Replaces all Conv2d modules with kernel_size == stride with LinearProjection. 102 | This is useful for replacing the first layer of a vision transformer with a linear projection. 103 | Helpful for using LoRA on ViT Conv2D projection. 104 | Updates model in place but returns the model for convenience. 105 | 106 | Args: 107 | model (nn.Module): the model to convert 108 | Returns: 109 | model (nn.Module): the model with replaced conv2d modules 110 | 111 | Example:: 112 | model = lte.misc.replace_conv_proj_with_linear(model) 113 | """ 114 | for k, m in model.named_modules(): 115 | parent_module, old_module, target_name = _get_submodules(model, k) 116 | 117 | # replace all conv2d that have same kernel_size and stride with linear projection 118 | if isinstance(old_module, nn.Conv2d) and old_module.kernel_size == old_module.stride: 119 | new_module = LinearProjection(old_module.out_channels, old_module.kernel_size[0]) 120 | 121 | new_module.linear_proj.weight.data.copy_( 122 | old_module.weight.data.moveaxis(1, -1).reshape(old_module.out_channels, -1).clone() 123 | ) 124 | new_module.linear_proj.bias.data.copy_( 125 | old_module.bias.data.clone() 126 | ) 127 | 128 | setattr(parent_module, target_name, new_module) 129 | del old_module 130 | return model 131 | 132 | 133 | def disable_norm_affine_parameters(model): 134 | """ 135 | Disables the affine parameters of all LayerNorm and BatchNorm2d modules in the model. 136 | Updates model in place but returns the model for convenience. 137 | 138 | Args: 139 | model (nn.Module): the model to disable the affine parameters of 140 | Returns: 141 | model (nn.Module): the model with disabled affine parameters 142 | 143 | Example:: 144 | model = lte.misc.disable_norm_affine_parameters(model) 145 | 146 | NOTE: Feel free to add other normalization layers as needed. 147 | """ 148 | for n, m in model.named_modules(): 149 | if isinstance(m, nn.LayerNorm): 150 | m.weight = None 151 | m.bias = None 152 | m.elementwise_affine = False 153 | elif isinstance(m, nn.BatchNorm2d): 154 | m.weight = None 155 | m.bias = None 156 | m.affine = False 157 | return model 158 | -------------------------------------------------------------------------------- /lte/misc/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | from lte.misc import common 6 | 7 | 8 | class MultiheadAttention(nn.Module): 9 | """ MultiHead Attention using PyTorch's scaled_dot_product_attention """ 10 | 11 | def __init__( 12 | self, 13 | embed_dim, 14 | num_heads=8, 15 | dropout=0.0, 16 | bias=True, 17 | split_qkv=True, 18 | ): 19 | super().__init__() 20 | self.bias = bias 21 | self.heads = num_heads 22 | self.dropout = dropout 23 | self.split_qkv = split_qkv 24 | 25 | if self.split_qkv: 26 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 27 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 28 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 29 | else: 30 | self.in_proj = nn.Linear(embed_dim, embed_dim * 3, bias=bias) 31 | 32 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 33 | self.init_weights() 34 | return 35 | 36 | def init_weights(self): 37 | """ 38 | Using same initialization protocol for PyTorch's MultiheadAttention 39 | https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/activation.py#L1041 40 | """ 41 | if self.split_qkv: 42 | for m in [self.q_proj, self.k_proj, self.v_proj]: 43 | torch.nn.init.xavier_uniform_(m.weight) 44 | if self.bias: 45 | torch.nn.init.constant_(m.bias, 0.0) 46 | else: 47 | torch.nn.init.xavier_uniform_(self.in_proj.weight) 48 | if self.bias: 49 | torch.nn.init.constant_(self.in_proj.bias, 0.0) 50 | 51 | if self.bias: 52 | torch.nn.init.constant_(self.out_proj.bias, 0.0) 53 | return 54 | 55 | def in_projection(self, q, k, v): 56 | """ 57 | Args: 58 | q, k, v: torch.Tensor of shape (B, S, D) 59 | Returns: 60 | q, k, v: torch.Tensor of shape (B, H, S, D_head) 61 | """ 62 | if self.split_qkv: 63 | q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v) 64 | else: 65 | q, k, v = self.in_proj(q).chunk(3, dim=-1) 66 | 67 | q, k, v = ( 68 | q.unflatten(-1, (self.heads, -1)).swapaxes(1, 2), 69 | k.unflatten(-1, (self.heads, -1)).swapaxes(1, 2), 70 | v.unflatten(-1, (self.heads, -1)).swapaxes(1, 2), 71 | ) 72 | return q, k, v 73 | 74 | def forward(self, q, k, v, need_weights=False): 75 | q, k, v = self.in_projection(q, k, v) 76 | assert need_weights == False, "need_weights is not supported in this version" 77 | 78 | out = F.scaled_dot_product_attention( 79 | q, k, v, dropout_p=self.dropout, 80 | ).permute(0, 2, 1, 3).flatten(-2, -1) 81 | return self.out_proj(out), None 82 | 83 | 84 | class DeprecatedMultiheadAttention(nn.Module): 85 | """ 86 | This version is deprecated and will be removed in the future. Please use MultiheadAttention instead. 87 | PyTorch 2.2 now natively supports flash-attention-2 88 | """ 89 | 90 | def __init__( 91 | self, 92 | embed_dim, 93 | num_heads=8, 94 | dropout=0.0, 95 | bias=True, 96 | split_qkv=True, 97 | ): 98 | print("This version of MultiheadAttention is deprecated and will be removed in the future.") 99 | super().__init__() 100 | self.bias = bias 101 | self.heads = num_heads 102 | self.scale = (embed_dim // num_heads) ** -0.5 103 | self.split_qkv = split_qkv 104 | self.flash_available, self.flash_attn = self.check_flash_available() 105 | 106 | self.attend = nn.Softmax(dim=-1) 107 | self.dropout = nn.Dropout(dropout) 108 | 109 | if self.split_qkv: 110 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 111 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 112 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 113 | else: 114 | self.in_proj = nn.Linear(embed_dim, embed_dim * 3, bias=bias) 115 | 116 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 117 | self.init_weights() 118 | return 119 | 120 | def check_flash_available(self): 121 | try: 122 | if common.flash_ready_device(): 123 | from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func 124 | 125 | flash_available = True 126 | else: 127 | flash_available = False 128 | print("The current device does not support flash-attention.") 129 | except ImportError: 130 | flash_available = False 131 | flash_attn_qkvpacked_func = None 132 | print( 133 | "Flash-attention not available. " 134 | + "Please install it from https://github.com/Dao-AILab/flash-attention" 135 | ) 136 | return flash_available, flash_attn_qkvpacked_func 137 | 138 | 139 | def init_weights(self): 140 | """ 141 | Using same initialization protocol as pytorch 142 | https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/activation.py#L1041 143 | """ 144 | if self.split_qkv: 145 | for m in [self.q_proj, self.k_proj, self.v_proj]: 146 | torch.nn.init.xavier_uniform_(m.weight) 147 | if self.bias: 148 | torch.nn.init.constant_(m.bias, 0.0) 149 | else: 150 | torch.nn.init.xavier_uniform_(self.in_proj.weight) 151 | if self.bias: 152 | torch.nn.init.constant_(self.in_proj.bias, 0.0) 153 | 154 | if self.bias: 155 | torch.nn.init.constant_(self.out_proj.bias, 0.0) 156 | return 157 | 158 | def in_projection(self, q, k, v): 159 | if self.split_qkv: 160 | q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v) 161 | q, k, v = ( 162 | q.unflatten(-1, (self.heads, -1)), 163 | k.unflatten(-1, (self.heads, -1)), 164 | v.unflatten(-1, (self.heads, -1)), 165 | ) 166 | if not self.flash_available: 167 | q, k, v = map(lambda t: rearrange(t, "b n h d -> b h n d"), [q, k, v]) 168 | else: 169 | q, k, v = self.in_proj(q).chunk(3, dim=-1) 170 | if not self.flash_available: 171 | q, k, v = map( 172 | lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), 173 | [q, k, v], 174 | ) 175 | return q, k, v 176 | 177 | def forward(self, q, k, v, need_weights=True): 178 | q, k, v = self.in_projection(q, k, v) 179 | 180 | if self.flash_available: 181 | # qkv in float16 or bfloat16 182 | qkv = torch.stack([q, k, v], dim=2) 183 | qkv = common.autocast_vars(qkv) 184 | 185 | out = self.flash_attn( 186 | qkv=qkv, 187 | dropout_p=0.0, 188 | softmax_scale=self.scale, 189 | return_attn_probs=True, 190 | ) 191 | out, attn, _ = out 192 | out = out.flatten(-2, -1) 193 | else: 194 | q, k, v = common.autocast_vars(q, k, v) 195 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 196 | 197 | attn = self.attend(dots) 198 | attn = self.dropout(attn) 199 | 200 | out = torch.matmul(attn, v) 201 | out = rearrange(out, "b h n d -> b n (h d)") 202 | 203 | if not need_weights: 204 | attn = None 205 | 206 | return self.dropout(self.out_proj(out)), attn 207 | -------------------------------------------------------------------------------- /lte/misc/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import nullcontext 3 | 4 | 5 | def determine_compute_context(): 6 | """ 7 | Automatically determines the compute dtype and context manager. 8 | Override this function if you want to use a different dtype or context manager. 9 | 10 | Args: 11 | None 12 | Returns: 13 | dtype: torch.dtype for the model parameters not AMP dtype 14 | context: torch.cuda.amp.Gradscaler or contextlib.contextmanager 15 | scaler: torch.cuda.amp.GradScaler 16 | 17 | NOTE: it seems like even if we use bfloat16, using high-precision is important for 18 | numerical stability of the merge. Without using AMP sequential merigng hurts performance. 19 | In practice one would keep a high-precision copy regardless, which was removed for simplicity. 20 | Until the feature for high-precsision parameters with quantization is re-implemented 21 | We will use AMP with bfloat16 also. 22 | """ 23 | 24 | if torch.cuda.is_bf16_supported(): 25 | # dtype = torch.bfloat16 26 | dtype = torch.float32 27 | # context = nullcontext() 28 | # scaler = None 29 | context = torch.cuda.amp.autocast(enabled=True, dtype=torch.float16) 30 | scaler = torch.cuda.amp.GradScaler() 31 | else: 32 | dtype = torch.float32 33 | context = torch.cuda.amp.autocast(enabled=True, dtype=torch.float16) 34 | scaler = torch.cuda.amp.GradScaler() 35 | 36 | return dtype, context, scaler 37 | 38 | 39 | def auto_dtype(): 40 | """ Determines the compute dtype depending on the device. """ 41 | return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 42 | 43 | 44 | def autocast_vars(*vars): 45 | """ 46 | This function was needed when using AMP with Dao-AILab/flash-attention. 47 | As it does not auto-cast if some of the layers are float32 (e.g. LN). 48 | Since the native FA2 support from PyTorch 2.2 this function is no longer needed. 49 | """ 50 | def _autocast(x): 51 | if torch.is_autocast_enabled() and x.dtype != auto_dtype(): 52 | print('forced autocasting') 53 | return x.to(dtype=auto_dtype()) 54 | return x 55 | 56 | if len(vars) == 1: 57 | return _autocast(vars[0]) 58 | return (*[_autocast(x) for x in vars],) 59 | 60 | 61 | def flash_ready_device(): 62 | """ Check if the device has support flash-attention. """ 63 | if not torch.cuda.is_available(): 64 | return False 65 | 66 | device = torch.device("cuda:0") 67 | major, minor = torch.cuda.get_device_capability(device) 68 | return major >= 8 69 | -------------------------------------------------------------------------------- /lte/misc/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | 6 | def num_visible_devices(): 7 | return torch.cuda.device_count() 8 | 9 | 10 | def reduce(tensor, reduction): 11 | """DDP reduction across devices""" 12 | if not is_distributed(): 13 | return tensor 14 | tensor = torch.tensor(tensor, device="cuda") 15 | dist.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) 16 | dist.barrier() 17 | if reduction == "mean": 18 | tensor /= dist.get_world_size() 19 | elif reduction == "sum": 20 | pass 21 | else: 22 | raise ValueError(f"Invalid reduction: {reduction}") 23 | return tensor 24 | 25 | 26 | def is_distributed(): 27 | return dist.is_initialized() and dist.is_available() 28 | 29 | 30 | def local_rank(): 31 | if not is_distributed(): 32 | return 0 33 | return dist.get_rank() 34 | 35 | 36 | def world_size(): 37 | if not is_distributed(): 38 | return 1 39 | return dist.get_world_size() 40 | 41 | 42 | def device(): 43 | return torch.device("cuda", dist.get_rank()) 44 | 45 | 46 | def is_main_process(): 47 | return local_rank() == 0 48 | 49 | 50 | def init_distributed_mode(args): 51 | """Initialize distributed mode""" 52 | 53 | def setup_for_distributed(is_master): 54 | """Disables printing when not in master process""" 55 | import builtins as __builtin__ 56 | 57 | builtin_print = __builtin__.print 58 | 59 | def print(*args, **kwargs): 60 | force = kwargs.pop("force", False) 61 | if is_master or force: 62 | builtin_print(*args, **kwargs) 63 | 64 | __builtin__.print = print 65 | 66 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 67 | args.rank = int(os.environ["RANK"]) 68 | args.world_size = int(os.environ["WORLD_SIZE"]) 69 | args.gpu = int(os.environ["LOCAL_RANK"]) 70 | elif "SLURM_PROCID" in os.environ: 71 | args.rank = int(os.environ["SLURM_PROCID"]) 72 | args.gpu = args.rank % torch.cuda.device_count() 73 | elif hasattr(args, "rank"): 74 | pass 75 | else: 76 | print("Not using distributed mode") 77 | args.distributed = False 78 | return 79 | 80 | args.distributed = True 81 | 82 | torch.cuda.set_device(args.gpu) 83 | args.dist_backend = "nccl" 84 | print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) 85 | dist.init_process_group( 86 | backend=args.dist_backend, 87 | init_method=args.dist_url, 88 | world_size=args.world_size, 89 | rank=args.rank, 90 | ) 91 | dist.barrier() 92 | setup_for_distributed(args.rank == 0) 93 | -------------------------------------------------------------------------------- /lte/misc/merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lte 3 | 4 | 5 | class MergeCondition(): 6 | """ 7 | LTE Merge scheduler 8 | 9 | Args: 10 | model (nn.Module): the model to merge 11 | merge_steps (int): the number of steps before merging 12 | method (str): the method to use for merging (default: 'step') 13 | 14 | Example:: 15 | merge_scheduler = lte.misc.merge.MergeCondition(model, merge_steps=10, method='step') 16 | 17 | for ... in range(dataloader): 18 | # optimize model 19 | ... 20 | 21 | # step the merge scheduler (every 10 steps it will merge the model) 22 | merge_scheduler.step() 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model, 28 | merge_steps=1, 29 | method='step', 30 | ): 31 | 32 | self.model = model 33 | self.clock = 0 34 | self.merge_steps = merge_steps 35 | self.method = method 36 | self.reset_opt = True 37 | 38 | method_fn_map = { 39 | 'step': self.step_condition, 40 | } 41 | self.cond_fn = method_fn_map[method] 42 | return 43 | 44 | def peek(self): 45 | """ peeks whether model is planning to merge """ 46 | return (self.clock + 1) % self.merge_steps == 0 47 | 48 | 49 | def step(self): 50 | """ increments step count """ 51 | if self.model is None: 52 | raise RuntimeError('this merge condition is not registered to a model.') 53 | 54 | self.clock += 1 55 | merged = self.cond_fn() 56 | return merged 57 | 58 | @torch.no_grad() 59 | def merge(self): 60 | """ merges LTE and Replica layers """ 61 | for m in self.model.modules(): 62 | 63 | if isinstance(m, lte.LTELayer): 64 | m.merge_parameters() 65 | 66 | if isinstance(m, lte.ReplicaLayer): 67 | m.merge_parameters() 68 | return 69 | 70 | def step_condition(self): 71 | """ simple step condition """ 72 | if self.clock % int(self.merge_steps) == 0: 73 | self.merge() 74 | self.clock = 0 75 | return True 76 | return False 77 | 78 | def register_model(self, model): 79 | """ registers model """ 80 | self.model = model 81 | -------------------------------------------------------------------------------- /lte/misc/position.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from torch import Tensor 5 | 6 | 7 | class PositionalEncoding(nn.Module): 8 | """ slight modification from https://pytorch.org/tutorials/beginner/transformer_tutorial.html """ 9 | def __init__(self, d_model: int, max_len: int = 5000): 10 | super().__init__() 11 | self.register_buffer('pe', sinusoidal(d_model, max_len)) 12 | 13 | def forward(self, x): 14 | """ 15 | Args: 16 | x (Tensor): the input tensor of shape (seq_len, batch, embed) 17 | Returns: 18 | x (Tensor): the output tensor of shape (seq_len, batch, embed) 19 | """ 20 | return x + self.pe[:x.size(0)] 21 | 22 | 23 | def sinusoidal(d_model, max_len=5000): 24 | """ 25 | Create a sinusoidal positional encoding. 26 | 27 | Args: 28 | d_model (int): the model dimension 29 | max_len (int): the maximum length of the sequence 30 | Returns: 31 | pe (Tensor): the positional encoding tensor of shape (1, max_len, d_model) 32 | """ 33 | position = torch.arange(max_len).unsqueeze(1) 34 | div_term = torch.exp(torch.arange(0, d_model, 2) 35 | * (-math.log(10000.0) / d_model)) 36 | pe = torch.zeros(max_len, 1, d_model) 37 | pe[:, 0, 0::2] = torch.sin(position * div_term) 38 | pe[:, 0, 1::2] = torch.cos(position * div_term) 39 | return pe.permute(1, 0, 2) 40 | 41 | 42 | def sinusoidal_2d(d_model, height, width, normalization_constant=None): 43 | """ 44 | Create a 2D sinusoidal positional encoding. 45 | 46 | Args: 47 | d_model (int): the model dimension 48 | height (int): the height of the 2D grid 49 | width (int): the width of the 2D grid 50 | Returns: 51 | pe (Tensor): the positional encoding tensor of shape (1, height*width, d_model) 52 | """ 53 | if normalization_constant is None: 54 | normalization_constant = height * width 55 | 56 | # calculate div_term for both dimensions 57 | # this controls for the frequency cos(w / div_term) 58 | # not sure what the choice of using exp(log()) is 59 | div_term = torch.exp(torch.arange(0, d_model, 2) * 60 | (-math.log(normalization_constant) / d_model)) 61 | 62 | # create position encoding for height 63 | pe_y = torch.zeros(height, d_model) 64 | pe_y[:, 0::2] = torch.sin(torch.arange(height).unsqueeze(1) * div_term) 65 | pe_y[:, 1::2] = torch.cos(torch.arange(height).unsqueeze(1) * div_term) 66 | pe_y = pe_y.unsqueeze(1).repeat(1, width, 1) # Repeat for each width 67 | 68 | # create position encoding for width 69 | pe_x = torch.zeros(width, d_model) 70 | pe_x[:, 0::2] = torch.sin(torch.arange(width).unsqueeze(1) * div_term) 71 | pe_x[:, 1::2] = torch.cos(torch.arange(width).unsqueeze(1) * div_term) 72 | pe_x = pe_x.unsqueeze(0).repeat(height, 1, 1) # Repeat for each height 73 | 74 | # combine the encodings by adding 75 | pe = pe_y + pe_x 76 | 77 | # reshape to match the expected output shape 78 | return pe.view(1, height * width, d_model) 79 | -------------------------------------------------------------------------------- /lte/prepare.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import warnings 4 | from lte.replica import MultiheadReplicaLayer 5 | 6 | 7 | def prepare_model_for_lte( 8 | model, 9 | lora_config, 10 | copy_weights=True, 11 | strict=True, 12 | replica_layers=[], 13 | mode="ddp", 14 | ): 15 | """ 16 | Convert a model into an LTE model. 17 | 18 | Args: 19 | model (nn.Module): the model to convert 20 | lora_config (LTEConfigs): the lora config to use 21 | copy_weights (bool): if True, copy weights from the original model to the new model 22 | strict (bool): if True, raise an error if not all parameters are converted 23 | replica_layers (list): list of modules to convert for standard local-step averaging. 24 | mode (str): the mode to use. Options are "ddp" and "dmp" 25 | 26 | Returns: 27 | model (nn.Module): LTE model 28 | 29 | Example:: 30 | model = \ 31 | lte.prepare_model_for_lte( 32 | model, 33 | lte.misc.LTEConfig.default( 34 | lora_r=16, 35 | num_heads=4, 36 | ) 37 | ) 38 | """ 39 | 40 | if mode == "ddp": 41 | from lte.ddp.linear import MultiheadLoRALinear 42 | elif mode == "dmp": 43 | from lte.dmp.linear import MultiheadLoRALinear 44 | assert next(model.parameters()).is_cuda, "dmp expects model to be in cuda" 45 | elif mode == 'mhlora': 46 | from lte.mhlora.linear import MultiheadLoRALinear 47 | else: 48 | raise ValueError(f"mode {mode} not recognized") 49 | 50 | 51 | lora_kwargs = lora_config.lora 52 | linear_lora_kwargs = patch_kwargs(lora_kwargs, lora_kwargs.linear) 53 | orig_linear_lora_alpha = linear_lora_kwargs['lora_alpha'] 54 | 55 | # replace pytorch attention with custom attention since we look for nn.Linear 56 | converted_parameter_count = 0 57 | trainable_parameter_count = sum([p.numel() for p in model.parameters() if p.requires_grad]) 58 | 59 | supported_modules = ( 60 | nn.Linear, 61 | # nn.LayerNorm, 62 | # nn.Conv2d, 63 | # nn.Embedding, 64 | ) 65 | 66 | for n, m in model.named_modules(): 67 | 68 | if not isinstance(m, supported_modules) and (not (m in replica_layers)): 69 | continue 70 | 71 | if np.any([is_submodule(rm, m) for rm in replica_layers]): 72 | continue 73 | 74 | parent_module, old_module, target_name = _get_submodules(model, n) 75 | converted_parameter_count += sum([p.numel() for p in old_module.parameters() if p.requires_grad]) 76 | 77 | dtype = next(old_module.parameters()).dtype 78 | 79 | if m in replica_layers: 80 | new_module = MultiheadReplicaLayer( 81 | old_module, 82 | num_heads=lora_kwargs.num_heads, 83 | mode=mode, 84 | ).to(dtype=dtype) 85 | else: 86 | if isinstance(m, nn.Linear): 87 | device = next(old_module.parameters()).device 88 | 89 | new_module = MultiheadLoRALinear( 90 | old_module.in_features, 91 | old_module.out_features, 92 | bias=(old_module.bias is not None), 93 | **linear_lora_kwargs, 94 | ).to(device=device, dtype=dtype) 95 | 96 | if copy_weights: 97 | if mode == 'ddp': 98 | new_module.weight.data = old_module.weight.data 99 | if old_module.bias is not None: 100 | new_module.bias.data = old_module.bias.data 101 | else: 102 | new_module.weight.data = old_module.weight.data.clone().to(new_module.weight.device) 103 | if old_module.bias is not None: 104 | new_module.bias.data = old_module.bias.data.clone().to(new_module.bias.device) 105 | 106 | if mode == 'dmp': 107 | # dmp creates N copies of the original weight so it requires further copying 108 | for l in new_module.layers: 109 | l.weight.data = old_module.weight.data.clone().to(l.weight.device) 110 | 111 | if old_module.bias is not None: 112 | l.bias.data = old_module.bias.data.clone().to(l.bias.device) 113 | 114 | else: 115 | print("module replacement rule not found") 116 | 117 | setattr(parent_module, target_name, new_module) 118 | 119 | if converted_parameter_count != trainable_parameter_count: 120 | diff = trainable_parameter_count - converted_parameter_count 121 | e_msg = f"Converted parameter count {converted_parameter_count} " + \ 122 | f"does not match trainable parameter count {trainable_parameter_count} [diff: {diff}]" 123 | if strict: 124 | raise RuntimeError(e_msg) 125 | else: 126 | warnings.warn(e_msg) 127 | return model 128 | 129 | def _get_submodules(model, key): 130 | parent = model.get_submodule(".".join(key.split(".")[:-1])) 131 | target_name = key.split(".")[-1] 132 | target = model.get_submodule(key) 133 | return parent, target, target_name 134 | 135 | def patch_kwargs(kwargs, new_kwargs): 136 | kwargs = kwargs.copy() 137 | for key, value in new_kwargs.items(): 138 | kwargs[key] = value 139 | 140 | for k, v in list(kwargs.items()): 141 | if isinstance(v, dict): 142 | del kwargs[k] 143 | return kwargs 144 | 145 | def is_submodule(parent_module, submodule): 146 | return np.any([mod is submodule for mod in parent_module.modules()]) and (parent_module is not submodule) 147 | -------------------------------------------------------------------------------- /lte/replica.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | import lte.misc.distributed as D 5 | from lte import ReplicaLayer 6 | 7 | 8 | class MultiheadReplicaLayer(ReplicaLayer): 9 | """ 10 | Replica layer. Creates N replicas of the module. 11 | 12 | Args: 13 | target_module (nn.Module): the module to replicate 14 | num_heads (int): the number of replicas to create 15 | 16 | NOTE: does not support module that outputs tuples 17 | """ 18 | 19 | def __init__(self, target_module, num_heads, mode="ddp"): 20 | super().__init__() 21 | self.num_heads = num_heads 22 | self.replicas = [] 23 | for i in range(num_heads): 24 | if mode == "ddp": 25 | self.replicas.append(copy.deepcopy(target_module)) 26 | else: 27 | device_id = f"cuda:{i % D.num_visible_devices()}" 28 | self.replicas.append( 29 | copy.deepcopy(target_module).to(device=device_id) 30 | ) 31 | self.replicas = nn.ModuleList(self.replicas) 32 | return 33 | 34 | def forward(self, inputs): 35 | """ 36 | Args: 37 | inputs (torch.Tensor): the input tensor 38 | Returns: 39 | outputs (torch.Tensor): the output tensor 40 | """ 41 | if not self.training: 42 | replica_device = next(self.replicas[0].parameters()).device 43 | outputs = self.replicas[0](inputs.to(device=replica_device)) 44 | else: 45 | xs = inputs.chunk(self.num_heads) 46 | outputs = [] 47 | for x, replica in zip(xs, self.replicas): 48 | replica_device = next(replica.parameters()).device 49 | outputs.append(replica(x.to(device=replica_device))) 50 | outputs = torch.cat([x.to(device=inputs.device) for x in outputs]) 51 | return outputs 52 | 53 | @torch.no_grad() 54 | def merge_parameters(self): 55 | """ compute average across N devices and then assign to all copies in replica """ 56 | 57 | # compute average of the parameter 58 | avg_params = [torch.zeros_like(p) for p in self.replicas[0].parameters()] 59 | for replica in self.replicas: 60 | for p, avg_p in zip(replica.parameters(), avg_params): 61 | avg_p += p 62 | 63 | avg_params = [p / self.num_heads for p in avg_params] 64 | 65 | # assign to all replicas (clone and assign to correct device) 66 | for replica in self.replicas: 67 | for p, avg_p in zip(replica.parameters(), avg_params): 68 | p.data = avg_p.clone().to(device=p.device) 69 | return 70 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import setuptools 3 | 4 | 5 | setuptools.setup( 6 | name="lte", 7 | version="0.0.1", 8 | author="Minyoung Huh", 9 | author_email="minhuh@mit.edu", 10 | description=f"PyTorch LTE", 11 | url="git@github.com:minyoungg/lte.git", 12 | packages=setuptools.find_packages(), 13 | classifiers=[ 14 | "Programming Language :: Python :: 3", 15 | "Programming Language :: Python :: 3.9", 16 | "Programming Language :: Python :: 3.10", 17 | "Programming Language :: Python :: 3.11", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | python_requires='>=3.9', 22 | ) 23 | -------------------------------------------------------------------------------- /unittest/run_tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torch.nn as nn 4 | 5 | from lte import LTEConfig, prepare_model_for_lte, MultiheadReplicaLayer 6 | 7 | 8 | class TestModelParameters(unittest.TestCase): 9 | rtol = 1e-3 10 | atol = 1e-6 11 | lora_r = 8 12 | lora_alpha = 16 13 | num_heads = 4 14 | in_dim = 16 15 | out_dim = 32 16 | bias = True 17 | 18 | 19 | def setUp(self): 20 | torch.manual_seed(0) 21 | self.dmp_model = nn.Sequential(nn.Linear(self.in_dim, self.out_dim, bias=self.bias)).cuda() 22 | self.dmp_model = prepare_model_for_lte( 23 | self.dmp_model, 24 | LTEConfig.default( 25 | lora_r=self.lora_r, 26 | lora_alpha=self.lora_alpha, 27 | num_heads=self.num_heads, 28 | ), 29 | mode='dmp', 30 | strict=True, 31 | ).double() 32 | 33 | torch.manual_seed(0) 34 | self.ddp_model = nn.Sequential(nn.Linear(self.in_dim, self.out_dim, bias=self.bias)).cuda() 35 | self.ddp_model = prepare_model_for_lte( 36 | self.ddp_model, 37 | LTEConfig.default( 38 | lora_r=self.lora_r, 39 | lora_alpha=self.lora_alpha, 40 | num_heads=self.num_heads, 41 | ), 42 | mode='ddp', 43 | strict=True, 44 | ).double() 45 | 46 | for i in range(4): 47 | self.dmp_model[0].lora_A[i].weight.data = \ 48 | self.ddp_model[0].lora_A_weight[i].data.clone().to( 49 | device=self.dmp_model[0].lora_A[i].weight.device 50 | ) 51 | 52 | if self.dmp_model[0].lora_bias: 53 | self.dmp_model[0].lora_A[i].bias.data = \ 54 | self.ddp_model[0].lora_A_bias[i].data.clone().to( 55 | device=self.dmp_model[0].lora_A[i].bias.device 56 | ) 57 | 58 | def test_parameter_count_difference(self): 59 | # test if there is no parameter being tied 60 | ddp_param_count = sum(p.numel() for p in self.ddp_model.parameters()) 61 | dmp_param_count = sum(p.numel() for p in self.dmp_model.parameters()) 62 | self.assertNotEqual(ddp_param_count, dmp_param_count) 63 | 64 | def test_lora_parameter_synchronization(self): 65 | for i in range(4): 66 | dmp_w = self.dmp_model[0].layers[i].weight 67 | dmp_a = self.dmp_model[0].lora_A[i].weight 68 | dmp_b = self.dmp_model[0].lora_B[i].weight 69 | 70 | ddp_w = self.ddp_model[0].weight 71 | ddp_a = self.ddp_model[0].lora_A_weight[i] 72 | ddp_b = self.ddp_model[0].lora_B_weight[i] 73 | 74 | self.assertTrue(torch.allclose(ddp_w.cpu(), dmp_w.cpu())) 75 | self.assertTrue(torch.allclose(ddp_a.cpu(), dmp_a.cpu())) 76 | self.assertTrue(torch.allclose(ddp_b.cpu(), dmp_b.cpu())) 77 | 78 | def test_training_behavior(self): 79 | dmp_opt = torch.optim.AdamW(self.dmp_model.parameters(), lr=0.1) 80 | ddp_opt = torch.optim.AdamW(self.ddp_model.parameters(), lr=0.1) 81 | 82 | for _ in range(100): 83 | x = torch.randn(8, 128, self.in_dim).double().cuda() 84 | dmp_out = self.dmp_model(x) 85 | ddp_out = self.ddp_model(x) 86 | 87 | dmp_out.sum().backward() 88 | ddp_out.sum().backward() 89 | 90 | dmp_opt.step() 91 | ddp_opt.step() 92 | 93 | dmp_opt.zero_grad() 94 | ddp_opt.zero_grad() 95 | 96 | self.assertTrue(torch.allclose(dmp_out, ddp_out, rtol=self.rtol, atol=self.atol)) 97 | return self.dmp_model, self.ddp_model 98 | 99 | 100 | def test_training_behavior_and_merging_v1(self): 101 | self.test_training_behavior() 102 | 103 | # Test merging behavior 104 | self.dmp_model[0].merge_parameters() 105 | self.dmp_model[0].reset_lora_parameters() 106 | self.ddp_model[0].merge_parameters() 107 | self.ddp_model[0].reset_lora_parameters() 108 | 109 | x = torch.randn(8, 128, self.in_dim).double().cuda() 110 | 111 | dmp_out = self.dmp_model(x) 112 | ddp_out = self.ddp_model(x) 113 | 114 | self.assertTrue(torch.allclose(dmp_out, ddp_out, rtol=self.rtol, atol=self.atol)) 115 | 116 | def test_training_behavior_and_merging_v2(self): 117 | self.test_training_behavior() 118 | 119 | # Test merging behavior (no resets) 120 | self.dmp_model[0].merge_parameters() 121 | # self.dmp_model[0].reset_lora_parameters() 122 | self.ddp_model[0].merge_parameters() 123 | # self.ddp_model[0].reset_lora_parameters() 124 | 125 | x = torch.randn(8, 128, self.in_dim).double().cuda() 126 | 127 | dmp_out = self.dmp_model(x) 128 | ddp_out = self.ddp_model(x) 129 | 130 | self.assertTrue(torch.allclose(dmp_out, ddp_out, rtol=self.rtol, atol=self.atol)) 131 | 132 | 133 | def test_training_behavior_and_merging_without_resetting(self): 134 | self.test_training_behavior() 135 | 136 | # Test merging behavior 137 | self.dmp_model[0].merge_parameters() 138 | self.ddp_model[0].merge_parameters() 139 | 140 | x = torch.randn(8, 128, self.in_dim).double().cuda() 141 | 142 | dmp_out = self.dmp_model(x) 143 | ddp_out = self.ddp_model(x) 144 | self.assertTrue(torch.allclose(dmp_out, ddp_out, rtol=self.rtol, atol=self.atol)) 145 | 146 | 147 | 148 | def test_dmp_ddp_equivalence(self): 149 | 150 | ddp_opt = torch.optim.SGD(self.ddp_model.parameters(), lr=0.01) 151 | dmp_opt = torch.optim.SGD(self.dmp_model.parameters(), lr=0.01) 152 | 153 | # merging every iteration should have equivalent behavior 154 | for _ in range(10): 155 | x = torch.randn(16, 128, self.in_dim).double().cuda() 156 | y = torch.randn(16, 128, self.out_dim).double().cuda() 157 | 158 | dmp_out = self.dmp_model(x) 159 | ddp_out = self.ddp_model(x) 160 | 161 | (dmp_out - y).mean().backward() 162 | (ddp_out - y).mean().backward() 163 | 164 | with torch.no_grad(): 165 | self.assertTrue(torch.allclose(dmp_out, ddp_out, rtol=self.rtol, atol=self.atol)) 166 | 167 | # for each lora param check identical gradient 168 | for i in range(self.num_heads): 169 | ddp_grad_A = self.ddp_model[0].lora_A_weight.grad[i] 170 | ddp_grad_B = self.ddp_model[0].lora_B_weight.grad[i] 171 | dmp_grad_A = self.dmp_model[0].lora_A[i].weight.grad 172 | dmp_grad_B = self.dmp_model[0].lora_B[i].weight.grad 173 | 174 | self.assertTrue(torch.allclose(ddp_grad_A, dmp_grad_A, rtol=self.rtol, atol=self.atol)) 175 | self.assertTrue(torch.allclose(ddp_grad_B, dmp_grad_B, rtol=self.rtol, atol=self.atol)) 176 | 177 | if self.ddp_model[0].lora_bias: 178 | ddp_grad_A = self.ddp_model[0].lora_A_bias.grad[i] 179 | ddp_grad_B = self.ddp_model[0].lora_B_bias.grad[i] 180 | dmp_grad_A = self.dmp_model[0].lora_A[i].bias.grad 181 | dmp_grad_B = self.dmp_model[0].lora_B[i].bias.grad 182 | 183 | self.assertTrue(torch.allclose(ddp_grad_A, dmp_grad_A, rtol=self.rtol, atol=self.atol)) 184 | self.assertTrue(torch.allclose(ddp_grad_B, dmp_grad_B, rtol=self.rtol, atol=self.atol)) 185 | 186 | ddp_opt.step() 187 | dmp_opt.step() 188 | 189 | ddp_opt.zero_grad() 190 | dmp_opt.zero_grad() 191 | 192 | self.ddp_model[0].merge_parameters() 193 | self.dmp_model[0].merge_parameters() 194 | 195 | 196 | def test_mhlora_dmp_lte_equivalence(self): 197 | 198 | mhlora = MultiheadLoRA( 199 | self.in_dim, 200 | self.out_dim, 201 | self.num_heads, 202 | self.lora_r, 203 | self.lora_alpha, 204 | bias=self.bias 205 | ).cuda().double() 206 | 207 | self.dmp_model = self.dmp_model.double() 208 | 209 | # synchronize parameters 210 | for i in range(4): 211 | mhlora.linear.weight.data = self.dmp_model[0].layers[i].weight.data.clone() 212 | mhlora.lora_A[i].weight.data = self.dmp_model[0].lora_A[i].weight.data.clone() 213 | mhlora.lora_B[i].weight.data = self.dmp_model[0].lora_B[i].weight.data.clone() 214 | 215 | if self.dmp_model[0].bias is not None: 216 | mhlora.linear.bias.data = self.dmp_model[0].layers[i].bias.data.clone() 217 | 218 | if self.dmp_model[0].lora_bias: 219 | mhlora.lora_A[i].bias.data = self.dmp_model[0].lora_A[i].bias.data.clone() 220 | mhlora.lora_B[i].bias.data = self.dmp_model[0].lora_B[i].bias.data.clone() 221 | 222 | mhlora_opt = torch.optim.SGD(mhlora.parameters(), lr=0.01) 223 | lte_opt = torch.optim.SGD(self.dmp_model.parameters(), lr=0.01) 224 | 225 | # merging every iteration should have equivalent behavior 226 | for _ in range(10): 227 | x1 = torch.randn(2, 128, self.in_dim).double().cuda() 228 | y1 = torch.randn(2, 128, self.out_dim).double().cuda() 229 | mhlora_out = mhlora(x1) 230 | 231 | x2 = x1.repeat(self.num_heads, 1, 1) 232 | y2 = y1.repeat(self.num_heads, 1, 1).unflatten(0, (self.num_heads, 2)) 233 | 234 | lte_out = self.dmp_model(x2) 235 | lte_out = lte_out.unflatten(0, (self.num_heads, 2)) 236 | 237 | (mhlora_out - y1).mean().backward() 238 | (lte_out - y2).mean().backward() 239 | 240 | with torch.no_grad(): 241 | self.assertTrue(torch.allclose(mhlora_out, lte_out.mean(0), rtol=self.rtol, atol=self.atol)) 242 | 243 | # for each lora param check identical gradient 244 | for i in range(self.num_heads): 245 | mhlora_A_grad = mhlora.lora_A[i].weight.grad 246 | mhlora_B_grad = mhlora.lora_B[i].weight.grad 247 | lte_A_grad = self.dmp_model[0].lora_A[i].weight.grad 248 | lte_B_grad = self.dmp_model[0].lora_B[i].weight.grad 249 | 250 | self.assertTrue(torch.allclose(mhlora_A_grad, lte_A_grad, rtol=self.rtol, atol=self.atol)) 251 | self.assertTrue(torch.allclose(mhlora_B_grad, lte_B_grad, rtol=self.rtol, atol=self.atol)) 252 | 253 | if self.dmp_model[0].lora_bias: 254 | mhlora_A_grad = mhlora.lora_A[i].bias.grad 255 | mhlora_B_grad = mhlora.lora_B[i].bias.grad 256 | lte_A_grad = self.dmp_model[0].lora_A[i].bias.grad 257 | lte_B_grad = self.dmp_model[0].lora_B[i].bias.grad 258 | 259 | self.assertTrue(torch.allclose(mhlora_A_grad, lte_A_grad, rtol=self.rtol, atol=self.atol)) 260 | self.assertTrue(torch.allclose(mhlora_B_grad, lte_B_grad, rtol=self.rtol, atol=self.atol)) 261 | 262 | mhlora_opt.step() 263 | lte_opt.step() 264 | 265 | mhlora_opt.zero_grad() 266 | lte_opt.zero_grad() 267 | 268 | self.dmp_model[0].merge_parameters() 269 | # self.dmp_model[0].reset_lora_parameters() 270 | 271 | 272 | def test_mhlora_ddp_lte_equivalence(self): 273 | 274 | mhlora = MultiheadLoRA( 275 | self.in_dim, 276 | self.out_dim, 277 | self.num_heads, 278 | self.lora_r, 279 | self.lora_alpha, 280 | bias=self.bias 281 | ).cuda().double() 282 | 283 | self.ddp_model = self.ddp_model.double() 284 | 285 | # synchronize parameters 286 | for i in range(4): 287 | mhlora.linear.weight.data = self.ddp_model[0].weight.data.clone() 288 | mhlora.lora_A[i].weight.data = self.ddp_model[0].lora_A_weight[i].data.clone() 289 | mhlora.lora_B[i].weight.data = self.ddp_model[0].lora_B_weight[i].data.clone() 290 | 291 | if self.ddp_model[0].bias is not None: 292 | mhlora.linear.bias.data = self.ddp_model[0].bias.data.clone() 293 | 294 | if self.ddp_model[0].lora_bias: 295 | mhlora.lora_A[i].bias.data = self.ddp_model[0].lora_A_bias[i].data.clone() 296 | mhlora.lora_B[i].bias.data = self.ddp_model[0].lora_B_bias[i].data.clone() 297 | 298 | mhlora_opt = torch.optim.SGD(mhlora.parameters(), lr=0.01) 299 | lte_opt = torch.optim.SGD(self.ddp_model.parameters(), lr=0.01) 300 | 301 | # merging every iteration should have equivalent behavior 302 | for _ in range(10): 303 | x1 = torch.randn(2, 128, self.in_dim).double().cuda() 304 | y1 = torch.randn(2, 128, self.out_dim).double().cuda() 305 | mhlora_out = mhlora(x1) 306 | 307 | x2 = x1.repeat(self.num_heads, 1, 1) 308 | y2 = y1.repeat(self.num_heads, 1, 1).unflatten(0, (self.num_heads, 2)) 309 | 310 | lte_out = self.ddp_model(x2) 311 | lte_out = lte_out.unflatten(0, (self.num_heads, 2)) 312 | 313 | (mhlora_out - y1).mean().backward() 314 | (lte_out - y2).mean().backward() 315 | 316 | with torch.no_grad(): 317 | # print(mhlora_out, lte_out.mean(0)) 318 | # import ipdb; ipdb.set_trace() 319 | self.assertTrue(torch.allclose(mhlora_out, lte_out.mean(0), rtol=self.rtol, atol=self.atol)) 320 | 321 | # for each lora param check identical gradient 322 | for i in range(self.num_heads): 323 | mhlora_A_grad = mhlora.lora_A[i].weight.grad 324 | mhlora_B_grad = mhlora.lora_B[i].weight.grad 325 | lte_A_grad = self.ddp_model[0].lora_A_weight.grad[i] 326 | lte_B_grad = self.ddp_model[0].lora_B_weight.grad[i] 327 | 328 | self.assertTrue(torch.allclose(mhlora_A_grad, lte_A_grad, rtol=self.rtol, atol=self.atol)) 329 | self.assertTrue(torch.allclose(mhlora_B_grad, lte_B_grad, rtol=self.rtol, atol=self.atol)) 330 | 331 | if self.ddp_model[0].lora_bias: 332 | mhlora_A_grad = mhlora.lora_A[i].bias.grad 333 | mhlora_B_grad = mhlora.lora_B[i].bias.grad 334 | lte_A_grad = self.ddp_model[0].lora_A_bias.grad[i] 335 | lte_B_grad = self.ddp_model[0].lora_B_bias.grad[i] 336 | 337 | self.assertTrue(torch.allclose(mhlora_A_grad, lte_A_grad, rtol=self.rtol, atol=self.atol)) 338 | self.assertTrue(torch.allclose(mhlora_B_grad, lte_B_grad, rtol=self.rtol, atol=self.atol)) 339 | 340 | mhlora_opt.step() 341 | lte_opt.step() 342 | 343 | mhlora_opt.zero_grad() 344 | lte_opt.zero_grad() 345 | 346 | self.ddp_model[0].merge_parameters() 347 | # self.ddp_model[0].reset_lora_parameters() 348 | 349 | def test_mhlora_ddp_lte_equivalence_v2(self): 350 | 351 | torch.manual_seed(0) 352 | lte_model = nn.Sequential(nn.Linear(self.in_dim, self.out_dim, bias=self.bias)).cuda() 353 | lte_model = prepare_model_for_lte( 354 | lte_model, 355 | LTEConfig.default( 356 | lora_r=self.lora_r, 357 | lora_alpha=self.lora_alpha, 358 | num_heads=1, 359 | ), 360 | mode='ddp', 361 | strict=True, 362 | ).double() 363 | 364 | torch.manual_seed(0) 365 | mhlora_model = nn.Sequential(nn.Linear(self.in_dim, self.out_dim, bias=self.bias)).cuda() 366 | mhlora_model = prepare_model_for_lte( 367 | mhlora_model, 368 | LTEConfig.default( 369 | lora_r=self.lora_r, 370 | lora_alpha=self.lora_alpha, 371 | num_heads=1, 372 | ), 373 | mode='mhlora', 374 | strict=True, 375 | ).double() 376 | 377 | # check if parameters have same dynamics throughout training and merging 378 | # LTE with merge=1 should be equivalent to mhlora without merge 379 | 380 | # lte_opt = torch.optim.SGD(lte_model.parameters(), lr=0.01) 381 | # mhlora_opt = torch.optim.SGD(mhlora_model.parameters(), lr=0.01) 382 | 383 | lte_opt = torch.optim.Adam(lte_model.parameters(), lr=0.01) 384 | mhlora_opt = torch.optim.Adam(mhlora_model.parameters(), lr=0.01) 385 | 386 | for i in range(10): 387 | # optimize both with the same input 388 | x = torch.randn(16, 128, self.in_dim).double().cuda() 389 | y = torch.randn(16, 128, self.out_dim).double().cuda() 390 | 391 | lte_out = lte_model(x) 392 | mhlora_out = mhlora_model(x) 393 | 394 | lte_loss = (lte_out - y).mean() 395 | mhlora_loss = (mhlora_out - y).mean() 396 | self.assertTrue(torch.allclose(lte_out, mhlora_out, rtol=self.rtol, atol=self.atol)) 397 | 398 | lte_loss.backward() 399 | mhlora_loss.backward() 400 | 401 | lte_opt.step() 402 | mhlora_opt.step() 403 | 404 | lte_opt.zero_grad() 405 | mhlora_opt.zero_grad() 406 | 407 | lte_model[0].merge_parameters() 408 | return 409 | 410 | 411 | def test_replica_behavior(self): 412 | 413 | model = nn.Sequential(nn.Linear(self.in_dim, self.out_dim, bias=self.bias)).double().cuda() 414 | model = MultiheadReplicaLayer(model, self.num_heads) 415 | 416 | x = torch.randn(1, self.in_dim).double().cuda() 417 | x = x.repeat(self.num_heads, 1) 418 | ys = model(x).chunk(self.num_heads) 419 | 420 | # check all equal 421 | for y in ys: 422 | self.assertTrue(torch.allclose(ys[0], y, rtol=self.rtol, atol=self.atol)) 423 | 424 | # check if the parameters are tied (updated 1 layer and see if other layers are updated) 425 | model.replicas[0][0].weight.data += 1 426 | ys = model(x).chunk(self.num_heads) 427 | 428 | for y in ys[1:]: 429 | self.assertFalse(torch.allclose(ys[0], y, rtol=self.rtol, atol=self.atol)) 430 | 431 | # make sure merge parameters works 432 | model = nn.Sequential(nn.Linear(self.in_dim, self.out_dim, bias=self.bias)).double().cuda() 433 | model = MultiheadReplicaLayer(model, self.num_heads) 434 | model.merge_parameters() 435 | ys = model(x).chunk(self.num_heads) 436 | 437 | # nothing should have changed 438 | for y in ys: 439 | self.assertTrue(torch.allclose(ys[0], y, rtol=self.rtol, atol=self.atol)) 440 | 441 | 442 | 443 | class MultiheadLoRA(nn.Module): 444 | def __init__(self, in_dim, out_dim, num_heads, lora_r, lora_alpha, bias=False, lora_bias=False): 445 | super().__init__() 446 | self.linear = nn.Linear(in_dim, out_dim, bias=bias) 447 | self.lora_A = nn.ModuleList( 448 | [nn.Linear(in_dim, lora_r, bias=lora_bias) for _ in range(num_heads)]) 449 | self.lora_B = nn.ModuleList( 450 | [nn.Linear(lora_r, out_dim, bias=lora_bias) for _ in range(num_heads)]) 451 | self.s = lora_alpha / lora_r 452 | self.num_heads = num_heads 453 | 454 | for p in self.linear.parameters(): 455 | p.requires_grad_(False) 456 | return 457 | 458 | def forward(self, x): 459 | out = self.linear(x) 460 | for a, b in zip(self.lora_A, self.lora_B): 461 | out += (self.s / self.num_heads * b(a(x))) 462 | return out 463 | 464 | 465 | if __name__ == '__main__': 466 | unittest.main() 467 | --------------------------------------------------------------------------------