├── tests ├── __init__.py ├── test_haiku.py ├── utils.py ├── test_objax.py ├── test_core_save.py ├── test_utils.py ├── conftest.py ├── test_flax.py └── test_core_load.py ├── src └── safejax │ ├── py.typed │ ├── core │ ├── __init__.py │ ├── load.py │ └── save.py │ ├── __init__.py │ ├── haiku.py │ ├── flax.py │ ├── typing.py │ ├── objax.py │ └── utils.py ├── .github ├── FUNDING.yml ├── pull_request_template.md └── workflows │ └── ci-cd.yaml ├── docs ├── api │ ├── utils.md │ ├── core_load.md │ └── core_save.md ├── requirements.md ├── installation.md ├── index.md ├── license.md ├── why_safejax.md ├── usage.md └── examples.md ├── benchmarks ├── resnet50.py ├── hyperfine │ ├── resnet50.py │ └── single_layer.py └── single_layer.py ├── examples ├── objax_ft_safejax.py ├── flax_ft_safejax.py ├── haiku_ft_safejax.py └── README.md ├── LICENSE ├── .pre-commit-config.yaml ├── mkdocs.yml ├── .gitignore ├── pyproject.toml ├── requirements ├── requirements.txt ├── requirements-test.txt ├── requirements-dev.txt └── requirements-docs.txt └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/safejax/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/safejax/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: alvarobartt 2 | -------------------------------------------------------------------------------- /docs/api/utils.md: -------------------------------------------------------------------------------- 1 | ::: safejax.utils 2 | handler: python 3 | -------------------------------------------------------------------------------- /docs/api/core_load.md: -------------------------------------------------------------------------------- 1 | ::: safejax.core.load 2 | handler: python 3 | -------------------------------------------------------------------------------- /docs/api/core_save.md: -------------------------------------------------------------------------------- 1 | ::: safejax.core.save 2 | handler: python 3 | -------------------------------------------------------------------------------- /docs/requirements.md: -------------------------------------------------------------------------------- 1 | # 🛠️ Requirements 2 | 3 | `safejax` requires Python 3.7 or above 4 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # ⬇️ Installation 2 | 3 | ```bash 4 | pip install safejax --upgrade 5 | ``` 6 | -------------------------------------------------------------------------------- /src/safejax/__init__.py: -------------------------------------------------------------------------------- 1 | """`safejax `: Serialize JAX, Flax, Haiku, or Objax model params with `safetensors`""" 2 | 3 | __author__ = "Alvaro Bartolome " 4 | __version__ = "0.5.0" 5 | 6 | from safejax.core.load import deserialize # noqa: F401 7 | from safejax.core.save import serialize # noqa: F401 8 | -------------------------------------------------------------------------------- /src/safejax/haiku.py: -------------------------------------------------------------------------------- 1 | from safejax.core.load import deserialize # noqa: F401 2 | from safejax.core.save import serialize # noqa: F401 3 | 4 | # Nothing here as `dm-haiku` works with the default behavior of both 5 | # `safejax.core.load.deserialize` and `safejax.core.save.serialize`. But 6 | # placing this here for consistency with the other frameworks. 7 | -------------------------------------------------------------------------------- /src/safejax/flax.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from safejax.core.load import deserialize 4 | from safejax.core.save import serialize # noqa: F401 5 | 6 | # `flax` expects either a `Dict[str, Any` or a `FrozenDict`, but for robustness we are 7 | # setting `freeze_dict` to `True` to restore a `FrozenDict` which contains the params 8 | # frozen to avoid any accidental mutation. 9 | deserialize = partial(deserialize, freeze_dict=True) 10 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## ✨ Features 2 | 3 | - List 4 | - implemented 5 | - features 6 | - here 7 | 8 | ## 🐛 Bug Fixes 9 | 10 | - Listed 11 | - fixed 12 | - bugs 13 | - here 14 | 15 | ## 🔗 Linked Issue/s 16 | 17 | Add here the reference to the issue/s referenced in this PR 18 | 19 | ## 🧪 Tests 20 | 21 | - [ ] Did you implement unit tests if you need to? 22 | 23 | If the above checkbox is checked, could you describe how you unit-tested it? 24 | -------------------------------------------------------------------------------- /benchmarks/resnet50.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter 2 | 3 | import jax 4 | from flax.serialization import to_bytes 5 | from flaxmodels.resnet import ResNet50 6 | from jax import numpy as jnp 7 | 8 | from safejax import serialize 9 | 10 | resnet50 = ResNet50() 11 | params = resnet50.init(jax.random.PRNGKey(42), jnp.ones((1, 224, 224, 3))) 12 | 13 | 14 | start_time = perf_counter() 15 | for _ in range(100): 16 | serialize(params) 17 | end_time = perf_counter() 18 | print(f"safejax (100 runs): {end_time - start_time:0.4f} s") 19 | 20 | start_time = perf_counter() 21 | for _ in range(100): 22 | to_bytes(params) 23 | end_time = perf_counter() 24 | print(f"flax (100 runs): {end_time - start_time:0.4f} s") 25 | -------------------------------------------------------------------------------- /src/safejax/typing.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict, Union 3 | 4 | import numpy as np 5 | from flax.core.frozen_dict import FrozenDict 6 | from jax import numpy as jnp 7 | from objax.variable import BaseVar, StateVar, VarCollection 8 | 9 | PathLike = Union[str, Path] 10 | 11 | NumpyArrayDict = Dict[str, np.ndarray] 12 | JaxDeviceArrayDict = Dict[str, jnp.DeviceArray] 13 | HaikuParams = Dict[str, JaxDeviceArrayDict] 14 | ObjaxDict = Dict[str, Union[BaseVar, StateVar]] 15 | ObjaxParams = Union[VarCollection, ObjaxDict] 16 | FlaxParams = Union[Dict[str, Union[Dict, JaxDeviceArrayDict]], FrozenDict] 17 | 18 | ParamsDictLike = Union[JaxDeviceArrayDict, HaikuParams, ObjaxParams, FlaxParams] 19 | -------------------------------------------------------------------------------- /benchmarks/hyperfine/resnet50.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import jax 4 | from flax.serialization import to_bytes 5 | from flaxmodels.resnet import ResNet50 6 | from jax import numpy as jnp 7 | 8 | from safejax import serialize 9 | 10 | resnet50 = ResNet50() 11 | params = resnet50.init(jax.random.PRNGKey(42), jnp.ones((1, 224, 224, 3))) 12 | 13 | 14 | def serialization_safejax(): 15 | _ = serialize(params) 16 | 17 | 18 | def serialization_flax(): 19 | _ = to_bytes(params) 20 | 21 | 22 | if __name__ == "__main__": 23 | if len(sys.argv) < 2: 24 | raise ValueError("Please provide a function name to run as an argument") 25 | if sys.argv[1] not in globals(): 26 | raise ValueError(f"Function {sys.argv[1]} not found") 27 | globals()[sys.argv[1]]() 28 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # 🔐 Serialize JAX, Flax, Haiku, or Objax model params with `safetensors` 2 | 3 | `safejax` is a Python package to serialize JAX, Flax, Haiku, or Objax model params using `safetensors` 4 | as the tensor storage format, instead of relying on `pickle`. For more details on why 5 | `safetensors` is safer than `pickle` please check [huggingface/safetensors](https://github.com/huggingface/safetensors). 6 | 7 | Note that `safejax` supports the serialization of `jax`, `flax`, `dm-haiku`, and `objax` model 8 | parameters and has been tested with all those frameworks, but there may be some cases where it 9 | does not work as expected, as this is still in an early development phase, so please if you have 10 | any feedback or bug reports, open an issue at [safejax/issues](https://github.com/alvarobartt/safejax/issues). 11 | -------------------------------------------------------------------------------- /examples/objax_ft_safejax.py: -------------------------------------------------------------------------------- 1 | from jax import numpy as jnp 2 | from objax.zoo.resnet_v2 import ResNet50 3 | 4 | from safejax import deserialize, serialize 5 | 6 | model = ResNet50(in_channels=3, num_classes=1000) 7 | 8 | encoded_bytes = serialize(params=model.vars()) 9 | assert isinstance(encoded_bytes, bytes) 10 | assert len(encoded_bytes) > 0 11 | 12 | decoded_params = deserialize( 13 | encoded_bytes, requires_unflattening=False, to_var_collection=True 14 | ) 15 | assert isinstance(decoded_params, dict) 16 | assert len(decoded_params) > 0 17 | assert decoded_params.keys() == model.vars().keys() 18 | 19 | for key, value in decoded_params.items(): 20 | if key not in model.vars(): 21 | print(f"Key {key} not in model.vars()! Skipping.") 22 | continue 23 | model.vars()[key].assign(value) 24 | 25 | x = jnp.ones((1, 3, 224, 224)) 26 | y = model(x, training=False) 27 | assert y.shape == (1, 1000) 28 | -------------------------------------------------------------------------------- /benchmarks/single_layer.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter 2 | 3 | import jax 4 | from flax import linen as nn 5 | from flax.serialization import to_bytes 6 | from jax import numpy as jnp 7 | 8 | from safejax import serialize 9 | 10 | 11 | class SingleLayerModel(nn.Module): 12 | features: int 13 | 14 | @nn.compact 15 | def __call__(self, x): 16 | x = nn.Dense(features=self.features)(x) 17 | return x 18 | 19 | 20 | model = SingleLayerModel(features=1) 21 | 22 | rng = jax.random.PRNGKey(0) 23 | params = model.init(rng, jnp.ones((1, 1))) 24 | 25 | 26 | start_time = perf_counter() 27 | for _ in range(100): 28 | serialize(params) 29 | end_time = perf_counter() 30 | print(f"safejax (100 runs): {end_time - start_time:0.4f} s") 31 | 32 | start_time = perf_counter() 33 | for _ in range(100): 34 | to_bytes(params) 35 | end_time = perf_counter() 36 | print(f"flax (100 runs): {end_time - start_time:0.4f} s") 37 | -------------------------------------------------------------------------------- /benchmarks/hyperfine/single_layer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from flax import linen as nn 6 | from flax.serialization import to_bytes 7 | 8 | from safejax import serialize 9 | 10 | 11 | class SingleLayerModel(nn.Module): 12 | features: int 13 | 14 | @nn.compact 15 | def __call__(self, x): 16 | x = nn.Dense(features=self.features)(x) 17 | return x 18 | 19 | 20 | model = SingleLayerModel(features=1) 21 | 22 | rng = jax.random.PRNGKey(0) 23 | params = model.init(rng, jnp.ones((1, 1))) 24 | 25 | 26 | def serialization_safejax(): 27 | _ = serialize(params) 28 | 29 | 30 | def serialization_flax(): 31 | _ = to_bytes(params) 32 | 33 | 34 | if __name__ == "__main__": 35 | if len(sys.argv) < 2: 36 | raise ValueError("Please provide a function name to run as an argument") 37 | if sys.argv[1] not in globals(): 38 | raise ValueError(f"Function {sys.argv[1]} not found") 39 | globals()[sys.argv[1]]() 40 | -------------------------------------------------------------------------------- /examples/flax_ft_safejax.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from flax import linen as nn 3 | from flax.core.frozen_dict import FrozenDict 4 | from jax import numpy as jnp 5 | 6 | from safejax import deserialize, serialize 7 | 8 | 9 | class SingleLayerModel(nn.Module): 10 | features: int 11 | 12 | @nn.compact 13 | def __call__(self, x): 14 | x = nn.Dense(features=self.features)(x) 15 | return x 16 | 17 | 18 | network = SingleLayerModel(features=1) 19 | 20 | rng_key = jax.random.PRNGKey(seed=0) 21 | initial_params = network.init(rng_key, jnp.ones((1, 1))) 22 | 23 | encoded_bytes = serialize(params=initial_params) 24 | assert isinstance(encoded_bytes, bytes) 25 | assert len(encoded_bytes) > 0 26 | 27 | decoded_params = deserialize(path_or_buf=encoded_bytes, freeze_dict=True) 28 | assert isinstance(decoded_params, FrozenDict) 29 | assert len(decoded_params) > 0 30 | assert decoded_params.keys() == initial_params.keys() 31 | 32 | x = jnp.ones((1, 1)) 33 | y = network.apply(decoded_params, x) 34 | assert y.shape == (1, 1) 35 | -------------------------------------------------------------------------------- /examples/haiku_ft_safejax.py: -------------------------------------------------------------------------------- 1 | import haiku as hk 2 | import jax 3 | from jax import numpy as jnp 4 | 5 | from safejax import deserialize, serialize 6 | 7 | 8 | def resnet_fn(x: jnp.DeviceArray, is_training: bool): 9 | resnet = hk.nets.ResNet50(num_classes=10) 10 | return resnet(x, is_training=is_training) 11 | 12 | 13 | network = hk.without_apply_rng(hk.transform_with_state(resnet_fn)) 14 | 15 | rng_key = jax.random.PRNGKey(seed=0) 16 | initial_params, initial_state = network.init( 17 | rng_key, jnp.ones([1, 224, 224, 3]), is_training=True 18 | ) 19 | 20 | encoded_bytes = serialize(params=initial_params) 21 | assert isinstance(encoded_bytes, bytes) 22 | assert len(encoded_bytes) > 0 23 | 24 | decoded_params = deserialize(path_or_buf=encoded_bytes) 25 | assert isinstance(decoded_params, dict) 26 | assert len(decoded_params) > 0 27 | assert decoded_params.keys() == initial_params.keys() 28 | 29 | x = jnp.ones([1, 224, 224, 3]) 30 | y, _ = network.apply(decoded_params, initial_state, x, is_training=False) 31 | assert y.shape == (1, 10) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022-present Alvaro Bartolome 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /docs/license.md: -------------------------------------------------------------------------------- 1 | # 📝 License 2 | 3 | MIT License 4 | 5 | Copyright (c) 2022-present Alvaro Bartolome 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: "v4.3.0" 4 | hooks: 5 | - id: check-added-large-files 6 | - id: check-toml 7 | - id: check-yaml 8 | 9 | - repo: https://github.com/psf/black 10 | rev: 22.10.0 11 | hooks: 12 | - id: black 13 | args: ["--preview"] 14 | language_version: python3 15 | 16 | - repo: https://github.com/charliermarsh/ruff-pre-commit 17 | rev: "v0.0.194" 18 | hooks: 19 | - id: ruff 20 | args: [--fix] 21 | 22 | - repo: https://github.com/jazzband/pip-tools 23 | rev: 6.12.0 24 | hooks: 25 | - id: pip-compile 26 | files: requirements/requirements.txt 27 | args: ["--output-file=requirements/requirements.txt", "pyproject.toml"] 28 | - id: pip-compile 29 | files: requirements/requirements-dev.txt 30 | args: 31 | [ 32 | "--extra=quality", 33 | "--output-file=requirements/requirements-dev.txt", 34 | "pyproject.toml", 35 | ] 36 | - id: pip-compile 37 | files: requirements/requirements-test.txt 38 | args: 39 | [ 40 | "--extra=test", 41 | "--output-file=requirements/requirements-test.txt", 42 | "pyproject.toml", 43 | ] 44 | - id: pip-compile 45 | files: requirements/requirements-docs.txt 46 | args: 47 | [ 48 | "--extra=docs", 49 | "--output-file=requirements/requirements-docs.txt", 50 | "pyproject.toml", 51 | ] 52 | -------------------------------------------------------------------------------- /tests/test_haiku.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from safejax.haiku import deserialize, serialize 6 | from safejax.typing import HaikuParams 7 | 8 | from .utils import assert_over_trees 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "params", 13 | [ 14 | pytest.lazy_fixture("haiku_resnet50_params"), 15 | ], 16 | ) 17 | def test_serialize_and_deserialize(params: HaikuParams) -> None: 18 | encoded_params = serialize(params=params) 19 | assert isinstance(encoded_params, bytes) 20 | assert len(encoded_params) > 0 21 | 22 | decoded_params = deserialize(path_or_buf=encoded_params) 23 | assert isinstance(decoded_params, dict) 24 | assert len(decoded_params) > 0 25 | assert id(decoded_params) != id(params) 26 | assert decoded_params.keys() == params.keys() 27 | 28 | assert_over_trees(params=params, decoded_params=decoded_params) 29 | 30 | 31 | @pytest.mark.parametrize( 32 | "params", 33 | [ 34 | pytest.lazy_fixture("haiku_resnet50_params"), 35 | ], 36 | ) 37 | @pytest.mark.usefixtures("safetensors_file") 38 | def test_serialize_and_deserialize_from_file( 39 | params: HaikuParams, safetensors_file: Path 40 | ) -> None: 41 | safetensors_file = serialize(params=params, filename=safetensors_file) 42 | assert isinstance(safetensors_file, Path) 43 | assert safetensors_file.exists() 44 | 45 | decoded_params = deserialize(path_or_buf=safetensors_file) 46 | assert isinstance(decoded_params, dict) 47 | assert len(decoded_params) > 0 48 | assert id(decoded_params) != id(params) 49 | assert decoded_params.keys() == params.keys() 50 | 51 | assert_over_trees(params=params, decoded_params=decoded_params) 52 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: 🔐 safejax 2 | site_url: https://github.com/alvarobartt/safejax 3 | site_author: Alvaro Bartolome 4 | site_description: Serialize JAX, Flax, Haiku, or Objax model params with `safetensors` 5 | 6 | repo_name: alvarobartt/safejax 7 | repo_url: https://github.com/alvarobartt/safejax 8 | 9 | copyright: Copyright (c) 2022-present Alvaro Bartolome 10 | 11 | theme: 12 | name: material 13 | palette: 14 | - scheme: default 15 | toggle: 16 | icon: material/brightness-7 17 | name: Switch to dark mode 18 | - scheme: slate 19 | toggle: 20 | icon: material/brightness-4 21 | name: Switch to light mode 22 | font: 23 | text: Roboto 24 | code: Roboto Mono 25 | 26 | markdown_extensions: 27 | - pymdownx.highlight: 28 | anchor_linenums: true 29 | - pymdownx.superfences 30 | 31 | plugins: 32 | - search: 33 | - git-revision-date-localized: 34 | type: timeago 35 | enable_creation_date: true 36 | - mkdocstrings: 37 | 38 | extra: 39 | social: 40 | - icon: fontawesome/brands/python 41 | link: https://pypi.org/project/safejax/ 42 | - icon: fontawesome/brands/github 43 | link: https://github.com/alvarobartt 44 | - icon: fontawesome/brands/twitter 45 | link: https://twitter.com/alvarobartt 46 | - icon: fontawesome/brands/linkedin 47 | link: https://www.linkedin.com/in/alvarobartt/ 48 | 49 | nav: 50 | - Home: index.md 51 | - Requirements: requirements.md 52 | - Installation: installation.md 53 | - Usage: usage.md 54 | - Why safejax?: why_safejax.md 55 | - Examples: examples.md 56 | - Reference: 57 | - safejax.core.load: api/core_load.md 58 | - safejax.core.save: api/core_save.md 59 | - safejax.utils: api/utils.md 60 | - License: license.md 61 | -------------------------------------------------------------------------------- /docs/why_safejax.md: -------------------------------------------------------------------------------- 1 | # 🤔 Why `safejax`? 2 | 3 | `safetensors` defines an easy and fast (zero-copy) format to store tensors, 4 | while `pickle` has some known weaknesses and security issues. `safetensors` 5 | is also a storage format that is intended to be trivial to the framework 6 | used to load the tensors. More in-depth information can be found at 7 | [huggingface/safetensors](https://github.com/huggingface/safetensors). 8 | 9 | `jax` uses `pytrees` to store the model parameters in memory, so 10 | it's a dictionary-like class containing nested `jnp.DeviceArray` tensors. 11 | 12 | `dm-haiku` uses a custom dictionary formatted as `/~/`, where the 13 | levels are the ones that define the tree structure and `/~/` is the separator between those 14 | e.g. `res_net50/~/intial_conv`, and that key does not contain a `jnp.DeviceArray`, but a 15 | dictionary with key value pairs e.g. for both weights as `w` and biases as `b`. 16 | 17 | `objax` defines a custom dictionary-like class named `VarCollection` that contains 18 | some variables inheriting from `BaseVar` which is another custom `objax` type. 19 | 20 | `flax` defines a dictionary-like class named `FrozenDict` that is used to 21 | store the tensors in memory, it can be dumped either into `bytes` in `MessagePack` 22 | format or as a `state_dict`. 23 | 24 | There are no plans from HuggingFace to extend `safetensors` to support anything more than tensors 25 | e.g. `FrozenDict`s, see their response at [huggingface/safetensors/discussions/138](https://github.com/huggingface/safetensors/discussions/138). 26 | 27 | So the motivation to create `safejax` is to easily provide a way to serialize `FrozenDict`s 28 | using `safetensors` as the tensor storage format instead of `pickle`, as well as to provide 29 | a common and easy way to serialize and deserialize any JAX model params (Flax, Haiku, or Objax) 30 | using `safetensors` format. 31 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import chex 4 | import jax 5 | from flax.core.frozen_dict import FrozenDict, unfreeze 6 | from objax.variable import VarCollection 7 | 8 | from safejax.typing import ParamsDictLike 9 | 10 | 11 | def assert_over_trees(params: ParamsDictLike, decoded_params: ParamsDictLike) -> None: 12 | """Assertions using `chex` to compare two trees of parameters. 13 | 14 | Note: 15 | This function does not support `objax.variable.VarCollection` objects yet, 16 | so the assertions are just done over `jax`, `flax`, and `haiku` params. 17 | 18 | Args: 19 | params: a `ParamsDictLike` object with the original parameters. 20 | decoded_params: a `ParamsDictLike` object with the decoded parameters using `safejax`. 21 | 22 | Raises: 23 | AssertionError: if the two trees are not equal on dtype, shape, structure, and values. 24 | """ 25 | if isinstance(params, VarCollection) or isinstance(decoded_params, VarCollection): 26 | warnings.warn( 27 | "This function does not support `objax.variable.VarCollection` objects yet." 28 | ) 29 | else: 30 | params = unfreeze(params) if isinstance(params, FrozenDict) else params 31 | decoded_params = ( 32 | unfreeze(decoded_params) 33 | if isinstance(decoded_params, FrozenDict) 34 | else decoded_params 35 | ) 36 | params_tree = jax.tree_util.tree_map(lambda x: x, params) 37 | decoded_params_tree = jax.tree_util.tree_map(lambda x: x, decoded_params) 38 | 39 | chex.assert_trees_all_close( 40 | params_tree, decoded_params_tree 41 | ) # static and jittable static 42 | chex.assert_trees_all_equal_dtypes(params_tree, decoded_params_tree) 43 | chex.assert_trees_all_equal_shapes(params_tree, decoded_params_tree) 44 | chex.assert_trees_all_equal_structs(params_tree, decoded_params_tree) 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # PyCharm default idea folder 107 | .idea/ 108 | 109 | # VSCode Files 110 | .vscode/ 111 | 112 | # Hatch files 113 | .hatch/ 114 | 115 | # Ruff cache 116 | .ruff_cache/ 117 | -------------------------------------------------------------------------------- /tests/test_objax.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | from jax import numpy as jnp 5 | from objax.variable import VarCollection 6 | 7 | from safejax.objax import deserialize, deserialize_with_assignment, serialize 8 | from safejax.typing import ObjaxParams 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "params", 13 | [ 14 | pytest.lazy_fixture("objax_single_layer_params"), 15 | pytest.lazy_fixture("objax_resnet50_params"), 16 | ], 17 | ) 18 | def test_serialize_and_deserialize(params: ObjaxParams) -> None: 19 | encoded_params = serialize(params=params) 20 | assert isinstance(encoded_params, bytes) 21 | assert len(encoded_params) > 0 22 | 23 | decoded_params = deserialize(path_or_buf=encoded_params) 24 | assert isinstance(decoded_params, VarCollection) 25 | assert len(decoded_params) > 0 26 | assert id(decoded_params) != id(params) 27 | assert decoded_params.keys() == params.keys() 28 | 29 | 30 | @pytest.mark.parametrize( 31 | "params", 32 | [ 33 | pytest.lazy_fixture("objax_single_layer_params"), 34 | pytest.lazy_fixture("objax_resnet50_params"), 35 | ], 36 | ) 37 | @pytest.mark.usefixtures("safetensors_file") 38 | def test_serialize_and_deserialize_from_file( 39 | params: ObjaxParams, safetensors_file: Path 40 | ) -> None: 41 | safetensors_file = serialize(params=params, filename=safetensors_file) 42 | assert isinstance(safetensors_file, Path) 43 | assert safetensors_file.exists() 44 | 45 | decoded_params = deserialize(path_or_buf=safetensors_file) 46 | assert isinstance(decoded_params, VarCollection) 47 | assert len(decoded_params) > 0 48 | assert id(decoded_params) != id(params) 49 | assert decoded_params.keys() == params.keys() 50 | 51 | 52 | @pytest.mark.parametrize( 53 | "params", 54 | [ 55 | pytest.lazy_fixture("objax_single_layer_params"), 56 | pytest.lazy_fixture("objax_resnet50_params"), 57 | ], 58 | ) 59 | @pytest.mark.usefixtures("safetensors_file") 60 | def test_serialize_and_deserialize_with_assignment( 61 | params: ObjaxParams, safetensors_file: Path 62 | ) -> None: 63 | safetensors_file = serialize(params=params, filename=safetensors_file) 64 | assert isinstance(safetensors_file, Path) 65 | assert safetensors_file.exists() 66 | 67 | # Assign jnp.zeros to all params.tensors() to make sure the assignment is working 68 | # before we deserialize the params. 69 | params.assign([jnp.zeros(x.shape, x.dtype) for x in params.tensors()]) 70 | assert all([jnp.all(x == 0) for x in params.tensors()]) 71 | 72 | deserialize_with_assignment(filename=safetensors_file, model_vars=params) 73 | assert isinstance(params, VarCollection) 74 | assert len(params) > 0 75 | assert not all([jnp.all(x != 0) for x in params.tensors()]) 76 | -------------------------------------------------------------------------------- /tests/test_core_save.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict 3 | 4 | import pytest 5 | from fsspec.spec import AbstractFileSystem 6 | 7 | from safejax.core.save import serialize 8 | from safejax.utils import ParamsDictLike 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "params", 13 | [ 14 | pytest.lazy_fixture("flax_single_layer_params"), 15 | pytest.lazy_fixture("flax_resnet50_params"), 16 | pytest.lazy_fixture("objax_resnet50_params"), 17 | pytest.lazy_fixture("haiku_resnet50_params"), 18 | ], 19 | ) 20 | def test_serialize(params: ParamsDictLike) -> None: 21 | encoded_params = serialize(params=params) 22 | assert isinstance(encoded_params, bytes) 23 | assert len(encoded_params) > 0 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "params", 28 | [ 29 | pytest.lazy_fixture("flax_single_layer_params"), 30 | pytest.lazy_fixture("flax_resnet50_params"), 31 | pytest.lazy_fixture("objax_resnet50_params"), 32 | pytest.lazy_fixture("haiku_resnet50_params"), 33 | ], 34 | ) 35 | @pytest.mark.usefixtures("metadata") 36 | def test_serialize_with_metadata( 37 | params: ParamsDictLike, metadata: Dict[str, str] 38 | ) -> None: 39 | encoded_params = serialize(params=params, metadata=metadata) 40 | assert isinstance(encoded_params, bytes) 41 | assert len(encoded_params) > 0 42 | assert encoded_params.__contains__(b"metadata") 43 | 44 | 45 | @pytest.mark.parametrize( 46 | "params", 47 | [ 48 | pytest.lazy_fixture("flax_single_layer_params"), 49 | pytest.lazy_fixture("flax_resnet50_params"), 50 | pytest.lazy_fixture("objax_resnet50_params"), 51 | pytest.lazy_fixture("haiku_resnet50_params"), 52 | ], 53 | ) 54 | @pytest.mark.usefixtures("safetensors_file") 55 | def test_serialize_to_file(params: ParamsDictLike, safetensors_file: Path) -> None: 56 | safetensors_file = serialize(params=params, filename=safetensors_file) 57 | assert isinstance(safetensors_file, Path) 58 | assert safetensors_file.exists() 59 | 60 | 61 | @pytest.mark.parametrize( 62 | "params", 63 | [ 64 | pytest.lazy_fixture("flax_single_layer_params"), 65 | pytest.lazy_fixture("flax_resnet50_params"), 66 | pytest.lazy_fixture("objax_resnet50_params"), 67 | pytest.lazy_fixture("haiku_resnet50_params"), 68 | ], 69 | ) 70 | @pytest.mark.usefixtures("safetensors_file", "fs") 71 | def test_serialize_to_file_in_fs( 72 | params: ParamsDictLike, safetensors_file: Path, fs: AbstractFileSystem 73 | ) -> None: 74 | safetensors_file = serialize(params=params, filename=safetensors_file, fs=fs) 75 | assert isinstance(safetensors_file, Path) 76 | assert safetensors_file.exists() 77 | assert safetensors_file.as_posix() in fs.ls(safetensors_file.parent.as_posix()) 78 | -------------------------------------------------------------------------------- /.github/workflows/ci-cd.yaml: -------------------------------------------------------------------------------- 1 | name: ci-cd 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | branches: 7 | - main 8 | push: 9 | branches: 10 | - main 11 | paths: 12 | - .github/workflows/ci-cd.yaml 13 | - src/** 14 | - tests/** 15 | release: 16 | types: 17 | - published 18 | 19 | jobs: 20 | check-quality: 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - name: checkout 25 | uses: actions/checkout@v3 26 | 27 | - name: setup-python 28 | uses: actions/setup-python@v4 29 | with: 30 | python-version: 3.8 31 | 32 | - name: install-dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install ".[quality]" 36 | 37 | - name: check-quality 38 | run: | 39 | ruff src tests benchmarks examples 40 | black --check --diff --preview src tests benchmarks examples 41 | 42 | run-tests: 43 | needs: check-quality 44 | 45 | runs-on: ubuntu-latest 46 | 47 | steps: 48 | - name: checkout 49 | uses: actions/checkout@v3 50 | 51 | - name: setup-python 52 | uses: actions/setup-python@v4 53 | with: 54 | python-version: 3.8 55 | 56 | - name: install-dependencies 57 | run: | 58 | python -m pip install --upgrade pip 59 | pip install ".[tests]" 60 | 61 | - name: run-tests 62 | run: pytest tests/ -s --durations 0 --disable-warnings 63 | 64 | deploy-docs: 65 | needs: run-tests 66 | 67 | runs-on: ubuntu-latest 68 | 69 | steps: 70 | - name: checkout 71 | uses: actions/checkout@v3 72 | 73 | - name: setup-python 74 | uses: actions/setup-python@v4 75 | with: 76 | python-version: 3.8 77 | 78 | - name: install-dependencies 79 | run: | 80 | python -m pip install --upgrade pip 81 | pip install -e ".[docs]" 82 | 83 | - name: deploy-to-gh-pages 84 | run: mkdocs gh-deploy --force 85 | 86 | publish-package: 87 | needs: deploy-docs 88 | if: github.event_name == 'release' 89 | 90 | runs-on: ubuntu-latest 91 | 92 | steps: 93 | - name: checkout 94 | uses: actions/checkout@v3 95 | 96 | - name: setup-python 97 | uses: actions/setup-python@v4 98 | with: 99 | python-version: 3.8 100 | 101 | - name: install-dependencies 102 | run: | 103 | python -m pip install --upgrade pip 104 | pip install hatch 105 | 106 | - name: build-package 107 | run: hatch build 108 | 109 | - name: publish-package 110 | run: hatch publish --user __token__ --auth $PYPI_TOKEN 111 | env: 112 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 113 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "hatchling.build" 3 | requires = ["hatchling"] 4 | 5 | [project] 6 | authors = [{name = "Alvaro Bartolome", email = "alvarobartt@yahoo.com"}] 7 | classifiers = [ 8 | "Development Status :: 4 - Beta", 9 | "Programming Language :: Python", 10 | "Programming Language :: Python :: 3.7", 11 | "Programming Language :: Python :: 3.8", 12 | "Programming Language :: Python :: 3.9", 13 | "Programming Language :: Python :: 3.10", 14 | "Programming Language :: Python :: 3.11", 15 | "Intended Audience :: Developers", 16 | "Intended Audience :: Science/Research", 17 | "Topic :: Software Development :: Libraries", 18 | "Topic :: Software Development :: Libraries :: Python Modules", 19 | ] 20 | dependencies = [ 21 | "jaxlib~=0.3.25", 22 | "jax~=0.3.25", 23 | "objax~=1.6.0", 24 | "flax~=0.6.2", 25 | "dm-haiku~=0.0.9", 26 | "safetensors~=0.2.5", 27 | "fsspec~=2022.11.0", 28 | ] 29 | description = "Serialize JAX, Flax, Haiku, or Objax model params with `safetensors`" 30 | dynamic = ["version"] 31 | keywords = [] 32 | license = "MIT" 33 | name = "safejax" 34 | readme = "README.md" 35 | requires-python = ">=3.7" 36 | 37 | [project.urls] 38 | Documentation = "https://alvarobartt.github.io/safejax" 39 | Issues = "https://github.com/alvarobartt/safejax/issues" 40 | Source = "https://github.com/alvarobartt/safejax" 41 | 42 | [tool.hatch.version] 43 | path = "src/safejax/__init__.py" 44 | 45 | [project.optional-dependencies] 46 | docs = [ 47 | "mkdocs~=1.4.0", 48 | "mkdocs-material~=8.5.4", 49 | "mkdocs-git-revision-date-localized-plugin~=1.1.0", 50 | "mkdocstrings[python]~=0.19.0", 51 | ] 52 | quality = [ 53 | "black~=22.10.0", 54 | "ruff~=0.0.194", 55 | "pip-tools~=6.12.0", 56 | "pre-commit~=2.20.0", 57 | ] 58 | tests = [ 59 | "pytest~=7.1.2", 60 | "pytest-lazy-fixture~=0.6.3", 61 | "flaxmodels~=0.1.2", 62 | ] 63 | 64 | [tool.hatch.envs.quality] 65 | features = [ 66 | "quality", 67 | ] 68 | 69 | [tool.hatch.envs.quality.scripts] 70 | check = [ 71 | "ruff src tests benchmarks examples", 72 | "black --check --diff --preview src tests benchmarks examples", 73 | ] 74 | format = [ 75 | "ruff --fix src tests benchmarks examples", 76 | "black --preview src tests benchmarks examples", 77 | "check", 78 | ] 79 | 80 | [tool.isort] 81 | profile = "black" 82 | 83 | [tool.ruff] 84 | ignore = [ 85 | "E501", # line too long, handled by black 86 | "B008", # do not perform function calls in argument defaults 87 | "C901", # too complex 88 | ] 89 | select = [ 90 | "E", # pycodestyle errors 91 | "W", # pycodestyle warnings 92 | "F", # pyflakes 93 | "I", # isort 94 | "C", # flake8-comprehensions 95 | "B", # flake8-bugbear 96 | ] 97 | 98 | [tool.ruff.isort] 99 | known-first-party = ["safejax"] 100 | 101 | [tool.hatch.envs.test] 102 | features = [ 103 | "tests", 104 | ] 105 | 106 | [tool.hatch.envs.test.scripts] 107 | run = "pytest -s --durations 0 --disable-warnings" 108 | 109 | [[tool.hatch.envs.test.matrix]] 110 | python = ["37", "38", "39", "310"] 111 | 112 | [tool.hatch.envs.docs] 113 | features = [ 114 | "docs", 115 | ] 116 | 117 | [tool.hatch.envs.docs.scripts] 118 | build = [ 119 | "mkdocs build", 120 | ] 121 | serve = [ 122 | "mkdocs serve", 123 | ] 124 | 125 | [tool.hatch.build.targets.sdist] 126 | exclude = [ 127 | "/.github", 128 | "/docs", 129 | "/.pre-commit-config.yaml", 130 | "/.gitignore", 131 | "/tests", 132 | ] 133 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import pytest 4 | from jax import numpy as jnp 5 | 6 | from safejax.utils import flatten_dict, unflatten_dict 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "input_dict, expected_output_dict", 11 | [ 12 | ( 13 | {"a": jnp.zeros(1), "b": jnp.zeros(1)}, 14 | {"a": jnp.zeros(1), "b": jnp.zeros(1)}, 15 | ), 16 | ( 17 | {"a.b": jnp.zeros(1), "b": jnp.zeros(1)}, 18 | {"a": {"b": jnp.zeros(1)}, "b": jnp.zeros(1)}, 19 | ), 20 | ( 21 | {"a.b": jnp.zeros(1), "a.c": jnp.zeros(1), "b": jnp.zeros(1)}, 22 | {"a": {"b": jnp.zeros(1), "c": jnp.zeros(1)}, "b": jnp.zeros(1)}, 23 | ), 24 | ( 25 | { 26 | "a.b.c": jnp.zeros(1), 27 | "a.b.d": jnp.zeros(1), 28 | "a.e": jnp.zeros(1), 29 | "b": jnp.zeros(1), 30 | }, 31 | { 32 | "a": {"b": {"c": jnp.zeros(1), "d": jnp.zeros(1)}, "e": jnp.zeros(1)}, 33 | "b": jnp.zeros(1), 34 | }, 35 | ), 36 | ( 37 | { 38 | "a.b.c": jnp.zeros(1), 39 | "a.b.d": jnp.zeros(1), 40 | "a.e": jnp.zeros(1), 41 | "b": jnp.zeros(1), 42 | "c": jnp.zeros(1), 43 | }, 44 | { 45 | "a": {"b": {"c": jnp.zeros(1), "d": jnp.zeros(1)}, "e": jnp.zeros(1)}, 46 | "b": jnp.zeros(1), 47 | "c": jnp.zeros(1), 48 | }, 49 | ), 50 | ], 51 | ) 52 | def test_unflatten_dict( 53 | input_dict: Dict[str, Any], expected_output_dict: Dict[str, Any] 54 | ) -> None: 55 | unflattened_dict = unflatten_dict(params=input_dict) 56 | assert unflattened_dict == expected_output_dict 57 | 58 | 59 | @pytest.mark.parametrize( 60 | "input_dict, expected_output_dict", 61 | [ 62 | ( 63 | {"a": {"b": jnp.zeros(1)}, "b": jnp.zeros(1)}, 64 | {"a.b": jnp.zeros(1), "b": jnp.zeros(1)}, 65 | ), 66 | ( 67 | {"a": {"b": jnp.zeros(1), "c": jnp.zeros(1)}, "b": jnp.zeros(1)}, 68 | {"a.b": jnp.zeros(1), "a.c": jnp.zeros(1), "b": jnp.zeros(1)}, 69 | ), 70 | ( 71 | { 72 | "a": {"b": {"c": jnp.zeros(1), "d": jnp.zeros(1)}, "e": jnp.zeros(1)}, 73 | "b": jnp.zeros(1), 74 | }, 75 | { 76 | "a.b.c": jnp.zeros(1), 77 | "a.b.d": jnp.zeros(1), 78 | "a.e": jnp.zeros(1), 79 | "b": jnp.zeros(1), 80 | }, 81 | ), 82 | ( 83 | { 84 | "a": {"b": {"c": jnp.zeros(1), "d": jnp.zeros(1)}, "e": jnp.zeros(1)}, 85 | "b": jnp.zeros(1), 86 | "c": jnp.zeros(1), 87 | }, 88 | { 89 | "a.b.c": jnp.zeros(1), 90 | "a.b.d": jnp.zeros(1), 91 | "a.e": jnp.zeros(1), 92 | "b": jnp.zeros(1), 93 | "c": jnp.zeros(1), 94 | }, 95 | ), 96 | ], 97 | ) 98 | def test_flatten_dict( 99 | input_dict: Dict[str, Any], expected_output_dict: Dict[str, Any] 100 | ) -> None: 101 | flattened_dict = flatten_dict(params=input_dict) 102 | assert flattened_dict == expected_output_dict 103 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict 3 | 4 | import fsspec 5 | import haiku as hk 6 | import jax 7 | import jax.numpy as jnp 8 | import objax 9 | import pytest 10 | from flax import linen as nn 11 | from flax.core.frozen_dict import FrozenDict 12 | from flaxmodels.resnet import ResNet50 as FlaxResNet50 13 | from fsspec.spec import AbstractFileSystem 14 | from objax.variable import VarCollection 15 | from objax.zoo.resnet_v2 import ResNet50 as ObjaxResNet50 16 | 17 | 18 | @pytest.fixture 19 | def flax_single_layer() -> nn.Module: 20 | class SingleLayer(nn.Module): 21 | @nn.compact 22 | def __call__(self, x): 23 | x = nn.Dense(features=1)(x) 24 | return x 25 | 26 | return SingleLayer() 27 | 28 | 29 | @pytest.fixture 30 | def flax_single_layer_params(flax_single_layer: nn.Module) -> FrozenDict: 31 | # https://docs.pytest.org/en/7.1.x/how-to/fixtures.html#fixtures-can-request-other-fixtures 32 | rng = jax.random.PRNGKey(0) 33 | params = flax_single_layer.init(rng, jnp.ones((1, 1))) 34 | return params 35 | 36 | 37 | @pytest.fixture 38 | def objax_single_layer() -> objax.nn.Sequential: 39 | return objax.nn.Sequential( 40 | [ 41 | objax.nn.Linear(1, 1), 42 | ] 43 | ) 44 | 45 | 46 | @pytest.fixture 47 | def objax_single_layer_params(objax_single_layer: objax.nn.Sequential) -> VarCollection: 48 | # https://docs.pytest.org/en/7.1.x/how-to/fixtures.html#fixtures-can-request-other-fixtures 49 | return objax_single_layer.vars() 50 | 51 | 52 | @pytest.fixture 53 | def flax_resnet() -> nn.Module: 54 | return FlaxResNet50() 55 | 56 | 57 | @pytest.fixture 58 | def flax_resnet50_params(flax_resnet: nn.Module) -> FrozenDict: 59 | rng = jax.random.PRNGKey(0) 60 | params = flax_resnet.init(rng, jnp.ones((1, 224, 224, 3))) 61 | return params 62 | 63 | 64 | @pytest.fixture 65 | def haiku_resnet50() -> hk.TransformedWithState: 66 | def resnet_fn(x: jnp.DeviceArray, is_training: bool) -> hk.Module: 67 | resnet = hk.nets.ResNet50(num_classes=10) 68 | return resnet(x, is_training=is_training) 69 | 70 | return hk.without_apply_rng(hk.transform_with_state(resnet_fn)) 71 | 72 | 73 | @pytest.fixture 74 | def haiku_resnet50_params(haiku_resnet50: hk.TransformedWithState) -> FrozenDict: 75 | rng = jax.random.PRNGKey(0) 76 | params, _ = haiku_resnet50.init(rng, jnp.ones((1, 224, 224, 3)), is_training=True) 77 | return params 78 | 79 | 80 | @pytest.fixture 81 | def objax_resnet50() -> objax.nn.Sequential: 82 | return ObjaxResNet50(in_channels=3, num_classes=1000) 83 | 84 | 85 | @pytest.fixture 86 | def objax_resnet50_params(objax_resnet50: objax.nn.Sequential) -> VarCollection: 87 | return objax_resnet50.vars() 88 | 89 | 90 | @pytest.fixture 91 | def metadata() -> Dict[str, str]: 92 | return { 93 | "test": "test", 94 | } 95 | 96 | 97 | @pytest.fixture(scope="session") 98 | def safetensors_file(tmp_path_factory) -> Path: 99 | # https://docs.pytest.org/en/7.1.x/how-to/tmp_path.html#the-tmp-path-factory-fixture 100 | return Path(tmp_path_factory.mktemp("data") / "params.safetensors") 101 | 102 | 103 | @pytest.fixture(scope="session") 104 | def msgpack_file(tmp_path_factory) -> Path: 105 | return Path(tmp_path_factory.mktemp("data") / "params.msgpack") 106 | 107 | 108 | @pytest.fixture 109 | def fs() -> AbstractFileSystem: 110 | return fsspec.filesystem("file") 111 | -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | # 💻 Usage 2 | 3 | ## `flax` 4 | 5 | * Convert `params` to `bytes` in memory 6 | 7 | ```python 8 | from safejax.flax import serialize, deserialize 9 | 10 | params = model.init(...) 11 | 12 | encoded_bytes = serialize(params) 13 | decoded_params = deserialize(encoded_bytes) 14 | 15 | model.apply(decoded_params, ...) 16 | ``` 17 | 18 | * Convert `params` to `bytes` in `params.safetensors` file 19 | 20 | ```python 21 | from safejax.flax import serialize, deserialize 22 | 23 | params = model.init(...) 24 | 25 | encoded_bytes = serialize(params, filename="./params.safetensors") 26 | decoded_params = deserialize("./params.safetensors") 27 | 28 | model.apply(decoded_params, ...) 29 | ``` 30 | 31 | --- 32 | 33 | ## `dm-haiku` 34 | 35 | * Just contains `params` 36 | 37 | ```python 38 | from safejax.haiku import serialize, deserialize 39 | 40 | params = model.init(...) 41 | 42 | encoded_bytes = serialize(params) 43 | decoded_params = deserialize(encoded_bytes) 44 | 45 | model.apply(decoded_params, ...) 46 | ``` 47 | 48 | * If it contains `params` and `state` e.g. ExponentialMovingAverage in BatchNorm 49 | 50 | ```python 51 | from safejax.haiku import serialize, deserialize 52 | 53 | params, state = model.init(...) 54 | params_state = {"params": params, "state": state} 55 | 56 | encoded_bytes = serialize(params_state) 57 | decoded_params_state = deserialize(encoded_bytes) # .keys() contains `params` and `state` 58 | 59 | model.apply(decoded_params_state["params"], decoded_params_state["state"], ...) 60 | ``` 61 | 62 | * If it contains `params` and `state`, but we want to serialize those individually 63 | 64 | ```python 65 | from safejax.haiku import serialize, deserialize 66 | 67 | params, state = model.init(...) 68 | 69 | encoded_bytes = serialize(params) 70 | decoded_params = deserialize(encoded_bytes) 71 | 72 | encoded_bytes = serialize(state) 73 | decoded_state = deserialize(encoded_bytes) 74 | 75 | model.apply(decoded_params, decoded_state, ...) 76 | ``` 77 | 78 | --- 79 | 80 | ## `objax` 81 | 82 | * Convert `params` to `bytes` in memory, and convert back to `VarCollection` 83 | 84 | ```python 85 | from safejax.objax import serialize, deserialize 86 | 87 | params = model.vars() 88 | 89 | encoded_bytes = serialize(params=params) 90 | decoded_params = deserialize(encoded_bytes) 91 | 92 | for key, value in decoded_params.items(): 93 | if key in model.vars(): 94 | model.vars()[key].assign(value.value) 95 | 96 | model(...) 97 | ``` 98 | 99 | * Convert `params` to `bytes` in `params.safetensors` file 100 | 101 | ```python 102 | from safejax.objax import serialize, deserialize 103 | 104 | params = model.vars() 105 | 106 | encoded_bytes = serialize(params=params, filename="./params.safetensors") 107 | decoded_params = deserialize("./params.safetensors") 108 | 109 | for key, value in decoded_params.items(): 110 | if key in model.vars(): 111 | model.vars()[key].assign(value.value) 112 | 113 | model(...) 114 | ``` 115 | 116 | * Convert `params` to `bytes` in `params.safetensors` and assign during deserialization 117 | 118 | ```python 119 | from safejax.objax import serialize, deserialize_with_assignment 120 | 121 | params = model.vars() 122 | 123 | encoded_bytes = serialize(params=params, filename="./params.safetensors") 124 | deserialize_with_assignment(filename="./params.safetensors", model_vars=params) 125 | 126 | model(...) 127 | ``` 128 | 129 | --- 130 | 131 | More in-detail examples can be found at [`examples/`](https://github.com/alvarobartt/safejax/examples) 132 | for `flax`, `dm-haiku`, and `objax`. 133 | -------------------------------------------------------------------------------- /src/safejax/objax.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from pathlib import Path 3 | 4 | from objax.variable import VarCollection 5 | from safetensors import safe_open 6 | 7 | from safejax.core.load import deserialize 8 | from safejax.core.save import serialize # noqa: F401 9 | from safejax.typing import PathLike 10 | from safejax.utils import OBJAX_VARIABLE_SEPARATOR 11 | 12 | # `objax` params are usually defined as a `VarCollection`, and that's basically a dictionary with 13 | # key-value pairs where the value is either a `BaseVar` or a `StateVar`. The problem is that when 14 | # serializing those variables by default we just keep the value which is a `jnp.DeviceArray`, so we 15 | # need to provide `include_objax_variables=True` to store the variable type names as part of the key 16 | # using `::` as the separator. This is useful when deserializing the params, as we can restore a 17 | # `VarCollection` object instead of a `Dict[str, jnp.DeviceArray]`. 18 | serialize = partial(serialize, include_objax_variables=True) 19 | 20 | # `objax` expects either a `Dict[str, jnp.DeviceArray]` or a `VarCollection` as model params 21 | # which means any other type of `Dict` will not work. The only difference is that `VarCollection` can 22 | # be assigned directly to `.vars()` while `Dict[str, jnp.DeviceArray]` needs to be manually assigned 23 | # when looping over `.vars()`. Ideally, we want to restore the params from a `VarCollection`, that's why 24 | # we've set the `to_var_collection` parameter to `True` by default. 25 | deserialize = partial(deserialize, requires_unflattening=False, to_var_collection=True) 26 | 27 | 28 | # When calling `deserialize` over an `objax.variable.VarCollection` object, those variables cannot 29 | # be used directly for the inference, as the forward pass in `objax` is done through `__call__`, which 30 | # implies that the class instance must contain the model params loaded in `.vars` attribute. So this 31 | # function has been created in order to ease the parameter loading for `objax`, since as opposed to 32 | # `flax` and `haiku`, the model params are not provided on every forward pass. 33 | def deserialize_with_assignment(filename: PathLike, model_vars: VarCollection) -> None: 34 | """Deserialize a `VarCollection` from a file and assign it to a `VarCollection` object. 35 | 36 | Note: 37 | This function avoid some known issues related to the variable deserialization with `objax`, 38 | since the params are stored in a `VarCollection` object, which contains some `objax.variable` 39 | variables instead of key-tensor pais. So this way we avoid having to restore the `objax.variable` 40 | type per each value. 41 | 42 | Args: 43 | filename: Path to the file containing the serialized `VarCollection` as a `Dict[str, jnp.DeviceArray]`. 44 | model_vars: `VarCollection` object to which the deserialized tensors will be assigned. 45 | 46 | Returns: 47 | `None`, as the deserialized tensors are assigned to the `model_vars` object. So you 48 | just need to access `model_vars`, or the actual `model.vars()` attribute, since the 49 | assignment is done over a class attribute named `vars`. 50 | """ 51 | if not isinstance(filename, (str, Path)): 52 | raise ValueError( 53 | "`filename` must be a `str` or a `pathlib.Path` object, not" 54 | f" {type(filename)}." 55 | ) 56 | filename = filename if isinstance(filename, Path) else Path(filename) 57 | if not filename.exists or not filename.is_file: 58 | raise ValueError(f"`filename` must be a valid file path, not {filename}.") 59 | with safe_open(filename.as_posix(), framework="jax") as f: 60 | for key in f.keys(): 61 | just_key = ( 62 | key.split(OBJAX_VARIABLE_SEPARATOR)[0] 63 | if OBJAX_VARIABLE_SEPARATOR in key 64 | else key 65 | ) 66 | if just_key not in model_vars.keys(): 67 | raise ValueError(f"Variable with name {key} not found in model_vars.") 68 | model_vars[just_key].assign(f.get_tensor(key)) 69 | -------------------------------------------------------------------------------- /src/safejax/core/load.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from typing import Dict, Tuple, Union 4 | 5 | from flax.core.frozen_dict import freeze 6 | from fsspec import AbstractFileSystem 7 | from objax.variable import VarCollection 8 | from safetensors import safe_open 9 | from safetensors.flax import load 10 | 11 | from safejax.typing import ParamsDictLike, PathLike 12 | from safejax.utils import cast_objax_variables, unflatten_dict 13 | 14 | 15 | def deserialize( 16 | path_or_buf: Union[PathLike, bytes], 17 | fs: Union[AbstractFileSystem, None] = None, 18 | freeze_dict: bool = False, 19 | requires_unflattening: bool = True, 20 | to_var_collection: bool = False, 21 | ) -> Union[ParamsDictLike, Tuple[ParamsDictLike, Dict[str, str]]]: 22 | """ 23 | Deserialize JAX, Flax, Haiku, or Objax model params from either a `bytes` object or a file path, 24 | stored using `safetensors.flax.save_file` or directly saved using `safejax.save.serialize` with 25 | the `filename` parameter. 26 | 27 | Note: 28 | The default behavior of this function is to restore a `Dict[str, jnp.DeviceArray]` from 29 | a `bytes` object or a file path. If you are using `objax`, you should set `requires_unflattening` 30 | to `False` and `to_var_collection` to `True` to restore a `VarCollection`. If you're using `flax` you 31 | should set `freeze_dict` to `True` to restore a `FrozenDict`. Those are just tips on how to use it 32 | but all those frameworks are compatible with the default behavior. 33 | 34 | Args: 35 | path_or_buf: 36 | A `bytes` object or a file path containing the serialized model params. 37 | fs: The filesystem to use to load the model params. Defaults to `None`. 38 | freeze_dict: 39 | Whether to freeze the output `Dict` to be a `FrozenDict` or not. Defaults to `False`. 40 | requires_unflattening: 41 | Whether the model params require unflattening or not. Defaults to `True`. 42 | to_var_collection: 43 | Whether to convert the output `Dict` to a `VarCollection` or not. Defaults to `False`. 44 | 45 | Returns: 46 | A `Dict[str, jnp.DeviceArray]`, `FrozenDict`, or `VarCollection` containing the model params, 47 | or in case `path_or_buf` is a filename and `metadata` is not empty, a tuple containing the 48 | model params and the metadata (in that order). 49 | """ 50 | metadata = {} 51 | if isinstance(path_or_buf, bytes): 52 | decoded_params = load(data=path_or_buf) 53 | elif isinstance(path_or_buf, (str, Path)): 54 | if fs and fs.protocol != "file": 55 | if not isinstance(fs, AbstractFileSystem): 56 | raise ValueError( 57 | "`fs` must be a `fsspec.AbstractFileSystem` object or `None`," 58 | f" not {type(fs)}." 59 | ) 60 | with fs.open(path_or_buf, "rb") as f: 61 | decoded_params = load(data=f.read()) 62 | else: 63 | if fs and fs.protocol == "file": 64 | filename = Path(fs._strip_protocol(path_or_buf)) 65 | else: 66 | filename = ( 67 | path_or_buf if isinstance(path_or_buf, Path) else Path(path_or_buf) 68 | ) 69 | if not filename.exists or not filename.is_file: 70 | raise ValueError( 71 | f"`path_or_buf` must be a valid file path, not {path_or_buf}." 72 | ) 73 | decoded_params = {} 74 | with safe_open(filename.as_posix(), framework="jax") as f: 75 | metadata = f.metadata() 76 | for k in f.keys(): 77 | decoded_params[k] = f.get_tensor(k) 78 | else: 79 | raise ValueError( 80 | "`path_or_buf` must be a `bytes` object or a file path (`str` or" 81 | f" `pathlib.Path` object), not {type(path_or_buf)}." 82 | ) 83 | if to_var_collection: 84 | try: 85 | return VarCollection(cast_objax_variables(params=decoded_params)) 86 | except ValueError as e: 87 | warnings.warn(e) 88 | return decoded_params 89 | if requires_unflattening: 90 | decoded_params = unflatten_dict(params=decoded_params) 91 | if freeze_dict: 92 | return freeze(decoded_params) 93 | if metadata and len(metadata) > 0: 94 | return decoded_params, metadata 95 | return decoded_params 96 | -------------------------------------------------------------------------------- /src/safejax/core/save.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import warnings 4 | from pathlib import Path 5 | from typing import Dict, Union 6 | 7 | from fsspec import AbstractFileSystem 8 | from safetensors.flax import save, save_file 9 | 10 | from safejax.typing import ParamsDictLike, PathLike 11 | from safejax.utils import flatten_dict 12 | 13 | 14 | def serialize( 15 | params: ParamsDictLike, 16 | metadata: Union[None, Dict[str, str]] = None, 17 | include_objax_variables: bool = False, 18 | filename: Union[PathLike, None] = None, 19 | fs: Union[AbstractFileSystem, None] = None, 20 | ) -> Union[bytes, PathLike]: 21 | """ 22 | Serialize JAX, Flax, Haiku, or Objax model params from either `FrozenDict`, `Dict`, or `VarCollection`. 23 | 24 | If `filename` is not provided, the serialized model is returned as a `bytes` object, 25 | otherwise the model is saved to the provided `filename` and the `filename` is returned. 26 | 27 | Args: 28 | params: A `FrozenDict`, a `Dict` or a `VarCollection` containing the model params. 29 | metadata: A `Dict` containing the metadata to be saved along with the model params. 30 | include_objax_variables: Whether to include `objax.Variable` objects in the serialized model params. 31 | filename: The path to the file where the model params will be saved. 32 | fs: The filesystem to use to save the model params. Defaults to `None`. 33 | 34 | Returns: 35 | The serialized model params as a `bytes` object or the path to the file where the model params were saved. 36 | """ 37 | params = flatten_dict( 38 | params=params, include_objax_variables=include_objax_variables 39 | ) 40 | 41 | if metadata: 42 | if any( 43 | not isinstance(key, str) or not isinstance(value, str) 44 | for key, value in metadata.items() 45 | ): 46 | raise ValueError( 47 | "If `metadata` is provided (not `None`), it must be a `Dict[str, str]`" 48 | " object. From the `safetensors` documentation: 'Optional text only" 49 | " metadata you might want to save in your header. For instance it can" 50 | " be useful to specify more about the underlying tensors. This is" 51 | " purely informative and does not affect tensor loading.'" 52 | ) 53 | if not filename: 54 | warnings.warn( 55 | "`metadata` param will be ignored when trying to `deserialize` from" 56 | " bytes, if you want to save the `metadata` to be loaded later, you can" 57 | " set the `filename` param to dump the `metadata` along with the model" 58 | " params in a file, either to be loaded back using `deserialize` from" 59 | " `path_or_buf` or using `safetensors.safe_open`. More information at" 60 | " https://github.com/huggingface/safetensors/issues/147." 61 | ) 62 | if filename: 63 | if not isinstance(filename, (str, Path)): 64 | raise ValueError( 65 | "If `filename` is provided (not `None`), it must be a `str` or a" 66 | f" `pathlib.Path` object, not {type(filename)}." 67 | ) 68 | if fs and fs.protocol != "file": 69 | if not isinstance(fs, AbstractFileSystem): 70 | raise ValueError( 71 | "`fs` must be a `fsspec.AbstractFileSystem` object or `None`," 72 | f" not {type(fs)}." 73 | ) 74 | temp_filename = tempfile.NamedTemporaryFile( 75 | mode="wb", suffix=".safetensors", delete=False 76 | ) 77 | try: 78 | temp_filename.write(save(tensors=params, metadata=metadata)) 79 | finally: 80 | temp_filename.close() 81 | fs.put_file(lpath=temp_filename.name, rpath=filename) 82 | os.remove(temp_filename.name) 83 | else: 84 | if fs and fs.protocol == "file": 85 | filename = Path(fs._strip_protocol(filename)) 86 | else: 87 | filename = filename if isinstance(filename, Path) else Path(filename) 88 | if not filename.exists or not filename.is_file: 89 | raise ValueError( 90 | f"`filename` must be a valid file path, not {filename}." 91 | ) 92 | save_file(tensors=params, filename=filename.as_posix(), metadata=metadata) 93 | return filename 94 | 95 | return save(tensors=params, metadata=metadata) 96 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.9 3 | # by the following command: 4 | # 5 | # pip-compile --output-file=requirements/requirements.txt pyproject.toml 6 | # 7 | absl-py==1.3.0 8 | # via 9 | # chex 10 | # dm-haiku 11 | # optax 12 | # orbax 13 | # tensorboard 14 | attrs==22.2.0 15 | # via pytest 16 | cached-property==1.5.2 17 | # via orbax 18 | cachetools==5.2.0 19 | # via google-auth 20 | certifi==2022.12.7 21 | # via requests 22 | charset-normalizer==2.1.1 23 | # via requests 24 | chex==0.1.5 25 | # via optax 26 | commonmark==0.9.1 27 | # via rich 28 | contourpy==1.0.6 29 | # via matplotlib 30 | cycler==0.11.0 31 | # via matplotlib 32 | dm-haiku==0.0.9 33 | # via safejax (pyproject.toml) 34 | dm-tree==0.1.8 35 | # via chex 36 | etils==0.9.0 37 | # via orbax 38 | exceptiongroup==1.1.0 39 | # via pytest 40 | flax==0.6.3 41 | # via 42 | # orbax 43 | # safejax (pyproject.toml) 44 | fonttools==4.38.0 45 | # via matplotlib 46 | google-auth==2.15.0 47 | # via 48 | # google-auth-oauthlib 49 | # tensorboard 50 | google-auth-oauthlib==0.4.6 51 | # via tensorboard 52 | grpcio==1.51.1 53 | # via tensorboard 54 | idna==3.4 55 | # via requests 56 | importlib-metadata==5.2.0 57 | # via markdown 58 | importlib-resources==5.10.1 59 | # via orbax 60 | iniconfig==1.1.1 61 | # via pytest 62 | jax==0.3.25 63 | # via 64 | # chex 65 | # flax 66 | # objax 67 | # optax 68 | # orbax 69 | # safejax (pyproject.toml) 70 | jaxlib==0.3.25 71 | # via 72 | # chex 73 | # objax 74 | # optax 75 | # orbax 76 | # safejax (pyproject.toml) 77 | jmp==0.0.2 78 | # via dm-haiku 79 | kiwisolver==1.4.4 80 | # via matplotlib 81 | markdown==3.4.1 82 | # via tensorboard 83 | markupsafe==2.1.1 84 | # via werkzeug 85 | matplotlib==3.6.2 86 | # via flax 87 | msgpack==1.0.4 88 | # via flax 89 | numpy==1.24.1 90 | # via 91 | # chex 92 | # contourpy 93 | # dm-haiku 94 | # flax 95 | # jax 96 | # jaxlib 97 | # jmp 98 | # matplotlib 99 | # objax 100 | # opt-einsum 101 | # optax 102 | # orbax 103 | # scipy 104 | # tensorboard 105 | # tensorstore 106 | oauthlib==3.2.2 107 | # via requests-oauthlib 108 | objax==1.6.0 109 | # via safejax (pyproject.toml) 110 | opt-einsum==3.3.0 111 | # via jax 112 | optax==0.1.4 113 | # via flax 114 | orbax==0.0.23 115 | # via flax 116 | packaging==22.0 117 | # via 118 | # matplotlib 119 | # pytest 120 | parameterized==0.8.1 121 | # via objax 122 | pillow==9.3.0 123 | # via 124 | # matplotlib 125 | # objax 126 | pluggy==1.0.0 127 | # via pytest 128 | protobuf==3.20.3 129 | # via tensorboard 130 | pyasn1==0.4.8 131 | # via 132 | # pyasn1-modules 133 | # rsa 134 | pyasn1-modules==0.2.8 135 | # via google-auth 136 | pygments==2.13.0 137 | # via rich 138 | pyparsing==3.0.9 139 | # via matplotlib 140 | pytest==7.2.0 141 | # via orbax 142 | python-dateutil==2.8.2 143 | # via matplotlib 144 | pyyaml==6.0 145 | # via 146 | # flax 147 | # orbax 148 | requests==2.28.1 149 | # via 150 | # requests-oauthlib 151 | # tensorboard 152 | requests-oauthlib==1.3.1 153 | # via google-auth-oauthlib 154 | rich==12.6.0 155 | # via flax 156 | rsa==4.9 157 | # via google-auth 158 | safetensors==0.2.6 159 | # via safejax (pyproject.toml) 160 | scipy==1.9.3 161 | # via 162 | # jax 163 | # jaxlib 164 | # objax 165 | six==1.16.0 166 | # via 167 | # google-auth 168 | # python-dateutil 169 | tabulate==0.9.0 170 | # via dm-haiku 171 | tensorboard==2.11.0 172 | # via objax 173 | tensorboard-data-server==0.6.1 174 | # via tensorboard 175 | tensorboard-plugin-wit==1.8.1 176 | # via tensorboard 177 | tensorstore==0.1.28 178 | # via 179 | # flax 180 | # orbax 181 | tomli==2.0.1 182 | # via pytest 183 | toolz==0.12.0 184 | # via chex 185 | typing-extensions==4.4.0 186 | # via 187 | # flax 188 | # jax 189 | # optax 190 | urllib3==1.26.13 191 | # via requests 192 | werkzeug==2.2.2 193 | # via tensorboard 194 | wheel==0.38.4 195 | # via tensorboard 196 | zipp==3.11.0 197 | # via 198 | # importlib-metadata 199 | # importlib-resources 200 | 201 | # The following packages are considered to be unsafe in a requirements file: 202 | # setuptools 203 | -------------------------------------------------------------------------------- /requirements/requirements-test.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.9 3 | # by the following command: 4 | # 5 | # pip-compile --extra=test --output-file=requirements/requirements-test.txt pyproject.toml 6 | # 7 | absl-py==1.3.0 8 | # via 9 | # chex 10 | # dm-haiku 11 | # optax 12 | # orbax 13 | # tensorboard 14 | attrs==22.2.0 15 | # via pytest 16 | cached-property==1.5.2 17 | # via orbax 18 | cachetools==5.2.0 19 | # via google-auth 20 | certifi==2022.12.7 21 | # via requests 22 | charset-normalizer==2.1.1 23 | # via requests 24 | chex==0.1.5 25 | # via optax 26 | commonmark==0.9.1 27 | # via rich 28 | contourpy==1.0.6 29 | # via matplotlib 30 | cycler==0.11.0 31 | # via matplotlib 32 | dm-haiku==0.0.9 33 | # via safejax (pyproject.toml) 34 | dm-tree==0.1.8 35 | # via chex 36 | etils==0.9.0 37 | # via orbax 38 | exceptiongroup==1.1.0 39 | # via pytest 40 | flax==0.6.3 41 | # via 42 | # orbax 43 | # safejax (pyproject.toml) 44 | fonttools==4.38.0 45 | # via matplotlib 46 | google-auth==2.15.0 47 | # via 48 | # google-auth-oauthlib 49 | # tensorboard 50 | google-auth-oauthlib==0.4.6 51 | # via tensorboard 52 | grpcio==1.51.1 53 | # via tensorboard 54 | idna==3.4 55 | # via requests 56 | importlib-metadata==5.2.0 57 | # via markdown 58 | importlib-resources==5.10.1 59 | # via orbax 60 | iniconfig==1.1.1 61 | # via pytest 62 | jax==0.3.25 63 | # via 64 | # chex 65 | # flax 66 | # objax 67 | # optax 68 | # orbax 69 | # safejax (pyproject.toml) 70 | jaxlib==0.3.25 71 | # via 72 | # chex 73 | # objax 74 | # optax 75 | # orbax 76 | # safejax (pyproject.toml) 77 | jmp==0.0.2 78 | # via dm-haiku 79 | kiwisolver==1.4.4 80 | # via matplotlib 81 | markdown==3.4.1 82 | # via tensorboard 83 | markupsafe==2.1.1 84 | # via werkzeug 85 | matplotlib==3.6.2 86 | # via flax 87 | msgpack==1.0.4 88 | # via flax 89 | numpy==1.24.1 90 | # via 91 | # chex 92 | # contourpy 93 | # dm-haiku 94 | # flax 95 | # jax 96 | # jaxlib 97 | # jmp 98 | # matplotlib 99 | # objax 100 | # opt-einsum 101 | # optax 102 | # orbax 103 | # scipy 104 | # tensorboard 105 | # tensorstore 106 | oauthlib==3.2.2 107 | # via requests-oauthlib 108 | objax==1.6.0 109 | # via safejax (pyproject.toml) 110 | opt-einsum==3.3.0 111 | # via jax 112 | optax==0.1.4 113 | # via flax 114 | orbax==0.0.23 115 | # via flax 116 | packaging==22.0 117 | # via 118 | # matplotlib 119 | # pytest 120 | parameterized==0.8.1 121 | # via objax 122 | pillow==9.3.0 123 | # via 124 | # matplotlib 125 | # objax 126 | pluggy==1.0.0 127 | # via pytest 128 | protobuf==3.20.3 129 | # via tensorboard 130 | pyasn1==0.4.8 131 | # via 132 | # pyasn1-modules 133 | # rsa 134 | pyasn1-modules==0.2.8 135 | # via google-auth 136 | pygments==2.13.0 137 | # via rich 138 | pyparsing==3.0.9 139 | # via matplotlib 140 | pytest==7.2.0 141 | # via orbax 142 | python-dateutil==2.8.2 143 | # via matplotlib 144 | pyyaml==6.0 145 | # via 146 | # flax 147 | # orbax 148 | requests==2.28.1 149 | # via 150 | # requests-oauthlib 151 | # tensorboard 152 | requests-oauthlib==1.3.1 153 | # via google-auth-oauthlib 154 | rich==12.6.0 155 | # via flax 156 | rsa==4.9 157 | # via google-auth 158 | safetensors==0.2.6 159 | # via safejax (pyproject.toml) 160 | scipy==1.9.3 161 | # via 162 | # jax 163 | # jaxlib 164 | # objax 165 | six==1.16.0 166 | # via 167 | # google-auth 168 | # python-dateutil 169 | tabulate==0.9.0 170 | # via dm-haiku 171 | tensorboard==2.11.0 172 | # via objax 173 | tensorboard-data-server==0.6.1 174 | # via tensorboard 175 | tensorboard-plugin-wit==1.8.1 176 | # via tensorboard 177 | tensorstore==0.1.28 178 | # via 179 | # flax 180 | # orbax 181 | tomli==2.0.1 182 | # via pytest 183 | toolz==0.12.0 184 | # via chex 185 | typing-extensions==4.4.0 186 | # via 187 | # flax 188 | # jax 189 | # optax 190 | urllib3==1.26.13 191 | # via requests 192 | werkzeug==2.2.2 193 | # via tensorboard 194 | wheel==0.38.4 195 | # via tensorboard 196 | zipp==3.11.0 197 | # via 198 | # importlib-metadata 199 | # importlib-resources 200 | 201 | # The following packages are considered to be unsafe in a requirements file: 202 | # setuptools 203 | -------------------------------------------------------------------------------- /requirements/requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.9 3 | # by the following command: 4 | # 5 | # pip-compile --extra=quality --output-file=requirements/requirements-dev.txt pyproject.toml 6 | # 7 | absl-py==1.3.0 8 | # via 9 | # chex 10 | # dm-haiku 11 | # optax 12 | # orbax 13 | # tensorboard 14 | attrs==22.2.0 15 | # via pytest 16 | black==22.10.0 17 | # via safejax (pyproject.toml) 18 | build==0.9.0 19 | # via pip-tools 20 | cached-property==1.5.2 21 | # via orbax 22 | cachetools==5.2.0 23 | # via google-auth 24 | certifi==2022.12.7 25 | # via requests 26 | cfgv==3.3.1 27 | # via pre-commit 28 | charset-normalizer==2.1.1 29 | # via requests 30 | chex==0.1.5 31 | # via optax 32 | click==8.1.3 33 | # via 34 | # black 35 | # pip-tools 36 | commonmark==0.9.1 37 | # via rich 38 | contourpy==1.0.6 39 | # via matplotlib 40 | cycler==0.11.0 41 | # via matplotlib 42 | distlib==0.3.6 43 | # via virtualenv 44 | dm-haiku==0.0.9 45 | # via safejax (pyproject.toml) 46 | dm-tree==0.1.8 47 | # via chex 48 | etils==0.9.0 49 | # via orbax 50 | exceptiongroup==1.1.0 51 | # via pytest 52 | filelock==3.8.2 53 | # via virtualenv 54 | flax==0.6.3 55 | # via 56 | # orbax 57 | # safejax (pyproject.toml) 58 | fonttools==4.38.0 59 | # via matplotlib 60 | google-auth==2.15.0 61 | # via 62 | # google-auth-oauthlib 63 | # tensorboard 64 | google-auth-oauthlib==0.4.6 65 | # via tensorboard 66 | grpcio==1.51.1 67 | # via tensorboard 68 | identify==2.5.11 69 | # via pre-commit 70 | idna==3.4 71 | # via requests 72 | importlib-metadata==5.2.0 73 | # via markdown 74 | importlib-resources==5.10.1 75 | # via orbax 76 | iniconfig==1.1.1 77 | # via pytest 78 | jax==0.3.25 79 | # via 80 | # chex 81 | # flax 82 | # objax 83 | # optax 84 | # orbax 85 | # safejax (pyproject.toml) 86 | jaxlib==0.3.25 87 | # via 88 | # chex 89 | # objax 90 | # optax 91 | # orbax 92 | # safejax (pyproject.toml) 93 | jmp==0.0.2 94 | # via dm-haiku 95 | kiwisolver==1.4.4 96 | # via matplotlib 97 | markdown==3.4.1 98 | # via tensorboard 99 | markupsafe==2.1.1 100 | # via werkzeug 101 | matplotlib==3.6.2 102 | # via flax 103 | msgpack==1.0.4 104 | # via flax 105 | mypy-extensions==0.4.3 106 | # via black 107 | nodeenv==1.7.0 108 | # via pre-commit 109 | numpy==1.24.1 110 | # via 111 | # chex 112 | # contourpy 113 | # dm-haiku 114 | # flax 115 | # jax 116 | # jaxlib 117 | # jmp 118 | # matplotlib 119 | # objax 120 | # opt-einsum 121 | # optax 122 | # orbax 123 | # scipy 124 | # tensorboard 125 | # tensorstore 126 | oauthlib==3.2.2 127 | # via requests-oauthlib 128 | objax==1.6.0 129 | # via safejax (pyproject.toml) 130 | opt-einsum==3.3.0 131 | # via jax 132 | optax==0.1.4 133 | # via flax 134 | orbax==0.0.23 135 | # via flax 136 | packaging==22.0 137 | # via 138 | # build 139 | # matplotlib 140 | # pytest 141 | parameterized==0.8.1 142 | # via objax 143 | pathspec==0.10.3 144 | # via black 145 | pep517==0.13.0 146 | # via build 147 | pillow==9.3.0 148 | # via 149 | # matplotlib 150 | # objax 151 | pip-tools==6.12.1 152 | # via safejax (pyproject.toml) 153 | platformdirs==2.6.0 154 | # via 155 | # black 156 | # virtualenv 157 | pluggy==1.0.0 158 | # via pytest 159 | pre-commit==2.20.0 160 | # via safejax (pyproject.toml) 161 | protobuf==3.20.3 162 | # via tensorboard 163 | pyasn1==0.4.8 164 | # via 165 | # pyasn1-modules 166 | # rsa 167 | pyasn1-modules==0.2.8 168 | # via google-auth 169 | pygments==2.13.0 170 | # via rich 171 | pyparsing==3.0.9 172 | # via matplotlib 173 | pytest==7.2.0 174 | # via orbax 175 | python-dateutil==2.8.2 176 | # via matplotlib 177 | pyyaml==6.0 178 | # via 179 | # flax 180 | # orbax 181 | # pre-commit 182 | requests==2.28.1 183 | # via 184 | # requests-oauthlib 185 | # tensorboard 186 | requests-oauthlib==1.3.1 187 | # via google-auth-oauthlib 188 | rich==12.6.0 189 | # via flax 190 | rsa==4.9 191 | # via google-auth 192 | ruff==0.0.195 193 | # via safejax (pyproject.toml) 194 | safetensors==0.2.6 195 | # via safejax (pyproject.toml) 196 | scipy==1.9.3 197 | # via 198 | # jax 199 | # jaxlib 200 | # objax 201 | six==1.16.0 202 | # via 203 | # google-auth 204 | # python-dateutil 205 | tabulate==0.9.0 206 | # via dm-haiku 207 | tensorboard==2.11.0 208 | # via objax 209 | tensorboard-data-server==0.6.1 210 | # via tensorboard 211 | tensorboard-plugin-wit==1.8.1 212 | # via tensorboard 213 | tensorstore==0.1.28 214 | # via 215 | # flax 216 | # orbax 217 | toml==0.10.2 218 | # via pre-commit 219 | tomli==2.0.1 220 | # via 221 | # black 222 | # build 223 | # pep517 224 | # pytest 225 | toolz==0.12.0 226 | # via chex 227 | typing-extensions==4.4.0 228 | # via 229 | # black 230 | # flax 231 | # jax 232 | # optax 233 | urllib3==1.26.13 234 | # via requests 235 | virtualenv==20.17.1 236 | # via pre-commit 237 | werkzeug==2.2.2 238 | # via tensorboard 239 | wheel==0.38.4 240 | # via 241 | # pip-tools 242 | # tensorboard 243 | zipp==3.11.0 244 | # via 245 | # importlib-metadata 246 | # importlib-resources 247 | 248 | # The following packages are considered to be unsafe in a requirements file: 249 | # pip 250 | # setuptools 251 | -------------------------------------------------------------------------------- /tests/test_flax.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | from flax.core.frozen_dict import FrozenDict, freeze, unfreeze 5 | from flax.serialization import ( 6 | from_bytes, 7 | from_state_dict, 8 | msgpack_restore, 9 | msgpack_serialize, 10 | to_bytes, 11 | to_state_dict, 12 | ) 13 | 14 | from safejax.flax import deserialize, serialize 15 | from safejax.typing import FlaxParams 16 | 17 | from .utils import assert_over_trees 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "params", 22 | [ 23 | pytest.lazy_fixture("flax_single_layer_params"), 24 | pytest.lazy_fixture("flax_resnet50_params"), 25 | ], 26 | ) 27 | def test_partial_deserialize(params: FlaxParams) -> None: 28 | encoded_params = serialize(params=params) 29 | assert isinstance(encoded_params, bytes) 30 | assert len(encoded_params) > 0 31 | 32 | decoded_params = deserialize(path_or_buf=encoded_params) 33 | assert isinstance(decoded_params, FrozenDict) 34 | assert len(decoded_params) > 0 35 | assert id(decoded_params) != id(params) 36 | assert decoded_params.keys() == params.keys() 37 | 38 | assert_over_trees(params=params, decoded_params=decoded_params) 39 | 40 | 41 | @pytest.mark.parametrize( 42 | "params", 43 | [ 44 | pytest.lazy_fixture("flax_single_layer_params"), 45 | pytest.lazy_fixture("flax_resnet50_params"), 46 | ], 47 | ) 48 | @pytest.mark.usefixtures("safetensors_file") 49 | def test_partial_deserialize_from_file( 50 | params: FlaxParams, safetensors_file: Path 51 | ) -> None: 52 | safetensors_file = serialize(params=params, filename=safetensors_file) 53 | assert isinstance(safetensors_file, Path) 54 | assert safetensors_file.exists() 55 | 56 | decoded_params = deserialize(path_or_buf=safetensors_file) 57 | assert isinstance(decoded_params, FrozenDict) 58 | assert len(decoded_params) > 0 59 | assert id(decoded_params) != id(params) 60 | assert decoded_params.keys() == params.keys() 61 | 62 | assert_over_trees(params=params, decoded_params=decoded_params) 63 | 64 | 65 | @pytest.mark.parametrize( 66 | "params", 67 | [ 68 | pytest.lazy_fixture("flax_single_layer_params"), 69 | pytest.lazy_fixture("flax_resnet50_params"), 70 | ], 71 | ) 72 | @pytest.mark.usefixtures("safetensors_file", "msgpack_file") 73 | def test_safejax_and_msgpack( 74 | params: FlaxParams, safetensors_file: Path, msgpack_file: Path 75 | ) -> None: 76 | safetensors_file = serialize(params=params, filename=safetensors_file) 77 | assert isinstance(safetensors_file, Path) 78 | assert safetensors_file.exists() 79 | 80 | safetensors_decoded_params = deserialize(path_or_buf=safetensors_file) 81 | assert isinstance(safetensors_decoded_params, FrozenDict) 82 | assert len(safetensors_decoded_params) > 0 83 | assert id(safetensors_decoded_params) != id(params) 84 | assert safetensors_decoded_params.keys() == params.keys() 85 | 86 | with open(msgpack_file, mode="wb") as f: 87 | f.write(msgpack_serialize(unfreeze(params))) 88 | 89 | with open(msgpack_file, "rb") as f: 90 | msgpack_decoded_params = freeze(msgpack_restore(f.read())) 91 | 92 | assert isinstance(msgpack_decoded_params, type(params)) 93 | assert len(msgpack_decoded_params) > 0 94 | assert id(msgpack_decoded_params) != id(params) 95 | assert msgpack_decoded_params.keys() == params.keys() 96 | 97 | assert_over_trees(params=params, decoded_params=safetensors_decoded_params) 98 | assert_over_trees(params=params, decoded_params=msgpack_decoded_params) 99 | 100 | 101 | @pytest.mark.parametrize( 102 | "params", 103 | [ 104 | pytest.lazy_fixture("flax_single_layer_params"), 105 | pytest.lazy_fixture("flax_resnet50_params"), 106 | ], 107 | ) 108 | def test_safejax_and_msgpack_bytes(params: FlaxParams) -> None: 109 | encoded_params = serialize(params=params) 110 | assert isinstance(encoded_params, bytes) 111 | assert len(encoded_params) > 0 112 | 113 | safetensors_decoded_params = deserialize(path_or_buf=encoded_params) 114 | assert isinstance(safetensors_decoded_params, FrozenDict) 115 | assert len(safetensors_decoded_params) > 0 116 | assert id(safetensors_decoded_params) != id(params) 117 | assert safetensors_decoded_params.keys() == params.keys() 118 | 119 | msgpack_bytes_encoded_params = to_bytes(params) 120 | assert isinstance(msgpack_bytes_encoded_params, bytes) 121 | assert len(msgpack_bytes_encoded_params) > 0 122 | 123 | msgpack_bytes_decoded_params = freeze( 124 | from_bytes(params, msgpack_bytes_encoded_params) 125 | ) 126 | assert isinstance(msgpack_bytes_decoded_params, FrozenDict) 127 | assert len(msgpack_bytes_decoded_params) > 0 128 | assert id(msgpack_bytes_decoded_params) != id(params) 129 | assert msgpack_bytes_decoded_params.keys() == params.keys() 130 | 131 | assert_over_trees(params=params, decoded_params=safetensors_decoded_params) 132 | assert_over_trees(params=params, decoded_params=msgpack_bytes_decoded_params) 133 | 134 | 135 | @pytest.mark.parametrize( 136 | "params", 137 | [ 138 | pytest.lazy_fixture("flax_single_layer_params"), 139 | pytest.lazy_fixture("flax_resnet50_params"), 140 | ], 141 | ) 142 | def test_safejax_and_state_dict(params: FlaxParams) -> None: 143 | encoded_params = serialize(params=params) 144 | assert isinstance(encoded_params, bytes) 145 | assert len(encoded_params) > 0 146 | 147 | safetensors_decoded_params = deserialize(path_or_buf=encoded_params) 148 | assert isinstance(safetensors_decoded_params, FrozenDict) 149 | assert len(safetensors_decoded_params) > 0 150 | assert id(safetensors_decoded_params) != id(params) 151 | assert safetensors_decoded_params.keys() == params.keys() 152 | 153 | state_dict_encoded_params = to_state_dict(params) 154 | assert isinstance(state_dict_encoded_params, dict) 155 | assert len(state_dict_encoded_params) > 0 156 | 157 | state_dict_decoded_params = from_state_dict(params, state_dict_encoded_params) 158 | assert isinstance(state_dict_decoded_params, FrozenDict) 159 | assert len(state_dict_decoded_params) > 0 160 | assert id(state_dict_decoded_params) != id(params) 161 | assert state_dict_decoded_params.keys() == params.keys() 162 | 163 | assert_over_trees(params=params, decoded_params=safetensors_decoded_params) 164 | assert_over_trees(params=params, decoded_params=state_dict_decoded_params) 165 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🔐 Serialize JAX, Flax, Haiku, or Objax model params with `safetensors` 2 | 3 | `safejax` is a Python package to serialize JAX, Flax, Haiku, or Objax model params using `safetensors` 4 | as the tensor storage format, instead of relying on `pickle`. For more details on why 5 | `safetensors` is safer than `pickle` please check [huggingface/safetensors](https://github.com/huggingface/safetensors). 6 | 7 | Note that `safejax` supports the serialization of `jax`, `flax`, `dm-haiku`, and `objax` model 8 | parameters and has been tested with all those frameworks, but there may be some cases where it 9 | does not work as expected, as this is still in an early development phase, so please if you have 10 | any feedback or bug reports, open an issue at [safejax/issues](https://github.com/alvarobartt/safejax/issues). 11 | 12 | ## 🛠️ Requirements & Installation 13 | 14 | `safejax` requires Python 3.7 or above 15 | 16 | ```bash 17 | pip install safejax --upgrade 18 | ``` 19 | 20 | ## 💻 Usage 21 | 22 | ### `flax` 23 | 24 | * Convert `params` to `bytes` in memory 25 | 26 | ```python 27 | from safejax.flax import serialize, deserialize 28 | 29 | params = model.init(...) 30 | 31 | encoded_bytes = serialize(params) 32 | decoded_params = deserialize(encoded_bytes) 33 | 34 | model.apply(decoded_params, ...) 35 | ``` 36 | 37 | * Convert `params` to `bytes` in `params.safetensors` file 38 | 39 | ```python 40 | from safejax.flax import serialize, deserialize 41 | 42 | params = model.init(...) 43 | 44 | encoded_bytes = serialize(params, filename="./params.safetensors") 45 | decoded_params = deserialize("./params.safetensors") 46 | 47 | model.apply(decoded_params, ...) 48 | ``` 49 | 50 | --- 51 | 52 | ### `dm-haiku` 53 | 54 | * Just contains `params` 55 | 56 | ```python 57 | from safejax.haiku import serialize, deserialize 58 | 59 | params = model.init(...) 60 | 61 | encoded_bytes = serialize(params) 62 | decoded_params = deserialize(encoded_bytes) 63 | 64 | model.apply(decoded_params, ...) 65 | ``` 66 | 67 | * If it contains `params` and `state` e.g. ExponentialMovingAverage in BatchNorm 68 | 69 | ```python 70 | from safejax.haiku import serialize, deserialize 71 | 72 | params, state = model.init(...) 73 | params_state = {"params": params, "state": state} 74 | 75 | encoded_bytes = serialize(params_state) 76 | decoded_params_state = deserialize(encoded_bytes) # .keys() contains `params` and `state` 77 | 78 | model.apply(decoded_params_state["params"], decoded_params_state["state"], ...) 79 | ``` 80 | 81 | * If it contains `params` and `state`, but we want to serialize those individually 82 | 83 | ```python 84 | from safejax.haiku import serialize, deserialize 85 | 86 | params, state = model.init(...) 87 | 88 | encoded_bytes = serialize(params) 89 | decoded_params = deserialize(encoded_bytes) 90 | 91 | encoded_bytes = serialize(state) 92 | decoded_state = deserialize(encoded_bytes) 93 | 94 | model.apply(decoded_params, decoded_state, ...) 95 | ``` 96 | 97 | --- 98 | 99 | ### `objax` 100 | 101 | * Convert `params` to `bytes` in memory, and convert back to `VarCollection` 102 | 103 | ```python 104 | from safejax.objax import serialize, deserialize 105 | 106 | params = model.vars() 107 | 108 | encoded_bytes = serialize(params=params) 109 | decoded_params = deserialize(encoded_bytes) 110 | 111 | for key, value in decoded_params.items(): 112 | if key in model.vars(): 113 | model.vars()[key].assign(value.value) 114 | 115 | model(...) 116 | ``` 117 | 118 | * Convert `params` to `bytes` in `params.safetensors` file 119 | 120 | ```python 121 | from safejax.objax import serialize, deserialize 122 | 123 | params = model.vars() 124 | 125 | encoded_bytes = serialize(params=params, filename="./params.safetensors") 126 | decoded_params = deserialize("./params.safetensors") 127 | 128 | for key, value in decoded_params.items(): 129 | if key in model.vars(): 130 | model.vars()[key].assign(value.value) 131 | 132 | model(...) 133 | ``` 134 | 135 | * Convert `params` to `bytes` in `params.safetensors` and assign during deserialization 136 | 137 | ```python 138 | from safejax.objax import serialize, deserialize_with_assignment 139 | 140 | params = model.vars() 141 | 142 | encoded_bytes = serialize(params=params, filename="./params.safetensors") 143 | deserialize_with_assignment(filename="./params.safetensors", model_vars=params) 144 | 145 | model(...) 146 | ``` 147 | 148 | --- 149 | 150 | More in-detail examples can be found at [`examples/`](./examples) for `flax`, `dm-haiku`, and `objax`. 151 | 152 | ## 🤔 Why `safejax`? 153 | 154 | `safetensors` defines an easy and fast (zero-copy) format to store tensors, 155 | while `pickle` has some known weaknesses and security issues. `safetensors` 156 | is also a storage format that is intended to be trivial to the framework 157 | used to load the tensors. More in-depth information can be found at 158 | [huggingface/safetensors](https://github.com/huggingface/safetensors). 159 | 160 | `jax` uses `pytrees` to store the model parameters in memory, so 161 | it's a dictionary-like class containing nested `jnp.DeviceArray` tensors. 162 | 163 | `dm-haiku` uses a custom dictionary formatted as `/~/`, where the 164 | levels are the ones that define the tree structure and `/~/` is the separator between those 165 | e.g. `res_net50/~/intial_conv`, and that key does not contain a `jnp.DeviceArray`, but a 166 | dictionary with key value pairs e.g. for both weights as `w` and biases as `b`. 167 | 168 | `objax` defines a custom dictionary-like class named `VarCollection` that contains 169 | some variables inheriting from `BaseVar` which is another custom `objax` type. 170 | 171 | `flax` defines a dictionary-like class named `FrozenDict` that is used to 172 | store the tensors in memory, it can be dumped either into `bytes` in `MessagePack` 173 | format or as a `state_dict`. 174 | 175 | There are no plans from HuggingFace to extend `safetensors` to support anything more than tensors 176 | e.g. `FrozenDict`s, see their response at [huggingface/safetensors/discussions/138](https://github.com/huggingface/safetensors/discussions/138). 177 | 178 | So the motivation to create `safejax` is to easily provide a way to serialize `FrozenDict`s 179 | using `safetensors` as the tensor storage format instead of `pickle`, as well as to provide 180 | a common and easy way to serialize and deserialize any JAX model params (Flax, Haiku, or Objax) 181 | using `safetensors` format. 182 | -------------------------------------------------------------------------------- /requirements/requirements-docs.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.9 3 | # by the following command: 4 | # 5 | # pip-compile --extra=docs --output-file=requirements/requirements-docs.txt pyproject.toml 6 | # 7 | absl-py==1.3.0 8 | # via 9 | # chex 10 | # dm-haiku 11 | # optax 12 | # orbax 13 | # tensorboard 14 | attrs==22.2.0 15 | # via pytest 16 | babel==2.11.0 17 | # via mkdocs-git-revision-date-localized-plugin 18 | cached-property==1.5.2 19 | # via orbax 20 | cachetools==5.2.0 21 | # via google-auth 22 | certifi==2022.12.7 23 | # via requests 24 | charset-normalizer==2.1.1 25 | # via requests 26 | chex==0.1.5 27 | # via optax 28 | click==8.1.3 29 | # via mkdocs 30 | colorama==0.4.6 31 | # via griffe 32 | commonmark==0.9.1 33 | # via rich 34 | contourpy==1.0.6 35 | # via matplotlib 36 | cycler==0.11.0 37 | # via matplotlib 38 | dm-haiku==0.0.9 39 | # via safejax (pyproject.toml) 40 | dm-tree==0.1.8 41 | # via chex 42 | etils==0.9.0 43 | # via orbax 44 | flax==0.6.3 45 | # via 46 | # orbax 47 | # safejax (pyproject.toml) 48 | fonttools==4.38.0 49 | # via matplotlib 50 | ghp-import==2.1.0 51 | # via mkdocs 52 | gitdb==4.0.10 53 | # via gitpython 54 | gitpython==3.1.29 55 | # via mkdocs-git-revision-date-localized-plugin 56 | google-auth==2.15.0 57 | # via 58 | # google-auth-oauthlib 59 | # tensorboard 60 | google-auth-oauthlib==0.4.6 61 | # via tensorboard 62 | griffe==0.25.2 63 | # via mkdocstrings-python 64 | grpcio==1.51.1 65 | # via tensorboard 66 | idna==3.4 67 | # via requests 68 | importlib-metadata==5.2.0 69 | # via 70 | # markdown 71 | # mkdocs 72 | importlib-resources==5.10.1 73 | # via orbax 74 | iniconfig==1.1.1 75 | # via pytest 76 | jax==0.3.25 77 | # via 78 | # chex 79 | # flax 80 | # objax 81 | # optax 82 | # orbax 83 | # safejax (pyproject.toml) 84 | jaxlib==0.3.25 85 | # via 86 | # chex 87 | # objax 88 | # optax 89 | # orbax 90 | # safejax (pyproject.toml) 91 | jinja2==3.1.2 92 | # via 93 | # mkdocs 94 | # mkdocs-material 95 | # mkdocstrings 96 | jmp==0.0.2 97 | # via dm-haiku 98 | kiwisolver==1.4.4 99 | # via matplotlib 100 | markdown==3.3.7 101 | # via 102 | # mkdocs 103 | # mkdocs-autorefs 104 | # mkdocs-material 105 | # mkdocstrings 106 | # pymdown-extensions 107 | # tensorboard 108 | markupsafe==2.1.1 109 | # via 110 | # jinja2 111 | # mkdocstrings 112 | # werkzeug 113 | matplotlib==3.6.2 114 | # via flax 115 | mergedeep==1.3.4 116 | # via mkdocs 117 | mkdocs==1.4.2 118 | # via 119 | # mkdocs-autorefs 120 | # mkdocs-git-revision-date-localized-plugin 121 | # mkdocs-material 122 | # mkdocstrings 123 | # safejax (pyproject.toml) 124 | mkdocs-autorefs==0.4.1 125 | # via mkdocstrings 126 | mkdocs-git-revision-date-localized-plugin==1.1.0 127 | # via safejax (pyproject.toml) 128 | mkdocs-material==8.5.11 129 | # via safejax (pyproject.toml) 130 | mkdocs-material-extensions==1.1.1 131 | # via mkdocs-material 132 | mkdocstrings[python]==0.19.1 133 | # via 134 | # mkdocstrings-python 135 | # safejax (pyproject.toml) 136 | mkdocstrings-python==0.8.2 137 | # via mkdocstrings 138 | msgpack==1.0.4 139 | # via flax 140 | numpy==1.24.1 141 | # via 142 | # chex 143 | # contourpy 144 | # dm-haiku 145 | # flax 146 | # jax 147 | # jaxlib 148 | # jmp 149 | # matplotlib 150 | # objax 151 | # opt-einsum 152 | # optax 153 | # orbax 154 | # scipy 155 | # tensorboard 156 | # tensorstore 157 | oauthlib==3.2.2 158 | # via requests-oauthlib 159 | objax==1.6.0 160 | # via safejax (pyproject.toml) 161 | opt-einsum==3.3.0 162 | # via jax 163 | optax==0.1.4 164 | # via flax 165 | orbax==0.0.23 166 | # via flax 167 | packaging==22.0 168 | # via 169 | # matplotlib 170 | # mkdocs 171 | # pytest 172 | parameterized==0.8.1 173 | # via objax 174 | pillow==9.3.0 175 | # via 176 | # matplotlib 177 | # objax 178 | pluggy==1.0.0 179 | # via pytest 180 | protobuf==3.20.3 181 | # via tensorboard 182 | py==1.11.0 183 | # via pytest 184 | pyasn1==0.4.8 185 | # via 186 | # pyasn1-modules 187 | # rsa 188 | pyasn1-modules==0.2.8 189 | # via google-auth 190 | pygments==2.13.0 191 | # via 192 | # mkdocs-material 193 | # rich 194 | pymdown-extensions==9.9 195 | # via 196 | # mkdocs-material 197 | # mkdocstrings 198 | pyparsing==3.0.9 199 | # via matplotlib 200 | pytest==7.1.3 201 | # via orbax 202 | python-dateutil==2.8.2 203 | # via 204 | # ghp-import 205 | # matplotlib 206 | pytz==2022.7 207 | # via babel 208 | pyyaml==6.0 209 | # via 210 | # flax 211 | # mkdocs 212 | # orbax 213 | # pyyaml-env-tag 214 | pyyaml-env-tag==0.1 215 | # via mkdocs 216 | requests==2.28.1 217 | # via 218 | # mkdocs-material 219 | # requests-oauthlib 220 | # tensorboard 221 | requests-oauthlib==1.3.1 222 | # via google-auth-oauthlib 223 | rich==12.6.0 224 | # via flax 225 | rsa==4.9 226 | # via google-auth 227 | safetensors==0.2.6 228 | # via safejax (pyproject.toml) 229 | scipy==1.9.3 230 | # via 231 | # jax 232 | # jaxlib 233 | # objax 234 | six==1.16.0 235 | # via 236 | # google-auth 237 | # python-dateutil 238 | smmap==5.0.0 239 | # via gitdb 240 | tabulate==0.9.0 241 | # via dm-haiku 242 | tensorboard==2.11.0 243 | # via objax 244 | tensorboard-data-server==0.6.1 245 | # via tensorboard 246 | tensorboard-plugin-wit==1.8.1 247 | # via tensorboard 248 | tensorstore==0.1.28 249 | # via 250 | # flax 251 | # orbax 252 | tomli==2.0.1 253 | # via pytest 254 | toolz==0.12.0 255 | # via chex 256 | typing-extensions==4.4.0 257 | # via 258 | # flax 259 | # jax 260 | # optax 261 | urllib3==1.26.13 262 | # via requests 263 | watchdog==2.2.0 264 | # via mkdocs 265 | werkzeug==2.2.2 266 | # via tensorboard 267 | wheel==0.38.4 268 | # via tensorboard 269 | zipp==3.11.0 270 | # via 271 | # importlib-metadata 272 | # importlib-resources 273 | 274 | # The following packages are considered to be unsafe in a requirements file: 275 | # setuptools 276 | -------------------------------------------------------------------------------- /src/safejax/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Dict, Union 3 | 4 | import numpy as np 5 | from flax.core.frozen_dict import FrozenDict 6 | from jax import numpy as jnp 7 | from objax.variable import BaseState, BaseVar, RandomState, StateVar, TrainRef, TrainVar 8 | 9 | from safejax.typing import JaxDeviceArrayDict, NumpyArrayDict, ObjaxDict, ParamsDictLike 10 | 11 | OBJAX_VARIABLES = { 12 | "BaseVar": BaseVar, 13 | "BaseState": BaseState, 14 | "RandomState": RandomState, 15 | "TrainRef": TrainRef, 16 | "StateVar": StateVar, 17 | "TrainVar": TrainVar, 18 | } 19 | OBJAX_VARIABLE_SEPARATOR = "::" 20 | 21 | 22 | def flatten_dict( 23 | params: ParamsDictLike, 24 | key_prefix: Union[str, None] = None, 25 | include_objax_variables: bool = False, 26 | ) -> Union[NumpyArrayDict, JaxDeviceArrayDict]: 27 | """ 28 | Flatten a `Dict`, `FrozenDict`, or `VarCollection`, for more detailed information on 29 | the supported input types check `safejax.typing.ParamsDictLike`. 30 | 31 | Note: 32 | This function is recursive to explore all the nested dictionaries, 33 | and the keys are being flattened using the `.` character. So that the 34 | later de-nesting can be done using the `.` character as a separator. 35 | 36 | Reference at https://gist.github.com/Narsil/d5b0d747e5c8c299eb6d82709e480e3d 37 | 38 | Args: 39 | params: A `Dict`, `FrozenDict`, or `VarCollection` with the params to flatten. 40 | key_prefix: A prefix to prepend to the keys of the flattened dictionary. 41 | include_objax_variables: 42 | A boolean indicating whether to include the `objax.variable` types in 43 | the keys of the flattened dictionary. 44 | 45 | Returns: 46 | A `Dict` containing the flattened params as level-1 key-value pairs. 47 | """ 48 | flattened_params = {} 49 | for key, value in params.items(): 50 | key = f"{key_prefix}.{key}" if key_prefix else key 51 | if isinstance(value, (BaseVar, BaseState)): 52 | if include_objax_variables: 53 | key = f"{key}{OBJAX_VARIABLE_SEPARATOR}{type(value).__name__}" 54 | value = value.value 55 | if isinstance(value, (jnp.DeviceArray, np.ndarray)): 56 | flattened_params[key] = value 57 | continue 58 | if isinstance(value, (Dict, FrozenDict)): 59 | flattened_params.update( 60 | flatten_dict( 61 | params=value, 62 | key_prefix=key, 63 | include_objax_variables=include_objax_variables, 64 | ) 65 | ) 66 | return flattened_params 67 | 68 | 69 | def unflatten_dict(params: Union[NumpyArrayDict, JaxDeviceArrayDict]) -> Dict[str, Any]: 70 | """ 71 | Unflatten a `Dict` where the keys should be expanded using the `.` character 72 | as a separator. 73 | 74 | Note: 75 | If the params where serialized from a `VarCollection` object, then the 76 | `objax.variable` types are included in the keys, and since this function 77 | just unflattens the dictionary without `objax.variable` casting, those 78 | variables will be ignored and unflattened normally. Anyway, when deserializing 79 | `objax` models you should use `safejax.objax.deserialize` or just use the 80 | function params in `safejax.deserialize`: `requires_unflattening=False` and 81 | `to_var_collection=True`. 82 | 83 | Reference at https://stackoverflow.com/a/63545677. 84 | 85 | Args: 86 | params: A `Dict` containing the params to unflatten by expanding the keys. 87 | 88 | Returns: 89 | An unflattened `Dict` where the keys are expanded using the `.` character. 90 | """ 91 | unflattened_params = {} 92 | warned_user = False 93 | for key, value in params.items(): 94 | unflattened_params_tmp = unflattened_params 95 | if not warned_user and OBJAX_VARIABLE_SEPARATOR in key: 96 | warnings.warn( 97 | "The params were serialized from a `VarCollection` object, " 98 | "so the `objax.variable` types are included in the keys, " 99 | "and since this function just unflattens the dictionary " 100 | "without `objax.variable` casting, those variables will be " 101 | "ignored and unflattened normally. Anyway, when deserializing " 102 | "`objax` models you should use `safejax.objax.deserialize` " 103 | "or just use the function params in `safejax.deserialize`: " 104 | "`requires_unflattening=False` and `to_var_collection=True`." 105 | ) 106 | warned_user = True 107 | key = ( 108 | key.split(OBJAX_VARIABLE_SEPARATOR)[0] 109 | if OBJAX_VARIABLE_SEPARATOR in key 110 | else key 111 | ) 112 | subkeys = key.split(".") 113 | for subkey in subkeys[:-1]: 114 | unflattened_params_tmp = unflattened_params_tmp.setdefault(subkey, {}) 115 | unflattened_params_tmp[subkeys[-1]] = value 116 | return unflattened_params 117 | 118 | 119 | def cast_objax_variables( 120 | params: JaxDeviceArrayDict, 121 | ) -> Union[JaxDeviceArrayDict, ObjaxDict]: 122 | """ 123 | Cast the `jnp.DeviceArray` to their corresponding `objax.variable` types. 124 | 125 | Note: 126 | This function may return the same `params` if no `objax.variable` types 127 | are found in the keys. 128 | 129 | Args: 130 | params: A `Dict` containing the params to cast. 131 | 132 | Raises: 133 | ValueError: If the params were not serialized from a `VarCollection` object. 134 | 135 | Returns: 136 | A `Dict` containing the keys without the variable name, and the values 137 | with the `objax.variable` objects with `.value` assigned from the 138 | `jnp.DeviceArray`. 139 | """ 140 | casted_params = {} 141 | for key, value in params.items(): 142 | if OBJAX_VARIABLE_SEPARATOR not in key: 143 | raise ValueError( 144 | "The params were not serialized from a `VarCollection` object, since" 145 | " the type has not been included as part of the key using" 146 | f" `{OBJAX_VARIABLE_SEPARATOR}` as separator at the end of the key." 147 | " Returning the same params without casting the `jnp.DeviceArray` to" 148 | " `objax.variable` types." 149 | ) 150 | key, objax_var_type = key.split(OBJAX_VARIABLE_SEPARATOR) 151 | casted_params[key] = OBJAX_VARIABLES[objax_var_type](value) 152 | return casted_params 153 | -------------------------------------------------------------------------------- /tests/test_core_load.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, Union 3 | 4 | import pytest 5 | from flax.core.frozen_dict import FrozenDict 6 | from fsspec.spec import AbstractFileSystem 7 | from objax.variable import VarCollection 8 | 9 | from safejax.core.load import deserialize 10 | from safejax.core.save import serialize 11 | from safejax.typing import ParamsDictLike 12 | 13 | from .utils import assert_over_trees 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "params, serialize_kwargs, deserialize_kwargs, expected_output_type", 18 | [ 19 | ( 20 | pytest.lazy_fixture("flax_resnet50_params"), 21 | {}, 22 | {"freeze_dict": True}, 23 | FrozenDict, 24 | ), 25 | (pytest.lazy_fixture("flax_resnet50_params"), {}, {"freeze_dict": False}, dict), 26 | ( 27 | pytest.lazy_fixture("objax_resnet50_params"), 28 | {"include_objax_variables": True}, 29 | {"requires_unflattening": False, "to_var_collection": True}, 30 | VarCollection, 31 | ), 32 | ( 33 | pytest.lazy_fixture("objax_resnet50_params"), 34 | {"include_objax_variables": False}, 35 | {"requires_unflattening": False, "to_var_collection": False}, 36 | dict, 37 | ), 38 | (pytest.lazy_fixture("haiku_resnet50_params"), {}, {}, dict), 39 | ], 40 | ) 41 | def test_deserialize( 42 | params: ParamsDictLike, 43 | serialize_kwargs: Dict[str, Any], 44 | deserialize_kwargs: Dict[str, Any], 45 | expected_output_type: Union[dict, FrozenDict, VarCollection], 46 | ) -> None: 47 | encoded_params = serialize(params=params, **serialize_kwargs) 48 | decoded_params = deserialize(path_or_buf=encoded_params, **deserialize_kwargs) 49 | assert isinstance(decoded_params, expected_output_type) 50 | assert len(decoded_params) > 0 51 | assert id(decoded_params) != id(params) 52 | assert decoded_params.keys() == params.keys() 53 | 54 | assert_over_trees(params=params, decoded_params=decoded_params) 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "params, serialize_kwargs, deserialize_kwargs, expected_output_type", 59 | [ 60 | ( 61 | pytest.lazy_fixture("flax_resnet50_params"), 62 | {}, 63 | {"freeze_dict": True}, 64 | FrozenDict, 65 | ), 66 | (pytest.lazy_fixture("flax_resnet50_params"), {}, {"freeze_dict": False}, dict), 67 | ( 68 | pytest.lazy_fixture("objax_resnet50_params"), 69 | {"include_objax_variables": True}, 70 | {"requires_unflattening": False, "to_var_collection": True}, 71 | VarCollection, 72 | ), 73 | ( 74 | pytest.lazy_fixture("objax_resnet50_params"), 75 | {"include_objax_variables": False}, 76 | {"requires_unflattening": False, "to_var_collection": False}, 77 | dict, 78 | ), 79 | (pytest.lazy_fixture("haiku_resnet50_params"), {}, {}, dict), 80 | ], 81 | ) 82 | @pytest.mark.usefixtures("safetensors_file") 83 | def test_deserialize_from_file( 84 | params: ParamsDictLike, 85 | serialize_kwargs: Dict[str, Any], 86 | deserialize_kwargs: Dict[str, Any], 87 | expected_output_type: Union[dict, FrozenDict, VarCollection], 88 | safetensors_file: Path, 89 | ) -> None: 90 | safetensors_file = serialize( 91 | params=params, filename=safetensors_file, **serialize_kwargs 92 | ) 93 | decoded_params = deserialize(path_or_buf=safetensors_file, **deserialize_kwargs) 94 | assert isinstance(decoded_params, expected_output_type) 95 | assert len(decoded_params) > 0 96 | assert id(decoded_params) != id(params) 97 | assert decoded_params.keys() == params.keys() 98 | 99 | assert_over_trees(params=params, decoded_params=decoded_params) 100 | 101 | 102 | @pytest.mark.usefixtures("flax_resnet50_params", "safetensors_file", "metadata") 103 | def test_deserialize_from_file_with_metadata( 104 | flax_resnet50_params: ParamsDictLike, 105 | safetensors_file: Path, 106 | metadata: Dict[str, str], 107 | ) -> None: 108 | safetensors_file = serialize( 109 | params=flax_resnet50_params, filename=safetensors_file, metadata=metadata 110 | ) 111 | decoded_params, metadata = deserialize(path_or_buf=safetensors_file) 112 | assert isinstance(decoded_params, dict) 113 | assert len(decoded_params) > 0 114 | assert id(decoded_params) != id(flax_resnet50_params) 115 | assert decoded_params.keys() == flax_resnet50_params.keys() 116 | assert metadata is not None 117 | assert isinstance(metadata, dict) 118 | assert len(metadata) > 0 119 | 120 | assert_over_trees(params=flax_resnet50_params, decoded_params=decoded_params) 121 | 122 | 123 | @pytest.mark.parametrize( 124 | "params, serialize_kwargs, deserialize_kwargs, expected_output_type", 125 | [ 126 | ( 127 | pytest.lazy_fixture("flax_resnet50_params"), 128 | {}, 129 | {"freeze_dict": True}, 130 | FrozenDict, 131 | ), 132 | (pytest.lazy_fixture("flax_resnet50_params"), {}, {"freeze_dict": False}, dict), 133 | ( 134 | pytest.lazy_fixture("objax_resnet50_params"), 135 | {"include_objax_variables": True}, 136 | {"requires_unflattening": False, "to_var_collection": True}, 137 | VarCollection, 138 | ), 139 | ( 140 | pytest.lazy_fixture("objax_resnet50_params"), 141 | {"include_objax_variables": False}, 142 | {"requires_unflattening": False, "to_var_collection": False}, 143 | dict, 144 | ), 145 | (pytest.lazy_fixture("haiku_resnet50_params"), {}, {}, dict), 146 | ], 147 | ) 148 | @pytest.mark.usefixtures("safetensors_file", "fs") 149 | def test_deserialize_from_file_in_fs( 150 | params: ParamsDictLike, 151 | serialize_kwargs: Dict[str, Any], 152 | deserialize_kwargs: Dict[str, Any], 153 | expected_output_type: Union[dict, FrozenDict, VarCollection], 154 | safetensors_file: Path, 155 | fs: AbstractFileSystem, 156 | ) -> None: 157 | safetensors_file = serialize( 158 | params=params, filename=safetensors_file, fs=fs, **serialize_kwargs 159 | ) 160 | decoded_params = deserialize( 161 | path_or_buf=safetensors_file, fs=fs, **deserialize_kwargs 162 | ) 163 | assert isinstance(decoded_params, expected_output_type) 164 | assert len(decoded_params) > 0 165 | assert id(decoded_params) != id(params) 166 | assert decoded_params.keys() == params.keys() 167 | 168 | assert_over_trees(params=params, decoded_params=decoded_params) 169 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # 💻 Examples 2 | 3 | Here you will find some detailed examples of how to use `safejax` to serialize 4 | model parameters, in opposition to the default way to store those, which uses 5 | `pickle` as the format to store the tensors instead of `safetensors`. 6 | 7 | ## Flax - [`flax_ft_safejax`](./examples/flax_ft_safejax.py) 8 | 9 | To run this Python script you won't need to install anything else than 10 | `safejax`, as both `jax` and `flax` are installed as part of it. 11 | 12 | In this case, a single-layer model will be created, for now, `flax` 13 | doesn't have any pre-defined architecture such as ResNet, but you can use 14 | [`flaxmodels`](https://github.com/matthias-wright/flaxmodels) for that, as 15 | it defines some well-known architectures written in `flax`. 16 | 17 | ```python 18 | import jax 19 | from flax import linen as nn 20 | 21 | class SingleLayerModel(nn.Module): 22 | features: int 23 | 24 | @nn.compact 25 | def __call__(self, x): 26 | x = nn.Dense(features=self.features)(x) 27 | return x 28 | ``` 29 | 30 | Once the network has been defined, we can instantiate and initialize it, 31 | to retrieve the `params` out of the forward pass performed during 32 | `.init`. 33 | 34 | ```python 35 | import jax 36 | from jax import numpy as jnp 37 | 38 | network = SingleLayerModel(features=1) 39 | 40 | rng_key = jax.random.PRNGKey(seed=0) 41 | initial_params = network.init(rng_key, jnp.ones((1, 1))) 42 | ``` 43 | 44 | Right after getting the `params` from the `.init` method's output, we can 45 | use `safejax.serialize` to encode those using `safetensors`, that later on 46 | can be loaded back using `safejax.deserialize`. 47 | 48 | ```python 49 | from safejax import deserialize, serialize 50 | 51 | encoded_bytes = serialize(params=initial_params) 52 | decoded_params = deserialize(path_or_buf=encoded_bytes, freeze_dict=True) 53 | ``` 54 | 55 | As seen in the code above, we're using `freeze_dict=True` since its default 56 | value is False, as we want to freeze the `dict` with the params before actually 57 | returning it during `safejax.deserialize`, this transforms the `Dict` 58 | into a `FrozenDict`. 59 | 60 | Finally, we can use those `decoded_params` to run a forward pass 61 | with the previously defined single-layer network. 62 | 63 | ```python 64 | x = jnp.ones((1, 1)) 65 | y = network.apply(decoded_params, x) 66 | ``` 67 | 68 | 69 | ## Haiku - [`haiku_ft_safejax.py`](./examples/haiku_ft_safejax.py) 70 | 71 | To run this Python script you'll need to have both `safejax` and [`dm-haiku`](https://github.com/deepmind/dm-haiku) 72 | installed. 73 | 74 | A ResNet50 architecture will be used from `haiku.nets.imagenet.resnet` and since 75 | the purpose of the example is to show the integration of both `dm-haiku` and 76 | `safejax`, we won't use pre-trained weights. 77 | 78 | If you're not familiar with `dm-haiku`, please visit [Haiku Basics](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html). 79 | 80 | First of all, let's create the network instance for the ResNet50 using `dm-haiku` 81 | with the following code: 82 | 83 | ```python 84 | import haiku as hk 85 | from jax import numpy as jnp 86 | 87 | def resnet_fn(x: jnp.DeviceArray, is_training: bool): 88 | resnet = hk.nets.ResNet50(num_classes=10) 89 | return resnet(x, is_training=is_training) 90 | 91 | network = hk.without_apply_rng(hk.transform_with_state(resnet_fn)) 92 | ``` 93 | 94 | Some notes on the code above: 95 | * `haiku.nets.ResNet50` requires `num_classes` as a mandatory parameter 96 | * `haiku.nets.ResNet50.__call__` requires `is_training` as a mandatory parameter 97 | * It needs to be initialized with `hk.transform_with_state` as we want to preserve 98 | the state e.g. ExponentialMovingAverage in BatchNorm. More information at https://dm-haiku.readthedocs.io/en/latest/api.html#transform-with-state. 99 | * Using `hk.without_apply_rng` removes the `rng` arg in the `.apply` function. More information at https://dm-haiku.readthedocs.io/en/latest/api.html#without-apply-rng. 100 | 101 | Then we just initialize the network to retrieve both the `params` and the `state`, 102 | which again, are random. 103 | 104 | ```python 105 | import jax 106 | 107 | rng_key = jax.random.PRNGKey(seed=0) 108 | initial_params, initial_state = network.init( 109 | rng_key, jnp.ones([1, 224, 224, 3]), is_training=True 110 | ) 111 | ``` 112 | 113 | Now once we have the `params`, we can import `safejax.serialize` to serialize the 114 | params using `safetensors` as the tensor storage format, that later on can be loaded 115 | back using `safejax.deserialize` and used for the network's inference. 116 | 117 | ```python 118 | from safejax import deserialize, serialize 119 | 120 | encoded_bytes = serialize(params=initial_params) 121 | decoded_params = deserialize(path_or_buf=encoded_bytes) 122 | ``` 123 | 124 | Finally, let's just use those `decoded_params` to run the inference over the network 125 | using those weights. 126 | 127 | ```python 128 | x = jnp.ones([1, 224, 224, 3]) 129 | y, _ = network.apply(decoded_params, initial_state, x, is_training=False) 130 | ``` 131 | 132 | ## Objax - [`objax_ft_safejax.py`](./examples/objax_ft_safejax.py) 133 | 134 | To run this Python script you won't need to install anything else than 135 | `safejax`, as both `jax` and `objax` are installed as part of it. 136 | 137 | In this case, we'll be using one of the architectures defined in the model zoo 138 | of `objax` at [`objax/zoo`](https://github.com/google/objax/tree/master/objax/zoo), 139 | which is ResNet50. So first of all, let's initialize it: 140 | 141 | ```python 142 | from objax.zoo.resnet_v2 import ResNet50 143 | 144 | model = ResNet50(in_channels=3, num_classes=1000) 145 | ``` 146 | 147 | Once initialized, we can already access the model params which in `objax` are stored 148 | in `model.vars()` and are of type `VarCollection` which is a dictionary-like class. So 149 | on, we can already serialize those using `safejax.serialize` and `safetensors` format 150 | instead of `pickle` which is the current recommended way, see https://objax.readthedocs.io/en/latest/advanced/io.html. 151 | 152 | ```python 153 | from safejax import serialize 154 | 155 | encoded_bytes = serialize(params=model.vars()) 156 | ``` 157 | 158 | Then we can just deserialize those params back using `safejax.deserialize`, and 159 | we'll end up getting the same `VarCollection` dictionary back. Note that we need 160 | to disable the unflattening with `requires_unflattening=False` as it's not required 161 | due to the way it's stored, and set `to_var_collection=True` to get a `VarCollection` 162 | instead of a `Dict[str, jnp.DeviceArray]`, even though it will work with a standard 163 | dict too. 164 | 165 | ```python 166 | from safejax import deserialize 167 | 168 | decoded_params = deserialize( 169 | encoded_bytes, requires_unflattening=False, to_var_collection=True 170 | ) 171 | ``` 172 | 173 | Now, once decoded with `safejax.deserialize` we need to assign those key-value 174 | pais back to the `VarCollection` of the ResNet50 via assignment, as `.update` in 175 | `objax` has been redefined, see https://github.com/google/objax/blob/53b391bfa72dc59009c855d01b625049a35f5f1b/objax/variable.py#L311, 176 | and it's not consistent with the standard `dict.update` (already reported at 177 | https://github.com/google/objax/issues/254). So, instead, we need to loop over 178 | all the key-value pairs in the decoded params and assign those one by one to the 179 | `VarCollection` in `model.vars()`. 180 | 181 | ```python 182 | for key, value in decoded_params.items(): 183 | if key not in model.vars(): 184 | print(f"Key {key} not in model.vars()! Skipping.") 185 | continue 186 | model.vars()[key].assign(value) 187 | ``` 188 | 189 | And, finally, we can run the inference over the model via the `__call__` method 190 | as the `.vars()` are already copied from the params resulting of `safejax.deserialize`. 191 | 192 | ```python 193 | from jax import numpy as jnp 194 | 195 | x = jnp.ones((1, 3, 224, 224)) 196 | y = model(x, training=False) 197 | ``` 198 | 199 | Note that we're setting the `training` flag to `False`, which is the standard way 200 | of running the inference over a pre-trained model in `objax`. 201 | -------------------------------------------------------------------------------- /docs/examples.md: -------------------------------------------------------------------------------- 1 | # 🤖 Examples 2 | 3 | Here you will find some detailed examples of how to use `safejax` to serialize 4 | model parameters, in opposition to the default way to store those, which uses 5 | `pickle` as the format to store the tensors instead of `safetensors`. 6 | 7 | ## **Flax** 8 | 9 | Available at [`flax_ft_safejax.py`](https://github.com/alvarobartt/safejax/examples/flax_ft_safejax.py). 10 | 11 | To run this Python script you won't need to install anything else than 12 | `safejax`, as both `jax` and `flax` are installed as part of it. 13 | 14 | In this case, a single-layer model will be created, for now, `flax` 15 | doesn't have any pre-defined architecture such as ResNet, but you can use 16 | [`flaxmodels`](https://github.com/matthias-wright/flaxmodels) for that, as 17 | it defines some well-known architectures written in `flax`. 18 | 19 | ```python 20 | import jax 21 | from flax import linen as nn 22 | 23 | class SingleLayerModel(nn.Module): 24 | features: int 25 | 26 | @nn.compact 27 | def __call__(self, x): 28 | x = nn.Dense(features=self.features)(x) 29 | return x 30 | ``` 31 | 32 | Once the network has been defined, we can instantiate and initialize it, 33 | to retrieve the `params` out of the forward pass performed during 34 | `.init`. 35 | 36 | ```python 37 | import jax 38 | from jax import numpy as jnp 39 | 40 | network = SingleLayerModel(features=1) 41 | 42 | rng_key = jax.random.PRNGKey(seed=0) 43 | initial_params = network.init(rng_key, jnp.ones((1, 1))) 44 | ``` 45 | 46 | Right after getting the `params` from the `.init` method's output, we can 47 | use `safejax.serialize` to encode those using `safetensors`, that later on 48 | can be loaded back using `safejax.deserialize`. 49 | 50 | ```python 51 | from safejax import deserialize, serialize 52 | 53 | encoded_bytes = serialize(params=initial_params) 54 | decoded_params = deserialize(path_or_buf=encoded_bytes, freeze_dict=True) 55 | ``` 56 | 57 | As seen in the code above, we're using `freeze_dict=True` since its default 58 | value is False, as we want to freeze the `dict` with the params before actually 59 | returning it during `safejax.deserialize`, this transforms the `Dict` 60 | into a `FrozenDict`. 61 | 62 | Finally, we can use those `decoded_params` to run a forward pass 63 | with the previously defined single-layer network. 64 | 65 | ```python 66 | x = jnp.ones((1, 1)) 67 | y = network.apply(decoded_params, x) 68 | ``` 69 | 70 | 71 | ## **Haiku** 72 | 73 | Available at [`haiku_ft_safejax.py`](https://github.com/alvarobartt/safejax/examples/haiku_ft_safejax.py). 74 | 75 | To run this Python script you'll need to have both `safejax` and [`dm-haiku`](https://github.com/deepmind/dm-haiku) 76 | installed. 77 | 78 | A ResNet50 architecture will be used from `haiku.nets.imagenet.resnet` and since 79 | the purpose of the example is to show the integration of both `dm-haiku` and 80 | `safejax`, we won't use pre-trained weights. 81 | 82 | If you're not familiar with `dm-haiku`, please visit [Haiku Basics](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html). 83 | 84 | First of all, let's create the network instance for the ResNet50 using `dm-haiku` 85 | with the following code: 86 | 87 | ```python 88 | import haiku as hk 89 | from jax import numpy as jnp 90 | 91 | def resnet_fn(x: jnp.DeviceArray, is_training: bool): 92 | resnet = hk.nets.ResNet50(num_classes=10) 93 | return resnet(x, is_training=is_training) 94 | 95 | network = hk.without_apply_rng(hk.transform_with_state(resnet_fn)) 96 | ``` 97 | 98 | Some notes on the code above: 99 | * `haiku.nets.ResNet50` requires `num_classes` as a mandatory parameter 100 | * `haiku.nets.ResNet50.__call__` requires `is_training` as a mandatory parameter 101 | * It needs to be initialized with `hk.transform_with_state` as we want to preserve 102 | the state e.g. ExponentialMovingAverage in BatchNorm. More information at https://dm-haiku.readthedocs.io/en/latest/api.html#transform-with-state. 103 | * Using `hk.without_apply_rng` removes the `rng` arg in the `.apply` function. More information at https://dm-haiku.readthedocs.io/en/latest/api.html#without-apply-rng. 104 | 105 | Then we just initialize the network to retrieve both the `params` and the `state`, 106 | which again, are random. 107 | 108 | ```python 109 | import jax 110 | 111 | rng_key = jax.random.PRNGKey(seed=0) 112 | initial_params, initial_state = network.init( 113 | rng_key, jnp.ones([1, 224, 224, 3]), is_training=True 114 | ) 115 | ``` 116 | 117 | Now once we have the `params`, we can import `safejax.serialize` to serialize the 118 | params using `safetensors` as the tensor storage format, that later on can be loaded 119 | back using `safejax.deserialize` and used for the network's inference. 120 | 121 | ```python 122 | from safejax import deserialize, serialize 123 | 124 | encoded_bytes = serialize(params=initial_params) 125 | decoded_params = deserialize(path_or_buf=encoded_bytes) 126 | ``` 127 | 128 | Finally, let's just use those `decoded_params` to run the inference over the network 129 | using those weights. 130 | 131 | ```python 132 | x = jnp.ones([1, 224, 224, 3]) 133 | y, _ = network.apply(decoded_params, initial_state, x, is_training=False) 134 | ``` 135 | 136 | ## **Objax** 137 | 138 | Available at [`objax_ft_safejax.py`](https://github.com/alvarobartt/safejax/examples/objax_ft_safejax.py). 139 | 140 | To run this Python script you won't need to install anything else than 141 | `safejax`, as both `jax` and `objax` are installed as part of it. 142 | 143 | In this case, we'll be using one of the architectures defined in the model zoo 144 | of `objax` at [`objax/zoo`](https://github.com/google/objax/tree/master/objax/zoo), 145 | which is ResNet50. So first of all, let's initialize it: 146 | 147 | ```python 148 | from objax.zoo.resnet_v2 import ResNet50 149 | 150 | model = ResNet50(in_channels=3, num_classes=1000) 151 | ``` 152 | 153 | Once initialized, we can already access the model params which in `objax` are stored 154 | in `model.vars()` and are of type `VarCollection` which is a dictionary-like class. So 155 | on, we can already serialize those using `safejax.serialize` and `safetensors` format 156 | instead of `pickle` which is the current recommended way, see https://objax.readthedocs.io/en/latest/advanced/io.html. 157 | 158 | ```python 159 | from safejax import serialize 160 | 161 | encoded_bytes = serialize(params=model.vars()) 162 | ``` 163 | 164 | Then we can just deserialize those params back using `safejax.deserialize`, and 165 | we'll end up getting the same `VarCollection` dictionary back. Note that we need 166 | to disable the unflattening with `requires_unflattening=False` as it's not required 167 | due to the way it's stored, and set `to_var_collection=True` to get a `VarCollection` 168 | instead of a `Dict[str, jnp.DeviceArray]`, even though it will work with a standard 169 | dict too. 170 | 171 | ```python 172 | from safejax import deserialize 173 | 174 | decoded_params = deserialize( 175 | encoded_bytes, requires_unflattening=False, to_var_collection=True 176 | ) 177 | ``` 178 | 179 | Now, once decoded with `safejax.deserialize` we need to assign those key-value 180 | pais back to the `VarCollection` of the ResNet50 via assignment, as `.update` in 181 | `objax` has been redefined, see https://github.com/google/objax/blob/53b391bfa72dc59009c855d01b625049a35f5f1b/objax/variable.py#L311, 182 | and it's not consistent with the standard `dict.update` (already reported at 183 | https://github.com/google/objax/issues/254). So, instead, we need to loop over 184 | all the key-value pairs in the decoded params and assign those one by one to the 185 | `VarCollection` in `model.vars()`. 186 | 187 | ```python 188 | for key, value in decoded_params.items(): 189 | if key not in model.vars(): 190 | print(f"Key {key} not in model.vars()! Skipping.") 191 | continue 192 | model.vars()[key].assign(value) 193 | ``` 194 | 195 | And, finally, we can run the inference over the model via the `__call__` method 196 | as the `.vars()` are already copied from the params resulting of `safejax.deserialize`. 197 | 198 | ```python 199 | from jax import numpy as jnp 200 | 201 | x = jnp.ones((1, 3, 224, 224)) 202 | y = model(x, training=False) 203 | ``` 204 | 205 | Note that we're setting the `training` flag to `False`, which is the standard way 206 | of running the inference over a pre-trained model in `objax`. 207 | --------------------------------------------------------------------------------