├── qax ├── common │ ├── __init__.py │ ├── type_utils.py │ └── utils.py ├── implicit │ ├── __init__.py │ ├── implicit_utils.py │ └── implicit_array.py ├── utils.py ├── __init__.py ├── primitives.py ├── constants.py └── symbols.py ├── pyproject.toml ├── examples ├── nesting.py ├── combining.py ├── nullable_array.py ├── zero.py ├── identity.py ├── const.py └── How_to_Qax.ipynb ├── LICENSE ├── tests ├── grad.py ├── nested.py ├── utils.py ├── symbols.py ├── scan.py ├── transform.py └── primitives.py └── README.md /qax/common/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /qax/implicit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /qax/utils.py: -------------------------------------------------------------------------------- 1 | from .common.utils import * 2 | from .implicit.implicit_utils import * 3 | -------------------------------------------------------------------------------- /qax/__init__.py: -------------------------------------------------------------------------------- 1 | from . import symbols 2 | from .common import type_utils 3 | from .implicit import implicit_array 4 | from .implicit.implicit_array import ( 5 | ImplicitArray, 6 | UninitializedAval, 7 | aux_field, 8 | use_implicit_args, 9 | ) 10 | from .primitives import ArrayValue, default_handler, primitive_handler 11 | from .utils import EmptyNode, freeze_keys, materialize_nested 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = 'qax' 3 | version = '0.4.1' 4 | description = 'A JAX transform for writing things which pretend to be tensors' 5 | authors = ['Davis Yoshida '] 6 | license = 'MIT' 7 | readme = 'README.md' 8 | packages = [ 9 | {include = 'qax'} 10 | ] 11 | 12 | [tool.poetry.dependencies] 13 | python = '^3.10' 14 | jax = '^0.4.24' 15 | jaxlib = '^0.4.24' 16 | plum-dispatch = '^2.1.0' 17 | optax = '^0.1.5' 18 | 19 | [tool.poetry.dev-dependencies] 20 | pytest = '^7.3.1' 21 | 22 | [build-system] 23 | requires = ['poetry-core>=1.0.0'] 24 | build-backend = 'poetry.core.masonry.api' 25 | -------------------------------------------------------------------------------- /examples/nesting.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from examples.nullable_array import NullableArray 5 | from examples.zero import ImplicitZeros 6 | from qax import use_implicit_args 7 | 8 | 9 | @jax.jit 10 | @use_implicit_args 11 | def f(x, y): 12 | return (x * y) + 3 13 | 14 | 15 | shape = (2, 3) 16 | x = NullableArray( 17 | val=ImplicitZeros(shape=shape), 18 | mask=jnp.asarray( 19 | [[True, False, True], [False, True, True]], 20 | ), 21 | ) 22 | 23 | y = NullableArray( 24 | val=jnp.ones(shape), mask=jnp.asarray([[True, True, True], [False, False, True]]) 25 | ) 26 | 27 | result = f(x, y) 28 | print(f"Result:\n{result.val}") 29 | print(f"Mask:\n{result.mask}") 30 | -------------------------------------------------------------------------------- /examples/combining.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from examples.const import ImplicitConst 7 | from examples.zero import ImplicitZeros 8 | from qax import primitive_handler, use_implicit_args 9 | 10 | 11 | @use_implicit_args 12 | def f(x, y): 13 | return x * y 14 | 15 | 16 | def main(): 17 | shape = (2, 3) 18 | zeros = ImplicitZeros(shape=shape, dtype=jnp.float32) 19 | const = ImplicitConst(1.0, shape=shape) 20 | 21 | assert isinstance(f(const, zeros), jax.Array) 22 | 23 | @primitive_handler("mul") 24 | def heterogenous_handler( 25 | primitive, 26 | x: Union[ImplicitZeros, ImplicitConst], 27 | y: Union[ImplicitZeros, ImplicitConst], 28 | ): 29 | out_shape = jnp.broadcast_shapes(x.shape, y.shape) 30 | out_dtype = jnp.result_type(x.dtype, y.dtype) 31 | return ImplicitZeros(shape=out_shape, dtype=out_dtype) 32 | 33 | assert isinstance(f(const, zeros), ImplicitZeros) 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /qax/common/type_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Tuple 2 | 3 | from beartype.vale import IsInstance 4 | from plum import dispatch, parametric 5 | 6 | 7 | class _ComplementMeta(type): 8 | def __instancecheck__(self, x): 9 | a, b = self.type_parameter 10 | return a is None or (isinstance(x, a) and not isinstance(x, b)) 11 | 12 | 13 | @parametric 14 | class Complement(metaclass=_ComplementMeta): 15 | """ 16 | Relative complement 17 | I.e. Complement[A, B] = A - B 18 | """ 19 | 20 | @classmethod 21 | @dispatch 22 | def __init_type_parameter__( 23 | cls, 24 | a: Optional[Any], 25 | b: Optional[Any], 26 | ): 27 | return a, b 28 | 29 | @classmethod 30 | @dispatch 31 | def __le_type_parameter__( 32 | cls, 33 | left: Tuple[Optional[Any], Optional[Any]], 34 | right: Tuple[Optional[Any], Optional[Any]], 35 | ): 36 | a_left, b_left = left 37 | a_right, b_right = right 38 | 39 | return issubclass(a_left, a_right) and issubclass(b_right, b_left) 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Davis Yoshida 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/grad.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from qax import ArrayValue, ImplicitArray, primitive_handler, use_implicit_args 7 | 8 | 9 | @dataclass 10 | class TwoMatricesInATrenchcoat(ImplicitArray): 11 | a: ArrayValue 12 | b: ArrayValue 13 | 14 | def materialize(self): 15 | return self.a + self.b 16 | 17 | 18 | @primitive_handler("mul") 19 | def handler(primitive, x: TwoMatricesInATrenchcoat, y: jax.Array): 20 | return TwoMatricesInATrenchcoat(x.a * y, x.b * y) 21 | 22 | 23 | def test_grad(): 24 | shape = (5, 7) 25 | k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3) 26 | x = TwoMatricesInATrenchcoat( 27 | jax.random.normal(k1, shape), 28 | jax.random.normal(k2, shape), 29 | ) 30 | y = jax.random.normal(k3, shape) 31 | 32 | @use_implicit_args 33 | def f(x, y): 34 | return jnp.sum(x * y) 35 | 36 | def explicit_f(a, b, y): 37 | return jnp.sum((a * y) + (b * y)) 38 | 39 | x_grad = jax.grad(f)(x, y) 40 | y_grad = jax.grad(f, 1)(x, y) 41 | 42 | a_grad_expected = jax.grad(explicit_f)(x.a, x.b, y) 43 | b_grad_expected = jax.grad(explicit_f, 1)(x.b, x.a, y) 44 | y_grad_expected = jax.grad(explicit_f, 2)(x.a, x.b, y) 45 | 46 | assert jnp.allclose(x_grad.a, a_grad_expected) 47 | assert jnp.allclose(x_grad.b, b_grad_expected) 48 | assert jnp.allclose(y_grad, y_grad_expected) 49 | -------------------------------------------------------------------------------- /tests/nested.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from qax import ArrayValue, ImplicitArray, primitive_handler, use_implicit_args 8 | 9 | 10 | @dataclass 11 | class Outer(ImplicitArray): 12 | x: ArrayValue 13 | 14 | def materialize(self): 15 | return 2 * (self.x**1) 16 | 17 | 18 | @primitive_handler(jax.lax.mul_p) 19 | def mul(primitive, arg: Outer, other: jax.Array): 20 | return Outer(arg.x * other) 21 | 22 | 23 | @dataclass 24 | class Inner(ImplicitArray): 25 | value: ArrayValue 26 | 27 | def materialize(self): 28 | return jnp.full(self.shape, self.value, dtype=self.dtype) 29 | 30 | 31 | @primitive_handler(jax.lax.integer_pow_p) 32 | def pow(primitive, arg: Inner, *, y): 33 | new_value = arg.value**y 34 | return Inner(new_value, shape=arg.shape, dtype=arg.dtype) 35 | 36 | 37 | def test_nested(): 38 | @use_implicit_args 39 | def f(x): 40 | return jnp.sum(x) 41 | 42 | inner = Inner(3, shape=(2, 3), dtype=jnp.float32) 43 | nested = Outer(inner) 44 | result = f(nested) 45 | assert result == 36 46 | 47 | 48 | def test_nested_with_operation(): 49 | @use_implicit_args 50 | def f(x): 51 | return jnp.sum(x * jnp.ones(x.shape)) 52 | 53 | inner = Inner(3, shape=(2, 3), dtype=jnp.float32) 54 | nested = Outer(inner) 55 | with warnings.catch_warnings(): 56 | warnings.filterwarnings( 57 | "error", message="Primitive mul was not handled by class Outer" 58 | ) 59 | result = f(nested) 60 | assert result == 36 61 | -------------------------------------------------------------------------------- /qax/primitives.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from itertools import count 3 | 4 | import jax 5 | from plum import Dispatcher, Function 6 | 7 | 8 | class ArrayValue(ABC): 9 | pass 10 | 11 | 12 | ArrayValue.register(jax.Array) 13 | 14 | _dispatch = Dispatcher() 15 | 16 | _primitive_ids = count() 17 | 18 | 19 | def get_lax_primitive_by_name(name): 20 | return getattr(jax.lax, f"{name}_p") 21 | 22 | 23 | def get_primitive_handler(primitive): 24 | if isinstance(primitive, str): 25 | primitive = get_lax_primitive_by_name(primitive) 26 | handler = _dispatch.functions.get(primitive) 27 | if handler is None: 28 | 29 | def _not_impl_handler(primitive: jax.core.Primitive, *args, **kwargs): 30 | return NotImplemented 31 | 32 | _not_impl_handler.__doc__ = "Default handler for {primitive.name}" 33 | handler = Function(_not_impl_handler) 34 | handler.register(_not_impl_handler, precedence=-1e9) 35 | handler.__name__ = f"{primitive.name}_{next(_primitive_ids)}" 36 | _dispatch.functions[primitive] = handler 37 | return handler 38 | 39 | 40 | def primitive_handler(primitives, precedence=0): 41 | if isinstance(primitives, (str, jax.core.Primitive)): 42 | primitives = [primitives] 43 | 44 | def decorator(fn): 45 | for primitive in primitives: 46 | handler = get_primitive_handler(primitive) 47 | handler.register(fn, precedence=precedence) 48 | 49 | return decorator 50 | 51 | 52 | def default_handler(primitive, *args, **params): 53 | subfuns, bind_params = primitive.get_bind_params(params) 54 | return primitive.bind(*subfuns, *args, **bind_params) 55 | -------------------------------------------------------------------------------- /examples/nullable_array.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of an ImplicitArray which represents an array + a boolean mask representing the validity of each entry. 3 | This is a proof of concept and is not optimized for performance. 4 | """ 5 | 6 | from dataclasses import dataclass 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | 11 | from qax import ( 12 | ArrayValue, 13 | ImplicitArray, 14 | default_handler, 15 | primitive_handler, 16 | use_implicit_args, 17 | ) 18 | from qax.constants import ELEMENTWISE_BINOPS, ELEMENTWISE_UNOPS, REDUCTION_OPS 19 | 20 | 21 | @dataclass 22 | class NullableArray(ImplicitArray): 23 | val: ArrayValue 24 | mask: ArrayValue 25 | 26 | def materialize(self): 27 | return self.val 28 | 29 | 30 | @primitive_handler(ELEMENTWISE_UNOPS) 31 | def handle_unop(primitive, nullable_val: NullableArray, **params): 32 | val = default_handler(primitive, nullable_val.val, **params) 33 | return NullableArray(val, nullable_val.mask) 34 | 35 | 36 | @primitive_handler(ELEMENTWISE_BINOPS) 37 | def handle_binop(primitive, lhs: ArrayValue, rhs: ArrayValue, **params): 38 | lhs_is_nullable = isinstance(lhs, NullableArray) 39 | rhs_is_nullable = isinstance(rhs, NullableArray) 40 | mask = lhs.mask if lhs_is_nullable else None 41 | 42 | if lhs_is_nullable: 43 | lhs = lhs.val 44 | 45 | if rhs_is_nullable: 46 | mask = rhs.mask if mask is None else mask & rhs.mask 47 | rhs = rhs.val 48 | 49 | out_val = default_handler(primitive, lhs, rhs, **params) 50 | return NullableArray(out_val, mask) 51 | 52 | 53 | @primitive_handler(REDUCTION_OPS) 54 | def handle_reduction(primitive, null_arr: NullableArray, **params): 55 | new_val = default_handler(primitive, null_arr.val, **params) 56 | new_mask = default_handler(jax.lax.reduce_and_p, null_arr.mask, **params) 57 | return NullableArray(new_val, new_mask) 58 | 59 | 60 | @jax.jit 61 | @use_implicit_args 62 | def f(x, y): 63 | return jnp.sum(-x * y, axis=0) 64 | 65 | 66 | if __name__ == "__main__": 67 | x = NullableArray( 68 | val=jnp.ones((2, 3)), 69 | mask=jnp.asarray([[True, False, True], [False, True, True]]), 70 | ) 71 | 72 | y = NullableArray( 73 | val=jnp.full((2, 3), 3), 74 | mask=jnp.asarray([[False, True, True], [True, True, True]]), 75 | ) 76 | 77 | output = f(x, y) 78 | print(f"Result: {output.val}") 79 | print(f"Mask: {output.mask}") 80 | -------------------------------------------------------------------------------- /examples/zero.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from qax import ImplicitArray, primitive_handler, use_implicit_args 7 | 8 | 9 | class ImplicitZeros(ImplicitArray): 10 | default_dtype = jnp.float32 11 | 12 | def materialize(self): 13 | return jnp.zeros(self.shape, dtype=self.dtype) 14 | 15 | 16 | def binop_shape_dtype(x, y): 17 | return { 18 | "shape": jnp.broadcast_shapes(x.shape, y.shape), 19 | "dtype": jnp.result_type(x.dtype, y.dtype), 20 | } 21 | 22 | 23 | @primitive_handler(jax.lax.mul_p) 24 | def do_mul(primitive, x: ImplicitZeros, y: jax.Array): 25 | print("Invoked do_mul") 26 | return ImplicitZeros(**binop_shape_dtype(x, y)) 27 | 28 | 29 | @primitive_handler([jax.lax.add_p, jax.lax.mul_p]) 30 | def handle_both_implicit(primitive, x: ImplicitZeros, y: ImplicitZeros): 31 | print("Invoked handle_both_implicit") 32 | return ImplicitZeros(**binop_shape_dtype(x, y)) 33 | 34 | 35 | @primitive_handler(jax.lax.add_p) 36 | def handle_add_general(primitive, x: ImplicitZeros, y: jax.Array): 37 | print("Invoked handle_add") 38 | shape_dtype = binop_shape_dtype(x, y) 39 | return jnp.broadcast_to(y, shape_dtype["shape"]).astype(shape_dtype["dtype"]) 40 | 41 | 42 | def main(): 43 | @jax.jit 44 | @use_implicit_args 45 | def f(x, y): 46 | # If x and y are both of type ImplicitZeros, the result will be: 47 | 48 | z = x + y # z: ImplicitZeros output: Invoked handle_both_implicit 49 | w = z * jnp.ones_like(z) # w: ImplicitZeros output: Invoked do_mul 50 | a = jnp.sum( 51 | w 52 | ) # a: jax.Array output: UserWarning: Primitive reduce_sum was not 53 | # handled by class ImplicitZeros, so implicit 54 | # args will be materialized 55 | b = w + a # b: jax.Array output: Invoked handle_add 56 | return b 57 | 58 | zeros = ImplicitZeros(shape=(2, 3)) 59 | 60 | result = jax.jit(f)(zeros, zeros) 61 | 62 | assert isinstance(result, jax.Array) 63 | assert result.shape == zeros.shape 64 | assert jnp.all(result == 0) 65 | 66 | # The decorated f will also work with mixed arguments or non-implicit arguments 67 | jnp_ones = jnp.ones(zeros.shape) 68 | f(zeros, jnp_ones) 69 | f(jnp_ones, zeros) 70 | f(jnp_ones, jnp_ones) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /examples/identity.py: -------------------------------------------------------------------------------- 1 | from dataclasses import InitVar, dataclass 2 | from functools import partial 3 | from typing import Union 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | import qax 9 | from examples.minimal_lora import LoraMatrix 10 | 11 | 12 | @dataclass 13 | class Eye(qax.ImplicitArray): 14 | dim: InitVar[int] 15 | 16 | def __init__(self, dim, dtype=jnp.float32): 17 | super().__init__(shape=(dim, dim), dtype=dtype) 18 | 19 | def materialize(self): 20 | return jnp.eye(self.shape[0], dtype=self.dtype) 21 | 22 | 23 | @qax.primitive_handler(jax.lax.dot_general_p) 24 | def dot_handler( 25 | primitive, 26 | lhs: Union[Eye, jax.Array], 27 | rhs: Union[Eye, jax.Array], 28 | *, 29 | dimension_numbers, 30 | **kwargs 31 | ): 32 | lhs_aval = jax.core.get_aval(lhs) 33 | rhs_aval = jax.core.get_aval(rhs) 34 | 35 | out_aval = jax.eval_shape( 36 | partial( 37 | qax.default_handler, 38 | primitive, 39 | dimension_numbers=dimension_numbers, 40 | **kwargs 41 | ), 42 | lhs_aval, 43 | rhs_aval, 44 | ) 45 | 46 | lhs_is_eye = isinstance(lhs, Eye) 47 | rhs_is_eye = isinstance(rhs, Eye) 48 | 49 | (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers 50 | # It's 1 AM and I can only conceptualize dot_generals during the PM hours 51 | # so I will only be implementing 1-2D x 2D matmuls 52 | if not ( 53 | 1 <= lhs.aval.ndim <= 2 54 | and rhs.aval.ndim <= 2 55 | and len(lhs_contract) == len(rhs_contract) == 1 56 | and lhs_batch == rhs_batch == () 57 | ): 58 | return NotImplemented 59 | 60 | if lhs_is_eye and rhs_is_eye: 61 | return Eye(out_aval.shape[0], dtype=out_aval.dtype) 62 | 63 | result = rhs if lhs_is_eye else lhs 64 | return result.astype(out_aval.dtype) 65 | 66 | 67 | def main(): 68 | @qax.use_implicit_args 69 | def f(a, b): 70 | return a @ b 71 | 72 | w = Eye(3) 73 | x = jnp.arange(39, dtype=jnp.float32).reshape(3, 13) 74 | 75 | print(f(w, x)) 76 | 77 | dim = 128 78 | rank = 16 79 | eye_plus_low_rank = LoraMatrix( 80 | W=Eye(dim), 81 | A=jax.random.normal(jax.random.PRNGKey(0), (dim, rank)), 82 | B=jnp.zeros((dim, rank)), 83 | ) 84 | 85 | x = jax.random.normal(jax.random.PRNGKey(1), (73, dim)) 86 | print(jnp.sum(f(x, eye_plus_low_rank))) 87 | 88 | 89 | if __name__ == "__main__": 90 | main() 91 | -------------------------------------------------------------------------------- /qax/constants.py: -------------------------------------------------------------------------------- 1 | # WARNING: This file is obviously super incomplete, and is 2 | # currently just for convenience in testing. 3 | 4 | COMMUTATIVE_OPS = frozenset( 5 | [ 6 | "add", 7 | "bitwise_and", 8 | "bitwise_or", 9 | "bitwise_xor", 10 | "eq", 11 | "max", 12 | "min", 13 | "mul", 14 | "ne", 15 | ] 16 | ) 17 | 18 | ELEMENTWISE_UNOPS = frozenset( 19 | [ 20 | "abs", 21 | "acos", 22 | "acosh", 23 | "asin", 24 | "asinh", 25 | "atan", 26 | "atanh", 27 | "bessel_i0e", 28 | "bessel_i1e", 29 | "cbrt", 30 | "ceil", 31 | "clz", 32 | "conj", 33 | "convert_element_type", 34 | "copy", 35 | "cos", 36 | "cosh", 37 | "digamma", 38 | "erf_inv", 39 | "erf", 40 | "erfc", 41 | "exp", 42 | "expm1", 43 | "floor", 44 | "imag", 45 | "integer_pow", 46 | "is_finite", 47 | "lgamma", 48 | "log1p", 49 | "log", 50 | "logistic", 51 | "neg", 52 | "not", 53 | "population_count", 54 | "real", 55 | "reduce_precision", 56 | "round", 57 | "rsqrt", 58 | "sign", 59 | "sin", 60 | "sinh", 61 | "sqrt", 62 | "tan", 63 | "tanh", 64 | ] 65 | ) 66 | 67 | ELEMENTWISE_BINOPS = frozenset( 68 | [ 69 | "add", 70 | "and", 71 | "atan2", 72 | "complex", 73 | "div", 74 | "eq", 75 | "ge", 76 | "gt", 77 | "igamma_grad_a", 78 | "igamma", 79 | "igammac", 80 | "le", 81 | "lt", 82 | "max", 83 | "min", 84 | "mul", 85 | "ne", 86 | "nextafter", 87 | "or", 88 | "pow", 89 | "random_gamma_grad", 90 | "rem", 91 | "shift_left", 92 | "shift_right_arithmetic", 93 | "shift_right_logical", 94 | "sub", 95 | "xor", 96 | ] 97 | ) 98 | 99 | REDUCTION_OPS = frozenset( 100 | [ 101 | "argmax", 102 | "argmin", 103 | "reduce_and", 104 | "reduce_max", 105 | "reduce_min", 106 | "reduce_or", 107 | "reduce_prod", 108 | "reduce_sum", 109 | "reduce_xor", 110 | ] 111 | ) 112 | 113 | CUMULATIVE_REDUCTION_OPS = frozenset( 114 | [ 115 | "cumlogsumexp", 116 | "cummax", 117 | "cummin", 118 | "cumprod", 119 | "cumsum", 120 | ] 121 | ) 122 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import optax 6 | import pytest 7 | from jax.tree_util import tree_structure 8 | 9 | from qax import ArrayValue, ImplicitArray, utils 10 | 11 | 12 | @dataclass 13 | class Container(ImplicitArray): 14 | a: ArrayValue 15 | b: ArrayValue 16 | 17 | def materialize(self): 18 | return self.a 19 | 20 | 21 | @pytest.fixture(scope="module", params=[0, 1, 2, 3]) 22 | def container_with_depth(request): 23 | a = jnp.arange(10) 24 | for d in range(request.param): 25 | a = Container(a, jnp.zeros(d)) 26 | 27 | return a, request.param 28 | 29 | 30 | def test_count_depth(container_with_depth): 31 | container, depth = container_with_depth 32 | assert utils.implicit_depth(container) == depth 33 | 34 | 35 | def test_flatten_one_layer(container_with_depth): 36 | container, depth = container_with_depth 37 | pytree = [{"x": container}, {"y": container}] 38 | flat, struct = utils.flatten_one_implicit_layer(pytree) 39 | 40 | unflattened = jax.tree_util.tree_unflatten(struct, flat) 41 | assert jax.tree_util.tree_structure(unflattened) == jax.tree_util.tree_structure( 42 | pytree 43 | ) 44 | assert utils.implicit_depth(flat) == max(depth - 1, 0) 45 | 46 | 47 | def _get_prefix(*containers): 48 | return [ 49 | transform(c) 50 | for c, transform in zip( 51 | containers, utils.get_common_prefix_transforms(containers) 52 | ) 53 | ] 54 | 55 | 56 | def test_prefix(): 57 | c1 = Container( 58 | a=Container(jnp.zeros(10), jnp.zeros(10)), 59 | b=Container(jnp.zeros(3), jnp.zeros(13)), 60 | ) 61 | c2 = Container(a=Container(jnp.zeros(10), jnp.zeros(10)), b=jnp.zeros(3)) 62 | 63 | full_materialized_c1, _ = _get_prefix(c1, jnp.ones(10)) 64 | assert isinstance(full_materialized_c1, jax.Array) 65 | assert jnp.all(full_materialized_c1 == jnp.zeros(10)) 66 | 67 | c3 = Container( 68 | a=Container(jnp.zeros(10), jnp.zeros(3)), 69 | b=Container(jnp.zeros(3), jnp.zeros(13)), 70 | ) 71 | 72 | prefix_c1, prefix_c3 = _get_prefix(c1, c3) 73 | expected = Container(a=jnp.zeros(10), b=Container(jnp.zeros(3), jnp.zeros(13))) 74 | assert ( 75 | tree_structure(prefix_c1) 76 | == tree_structure(prefix_c3) 77 | == tree_structure(expected) 78 | ) 79 | 80 | c4 = Container( 81 | a=Container(a=Container(jnp.ones(10), jnp.zeros(3)), b=jnp.zeros(3)), 82 | b=jnp.zeros(10), 83 | ) 84 | 85 | c5 = Container( 86 | a=jnp.zeros(10), 87 | b=Container( 88 | Container(jnp.zeros(10), jnp.zeros(3)), 89 | Container(jnp.zeros(3), jnp.zeros(13)), 90 | ), 91 | ) 92 | 93 | prefix_c4, prefix_c5 = _get_prefix(c4, c5) 94 | expected = Container(a=jnp.zeros(10), b=jnp.zeros(10)) 95 | assert ( 96 | tree_structure(prefix_c4) 97 | == tree_structure(prefix_c5) 98 | == tree_structure(expected) 99 | ) 100 | -------------------------------------------------------------------------------- /tests/symbols.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import operator as fn 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import pytest 8 | 9 | import qax 10 | from qax.symbols import SymbolicConstant 11 | 12 | add = qax.use_implicit_args(fn.add) 13 | mul = qax.use_implicit_args(fn.mul) 14 | 15 | default_shape = (3, 5) 16 | 17 | 18 | @pytest.fixture 19 | def arr(): 20 | return jax.random.normal(jax.random.PRNGKey(0), default_shape) 21 | 22 | 23 | @pytest.fixture 24 | def zeros(): 25 | return SymbolicConstant(0, shape=default_shape, dtype=jnp.float32) 26 | 27 | 28 | @pytest.fixture 29 | def ones(): 30 | return SymbolicConstant(1, shape=default_shape, dtype=jnp.float32) 31 | 32 | 33 | @pytest.fixture 34 | def pis(): 35 | return SymbolicConstant(jnp.pi, shape=default_shape, dtype=jnp.float32) 36 | 37 | 38 | def test_add(zeros, arr, pis): 39 | z_plus_z = add(zeros, zeros) 40 | 41 | assert isinstance(z_plus_z, SymbolicConstant) 42 | assert z_plus_z.value == 0 43 | 44 | assert jnp.allclose(add(zeros, arr), arr) 45 | 46 | pi_plus_pi = add(pis, pis) 47 | assert isinstance(pi_plus_pi, SymbolicConstant) 48 | assert jnp.isclose(pi_plus_pi.value, 2 * jnp.pi) 49 | 50 | pi_plus_arr = add(pis, arr) 51 | assert isinstance(pi_plus_arr, jnp.ndarray) 52 | assert jnp.allclose(pi_plus_arr, arr + jnp.pi) 53 | 54 | 55 | def test_mul(zeros, ones, arr, pis): 56 | zero_times_one = mul(zeros, ones) 57 | 58 | assert isinstance(zero_times_one, SymbolicConstant) 59 | assert zero_times_one.value == 0 60 | 61 | assert jnp.all(mul(ones, arr) == arr) 62 | 63 | pi_times_pi = mul(pis, pis) 64 | assert isinstance(pi_times_pi, SymbolicConstant) 65 | assert jnp.isclose(pi_times_pi.value, jnp.pi**2) 66 | 67 | assert mul(pis, ones).value == jnp.pi 68 | 69 | 70 | _names = ["zeros", "ones", "pis"] 71 | 72 | 73 | @pytest.mark.parametrize( 74 | "fn,lhs,rhs", 75 | itertools.product( 76 | [jax.lax.add, jax.lax.mul, jax.lax.sub, jax.lax.atan2, jax.lax.max], 77 | _names, 78 | _names, 79 | ), 80 | ) 81 | def test_binop(fn, lhs, rhs, request): 82 | lhs = request.getfixturevalue(lhs) 83 | rhs = request.getfixturevalue(rhs) 84 | expected = fn(lhs.materialize(), rhs.materialize()) 85 | result = qax.use_implicit_args(fn)(lhs, rhs) 86 | assert isinstance(result, SymbolicConstant) 87 | assert jnp.allclose(result.value, expected) 88 | assert result.shape == expected.shape 89 | assert result.dtype == expected.dtype 90 | 91 | 92 | @pytest.mark.parametrize( 93 | "fn,arg", 94 | itertools.product( 95 | [jnp.sum, jnp.prod, jnp.all, jnp.any, jnp.sin, jnp.isfinite], _names 96 | ), 97 | ) 98 | def test_unop(fn, arg, request): 99 | value = request.getfixturevalue(arg) 100 | expected = fn(value.materialize()) 101 | result = qax.use_implicit_args(fn)(value) 102 | assert isinstance(result, SymbolicConstant) 103 | assert jnp.allclose(result.value, expected) 104 | assert result.shape == expected.shape 105 | assert result.dtype == expected.dtype 106 | 107 | 108 | def test_select_n(zeros, ones): 109 | @qax.use_implicit_args 110 | def f(c, x, y): 111 | return jax.lax.select_n(c, x, y) 112 | 113 | assert isinstance(f(True, zeros, ones), jnp.ndarray) 114 | assert isinstance(f(False, zeros, zeros), SymbolicConstant) 115 | -------------------------------------------------------------------------------- /tests/scan.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass 3 | from functools import partial 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | import qax 9 | from qax.symbols import SymbolicConstant 10 | 11 | 12 | def test_scan_empty_tree(): 13 | a = SymbolicConstant(1, shape=(3, 5), dtype=jnp.float32) 14 | 15 | def body(carry, x): 16 | return [carry[0] + jnp.sum(x), carry[1]], x[1] 17 | 18 | qax_scan = partial(qax.use_implicit_args(jax.lax.scan), body, [0, jnp.ones(11)]) 19 | 20 | materialized_output = qax_scan(a.materialize()) 21 | 22 | with warnings.catch_warnings(): 23 | warnings.filterwarnings("error", message="Primitive scan was not") 24 | output = qax_scan(a) 25 | 26 | assert jax.tree_util.tree_all( 27 | jax.tree_map(lambda x, y: jnp.all(x == y), materialized_output, output) 28 | ) 29 | 30 | 31 | def test_scan_arr_with_data(): 32 | @dataclass 33 | class SumArray(qax.ImplicitArray): 34 | a: jax.Array 35 | b: jax.Array 36 | 37 | def materialize(self): 38 | return 2 * (self.a + self.b) 39 | 40 | @qax.primitive_handler("add") 41 | def add(primitive, lhs: SumArray, rhs: jax.Array): 42 | return 2 * lhs.a + 2 * lhs.b + rhs 43 | 44 | a = SumArray(jnp.ones((3, 5)), jnp.ones((3, 5))) 45 | 46 | def body(carry, x): 47 | return jnp.sum(x + carry), None 48 | 49 | qax_scan = partial(qax.use_implicit_args(jax.lax.scan), body, 0.0) 50 | 51 | expected_output = qax_scan(a.materialize()) 52 | output = qax_scan(a) 53 | 54 | assert expected_output == output 55 | 56 | 57 | def test_output_implicit(): 58 | @dataclass 59 | class Wrapper(qax.ImplicitArray): 60 | a: jax.Array 61 | 62 | def materialize(self): 63 | return self.a 64 | 65 | @qax.primitive_handler("add") 66 | def add(primitive, lhs: Wrapper, rhs: jax.Array): 67 | return Wrapper(a=lhs.a + rhs) 68 | 69 | @qax.primitive_handler("reduce_sum") 70 | def reduce_sum(primitive, x: Wrapper, **kwargs): 71 | return Wrapper(qax.default_handler(primitive, x.a, **kwargs)) 72 | 73 | def body(carry, x): 74 | return jnp.sum(x + carry), None 75 | 76 | qax_scan = partial(qax.use_implicit_args(jax.lax.scan), body, 0.0) 77 | 78 | a = Wrapper(jnp.arange(6).reshape((3, 2)).astype(jnp.float32)) 79 | 80 | expected_output, _ = qax_scan(a.materialize()) 81 | output, _ = qax_scan(a) 82 | 83 | assert isinstance(output, Wrapper) 84 | assert expected_output == output.materialize() 85 | 86 | 87 | def test_scan_closure(): 88 | a = SymbolicConstant(1, shape=(3, 5), dtype=jnp.float32) 89 | 90 | @qax.use_implicit_args 91 | def f(a): 92 | def body(carry, x): 93 | return carry + jnp.sum(a), None 94 | 95 | return jax.lax.scan(body, 0.0, jnp.arange(10)) 96 | 97 | expected = f(a.materialize()) 98 | output = f(a) 99 | 100 | assert expected == output 101 | 102 | 103 | def test_scan_closure_nonempty(): 104 | @dataclass 105 | class Stacker(qax.ImplicitArray): 106 | a: jax.Array 107 | b: jax.Array 108 | 109 | def materialize(self): 110 | return jnp.concatenate([self.a, self.b], axis=0) 111 | 112 | w = Stacker(jnp.ones((3, 6)), jnp.ones((3, 6))) 113 | w2 = Stacker(jnp.ones((3, 6)), jnp.ones((3, 6))) 114 | 115 | @qax.use_implicit_args 116 | def f(w, w2): 117 | def body(carry, x): 118 | return carry @ w @ w2, x 119 | 120 | return jax.lax.scan(body, jnp.ones(6), jnp.arange(10))[0] 121 | 122 | expected = f(w.materialize(), w2.materialize()) 123 | output = f(w, w2) 124 | 125 | assert jnp.allclose(expected, output) 126 | -------------------------------------------------------------------------------- /qax/common/utils.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import optax 6 | from jax import tree_util 7 | from jax.dtypes import float0 8 | 9 | from ..implicit.implicit_array import use_implicit_args 10 | from ..symbols import SymbolicConstant 11 | 12 | 13 | def vmap_all_but_one(f, axis, out_ndim=0): 14 | """ 15 | Repeatedly calls vmap to map over all axes except for `axis.` 16 | All args will be mapped on the same dimensions. 17 | """ 18 | 19 | @wraps(f) 20 | def inner(*args): 21 | n_dim = args[0].ndim 22 | if axis >= n_dim: 23 | raise ValueError( 24 | f"Axis {axis} is out of bounds for array of dimension {n_dim}" 25 | ) 26 | fn = f 27 | vmap_dim = 1 28 | out_dim = out_ndim 29 | for i in reversed(range(n_dim)): 30 | if i == axis: 31 | vmap_dim = 0 32 | out_dim = 0 33 | else: 34 | fn = jax.vmap(fn, vmap_dim, out_dim) 35 | return fn(*args) 36 | 37 | return inner 38 | 39 | 40 | def freeze_subtrees( 41 | optimizer: optax.GradientTransformation, label_fn, use_scalar_zeros=False 42 | ): 43 | """ 44 | Utility which wraps an optimizer such that subtrees specified by 45 | label_fn will receive zeros as updates. 46 | Subtrees to be frozen should be labeled with "freeze" 47 | and all other subtrees should be labeled with "train" 48 | """ 49 | multi_transformed_optimizer = optax.multi_transform( 50 | { 51 | "freeze": set_to_zero_scalar() if use_scalar_zeros else optax.set_to_zero(), 52 | "train": optimizer, 53 | }, 54 | label_fn, 55 | ) 56 | 57 | def new_update(grads, opt_state, params): 58 | def map_float0(param, grad): 59 | if grad.dtype == float0: 60 | return ( 61 | jnp.zeros((), param.dtype) 62 | if use_scalar_zeros 63 | else jnp.zeros_like(param) 64 | ) 65 | return grad 66 | 67 | fixed_grads = jax.tree_map(map_float0, params, grads) 68 | return multi_transformed_optimizer.update(fixed_grads, opt_state, params) 69 | 70 | return optax.GradientTransformation(multi_transformed_optimizer.init, new_update) 71 | 72 | 73 | def freeze_keys( 74 | optimizer: optax.GradientTransformation, arr_type, keys, use_scalar_zeros=False 75 | ) -> optax.GradientTransformation: 76 | keys = set(keys) 77 | 78 | def label_leaf(leaf): 79 | if not isinstance(leaf, arr_type): 80 | return "train" 81 | 82 | children, aux_data = leaf.tree_flatten_with_keys() 83 | labels = ["freeze" if key in keys else "train" for key, _ in children] 84 | struct = leaf.tree_unflatten(aux_data, labels) 85 | return struct 86 | 87 | def label_fn(root): 88 | return jax.tree_map(label_leaf, root, is_leaf=lambda x: isinstance(x, arr_type)) 89 | 90 | return freeze_subtrees(optimizer, label_fn, use_scalar_zeros=use_scalar_zeros) 91 | 92 | 93 | def apply_updates(params: optax.Params, updates: optax.Updates) -> optax.Params: 94 | """ 95 | Like optax.apply_updates, but updates can be SymbolicConstant instances 96 | """ 97 | updates_flat, update_struct = tree_util.tree_flatten( 98 | updates, is_leaf=lambda x: isinstance(x, SymbolicConstant) 99 | ) 100 | semi_flat_params = update_struct.flatten_up_to(params) 101 | 102 | updated_flat = use_implicit_args(optax.apply_updates)( 103 | semi_flat_params, updates_flat 104 | ) 105 | updated = update_struct.unflatten(updated_flat) 106 | return updated 107 | 108 | 109 | def set_to_zero_scalar() -> optax.GradientTransformation: 110 | """ 111 | Returns a gradient transformation that sets all gradients to 0 in order to 112 | make downstream constant folding cheaper. 113 | """ 114 | 115 | def init_fn(params): 116 | del params 117 | return optax.EmptyState() 118 | 119 | def update_fn(updates, state, params=None): 120 | return jax.tree_map(lambda x: jnp.zeros((), x.dtype), updates), state 121 | 122 | return optax.GradientTransformation(init_fn, update_fn) 123 | -------------------------------------------------------------------------------- /tests/transform.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass 3 | from functools import partial 4 | from typing import Any, Union 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import pytest 10 | from jax.core import Primitive 11 | from jax.experimental import pjit 12 | 13 | from qax import ImplicitArray, primitive_handler, use_implicit_args, utils 14 | 15 | WARN_PATTERN = ".*implicit args will be materialized" 16 | 17 | 18 | @dataclass 19 | class ImplicitConst(ImplicitArray): 20 | default_dtype = jnp.float32 21 | value: Any 22 | dummy_val: Any 23 | 24 | def materialize(self): 25 | return jnp.full(self.shape, self.value, dtype=self.dtype) 26 | 27 | 28 | @primitive_handler([jax.lax.mul_p, jax.lax.sub_p]) 29 | def mul_handler( 30 | primitive: Primitive, a: ImplicitConst, b: Union[ImplicitConst, jax.Array], **params 31 | ): 32 | def op(lhs, rhs): 33 | return lhs * rhs if primitive.name == "mul" else lhs - rhs 34 | 35 | assert not params 36 | if isinstance(b, ImplicitConst): 37 | return ImplicitConst( 38 | op(a.value, b.value), a.dummy_val, shape=a.shape, dtype=a.dtype 39 | ) 40 | if b.shape == (): 41 | new_value = op(a.value, b) 42 | return ImplicitConst(new_value, a.dummy_val, shape=a.shape, dtype=a.dtype) 43 | return op(a.value, b) 44 | 45 | 46 | @pytest.fixture 47 | def const(): 48 | shape = (2, 3) 49 | return ImplicitConst(2, -173, shape=shape) 50 | 51 | 52 | def test_transform(const): 53 | @use_implicit_args 54 | def f(x, y): 55 | return x * y 56 | 57 | print(f"Const: {const}") 58 | assert f(const, jnp.ones(const.shape))[0, 0] == const.value 59 | 60 | 61 | def test_pjit(const): 62 | @use_implicit_args 63 | @pjit.pjit 64 | def f(x, y): 65 | return x * y 66 | 67 | assert f(const, jnp.ones(const.shape))[0, 0] == const.value 68 | 69 | 70 | def test_remat(const): 71 | @use_implicit_args 72 | @jax.checkpoint 73 | def f(x, y): 74 | return x * y 75 | 76 | result = f(const, jnp.ones(const.shape)) 77 | assert result.shape == const.shape 78 | assert result[0, 0] == const.value 79 | 80 | 81 | def test_materialize(const): 82 | def f(x): 83 | return 3 + x 84 | 85 | with pytest.warns(UserWarning, match=WARN_PATTERN): 86 | use_implicit_args(f)(const) 87 | 88 | 89 | def test_suppress_materialize_warning(): 90 | class NoMaterializeWarning(ImplicitConst, warn_on_materialize=False): 91 | pass 92 | 93 | def f(x): 94 | return 3 + x 95 | 96 | with warnings.catch_warnings(): 97 | warnings.filterwarnings("error", message=WARN_PATTERN) 98 | use_implicit_args(f)(NoMaterializeWarning(2, -173, shape=())) 99 | 100 | 101 | def test_cond(const): 102 | @use_implicit_args 103 | def f(x, y): 104 | def true_fn(x): 105 | return x * jnp.ones(x.shape) 106 | 107 | def false_fn(x): 108 | return x * jnp.zeros(x.shape) + 5 109 | 110 | return jnp.sum(jax.lax.cond(y, true_fn, false_fn, x)) 111 | 112 | with warnings.catch_warnings(): 113 | warnings.filterwarnings("error", message=WARN_PATTERN) 114 | assert f(const, True) == const.value * np.prod(const.shape) 115 | assert f(const, False) == 5 * np.prod(const.shape) 116 | 117 | 118 | def test_cond_materialize_branch(const): 119 | @use_implicit_args 120 | def f(x, y): 121 | def true_fn(x): 122 | return x 123 | 124 | def false_fn(x): 125 | return jnp.ones(x.shape) 126 | 127 | return jax.lax.cond(y, true_fn, false_fn, x) 128 | 129 | result = f(const, True) 130 | assert isinstance(result, jax.Array) 131 | assert result.shape == const.shape 132 | assert jnp.all(result == const.value) 133 | 134 | 135 | def test_cond_partial_materialize_branch(): 136 | @use_implicit_args 137 | def f(x, y, z): 138 | def true_fn(x, y): 139 | return y * y 140 | 141 | def false_fn(x, y): 142 | return x * x 143 | 144 | return jax.lax.cond(z, true_fn, false_fn, x, y) 145 | 146 | shape = (2, 3) 147 | x = ImplicitConst(2.0, -173, shape=shape) 148 | y = ImplicitConst( 149 | value=ImplicitConst(1.0, -173, shape=()), dummy_val=-173, shape=shape 150 | ) 151 | # y._materialize() 152 | 153 | result = f(x, y, True) 154 | assert isinstance(result, ImplicitConst) 155 | assert isinstance(result.value, jax.Array) 156 | assert result.shape == (2, 3) 157 | assert jnp.all(result.value == 1) 158 | 159 | 160 | def test_switch(const): 161 | @use_implicit_args 162 | def f(x, i): 163 | branch_fn = lambda a, x: jnp.sum(a * x) 164 | branches = [partial(branch_fn, jnp.asarray(i)) for i in range(3)] 165 | 166 | return jax.lax.switch(i, branches, x) 167 | 168 | with warnings.catch_warnings(): 169 | warnings.filterwarnings("error", message=".*switch was not handled") 170 | assert f(const, 0) == 0 171 | 172 | 173 | def test_no_implicit_args(): 174 | def f(x): 175 | return jnp.sum(x**2) 176 | 177 | assert use_implicit_args(f)(jnp.ones((3, 3))) == 9 178 | 179 | 180 | def test_vmap(): 181 | def f(x, y): 182 | return jnp.sum(x * y) 183 | 184 | xs = ImplicitConst(jnp.arange(3), jnp.arange(-100, -97), shape=(7, 11)) 185 | ys = jax.random.normal(jax.random.PRNGKey(0), (7, 11)) 186 | 187 | x_value = jnp.tile(jnp.arange(3)[:, None, None], (1, 7, 11)) 188 | 189 | vmapped_f = jax.vmap(f, in_axes=(0, None)) 190 | implicit_f = jax.vmap(use_implicit_args(f), in_axes=(0, None)) 191 | 192 | result = implicit_f(xs, ys) 193 | expected_result = vmapped_f(x_value, ys) 194 | 195 | assert jnp.allclose(result, expected_result) 196 | 197 | 198 | def test_disable_commute(): 199 | class NoCommute(ImplicitArray, commute_ops=False): 200 | def materialize(self): 201 | return jnp.zeros(self.shape) 202 | 203 | @primitive_handler("add") 204 | def add(primitive: Primitive, x: NoCommute, y: Any): 205 | return y 206 | 207 | with pytest.warns(UserWarning, match=WARN_PATTERN): 208 | use_implicit_args(lambda x: 1 + x)(NoCommute(shape=())) 209 | -------------------------------------------------------------------------------- /examples/const.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | from qax import ( 8 | ArrayValue, 9 | ImplicitArray, 10 | aux_field, 11 | default_handler, 12 | primitive_handler, 13 | use_implicit_args, 14 | ) 15 | 16 | # To define the behavior we want, we subclass qax.ImplicitArray 17 | # To define additional fields we also need to mark this class as 18 | # a dataclass 19 | 20 | 21 | @dataclass 22 | class ImplicitConst(ImplicitArray): 23 | # Dataclass attributes may be used to define the arrays which 24 | # determine the concrete array being represented 25 | # In this case it's a single JAX scalar 26 | value: ArrayValue 27 | 28 | # ImplicitArrays are pytrees, so all attributes are automatically 29 | # marked as pytree children. To instead mark one as auxiliary data 30 | # use the qax.aux_field decorator 31 | my_aux_value: str = aux_field(default="some_metadata") 32 | 33 | # There are several ways to control the shape and dtype of an ImplicitArray 34 | # They are: 35 | # 1. Pass shape/dtype kwargs to the constructor 36 | # 2. Override the compute_shape/commute_dtype methods 37 | # 3. Override the default_shape/default_dtype class attributes 38 | # 4. Manually override __post_init__ and set the self.shape/self.dtype values yourself 39 | # 5. Do none of the above, in which case materialize() will be abstractly evaluated 40 | # in an attempt to derive the values. That won't work in this case since we need 41 | # to know them in order to call jnp.full 42 | 43 | default_dtype = jnp.float32 44 | 45 | def compute_dtype(self): 46 | # We're doing this instead of just self.value.dtype since we might get 47 | # a python scalar 48 | return jax.core.get_aval(self.value).dtype 49 | 50 | # The way we can guarantee that our ImplicitArrays will work 51 | # with pre-existing code is that whenever we hit an op 52 | # that we haven't written custom behavior for, the 53 | # ImplicitArray instance will be materialized into a 54 | # dense array and the default behavior will be used 55 | def materialize(self): 56 | return jnp.full(self.shape, self.value, dtype=self.dtype) 57 | 58 | 59 | # The way we define custom behavior is by writing a function 60 | # and decorating it with the primitive_handler decorator 61 | # The type annotations are used for multiple dispatch with 62 | # plum so make sure to get them right! 63 | # 64 | # For commutative ops, the ImplicitArray instance will always be made the 65 | # lhs, but this isn't true for non-commutative ops as we'll see below 66 | @primitive_handler("mul") 67 | def mul(primitive, a: ImplicitConst, b: jax.Array): 68 | """ 69 | Arguments: 70 | - primitive: A JAX primitive 71 | - const: An argument guaranteed to be an ImplicitConst instance 72 | - other: An argument which will either be an ImplicitConst or a JAX typj 73 | """ 74 | 75 | # Get the output shape in case there's any broadcasting going on 76 | out_shape = jnp.broadcast_shapes(a.shape, b.shape) 77 | if b.size == 1: 78 | # If we get multiplied by a scalar, we can 79 | # output another ImplicitConst instance 80 | # rather than materializing the dense array 81 | return ImplicitConst(a.value * b.reshape(1)[0], shape=out_shape) 82 | 83 | # In the general case we just multiply our constant value by the other array 84 | result = b * a.value 85 | return jnp.broadcast_to(result, out_shape) 86 | 87 | 88 | # We can also define the case where both arguments are ImplicitConsts 89 | @primitive_handler("mul") 90 | def mul(primitive, a: ImplicitConst, b: ImplicitConst): 91 | out_shape = jnp.broadcast_shapes(a.shape, b.shape) 92 | return ImplicitConst(a.value * b.value, shape=out_shape) 93 | 94 | 95 | # You can use one handler for multiple primitives by passing an iterable to the decorator 96 | @primitive_handler(["sin", "cos", "exp"]) 97 | def elementwise_unop(primitive, arg: ImplicitConst): 98 | # In a lot of cases the logic doesn't have anything 99 | # to do with the exact primitive being used so 100 | # we can just use qax.default_handler to execute 101 | result = default_handler(primitive, arg.value) 102 | return ImplicitConst(result, shape=arg.shape) 103 | 104 | 105 | # If the primitive has any params (such as reduction axes) the handler will receive 106 | # them as a param kwarg 107 | # 108 | # The above handlers were registered using the primitive name, which is 109 | # is using the actual lax primitive under the hood. You can also use 110 | # the actual primitive, which is done here 111 | @primitive_handler(jax.lax.reduce_sum_p) 112 | def reduce_sum(primitive, a: ImplicitConst, *, axes): 113 | sum_result = np.prod([a.shape[i] for i in axes]) * a.value 114 | new_shape = tuple(d for d in a.shape if d not in axes) 115 | return ImplicitConst(sum_result, shape=new_shape) 116 | 117 | 118 | # This decorator makes it so that `f` can handle inputs which are instances 119 | # of ImplicitArray subclasses (or pytrees containing such instances) 120 | # You can also still call it with ordinary JAX inputs 121 | @use_implicit_args 122 | def f(a, b): 123 | c = a * b 124 | d = jnp.sin(c) 125 | return jnp.sum(d) 126 | 127 | 128 | def main(): 129 | shape = (5, 7) 130 | 131 | a_full = jnp.full(shape, 3.0) 132 | a_implicit = ImplicitConst(3.0, shape=shape) 133 | 134 | b_full = jnp.full(shape, 2.0) 135 | b_implicit = ImplicitConst(2.0, shape=shape) 136 | 137 | result = f(a_full, b_full) 138 | 139 | full_implicit_result = f(a_implicit, b_implicit) 140 | mixed_result = f(a_implicit, b_full) 141 | 142 | # We get the same result each time (other than some floating point error) 143 | # In the second case, we were able to avoid materializing the ImplicitConst 144 | # so we get an ImplicitConst as an output 145 | print(result) # -9.779543 146 | print(full_implicit_result) # ImplicitConst(-9.779541969299316, (5, 7)) 147 | print(mixed_result) # -9.779543 148 | 149 | # We can also nest ImplicitArray instances (even if they're different subclasses) 150 | nested_b = ImplicitConst(value=ImplicitConst(2.0, shape=()), shape=shape) 151 | 152 | nested_result = f(a_implicit, nested_b) 153 | print(nested_result) # ImplicitConst(ImplicitConst(-9.779541969299316, ()), (5, 7)) 154 | 155 | 156 | if __name__ == "__main__": 157 | main() 158 | -------------------------------------------------------------------------------- /tests/primitives.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import pytest 6 | 7 | from qax import ( 8 | ArrayValue, 9 | ImplicitArray, 10 | default_handler, 11 | primitive_handler, 12 | use_implicit_args, 13 | ) 14 | from qax.constants import ( 15 | CUMULATIVE_REDUCTION_OPS, 16 | ELEMENTWISE_BINOPS, 17 | ELEMENTWISE_UNOPS, 18 | REDUCTION_OPS, 19 | ) 20 | 21 | primitive_example_params = { 22 | "convert_element_type": {"new_dtype": jnp.float32, "weak_type": False}, 23 | "integer_pow": {"y": 3}, 24 | "reduce_precision": {"exponent_bits": 4, "mantissa_bits": 12}, 25 | "round": {"rounding_method": jax.lax.RoundingMethod.AWAY_FROM_ZERO}, 26 | } 27 | 28 | for op in REDUCTION_OPS: 29 | primitive_example_params[op] = {"axes": (1,)} 30 | 31 | for op in CUMULATIVE_REDUCTION_OPS: 32 | primitive_example_params[op] = {"axis": 1, "reverse": False} 33 | 34 | primitive_example_params["argmax"] = primitive_example_params["argmin"] = { 35 | "axes": (1,), 36 | "index_dtype": jnp.int32, 37 | } 38 | 39 | input_dtypes = { 40 | "clz": jnp.int32, 41 | "not": jnp.uint8, 42 | "population_count": jnp.int32, 43 | "imag": jnp.complex64, 44 | "real": jnp.complex64, 45 | "shift_right_logical": (jnp.int32, jnp.int32), 46 | "shift_right_arithmetic": (jnp.int32, jnp.int32), 47 | "shift_left": (jnp.int32, jnp.int32), 48 | "or": (jnp.int32, jnp.int32), 49 | "and": (jnp.int32, jnp.int32), 50 | "xor": (jnp.int32, jnp.int32), 51 | } 52 | 53 | 54 | def make_class_for_primitive(primitive): 55 | @dataclass 56 | class StackedArray(ImplicitArray): 57 | a: ArrayValue 58 | b: ArrayValue 59 | 60 | def materialize(self): 61 | return jnp.concatenate((self.a, self.b), axis=0) 62 | 63 | def __repr__(self): 64 | return f"StackedArray({self.a}, {self.b})" 65 | 66 | return StackedArray 67 | 68 | 69 | @pytest.mark.parametrize("primitive", ELEMENTWISE_UNOPS) 70 | def test_unop(primitive): 71 | StackedArray = make_class_for_primitive(primitive) 72 | 73 | @primitive_handler(primitive) 74 | def handler(primitive, arg: StackedArray, **kwargs): 75 | new_a = default_handler(primitive, arg.a, **kwargs) 76 | new_b = default_handler(primitive, arg.b, **kwargs) 77 | return StackedArray( 78 | new_a, 79 | new_b, 80 | ) 81 | 82 | lax_primitive = getattr(jax.lax, f"{primitive}_p") 83 | 84 | def f(x): 85 | params = primitive_example_params.get(primitive, {}) 86 | return lax_primitive.bind(x, **params) 87 | 88 | to_type = input_dtypes.get(primitive, jnp.float32) 89 | x = jax.random.normal(jax.random.PRNGKey(0), (3, 7)).astype(to_type) 90 | y = jax.random.normal(jax.random.PRNGKey(1), (9, 7)).astype(to_type) 91 | stacked = StackedArray(x, y) 92 | expected = f(stacked.materialize()) 93 | 94 | with_implicit = use_implicit_args(f)(stacked).materialize() 95 | 96 | close = jnp.isclose(with_implicit, expected) 97 | nan_agree = jnp.logical_and(jnp.isnan(with_implicit), jnp.isnan(expected)) 98 | assert jnp.all(close | nan_agree) 99 | 100 | 101 | @pytest.mark.parametrize("primitive", ELEMENTWISE_BINOPS) 102 | def test_binop(primitive): 103 | StackedArray = make_class_for_primitive(primitive) 104 | 105 | @primitive_handler(primitive) 106 | def handler(primitive, arg1: StackedArray, arg2: StackedArray, **kwargs): 107 | new_a = default_handler(primitive, arg1.a, arg2.a, **kwargs) 108 | new_b = default_handler(primitive, arg1.b, arg2.b, **kwargs) 109 | return StackedArray(new_a, new_b) 110 | 111 | lax_primitive = getattr(jax.lax, f"{primitive}_p") 112 | 113 | def f(x, y): 114 | params = primitive_example_params.get(primitive, {}) 115 | return lax_primitive.bind(x, y, **params) 116 | 117 | lhs_type, rhs_type = input_dtypes.get(primitive, (jnp.float32, jnp.float32)) 118 | x = jax.random.normal(jax.random.PRNGKey(0), (3, 7)).astype(lhs_type) 119 | y = jax.random.normal(jax.random.PRNGKey(1), (9, 7)).astype(lhs_type) 120 | stacked1 = StackedArray(x, y) 121 | 122 | z = jax.random.normal(jax.random.PRNGKey(2), (3, 7)).astype(rhs_type) 123 | w = jax.random.normal(jax.random.PRNGKey(3), (9, 7)).astype(rhs_type) 124 | stacked2 = StackedArray(z, w) 125 | 126 | expected = f(stacked1.materialize(), stacked2.materialize()) 127 | 128 | with_implicit = use_implicit_args(f)(stacked1, stacked2).materialize() 129 | 130 | close = jnp.isclose(with_implicit, expected) 131 | nan_agree = jnp.logical_and(jnp.isnan(with_implicit), jnp.isnan(expected)) 132 | assert jnp.all(close | nan_agree) 133 | 134 | 135 | @pytest.mark.parametrize("primitive", REDUCTION_OPS) 136 | def test_reduction(primitive): 137 | StackedArray = make_class_for_primitive(primitive) 138 | 139 | @primitive_handler(primitive) 140 | def handler(primitive, arg: StackedArray, *, axes, **params): 141 | if 0 in axes: 142 | raise ValueError("Tests should use axis 1") 143 | a_reduced = default_handler(primitive, arg.a, axes=axes, **params) 144 | b_reduced = default_handler(primitive, arg.b, axes=axes, **params) 145 | return StackedArray(a_reduced, b_reduced) 146 | 147 | lax_primitive = getattr(jax.lax, f"{primitive}_p") 148 | 149 | def f(x): 150 | params = primitive_example_params.get(primitive, {}) 151 | return lax_primitive.bind(x, **params) 152 | 153 | to_type = input_dtypes.get(primitive, jnp.int32) 154 | x = jax.random.normal(jax.random.PRNGKey(0), (3, 7)).astype(to_type) 155 | y = jax.random.normal(jax.random.PRNGKey(1), (9, 7)).astype(to_type) 156 | stacked = StackedArray(x, y) 157 | expected = f(stacked.materialize()) 158 | 159 | with_implicit = use_implicit_args(f)(stacked).materialize() 160 | 161 | close = jnp.isclose(with_implicit, expected) 162 | nan_agree = jnp.logical_and(jnp.isnan(with_implicit), jnp.isnan(expected)) 163 | assert jnp.all(close | nan_agree) 164 | 165 | 166 | @pytest.mark.parametrize("primitive", CUMULATIVE_REDUCTION_OPS) 167 | def test_cumulative_reduction(primitive): 168 | StackedArray = make_class_for_primitive(primitive) 169 | 170 | @primitive_handler(primitive) 171 | def handler(primitive, arg: StackedArray, *, axis, **params): 172 | if axis != 1: 173 | raise ValueError("Tests should use axis 1") 174 | 175 | a_reduced = default_handler(primitive, arg.a, axis=axis, **params) 176 | b_reduced = default_handler(primitive, arg.b, axis=axis, **params) 177 | 178 | return StackedArray(a_reduced, b_reduced) 179 | 180 | lax_primitive = getattr(jax.lax, f"{primitive}_p") 181 | 182 | def f(x): 183 | params = primitive_example_params.get(primitive, {}) 184 | return lax_primitive.bind(x, **params) 185 | 186 | to_type = input_dtypes.get(primitive, jnp.float32) 187 | x = jax.random.normal(jax.random.PRNGKey(0), (3, 7)).astype(to_type) 188 | y = jax.random.normal(jax.random.PRNGKey(1), (9, 7)).astype(to_type) 189 | stacked = StackedArray(x, y) 190 | expected = f(stacked.materialize()) 191 | 192 | with_implicit = use_implicit_args(f)(stacked).materialize() 193 | 194 | close = jnp.isclose(with_implicit, expected) 195 | nan_agree = jnp.logical_and(jnp.isnan(with_implicit), jnp.isnan(expected)) 196 | assert jnp.all(close | nan_agree) 197 | -------------------------------------------------------------------------------- /qax/symbols.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of SymbolicConstant, an ImplicitArray subclass representing jnp.full(shape, value, dtype). 3 | The data is stored in pytree auxilary data, and all supported operations are run at compile time. 4 | This is useful for forcing constant folding in cases where the JIT would not know that an argument 5 | is constant, or for lowering the cost of constant folding on large constant arrays. 6 | These do not respect NaNs in certain ways, e.g. 0 * NaN = 0, max(inf, NaN) = inf 7 | """ 8 | 9 | from dataclasses import dataclass 10 | from functools import partial 11 | from typing import Any 12 | 13 | import jax 14 | import jax.numpy as jnp 15 | import numpy as np 16 | import optax 17 | 18 | from .common.type_utils import Complement 19 | from .constants import ELEMENTWISE_BINOPS, ELEMENTWISE_UNOPS 20 | from .implicit import implicit_utils as iu 21 | from .implicit.implicit_array import ( 22 | ArrayValue, 23 | ImplicitArray, 24 | UninitializedAval, 25 | aux_field, 26 | use_implicit_args, 27 | ) 28 | from .primitives import default_handler, get_primitive_handler, primitive_handler 29 | 30 | _GENERAL = -2 31 | _SPECIALIZED = -1 32 | 33 | 34 | def _get_shape_dtype(x, shape, dtype): 35 | if shape is None: 36 | shape = np.shape(x) 37 | else: 38 | shape = jax.core.canonicalize_shape(shape) 39 | 40 | if dtype is None: 41 | jax.lax.dtype(x) 42 | return shape, dtype 43 | 44 | 45 | def _out_shape_dtype(primitive, *args, **kwargs): 46 | out_aval = jax.eval_shape( 47 | partial(default_handler, primitive, **kwargs), 48 | *(jax.core.get_aval(x) for x in args), 49 | ) 50 | return jax.tree_map(lambda x: (x.shape, x.dtype), out_aval) 51 | 52 | 53 | def symbolic_zero_like(x, shape=None, dtype=None): 54 | dtype = jax.lax.dtype(x) if dtype is None else dtype 55 | return symbolic_full_like(x, 0, shape=shape, dtype=dtype) 56 | 57 | 58 | def symbolic_full_like(x, fill_value, shape=None, dtype=None): 59 | shape, _ = _get_shape_dtype(x, shape, None) 60 | if dtype is None: 61 | dtype = jax.lax.dtype(fill_value) 62 | 63 | return SymbolicConstant(fill_value, shape=shape, dtype=dtype) 64 | 65 | 66 | @dataclass 67 | class SymbolicConstant(ImplicitArray): 68 | value: Any = aux_field() 69 | weak_type: bool = aux_field(default=False) 70 | 71 | def __post_init__(self): 72 | super().__post_init__() 73 | with jax.ensure_compile_time_eval(): 74 | self.value = jnp.asarray(self.value, dtype=self.dtype) 75 | 76 | def compute_dtype(self): 77 | return jax.lax.dtype(self.value) 78 | 79 | def materialize(self): 80 | return jnp.full(self.shape, self.value, dtype=self.dtype) 81 | 82 | def copy(self): 83 | return jax.tree_map(lambda x: x, self) 84 | 85 | 86 | @use_implicit_args 87 | def broadcast_to(val, shape): 88 | return jnp.broadcast_to(val, shape) 89 | 90 | 91 | @use_implicit_args 92 | def astype(val, dtype): 93 | return val.astype(dtype) 94 | 95 | 96 | @primitive_handler( 97 | [ 98 | "reshape", 99 | "broadcast_in_dim", 100 | "reduce_min", 101 | "reduce_max", 102 | "reduce_or", 103 | "reduce_and", 104 | ] 105 | ) 106 | def unchanged_value_op(primitive, sym: SymbolicConstant, **kwargs): 107 | out_shape, out_dtype = _out_shape_dtype(primitive, sym, **kwargs) 108 | return SymbolicConstant(sym.value, shape=out_shape, dtype=out_dtype) 109 | 110 | 111 | def _op_and_reshape(primitive, lhs, rhs, flip=False): 112 | """ 113 | Close over one arg so we can do math at tracing time, but let the other one get traced 114 | """ 115 | if flip: 116 | lhs, rhs = (rhs, lhs) 117 | 118 | @use_implicit_args 119 | def inner(arg): 120 | other = lhs 121 | if flip: 122 | arg, other = (other, arg) 123 | 124 | result = default_handler(primitive, arg, other) 125 | return result 126 | 127 | return inner(rhs) 128 | 129 | 130 | def special_case_binop(name, identity=None, annihilator=None, flip=False): 131 | lhs_type = SymbolicConstant 132 | rhs_type = Complement[ArrayValue, SymbolicConstant] 133 | if flip: 134 | lhs_type, rhs_type = rhs_type, lhs_type 135 | 136 | @primitive_handler(name, precedence=_SPECIALIZED) 137 | def handler(primitive, lhs: lhs_type, rhs: rhs_type, **kwargs): 138 | out_shape, out_dtype = _out_shape_dtype(primitive, lhs, rhs, **kwargs) 139 | with jax.ensure_compile_time_eval(): 140 | if lhs.value == identity: 141 | return broadcast_to(astype(rhs, out_dtype), out_shape) 142 | 143 | if lhs.value == annihilator: 144 | return SymbolicConstant(lhs.value, shape=out_shape, dtype=out_dtype) 145 | 146 | print(f"{primitive} {lhs.value} {rhs}") 147 | return _op_and_reshape(primitive, lhs.value, rhs) 148 | 149 | 150 | special_case_binop("add", identity=0) 151 | special_case_binop("mul", identity=1, annihilator=0) 152 | special_case_binop("and", annihilator=0) 153 | special_case_binop("or", identity=0) 154 | special_case_binop("xor", identity=0) 155 | 156 | special_case_binop("sub", identity=0, flip=True) 157 | special_case_binop("div", identity=1, flip=True) 158 | special_case_binop("exp", identity=1, flip=True) 159 | 160 | special_case_binop("min", identity=float("inf"), annihilator=float("-inf")) 161 | special_case_binop("max", identity=float("-inf"), annihilator=float("inf")) 162 | 163 | 164 | def eval_default_handler(primitive, *args, **kwargs): 165 | with jax.ensure_compile_time_eval(): 166 | result = primitive.bind(*args, **kwargs) 167 | return result 168 | 169 | 170 | @primitive_handler(ELEMENTWISE_UNOPS, precedence=_GENERAL) 171 | def handle_unop(primitive, sym: SymbolicConstant, **kwargs): 172 | print(f"Handling {primitive} with {sym}") 173 | new_val = eval_default_handler(primitive, sym.value, **kwargs) 174 | return symbolic_full_like(sym, new_val) 175 | 176 | 177 | @primitive_handler(ELEMENTWISE_BINOPS, precedence=_GENERAL) 178 | def handle_binop(primitive, lhs: SymbolicConstant, rhs: SymbolicConstant, **kwargs): 179 | out_shape, out_dtype = _out_shape_dtype(primitive, lhs, rhs, **kwargs) 180 | new_val = eval_default_handler(primitive, lhs.value, rhs.value, **kwargs) 181 | return symbolic_full_like(lhs, new_val, shape=out_shape, dtype=out_dtype) 182 | 183 | 184 | @primitive_handler(["reduce_sum", "reduce_prod"]) 185 | def reduce_sum(primitive, sym: SymbolicConstant, *, axes): 186 | out_shape, out_dtype = _out_shape_dtype(primitive, sym, axes=axes) 187 | with jax.ensure_compile_time_eval(): 188 | if sym.value == 0: 189 | return SymbolicConstant(0, shape=out_shape, dtype=out_dtype) 190 | 191 | orig_size = np.prod(sym.shape) 192 | new_size = np.prod(out_shape) 193 | 194 | n_combined = orig_size // new_size 195 | 196 | new_val = sym.value 197 | if primitive.name == "reduce_sum": 198 | new_val = new_val * n_combined 199 | else: 200 | new_val = new_val**n_combined 201 | 202 | return SymbolicConstant(new_val, shape=out_shape, dtype=out_dtype) 203 | 204 | 205 | @primitive_handler("select_n") 206 | def handle_select_n(primitive, cond_val, *arg_vals: SymbolicConstant): 207 | if len(set(val.value.item() for val in arg_vals)) != 1: 208 | return NotImplemented 209 | 210 | out_shape, out_dtype = _out_shape_dtype(primitive, cond_val, *arg_vals) 211 | return SymbolicConstant(arg_vals[0].value, shape=out_shape, dtype=out_dtype) 212 | -------------------------------------------------------------------------------- /qax/implicit/implicit_utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial, wraps 2 | from itertools import chain 3 | 4 | import jax 5 | import jax.extend.linear_util as lu 6 | from jax import core, tree_util 7 | from jax.api_util import flatten_fun_nokwargs 8 | 9 | from . import implicit_array as ia 10 | 11 | 12 | class _EmptyNodeCls: 13 | _instance = None 14 | 15 | def __new__(cls): 16 | if cls._instance is None: 17 | cls._instance = super().__new__(cls) 18 | return cls._instance 19 | 20 | 21 | EmptyNode = _EmptyNodeCls() 22 | 23 | tree_util.register_pytree_node( 24 | _EmptyNodeCls, lambda node: ((), None), lambda _, __: EmptyNode 25 | ) 26 | 27 | 28 | def combine_leaf_predicate(base_fn, is_leaf): 29 | @wraps(base_fn) 30 | def new_fn(*args, new_is_leaf=None): 31 | if new_is_leaf is None: 32 | combined_is_leaf = is_leaf 33 | else: 34 | 35 | def combined_is_leaf(arg): 36 | return is_leaf(arg) or new_is_leaf(arg) 37 | 38 | return base_fn(*args, is_leaf=combined_is_leaf) 39 | 40 | return new_fn 41 | 42 | 43 | def leaf_predicate(x): 44 | return isinstance(x, (ia.ImplicitArray, _EmptyNodeCls)) 45 | 46 | 47 | tree_map_with_implicit = combine_leaf_predicate(jax.tree_map, leaf_predicate) 48 | tree_map_with_path_with_implicit = combine_leaf_predicate( 49 | tree_util.tree_map_with_path, leaf_predicate 50 | ) 51 | tree_flatten_with_implicit = combine_leaf_predicate( 52 | tree_util.tree_flatten, leaf_predicate 53 | ) 54 | tree_flatten_with_path_with_implicit = combine_leaf_predicate( 55 | tree_util.tree_flatten_with_path, leaf_predicate 56 | ) 57 | tree_leaves_with_implicit = combine_leaf_predicate( 58 | tree_util.tree_leaves, leaf_predicate 59 | ) 60 | tree_structure_with_implicit = combine_leaf_predicate( 61 | tree_util.tree_structure, leaf_predicate 62 | ) 63 | 64 | 65 | def flatten_one_implicit_layer(tree): 66 | def is_leaf_below_node(node, x): 67 | return isinstance(x, ia.ImplicitArray) and x is not node 68 | 69 | def replace_subtree_implicits(node): 70 | return tree_util.tree_map( 71 | lambda _: 1, node, is_leaf=partial(is_leaf_below_node, node) 72 | ) 73 | 74 | prototype = tree_map_with_implicit(replace_subtree_implicits, tree) 75 | struct = tree_util.tree_structure(prototype) 76 | 77 | leaves = tree_leaves_with_implicit(tree) 78 | leaves = list( 79 | chain.from_iterable( 80 | ( 81 | tree_util.tree_leaves(leaf, is_leaf=partial(is_leaf_below_node, leaf)) 82 | if isinstance(leaf, ia.ImplicitArray) 83 | else [leaf] 84 | ) 85 | for leaf in leaves 86 | ) 87 | ) 88 | return leaves, struct 89 | 90 | 91 | def implicit_depth(tree): 92 | leaves = tree_leaves_with_implicit(tree) 93 | depth = 0 94 | while True: 95 | next_leaves = [] 96 | any_implicit = False 97 | for leaf in leaves: 98 | if not isinstance(leaf, ia.ImplicitArray): 99 | continue 100 | any_implicit = True 101 | next_leaves.extend(flatten_one_implicit_layer(leaf)[0]) 102 | 103 | if not any_implicit: 104 | return depth 105 | 106 | depth += 1 107 | leaves = next_leaves 108 | 109 | 110 | def _map_leaves_with_implicit_path(f, leaves, is_leaf, path_prefix=()): 111 | mapped_leaves = [] 112 | for idx, leaf in enumerate(leaves): 113 | path = path_prefix + (idx,) 114 | if not isinstance(leaf, ia.ImplicitArray) or is_leaf(path, leaf): 115 | mapped_leaves.append(f(leaf)) 116 | continue 117 | 118 | subtree, substruct = flatten_one_implicit_layer(leaf) 119 | mapped_subtree = _map_leaves_with_implicit_path( 120 | f, subtree, is_leaf=is_leaf, path_prefix=path 121 | ) 122 | mapped_leaves.append(tree_util.tree_unflatten(substruct, mapped_subtree)) 123 | return mapped_leaves 124 | 125 | 126 | def _get_pruning_transform(tree, materialization_paths): 127 | if not materialization_paths: 128 | return lambda x: x 129 | 130 | def is_leaf(path, leaf): 131 | return path in materialization_paths 132 | 133 | def materialize_subtrees(tree): 134 | leaves, struct = tree_flatten_with_implicit(tree) 135 | 136 | mapped_leaves = _map_leaves_with_implicit_path( 137 | partial(materialize_nested, full=True), leaves, is_leaf 138 | ) 139 | return tree_util.tree_unflatten(struct, mapped_leaves) 140 | 141 | return materialize_subtrees 142 | 143 | 144 | def get_common_prefix_transforms(trees): 145 | """ 146 | Given an iterable of pytrees which have the same structure after all 147 | ImplicitArray instances are materialized, return a list of callables 148 | which will transform each tree into the largest common structure 149 | obtainable via materialization of ImplicitArrays. 150 | """ 151 | if len(trees) <= 1: 152 | return [lambda x: x for _ in trees] 153 | 154 | all_leaves, structures = zip(*(tree_flatten_with_implicit(tree) for tree in trees)) 155 | post_materialization_avals = [core.get_aval(leaf) for leaf in all_leaves[0]] 156 | for i, (leaves, structure) in enumerate(zip(all_leaves[1:], structures[1:]), 1): 157 | if structure != structures[0]: 158 | raise ValueError( 159 | "Trees do not have the same structure after materialization" 160 | ) 161 | 162 | for leaf, expected_aval in zip(leaves, post_materialization_avals): 163 | aval = core.get_aval(leaf) 164 | if not ( 165 | aval.shape == expected_aval.shape and aval.dtype == expected_aval.dtype 166 | ): 167 | raise ValueError( 168 | f"Trees do not have the same avals after materialization. Tree 0: {expected_aval}, Tree {i}: {aval}" 169 | ) 170 | 171 | # Stack will contain tuples of (path, nodes) 172 | # path = a sequence of integers specifying which child 173 | # was taken at each _flatten_one_implicit_layer call 174 | # or the first flatten_with_implicit call 175 | # nodes = one node from each tree 176 | stack = [] 177 | 178 | all_leaves = [] 179 | for tree in trees: 180 | all_leaves.append(tree_leaves_with_implicit(tree)) 181 | 182 | for i, nodes in enumerate(zip(*all_leaves)): 183 | stack.append(((i,), nodes)) 184 | 185 | materialization_paths = set() 186 | while stack: 187 | path_prefix, nodes = stack.pop() 188 | if not any(isinstance(node, ia.ImplicitArray) for node in nodes): 189 | continue 190 | 191 | all_leaves, all_structures = zip( 192 | *(flatten_one_implicit_layer(node) for node in nodes) 193 | ) 194 | node_structures = set(all_structures) 195 | if len(node_structures) > 1: 196 | materialization_paths.add(path_prefix) 197 | continue 198 | 199 | aval_diff = False 200 | for leaves in zip(*all_leaves): 201 | first_aval = core.get_aval(leaves[0]) 202 | shape = first_aval.shape 203 | dtype = first_aval.dtype 204 | for leaf in leaves[1:]: 205 | aval = core.get_aval(leaf) 206 | if not (aval.shape == shape and aval.dtype == dtype): 207 | materialization_paths.add(path_prefix) 208 | aval_diff = True 209 | if aval_diff: 210 | break 211 | 212 | if aval_diff: 213 | continue 214 | 215 | for i, leaf_group in enumerate(zip(*all_leaves)): 216 | stack.append((path_prefix + (i,), leaf_group)) 217 | 218 | return [_get_pruning_transform(tree, materialization_paths) for tree in trees] 219 | 220 | 221 | def materialize_nested(implicit_arr, full=False): 222 | """ 223 | Materialize an ImplicitArray instance, handling the case where implicit_arr.materialize() 224 | involves further ImplicitArray instances. 225 | Arguments: 226 | implicit_arr: An ImplicitArray instance 227 | full: If True, repeatedly materialize until the result is a concrete array 228 | Returns: 229 | The materialized array 230 | """ 231 | while isinstance(implicit_arr, ia.ImplicitArray): 232 | wrapped = lu.wrap_init(type(implicit_arr).materialize) 233 | flat, in_tree = flatten_one_implicit_layer((implicit_arr,)) 234 | flat_fn, out_tree = flatten_fun_nokwargs(wrapped, in_tree) 235 | out_flat = ia.use_implicit_args(flat_fn.call_wrapped)(*flat) 236 | implicit_arr = jax.tree_util.tree_unflatten(out_tree(), out_flat) 237 | 238 | if not full: 239 | break 240 | 241 | return implicit_arr 242 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/davisyoshida/qax/blob/master/examples/How_to_Qax.ipynb) 2 | 3 | # Qax: If it quacks like a tensor... 4 | 🦆[Qax](https://github.com/davisyoshida/qax)🦆 is a tool for implementing types which represent tensors, but may or may not be instantiated as a single dense array on your GPU. Examples of this include: 5 | * Quantization: A 4-bit array of integers + a small number of scale values are used to represent a full 16/32-bit array 6 | * LoRA: An array $W$ is replaced by the array $(W + BA^T)$ so that $A$ and $B$ may be trained while leaving $W$ frozen 7 | * Symbolic zeros/constants: For arrays which will consist entirely of a single repeated value, simply store that single value and the shape of the array 8 | * Custom kernels: If you have a custom kernel and want to use it with existing models without modifying them, Qax is an easy way to do so 9 | * Hopefully many more things! 10 | 11 | The goal of Qax is to make implementing custom JAX behavior much easier, so that users won't need to deal with all the details of writing a full JAX transform. All you need to do to get custom representations is: 12 | 13 | 1. Define what data/metadata your datatype should contain 14 | 2. Optionally write any number of handlers which specify how your type behaves under JAX primitives such as multiplication 15 | 3. Write a function which constructs a dense array from your implicit representation 16 | 17 | Both of the above are written in pure JAX, so no need for custom gradients (unless you want to of course!). 18 | 19 | ## Installation 20 | ``` 21 | pip install qax 22 | ``` 23 | 24 | ## Example 1: A symbolic zero 25 | The way you specify custom behavior with Qax is to subclass the `qax.ImplicitArray` abstract class. One of the simplest things we could implement is a symbolic zero: A data type which represents an arbitrary tensor full of zeros without actually instantiating them on the GPU. 26 | 27 | 28 | ```python 29 | class Zeros(qax.ImplicitArray): 30 | default_dtype = jnp.float32 31 | 32 | def materialize(self): 33 | # self.shape and self.dtype will be 34 | # populated by the ImplicitArray constructor 35 | return jnp.zeros(self.shape, self.dtype) 36 | 37 | def __str__(self): 38 | return f'Zeros({self.shape}, {self.dtype})' 39 | ``` 40 | 41 | The only mandatory method to implement when subclassing `ImplicitArray` is `materialize()`. 42 | `materialize()` specifies how to turn our _implicitly_ represented array into an _explicit_ one, i.e. a single dense JAX array. In the case of `Zeros`, we can just call `jnp.zeros`. 43 | 44 | Let's instantiate a `Zeros` instance to try it out: 45 | 46 | 47 | ```python 48 | z = Zeros(shape=(2, 3)) 49 | ``` 50 | 51 | ImplicitArrays are [dataclasses](https://docs.python.org/3/library/dataclasses.html), which by default have two keyword only attributes: `shape` and `dtype`. 52 | 53 | By default JAX won't know how to use our new type. In order to use it in functions, we apply the `@use_implicit_args` decorator: 54 | 55 | 56 | ```python 57 | @qax.use_implicit_args 58 | def f(x, y): 59 | return (x + y)[0, 0] 60 | ``` 61 | 62 | 63 | ```python 64 | with warnings.catch_warnings(): 65 | warnings.simplefilter('always') 66 | print(f(z, jnp.ones(3))) 67 | ``` 68 | 69 | /home/davis/src/qax/qax/implicit/implicit_array.py:303: UserWarning: Primitive add was not handled by class Zeros, so implicit args will be materialized. 70 | warnings.warn(f'Primitive {primitive.name} was not handled by class {vals[implicit_idx].__class__.__name__}, so implicit args will be materialized.') 71 | 72 | 73 | 1.0 74 | 75 | 76 | The cool thing is that `f` doesn't need to have any idea that it will be called with `ImplicitArray` instances, so we can use this with any pre-existing model. Right now this isn't much use, since all `z` is being materialized into a dense array as soon as it's needed for a JAX operation. 77 | 78 | To make our `Zeros` do something productive, let's implement the fact that $x + 0$ is always equal to $x$. We do this using the `@qax.primitive_handler` decorator: 79 | 80 | 81 | ```python 82 | def get_binop_result_shape_dtype(a, b): 83 | out_shape = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(b)) 84 | out_dtype = jnp.result_type(a.dtype, b.dtype) 85 | return out_shape, out_dtype 86 | 87 | # primitive_handler() takes a string, JAX primitive, or a list of those types 88 | # strings are used to find the corresponding primitive from `jax.lax` 89 | @qax.primitive_handler('add') 90 | def my_add_handler(primitive, a : Zeros, b): 91 | # Handlers will receive as arguments: 92 | # - primitive: a jax.core.Primitive instance (often can be ignored if the handler is just for one op) 93 | # Any number of arguments which are either JAX values or ImplicitArrays 94 | # Keyword arguments specifying parameters of the operation (e.g. axes for reduction operations) 95 | 96 | out_shape, out_dtype = get_binop_result_shape_dtype(a, b) 97 | 98 | if isinstance(b, Zeros): 99 | # We can return further ImplicitArray instances if we want 100 | return Zeros(shape=out_shape, dtype=out_dtype) 101 | 102 | # Return b, possibly modifying its shape or dtype 103 | return jnp.broadcast_to(b, out_shape).astype(out_dtype) 104 | ``` 105 | 106 | The type annotation `a : Zeros` is actually important, Qax uses [Plum](https://github.com/beartype/plum) for multiple dispatch. You can even use this to define how different subclasses of ImplicitArray should interact with each other. 107 | 108 | (For convenience, commutative binary ops like $+$ and $\times$ will automatically get their argument order switched so that the `ImplicitArray` instance comes first.) 109 | 110 | Now when we call `f`, we no longer see the materialization log message, since our add handler is skipping over ever instantiating the array of zeros: 111 | 112 | 113 | ```python 114 | print(f(z, jnp.ones(3))) 115 | ``` 116 | 117 | 1.0 118 | 119 | 120 | Let's define a multiplication handler as well, since $x \cdot 0 = 0$ for all $x$: 121 | 122 | 123 | ```python 124 | @qax.primitive_handler('mul') 125 | def handle_mul(primitive, a : Zeros, b): 126 | out_shape, out_dtype = get_binop_result_shape_dtype(a, b) 127 | 128 | return Zeros(shape=out_shape, dtype=out_dtype) 129 | 130 | 131 | @jax.jit 132 | @qax.use_implicit_args 133 | def g(x, y): 134 | return (1 + x) * y 135 | 136 | print(g(z, z)) 137 | ``` 138 | 139 | Zeros((2, 3), float32) 140 | 141 | 142 | The output of `use_implicit_args` is a function which is compatible with all the usual JAX transformations such as `jit`, `vmap`, `grad`, etc. 143 | 144 | Even this simple implementation is enough to let us modify the behavior of models which were written without knowing about Qax. Let's try replacing all the biases in HuggingFace's GPT-2 with zeros: 145 | 146 | 147 | ```python 148 | @qax.primitive_handler('broadcast_in_dim') 149 | def broadcast(primitive, a : Zeros, *, shape, broadcast_dimensions): 150 | # The biases get broadcast in order to add them to the activations 151 | # so we need to handle that case 152 | # Sometimes the simplest thing to do is use jax.eval_shape 153 | # to figure out what shape to return 154 | result_shape = jax.eval_shape( 155 | partial(jax.lax.broadcast_in_dim, shape=shape, broadcast_dimensions=broadcast_dimensions), 156 | a.aval # ImplicitArray has an aval property which will get an abstract shape/dtype 157 | ).shape 158 | return Zeros(shape=result_shape, dtype=a.dtype) 159 | 160 | 161 | model, params = transformers.FlaxAutoModelForCausalLM.from_pretrained('gpt2', _do_init=False) 162 | 163 | inputs = jnp.arange(1, 10)[None] 164 | 165 | # Helper function to switch all the biases 166 | # in the params out for some other value 167 | def replace_biases(params, replacer): 168 | def maybe_replace_val(path, val): 169 | if val.ndim != 1: 170 | return val 171 | 172 | # Skip layernorms 173 | if any( 174 | isinstance(p, jax.tree_util.DictKey) and p.key.startswith('ln') 175 | for p in path 176 | ): 177 | return val 178 | return replacer(shape=val.shape, dtype=val.dtype) 179 | return jax.tree_util.tree_map_with_path(maybe_replace_val, params) 180 | 181 | 182 | # Replace the biases with dense zero arrays: 183 | params_with_zeros = replace_biases(params, jnp.zeros) 184 | print('New bias:', params['transformer']['h']['0']['attn']['c_attn']['bias']) 185 | 186 | output = model(inputs, params=params_with_zeros).logits 187 | print('Last logit average:', jnp.mean(output[0, -1])) 188 | ``` 189 | 190 | New bias: [ 0.48033914 -0.5254326 -0.42926455 ... 0.01257301 -0.04987717 191 | 0.00324764] 192 | Last logit average: -105.25595 193 | 194 | 195 | Now let's try replacing them with our symbolic zeros instead: 196 | 197 | 198 | ```python 199 | params_with_zeros = replace_biases(params, Zeros) 200 | print('New bias:', params['transformer']['h']['0']['attn']['c_attn']['bias']) 201 | 202 | # In this case since we're calling the model directly, we need to 203 | # wrap it so we can pass params in a positional argument 204 | # This usually won't be an issue since the call to the model will 205 | # be inside a loss function or some other function 206 | 207 | output = qax.use_implicit_args(model)(inputs, params=params_with_zeros).logits 208 | print('Last logit average:', jnp.mean(output[0, -1])) 209 | ``` 210 | 211 | New bias: [ 0.48033914 -0.5254326 -0.42926455 ... 0.01257301 -0.04987717 212 | 0.00324764] 213 | Last logit average: -105.25595 214 | 215 | 216 | 217 | ```python 218 | del model 219 | del params 220 | ``` 221 | 222 | We got the same result, but using 0 FLOPs for adding the biases! If you really wanted to flesh out the behavior of `Zeros`, you could also add handlers defining its output for primitives such as `sin`, `cos`, etc. Let's move on to something more interesting though. 223 | 224 | ## Example 2: LoRA 225 | In this example we'll implement [LoRA](https://arxiv.org/abs/2106.09685) in just a few lines of code. Unlike the `Zeros` example from the previous section, our `ImplicitArray` subclass will actually contain data this time. As such we'll need to implement flattening/unflattening logic, since all `ImplicitArray` subclasses are pytrees. This also means you can use `tree_map` and friends to manipulate them. 226 | 227 | To add child pytrees to a subclass, we just add them as dataclass attributes. To add auxilary data, you can wrap a field with `qax.aux_field` which is just a wrapper around `dataclass.field`. 228 | 229 | LoRA replaces a matrix $W$ with the matrix $W_0 + AB^T$, so we'll have three arrays as new attributes. 230 | 231 | 232 | ```python 233 | @dataclass 234 | class LoraMatrix(qax.ImplicitArray): 235 | """Represent W + A B^T""" 236 | w : qax.ArrayValue 237 | a : qax.ArrayValue 238 | b : qax.ArrayValue 239 | 240 | # auxiliary data example 241 | is_array_happy : bool = qax.aux_field(default=True) 242 | 243 | def __post_init__(self): 244 | # If you need to do any validation, you can override the __post_init__ method 245 | # This example is purely for error checking, but you can also 246 | # add manipulations of the attributes 247 | super().__post_init__() 248 | w_aval = jax.core.get_aval(self.w) 249 | a_aval = jax.core.get_aval(self.a) 250 | b_aval = jax.core.get_aval(self.b) 251 | assert w_aval.ndim == a_aval.ndim == b_aval.ndim == 2 252 | assert a_aval.shape[1] == b_aval.shape[1] 253 | assert a_aval.shape[0] == w_aval.shape[0] 254 | assert b_aval.shape[0] == w_aval.shape[1] 255 | assert a_aval.dtype == b_aval.dtype == w_aval.dtype 256 | 257 | def materialize(self): 258 | return self.w + self.a @ self.b.T 259 | 260 | @qax.primitive_handler('dot_general') 261 | def f(primitive, x : jax.Array, w : LoraMatrix, *, dimension_numbers, **kwargs): 262 | # For this example, we'll only handle the simple case of of x @ w, rather than 263 | # all possible dot_general invocations 264 | (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers 265 | 266 | # This check just makes sure that all that's happening is a simple matmul 267 | if not ( 268 | len(w.shape) == 2 269 | and lhs_contract == (x.ndim - 1,) 270 | and rhs_contract == (0,) 271 | and lhs_batch == () 272 | and rhs_batch == () 273 | ): 274 | # If we want to only partially handle a particular primitive, 275 | # we can fall back to the default logic by returning NotImplemented 276 | return NotImplemented 277 | 278 | kwargs = {**kwargs, 'dimension_numbers': dimension_numbers} 279 | # In order to defer to the default implementation of the primitive, 280 | # use the qax.default_handler helper: 281 | result = qax.default_handler( 282 | primitive, # pass the primitive 283 | x, w.w, # Any number of positional arguments, 284 | **kwargs # Then the primitive's keyword args 285 | ) 286 | 287 | xa = qax.default_handler(primitive, x, w.a, **kwargs) 288 | 289 | xab = qax.default_handler(primitive, xa, w.b.T, **kwargs) 290 | 291 | result += xab 292 | return result 293 | 294 | def lora_from_tree(tree, key, lora_dim=8): 295 | """ 296 | Helper function for replacing non-embedding weight 297 | matrices in T5 with LoraMatrix instances. 298 | """ 299 | def iter_keys(key): 300 | while True: 301 | key, k2 = jax.random.split(key) 302 | yield k2 303 | 304 | key_it = iter_keys(key) 305 | def map_fn(path, val): 306 | if val.ndim != 2: 307 | return val 308 | 309 | # Skip embedding params 310 | if any(isinstance(p, jax.tree_util.DictKey) and p.key == 'embedding' for p in path): 311 | return val 312 | 313 | a = jax.random.normal(next(key_it), (val.shape[0], lora_dim), val.dtype) 314 | b = jnp.zeros((val.shape[1], lora_dim), val.dtype) 315 | return LoraMatrix(val, a, b) 316 | 317 | return jax.tree_util.tree_map_with_path(map_fn, tree) 318 | ``` 319 | 320 | Let's try it out on a T5 model: 321 | 322 | 323 | ```python 324 | t5, params = transformers.FlaxAutoModelForSeq2SeqLM.from_pretrained('t5-small', _do_init=False) 325 | tokenizer = transformers.AutoTokenizer.from_pretrained('t5-small') 326 | encoder_inputs = jnp.asarray(tokenizer.encode('Some input'))[None] 327 | decoder_inputs = jnp.asarray([0] + tokenizer.encode('Some output'))[None] 328 | 329 | lora_params = lora_from_tree(params, jax.random.PRNGKey(1234)) 330 | ``` 331 | 332 | 333 | ```python 334 | orig_output = t5(input_ids=encoder_inputs, decoder_input_ids=decoder_inputs, params=params).logits 335 | ``` 336 | 337 | 338 | ```python 339 | lora_output = qax.use_implicit_args(t5)( 340 | input_ids=encoder_inputs, 341 | decoder_input_ids=decoder_inputs, 342 | params=lora_params 343 | ).logits 344 | print(jnp.max(jnp.abs(lora_output - orig_output))) 345 | ``` 346 | 347 | 0.0 348 | 349 | 350 | The LoRA result is identical to the execution of the unmodified network, and we didn't get any materialization warnings so we successfully made a LoRA forward pass without ever calculating $W + AB^T$! 351 | 352 | ## Training 353 | So far we haven't looked at how to train a model when using Qax. The main thing to understand is that you should apply `qax.use_implicit_args` first, _then_ differentiate the resulting function. `use_implicit_args` transforms the function into one which goes from pytrees to pytrees, so all the standard JAX autodiff machinery will work. 354 | 355 | If you need to update only a subset of the elements of an ImplicitArray instance (e.g. only `a` and `b` for LoRA), Qax provides `qax.utils.freeze_keys` to make this easier. Here's an end-to-end example training T5 to memorize the input/output pair from above: 356 | 357 | 358 | ```python 359 | optimizer = optax.adam(3e-4) 360 | # freeze_keys_in_optimizer takes an optax optimizer, the ImplicitArray subclass to freeze for, 361 | # and an iterable of the keys to be frozen 362 | optimizer = qax.utils.freeze_keys(optimizer, LoraMatrix, ['w']) 363 | 364 | # We're only using a single example so we'll just close over the training data 365 | # There are no code changes from an ordinary training loop other than decorating 366 | # loss_fn with @use_implicit_args 367 | 368 | @qax.use_implicit_args 369 | def loss_fn(params): 370 | decoder_ids = decoder_inputs[:, :-1] 371 | targets = decoder_inputs[:, 1:] 372 | logits = t5( 373 | input_ids=encoder_inputs, 374 | decoder_input_ids=decoder_ids, 375 | params=params 376 | ).logits 377 | 378 | logprobs = jax.nn.log_softmax(logits) 379 | target_logprobs = jnp.take_along_axis(logprobs, targets[:, :, None], axis=-1) 380 | loss = -jnp.sum(target_logprobs) 381 | return loss 382 | 383 | grad_fn = jax.value_and_grad(loss_fn) 384 | 385 | @jax.jit 386 | def update(params, opt_state): 387 | loss, grads = grad_fn(params) 388 | updates, new_opt_state = optimizer.update(grads, opt_state, params=params) 389 | new_params = optax.apply_updates(updates, params) 390 | return loss, new_params, new_opt_state 391 | 392 | opt_state = optimizer.init(lora_params) 393 | for step in range(20): 394 | loss, lora_params, opt_state = update(lora_params, opt_state) 395 | print(f'{step}. {loss:.3f}') 396 | ``` 397 | 398 | 0. 8.882 399 | 1. 5.375 400 | 2. 3.787 401 | 3. 2.524 402 | 4. 1.491 403 | 5. 0.723 404 | 6. 0.242 405 | 7. 0.062 406 | 8. 0.022 407 | 9. 0.013 408 | 10. 0.011 409 | 11. 0.009 410 | 12. 0.008 411 | 13. 0.007 412 | 14. 0.007 413 | 15. 0.006 414 | 16. 0.005 415 | 17. 0.004 416 | 18. 0.003 417 | 19. 0.003 418 | 419 | 420 | That's all you need to know to get started using Qax! 421 | 422 | ## Example 3: Nesting 423 | Qax supports arbitrary nesting of `ImplicitArray` instances without. Here's a quick demo combining the previous two examples: 424 | 425 | 426 | ```python 427 | @qax.use_implicit_args 428 | def g(w, x): 429 | return jnp.sum(x @ w) 430 | 431 | w = jnp.ones((3, 5)) 432 | x = jnp.arange(3, dtype=jnp.float32) 433 | 434 | lora_with_symbolic_zero = LoraMatrix( 435 | w=w, 436 | a=Zeros(shape=(w.shape[0], 6)), 437 | b=Zeros(shape=(w.shape[1], 6)) 438 | ) 439 | print(f'Original: {g(w, x)}') 440 | with warnings.catch_warnings(): 441 | warnings.simplefilter('always') 442 | print(f'With lora: {g(lora_with_symbolic_zero, x)}') 443 | ``` 444 | 445 | Original: 15.0 446 | With lora: 15.0 447 | 448 | 449 | UserWarning: Primitive dot_general was not handled by class Zeros, so implicit args will be materialized. 450 | warnings.warn(f'Primitive {primitive.name} was not handled by class {vals[implicit_idx].__class__.__name__}, so implicit args will be materialized.') 451 | UserWarning: Primitive transpose was not handled by class Zeros, so implicit args will be materialized. 452 | warnings.warn(f'Primitive {primitive.name} was not handled by class {vals[implicit_idx].__class__.__name__}, so implicit args will be materialized.') 453 | 454 | 455 | If we wanted we could write a `dot_general` handler to avoid the materialization as well, but the main point is just to illustrate that it's easy to mix and match different `ImplicitArray` subclasses. A more useful example might be using a symbolic zero as the offset for a quantization datatypes which expects both an offset and a scale. 456 | 457 | ## Other examples 458 | [Here's](https://github.com/davisyoshida/abnormal-floats/blob/master/transform.py) an example of using Qax to implement a 4-bit quantized matrix representation. 459 | -------------------------------------------------------------------------------- /qax/implicit/implicit_array.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from abc import ABC, abstractmethod 3 | from contextlib import contextmanager 4 | from contextvars import ContextVar 5 | from dataclasses import dataclass, field, fields, is_dataclass 6 | from functools import partial, wraps 7 | from itertools import chain 8 | from typing import ClassVar, Optional 9 | 10 | import jax 11 | import jax.extend.linear_util as lu 12 | import jax.interpreters.partial_eval as pe 13 | import jax.numpy as jnp 14 | from jax import core 15 | from jax._src.typing import DTypeLike, Shape 16 | from jax.api_util import flatten_fun, flatten_fun_nokwargs 17 | from jax.tree_util import register_pytree_with_keys_class 18 | 19 | from .. import constants 20 | from ..primitives import ArrayValue, get_primitive_handler 21 | from . import implicit_utils as iu 22 | 23 | 24 | def _with_implicit_flat(fun: lu.WrappedFun) -> lu.WrappedFun: 25 | # Splitting to avoid leaks based on https://github.com/google/jax/blob/0dffdf4645db7bf7a9fadd4bcfe9ec0368a8ecb9/jax/_src/interpreters/batching.py#L539 26 | f = _implicit_inner(fun) 27 | return _implicit_outer(f) 28 | 29 | 30 | @lu.transformation 31 | def _implicit_outer(*in_vals): 32 | with core.new_main(ImplicitArrayTrace) as main: 33 | outs = yield (main, *in_vals), {} 34 | del main 35 | yield outs 36 | 37 | 38 | @lu.transformation 39 | def _implicit_inner(main, *in_vals): 40 | trace = main.with_cur_sublevel() 41 | in_tracers = [ 42 | ImplicitArrayTracer(trace, val) if isinstance(val, ImplicitArray) else val 43 | for val in in_vals 44 | ] 45 | outs = yield in_tracers, {} 46 | out_vals = [trace.full_raise(t).value for t in outs] 47 | yield out_vals 48 | 49 | 50 | def use_implicit_args(f): 51 | """ 52 | Decorator which allows a function to accept arguments which subclass ImplicitArray, possibly 53 | including further ImplicitArray instances as children. 54 | Any number of arguments (including 0) may be ImplicitArrays. 55 | """ 56 | 57 | @wraps(f) 58 | def implicit_f(*args, **kwargs): 59 | flat_args, in_tree = iu.tree_flatten_with_implicit((args, kwargs)) 60 | f_flat, out_tree = flatten_fun(lu.wrap_init(f), in_tree) 61 | f_wrapped = _with_implicit_flat(f_flat) 62 | outs_flat = f_wrapped.call_wrapped(*flat_args) 63 | return out_tree().unflatten(outs_flat) 64 | 65 | return implicit_f 66 | 67 | 68 | def aux_field(metadata=None, **kwargs): 69 | metadata = dict(metadata) if metadata else {} 70 | metadata["implicit_array_aux"] = True 71 | return field(metadata=metadata, **kwargs) 72 | 73 | 74 | class UninitializedAval(Exception): 75 | def __init__(self, kind): 76 | super().__init__(_AVAL_ERROR_MESSAGE.format(kind)) 77 | 78 | 79 | # This descriptor and the below context manager support discovering the aval 80 | # of an ImplicitArray. We don't want to throw an error just because a shape 81 | # wasn't passed, since it may be possible to infer it via materialization 82 | class _AvalDescriptor: 83 | def __set_name__(self, owner, name): 84 | self._name = f"_{name}" 85 | 86 | def __get__(self, obj, owner=None): 87 | if obj is None: 88 | return None 89 | result = getattr(obj, self._name, None) 90 | if result is None: 91 | raise UninitializedAval(kind=self._name[1:]) 92 | return result 93 | 94 | def __set__(self, obj, value): 95 | setattr(obj, self._name, value) 96 | 97 | 98 | # Context manager used for disabling UninitializedAval errors 99 | # during tree flattening only 100 | _aval_discovery = ContextVar("aval_discovery", default=False) 101 | 102 | 103 | @contextmanager 104 | def _aval_discovery_context(): 105 | token = _aval_discovery.set(True) 106 | try: 107 | yield 108 | finally: 109 | _aval_discovery.reset(token) 110 | 111 | 112 | @dataclass 113 | class _ImplicitArrayBase(ArrayValue, ABC): 114 | commute_ops: ClassVar[bool] = True 115 | warn_on_materialize: ClassVar[bool] = True 116 | default_shape: ClassVar[Optional[Shape]] = None 117 | default_dtype: ClassVar[Optional[DTypeLike]] = None 118 | 119 | shape: Optional[Shape] = aux_field(kw_only=True, default=None) 120 | dtype: DTypeLike = aux_field(kw_only=True, default=None) 121 | 122 | 123 | @dataclass 124 | class ImplicitArray(_ImplicitArrayBase): 125 | """ 126 | Abstract class for representing an abstract array of a given shape/dtype without actually instantiating it. 127 | Subclasses must implement the materialize method, which defines the relationship between the implicit array 128 | and the value it represents. Subclasses are valid arguments to functions decorated with qax.use_implicit_args. 129 | 130 | All subclasses are automatically registered as pytrees using jax.tree_util.register_pytree_with_keys_class. 131 | Any dataclass attributes added will be included as children, unless they are decorated with qax.aux_field 132 | in which case they are passed as auxiliary data during flattening. 133 | 134 | The represented shape and dtype may be defined in any of the following ways: 135 | - Explicitly passing shape/dtype keyword arguments at initialization 136 | - Overriding the default_shape/default_dtype class variables 137 | - Overriding the compute_shape/compute_dtype methods, which are called during __post_init__ 138 | - Overriding __post_init__ and manually setting shape/dtype before calling super().__post_init__ 139 | - None of the above, in which case an shape/dtype will be inferred by by running jax.eval_shape() 140 | on the subclass's materialize method. 141 | """ 142 | 143 | shape = _AvalDescriptor() 144 | dtype = _AvalDescriptor() 145 | 146 | def __post_init__(self): 147 | try: 148 | aval = _get_materialization_aval(self) 149 | except UninitializedAval: 150 | # Materialization depends on currently uninitialized shape/dtype 151 | aval = None 152 | 153 | shape = None 154 | try: 155 | shape = self.shape 156 | except UninitializedAval as e: 157 | shape = self.shape = self.compute_shape() 158 | 159 | if aval is not None: 160 | if shape is None: 161 | self.shape = aval.shape 162 | elif shape != aval.shape: 163 | warnings.warn( 164 | f"ImplicitArray shape {shape} does not match materialization shape {aval.shape}" 165 | ) 166 | elif shape is None: 167 | raise UninitializedAval("shape") 168 | 169 | dtype = None 170 | try: 171 | dtype = self.dtype 172 | except UninitializedAval as e: 173 | dtype = self.dtype = self.compute_dtype() 174 | 175 | if dtype is None and aval is None: 176 | # We have a shape but not a dtype, try once again to infer the dtype 177 | aval = _get_materialization_aval(self) 178 | 179 | if aval is not None: 180 | if dtype is None: 181 | self.dtype = aval.dtype 182 | elif dtype != aval.dtype: 183 | warnings.warn( 184 | f"ImplicitArray dtype {dtype} does not match materialization dtype {aval.dtype}" 185 | ) 186 | elif dtype is None: 187 | raise UninitializedAval("dtype") 188 | 189 | def compute_shape(self): 190 | """ 191 | Override this method if the subclass instance's shape should be computed based on its other properties. 192 | Returns: shape 193 | """ 194 | return self.default_shape 195 | 196 | def compute_dtype(self): 197 | """ 198 | Override this method if the subclass instance's dtype should be computed based on its other properties. 199 | Returns: dtype 200 | """ 201 | return self.default_dtype 202 | 203 | @property 204 | def aval(self): 205 | return core.ShapedArray(self.shape, self.dtype) 206 | 207 | @classmethod 208 | def default_handler(cls, primitive, *args, params=None): 209 | if params is None: 210 | params = {} 211 | return materialize_handler(primitive, *args, params=params) 212 | 213 | @abstractmethod 214 | def materialize(self): 215 | pass 216 | 217 | def tree_flatten_with_keys(self): 218 | children = [] 219 | aux_data = [] 220 | for name, is_aux in _get_names_and_aux(self): 221 | try: 222 | value = getattr(self, name) 223 | except UninitializedAval: 224 | if not _aval_discovery.get(): 225 | raise 226 | value = None 227 | if is_aux: 228 | aux_data.append(value) 229 | else: 230 | children.append((jax.tree_util.GetAttrKey(name), value)) 231 | 232 | return children, aux_data 233 | 234 | @classmethod 235 | def tree_unflatten(cls, aux_data, children): 236 | child_it = iter(children) 237 | aux_it = iter(aux_data) 238 | obj = cls.__new__(cls) 239 | for name, is_aux in _get_names_and_aux(cls): 240 | value = next(aux_it if is_aux else child_it) 241 | setattr(obj, name, value) 242 | 243 | return obj 244 | 245 | def handle_primitive(self, primitive, *args, params): 246 | handler = lu.wrap_init(partial(get_primitive_handler(primitive), primitive)) 247 | use_params = params 248 | 249 | if len(args) == 2 and self.commute_ops: 250 | args, use_params = _maybe_swap_args(primitive.name, args, use_params) 251 | 252 | # maybe_kwargs = {'params': params} if params else {} 253 | flat_args, in_tree = iu.flatten_one_implicit_layer((args, params)) 254 | flat_handler, out_tree = flatten_fun(handler, in_tree) 255 | 256 | result = use_implicit_args(flat_handler.call_wrapped)(*flat_args) 257 | return jax.tree_util.tree_unflatten(out_tree(), result) 258 | 259 | def __init_subclass__(cls, commute_ops=True, warn_on_materialize=True, **kwargs): 260 | super().__init_subclass__(**kwargs) 261 | cls.commute_ops = commute_ops 262 | cls.warn_on_materialize = warn_on_materialize 263 | 264 | if not is_dataclass(cls): 265 | raise TypeError(f"{cls.__name__} must be a dataclass") 266 | core.pytype_aval_mappings[cls] = lambda x: x.aval 267 | register_pytree_with_keys_class(cls) 268 | return cls 269 | 270 | 271 | def _get_names_and_aux(obj): 272 | for val in fields(obj): 273 | yield val.name, bool(val.metadata.get("implicit_array_aux")) 274 | 275 | 276 | def _materialize_all(it): 277 | return [ 278 | iu.materialize_nested(val) if isinstance(val, ImplicitArray) else val 279 | for val in it 280 | ] 281 | 282 | 283 | def _maybe_swap_args(op_name, args, params): 284 | if isinstance(args[0], ImplicitArray): 285 | return args, params 286 | if op_name in constants.COMMUTATIVE_OPS: 287 | return args[::-1], params 288 | 289 | return args, params 290 | 291 | 292 | class ImplicitArrayTracer(core.Tracer): 293 | def __init__(self, trace, value): 294 | super().__init__(trace) 295 | self.value = value 296 | 297 | @property 298 | def aval(self): 299 | if isinstance(self.value, ImplicitArray): 300 | return self.value.aval 301 | return core.get_aval(self.value) 302 | 303 | def full_lower(self): 304 | if isinstance(self.value, ImplicitArray): 305 | return self 306 | 307 | return core.full_lower(self.value) 308 | 309 | 310 | class ImplicitArrayTrace(core.Trace): 311 | pure = lift = lambda self, val: ImplicitArrayTracer(self, val) 312 | 313 | def process_primitive(self, primitive, tracers, params): 314 | outs = NotImplemented 315 | vals = [t.value for t in tracers] 316 | implicit_idx = next( 317 | (i for i, v in enumerate(vals) if isinstance(v, ImplicitArray)), None 318 | ) 319 | 320 | if implicit_idx is None: 321 | # No tracers, so just do default evaluation: 322 | subfuns, bind_params = primitive.get_bind_params(params) 323 | result = primitive.bind(*subfuns, *vals, **bind_params) 324 | return result 325 | 326 | outs = vals[implicit_idx].handle_primitive(primitive, *vals, params=params) 327 | 328 | if outs is NotImplemented: 329 | # For higher order primitives most users won't implement custom 330 | # logic, so there shouldn't be a warning 331 | if primitive.name in _default_handlers: 332 | outs = _default_handlers[primitive.name]( 333 | primitive, *vals, params=params 334 | ) 335 | else: 336 | implicit_cls = vals[implicit_idx].__class__ 337 | if implicit_cls.warn_on_materialize: 338 | warnings.warn( 339 | f"Primitive {primitive.name} was not handled by class {implicit_cls.__name__}, so implicit args will be materialized." 340 | ) 341 | 342 | if outs is NotImplemented: 343 | outs = vals[implicit_idx].default_handler(primitive, *vals, params=params) 344 | 345 | if primitive.multiple_results: 346 | return [ImplicitArrayTracer(self, out) for out in outs] 347 | return ImplicitArrayTracer(self, outs) 348 | 349 | 350 | def wrap_jaxpr(jaxpr, vals_with_implicits, return_closed=True): 351 | if isinstance(jaxpr, jax.core.ClosedJaxpr): 352 | literals = jaxpr.literals 353 | jaxpr = jaxpr.jaxpr 354 | else: 355 | literals = [] 356 | 357 | wrapped_fn = lu.wrap_init(use_implicit_args(partial(core.eval_jaxpr, jaxpr))) 358 | flat_args, in_tree = jax.tree_util.tree_flatten((literals, *vals_with_implicits)) 359 | flat_fn, out_tree = flatten_fun_nokwargs(wrapped_fn, in_tree) 360 | 361 | new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( 362 | flat_fn, [core.get_aval(v) for v in flat_args] 363 | ) 364 | 365 | ret = ( 366 | (jax.core.ClosedJaxpr(new_jaxpr, consts),) 367 | if return_closed 368 | else (new_jaxpr, consts) 369 | ) 370 | return *ret, flat_args, out_tree() 371 | 372 | 373 | def _transform_jaxpr_output(jaxpr, jaxpr_args, orig_out_struct, out_transform): 374 | def eval_fn(literals, *args): 375 | output = use_implicit_args(partial(core.eval_jaxpr, jaxpr.jaxpr))( 376 | literals, *args 377 | ) 378 | unflattened_output = orig_out_struct.unflatten(output) 379 | return out_transform(unflattened_output) 380 | 381 | wrapped = lu.wrap_init(eval_fn) 382 | 383 | flat_args, in_tree = jax.tree_util.tree_flatten((jaxpr.literals, *jaxpr_args)) 384 | flat_fn, out_tree = flatten_fun_nokwargs(wrapped, in_tree) 385 | new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( 386 | flat_fn, [core.get_aval(v) for v in flat_args] 387 | ) 388 | 389 | return jax.core.ClosedJaxpr(new_jaxpr, consts), out_tree() 390 | 391 | 392 | def _match_branches(branches, arg_vals): 393 | out_avals = [] 394 | new_jaxprs = [] 395 | flat_inputs = None 396 | branch_out_struct = None 397 | for branch in branches: 398 | new_jaxpr, flat_inputs, branch_out_struct = wrap_jaxpr(branch, arg_vals) 399 | new_jaxprs.append((new_jaxpr, branch_out_struct)) 400 | out_avals.append( 401 | branch_out_struct.unflatten( 402 | jax.eval_shape( 403 | partial(core.eval_jaxpr, new_jaxpr.jaxpr), 404 | new_jaxpr.literals, 405 | *flat_inputs, 406 | ) 407 | ) 408 | ) 409 | 410 | out_transforms = iu.get_common_prefix_transforms(out_avals) 411 | new_branches = [] 412 | out_struct = None 413 | for (new_jaxpr, orig_out_struct), transform in zip(new_jaxprs, out_transforms): 414 | new_jaxpr, out_struct = _transform_jaxpr_output( 415 | new_jaxpr, flat_inputs, orig_out_struct, transform 416 | ) 417 | new_branches.append(new_jaxpr) 418 | 419 | return tuple(new_branches), out_struct, flat_inputs 420 | 421 | 422 | def _handle_cond(primitive, *vals, params): 423 | cond_val, *arg_vals = vals 424 | subfuns, bind_params = primitive.get_bind_params(params) 425 | 426 | new_branches, out_struct, flat_inputs = _match_branches( 427 | params["branches"], arg_vals 428 | ) 429 | bind_params["branches"] = new_branches 430 | bind_params["linear"] = _broadcast_tuple(bind_params["linear"], arg_vals) 431 | 432 | outs = primitive.bind(*subfuns, cond_val, *flat_inputs, **bind_params) 433 | return jax.tree_util.tree_unflatten(out_struct, outs) 434 | 435 | 436 | def _handle_remat2(primitive, *vals, params): 437 | subfuns, bind_params = primitive.get_bind_params(params) 438 | new_jaxpr, consts, flat_inputs, out_tree = wrap_jaxpr( 439 | bind_params["jaxpr"], vals, return_closed=False 440 | ) 441 | new_jaxpr = pe.convert_constvars_jaxpr(new_jaxpr) 442 | bind_params["jaxpr"] = new_jaxpr 443 | outs = primitive.bind(*subfuns, *consts, *flat_inputs, **bind_params) 444 | return jax.tree_util.tree_unflatten(out_tree, outs) 445 | 446 | 447 | def _handle_pjit(primitive, *vals, params): 448 | new_jaxpr, flat_inputs, out_tree = wrap_jaxpr(params["jaxpr"], vals) 449 | donated_invars = _broadcast_tuple(params["donated_invars"], vals) 450 | in_shardings = _broadcast_tuple(params["in_shardings"], vals) 451 | out_shardings = _broadcast_tuple(params["out_shardings"], out_tree) 452 | 453 | subfuns, bind_params = primitive.get_bind_params(params) 454 | bind_params["jaxpr"] = new_jaxpr 455 | bind_params["donated_invars"] = donated_invars 456 | bind_params["in_shardings"] = in_shardings 457 | bind_params["out_shardings"] = out_shardings 458 | outs = primitive.bind(*subfuns, *flat_inputs, **bind_params) 459 | return jax.tree_util.tree_unflatten(out_tree, outs) 460 | 461 | 462 | def _handle_scan(primitive, *vals, params): 463 | n_consts = params["num_consts"] 464 | n_carry = params["num_carry"] 465 | 466 | consts = vals[:n_consts] 467 | real_n_consts = len(jax.tree_util.tree_leaves(consts)) 468 | 469 | carries = vals[n_consts : n_consts + n_carry] 470 | xs = vals[n_consts + n_carry :] 471 | 472 | if any(isinstance(c, ImplicitArray) for c in carries): 473 | warnings.warn( 474 | "ImplicitArray in scan carries are not yet supported." 475 | " If you need this feature please open an issue on the Qax repo:" 476 | " https://github.com/davisyoshida/qax/issues" 477 | ) 478 | carries = _materialize_all(carries) 479 | 480 | sliced_xs = jax.tree_map(partial(jax.eval_shape, lambda x: x[0]), xs) 481 | 482 | for x in sliced_xs: 483 | if isinstance(x, ImplicitArray): 484 | assert len(x._shape) > 0, "Attempted to scan over a scalar." 485 | x._shape = x._shape[1:] 486 | 487 | jaxpr = params["jaxpr"] 488 | new_jaxpr, _, out_tree = wrap_jaxpr( 489 | jaxpr=jaxpr, 490 | vals_with_implicits=(*consts, *carries, *sliced_xs), 491 | return_closed=True, 492 | ) 493 | 494 | flat_inputs = jax.tree_util.tree_leaves((jaxpr.literals, *consts, *carries, *xs)) 495 | 496 | subfuns, bind_params = primitive.get_bind_params(params) 497 | bind_params["jaxpr"] = new_jaxpr 498 | bind_params["num_consts"] = real_n_consts 499 | bind_params["num_carry"] = len(carries) 500 | bind_params["linear"] = _broadcast_tuple(params["linear"], vals) 501 | 502 | outs = primitive.bind(*subfuns, *flat_inputs, **bind_params) 503 | return jax.tree_util.tree_unflatten(out_tree, outs) 504 | 505 | 506 | _default_handlers = { 507 | "cond": _handle_cond, 508 | "remat2": _handle_remat2, 509 | "pjit": _handle_pjit, 510 | "scan": _handle_scan, 511 | } 512 | 513 | 514 | def materialize_handler(primitive, *vals, params): 515 | vals = _materialize_all(vals) 516 | subfuns, bind_params = primitive.get_bind_params(params) 517 | result = use_implicit_args(primitive.bind)(*subfuns, *vals, **bind_params) 518 | return result 519 | 520 | 521 | def _broadcast_tuple(t, trees): 522 | if isinstance(trees, jax.tree_util.PyTreeDef): 523 | trees = jax.tree_util.tree_unflatten(trees, range(trees.num_leaves)) 524 | assert len(t) == len(trees) 525 | return tuple( 526 | chain.from_iterable( 527 | (tuple_val for _ in jax.tree_util.tree_leaves(tree)) 528 | for tuple_val, tree in zip(t, trees) 529 | ) 530 | ) 531 | 532 | 533 | def _get_materialization_aval(imp_arr): 534 | with _aval_discovery_context(), _filter_materialization_warnings(): 535 | result = jax.eval_shape(partial(iu.materialize_nested, full=True), imp_arr) 536 | return result 537 | 538 | 539 | @contextmanager 540 | def _filter_materialization_warnings(): 541 | with warnings.catch_warnings(): 542 | warnings.filterwarnings("ignore", message="Primitive.*was not handled") 543 | yield 544 | 545 | 546 | _AVAL_ERROR_MESSAGE = ( 547 | "{} was not set during initialization. Shape and dtype may be set by:" 548 | "\n\t1. Directly passing them as keyword arguments to ImplicitArray instances" 549 | "\n\t2. Overriding the default_shape/default_dtype class attributes" 550 | "\n\t3. Overriding the compute_shape/compute_dtype methods" 551 | "\n\t4. Overriding __post_init__ and setting their values there" 552 | "\n\t5. None of the above, in which case `materialize()` will be called in an attempt to infer them." 553 | " If their values are required in order to compute the materialization this will be unsuccessful." 554 | ) 555 | -------------------------------------------------------------------------------- /examples/How_to_Qax.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "27a0df77", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install qax\n", 11 | "!pip install transformers" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "5c74d7d5", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from dataclasses import dataclass\n", 22 | "from functools import partial\n", 23 | "import warnings\n", 24 | "\n", 25 | "import jax\n", 26 | "import jax.numpy as jnp\n", 27 | "import optax\n", 28 | "import transformers\n", 29 | "import qax" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "dad94c8b", 35 | "metadata": {}, 36 | "source": [ 37 | "# Qax: If it quacks like a tensor...\n", 38 | "[Qax](https://github.com/davisyoshida/qax) is a tool for implementing types which represent tensors, but aren't actually instantiated as a single dense array on your GPU. Examples of this include:\n", 39 | "* Quantization: A 4-bit array of integers + a small number of scale values are used to represent a full 16/32-bit array\n", 40 | "* LoRA: An array $W$ is replaced by the array $(W + BA^T)$ so that $A$ and $B$ may be trained while leaving $W$ frozen\n", 41 | "* Symbolic zeros/constants: For arrays which will consist entirely of a single repeated value, simply store that single value and the shape of the array\n", 42 | "* Custom kernels: If you have a custom kernel and want to use it with existing models without modifying them, Qax is an easy way to do so\n", 43 | "* Hopefully many more things!\n", 44 | "\n", 45 | "The goal of Qax is to make implementing custom JAX behavior much easier, so that users won't need to deal with all the details of writing a full JAX transform. All you need to do to get custom representations is:\n", 46 | "\n", 47 | "1. Define what data/metadata your datatype should contain\n", 48 | "2. Optionally write any number of handlers which specify how your type behaves under JAX primitives such as multiplication\n", 49 | "3. Write a function which constructs a dense array from your implicit representation\n", 50 | "\n", 51 | "Both of the above are written in pure JAX, so no need for custom gradients (unless you want to of course!)." 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "63481ec3", 57 | "metadata": {}, 58 | "source": [ 59 | "## Installation\n", 60 | "```\n", 61 | "pip install qax\n", 62 | "```" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "7d2b7bd4", 68 | "metadata": {}, 69 | "source": [ 70 | "## Example 1: A symbolic zero\n", 71 | "The way you specify custom behavior with Qax is to subclass the `qax.ImplicitArray` abstract class. One of the simplest things we could implement is a symbolic zero: A data type which represents an arbitrary tensor full of zeros without actually instantiating them on the GPU." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "a0549d40", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "class Zeros(qax.ImplicitArray):\n", 82 | " default_dtype = jnp.float32\n", 83 | "\n", 84 | " def materialize(self):\n", 85 | " # self.shape and self.dtype will be\n", 86 | " # populated by the ImplicitArray constructor\n", 87 | " return jnp.zeros(self.shape, self.dtype)\n", 88 | " \n", 89 | " def __str__(self):\n", 90 | " return f'Zeros({self.shape}, {self.dtype})'" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "id": "69327946", 96 | "metadata": {}, 97 | "source": [ 98 | "The only mandatory method to implement when subclassing `ImplicitArray` is `materialize()`.\n", 99 | "`materialize()` specifies how to turn our _implicitly_ represented array into an _explicit_ one, i.e. a single dense JAX array. In the case of `Zeros`, we can just call `jnp.zeros`.\n", 100 | "\n", 101 | "Let's instantiate a `Zeros` instance to try it out:" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "id": "709e7de7", 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "z = Zeros(shape=(2, 3))" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "id": "2fe010e2", 117 | "metadata": {}, 118 | "source": [ 119 | "ImplicitArrays are [dataclasses](https://docs.python.org/3/library/dataclasses.html), which by default have two keyword only attributes: `shape` and `dtype`.\n", 120 | "\n", 121 | "By default JAX won't know how to use our new type. In order to use it in functions, we apply the `@use_implicit_args` decorator:" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "id": "58255519", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "@qax.use_implicit_args\n", 132 | "def f(x, y):\n", 133 | " return (x + y)[0, 0]" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "id": "a501fa3a", 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "with warnings.catch_warnings():\n", 144 | " warnings.simplefilter('always')\n", 145 | " print(f(z, jnp.ones(3)))" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "id": "9b57c0fa", 151 | "metadata": {}, 152 | "source": [ 153 | "The cool thing is that `f` doesn't need to have any idea that it will be called with `ImplicitArray` instances, so we can use this with any pre-existing model. Right now this isn't much use, since all `z` is being materialized into a dense array as soon as it's needed for a JAX operation.\n", 154 | "\n", 155 | "To make our `Zeros` do something productive, let's implement the fact that $x + 0$ is always equal to $x$. We do this using the `@qax.primitive_handler` decorator:" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "4030910f", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "def get_binop_result_shape_dtype(a, b):\n", 166 | " out_shape = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(b))\n", 167 | " out_dtype = jnp.result_type(a.dtype, b.dtype)\n", 168 | " return out_shape, out_dtype\n", 169 | "\n", 170 | "# primitive_handler() takes a string, JAX primitive, or a list of those types\n", 171 | "# strings are used to find the corresponding primitive from `jax.lax`\n", 172 | "@qax.primitive_handler('add')\n", 173 | "def my_add_handler(primitive, a : Zeros, b):\n", 174 | " # Handlers will receive as arguments:\n", 175 | " # - primitive: a jax.core.Primitive instance (often can be ignored if the handler is just for one op)\n", 176 | " # Any number of arguments which are either JAX values or ImplicitArrays\n", 177 | " # Keyword arguments specifying parameters of the operation (e.g. axes for reduction operations)\n", 178 | " \n", 179 | " out_shape, out_dtype = get_binop_result_shape_dtype(a, b)\n", 180 | " \n", 181 | " if isinstance(b, Zeros):\n", 182 | " # We can return further ImplicitArray instances if we want\n", 183 | " return Zeros(shape=out_shape, dtype=out_dtype)\n", 184 | " \n", 185 | " # Return b, possibly modifying its shape or dtype\n", 186 | " return jnp.broadcast_to(b, out_shape).astype(out_dtype)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "id": "2e0d68e9", 192 | "metadata": {}, 193 | "source": [ 194 | "The type annotation `a : Zeros` is actually important, Qax uses [Plum](https://github.com/beartype/plum) for multiple dispatch. You can even use this to define how different subclasses of ImplicitArray should interact with each other.\n", 195 | "\n", 196 | "(For convenience, commutative binary ops like $+$ and $\\times$ will automatically get their argument order switched so that the `ImplicitArray` instance comes first.)\n", 197 | "\n", 198 | "Now when we call `f`, we no longer see the materialization log message, since our add handler is skipping over ever instantiating the array of zeros:" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "id": "1e28b598", 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "print(f(z, jnp.ones(3)))" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "id": "0fe4a67d", 214 | "metadata": {}, 215 | "source": [ 216 | "Let's define a multiplication handler as well, since $x \\cdot 0 = 0$ for all $x$:" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "id": "a313cd62", 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "@qax.primitive_handler('mul')\n", 227 | "def handle_mul(primitive, a : Zeros, b):\n", 228 | " out_shape, out_dtype = get_binop_result_shape_dtype(a, b)\n", 229 | " \n", 230 | " return Zeros(shape=out_shape, dtype=out_dtype)\n", 231 | "\n", 232 | "\n", 233 | "@jax.jit\n", 234 | "@qax.use_implicit_args\n", 235 | "def g(x, y):\n", 236 | " return (1 + x) * y\n", 237 | "\n", 238 | "print(g(z, z))" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "id": "9ad2027d", 244 | "metadata": {}, 245 | "source": [ 246 | "The output of `use_implicit_args` is a function which is compatible with all the usual JAX transformations such as `jit`, `vmap`, `grad`, etc.\n", 247 | "\n", 248 | "Even this simple implementation is enough to let us modify the behavior of models which were written without knowing about Qax. Let's try replacing all the biases in HuggingFace's GPT-2 with zeros:" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "id": "1682fd45", 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "@qax.primitive_handler('broadcast_in_dim')\n", 259 | "def broadcast(primitive, a : Zeros, *, shape, broadcast_dimensions):\n", 260 | " # The biases get broadcast in order to add them to the activations\n", 261 | " # so we need to handle that case\n", 262 | " # Sometimes the simplest thing to do is use jax.eval_shape\n", 263 | " # to figure out what shape to return\n", 264 | " result_shape = jax.eval_shape(\n", 265 | " partial(jax.lax.broadcast_in_dim, shape=shape, broadcast_dimensions=broadcast_dimensions),\n", 266 | " a.aval # ImplicitArray has an aval property which will get an abstract shape/dtype\n", 267 | " ).shape\n", 268 | " return Zeros(shape=result_shape, dtype=a.dtype)\n", 269 | " \n", 270 | "\n", 271 | "model, params = transformers.FlaxAutoModelForCausalLM.from_pretrained('gpt2', _do_init=False)\n", 272 | "\n", 273 | "inputs = jnp.arange(1, 10)[None]\n", 274 | "\n", 275 | "# Helper function to switch all the biases\n", 276 | "# in the params out for some other value\n", 277 | "def replace_biases(params, replacer):\n", 278 | " def maybe_replace_val(path, val):\n", 279 | " if val.ndim != 1:\n", 280 | " return val\n", 281 | "\n", 282 | " # Skip layernorms\n", 283 | " if any(\n", 284 | " isinstance(p, jax.tree_util.DictKey) and p.key.startswith('ln')\n", 285 | " for p in path\n", 286 | " ):\n", 287 | " return val\n", 288 | " return replacer(shape=val.shape, dtype=val.dtype)\n", 289 | " return jax.tree_util.tree_map_with_path(maybe_replace_val, params)\n", 290 | "\n", 291 | "\n", 292 | "# Replace the biases with dense zero arrays:\n", 293 | "params_with_zeros = replace_biases(params, jnp.zeros)\n", 294 | "print('New bias:', params['transformer']['h']['0']['attn']['c_attn']['bias'])\n", 295 | "\n", 296 | "output = model(inputs, params=params_with_zeros).logits\n", 297 | "print('Last logit average:', jnp.mean(output[0, -1]))" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "id": "2fb9a410", 303 | "metadata": {}, 304 | "source": [ 305 | "Now let's try replacing them with our symbolic zeros instead:" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "id": "3cc5c534", 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "params_with_zeros = replace_biases(params, Zeros)\n", 316 | "print('New bias:', params['transformer']['h']['0']['attn']['c_attn']['bias'])\n", 317 | "\n", 318 | "# In this case since we're calling the model directly, we need to\n", 319 | "# wrap it so we can pass params in a positional argument\n", 320 | "# This usually won't be an issue since the call to the model will\n", 321 | "# be inside a loss function or some other function\n", 322 | "\n", 323 | "output = qax.use_implicit_args(model)(inputs, params=params_with_zeros).logits\n", 324 | "print('Last logit average:', jnp.mean(output[0, -1]))" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "id": "b219ec93", 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "del model\n", 335 | "del params" 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "id": "fc84e060", 341 | "metadata": {}, 342 | "source": [ 343 | "We got the same result, but using 0 FLOPs for adding the biases! If you really wanted to flesh out the behavior of `Zeros`, you could also add handlers defining its output for primitives such as `sin`, `cos`, etc. Let's move on to something more interesting though." 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "id": "47861972", 349 | "metadata": {}, 350 | "source": [ 351 | "## Example 2: LoRA\n", 352 | "In this example we'll implement [LoRA](https://arxiv.org/abs/2106.09685) in just a few lines of code. Unlike the `Zeros` example from the previous section, our `ImplicitArray` subclass will actually contain data this time. As such we'll need to implement flattening/unflattening logic, since all `ImplicitArray` subclasses are pytrees. This also means you can use `tree_map` and friends to manipulate them.\n", 353 | "\n", 354 | "To add child pytrees to a subclass, we just add them as dataclass attributes. To add auxilary data, you can wrap a field with `qax.aux_field` which is just a wrapper around `dataclass.field`.\n", 355 | "\n", 356 | "LoRA replaces a matrix $W$ with the matrix $W_0 + AB^T$, so we'll have three arrays as new attributes." 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "id": "14caac69", 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "@dataclass\n", 367 | "class LoraMatrix(qax.ImplicitArray):\n", 368 | " \"\"\"Represent W + A B^T\"\"\"\n", 369 | " w : qax.ArrayValue\n", 370 | " a : qax.ArrayValue\n", 371 | " b : qax.ArrayValue\n", 372 | " \n", 373 | " # auxiliary data example\n", 374 | " is_array_happy : bool = qax.aux_field(default=True)\n", 375 | " \n", 376 | " def __post_init__(self):\n", 377 | " # If you need to do any validation, you can override the __post_init__ method\n", 378 | " # This example is purely for error checking, but you can also\n", 379 | " # add manipulations of the attributes\n", 380 | " super().__post_init__()\n", 381 | " w_aval = jax.core.get_aval(self.w)\n", 382 | " a_aval = jax.core.get_aval(self.a)\n", 383 | " b_aval = jax.core.get_aval(self.b)\n", 384 | " assert w_aval.ndim == a_aval.ndim == b_aval.ndim == 2\n", 385 | " assert a_aval.shape[1] == b_aval.shape[1]\n", 386 | " assert a_aval.shape[0] == w_aval.shape[0]\n", 387 | " assert b_aval.shape[0] == w_aval.shape[1]\n", 388 | " assert a_aval.dtype == b_aval.dtype == w_aval.dtype\n", 389 | "\n", 390 | " def materialize(self):\n", 391 | " return self.w + self.a @ self.b.T\n", 392 | "\n", 393 | "@qax.primitive_handler('dot_general')\n", 394 | "def f(primitive, x : jax.Array, w : LoraMatrix, *, dimension_numbers, **kwargs):\n", 395 | " # For this example, we'll only handle the simple case of of x @ w, rather than\n", 396 | " # all possible dot_general invocations\n", 397 | " (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers\n", 398 | " \n", 399 | " # This check just makes sure that all that's happening is a simple matmul\n", 400 | " if not (\n", 401 | " len(w.shape) == 2\n", 402 | " and lhs_contract == (x.ndim - 1,)\n", 403 | " and rhs_contract == (0,)\n", 404 | " and lhs_batch == ()\n", 405 | " and rhs_batch == ()\n", 406 | " ):\n", 407 | " # If we want to only partially handle a particular primitive,\n", 408 | " # we can fall back to the default logic by returning NotImplemented\n", 409 | " return NotImplemented\n", 410 | "\n", 411 | " kwargs = {**kwargs, 'dimension_numbers': dimension_numbers}\n", 412 | " # In order to defer to the default implementation of the primitive,\n", 413 | " # use the qax.default_handler helper:\n", 414 | " result = qax.default_handler(\n", 415 | " primitive, # pass the primitive\n", 416 | " x, w.w, # Any number of positional arguments,\n", 417 | " **kwargs # Then the primitive's keyword args \n", 418 | " )\n", 419 | " \n", 420 | " xa = qax.default_handler(primitive, x, w.a, **kwargs)\n", 421 | " \n", 422 | " xab = qax.default_handler(primitive, xa, w.b.T, **kwargs)\n", 423 | "\n", 424 | " result += xab\n", 425 | " return result\n", 426 | "\n", 427 | "def lora_from_tree(tree, key, lora_dim=8):\n", 428 | " \"\"\"\n", 429 | " Helper function for replacing non-embedding weight\n", 430 | " matrices in T5 with LoraMatrix instances.\n", 431 | " \"\"\"\n", 432 | " def iter_keys(key):\n", 433 | " while True:\n", 434 | " key, k2 = jax.random.split(key)\n", 435 | " yield k2\n", 436 | " \n", 437 | " key_it = iter_keys(key)\n", 438 | " def map_fn(path, val):\n", 439 | " if val.ndim != 2:\n", 440 | " return val\n", 441 | " \n", 442 | " # Skip embedding params\n", 443 | " if any(isinstance(p, jax.tree_util.DictKey) and p.key == 'embedding' for p in path):\n", 444 | " return val\n", 445 | " \n", 446 | " a = jax.random.normal(next(key_it), (val.shape[0], lora_dim), val.dtype)\n", 447 | " b = jnp.zeros((val.shape[1], lora_dim), val.dtype)\n", 448 | " return LoraMatrix(val, a, b) \n", 449 | " \n", 450 | " return jax.tree_util.tree_map_with_path(map_fn, tree)" 451 | ] 452 | }, 453 | { 454 | "cell_type": "markdown", 455 | "id": "ce2a0958", 456 | "metadata": {}, 457 | "source": [ 458 | "Let's try it out on a T5 model:" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "id": "26e3891e", 465 | "metadata": {}, 466 | "outputs": [], 467 | "source": [ 468 | "t5, params = transformers.FlaxAutoModelForSeq2SeqLM.from_pretrained('t5-small', _do_init=False)\n", 469 | "tokenizer = transformers.AutoTokenizer.from_pretrained('t5-small')\n", 470 | "encoder_inputs = jnp.asarray(tokenizer.encode('Some input'))[None]\n", 471 | "decoder_inputs = jnp.asarray([0] + tokenizer.encode('Some output'))[None]\n", 472 | "\n", 473 | "lora_params = lora_from_tree(params, jax.random.PRNGKey(1234))" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": null, 479 | "id": "f6d38e5d", 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [ 483 | "orig_output = t5(input_ids=encoder_inputs, decoder_input_ids=decoder_inputs, params=params).logits" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "id": "afe08ddc", 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [ 493 | "lora_output = qax.use_implicit_args(t5)(\n", 494 | " input_ids=encoder_inputs,\n", 495 | " decoder_input_ids=decoder_inputs,\n", 496 | " params=lora_params\n", 497 | ").logits\n", 498 | "print(jnp.max(jnp.abs(lora_output - orig_output)))" 499 | ] 500 | }, 501 | { 502 | "cell_type": "markdown", 503 | "id": "29649785", 504 | "metadata": {}, 505 | "source": [ 506 | "The LoRA result is identical to the execution of the unmodified network, and we didn't get any materialization warnings so we successfully made a LoRA forward pass without ever calculating $W + AB^T$!" 507 | ] 508 | }, 509 | { 510 | "cell_type": "markdown", 511 | "id": "eaf56c56", 512 | "metadata": {}, 513 | "source": [ 514 | "## Training\n", 515 | "So far we haven't looked at how to train a model when using Qax. The main thing to understand is that you should apply `qax.use_implicit_args` first, _then_ differentiate the resulting function. `use_implicit_args` transforms the function into one which goes from pytrees to pytrees, so all the standard JAX autodiff machinery will work.\n", 516 | "\n", 517 | "If you need to update only a subset of the elements of an ImplicitArray instance (e.g. only `a` and `b` for LoRA), Qax provides `qax.utils.freeze_keys` to make this easier. Here's an end-to-end example training T5 to memorize the input/output pair from above:" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": null, 523 | "id": "caf64fd8", 524 | "metadata": {}, 525 | "outputs": [], 526 | "source": [ 527 | "optimizer = optax.adam(3e-4)\n", 528 | "# freeze_keys_in_optimizer takes an optax optimizer, the ImplicitArray subclass to freeze for, \n", 529 | "# and an iterable of the keys to be frozen\n", 530 | "optimizer = qax.utils.freeze_keys(optimizer, LoraMatrix, ['w'])\n", 531 | "\n", 532 | "# We're only using a single example so we'll just close over the training data\n", 533 | "# There are no code changes from an ordinary training loop other than decorating\n", 534 | "# loss_fn with @use_implicit_args\n", 535 | "\n", 536 | "@qax.use_implicit_args\n", 537 | "def loss_fn(params):\n", 538 | " decoder_ids = decoder_inputs[:, :-1]\n", 539 | " targets = decoder_inputs[:, 1:]\n", 540 | " logits = t5(\n", 541 | " input_ids=encoder_inputs,\n", 542 | " decoder_input_ids=decoder_ids,\n", 543 | " params=params\n", 544 | " ).logits\n", 545 | " \n", 546 | " logprobs = jax.nn.log_softmax(logits)\n", 547 | " target_logprobs = jnp.take_along_axis(logprobs, targets[:, :, None], axis=-1)\n", 548 | " loss = -jnp.sum(target_logprobs)\n", 549 | " return loss\n", 550 | "\n", 551 | "grad_fn = jax.value_and_grad(loss_fn)\n", 552 | "\n", 553 | "@jax.jit\n", 554 | "def update(params, opt_state):\n", 555 | " loss, grads = grad_fn(params)\n", 556 | " updates, new_opt_state = optimizer.update(grads, opt_state, params=params)\n", 557 | " new_params = optax.apply_updates(updates, params)\n", 558 | " return loss, new_params, new_opt_state\n", 559 | "\n", 560 | "opt_state = optimizer.init(lora_params)\n", 561 | "for step in range(20):\n", 562 | " loss, lora_params, opt_state = update(lora_params, opt_state)\n", 563 | " print(f'{step}. {loss:.3f}')" 564 | ] 565 | }, 566 | { 567 | "cell_type": "markdown", 568 | "id": "239aadd7", 569 | "metadata": {}, 570 | "source": [ 571 | "That's all you need to know to get started using Qax!" 572 | ] 573 | }, 574 | { 575 | "cell_type": "markdown", 576 | "id": "3d3f95e5", 577 | "metadata": {}, 578 | "source": [ 579 | "## Example 3: Nesting\n", 580 | "Qax supports arbitrary nesting of `ImplicitArray` instances without. Here's a quick demo combining the previous two examples:" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": null, 586 | "id": "4471c323", 587 | "metadata": {}, 588 | "outputs": [], 589 | "source": [ 590 | "@qax.use_implicit_args\n", 591 | "def g(w, x):\n", 592 | " return jnp.sum(x @ w)\n", 593 | "\n", 594 | "w = jnp.ones((3, 5))\n", 595 | "x = jnp.arange(3, dtype=jnp.float32)\n", 596 | "\n", 597 | "lora_with_symbolic_zero = LoraMatrix(\n", 598 | " w=w,\n", 599 | " a=Zeros(shape=(w.shape[0], 6)),\n", 600 | " b=Zeros(shape=(w.shape[1], 6))\n", 601 | ")\n", 602 | "print(f'Original: {g(w, x)}')\n", 603 | "with warnings.catch_warnings():\n", 604 | " warnings.simplefilter('always')\n", 605 | " print(f'With lora: {g(lora_with_symbolic_zero, x)}')" 606 | ] 607 | }, 608 | { 609 | "cell_type": "markdown", 610 | "id": "4dc5044c", 611 | "metadata": {}, 612 | "source": [ 613 | "If we wanted we could write a `dot_general` handler to avoid the materialization as well, but the main point is just to illustrate that it's easy to mix and match different `ImplicitArray` subclasses. A more useful example might be using a symbolic zero as the offset for a quantization datatypes which expects both an offset and a scale." 614 | ] 615 | }, 616 | { 617 | "cell_type": "markdown", 618 | "id": "2038e6d2", 619 | "metadata": {}, 620 | "source": [ 621 | "## Other examples\n", 622 | "[Here's](https://github.com/davisyoshida/abnormal-floats/blob/master/transform.py) an example of using Qax to implement a 4-bit quantized matrix representation." 623 | ] 624 | } 625 | ], 626 | "metadata": { 627 | "kernelspec": { 628 | "display_name": "Python 3 (ipykernel)", 629 | "language": "python", 630 | "name": "python3" 631 | }, 632 | "language_info": { 633 | "codemirror_mode": { 634 | "name": "ipython", 635 | "version": 3 636 | }, 637 | "file_extension": ".py", 638 | "mimetype": "text/x-python", 639 | "name": "python", 640 | "nbconvert_exporter": "python", 641 | "pygments_lexer": "ipython3", 642 | "version": "3.10.10" 643 | } 644 | }, 645 | "nbformat": 4, 646 | "nbformat_minor": 5 647 | } 648 | --------------------------------------------------------------------------------