├── LICENSE ├── README.md ├── examples ├── huggingface_gpt2.py └── simple.py ├── lorax ├── __init__.py ├── constants.py ├── helpers.py └── transform.py ├── pyproject.toml └── tests ├── conftest.py ├── test_utils.py └── tests.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 davisyoshida 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lorax: LoRA for JAX functions 2 | This is a JAX transform which implements [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). LoRA replaces operations like `Wx` with `(W + BA)x` where `A` and `B` are skinny rectangular matrices. You can then train only `A` and `B`, and leave `W` frozen, which dramatically reduces the amount of memory needed for things like optimizer states. 3 | 4 | Lorax should work on most JAX models. I did my testing with my models which use Haiku, and you can find an example of applying it to a HuggingFace Flax model in the [examples directory(examples/). 5 | 6 | ## Installation 7 | 8 | ```bash 9 | pip install jax-lorax 10 | ``` 11 | 12 | ## Changelog 13 | 14 | ### 0.2.0 15 | * Replaced backend with [Qax](https://github.com/davisyoshida/qax) 16 | * Overhauled API to simplify usage (No more need to separately handle frozen/tunable params) 17 | 18 | ### Running tests 19 | Install dev dependencies: 20 | ```bash 21 | git clone https://github.com/davisyoshida/lorax.git 22 | cd lorax 23 | pip install poetry 24 | poetry install 25 | ``` 26 | 27 | Run tests: 28 | ``` 29 | pytest tests.py 30 | ``` 31 | 32 | ## Minimal example 33 | Lorax makes it so you can take model code which wasn't written with LoRA in mind, and transform it so that it does! For example, consider the following MLP code: 34 | 35 | ```python 36 | 37 | import jax 38 | import jax.numpy as jnp 39 | 40 | import optax 41 | 42 | def model(params, x): 43 | """My model, written in the dark ages before LoRA, using gratuitous amounts of VRAM when trained""" 44 | for massive_w in params: 45 | x = jax.nn.relu(x @ massive_w) 46 | return jnp.sum(x) 47 | 48 | dim = 5000 49 | 50 | # Initialize about 3 GB of params 51 | params = [jax.random.normal(jax.random.PRNGKey(i), (dim, dim)) / (dim ** 0.5) for i in range(30)] 52 | optimizer = optax.adam(learning_rate=3e-4) 53 | 54 | # OOM on 7GB GPU :( 55 | opt_state = optimizer.init(params) 56 | ``` 57 | 58 | The optimizer states are way too expensive, but applying Lorax lets you just train two `5000 x 64` matrices for each original weight. 59 | 60 | First import lorax and transform your model: 61 | ```python 62 | import lorax 63 | 64 | # Transform the model code 65 | lora_model = lorax.lora(model) 66 | ``` 67 | 68 | Next initialize the new LoRA parameters: 69 | ```python 70 | # Tell LoRA what to use as the small dimension of B and A 71 | rank_constraint = 64 72 | lora_spec = [rank_constraint for param in params] 73 | 74 | # Initialize a set of LoRA factors for each parameter 75 | lora_params = lorax.init_lora(param_tree=params, spec=lora_spec, rng=jax.random.PRNGKey(0)) 76 | 77 | # The transformed model has the same call signature, but it can now handle parameters 78 | # of type lorax.LoraWeight 79 | lora_model(lora_params, jnp.ones((dim,))) 80 | 81 | # Wrap the optimizer so it will freeze parameters not marked as trainable by the spec 82 | optimizer = lorax.wrap_optimizer(optimizer, lora_spec) 83 | 84 | # Now the optimizer can be used just like normal 85 | opt_state = optimizer.init(lora_params) 86 | 87 | ``` 88 | 89 | That's it for the Lorax specific stuff. The wrapped `lora_model` function is just an ordinary 90 | JAX function, and the LoraWeight instances a pytrees. 91 | ```python 92 | # Normal update function: 93 | @jax.jit 94 | def update_fn(lora_params, opt_state, x): 95 | grad_fn = jax.value_and_grad(lora_model) 96 | loss, grad = grad_fn(lora_params, x) 97 | 98 | updates, new_opt_state = optimizer.update(grad, opt_state, params=lora_params) 99 | updated_params = optax.apply_updates(lora_params, updates) 100 | return loss, new_opt_state, updated_params 101 | ``` 102 | 103 | Now for some dummy data and the training loop: 104 | ```python 105 | x = jax.random.normal(jax.random.PRNGKey(0), (dim,)) 106 | for i in range(10): 107 | loss, opt_state, lora_params = update_fn(lora_params, opt_state, x) 108 | print(f'Step: {i} loss: {loss:.4e}') # Number goes down! 109 | # Step: 0 loss: 6.6614e-02 110 | # Step: 1 loss: 4.4402e-02 111 | # Step: 2 loss: 3.0241e-02 112 | # Step: 3 loss: 1.8457e-02 113 | # Step: 4 loss: 1.2326e-02 114 | # Step: 5 loss: 8.8878e-03 115 | # Step: 6 loss: 6.0599e-03 116 | # Step: 7 loss: 4.3899e-03 117 | # Step: 8 loss: 3.0839e-03 118 | # Step: 9 loss: 2.2423e-03 119 | ``` 120 | 121 | Number goes down! We can now merge the trained LoRA params with the frozen params, and use them with the unmodified model: 122 | ```python 123 | lora_output = lora_model((frozen_params, tunable_params), x) 124 | 125 | # Now we merge the params to get params usable in the original model 126 | merged_params = lorax.merge_params(lora_params) 127 | orig_model_output = model(merged_params, x) 128 | 129 | # Verify that the model outputs are the same 130 | print(f'Difference between split and merged outputs: {orig_model_output - lora_output:.3e}') 131 | # Difference between split and merged params: 1.164e-10 132 | ``` 133 | 134 | See [examples/huggingface_gpt2.py](examples/huggingface_gpt2.py) for an example applying Lorax to a realistic model. 135 | -------------------------------------------------------------------------------- /examples/huggingface_gpt2.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import optax 6 | from transformers import FlaxGPT2LMHeadModel 7 | 8 | import lorax 9 | 10 | def main(): 11 | model = FlaxGPT2LMHeadModel.from_pretrained('gpt2') 12 | 13 | # This function defines a spec which tells lorax how each parameter should be handled 14 | def decision_fn(path, param): 15 | if 'embedding' in path: 16 | print(f'Fully finetuning param {path}') 17 | return LORA_FULL 18 | dim = 32 19 | print(f'Using LoRA with dim={dim} for param {path}') 20 | return dim 21 | 22 | # Create a pytree with the same shape as params indicating how each parameter should be handled 23 | # Each leaf will be given one of the following values: 24 | # - LORA_FULL: The parameter will be fully finetuned 25 | # - LORA_FREEZE: The parameter will be frozen 26 | # - k > 0: The parameter will be LoRA tuned with a rank k update 27 | 28 | # Simple_spec is a helper to do this, but you can also create the label pytree yourself 29 | lora_spec = lorax.simple_spec(model.params, decision_fn=decision_fn, tune_vectors=True) 30 | 31 | # Split the parameters up into tunable and frozen ones, and initialize a pair of LoRA matrices for each parameter 32 | # which had a spec value other than LORA_FULL or LORA_FREEZE 33 | lora_params = lorax.init_lora(model.params, lora_spec, jax.random.PRNGKey(0)) 34 | 35 | optimizer = optax.adamw(learning_rate=1e-4, weight_decay=1e-4) 36 | 37 | # `wrap_optimizer` uses the spec to freeze the appropriate subset 38 | # of parameters. 39 | # The frozen parameters won't have optimizer states etc 40 | # created for them 41 | optimizer = lorax.wrap_optimizer(optimizer, lora_spec) 42 | 43 | opt_state = optimizer.init(lora_params) 44 | 45 | # lorax.lora wraps a callable so that the arguments can be lorax.LoraWeight 46 | # instances. (It's actually just an alias for qax.use_implicit_args, so 47 | # the wrapped function can handle other qax types as well) 48 | lora_model = lorax.lora(model) 49 | 50 | # No changes are necessary to the loss function apart from using the wrapped model 51 | def loss_fn(lora_params, batch): 52 | input_ids = batch[:, :-1] 53 | 54 | # The call signature of the wrapped model is unchanged from the original HuggingFace model 55 | logits = lora_model(input_ids, params=lora_params).logits 56 | 57 | logprobs = jax.nn.log_softmax(logits) 58 | target_logprobs = jnp.take_along_axis(logprobs, batch[:, 1:, None], axis=-1) 59 | return -jnp.mean(target_logprobs) 60 | 61 | # The update function also doesn't need to be modified other than 62 | # using the wrapped optimizer 63 | @jax.jit 64 | def update_fn(lora_params, opt_state, batch): 65 | loss, grads = jax.value_and_grad(loss_fn)(lora_params, batch) 66 | updates, new_opt_state = optimizer.update(grads, opt_state, params=lora_params) 67 | 68 | new_params = optax.apply_updates(lora_params, updates) 69 | return new_params, new_opt_state, loss 70 | 71 | # Train on a dummy batch to show that we can fit the model to stuff 72 | example_data = jax.random.randint(jax.random.PRNGKey(0), (4, 128), 0, 50257) 73 | for _ in range(100): 74 | lora_params, opt_state, loss = update_fn(lora_params, opt_state, example_data) 75 | print(loss) 76 | 77 | final_predictions = lora_model(example_data, params=lora_params).logits 78 | 79 | # Now let's merge the loras back into the original parameters to get 80 | # finetuned parameters we can use with no extra compute 81 | merged_params = lorax.merge_params(lora_params) 82 | 83 | orig_model_predictions = model(example_data, params=merged_params).logits 84 | 85 | gap = jnp.max(jnp.abs(final_predictions - orig_model_predictions)) 86 | print(f'Max prediction gap: {gap:.3e}') 87 | 88 | if __name__ == '__main__': 89 | main() 90 | -------------------------------------------------------------------------------- /examples/simple.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import jax 3 | import jax.numpy as jnp 4 | 5 | import optax 6 | 7 | def model(params, x): 8 | """My model, written in the dark ages before LoRA, using gratuitous amounts of VRAM when trained""" 9 | for massive_w in params: 10 | x = jax.nn.relu(x @ massive_w) 11 | return jnp.sum(x) 12 | 13 | dim = 5000 14 | 15 | # Initialize about 3 GB of params 16 | params = [jax.random.normal(jax.random.PRNGKey(i), (dim, dim)) / (dim ** 0.5) for i in range(30)] 17 | optimizer = optax.adam(learning_rate=3e-4) 18 | 19 | # OOM on 7GB GPU :( 20 | # opt_state = optimizer.init(params) 21 | 22 | import lorax 23 | 24 | # Transform the model code 25 | lora_model = lorax.lora(model) 26 | 27 | # Tell LoRA what to use as the small dimension of B and A 28 | rank_constraint = 64 29 | lora_spec = [rank_constraint for param in params] 30 | 31 | # Initialize a set of LoRA factors for each parameter 32 | lora_params = lorax.init_lora(param_tree=params, spec=lora_spec, rng=jax.random.PRNGKey(0)) 33 | 34 | # The transformed model has the same call signature, but it can now handle parameters 35 | # of type lorax.LoraWeight 36 | lora_model(lora_params, jnp.ones((dim,))) 37 | 38 | # Wrap the optimizer so it will freeze parameters not marked as trainable by the spec 39 | optimizer = lorax.wrap_optimizer(optimizer, lora_spec) 40 | 41 | # Now the optimizer can be used just like normal 42 | opt_state = optimizer.init(lora_params) 43 | 44 | @partial(jax.jit, donate_argnums=(0, 1)) 45 | def update_fn(lora_params, opt_state, x): 46 | # The transformed model function is compatible with all the normal JAX transforms 47 | # It's just a function which maps pytrees to pytrees 48 | grad_fn = jax.value_and_grad(lora_model) 49 | loss, grad = grad_fn(lora_params, x) 50 | 51 | updates, new_opt_state = optimizer.update(grad, opt_state, params=lora_params) 52 | updated_params = optax.apply_updates(lora_params, updates) 53 | return loss, new_opt_state, updated_params 54 | 55 | x = jax.random.normal(jax.random.PRNGKey(0), (dim,)) 56 | for i in range(10): 57 | loss, opt_state, lora_params = update_fn(lora_params, opt_state, x) 58 | print(f'Step: {i} loss: {loss:.4e}') # Number goes down! 59 | 60 | # Save the output to verify correctness 61 | lora_output = lora_model(lora_params, x) 62 | 63 | # Now we merge the params to get params usable in the original model 64 | merged_params = lorax.merge_params(lora_params) 65 | orig_model_output = model(merged_params, x) 66 | 67 | # Verify that the model outputs are the same 68 | print(f'Difference between split and merged outputs: {orig_model_output - lora_output:.3e}') 69 | -------------------------------------------------------------------------------- /lorax/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import LORA_FREEZE, LORA_FULL 2 | from .helpers import init_lora, merge_params, simple_spec, split_lora_params, wrap_optimizer 3 | from .transform import LoraWeight, lora 4 | -------------------------------------------------------------------------------- /lorax/constants.py: -------------------------------------------------------------------------------- 1 | LORA_FREEZE = 0 2 | LORA_FULL = -1 3 | -------------------------------------------------------------------------------- /lorax/helpers.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax.tree_util import tree_map_with_path, DictKey, SequenceKey 4 | 5 | import optax 6 | import qax 7 | 8 | from .constants import LORA_FREEZE, LORA_FULL 9 | from .transform import LoraWeight 10 | 11 | 12 | def init_lora(param_tree, spec, rng, stddev=0.01, dtype=jnp.float32, alpha=1., is_leaf=None): 13 | def iter_keys(key): 14 | while True: 15 | key, out_key = jax.random.split(key) 16 | yield out_key 17 | 18 | key_it = iter_keys(rng) 19 | 20 | def get_param(path, param, spec_val): 21 | if spec_val in (LORA_FREEZE, LORA_FULL): 22 | return param 23 | 24 | if len(param.shape) == 1: 25 | raise ValueError(f'Vectors must either be frozen or fully tuned, but got spec value {spec} for param with path {path}') 26 | 27 | if len(param.shape) == 2: 28 | b_dim, a_dim = param.shape 29 | 30 | b = jnp.zeros((b_dim, spec_val), dtype=dtype) 31 | a = jax.random.normal(next(key_it), (spec_val, a_dim), dtype=dtype) * stddev 32 | return LoraWeight(w=param, a=a, b=b, alpha=alpha) 33 | 34 | # conv case 35 | *window_shape, in_channels, out_channels = param.shape 36 | 37 | a = jnp.zeros(( 38 | *(1 for _ in range(len(window_shape))), 39 | spec_val, 40 | out_channels 41 | ), dtype=param.dtype) 42 | b = jax.random.normal(rng, (*window_shape, in_channels, spec_val), dtype=param.dtype) * stddev 43 | return LoraWeight(param, a, b, alpha=alpha) 44 | 45 | return jax.tree_util.tree_map_with_path(get_param, param_tree, spec, is_leaf=is_leaf) 46 | 47 | def simple_spec(params, decision_fn=None, tune_vectors=False, is_leaf=None): 48 | """ 49 | Create a simple lora spec for a pytree 50 | Args: 51 | params: pytree of parameters 52 | tune_vectors: If true, will flag all arrays with less than 2 dimensions for tuning 53 | decision_fn: A function which maps a Jax KeyPath and a parameter to a spec value 54 | """ 55 | if decision_fn is None: 56 | def decision_fn(*args): 57 | return LORA_FREEZE 58 | 59 | def full_fn(path, arr): 60 | if len(arr.shape) < 2: 61 | return LORA_FULL if tune_vectors else LORA_FREEZE 62 | 63 | value = decision_fn(path, arr) 64 | return value 65 | 66 | return tree_map_with_path(full_fn, params, is_leaf=is_leaf) 67 | 68 | def merge_params(lora_params, destructive=True, use_scaling=True): 69 | """ 70 | Re-merge LoRA parameters. 71 | Arguments: 72 | destructive: If True, the buffers in frozen_params may be freed to save memory. 73 | use_scaling: Whether to multiply LoRA params by alpha/r 74 | """ 75 | if not use_scaling: 76 | raise ValueError('Scaling is now always enabled to match the original LoRA implementation.') 77 | 78 | def _ensure_delete(val): 79 | if not isinstance(val, jax.Array) or val.is_deleted(): 80 | return 81 | 82 | val.device_buffer.delete() 83 | 84 | 85 | materializer = jax.jit(qax.materialize_nested, donate_argnums=0 if destructive else ()) 86 | def map_fn(param): 87 | if isinstance(param, LoraWeight): 88 | result = materializer(param) 89 | if destructive: 90 | jax.tree_map(_ensure_delete, param) 91 | return result 92 | return param 93 | 94 | return qax.utils.tree_map_with_implicit(map_fn, lora_params) 95 | 96 | def split_lora_params(params, spec): 97 | """ 98 | Map params to a pytree in which all `LoraWeight.w` values and all params marked with 99 | LORA_FREEZE are replaced with qax.EmptyNode. This is useful for checkpointing just 100 | the trainable params. 101 | """ 102 | def node_mapper(node, spec_val): 103 | if not isinstance(node, LoraWeight): 104 | return node if spec_val != LORA_FREEZE else qax.EmptyNode 105 | children, aux = node.tree_flatten_with_keys() 106 | idx = next(i for i, (key, _) in enumerate(children) if key == 'w') 107 | children[idx] = ('w', qax.EmptyNode) 108 | 109 | return LoraWeight.tree_unflatten(aux, [c for _, c in children]) 110 | 111 | return qax.utils.tree_map_with_implicit(node_mapper, params, spec) 112 | 113 | def wrap_optimizer(optimizer : optax.GradientTransformation, spec, scalar_frozen_grads=False): 114 | full_freeze_labels = jax.tree_map( 115 | lambda x: 'freeze' if x == LORA_FREEZE else 'train', 116 | spec 117 | ) 118 | optimizer_with_full_freeze = qax.utils.freeze_subtrees( 119 | optimizer, 120 | full_freeze_labels, 121 | use_scalar_zeros=scalar_frozen_grads 122 | ) 123 | 124 | return qax.freeze_keys(optimizer_with_full_freeze, LoraWeight, 'w', use_scalar_zeros=scalar_frozen_grads) 125 | -------------------------------------------------------------------------------- /lorax/transform.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | import warnings 4 | 5 | import jax 6 | import qax 7 | 8 | def lora(f): 9 | """ 10 | Alias for qax.use_implicit_args to reduce necessary modification to code 11 | using older version of Lorax 12 | """ 13 | return qax.use_implicit_args(f) 14 | 15 | @dataclass 16 | class LoraWeight(qax.ImplicitArray): 17 | w : qax.ArrayValue # M x N 18 | a : qax.ArrayValue # k x N 19 | b : qax.ArrayValue # M x k 20 | 21 | alpha : float = qax.aux_field(default=1.) 22 | 23 | def __post_init__(self): 24 | super().__post_init__() 25 | assert self.a.shape[-2] == self.b.shape[-1] 26 | assert self.w.shape[-2] == self.b.shape[-2] 27 | assert self.w.shape[-1] == self.a.shape[-1] 28 | 29 | def materialize(self): 30 | return (self.w + self.get_scale() * self.b @ self.a).astype(self.w.dtype) 31 | 32 | def get_scale(self): 33 | return self.alpha / self.b.shape[-1] 34 | 35 | def _check_dot_dimension_numbers(dimension_numbers): 36 | (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers 37 | if lhs_batch or rhs_batch: 38 | warnings.warn('Lorax does not support batched matmuls') 39 | return False 40 | if len(lhs_contract) != 1 or len(rhs_contract) != 1: 41 | warnings.warn('Lorax only supports matmul') 42 | return False 43 | return True 44 | 45 | @qax.primitive_handler('dot_general') 46 | def handle_dot_lhs(primitive, lora : LoraWeight, rhs: qax.ArrayValue, *, dimension_numbers, **kwargs): 47 | if not _check_dot_dimension_numbers(dimension_numbers): 48 | return NotImplemented 49 | 50 | if isinstance(rhs, LoraWeight): 51 | rhs = rhs.materialize() 52 | warnings.warn('Encountered product of two LoraWeights. Materializing the rhs') 53 | 54 | op = partial(jax.lax.dot_general, **kwargs) 55 | 56 | 57 | lhs_contract, = dimension_numbers[0][0] 58 | 59 | first, second = (lora.a, lora.b) if lhs_contract == 1 else (lora.b, lora.a) 60 | 61 | first *= lora.get_scale() 62 | 63 | orig = op(lora.w, rhs, dimension_numbers=dimension_numbers) 64 | lora_product = op(first, rhs, dimension_numbers=dimension_numbers) 65 | 66 | second_dimension_numbers = ((lhs_contract,), (0,)), dimension_numbers[1] 67 | 68 | lora_product = op(second, lora_product, dimension_numbers=second_dimension_numbers) 69 | 70 | return (orig + lora_product).astype(orig.dtype) 71 | 72 | @qax.primitive_handler('dot_general') 73 | def handle_dot_rhs(primitive, lhs : jax.Array, lora: LoraWeight, *, dimension_numbers, **kwargs): 74 | if not _check_dot_dimension_numbers(dimension_numbers): 75 | return NotImplemented 76 | op = partial(jax.lax.dot_general, **kwargs) 77 | 78 | rhs_contract, = dimension_numbers[0][1] 79 | first, second = (lora.a, lora.b) if rhs_contract == 1 else (lora.b, lora.a) 80 | 81 | first *= lora.get_scale() 82 | 83 | orig = op(lhs, lora.w, dimension_numbers=dimension_numbers) 84 | lora_product = op(lhs, first, dimension_numbers=dimension_numbers) 85 | 86 | second_dimension_numbers = ((lhs.ndim - 1), (rhs_contract,)), dimension_numbers[1] 87 | 88 | lora_product = op(lora_product, second, dimension_numbers=second_dimension_numbers) 89 | 90 | return (orig + lora_product).astype(orig.dtype) 91 | 92 | 93 | @qax.primitive_handler('conv_general_dilated') 94 | def handle_conv(primitive, inp : qax.ArrayValue, lora : LoraWeight, *, dimension_numbers, **params): 95 | if isinstance(inp, LoraWeight): 96 | warnings.warn('Using a LoraWeight as input to a convolution is not supported, so it will be materialized.') 97 | inp = inp.materialize() 98 | 99 | if not dimension_numbers.rhs_spec[:1] != ( 100 | len(dimension_numbers.rhs_spec) - 1, 101 | len(dimension_numbers.rhs_spec) - 2, 102 | ): 103 | raise ValueError('Lorax only supports convolutions with shape (..., in_features, out_features)') 104 | 105 | params = {**params, 'dimension_numbers': dimension_numbers} 106 | op = partial(jax.lax.conv_general_dilated, **params) 107 | orig = op(inp, lora.w) 108 | 109 | lora_product = op(inp, lora.b) 110 | 111 | params['window_strides'] = (1,) * (len(dimension_numbers.rhs_spec) - 2) 112 | params['padding'] = 'VALID' 113 | lora_product = jax.lax.conv_general_dilated( 114 | lora_product, 115 | lora.a * lora.get_scale(), 116 | **params 117 | ) 118 | 119 | return (orig + lora_product).astype(orig.dtype) 120 | 121 | @qax.primitive_handler('gather') 122 | def handle_gather(primitive, lora : LoraWeight, indices : jax.Array, *, dimension_numbers, slice_sizes, **params): 123 | if dimension_numbers.offset_dims != (len(indices.shape) - 1,): 124 | return NotImplemented 125 | 126 | lora_dim = lora.b.shape[-1] 127 | 128 | if slice_sizes != (1, lora.a.shape[1]): 129 | return NotImplemented 130 | 131 | params = {**params, 'dimension_numbers': dimension_numbers} 132 | 133 | orig = jax.lax.gather(lora.w, indices, slice_sizes=slice_sizes, **params) 134 | 135 | new_slice_sizes = (1, lora_dim) 136 | 137 | lora_product = jax.lax.gather(lora.b, indices, slice_sizes=new_slice_sizes, **params) 138 | lora_product = lora_product @ (lora.a * lora.get_scale()) 139 | 140 | return (orig + lora_product).astype(orig.dtype) 141 | 142 | @qax.primitive_handler('transpose') 143 | def eval_lora_transpose(primitive, arg : LoraWeight, *, permutation): 144 | if not len(arg.shape) == 2 and permutation == (1, 0): 145 | return NotImplemented 146 | 147 | return LoraWeight( 148 | w=arg.w.T, 149 | a=arg.b.T, 150 | b=arg.a.T, 151 | alpha=arg.alpha, 152 | ) 153 | 154 | @qax.primitive_handler('convert_element_type') 155 | def eval_lora_convert_element_type(primitive, arg : LoraWeight, **params): 156 | result = jax.tree_map( 157 | partial(qax.default_handler, primitive, **params), 158 | arg 159 | ) 160 | result.dtype = params['new_dtype'] 161 | return result 162 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = 'jax-lorax' 3 | version = '0.3.1' 4 | description = 'A JAX transform which applies LoRA to arbitrary JAX functions/models' 5 | authors = ['Davis Yoshida '] 6 | license = 'MIT' 7 | readme = 'README.md' 8 | packages = [ 9 | {include = 'lorax'} 10 | ] 11 | 12 | [tool.poetry.dependencies] 13 | python = '^3.8' 14 | jax = '^0.4.6' 15 | jaxlib = '^0.4.6' 16 | qax = '>=0.3.1' 17 | 18 | [tool.poetry.dev-dependencies] 19 | pytest = '^7.3.1' 20 | 21 | [build-system] 22 | requires = ['poetry-core>=1.0.0'] 23 | build-backend = 'poetry.core.masonry.api' 24 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import jax 4 | import pytest 5 | 6 | from lorax import LoraWeight 7 | 8 | @pytest.fixture(scope='session') 9 | def simple_params(): 10 | m, rank_constraint, n = 11, 7, 19 11 | x = jax.random.normal(jax.random.PRNGKey(0), (n, 10)) 12 | 13 | w = jax.random.normal(jax.random.PRNGKey(1), (m, n)) 14 | a = jax.random.normal(jax.random.PRNGKey(2), (rank_constraint, n)) 15 | b = jax.random.normal(jax.random.PRNGKey(3), (m, rank_constraint)) 16 | 17 | full = w + b @ a / b.shape[1] 18 | 19 | lora_params = LoraWeight( 20 | w=w, 21 | a=a, 22 | b=b 23 | ) 24 | 25 | return full, x, lora_params 26 | 27 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import optax 4 | import qax 5 | 6 | from lorax import LORA_FREEZE, LORA_FULL 7 | from lorax.helpers import split_lora_params, merge_params, wrap_optimizer 8 | 9 | def test_split(simple_params): 10 | _, _, params = simple_params 11 | 12 | tree = {'x': params, 'y': [params, jnp.zeros(5)]} 13 | spec = {'x': 5, 'y': [5, LORA_FULL]} 14 | 15 | split = split_lora_params(tree, spec) 16 | 17 | orig_struct = qax.utils.tree_structure_with_implicit(tree) 18 | struct = qax.utils.tree_structure_with_implicit(split) 19 | 20 | assert orig_struct == struct 21 | 22 | leaves = qax.utils.tree_leaves_with_implicit(split) 23 | 24 | for lora_leaf in leaves[:2]: 25 | assert lora_leaf.w is qax.EmptyNode 26 | assert isinstance(lora_leaf.a, jax.Array) 27 | assert isinstance(lora_leaf.b, jax.Array) 28 | 29 | assert isinstance(leaves[2], jax.Array) 30 | 31 | def test_merge(simple_params): 32 | w, _, params = simple_params 33 | params = jax.tree_map(lambda x: jnp.copy(x), params) 34 | 35 | tree = {'x': params, 'y': jnp.zeros(5)} 36 | structure = qax.utils.tree_structure_with_implicit(tree) 37 | 38 | merged = merge_params(tree) 39 | assert qax.utils.tree_structure_with_implicit(merged) == structure 40 | 41 | def test_wrap_optimizer(simple_params): 42 | _, x, params = simple_params 43 | 44 | params = {'u': params, 'y': {'z': jnp.ones(2), 'w': jnp.zeros(3)}} 45 | spec = {'u': 1234, 'y': {'z': LORA_FREEZE, 'w': LORA_FULL}} 46 | 47 | @qax.use_implicit_args 48 | def f(params, x): 49 | return jnp.sum(params['u'] @ x) + jnp.sum(params['y']['z']) + jnp.sum(params['y']['w']) 50 | 51 | grad = jax.grad(f)(params, x) 52 | 53 | opt = wrap_optimizer(optax.sgd(1e-3), spec) 54 | state = opt.init(params) 55 | 56 | updates, state = opt.update(grad, state, params) 57 | 58 | new_params = optax.apply_updates(params, updates) 59 | 60 | u = params['u'] 61 | new_u = new_params['u'] 62 | assert jnp.all(u.w == new_u.w) 63 | assert jnp.all(u.a != new_u.a) 64 | assert jnp.all(u.b != new_u.b) 65 | 66 | assert jnp.all(params['y']['z'] == new_params['y']['z']) 67 | assert jnp.all(params['y']['w'] != new_params['y']['w']) 68 | -------------------------------------------------------------------------------- /tests/tests.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import warnings 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import pytest 7 | 8 | from lorax import lora, init_lora, LoraWeight, LORA_FULL, LORA_FREEZE 9 | from lorax.constants import LORA_FULL, LORA_FREEZE 10 | 11 | @pytest.fixture(autouse=True) 12 | def catch_materialization_warnings(recwarn): 13 | warnings.filterwarnings('error', message='Primitive.*not handled') 14 | 15 | def test_materialize(simple_params): 16 | full, _, lora = simple_params 17 | assert jnp.allclose(lora.materialize(), full) 18 | 19 | def test_prepare(): 20 | w_shape = 3, 4 21 | params = { 22 | 'W': jnp.zeros(w_shape), 23 | 'b': jnp.zeros((4,)), 24 | 'W2': jnp.zeros((4, 5)), 25 | } 26 | spec = { 27 | 'W': 2, 28 | 'b': LORA_FREEZE, 29 | 'W2': LORA_FULL 30 | } 31 | 32 | lora_params = init_lora(params, spec, jax.random.PRNGKey(0)) 33 | 34 | assert isinstance(lora_params['W'], LoraWeight) 35 | assert lora_params['W'].shape == params['W'].shape 36 | assert lora_params['b'] is params['b'] 37 | assert lora_params['W2'] is params ['W2'] 38 | 39 | def test_simple(): 40 | key, init_key = jax.random.split(jax.random.PRNGKey(17)) 41 | batch = 5 42 | time = 7 43 | hidden = 11 44 | output = 13 45 | x = jax.random.normal(key, (batch, time, hidden)) 46 | 47 | params = [ 48 | jax.random.normal(key, (hidden, output)), 49 | ] 50 | 51 | def f(params, x): 52 | return x @ params[0] 53 | 54 | orig_output = f(params, x) 55 | 56 | lora_params = init_lora(params, [2], rng=init_key) 57 | 58 | lora_f = lora(f) 59 | lora_output = lora_f(lora_params, x) 60 | 61 | assert jnp.allclose(orig_output, lora_output) 62 | 63 | lora_params[0].b = jax.random.normal(key, (hidden, 2)) * 10 64 | 65 | perturbed_lora = lora_f(lora_params, x) 66 | 67 | combined_params = [lora_params[0].materialize()] 68 | combined_output = f(combined_params, x) 69 | 70 | print(f'Gap: {jnp.abs(combined_output - perturbed_lora).max()}') 71 | assert jnp.allclose(perturbed_lora, combined_output, atol=1e-5) 72 | 73 | def test_right_matmul(simple_params): 74 | w, _, lora_params = simple_params 75 | x = jax.random.normal(jax.random.PRNGKey(3), (10, w.shape[0])) 76 | def f(w, x): 77 | return x @ w 78 | 79 | lora_f = lora(f) 80 | lora_result = lora_f(lora_params, x) 81 | 82 | orig_result = f(w, x) 83 | 84 | assert jnp.allclose(lora_result, orig_result, atol=1e-4) 85 | 86 | def test_conv(): 87 | key, a_key, b_key = jax.random.split(jax.random.PRNGKey(18), 3) 88 | batch = 7 89 | time = 11 90 | hidden = 13 91 | output = 17 92 | rank_constraint = 3 93 | window_size = 2 94 | x = jax.random.normal(key, (batch, time, hidden)) 95 | 96 | def fn(w, x): 97 | return jax.lax.conv_general_dilated( 98 | lhs=x, 99 | rhs=w, 100 | window_strides=(1,), 101 | dimension_numbers=jax.lax.ConvDimensionNumbers( 102 | (0, 2, 1), 103 | (2, 1, 0), 104 | (0, 2, 1) 105 | ), 106 | padding='VALID' 107 | ) 108 | 109 | a = jax.random.normal(b_key, (1, rank_constraint, output)) 110 | b = jax.random.normal(a_key, (window_size, hidden, rank_constraint)) 111 | 112 | lora_params = LoraWeight( 113 | w=jnp.zeros((window_size, hidden, output)), 114 | a=a, 115 | b=b 116 | ) 117 | 118 | w = lora_params.materialize() 119 | 120 | lora_fn = lora(fn) 121 | orig_result = fn(w, x) 122 | lora_result = lora_fn(lora_params, x) 123 | print(f'Orig: {orig_result[:3, :3, :3]}') 124 | print(f'Lora: {lora_result[:3, :3, :3]}') 125 | 126 | assert jnp.allclose(orig_result, lora_result, rtol=1e-3) 127 | 128 | def test_embedding(): 129 | key, a_key, b_key, emb_key = jax.random.split(jax.random.PRNGKey(19), 4) 130 | batch = 11 131 | time = 13 132 | vocab = 4321 133 | hidden = 100 134 | 135 | rank_constraint = 19 136 | 137 | ids = jax.random.randint(key, (batch, time), 0, vocab) 138 | 139 | a = jax.random.normal(b_key, (rank_constraint, hidden)) 140 | b = jax.random.normal(a_key, (vocab, rank_constraint)) 141 | 142 | def f(w, x): 143 | return jax.lax.gather( 144 | w, 145 | x[:, :, None], 146 | dimension_numbers=jax.lax.GatherDimensionNumbers( 147 | offset_dims=(2,), 148 | collapsed_slice_dims=(0,), 149 | start_index_map=(0,), 150 | ), 151 | mode=jax.lax.GatherScatterMode.PROMISE_IN_BOUNDS, 152 | slice_sizes=(1, hidden) 153 | ) 154 | 155 | lora_params = LoraWeight( 156 | w=jax.random.normal(emb_key, (vocab, hidden)), 157 | a=a, 158 | b=b 159 | ) 160 | w = lora_params.materialize() 161 | 162 | lora_f = lora(f) 163 | 164 | orig_result = f(w, ids) 165 | lora_result = lora_f(lora_params, ids) 166 | 167 | gap = jnp.max(jnp.abs(orig_result - lora_result)) 168 | print(f'Gap: {gap:.3e}') 169 | assert jnp.allclose(orig_result, lora_result, atol=1e-5) 170 | 171 | def test_einsum(simple_params): 172 | w, x, lora_params = simple_params 173 | 174 | def f(w, x): 175 | return jnp.einsum('ij,jk->ik', w, x) 176 | 177 | expected = f(w, x) 178 | 179 | lora_f = lora(f) 180 | result = lora_f(lora_params, x) 181 | assert jnp.allclose(expected, result, rtol=1e-4) 182 | 183 | def test_remat(simple_params): 184 | w, x, lora_params = simple_params 185 | 186 | h = jax.random.normal(jax.random.PRNGKey(0), (x.shape[1],)) 187 | def f(w, x): 188 | return w @ x + h 189 | 190 | f = jax.remat(f) 191 | lora_f = jax.jit(lora(f)) 192 | 193 | expected = f(w, x) 194 | res = lora_f(lora_params, x) 195 | assert jnp.allclose(expected, res, rtol=1e-4) 196 | 197 | def test_transpose(simple_params): 198 | w, x, lora_params = simple_params 199 | def f(w, x): 200 | return x.T @ w.T 201 | 202 | lora_f = jax.jit(lora(f)) 203 | 204 | expected = f(w, x) 205 | res = lora_f(lora_params, x) 206 | print(f'Gap: {jnp.max(jnp.abs(expected - res)):.3e}') 207 | assert jnp.allclose(expected, res, atol=1e-6) 208 | 209 | @pytest.mark.parametrize('lora_first,contract_lora,contract_x,x_ndim', [ 210 | (lf, cl, cx, nd) for lf, cl, cx, nd in 211 | product([True, False], [0, 1], [0, 1, 2], [2, 3]) 212 | if cx < nd 213 | ]) 214 | def test_dot_contraction(simple_params, lora_first, contract_lora, contract_x, x_ndim): 215 | w, _, lora_params = simple_params 216 | def f(w, x): 217 | lhs = w 218 | rhs = x 219 | lhs_contract = contract_lora 220 | rhs_contract = contract_x 221 | if not lora_first: 222 | lhs_contract, rhs_contract = rhs_contract, lhs_contract 223 | lhs, rhs = rhs, lhs 224 | 225 | return jax.lax.dot_general( 226 | lhs, 227 | rhs, 228 | (((lhs_contract,), (rhs_contract,)), ((), ())) 229 | ) 230 | 231 | x_shape = [23] 232 | if x_ndim == 3: 233 | x_shape.append(29) 234 | 235 | contract_size = w.shape[contract_lora] 236 | x_shape.insert(contract_x, contract_size) 237 | 238 | x = jax.random.normal(jax.random.PRNGKey(0), x_shape) 239 | 240 | expected = f(w, x) 241 | 242 | lora_f = lora(f) 243 | lora_result = lora_f(lora_params, x) 244 | 245 | print(f'Gap: {jnp.max(jnp.abs(expected - lora_result)):.3e}') 246 | assert jnp.allclose(expected, lora_result, atol=1e-5) 247 | 248 | def test_cast(simple_params): 249 | w, x, lora_params = simple_params 250 | def f(w, x): 251 | return w.astype(jnp.float16) @ x.astype(jnp.float16) 252 | 253 | lora_f = lora(f) 254 | 255 | expected = f(w, x) 256 | res = lora_f(lora_params, x) 257 | 258 | print(f'Gap: {jnp.max(jnp.abs(expected - res)):.3e}') 259 | assert jnp.allclose(expected, res, atol=1e-2) 260 | 261 | def test_warning(simple_params): 262 | _, _, lora_params = simple_params 263 | def f(w): 264 | return w[:10, 3:] 265 | 266 | lora_f = lora(f) 267 | 268 | with pytest.warns(UserWarning, match='materialized'): 269 | lora_f(lora_params) 270 | 271 | --------------------------------------------------------------------------------