├── .gitignore ├── LICENSE ├── README.md ├── example.py ├── jax_hypernetwork ├── __init__.py ├── embedding.py ├── hnet.py ├── pytree.py └── utils.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python ### 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | .myenv/ 41 | .vscode/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) Simon Schug 2 | All rights reserved. 3 | 4 | MIT License 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 9 | 10 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # jax-hypernetwork 2 | 3 | A simple hypernetwork implementation in [jax](https://github.com/google/jax/) using [haiku](https://github.com/deepmind/dm-haiku). 4 | 5 | ## Example 6 | 7 | In this little demo, we create a linear hypernetwork to parametrise the weights of a multilayer perceptron. For a more elaborate example see [example.py](https://github.com/smonsays/jax-hypernetwork/blob/main/example.py). 8 | 9 | ```python 10 | import haiku as hk 11 | import jax 12 | 13 | from jax_hypernetwork import LinearHypernetwork 14 | 15 | rng = jax.random.PRNGKey(0) 16 | 17 | # 1) Create the target network 18 | target_network = hk.transform(lambda x: hk.nets.MLP([10, 10], with_bias=False)(x)) 19 | params_target = target_network.init(rng, jax.numpy.empty((28 * 28))) 20 | 21 | # 2) Create the hypernetwork 22 | hnet = hk.transform( 23 | lambda: LinearHypernetwork(params_target, chunk_shape=(1, 10), embedding_dim=7)() 24 | ) 25 | params_hnet = hnet.init(rng) 26 | 27 | # 3) Use hypernetwork to parametrise the target network 28 | params_target_generated = hnet.apply(params_hnet, rng) 29 | output = target_network.apply(params_target_generated, rng, jax.numpy.empty((28 * 28))) 30 | ``` 31 | 32 | ## Features 33 | 34 | The [Hypernetwork base class](https://github.com/smonsays/jax-hypernetwork/blob/main/jax_hypernetwork/hnet.py#L27) allows to specify the 35 | 36 | - source model (weight generator) 37 | - target model (model to be parametrised) 38 | - embedding model 39 | - chunking strategy 40 | 41 | and is easy to extend e.g. to add input dependence. Two examples provided are a [LinearHypernetwork](https://github.com/smonsays/jax-hypernetwork/blob/main/jax_hypernetwork/hnet.py#L95) and a [MLPHypernetwork](https://github.com/smonsays/jax-hypernetwork/blob/main/jax_hypernetwork/hnet.py#L101). 42 | 43 | ## Install 44 | 45 | Install `jax-hypernetwork` using pip: 46 | 47 | ```bash 48 | pip install git+https://github.com/smonsays/jax-hypernetwork 49 | ``` 50 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple example illustrating how to use a hypernetwork 3 | to partially parametrise a target network. 4 | """ 5 | import jax 6 | import jax.tree_util as jtu 7 | from jax_hypernetwork import LinearHypernetwork 8 | from jax_hypernetwork.utils import dict_filter, flatten_dict, unflatten_dict 9 | 10 | import haiku as hk 11 | 12 | # Define hyperparameters 13 | input_dim = 128 14 | output_dim = 10 15 | hidden_dim = 50 16 | hidden_layers = 2 17 | embedding_dim = 10 18 | 19 | 20 | @hk.without_apply_rng 21 | @hk.transform 22 | def target_network(inputs): 23 | return hk.nets.MLP(output_sizes=hidden_layers * [hidden_dim])(inputs) 24 | 25 | 26 | # Prepare randomness 27 | rng = jax.random.PRNGKey(0) 28 | rng_input, rng_hnet, rng_target = jax.random.split(rng, 3) 29 | sample_input = jax.random.normal(rng_input, shape=(input_dim,)) 30 | 31 | # Split params into those to be generated by hnet and those to be direclty optimized 32 | params_all = target_network.init(rng_target, sample_input) 33 | params_target_bias = dict_filter(params_all, "w", all_but_key=True) 34 | params_target_weights = dict_filter(params_all, "w") 35 | 36 | 37 | @hk.without_apply_rng 38 | @hk.transform 39 | def hnet(): 40 | return LinearHypernetwork( 41 | params_target=params_target_weights, 42 | chunk_shape=(1, hidden_dim), 43 | embedding_dim=embedding_dim, 44 | )() 45 | 46 | 47 | # Generate weights using the hypernetwork 48 | params_hnet = hnet.init(rng_hnet) 49 | params_target_weights_generated = hnet.apply(params_hnet) 50 | 51 | # Check that PyTreeDef for generated weights matches those of the original params 52 | assert ( 53 | jtu.tree_flatten(params_target_weights)[1] 54 | == jtu.tree_flatten(params_target_weights_generated)[1] 55 | ) 56 | 57 | # To use with target_network, combine generated and non-generated params 58 | params_target = unflatten_dict({ 59 | **flatten_dict(params_target_bias), 60 | **flatten_dict(params_target_weights_generated) 61 | }) 62 | 63 | sample_output = target_network.apply(params_target, sample_input) 64 | -------------------------------------------------------------------------------- /jax_hypernetwork/__init__.py: -------------------------------------------------------------------------------- 1 | from .hnet import Hypernetwork, LinearHypernetwork, MLPHypernetwork 2 | 3 | __all__ = [ 4 | "Hypernetwork", 5 | "LinearHypernetwork", 6 | "MLPHypernetwork", 7 | ] 8 | -------------------------------------------------------------------------------- /jax_hypernetwork/embedding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Simon Schug 3 | All rights reserved. 4 | 5 | MIT License 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 12 | """ 13 | from typing import Optional 14 | import haiku as hk 15 | 16 | 17 | class Embedding(hk.Module): 18 | def __init__(self, num_embeddings: int, embedding_dim: int, name: Optional[str] = None): 19 | super().__init__(name=name) 20 | self.num_embeddings = num_embeddings 21 | self.embedding_dim = embedding_dim 22 | 23 | @property 24 | def embeddings(self): 25 | return hk.get_parameter( 26 | name="embeddings", 27 | shape=[self.num_embeddings, self.embedding_dim], 28 | init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform") # lecun_uniform 29 | ) 30 | 31 | def __call__(self): 32 | return self.embeddings 33 | -------------------------------------------------------------------------------- /jax_hypernetwork/hnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Simon Schug 3 | All rights reserved. 4 | 5 | MIT License 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 12 | """ 13 | import abc 14 | import math 15 | from operator import add, floordiv 16 | from typing import Optional, Union 17 | 18 | import haiku as hk 19 | import jax.numpy as jnp 20 | import jax.tree_util as jtu 21 | 22 | from jax_hypernetwork.embedding import Embedding 23 | from jax_hypernetwork.pytree import is_tuple_of_ints 24 | from jax_hypernetwork.pytree import PytreeReshaper 25 | 26 | 27 | class Hypernetwork(abc.ABC, hk.Module): 28 | """ 29 | Abstract base class for all Hypernetworks. 30 | 31 | Args: 32 | params_target: dictionary with the params to be generated by the hypernetwork 33 | chunk_shape: the shape of a single chunk the weight_generator generates. 34 | The chunk shape needs to divide all target_shapes without remainder. 35 | For example if params_target have shapes=[(1, 50), (50, 50)], 36 | then valid chunk_shapes would be chunk_shape=(1, 50), chunk_shape=(1, 1) 37 | but chunk_shape != (3, 50) or chunk_shape != (50, 3) would be invalid. 38 | If set to `None`, the hypernetwork doesn't use chunking. 39 | embedding_dim: the embedding dimension 40 | """ 41 | 42 | def __init__( 43 | self, 44 | params_target: dict, 45 | chunk_shape: Union[tuple, None], 46 | embedding_dim: int, 47 | name: Optional[str] = None, 48 | ): 49 | super().__init__(name=name) 50 | self.params_shape = jtu.tree_map(jnp.shape, params_target) 51 | self.chunk_shape = chunk_shape 52 | self.embedding_dim = embedding_dim 53 | self.num_chunks, self.dim_chunks = Hypernetwork.get_chunk_sizes( 54 | self.params_shape, self.chunk_shape 55 | ) 56 | self.reshaper = PytreeReshaper(self.params_shape) 57 | 58 | @abc.abstractproperty 59 | def weight_generator(self): 60 | """ 61 | Source network to generate the params. 62 | The output dimension should equal `self.dim_chunks` 63 | """ 64 | pass 65 | 66 | @property 67 | def embedding(self): 68 | return Embedding(self.num_chunks, self.embedding_dim) 69 | 70 | @staticmethod 71 | def get_chunk_sizes(params_shape: dict, chunk_shape: tuple): 72 | if chunk_shape is None: 73 | # No chunking means all params are generated at once 74 | num_chunks = 1 75 | dim_chunks_tree = jtu.tree_map(math.prod, params_shape, is_leaf=is_tuple_of_ints) 76 | dim_chunks = jtu.tree_reduce(add, dim_chunks_tree, is_leaf=is_tuple_of_ints) 77 | 78 | else: 79 | # Assuming chunk sizes divides all param shapes without remainder 80 | def num_chunks_fn(shape): 81 | return math.prod(jtu.tree_map(floordiv, shape, chunk_shape)) 82 | 83 | num_chunks_tree = jtu.tree_map(num_chunks_fn, params_shape, is_leaf=is_tuple_of_ints) 84 | num_chunks = jtu.tree_reduce(add, num_chunks_tree) 85 | dim_chunks = math.prod(chunk_shape) 86 | 87 | return num_chunks, dim_chunks 88 | 89 | def __call__(self): 90 | params_flat = hk.vmap(self.weight_generator, split_rng=False)(self.embedding()) 91 | 92 | return self.reshaper(params_flat.reshape(-1)) 93 | 94 | 95 | class LinearHypernetwork(Hypernetwork): 96 | @property 97 | def weight_generator(self): 98 | return hk.Linear(self.dim_chunks) 99 | 100 | 101 | class MLPHypernetwork(Hypernetwork): 102 | def __init__(self, params_shape, chunk_shape, embedding_dim, hidden_dims, name=None): 103 | super().__init__(params_shape, chunk_shape, embedding_dim, name) 104 | self.hidden_dims = hidden_dims 105 | 106 | @property 107 | def weight_generator(self): 108 | return hk.nets.MLP(self.hidden_dims + [self.dim_chunks]) 109 | -------------------------------------------------------------------------------- /jax_hypernetwork/pytree.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Simon Schug 3 | All rights reserved. 4 | 5 | MIT License 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 12 | """ 13 | import math 14 | from typing import Any 15 | 16 | import jax.numpy as jnp 17 | import jax.tree_util as jtu 18 | import numpy as np 19 | 20 | PyTree = Any 21 | 22 | 23 | class PytreeReshaper: 24 | def __init__(self, tree_shapes: PyTree): 25 | self.shapes, self.treedef = jtu.tree_flatten( 26 | tree_shapes, is_leaf=is_tuple_of_ints 27 | ) 28 | sizes = [math.prod(shape) for shape in self.shapes] 29 | 30 | self.split_indeces = list(np.cumsum(sizes)[:-1]) 31 | self.num_elements = sum(sizes) 32 | 33 | def __call__(self, array_flat: jnp.ndarray): 34 | arrays_split = jnp.split(array_flat, self.split_indeces) 35 | arrays_reshaped = [a.reshape(shape) for a, shape in zip(arrays_split, self.shapes)] 36 | 37 | return jtu.tree_unflatten(self.treedef, arrays_reshaped) 38 | 39 | 40 | def is_tuple_of_ints(x: Any): 41 | return isinstance(x, tuple) and all(isinstance(v, int) for v in x) 42 | -------------------------------------------------------------------------------- /jax_hypernetwork/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Simon Schug 3 | All rights reserved. 4 | 5 | MIT License 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 12 | """ 13 | from typing import Optional, Tuple, Union 14 | 15 | 16 | def dict_filter(d: dict, key: str, all_but_key: Optional[bool] = False): 17 | """ 18 | Filter a dictionary by key returing only entries that contain the key. 19 | Returns the complement if all_but_key=True. 20 | """ 21 | 22 | def match_key_tuples(key1: Union[str, Tuple[str]], key2: Tuple[str]): 23 | """ 24 | Check if key1 is contained in key2. 25 | """ 26 | if isinstance(key1, str): 27 | return key1 in key2 28 | else: 29 | return all(k1 in k2 for k1, k2 in zip(key1, key2)) 30 | 31 | d_flat = flatten_dict(d) 32 | if not all_but_key: 33 | d_flat_filtered = {k: v for k, v in d_flat.items() if match_key_tuples(key, k)} 34 | else: 35 | d_flat_filtered = {k: v for k, v in d_flat.items() if not match_key_tuples(key, k)} 36 | 37 | d_filtered = unflatten_dict(d_flat_filtered) 38 | 39 | return d_filtered 40 | 41 | 42 | def flatten_dict(d: dict, parent_key: Optional[str] = ''): 43 | """ 44 | Flatten nested dictionary combining keys into tuples. 45 | """ 46 | items = [] 47 | for k, v in d.items(): 48 | new_key = (parent_key, k) if parent_key else k 49 | 50 | if isinstance(v, dict): 51 | items.extend(flatten_dict(v, new_key).items()) 52 | else: 53 | items.append((new_key, v)) 54 | 55 | return dict(items) 56 | 57 | 58 | def unflatten_dict(d_flat: dict): 59 | """ 60 | Unflatten a dictionary from tuple keys. 61 | """ 62 | assert isinstance(d_flat, dict) 63 | result = dict() 64 | for path, value in d_flat.items(): 65 | cursor = result 66 | for key in path[:-1]: 67 | if key not in cursor: 68 | cursor[key] = dict() 69 | cursor = cursor[key] 70 | 71 | cursor[path[-1]] = value 72 | 73 | return result 74 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_namespace_packages, setup 2 | 3 | setup( 4 | name="jax-hypernetwork", 5 | version="0.0.1", 6 | description="A simple hypernetwork implementation in jax using haiku.", 7 | author="smonsays", 8 | url="https://github.com/smonsays/jax-hypernetwork", 9 | license='MIT', 10 | install_requires=["dm_haiku"], 11 | packages=find_namespace_packages(), 12 | ) 13 | --------------------------------------------------------------------------------