├── .github └── workflows │ ├── ci.yml │ └── pypi-publish.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── advanced.rst ├── api.rst ├── conf.py ├── ext │ └── coverage_check.py ├── guides.rst ├── index.rst └── overview.rst ├── examples ├── README.md ├── autoencoder_mnist │ ├── experiment.py │ └── pipeline.py ├── classifier_mnist │ ├── experiment.py │ └── pipeline.py ├── datasets.py ├── losses.py ├── lrelunet101_imagenet │ ├── experiment.py │ └── pipeline.py ├── optax_wrapper.py ├── optimizers.py ├── resnet50_imagenet │ ├── experiment.py │ └── pipeline.py ├── schedules.py └── training.py ├── kfac_jax ├── __init__.py ├── _src │ ├── __init__.py │ ├── curvature_blocks │ │ ├── __init__.py │ │ ├── curvature_block.py │ │ ├── diagonal.py │ │ ├── full.py │ │ ├── kronecker_factored.py │ │ ├── tnt.py │ │ └── utils.py │ ├── curvature_estimator │ │ ├── __init__.py │ │ ├── block_diagonal.py │ │ ├── curvature_estimator.py │ │ ├── explicit_exact.py │ │ ├── implicit_exact.py │ │ └── optax_interface.py │ ├── layers_and_loss_tags.py │ ├── loss_functions.py │ ├── optimizer.py │ ├── patches_second_moment.py │ ├── tag_graph_matcher.py │ ├── tracer.py │ └── utils │ │ ├── __init__.py │ │ ├── accumulators.py │ │ ├── math.py │ │ ├── misc.py │ │ ├── parallel.py │ │ ├── staging.py │ │ └── types.py └── py.typed ├── pylintrc.toml ├── pyproject.toml ├── readthedocs.yml ├── setup.py ├── test.sh └── tests ├── estimator_test_utils.py ├── models.py ├── test_estimator.py ├── test_graph_matcher.py ├── test_optax_interface.py ├── test_patches_second_moment.py ├── test_tracer.py └── test_utils.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/.github/workflows/ci.yml -------------------------------------------------------------------------------- /.github/workflows/pypi-publish.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/.github/workflows/pypi-publish.yml -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/.gitignore -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/CONTRIBUTING.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/LICENSE -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include kfac_jax/py.typed 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/README.md -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/docs/Makefile -------------------------------------------------------------------------------- /docs/advanced.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/docs/advanced.rst -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/docs/api.rst -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/docs/conf.py -------------------------------------------------------------------------------- /docs/ext/coverage_check.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/docs/ext/coverage_check.py -------------------------------------------------------------------------------- /docs/guides.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/docs/guides.rst -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/docs/index.rst -------------------------------------------------------------------------------- /docs/overview.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/docs/overview.rst -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/README.md -------------------------------------------------------------------------------- /examples/autoencoder_mnist/experiment.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/autoencoder_mnist/experiment.py -------------------------------------------------------------------------------- /examples/autoencoder_mnist/pipeline.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/autoencoder_mnist/pipeline.py -------------------------------------------------------------------------------- /examples/classifier_mnist/experiment.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/classifier_mnist/experiment.py -------------------------------------------------------------------------------- /examples/classifier_mnist/pipeline.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/classifier_mnist/pipeline.py -------------------------------------------------------------------------------- /examples/datasets.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/datasets.py -------------------------------------------------------------------------------- /examples/losses.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/losses.py -------------------------------------------------------------------------------- /examples/lrelunet101_imagenet/experiment.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/lrelunet101_imagenet/experiment.py -------------------------------------------------------------------------------- /examples/lrelunet101_imagenet/pipeline.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/lrelunet101_imagenet/pipeline.py -------------------------------------------------------------------------------- /examples/optax_wrapper.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/optax_wrapper.py -------------------------------------------------------------------------------- /examples/optimizers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/optimizers.py -------------------------------------------------------------------------------- /examples/resnet50_imagenet/experiment.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/resnet50_imagenet/experiment.py -------------------------------------------------------------------------------- /examples/resnet50_imagenet/pipeline.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/resnet50_imagenet/pipeline.py -------------------------------------------------------------------------------- /examples/schedules.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/schedules.py -------------------------------------------------------------------------------- /examples/training.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/examples/training.py -------------------------------------------------------------------------------- /kfac_jax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/__init__.py -------------------------------------------------------------------------------- /kfac_jax/_src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/__init__.py -------------------------------------------------------------------------------- /kfac_jax/_src/curvature_blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/curvature_blocks/__init__.py -------------------------------------------------------------------------------- /kfac_jax/_src/curvature_blocks/curvature_block.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/curvature_blocks/curvature_block.py -------------------------------------------------------------------------------- /kfac_jax/_src/curvature_blocks/diagonal.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/curvature_blocks/diagonal.py -------------------------------------------------------------------------------- /kfac_jax/_src/curvature_blocks/full.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/curvature_blocks/full.py -------------------------------------------------------------------------------- /kfac_jax/_src/curvature_blocks/kronecker_factored.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/curvature_blocks/kronecker_factored.py -------------------------------------------------------------------------------- /kfac_jax/_src/curvature_blocks/tnt.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/curvature_blocks/tnt.py -------------------------------------------------------------------------------- /kfac_jax/_src/curvature_blocks/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/curvature_blocks/utils.py -------------------------------------------------------------------------------- /kfac_jax/_src/curvature_estimator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/curvature_estimator/__init__.py -------------------------------------------------------------------------------- /kfac_jax/_src/curvature_estimator/block_diagonal.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/curvature_estimator/block_diagonal.py -------------------------------------------------------------------------------- /kfac_jax/_src/curvature_estimator/curvature_estimator.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/curvature_estimator/curvature_estimator.py -------------------------------------------------------------------------------- /kfac_jax/_src/curvature_estimator/explicit_exact.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/curvature_estimator/explicit_exact.py -------------------------------------------------------------------------------- /kfac_jax/_src/curvature_estimator/implicit_exact.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/curvature_estimator/implicit_exact.py -------------------------------------------------------------------------------- /kfac_jax/_src/curvature_estimator/optax_interface.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/curvature_estimator/optax_interface.py -------------------------------------------------------------------------------- /kfac_jax/_src/layers_and_loss_tags.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/layers_and_loss_tags.py -------------------------------------------------------------------------------- /kfac_jax/_src/loss_functions.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/loss_functions.py -------------------------------------------------------------------------------- /kfac_jax/_src/optimizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/optimizer.py -------------------------------------------------------------------------------- /kfac_jax/_src/patches_second_moment.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/patches_second_moment.py -------------------------------------------------------------------------------- /kfac_jax/_src/tag_graph_matcher.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/tag_graph_matcher.py -------------------------------------------------------------------------------- /kfac_jax/_src/tracer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/tracer.py -------------------------------------------------------------------------------- /kfac_jax/_src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/utils/__init__.py -------------------------------------------------------------------------------- /kfac_jax/_src/utils/accumulators.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/utils/accumulators.py -------------------------------------------------------------------------------- /kfac_jax/_src/utils/math.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/utils/math.py -------------------------------------------------------------------------------- /kfac_jax/_src/utils/misc.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/utils/misc.py -------------------------------------------------------------------------------- /kfac_jax/_src/utils/parallel.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/utils/parallel.py -------------------------------------------------------------------------------- /kfac_jax/_src/utils/staging.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/utils/staging.py -------------------------------------------------------------------------------- /kfac_jax/_src/utils/types.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/kfac_jax/_src/utils/types.py -------------------------------------------------------------------------------- /kfac_jax/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pylintrc.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/pylintrc.toml -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/pyproject.toml -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/readthedocs.yml -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/setup.py -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/test.sh -------------------------------------------------------------------------------- /tests/estimator_test_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/tests/estimator_test_utils.py -------------------------------------------------------------------------------- /tests/models.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/tests/models.py -------------------------------------------------------------------------------- /tests/test_estimator.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/tests/test_estimator.py -------------------------------------------------------------------------------- /tests/test_graph_matcher.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/tests/test_graph_matcher.py -------------------------------------------------------------------------------- /tests/test_optax_interface.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/tests/test_optax_interface.py -------------------------------------------------------------------------------- /tests/test_patches_second_moment.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/tests/test_patches_second_moment.py -------------------------------------------------------------------------------- /tests/test_tracer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/tests/test_tracer.py -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/kfac-jax/HEAD/tests/test_utils.py --------------------------------------------------------------------------------