├── .github ├── FUNDING.yml └── workflows │ ├── create-release.yml │ ├── publish-package.yml │ └── run-tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── docs ├── blog.md ├── images │ └── stateful-transforms.png └── tiny_nnx.ipynb ├── examples ├── 00_demo.ipynb ├── 01_functional_api.py ├── 02_lifted_transforms.py ├── 03_train_state.py ├── 04_pure.py ├── 05_vae.py ├── 06_scan_over_layers.py ├── 07_transformer.py ├── 08_save_load_checkpoints.py └── 09_parameter_surgery.py ├── ideas ├── nnx_example.py ├── pure │ ├── __init__.py │ ├── full │ │ ├── partitioning_full.py │ │ └── state_full.py │ ├── module.py │ ├── partitioning.py │ ├── rngs.py │ └── state.py ├── pure_example.py ├── pure_nnx_example.py ├── pure_pytree │ ├── __init__.py │ ├── dataclass.py │ ├── full │ │ ├── partitioning_full.py │ │ └── state_full.py │ ├── module.py │ ├── partitioning.py │ └── rngs.py ├── pure_pytree_example.py └── shape_inference.py ├── nnx ├── __init__.py ├── containers.py ├── contextlib.py ├── dataclasses.py ├── errors.py ├── helpers.py ├── ids.py ├── module.py ├── nn │ ├── __init__.py │ ├── activations.py │ ├── dtypes.py │ ├── initializers.py │ ├── linear.py │ ├── normalization.py │ └── stochastic.py ├── nodes.py ├── partitioning.py ├── pytreelib.py ├── reprlib.py ├── spmd.py ├── state.py ├── tracers.py └── transforms.py ├── poetry.lock ├── pyproject.toml ├── scripts ├── deploy-docs.sh ├── run-all-examples.bash └── update_version.py └── tests ├── __init__.py ├── test_containers.py ├── test_context.py ├── test_helpers.py ├── test_ids.py ├── test_integration.py ├── test_module.py ├── test_partitioning.py ├── test_pytree.py ├── test_spmd.py ├── test_transforms.py └── test_variable.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/.github/FUNDING.yml -------------------------------------------------------------------------------- /.github/workflows/create-release.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/.github/workflows/create-release.yml -------------------------------------------------------------------------------- /.github/workflows/publish-package.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/.github/workflows/publish-package.yml -------------------------------------------------------------------------------- /.github/workflows/run-tests.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/.github/workflows/run-tests.yml -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/.gitignore -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/.pre-commit-config.yaml -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/.vscode/settings.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/README.md -------------------------------------------------------------------------------- /docs/blog.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/docs/blog.md -------------------------------------------------------------------------------- /docs/images/stateful-transforms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/docs/images/stateful-transforms.png -------------------------------------------------------------------------------- /docs/tiny_nnx.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/docs/tiny_nnx.ipynb -------------------------------------------------------------------------------- /examples/00_demo.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/examples/00_demo.ipynb -------------------------------------------------------------------------------- /examples/01_functional_api.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/examples/01_functional_api.py -------------------------------------------------------------------------------- /examples/02_lifted_transforms.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/examples/02_lifted_transforms.py -------------------------------------------------------------------------------- /examples/03_train_state.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/examples/03_train_state.py -------------------------------------------------------------------------------- /examples/04_pure.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/examples/04_pure.py -------------------------------------------------------------------------------- /examples/05_vae.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/examples/05_vae.py -------------------------------------------------------------------------------- /examples/06_scan_over_layers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/examples/06_scan_over_layers.py -------------------------------------------------------------------------------- /examples/07_transformer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/examples/07_transformer.py -------------------------------------------------------------------------------- /examples/08_save_load_checkpoints.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/examples/08_save_load_checkpoints.py -------------------------------------------------------------------------------- /examples/09_parameter_surgery.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/examples/09_parameter_surgery.py -------------------------------------------------------------------------------- /ideas/nnx_example.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/nnx_example.py -------------------------------------------------------------------------------- /ideas/pure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure/__init__.py -------------------------------------------------------------------------------- /ideas/pure/full/partitioning_full.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure/full/partitioning_full.py -------------------------------------------------------------------------------- /ideas/pure/full/state_full.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure/full/state_full.py -------------------------------------------------------------------------------- /ideas/pure/module.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure/module.py -------------------------------------------------------------------------------- /ideas/pure/partitioning.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure/partitioning.py -------------------------------------------------------------------------------- /ideas/pure/rngs.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure/rngs.py -------------------------------------------------------------------------------- /ideas/pure/state.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure/state.py -------------------------------------------------------------------------------- /ideas/pure_example.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure_example.py -------------------------------------------------------------------------------- /ideas/pure_nnx_example.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure_nnx_example.py -------------------------------------------------------------------------------- /ideas/pure_pytree/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure_pytree/__init__.py -------------------------------------------------------------------------------- /ideas/pure_pytree/dataclass.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure_pytree/dataclass.py -------------------------------------------------------------------------------- /ideas/pure_pytree/full/partitioning_full.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure_pytree/full/partitioning_full.py -------------------------------------------------------------------------------- /ideas/pure_pytree/full/state_full.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure_pytree/full/state_full.py -------------------------------------------------------------------------------- /ideas/pure_pytree/module.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure_pytree/module.py -------------------------------------------------------------------------------- /ideas/pure_pytree/partitioning.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure_pytree/partitioning.py -------------------------------------------------------------------------------- /ideas/pure_pytree/rngs.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure_pytree/rngs.py -------------------------------------------------------------------------------- /ideas/pure_pytree_example.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/pure_pytree_example.py -------------------------------------------------------------------------------- /ideas/shape_inference.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/ideas/shape_inference.py -------------------------------------------------------------------------------- /nnx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/__init__.py -------------------------------------------------------------------------------- /nnx/containers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/containers.py -------------------------------------------------------------------------------- /nnx/contextlib.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/contextlib.py -------------------------------------------------------------------------------- /nnx/dataclasses.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/dataclasses.py -------------------------------------------------------------------------------- /nnx/errors.py: -------------------------------------------------------------------------------- 1 | class TraceContextError(Exception): 2 | pass 3 | -------------------------------------------------------------------------------- /nnx/helpers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/helpers.py -------------------------------------------------------------------------------- /nnx/ids.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/ids.py -------------------------------------------------------------------------------- /nnx/module.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/module.py -------------------------------------------------------------------------------- /nnx/nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nnx/nn/activations.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/nn/activations.py -------------------------------------------------------------------------------- /nnx/nn/dtypes.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/nn/dtypes.py -------------------------------------------------------------------------------- /nnx/nn/initializers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/nn/initializers.py -------------------------------------------------------------------------------- /nnx/nn/linear.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/nn/linear.py -------------------------------------------------------------------------------- /nnx/nn/normalization.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/nn/normalization.py -------------------------------------------------------------------------------- /nnx/nn/stochastic.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/nn/stochastic.py -------------------------------------------------------------------------------- /nnx/nodes.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/nodes.py -------------------------------------------------------------------------------- /nnx/partitioning.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/partitioning.py -------------------------------------------------------------------------------- /nnx/pytreelib.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/pytreelib.py -------------------------------------------------------------------------------- /nnx/reprlib.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/reprlib.py -------------------------------------------------------------------------------- /nnx/spmd.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/spmd.py -------------------------------------------------------------------------------- /nnx/state.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/state.py -------------------------------------------------------------------------------- /nnx/tracers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/tracers.py -------------------------------------------------------------------------------- /nnx/transforms.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/nnx/transforms.py -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/poetry.lock -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/pyproject.toml -------------------------------------------------------------------------------- /scripts/deploy-docs.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/scripts/deploy-docs.sh -------------------------------------------------------------------------------- /scripts/run-all-examples.bash: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/scripts/run-all-examples.bash -------------------------------------------------------------------------------- /scripts/update_version.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/scripts/update_version.py -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_containers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/tests/test_containers.py -------------------------------------------------------------------------------- /tests/test_context.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/tests/test_context.py -------------------------------------------------------------------------------- /tests/test_helpers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/tests/test_helpers.py -------------------------------------------------------------------------------- /tests/test_ids.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/tests/test_ids.py -------------------------------------------------------------------------------- /tests/test_integration.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/tests/test_integration.py -------------------------------------------------------------------------------- /tests/test_module.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/tests/test_module.py -------------------------------------------------------------------------------- /tests/test_partitioning.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/tests/test_partitioning.py -------------------------------------------------------------------------------- /tests/test_pytree.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/tests/test_pytree.py -------------------------------------------------------------------------------- /tests/test_spmd.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/tests/test_spmd.py -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/tests/test_transforms.py -------------------------------------------------------------------------------- /tests/test_variable.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgarciae/nnx/HEAD/tests/test_variable.py --------------------------------------------------------------------------------