├── docs ├── rtd │ ├── .gitignore │ ├── assets │ │ ├── logo.pdf │ │ ├── logo.png │ │ ├── torch.pdf │ │ ├── hessian.pdf │ │ ├── .gitignore │ │ ├── convert.sh │ │ └── logo.tex │ ├── usage.rst │ ├── internals.rst │ ├── index.rst │ ├── Makefile │ ├── make.bat │ ├── linops.rst │ └── conf.py └── examples │ └── basic_usage │ ├── .gitignore │ ├── README.rst │ ├── benchmark │ ├── peakmem_KFAC_synthetic_imagenet_resnet50_cuda_matvec.json │ ├── time_Hessian_synthetic_imagenet_resnet50_cuda_matvec.json │ ├── time_Hessian_synthetic_shakespeare_nanogpt_cuda_matvec.json │ ├── time_KFAC_synthetic_imagenet_resnet50_cuda_matvec.json │ ├── time_KFAC_synthetic_imagenet_resnet50_cuda_precompute.json │ ├── time_KFAC_synthetic_shakespeare_nanogpt_cuda_matvec.json │ ├── peakmem_Hessian_synthetic_imagenet_resnet50_cuda_matvec.json │ ├── peakmem_KFAC_synthetic_imagenet_resnet50_cuda_precompute.json │ ├── peakmem_KFAC_synthetic_shakespeare_nanogpt_cuda_matvec.json │ ├── time_Hessian_synthetic_imagenet_resnet50_cuda_precompute.json │ ├── time_KFAC-inverse_synthetic_imagenet_resnet50_cuda_matvec.json │ ├── time_KFAC_synthetic_shakespeare_nanogpt_cuda_precompute.json │ ├── peakmem_Hessian_synthetic_imagenet_resnet50_cuda_precompute.json │ ├── peakmem_Hessian_synthetic_shakespeare_nanogpt_cuda_matvec.json │ ├── peakmem_Hessian_synthetic_shakespeare_nanogpt_cuda_precompute.json │ ├── peakmem_KFAC-inverse_synthetic_imagenet_resnet50_cuda_matvec.json │ ├── peakmem_KFAC-inverse_synthetic_shakespeare_nanogpt_cuda_matvec.json │ ├── peakmem_KFAC_synthetic_shakespeare_nanogpt_cuda_precompute.json │ ├── time_Empirical-Fisher_synthetic_imagenet_resnet50_cuda_matvec.json │ ├── time_Empirical-Fisher_synthetic_shakespeare_nanogpt_cuda_matvec.json │ ├── time_Hessian_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json │ ├── time_Hessian_synthetic_shakespeare_nanogpt_cuda_precompute.json │ ├── time_KFAC-inverse_synthetic_imagenet_resnet50_cuda_precompute.json │ ├── time_KFAC-inverse_synthetic_shakespeare_nanogpt_cuda_matvec.json │ ├── time_KFAC-inverse_synthetic_shakespeare_nanogpt_cuda_precompute.json │ ├── time_KFAC_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json │ ├── time_KFAC_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json │ ├── time_Monte-Carlo-Fisher_synthetic_imagenet_resnet50_cuda_matvec.json │ ├── time_Monte-Carlo-Fisher_synthetic_shakespeare_nanogpt_cuda_matvec.json │ ├── peakmem_Empirical-Fisher_synthetic_imagenet_resnet50_cuda_matvec.json │ ├── peakmem_Empirical-Fisher_synthetic_shakespeare_nanogpt_cuda_matvec.json │ ├── peakmem_Hessian_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json │ ├── peakmem_KFAC-inverse_synthetic_imagenet_resnet50_cuda_precompute.json │ ├── peakmem_KFAC-inverse_synthetic_shakespeare_nanogpt_cuda_precompute.json │ ├── peakmem_KFAC_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json │ ├── peakmem_KFAC_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json │ ├── peakmem_Monte-Carlo-Fisher_synthetic_imagenet_resnet50_cuda_matvec.json │ ├── time_Empirical-Fisher_synthetic_imagenet_resnet50_cuda_precompute.json │ ├── time_Generalized-Gauss-Newton_synthetic_imagenet_resnet50_cuda_matvec.json │ ├── time_Hessian_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json │ ├── time_KFAC-inverse_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json │ ├── peakmem_Empirical-Fisher_synthetic_imagenet_resnet50_cuda_precompute.json │ ├── peakmem_Empirical-Fisher_synthetic_shakespeare_nanogpt_cuda_precompute.json │ ├── peakmem_Generalized-Gauss-Newton_synthetic_imagenet_resnet50_cuda_matvec.json │ ├── peakmem_Hessian_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json │ ├── peakmem_KFAC-inverse_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json │ ├── peakmem_Monte-Carlo-Fisher_synthetic_imagenet_resnet50_cuda_precompute.json │ ├── peakmem_Monte-Carlo-Fisher_synthetic_shakespeare_nanogpt_cuda_matvec.json │ ├── time_Empirical-Fisher_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json │ ├── time_Empirical-Fisher_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json │ ├── time_Empirical-Fisher_synthetic_shakespeare_nanogpt_cuda_precompute.json │ ├── time_Generalized-Gauss-Newton_synthetic_shakespeare_nanogpt_cuda_matvec.json │ ├── time_KFAC-inverse_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json │ ├── time_Monte-Carlo-Fisher_synthetic_imagenet_resnet50_cuda_precompute.json │ ├── time_Monte-Carlo-Fisher_synthetic_shakespeare_nanogpt_cuda_precompute.json │ ├── peakmem_Empirical-Fisher_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json │ ├── peakmem_Empirical-Fisher_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json │ ├── peakmem_Generalized-Gauss-Newton_synthetic_shakespeare_nanogpt_cuda_matvec.json │ ├── peakmem_KFAC-inverse_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json │ ├── peakmem_Monte-Carlo-Fisher_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json │ ├── peakmem_Monte-Carlo-Fisher_synthetic_shakespeare_nanogpt_cuda_precompute.json │ ├── time_Generalized-Gauss-Newton_synthetic_imagenet_resnet50_cuda_precompute.json │ ├── time_Generalized-Gauss-Newton_synthetic_shakespeare_nanogpt_cuda_precompute.json │ ├── time_Monte-Carlo-Fisher_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json │ ├── time_Monte-Carlo-Fisher_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json │ ├── peakmem_Generalized-Gauss-Newton_synthetic_imagenet_resnet50_cuda_precompute.json │ ├── peakmem_Generalized-Gauss-Newton_synthetic_shakespeare_nanogpt_cuda_precompute.json │ ├── peakmem_Monte-Carlo-Fisher_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json │ ├── time_Generalized-Gauss-Newton_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json │ ├── time_Generalized-Gauss-Newton_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json │ ├── peakmem_Generalized-Gauss-Newton_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json │ ├── peakmem_Generalized-Gauss-Newton_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json │ ├── time_synthetic_imagenet_resnet50_cuda.pdf │ ├── time_synthetic_shakespeare_nanogpt_cuda.pdf │ ├── peakmem_synthetic_imagenet_resnet50_cuda.pdf │ └── peakmem_synthetic_shakespeare_nanogpt_cuda.pdf │ ├── toy_spectrum.pdf │ ├── toy_log_spectrum.pdf │ ├── trace_estimation.pdf │ ├── diagonal_estimation.pdf │ ├── curvature_matrices_log_abs.pdf │ ├── benchmark_utils.py │ ├── example_huggingface.py │ ├── memory_benchmark.py │ ├── example_visual_tour.py │ ├── example_matrix_vector_products.py │ ├── example_model_merging.py │ ├── example_submatrices.py │ └── example_eigenvalues.py ├── .envrc ├── test ├── norm │ ├── __init__.py │ └── test_hutchinson.py ├── __init__.py ├── experimental │ ├── __init__.py │ └── test_activation_hessian.py ├── papyan2020traces │ ├── __init__.py │ └── test_spectrum.py ├── trace │ ├── __init__.py │ ├── test_hutchinson.py │ ├── test_meyer2020hutch.py │ └── test_epperly2024xtrace.py ├── diagonal │ ├── __init__.py │ ├── test_hutchinson.py │ └── test_epperly2024xtrace.py ├── test_utils.py ├── test_ggn.py ├── test_gradient_moments.py ├── test_fisher.py ├── test_jacobian.py ├── test_hessian.py ├── test_submatrix.py └── conftest.py ├── curvlinops ├── trace │ ├── __init__.py │ ├── hutchinson.py │ ├── epperly2024xtrace.py │ └── meyer2020hutch.py ├── norm │ ├── __init__.py │ └── hutchinson.py ├── diagonal │ ├── __init__.py │ ├── hutchinson.py │ └── epperly2024xtrace.py ├── experimental │ └── __init__.py ├── papyan2020traces │ └── __init__.py ├── sampling.py ├── __init__.py ├── submatrix.py ├── examples │ └── __init__.py ├── ggn.py ├── gradient_moments.py └── hessian.py ├── .gitignore ├── .conda_env.yml ├── setup.cfg ├── pytest.ini ├── .readthedocs.yaml ├── .github └── workflows │ ├── lint-ruff.yaml │ ├── lint-darglint.yaml │ ├── lint-format.yaml │ ├── lint-pydocstyle.yaml │ ├── test.yaml │ └── python-publish.yml ├── LICENSE ├── makefile ├── README.md └── pyproject.toml /docs/rtd/.gitignore: -------------------------------------------------------------------------------- 1 | _build 2 | basic_usage -------------------------------------------------------------------------------- /.envrc: -------------------------------------------------------------------------------- 1 | source ~/anaconda3/bin/activate curvlinops 2 | -------------------------------------------------------------------------------- /docs/examples/basic_usage/.gitignore: -------------------------------------------------------------------------------- 1 | nanogpt_model.py -------------------------------------------------------------------------------- /test/norm/__init__.py: -------------------------------------------------------------------------------- 1 | """Test ``curvlinops.norm``.""" 2 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``curvlinops`` library.""" 2 | -------------------------------------------------------------------------------- /curvlinops/trace/__init__.py: -------------------------------------------------------------------------------- 1 | """Trace estimation techniques.""" 2 | -------------------------------------------------------------------------------- /curvlinops/norm/__init__.py: -------------------------------------------------------------------------------- 1 | """Matrix norm estimation methods.""" 2 | -------------------------------------------------------------------------------- /curvlinops/diagonal/__init__.py: -------------------------------------------------------------------------------- 1 | """Matrix diagonal estimation methods.""" 2 | -------------------------------------------------------------------------------- /docs/examples/basic_usage/README.rst: -------------------------------------------------------------------------------- 1 | Code samples 2 | ================== 3 | -------------------------------------------------------------------------------- /test/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains test of ``curvlinops.experimental``.""" 2 | -------------------------------------------------------------------------------- /test/papyan2020traces/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains tests of ``curvlinops/papyan2020``""" 2 | -------------------------------------------------------------------------------- /docs/rtd/assets/logo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/curvlinops/HEAD/docs/rtd/assets/logo.pdf -------------------------------------------------------------------------------- /docs/rtd/assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/curvlinops/HEAD/docs/rtd/assets/logo.png -------------------------------------------------------------------------------- /docs/rtd/assets/torch.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/curvlinops/HEAD/docs/rtd/assets/torch.pdf -------------------------------------------------------------------------------- /docs/rtd/assets/hessian.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/curvlinops/HEAD/docs/rtd/assets/hessian.pdf -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_KFAC_synthetic_imagenet_resnet50_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 5.881683826446533} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Hessian_synthetic_imagenet_resnet50_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"time": 1.1318046376109123} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Hessian_synthetic_shakespeare_nanogpt_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"time": 2.9113959968090057} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_KFAC_synthetic_imagenet_resnet50_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"time": 0.03497970104217529} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_KFAC_synthetic_imagenet_resnet50_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"time": 1.027599535882473} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_KFAC_synthetic_shakespeare_nanogpt_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"time": 0.12887317687273026} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | src/ 2 | **.egg-info 3 | **.mypy_cache 4 | **__pycache__ 5 | .vscode/* 6 | .coverage 7 | .eggs 8 | **.DS_Store 9 | -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Hessian_synthetic_imagenet_resnet50_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 12.749358177185059} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_KFAC_synthetic_imagenet_resnet50_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 5.785980701446533} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_KFAC_synthetic_shakespeare_nanogpt_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 6.78189754486084} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Hessian_synthetic_imagenet_resnet50_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"time": 7.815659046173096e-06} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_KFAC-inverse_synthetic_imagenet_resnet50_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"time": 0.03249528259038925} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_KFAC_synthetic_shakespeare_nanogpt_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"time": 1.2078865841031075} -------------------------------------------------------------------------------- /docs/rtd/assets/.gitignore: -------------------------------------------------------------------------------- 1 | **.pdf-view-restore 2 | **.el 3 | **.aux 4 | **.fdb_latexmk 5 | **.fls 6 | **.log 7 | **.synctex.gz -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Hessian_synthetic_imagenet_resnet50_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 0.09543752670288086} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Hessian_synthetic_shakespeare_nanogpt_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 17.485933303833008} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Hessian_synthetic_shakespeare_nanogpt_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 0.6088504791259766} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_KFAC-inverse_synthetic_imagenet_resnet50_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 5.881683826446533} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_KFAC-inverse_synthetic_shakespeare_nanogpt_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 6.78189754486084} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_KFAC_synthetic_shakespeare_nanogpt_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 6.465182304382324} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Empirical-Fisher_synthetic_imagenet_resnet50_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"time": 0.6661899238824844} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Empirical-Fisher_synthetic_shakespeare_nanogpt_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"time": 2.092774197459221} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Hessian_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"time": 0.26514657586812973} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Hessian_synthetic_shakespeare_nanogpt_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"time": 7.644295692443848e-06} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_KFAC-inverse_synthetic_imagenet_resnet50_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"time": 1.0400259047746658} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_KFAC-inverse_synthetic_shakespeare_nanogpt_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"time": 0.12013363093137741} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_KFAC-inverse_synthetic_shakespeare_nanogpt_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"time": 1.2172305434942245} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_KFAC_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"time": 0.2477797344326973} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_KFAC_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"time": 0.7269761338829994} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Monte-Carlo-Fisher_synthetic_imagenet_resnet50_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"time": 0.710419662296772} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Monte-Carlo-Fisher_synthetic_shakespeare_nanogpt_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"time": 2.253086730837822} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Empirical-Fisher_synthetic_imagenet_resnet50_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 6.163455963134766} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Empirical-Fisher_synthetic_shakespeare_nanogpt_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 10.986472129821777} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Hessian_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 5.408483028411865} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_KFAC-inverse_synthetic_imagenet_resnet50_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 5.785980701446533} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_KFAC-inverse_synthetic_shakespeare_nanogpt_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 6.465182304382324} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_KFAC_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 5.4073052406311035} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_KFAC_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 6.280280590057373} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Monte-Carlo-Fisher_synthetic_imagenet_resnet50_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 6.163455963134766} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Empirical-Fisher_synthetic_imagenet_resnet50_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"time": 7.793307304382324e-06} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Generalized-Gauss-Newton_synthetic_imagenet_resnet50_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"time": 0.6687018722295761} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Hessian_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"time": 0.7823684439063072} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_KFAC-inverse_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"time": 0.24632802605628967} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Empirical-Fisher_synthetic_imagenet_resnet50_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 0.09543752670288086} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Empirical-Fisher_synthetic_shakespeare_nanogpt_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 0.6088504791259766} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Generalized-Gauss-Newton_synthetic_imagenet_resnet50_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 6.162723541259766} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Hessian_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 6.572415828704834} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_KFAC-inverse_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 5.4073052406311035} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Monte-Carlo-Fisher_synthetic_imagenet_resnet50_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 0.09543752670288086} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Monte-Carlo-Fisher_synthetic_shakespeare_nanogpt_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 10.986472129821777} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Empirical-Fisher_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"time": 0.26463960111141205} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Empirical-Fisher_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"time": 0.7859276458621025} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Empirical-Fisher_synthetic_shakespeare_nanogpt_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"time": 7.830560207366943e-06} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Generalized-Gauss-Newton_synthetic_shakespeare_nanogpt_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"time": 2.0252698212862015} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_KFAC-inverse_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"time": 0.7273677289485931} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Monte-Carlo-Fisher_synthetic_imagenet_resnet50_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"time": 7.763504981994629e-06} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Monte-Carlo-Fisher_synthetic_shakespeare_nanogpt_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"time": 7.778406143188477e-06} -------------------------------------------------------------------------------- /docs/examples/basic_usage/toy_spectrum.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/curvlinops/HEAD/docs/examples/basic_usage/toy_spectrum.pdf -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Empirical-Fisher_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 5.408483028411865} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Empirical-Fisher_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 6.572415828704834} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Generalized-Gauss-Newton_synthetic_shakespeare_nanogpt_cuda_matvec.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 10.978537559509277} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_KFAC-inverse_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 6.280280590057373} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Monte-Carlo-Fisher_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 5.408483028411865} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Monte-Carlo-Fisher_synthetic_shakespeare_nanogpt_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 0.6088504791259766} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Generalized-Gauss-Newton_synthetic_imagenet_resnet50_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"time": 7.614493370056152e-06} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Generalized-Gauss-Newton_synthetic_shakespeare_nanogpt_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"time": 7.711350917816162e-06} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Monte-Carlo-Fisher_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"time": 0.28098224848508835} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Monte-Carlo-Fisher_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"time": 0.833616565912962} -------------------------------------------------------------------------------- /docs/examples/basic_usage/toy_log_spectrum.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/curvlinops/HEAD/docs/examples/basic_usage/toy_log_spectrum.pdf -------------------------------------------------------------------------------- /docs/examples/basic_usage/trace_estimation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/curvlinops/HEAD/docs/examples/basic_usage/trace_estimation.pdf -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Generalized-Gauss-Newton_synthetic_imagenet_resnet50_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 0.09543752670288086} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Generalized-Gauss-Newton_synthetic_shakespeare_nanogpt_cuda_precompute.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 0.6088504791259766} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Monte-Carlo-Fisher_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 6.572415828704834} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Generalized-Gauss-Newton_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"time": 0.2654357925057411} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_Generalized-Gauss-Newton_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"time": 0.7798505648970604} -------------------------------------------------------------------------------- /docs/examples/basic_usage/diagonal_estimation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/curvlinops/HEAD/docs/examples/basic_usage/diagonal_estimation.pdf -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Generalized-Gauss-Newton_synthetic_imagenet_resnet50_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 5.408483028411865} -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_Generalized-Gauss-Newton_synthetic_shakespeare_nanogpt_cuda_gradient_and_loss.json: -------------------------------------------------------------------------------- 1 | {"peakmem": 6.572415828704834} -------------------------------------------------------------------------------- /docs/examples/basic_usage/curvature_matrices_log_abs.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/curvlinops/HEAD/docs/examples/basic_usage/curvature_matrices_log_abs.pdf -------------------------------------------------------------------------------- /.conda_env.yml: -------------------------------------------------------------------------------- 1 | name: curvlinops 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.9.16 7 | - pip=23.1.2 8 | - pip: 9 | - -e .[lint,test,docs] 10 | -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_synthetic_imagenet_resnet50_cuda.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/curvlinops/HEAD/docs/examples/basic_usage/benchmark/time_synthetic_imagenet_resnet50_cuda.pdf -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/time_synthetic_shakespeare_nanogpt_cuda.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/curvlinops/HEAD/docs/examples/basic_usage/benchmark/time_synthetic_shakespeare_nanogpt_cuda.pdf -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Note: These tools do not yet support `pyproject.toml`, but these options 2 | # should be moved there once support is added. 3 | 4 | [darglint] 5 | docstring_style = google 6 | strictness = full 7 | -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_synthetic_imagenet_resnet50_cuda.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/curvlinops/HEAD/docs/examples/basic_usage/benchmark/peakmem_synthetic_imagenet_resnet50_cuda.pdf -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark/peakmem_synthetic_shakespeare_nanogpt_cuda.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/curvlinops/HEAD/docs/examples/basic_usage/benchmark/peakmem_synthetic_shakespeare_nanogpt_cuda.pdf -------------------------------------------------------------------------------- /curvlinops/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains experimental features.""" 2 | 3 | from curvlinops.experimental.activation_hessian import ActivationHessianLinearOperator 4 | 5 | __all__ = ["ActivationHessianLinearOperator"] 6 | -------------------------------------------------------------------------------- /docs/rtd/usage.rst: -------------------------------------------------------------------------------- 1 | How to use Curvlinops 2 | ===================== 3 | 4 | To get started, check out the :ref:`basic example`. 5 | 6 | Advanced use cases are illustrated :ref:`here`. 7 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | # NOTE The documentation recommends to **not** configure pytest with setup.cfg 2 | # (https://docs.pytest.org/en/6.2.x/customize.html#setup-cfg) 3 | [pytest] 4 | optional_tests: 5 | montecarlo: slow tests using Monte-Carlo sampling -------------------------------------------------------------------------------- /docs/rtd/assets/convert.sh: -------------------------------------------------------------------------------- 1 | # convert pdf generated by LaTeX into svg 2 | echo "Automatic conversion won't work (scipy/torch logos will have artifacts)." 3 | echo "Use inkscape and follow the steps in https://graphicdesign.stackexchange.com/a/144670" 4 | -------------------------------------------------------------------------------- /test/trace/__init__.py: -------------------------------------------------------------------------------- 1 | """Test ``curvlinops.trace``.""" 2 | 3 | DISTRIBUTIONS = ["rademacher", "normal"] 4 | DISTRIBUTION_IDS = [f"distribution={distribution}" for distribution in DISTRIBUTIONS] 5 | 6 | NUM_MATVECS = [3, 6] 7 | NUM_MATVEC_IDS = [f"num_matvecs={num_matvecs}" for num_matvecs in NUM_MATVECS] 8 | -------------------------------------------------------------------------------- /test/diagonal/__init__.py: -------------------------------------------------------------------------------- 1 | """Test ``curvlinops.diagonal``.""" 2 | 3 | DISTRIBUTIONS = ["rademacher", "normal"] 4 | DISTRIBUTION_IDS = [f"distribution={distribution}" for distribution in DISTRIBUTIONS] 5 | 6 | NUM_MATVECS = [4, 10] 7 | NUM_MATVEC_IDS = [f"num_matvecs={num_matvecs}" for num_matvecs in NUM_MATVECS] 8 | -------------------------------------------------------------------------------- /docs/rtd/internals.rst: -------------------------------------------------------------------------------- 1 | Internals 2 | ============ 3 | 4 | This section is for internal purposes only and serves to inform developers about 5 | details; because rendered LaTeX is easier to read than source code. 6 | 7 | 8 | KFAC-related 9 | ------------- 10 | 11 | .. autofunction:: curvlinops.kfac_utils.loss_hessian_matrix_sqrt 12 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | sphinx: 7 | configuration: docs/rtd/conf.py 8 | 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.9" 13 | 14 | python: 15 | install: 16 | - method: pip 17 | path: . 18 | extra_requirements: 19 | - docs 20 | -------------------------------------------------------------------------------- /curvlinops/papyan2020traces/__init__.py: -------------------------------------------------------------------------------- 1 | """Methods for the analysis of large-scale deep learning matrices via linear operators. 2 | 3 | Implements spectral analysis methods and linear operators from the GGN's 4 | hierarchical decomposition presented in 5 | 6 | - Papyan, V. (2020). Traces of class/cross-class structure pervade deep learning 7 | spectra. Journal of Machine Learning Research (JMLR), 8 | https://jmlr.org/papers/v21/20-933.html 9 | """ 10 | -------------------------------------------------------------------------------- /docs/rtd/index.rst: -------------------------------------------------------------------------------- 1 | Curvlinops 2 | ================================= 3 | 4 | Installation 5 | ------------ 6 | 7 | .. code:: bash 8 | 9 | pip install curvlinops-for-pytorch 10 | 11 | 12 | .. toctree:: 13 | :maxdepth: 2 14 | :caption: Getting started 15 | 16 | usage 17 | 18 | 19 | .. toctree:: 20 | :maxdepth: 2 21 | :caption: Curvlinops 22 | 23 | linops 24 | basic_usage/index 25 | 26 | .. toctree:: 27 | :caption: Internals 28 | 29 | internals 30 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | """Test general utility functions.""" 2 | 3 | from pytest import raises 4 | 5 | from curvlinops.utils import split_list 6 | 7 | 8 | def test_split_list(): 9 | """Test list splitting utility function.""" 10 | assert split_list(["a", "b", "c", "d"], [1, 3]) == [["a"], ["b", "c", "d"]] 11 | assert split_list(["a", "b", "c"], [3]) == [["a", "b", "c"]] 12 | 13 | with raises(ValueError): 14 | split_list(["a", "b", "c"], [1, 3]) 15 | -------------------------------------------------------------------------------- /.github/workflows/lint-ruff.yaml: -------------------------------------------------------------------------------- 1 | name: Lint-ruff 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | ruff: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - name: Set up Python 3.9 11 | uses: actions/setup-python@v5 12 | with: 13 | python-version: 3.9 14 | cache: pip 15 | - name: Install dependencies 16 | run: | 17 | python -m pip install --upgrade pip 18 | make install-lint 19 | - name: Run ruff 20 | run: | 21 | make ruff-check -------------------------------------------------------------------------------- /.github/workflows/lint-darglint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint-darglint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | darglint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - name: Set up Python 3.9 11 | uses: actions/setup-python@v5 12 | with: 13 | python-version: 3.9 14 | cache: pip 15 | - name: Install dependencies 16 | run: | 17 | python -m pip install --upgrade pip 18 | make install-lint 19 | - name: Run darglint 20 | run: | 21 | make darglint-check 22 | -------------------------------------------------------------------------------- /.github/workflows/lint-format.yaml: -------------------------------------------------------------------------------- 1 | name: Lint-format 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | ruff-format: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - name: Set up Python 3.9 11 | uses: actions/setup-python@v5 12 | with: 13 | python-version: 3.9 14 | cache: pip 15 | - name: Install dependencies 16 | run: | 17 | python -m pip install --upgrade pip 18 | make install-lint 19 | - name: Run ruff format 20 | run: | 21 | make ruff-format-check 22 | -------------------------------------------------------------------------------- /.github/workflows/lint-pydocstyle.yaml: -------------------------------------------------------------------------------- 1 | name: Lint-pydocstyle 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | pydocstyle: 7 | # NOTE: Uncomment the following line to disable 8 | if: false 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 3.9 13 | uses: actions/setup-python@v5 14 | with: 15 | python-version: 3.9 16 | cache: pip 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | make install-lint 21 | - name: Run pydocstyle 22 | run: | 23 | make pydocstyle-check 24 | -------------------------------------------------------------------------------- /docs/rtd/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/rtd/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | tests: 7 | name: "Python ${{ matrix.python-version }}" 8 | runs-on: ubuntu-latest 9 | env: 10 | USING_COVERAGE: '3.9' 11 | strategy: 12 | matrix: 13 | python-version: ["3.9"] 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: actions/setup-python@v5 17 | with: 18 | python-version: "${{ matrix.python-version }}" 19 | cache: pip 20 | - name: Install Dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | make install-test 24 | - name: Run test 25 | run: | 26 | make test 27 | 28 | - name: Test coveralls - python ${{ matrix.python-version }} 29 | run: coveralls --service=github 30 | env: 31 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 32 | flag-name: run-${{ matrix.python-version }} 33 | parallel: true 34 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.x" 20 | cache: pip 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | python -m pip install --upgrade twine build 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: __token__ 28 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 29 | run: | 30 | python -m build 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /test/test_ggn.py: -------------------------------------------------------------------------------- 1 | """Contains tests for ``curvlinops/ggn``.""" 2 | 3 | from curvlinops import GGNLinearOperator 4 | from curvlinops.examples.functorch import functorch_ggn 5 | from test.utils import compare_consecutive_matmats, compare_matmat 6 | 7 | 8 | def test_GGNLinearOperator_matvec(case, adjoint: bool, is_vec: bool): 9 | """Test matrix-matrix multiplication with the GGN. 10 | 11 | Args: 12 | case: Tuple of model, loss function, parameters, data, and batch size getter. 13 | adjoint: Whether to test the adjoint operator. 14 | is_vec: Whether to test matrix-vector or matrix-matrix multiplication. 15 | """ 16 | model_func, loss_func, params, data, batch_size_fn = case 17 | 18 | G = GGNLinearOperator( 19 | model_func, loss_func, params, data, batch_size_fn=batch_size_fn 20 | ) 21 | G_mat = functorch_ggn(model_func, loss_func, params, data, input_key="x") 22 | 23 | compare_consecutive_matmats(G, adjoint, is_vec) 24 | compare_matmat(G, G_mat, adjoint, is_vec, atol=1e-7, rtol=1e-4) 25 | -------------------------------------------------------------------------------- /test/trace/test_hutchinson.py: -------------------------------------------------------------------------------- 1 | """Test ``curvlinops.trace.hutchinson``.""" 2 | 3 | from functools import partial 4 | 5 | from pytest import mark 6 | from torch import manual_seed, rand 7 | 8 | from curvlinops import hutchinson_trace 9 | from test.trace import DISTRIBUTION_IDS, DISTRIBUTIONS, NUM_MATVEC_IDS, NUM_MATVECS 10 | from test.utils import check_estimator_convergence 11 | 12 | 13 | @mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS) 14 | @mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS) 15 | def test_hutchinson_trace(num_matvecs: int, distribution: str): 16 | """Test whether Hutchinon's trace estimator converges to the true trace. 17 | 18 | Args: 19 | num_matvecs: Number of matrix-vector products used per estimate. 20 | distribution: Distribution of the random vectors used for the trace estimation. 21 | """ 22 | manual_seed(0) 23 | A = rand(10, 10) 24 | estimator = partial( 25 | hutchinson_trace, A=A, num_matvecs=num_matvecs, distribution=distribution 26 | ) 27 | check_estimator_convergence(estimator, num_matvecs, A.trace()) 28 | -------------------------------------------------------------------------------- /test/test_gradient_moments.py: -------------------------------------------------------------------------------- 1 | """Contains tests for ``curvlinops/gradient_moments.py``.""" 2 | 3 | from curvlinops import EFLinearOperator 4 | from curvlinops.examples.functorch import functorch_empirical_fisher 5 | from test.utils import compare_consecutive_matmats, compare_matmat 6 | 7 | 8 | def test_EFLinearOperator(case, adjoint: bool, is_vec: bool): 9 | """Test matrix-matrix multiplication with the empirical Fisher. 10 | 11 | Args: 12 | case: Tuple of model, loss function, parameters, data, and batch size getter. 13 | adjoint: Whether to test the adjoint operator. 14 | is_vec: Whether to test matrix-vector or matrix-matrix multiplication. 15 | """ 16 | model_func, loss_func, params, data, batch_size_fn = case 17 | 18 | E = EFLinearOperator( 19 | model_func, loss_func, params, data, batch_size_fn=batch_size_fn 20 | ) 21 | E_mat = functorch_empirical_fisher( 22 | model_func, loss_func, params, data, input_key="x" 23 | ) 24 | 25 | compare_consecutive_matmats(E, adjoint, is_vec) 26 | compare_matmat(E, E_mat, adjoint, is_vec, rtol=5e-4, atol=5e-6) 27 | -------------------------------------------------------------------------------- /test/diagonal/test_hutchinson.py: -------------------------------------------------------------------------------- 1 | """Test ``curvlinops.diagonal.hutchinson``.""" 2 | 3 | from functools import partial 4 | 5 | from pytest import mark 6 | from torch import manual_seed, rand 7 | 8 | from curvlinops import hutchinson_diag 9 | from test.diagonal import DISTRIBUTION_IDS, DISTRIBUTIONS, NUM_MATVEC_IDS, NUM_MATVECS 10 | from test.utils import check_estimator_convergence 11 | 12 | 13 | @mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS) 14 | @mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS) 15 | def test_hutchinson_diag(num_matvecs: int, distribution: str): 16 | """Test whether Hutchinson's diagonal estimator converges to the true diagonal. 17 | 18 | Args: 19 | num_matvecs: Number of matrix-vector products used per estimate. 20 | distribution: Distribution of the random vectors used for the trace estimation. 21 | """ 22 | manual_seed(1) 23 | A = rand(30, 30) 24 | estimator = partial( 25 | hutchinson_diag, A=A, num_matvecs=num_matvecs, distribution=distribution 26 | ) 27 | check_estimator_convergence(estimator, num_matvecs, A.diag(), target_rel_error=2e-2) 28 | -------------------------------------------------------------------------------- /test/trace/test_meyer2020hutch.py: -------------------------------------------------------------------------------- 1 | """Test ``curvlinops.trace.meyer2020hutch.""" 2 | 3 | from functools import partial 4 | 5 | from pytest import mark 6 | from torch import manual_seed, rand 7 | 8 | from curvlinops import hutchpp_trace 9 | from test.trace import DISTRIBUTION_IDS, DISTRIBUTIONS, NUM_MATVEC_IDS, NUM_MATVECS 10 | from test.utils import check_estimator_convergence 11 | 12 | 13 | @mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS) 14 | @mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS) 15 | def test_hutchpp_trace(num_matvecs: int, distribution: str): 16 | """Test whether Hutch++'s trace estimator converges to the true trace. 17 | 18 | Args: 19 | num_matvecs: Number of matrix-vector products used per estimate. 20 | distribution: Distribution of the random vectors used for the trace estimation. 21 | """ 22 | manual_seed(0) 23 | A = rand(10, 10) 24 | estimator = partial( 25 | hutchpp_trace, A=A, num_matvecs=num_matvecs, distribution=distribution 26 | ) 27 | check_estimator_convergence( 28 | estimator, num_matvecs, A.trace(), target_rel_error=1e-3 29 | ) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Felix Dangel, Runa Eschenhagen, Lukas Tatzel & Philipp Hennig 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /test/norm/test_hutchinson.py: -------------------------------------------------------------------------------- 1 | """Test ``curvlinops.norm.hutchinson``.""" 2 | 3 | from functools import partial 4 | 5 | from pytest import mark 6 | from torch import manual_seed, rand 7 | 8 | from curvlinops import hutchinson_squared_fro 9 | from test.utils import check_estimator_convergence 10 | 11 | DISTRIBUTIONS = ["rademacher", "normal"] 12 | DISTRIBUTION_IDS = [f"distribution={distribution}" for distribution in DISTRIBUTIONS] 13 | 14 | NUM_MATVECS = [2, 8] 15 | NUM_MATVEC_IDS = [f"num_matvecs={num_matvecs}" for num_matvecs in NUM_MATVECS] 16 | 17 | 18 | @mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS) 19 | @mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS) 20 | def test_hutchinson_squared_fro(num_matvecs: int, distribution: str): 21 | """Test whether Hutchinson's squared Frobenius norm estimator converges. 22 | 23 | Args: 24 | num_matvecs: Number of matrix-vector products used per estimate. 25 | distribution: Distribution of the random vectors used for the estimation. 26 | """ 27 | manual_seed(0) 28 | A = rand(10, 15) 29 | estimator = partial( 30 | hutchinson_squared_fro, A=A, num_matvecs=num_matvecs, distribution=distribution 31 | ) 32 | check_estimator_convergence(estimator, num_matvecs, (A**2).sum()) 33 | -------------------------------------------------------------------------------- /test/test_fisher.py: -------------------------------------------------------------------------------- 1 | """Contains tests for ``curvlinops/fisher.py``.""" 2 | 3 | from pytest import mark 4 | 5 | from curvlinops import FisherMCLinearOperator 6 | from curvlinops.examples.functorch import functorch_ggn 7 | from test.utils import compare_consecutive_matmats, compare_matmat_expectation 8 | 9 | MAX_REPEATS_MC_SAMPLES = [(10_000, 1), (100, 100)] 10 | MAX_REPEATS_MC_SAMPLES_IDS = [ 11 | f"max_repeats={n}-mc_samples={m}" for (n, m) in MAX_REPEATS_MC_SAMPLES 12 | ] 13 | CHECK_EVERY = 100 14 | 15 | 16 | @mark.montecarlo 17 | @mark.parametrize( 18 | "max_repeats,mc_samples", MAX_REPEATS_MC_SAMPLES, ids=MAX_REPEATS_MC_SAMPLES_IDS 19 | ) 20 | def test_FisherMCLinearOperator_expectation( 21 | case, adjoint: bool, is_vec: bool, max_repeats: int, mc_samples: int 22 | ): 23 | """Test matrix-matrix multiplication with the Monte-Carlo Fisher. 24 | 25 | Args: 26 | case: Tuple of model, loss function, parameters, data, and batch size getter. 27 | adjoint: Whether to test the adjoint operator. 28 | is_vec: Whether to test matrix-vector or matrix-matrix multiplication. 29 | """ 30 | model_func, loss_func, params, data, batch_size_fn = case 31 | 32 | F = FisherMCLinearOperator( 33 | model_func, 34 | loss_func, 35 | params, 36 | data, 37 | batch_size_fn=batch_size_fn, 38 | mc_samples=mc_samples, 39 | ) 40 | G_mat = functorch_ggn(model_func, loss_func, params, data, input_key="x") 41 | 42 | compare_consecutive_matmats(F, adjoint, is_vec) 43 | compare_matmat_expectation( 44 | F, G_mat, adjoint, is_vec, max_repeats, CHECK_EVERY, rtol=2e-1, atol=5e-3 45 | ) 46 | -------------------------------------------------------------------------------- /test/test_jacobian.py: -------------------------------------------------------------------------------- 1 | """Contains tests for ``curvlinops/jacobian``.""" 2 | 3 | from curvlinops import JacobianLinearOperator, TransposedJacobianLinearOperator 4 | from curvlinops.examples.functorch import functorch_jacobian 5 | from test.utils import compare_consecutive_matmats, compare_matmat 6 | 7 | 8 | def test_JacobianLinearOperator(case, adjoint: bool, is_vec: bool): 9 | """Test matrix-matrix multiplication with the Jacobian. 10 | 11 | Args: 12 | case: Tuple of model, loss function, parameters, data, and batch size getter. 13 | adjoint: Whether to test the adjoint operator. 14 | is_vec: Whether to test matrix-vector or matrix-matrix multiplication. 15 | """ 16 | model_func, _, params, data, batch_size_fn = case 17 | 18 | J = JacobianLinearOperator(model_func, params, data, batch_size_fn=batch_size_fn) 19 | J_mat = functorch_jacobian(model_func, params, data, input_key="x") 20 | 21 | compare_consecutive_matmats(J, adjoint, is_vec) 22 | compare_matmat(J, J_mat, adjoint, is_vec, rtol=1e-4, atol=1e-7) 23 | 24 | 25 | def test_TransposedJacobianLinearOperator(case, adjoint: bool, is_vec: bool): 26 | """Test matrix-matrix multiplication with the transpose Jacobian. 27 | 28 | Args: 29 | case: Tuple of model, loss function, parameters, data, and batch size getter. 30 | adjoint: Whether to test the adjoint operator. 31 | is_vec: Whether to test matrix-vector or matrix-matrix multiplication. 32 | """ 33 | model_func, _, params, data, batch_size_fn = case 34 | 35 | JT = TransposedJacobianLinearOperator( 36 | model_func, params, data, batch_size_fn=batch_size_fn 37 | ) 38 | JT_mat = functorch_jacobian(model_func, params, data, input_key="x").T 39 | 40 | compare_consecutive_matmats(JT, adjoint, is_vec) 41 | compare_matmat(JT, JT_mat, adjoint, is_vec, rtol=1e-4, atol=1e-7) 42 | -------------------------------------------------------------------------------- /test/test_hessian.py: -------------------------------------------------------------------------------- 1 | """Contains tests for ``curvlinops/hessian``.""" 2 | 3 | from typing import Callable, List, Optional 4 | 5 | from torch import block_diag 6 | from torch.nn import Parameter 7 | 8 | from curvlinops import HessianLinearOperator 9 | from curvlinops.examples.functorch import functorch_hessian 10 | from curvlinops.utils import split_list 11 | from test.utils import compare_consecutive_matmats, compare_matmat 12 | 13 | 14 | def test_HessianLinearOperator( 15 | case, 16 | adjoint: bool, 17 | is_vec: bool, 18 | block_sizes_fn: Callable[[List[Parameter]], Optional[List[int]]], 19 | ): 20 | """Test matrix-matrix multiplication with the Hessian. 21 | 22 | Args: 23 | case: Tuple of model, loss function, parameters, data, and batch size getter. 24 | adjoint: Whether to test the adjoint operator. 25 | is_vec: Whether to test matrix-vector or matrix-matrix multiplication. 26 | block_sizes_fn: The function that generates the block sizes used to define 27 | block diagonal approximations from the parameters. 28 | """ 29 | model_func, loss_func, params, data, batch_size_fn = case 30 | block_sizes = block_sizes_fn(params) 31 | 32 | H = HessianLinearOperator( 33 | model_func, 34 | loss_func, 35 | params, 36 | data, 37 | batch_size_fn=batch_size_fn, 38 | block_sizes=block_sizes, 39 | ) 40 | 41 | # compute the blocks with functorch and build the block diagonal matrix 42 | H_blocks = [ 43 | functorch_hessian(model_func, loss_func, params_block, data, input_key="x") 44 | for params_block in split_list( 45 | params, [len(params)] if block_sizes is None else block_sizes 46 | ) 47 | ] 48 | H_mat = block_diag(*H_blocks) 49 | 50 | compare_consecutive_matmats(H, adjoint, is_vec) 51 | compare_matmat(H, H_mat, adjoint, is_vec, rtol=1e-4, atol=1e-6) 52 | -------------------------------------------------------------------------------- /curvlinops/sampling.py: -------------------------------------------------------------------------------- 1 | """Sampling methods for random vectors.""" 2 | 3 | from torch import Tensor, device, dtype, empty, randn 4 | 5 | 6 | def rademacher(dim: int, device: device, dtype: dtype) -> Tensor: 7 | """Draw a vector with i.i.d. Rademacher elements. 8 | 9 | Args: 10 | dim: Dimension of the vector. 11 | device: Device on which the vector is allocated. 12 | dtype: Data type of the vector. 13 | 14 | Returns: 15 | Vector with i.i.d. Rademacher elements and specified dimension. 16 | """ 17 | p_success = 0.5 18 | return empty(dim, device=device, dtype=dtype).bernoulli_(p_success).mul_(2).sub_(1) 19 | 20 | 21 | def normal(dim: int, device: device, dtype: dtype) -> Tensor: 22 | """Draw a vector with i.i.d. standard normal elements. 23 | 24 | Args: 25 | dim: Dimension of the vector. 26 | device: Device on which the vector is allocated. 27 | dtype: Data type of the vector. 28 | 29 | Returns: 30 | Vector with i.i.d. standard normal elements and specified dimension. 31 | """ 32 | return randn(dim, device=device, dtype=dtype) 33 | 34 | 35 | def random_vector(dim: int, distribution: str, device: device, dtype: dtype) -> Tensor: 36 | """Draw a vector with i.i.d. elements. 37 | 38 | Args: 39 | dim: Dimension of the vector. 40 | distribution: Distribution of the vector's elements. Either ``'rademacher'`` or 41 | ``'normal'``. 42 | device: Device on which the vector is allocated. 43 | dtype: Data type of the vector. 44 | 45 | Returns: 46 | Vector with i.i.d. elements and specified dimension. 47 | 48 | Raises: 49 | ValueError: If the distribution is unknown. 50 | """ 51 | if distribution == "rademacher": 52 | return rademacher(dim, device, dtype) 53 | elif distribution == "normal": 54 | return normal(dim, device, dtype) 55 | else: 56 | raise ValueError(f"Unknown distribution {distribution:!r}.") 57 | -------------------------------------------------------------------------------- /curvlinops/__init__.py: -------------------------------------------------------------------------------- 1 | """``curvlinops`` library API.""" 2 | 3 | from curvlinops.diagonal.epperly2024xtrace import xdiag 4 | from curvlinops.diagonal.hutchinson import hutchinson_diag 5 | from curvlinops.ekfac import EKFACLinearOperator 6 | from curvlinops.fisher import FisherMCLinearOperator 7 | from curvlinops.ggn import GGNLinearOperator 8 | from curvlinops.gradient_moments import EFLinearOperator 9 | from curvlinops.hessian import HessianLinearOperator 10 | from curvlinops.inverse import ( 11 | CGInverseLinearOperator, 12 | KFACInverseLinearOperator, 13 | LSMRInverseLinearOperator, 14 | NeumannInverseLinearOperator, 15 | ) 16 | from curvlinops.jacobian import JacobianLinearOperator, TransposedJacobianLinearOperator 17 | from curvlinops.kfac import FisherType, KFACLinearOperator, KFACType 18 | from curvlinops.norm.hutchinson import hutchinson_squared_fro 19 | from curvlinops.papyan2020traces.spectrum import ( 20 | LanczosApproximateLogSpectrumCached, 21 | LanczosApproximateSpectrumCached, 22 | lanczos_approximate_log_spectrum, 23 | lanczos_approximate_spectrum, 24 | ) 25 | from curvlinops.submatrix import SubmatrixLinearOperator 26 | from curvlinops.trace.epperly2024xtrace import xtrace 27 | from curvlinops.trace.hutchinson import hutchinson_trace 28 | from curvlinops.trace.meyer2020hutch import hutchpp_trace 29 | 30 | __all__ = [ 31 | # linear operators 32 | "HessianLinearOperator", 33 | "GGNLinearOperator", 34 | "EFLinearOperator", 35 | "FisherMCLinearOperator", 36 | "KFACLinearOperator", 37 | "EKFACLinearOperator", 38 | "JacobianLinearOperator", 39 | "TransposedJacobianLinearOperator", 40 | # Enums 41 | "FisherType", 42 | "KFACType", 43 | # inversion 44 | "CGInverseLinearOperator", 45 | "LSMRInverseLinearOperator", 46 | "NeumannInverseLinearOperator", 47 | "KFACInverseLinearOperator", 48 | # slicing 49 | "SubmatrixLinearOperator", 50 | # spectral properties 51 | "lanczos_approximate_spectrum", 52 | "lanczos_approximate_log_spectrum", 53 | "LanczosApproximateSpectrumCached", 54 | "LanczosApproximateLogSpectrumCached", 55 | # trace estimation 56 | "hutchinson_trace", 57 | "hutchpp_trace", 58 | "xtrace", 59 | # diagonal estimation 60 | "hutchinson_diag", 61 | "xdiag", 62 | # norm estimation 63 | "hutchinson_squared_fro", 64 | ] 65 | -------------------------------------------------------------------------------- /docs/rtd/assets/logo.tex: -------------------------------------------------------------------------------- 1 | \documentclass{standalone} 2 | 3 | \usepackage{tikz} 4 | 5 | \usetikzlibrary{shapes, 6 | shadings, 7 | calc, 8 | arrows, 9 | backgrounds, 10 | colorbrewer, 11 | shadows.blur} 12 | 13 | % bold math smbols 14 | \usepackage{bm} 15 | 16 | % colors 17 | \usepackage{xcolor} 18 | \definecolor{TUanthrazit}{RGB}{50, 65, 75} 19 | 20 | 21 | % from https://tex.stackexchange.com/a/216159 22 | \newcommand{\cube}[1]{ 23 | \begin{tikzpicture} 24 | % Settings 25 | \coordinate (CenterPoint) at (0,0); 26 | \def\width{0.7cm}; 27 | \def\height{0.7cm}; 28 | \def\textborder{0.1cm}; 29 | \def\xslant{0.25cm}; 30 | \def\yslant{0.15cm}; 31 | \def\rounding{0.2pt}; 32 | % Drawing 33 | \node[draw=white, draw, 34 | minimum height = \height, 35 | minimum width = \width, 36 | text width = {\width-2*\textborder}, 37 | align = center, 38 | fill opacity = 0.9, 39 | fill = TUanthrazit!50, 40 | rounded corners = \rounding] 41 | (front) 42 | at (CenterPoint) {#1}; 43 | % overlay front with Hessian silhouette 44 | \begin{pgfonlayer}{background} 45 | \node [inner sep=0pt] at (front) {\includegraphics[width=\width]{hessian.pdf}}; 46 | \end{pgfonlayer} 47 | % "3D" top 48 | \draw [rounded corners = \rounding, draw=white, fill=TUanthrazit!70] % 49 | ($(CenterPoint) + (-\width/2. - 2*\rounding, \height/2.)$) -- % 50 | ($(CenterPoint) + (-\width/2. + \xslant - 2*\rounding, \height/2. + \yslant)$) -- % 51 | ($(CenterPoint) + (\width/2. + \xslant + 2*\rounding, \height/2. + \yslant)$) -- % 52 | ($(CenterPoint) + (\width/2. + 2*\rounding, \height/2.)$) -- % 53 | cycle; 54 | % "3D" side 55 | \draw [rounded corners = \rounding, draw=white, fill=TUanthrazit!90] % 56 | ($(CenterPoint) + (\width/2. + \xslant + 2*\rounding, \height/2. + \yslant)$) -- % 57 | ($(CenterPoint) + (\width/2. + 2*\rounding, \height/2.)$) -- % 58 | ($(CenterPoint) + (\width/2. + 2*\rounding, -\height/2.)$) -- % 59 | ($(CenterPoint) + (\width/2. + \xslant + 2*\rounding, -\height/2. + \yslant)$) -- % 60 | cycle; 61 | \end{tikzpicture} 62 | } 63 | 64 | \begin{document} 65 | 66 | \begin{tikzpicture} 67 | \node [inner sep=1pt] (A) {\cube{$\bm{A}$}}; 68 | 69 | \node [anchor=south east, xshift=1.05cm, yshift=0.0cm] (v) at (A.north west) 70 | {\textbf{cur}$\bm{v}_{\includegraphics[height=0.15cm]{torch.pdf}}$\textbf{linops}}; 71 | 72 | \node (Av) [anchor=north west, xshift=-1ex, yshift=0.2cm] at (A.south east) 73 | {$\bm{Av}_{\includegraphics[height=0.15cm]{torch.pdf}}$}; 74 | 75 | \draw [->, >=stealth, thick] ($(v.south)+(-2ex,0)$) |- (A.west); 76 | 77 | \draw [->, >=stealth, thick] (A.east) -| (Av.north); 78 | \end{tikzpicture} 79 | 80 | \end{document} 81 | -------------------------------------------------------------------------------- /docs/rtd/linops.rst: -------------------------------------------------------------------------------- 1 | Linear operators 2 | ================ 3 | 4 | 5 | Hessian 6 | ------- 7 | 8 | .. autoclass:: curvlinops.HessianLinearOperator 9 | :members: __init__ 10 | 11 | Generalized Gauss-Newton 12 | ------------------------ 13 | 14 | .. autoclass:: curvlinops.GGNLinearOperator 15 | :members: __init__ 16 | 17 | Fisher (approximate) 18 | -------------------- 19 | 20 | .. autoclass:: curvlinops.FisherMCLinearOperator 21 | :members: __init__ 22 | 23 | .. autoclass:: curvlinops.KFACLinearOperator 24 | :members: __init__, trace, det, logdet, frobenius_norm, state_dict, load_state_dict, from_state_dict 25 | 26 | .. autoclass:: curvlinops.EKFACLinearOperator 27 | :members: __init__, trace, det, logdet, frobenius_norm, state_dict, load_state_dict, from_state_dict 28 | 29 | .. autoclass:: curvlinops.FisherType 30 | 31 | .. autoclass:: curvlinops.KFACType 32 | 33 | Uncentered gradient covariance (empirical Fisher) 34 | ------------------------------------------------- 35 | 36 | .. autoclass:: curvlinops.EFLinearOperator 37 | :members: __init__ 38 | 39 | Jacobians 40 | --------- 41 | 42 | .. autoclass:: curvlinops.JacobianLinearOperator 43 | :members: __init__ 44 | 45 | .. autoclass:: curvlinops.TransposedJacobianLinearOperator 46 | :members: __init__ 47 | 48 | Inverses 49 | -------- 50 | 51 | .. autoclass:: curvlinops.CGInverseLinearOperator 52 | :members: __init__ 53 | 54 | .. autoclass:: curvlinops.LSMRInverseLinearOperator 55 | :members: __init__ 56 | 57 | .. autoclass:: curvlinops.NeumannInverseLinearOperator 58 | :members: __init__ 59 | 60 | .. autoclass:: curvlinops.KFACInverseLinearOperator 61 | :members: __init__ 62 | 63 | Sub-matrices 64 | ------------ 65 | 66 | .. autoclass:: curvlinops.SubmatrixLinearOperator 67 | :members: __init__, set_submatrix 68 | 69 | Spectral density approximation 70 | ============================== 71 | 72 | .. note:: 73 | This functionality currently expects SciPy ``LinearOperator``s. 74 | 75 | .. autofunction:: curvlinops.lanczos_approximate_spectrum 76 | 77 | .. autofunction:: curvlinops.lanczos_approximate_log_spectrum 78 | 79 | .. autoclass:: curvlinops.LanczosApproximateSpectrumCached 80 | :members: __init__, approximate_spectrum 81 | 82 | Trace approximation 83 | =================== 84 | 85 | .. autofunction:: curvlinops.hutchinson_trace 86 | 87 | .. autofunction:: curvlinops.hutchpp_trace 88 | 89 | .. autofunction:: curvlinops.xtrace 90 | 91 | Diagonal approximation 92 | ====================== 93 | 94 | .. autofunction:: curvlinops.hutchinson_diag 95 | 96 | .. autofunction:: curvlinops.xdiag 97 | 98 | Frobenius norm approximation 99 | ============================ 100 | 101 | .. autoclass:: curvlinops.hutchinson_squared_fro 102 | 103 | Experimental 104 | ============ 105 | 106 | Experimental features may be subject to changes or become deprecated. 107 | 108 | .. autoclass:: curvlinops.experimental.ActivationHessianLinearOperator 109 | :members: __init__ 110 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT: help 2 | 3 | help: 4 | @echo "install" 5 | @echo " Install curvlinops and dependencies" 6 | @echo "uninstall" 7 | @echo " Uninstall curvlinops" 8 | @echo "lint" 9 | @echo " Run all linting actions" 10 | @echo "docs" 11 | @echo " Build the documentation" 12 | @echo "install-dev" 13 | @echo " Install curvlinops and development tools" 14 | @echo "install-docs" 15 | @echo " Install curvlinops and documentation tools" 16 | @echo "install-test" 17 | @echo " Install curvlinops and testing tools" 18 | @echo "test-light" 19 | @echo " Run pytest on the light part of project and report coverage" 20 | @echo "test" 21 | @echo " Run pytest on test and report coverage" 22 | @echo "install-lint" 23 | @echo " Install curvlinops and the linter tools" 24 | @echo "ruff-format" 25 | @echo " Run ruff format on the project" 26 | @echo "ruff-format-check" 27 | @echo " Check if ruff format would change files" 28 | @echo "ruff" 29 | @echo " Run ruff on the project and fix errors" 30 | @echo "ruff-check" 31 | @echo " Run ruff check on the project without fixing errors" 32 | @echo "conda-env" 33 | @echo " Create conda environment 'curvlinops' with dev setup" 34 | @echo "darglint-check" 35 | @echo " Run darglint (docstring check) on the project" 36 | @echo "pydocstyle-check" 37 | @echo " Run pydocstyle (docstring check) on the project" 38 | 39 | .PHONY: install 40 | 41 | install: 42 | @pip install -e . 43 | 44 | .PHONY: uninstall 45 | 46 | uninstall: 47 | @pip uninstall curvlinops-for-pytorch 48 | 49 | .PHONY: docs 50 | 51 | docs: 52 | @cd docs/rtd && make html 53 | @echo "\nOpen docs/rtd/index.html to see the result." 54 | 55 | .PHONY: install-dev 56 | 57 | install-dev: install-lint install-test install-docs 58 | 59 | .PHONY: install-docs 60 | 61 | install-docs: 62 | @pip install -e .[docs] 63 | 64 | .PHONY: install-test 65 | 66 | install-test: 67 | @pip install -e .[test] 68 | 69 | .PHONY: test test-light 70 | 71 | test: 72 | @pytest -vx --run-optional-tests=montecarlo --cov=curvlinops --doctest-modules curvlinops test 73 | 74 | test-light: 75 | @pytest -vx --cov=curvlinops --doctest-modules curvlinops test 76 | 77 | .PHONY: install-lint 78 | 79 | install-lint: 80 | @pip install -e .[lint] 81 | 82 | .PHONY: ruff-format ruff-format-check 83 | 84 | ruff-format: 85 | @ruff format . 86 | 87 | ruff-format-check: 88 | @ruff format --check . 89 | 90 | .PHONY: ruff-check 91 | 92 | ruff: 93 | @ruff check . --fix 94 | 95 | ruff-check: 96 | @ruff check . 97 | 98 | .PHONY: darglint-check 99 | 100 | darglint-check: 101 | @darglint --verbosity 2 curvlinops 102 | 103 | .PHONY: pydocstyle-check 104 | 105 | pydocstyle-check: 106 | @pydocstyle --count . 107 | 108 | .PHONY: conda-env 109 | 110 | conda-env: 111 | @conda env create --file .conda_env.yml 112 | 113 | .PHONY: lint 114 | 115 | lint: 116 | make ruff-format-check 117 | make ruff-check 118 | make darglint-check 119 | make pydocstyle-check 120 | -------------------------------------------------------------------------------- /test/papyan2020traces/test_spectrum.py: -------------------------------------------------------------------------------- 1 | """Contains tests of ``curvlinops/papyan2020traces/spectrum``.""" 2 | 3 | from math import isclose 4 | from pathlib import Path 5 | from subprocess import CalledProcessError, check_output 6 | 7 | from torch import tensor 8 | 9 | from curvlinops.examples import TensorLinearOperator 10 | from curvlinops.papyan2020traces.spectrum import ( 11 | approximate_boundaries, 12 | approximate_boundaries_abs, 13 | ) 14 | 15 | PROJECT_ROOT = Path(__file__).parent.parent.parent 16 | BASIC_USAGE = PROJECT_ROOT / "docs" / "examples" / "basic_usage" 17 | 18 | 19 | def test_example_verification_spectral_density(): 20 | """Integration test to check that the verification example is working. 21 | 22 | It is hard to test the spectral density estimation techniques. This test 23 | uses the verification example from the documentation. 24 | """ 25 | EXAMPLE_VERIFICATION_SPECTRAL_DENSITY = ( 26 | BASIC_USAGE / "example_verification_spectral_density.py" 27 | ) 28 | 29 | try: 30 | check_output(f"python {EXAMPLE_VERIFICATION_SPECTRAL_DENSITY}", shell=True) 31 | except CalledProcessError as e: 32 | print(e.output) 33 | raise e 34 | 35 | 36 | def test_approximate_boundaries(): 37 | """Test spectrum boundary approximation with partially supplied boundaries.""" 38 | A_diag = tensor([1.0, 2.0, 3.0, 4.0, 5.0]).double() 39 | A_matrix = A_diag.diag() 40 | A = TensorLinearOperator(A_matrix) 41 | lambda_min, lambda_max = A_diag.min().item(), A_diag.max().item() 42 | 43 | cases = [ 44 | [(0.0, 10.0), (0.0, 10.0)], 45 | [(1.5, None), (1.5, lambda_max)], 46 | [(None, 2.5), (lambda_min, 2.5)], 47 | [(None, None), (lambda_min, lambda_max)], 48 | [None, (lambda_min, lambda_max)], 49 | ] 50 | 51 | for inputs, results in cases: 52 | output = approximate_boundaries(A, boundaries=inputs) 53 | assert len(output) == 2 54 | assert isinstance(output[0], float) 55 | assert isinstance(output[1], float) 56 | assert isclose(output[0], results[0]) and isclose(output[1], results[1]) 57 | 58 | 59 | def test_approximate_boundaries_abs(): 60 | """Test abs spectrum boundary approximation with partially supplied boundaries.""" 61 | A_diag = tensor([-2.0, -1.0, 3.0, 4.0, 5.0]).double() 62 | A_matrix = A_diag.diag() 63 | A = TensorLinearOperator(A_matrix) 64 | lambda_abs_min, lambda_abs_max = ( 65 | A_diag.abs().min().item(), 66 | A_diag.abs().max().item(), 67 | ) 68 | 69 | cases = [ 70 | [(0.0, 10.0), (0.0, 10.0)], 71 | [(1.5, None), (1.5, lambda_abs_max)], 72 | [(None, 2.5), (lambda_abs_min, 2.5)], 73 | [(None, None), (lambda_abs_min, lambda_abs_max)], 74 | [None, (lambda_abs_min, lambda_abs_max)], 75 | ] 76 | 77 | for inputs, results in cases: 78 | output = approximate_boundaries_abs(A, boundaries=inputs) 79 | assert len(output) == 2 80 | assert isinstance(output[0], float) 81 | assert isinstance(output[1], float) 82 | assert isclose(output[0], results[0]) and isclose(output[1], results[1]) 83 | -------------------------------------------------------------------------------- /curvlinops/norm/hutchinson.py: -------------------------------------------------------------------------------- 1 | """Hutchinson-style matrix norm estimation.""" 2 | 3 | from typing import Union 4 | 5 | from torch import Tensor, column_stack 6 | 7 | from curvlinops._torch_base import PyTorchLinearOperator 8 | from curvlinops.sampling import random_vector 9 | 10 | 11 | def hutchinson_squared_fro( 12 | A: Union[Tensor, PyTorchLinearOperator], 13 | num_matvecs: int, 14 | distribution: str = "rademacher", 15 | ) -> Tensor: 16 | r"""Estimate the squared Frobenius norm of a matrix using Hutchinson's method. 17 | 18 | Let :math:`\mathbf{A} \in \mathbb{R}^{M \times N}` be some matrix. It's Frobenius 19 | norm :math:`\lVert\mathbf{A}\rVert_\text{F}` is defined via: 20 | 21 | .. math:: 22 | \lVert\mathbf{A}\rVert_\text{F}^2 23 | = 24 | \sum_{m=1}^M \sum_{n=1}^N \mathbf{A}_{n,m}^2 25 | = 26 | \text{Tr}(\mathbf{A}^\top \mathbf{A}). 27 | 28 | Due to the last equality, we can use Hutchinson-style trace estimation to estimate 29 | the squared Frobenius norm. 30 | 31 | Args: 32 | A: A matrix whose squared Frobenius norm is estimated. 33 | num_matvecs: Total number of matrix-vector products to use. Must be smaller 34 | than the minimum dimension of the matrix. 35 | distribution: Distribution of the random vectors used for the trace estimation. 36 | Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``. 37 | 38 | Returns: 39 | The estimated squared Frobenius norm of the matrix. 40 | 41 | Raises: 42 | ValueError: If the matrix is not two-dimensional or if the number of matrix- 43 | vector products is greater than the minimum dimension of the matrix 44 | (because then you can evaluate the true squared Frobenius norm directly 45 | atthe same cost). 46 | 47 | Example: 48 | >>> from torch.linalg import matrix_norm 49 | >>> from torch import rand, manual_seed 50 | >>> _ = manual_seed(0) # make deterministic 51 | >>> A = rand(40, 40) 52 | >>> fro2_A = matrix_norm(A).item()**2 # reference: exact squared Frobenius norm 53 | >>> # one- and multi-sample approximations 54 | >>> fro2_A_low_prec = hutchinson_squared_fro(A, num_matvecs=1).item() 55 | >>> fro2_A_high_prec = hutchinson_squared_fro(A, num_matvecs=30).item() 56 | >>> assert abs(fro2_A - fro2_A_low_prec) > abs(fro2_A - fro2_A_high_prec) 57 | >>> round(fro2_A, 1), round(fro2_A_low_prec, 1), round(fro2_A_high_prec, 1) 58 | (530.9, 156.7, 628.9) 59 | """ 60 | if len(A.shape) != 2: 61 | raise ValueError(f"A must be a matrix. Got shape {A.shape}.") 62 | dim = min(A.shape) 63 | if num_matvecs >= dim: 64 | raise ValueError( 65 | f"num_matvecs ({num_matvecs}) must be less than the minimum dimension of A." 66 | ) 67 | # Instead of AT @ A, use A @ AT if the matrix is wider than tall 68 | if A.shape[1] > A.shape[0]: 69 | A = A.T 70 | 71 | G = column_stack( 72 | [ 73 | random_vector(dim, distribution, A.device, A.dtype) 74 | for _ in range(num_matvecs) 75 | ] 76 | ) 77 | AG = A @ G 78 | return (AG**2 / num_matvecs).sum() 79 | -------------------------------------------------------------------------------- /curvlinops/diagonal/hutchinson.py: -------------------------------------------------------------------------------- 1 | """Hutchinson-style matrix diagonal estimation.""" 2 | 3 | from typing import Union 4 | 5 | from torch import Tensor, column_stack, einsum 6 | 7 | from curvlinops._torch_base import PyTorchLinearOperator 8 | from curvlinops.sampling import random_vector 9 | from curvlinops.utils import assert_is_square, assert_matvecs_subseed_dim 10 | 11 | 12 | def hutchinson_diag( 13 | A: Union[PyTorchLinearOperator, Tensor], 14 | num_matvecs: int, 15 | distribution: str = "rademacher", 16 | ) -> Tensor: 17 | r"""Estimate a linear operator's diagonal using Hutchinson's method. 18 | 19 | For details, see 20 | 21 | - Bekas, C., Kokiopoulou, E., & Saad, Y. (2007). An estimator for the diagonal 22 | of a matrix. Applied Numerical Mathematics. 23 | 24 | Let :math:`\mathbf{A}` be a square linear operator. We can approximate its diagonal 25 | :math:`\mathrm{diag}(\mathbf{A})` by drawing random vectors :math:`N` 26 | :math:`\mathbf{v}_n \sim \mathbf{v}` from a distribution :math:`\mathbf{v}` that 27 | satisfies :math:`\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}`, and compute 28 | the estimator 29 | 30 | .. math:: 31 | \mathbf{a} 32 | := \frac{1}{N} \sum_{n=1}^N \mathbf{v}_n \odot \mathbf{A} \mathbf{v}_n 33 | \approx \mathrm{diag}(\mathbf{A})\,. 34 | 35 | This estimator is unbiased, 36 | 37 | .. math:: 38 | \mathbb{E}[a_i] 39 | = \sum_j \mathbb{E}[v_i A_{i,j} v_j] 40 | = \sum_j A_{i,j} \mathbb{E}[v_i v_j] 41 | = \sum_j A_{i,j} \delta_{i, j} 42 | = A_{i,i}\,. 43 | 44 | Args: 45 | A: A square linear operator whose diagonal is estimated. 46 | num_matvecs: Total number of matrix-vector products to use. Must be smaller 47 | than the dimension of the linear operator (because otherwise one can 48 | evaluate the true diagonal directly at the same cost). 49 | distribution: Distribution of the random vectors used for the diagonal 50 | estimation. Can be either ``'rademacher'`` or ``'normal'``. 51 | Default: ``'rademacher'``. 52 | 53 | Returns: 54 | The estimated diagonal of the linear operator. 55 | 56 | Example: 57 | >>> from torch import manual_seed, rand 58 | >>> from torch.linalg import vector_norm 59 | >>> _ = manual_seed(0) # make deterministic 60 | >>> A = rand(40, 40) 61 | >>> diag_A = A.diag() # exact diagonal as reference 62 | >>> # one- and multi-sample approximations 63 | >>> diag_A_low_precision = hutchinson_diag(A, num_matvecs=1) 64 | >>> diag_A_high_precision = hutchinson_diag(A, num_matvecs=30) 65 | >>> # compute residual norms 66 | >>> error_low_precision = (vector_norm(diag_A - diag_A_low_precision) / vector_norm(diag_A)).item() 67 | >>> error_high_precision = (vector_norm(diag_A - diag_A_high_precision) / vector_norm(diag_A)).item() 68 | >>> assert error_low_precision > error_high_precision 69 | >>> round(error_low_precision, 4), round(error_high_precision, 4) 70 | (3.2648, 0.9253) 71 | """ 72 | dim = assert_is_square(A) 73 | assert_matvecs_subseed_dim(A, num_matvecs) 74 | G = column_stack( 75 | [ 76 | random_vector(dim, distribution, A.device, A.dtype) 77 | for _ in range(num_matvecs) 78 | ] 79 | ) 80 | 81 | return einsum("ij,ij->i", G, A @ G) / num_matvecs 82 | -------------------------------------------------------------------------------- /curvlinops/diagonal/epperly2024xtrace.py: -------------------------------------------------------------------------------- 1 | """Implements the XDiag algorithm from Epperly 2024.""" 2 | 3 | from typing import Union 4 | 5 | from torch import Tensor, column_stack, dot, einsum 6 | from torch.linalg import inv, qr 7 | 8 | from curvlinops._torch_base import PyTorchLinearOperator 9 | from curvlinops.sampling import random_vector 10 | from curvlinops.utils import ( 11 | assert_divisible_by, 12 | assert_is_square, 13 | assert_matvecs_subseed_dim, 14 | ) 15 | 16 | 17 | def xdiag(A: Union[PyTorchLinearOperator, Tensor], num_matvecs: int) -> Tensor: 18 | """Estimate a linear operator's diagonal using the XDiag algorithm. 19 | 20 | The method is presented in `this paper `_: 21 | 22 | - Epperly, E. N., Tropp, J. A., & Webber, R. J. (2024). Xtrace: making the most 23 | of every sample in stochastic trace estimation. SIAM Journal on Matrix Analysis 24 | and Applications (SIMAX). 25 | 26 | It combines the variance reduction from Diag++ with the exchangeability principle. 27 | 28 | Args: 29 | A: A square linear operator. 30 | num_matvecs: Total number of matrix-vector products to use. Must be even and 31 | less than the dimension of the linear operator (because otherwise one can 32 | evaluate the true diagonal directly at the same cost). 33 | 34 | Returns: 35 | The estimated diagonal of the linear operator. 36 | """ 37 | dim = assert_is_square(A) 38 | assert_matvecs_subseed_dim(A, num_matvecs) 39 | assert_divisible_by(num_matvecs, 2, "num_matvecs") 40 | 41 | # draw random vectors and compute their matrix-vector products 42 | num_vecs = num_matvecs // 2 43 | W = column_stack( 44 | [random_vector(dim, "rademacher", A.device, A.dtype) for _ in range(num_vecs)] 45 | ) 46 | A_W = A @ W 47 | 48 | # compute the orthogonal basis for all test vectors, and its associated diagonal 49 | Q, R = qr(A_W) 50 | QT_A = (A.adjoint() @ Q).T 51 | diag_Q_QT_A = einsum("ij,ji->i", Q, QT_A) 52 | 53 | # Compute and average the diagonals in the bases {Q_i} that would result had we left 54 | # out the i-th test vector in the QR decomposition. This follows by considering 55 | # diag(Q_i QT_i A) and using the relation Q_i QT_i = Q (I - s_i sT_i) QT, where the 56 | # s_i are given by: 57 | RT_inv = inv(R.T) 58 | D = 1 / (RT_inv**2).sum(0) ** 0.5 59 | S = einsum("ij,j->ij", RT_inv, D) 60 | # Further simplification then leads to 61 | diagonal = diag_Q_QT_A - einsum("ij,jk,lk,li->i", Q, S, S, QT_A) / num_vecs 62 | 63 | def deflate(v: Tensor, s: Tensor) -> Tensor: 64 | """Apply (I - s sT) to a vector. 65 | 66 | Args: 67 | v: Vector to deflate. 68 | s: Deflation vector. 69 | 70 | Returns: 71 | Deflated vector. 72 | """ 73 | return v - dot(s, v) * s 74 | 75 | # estimate the diagonal on the complement of Q_i with vanilla Hutchinson using the 76 | # i-th test vector 77 | for i in range(num_vecs): 78 | w_i = W[:, i] 79 | s_i = S[:, i] 80 | A_w_i = A_W[:, i] 81 | 82 | # Compute (I - Q_i QT_i) A w_i 83 | # = A w_i - (I - Q_i QT_i) A w_i 84 | # ( using that Q_i QT_i = Q (I - s_i sT_i) QT ) 85 | # = A w_i - Q (I - s_i sT_i) QT A w_i 86 | A_comp_w_i = A_w_i - Q @ deflate(QT_A @ w_i, s_i) 87 | 88 | diag_w_i = w_i * A_comp_w_i / w_i**2 89 | diagonal += diag_w_i / num_vecs 90 | 91 | return diagonal 92 | -------------------------------------------------------------------------------- /curvlinops/trace/hutchinson.py: -------------------------------------------------------------------------------- 1 | """Vanilla Hutchinson trace estimation.""" 2 | 3 | from typing import Union 4 | 5 | from torch import Tensor, column_stack, einsum 6 | 7 | from curvlinops._torch_base import PyTorchLinearOperator 8 | from curvlinops.sampling import random_vector 9 | from curvlinops.utils import assert_is_square, assert_matvecs_subseed_dim 10 | 11 | 12 | def hutchinson_trace( 13 | A: Union[Tensor, PyTorchLinearOperator], 14 | num_matvecs: int, 15 | distribution: str = "rademacher", 16 | ) -> Tensor: 17 | r"""Estimate a linear operator's trace using the Girard-Hutchinson method. 18 | 19 | For details, see 20 | 21 | - Girard, D. A. (1989). A fast 'monte-carlo cross-validation' procedure for 22 | large least squares problems with noisy data. Numerische Mathematik. 23 | - Hutchinson, M. (1989). A stochastic estimator of the trace of the influence 24 | matrix for laplacian smoothing splines. Communication in Statistics---Simulation 25 | and Computation. 26 | 27 | Let :math:`\mathbf{A}` be a square linear operator. We can approximate its trace 28 | :math:`\mathrm{Tr}(\mathbf{A})` by drawing :math:`N` random vectors 29 | :math:`\mathbf{v}_n \sim \mathbf{v}` from a distribution that satisfies 30 | :math:`\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}` and compute 31 | 32 | .. math:: 33 | a := \frac{1}{N} \sum_{n=1}^N \mathbf{v}_n^\top \mathbf{A} \mathbf{v}_n 34 | \approx \mathrm{Tr}(\mathbf{A})\,. 35 | 36 | This estimator is unbiased, 37 | 38 | .. math:: 39 | \mathbb{E}[a] 40 | = \mathrm{Tr}(\mathbb{E}[\mathbf{v}^\top\mathbf{A} \mathbf{v}]) 41 | = \mathrm{Tr}(\mathbf{A} \mathbb{E}[\mathbf{v} \mathbf{v}^\top]) 42 | = \mathrm{Tr}(\mathbf{A} \mathbf{I}) 43 | = \mathrm{Tr}(\mathbf{A})\,. 44 | 45 | Args: 46 | A: A square linear operator whose trace is estimated. 47 | num_matvecs: Total number of matrix-vector products to use. Must be smaller 48 | than the dimension of the linear operator (because otherwise one can 49 | evaluate the true trace directly at the same cost). 50 | distribution: Distribution of the random vectors used for the trace estimation. 51 | Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``. 52 | 53 | Returns: 54 | The estimated trace of the linear operator. 55 | 56 | Example: 57 | >>> from torch import manual_seed, rand 58 | >>> _ = manual_seed(0) # make deterministic 59 | >>> A = rand(50, 50) 60 | >>> tr_A = A.trace().item() # exact trace as reference 61 | >>> # one- and multi-sample approximations 62 | >>> tr_A_low_precision = hutchinson_trace(A, num_matvecs=1).item() 63 | >>> tr_A_high_precision = hutchinson_trace(A, num_matvecs=40).item() 64 | >>> # compute the relative errors 65 | >>> rel_error_low_precision = abs(tr_A - tr_A_low_precision) / abs(tr_A) 66 | >>> rel_error_high_precision = abs(tr_A - tr_A_high_precision) / abs(tr_A) 67 | >>> assert rel_error_low_precision > rel_error_high_precision 68 | >>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4) 69 | (23.7836, -10.0279, 20.8427) 70 | """ 71 | dim = assert_is_square(A) 72 | assert_matvecs_subseed_dim(A, num_matvecs) 73 | G = column_stack( 74 | [ 75 | random_vector(dim, distribution, A.device, A.dtype) 76 | for _ in range(num_matvecs) 77 | ] 78 | ) 79 | 80 | return einsum("ij,ij", G, A @ G) / num_matvecs 81 | -------------------------------------------------------------------------------- /curvlinops/submatrix.py: -------------------------------------------------------------------------------- 1 | """Implements slices of linear operators.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import List 6 | 7 | from torch import Tensor, device, dtype, zeros 8 | 9 | from curvlinops._torch_base import PyTorchLinearOperator 10 | 11 | 12 | class SubmatrixLinearOperator(PyTorchLinearOperator): 13 | """Class for sub-matrices of linear operators.""" 14 | 15 | def __init__( 16 | self, A: PyTorchLinearOperator, row_idxs: List[int], col_idxs: List[int] 17 | ): 18 | """Store the linear operator and indices of its sub-matrix. 19 | 20 | Represents the sub-matrix ``A[row_idxs, :][col_idxs, :]``. 21 | 22 | Args: 23 | A: A linear operator. 24 | row_idxs: The sub-matrix's row indices. 25 | col_idxs: The sub-matrix's column indices. 26 | """ 27 | self._A = A 28 | self.set_submatrix(row_idxs, col_idxs) 29 | 30 | @property 31 | def dtype(self) -> dtype: 32 | """Determine the linear operator's data type. 33 | 34 | Returns: 35 | The linear operator's dtype. 36 | """ 37 | return self._A.dtype 38 | 39 | @property 40 | def device(self) -> device: 41 | """Determine the device the linear operators is defined on. 42 | 43 | Returns: 44 | The linear operator's device. 45 | """ 46 | return self._A.device 47 | 48 | def set_submatrix(self, row_idxs: List[int], col_idxs: List[int]): 49 | """Define the sub-matrix. 50 | 51 | Internally sets the linear operator's shape. 52 | 53 | Args: 54 | row_idxs: The sub-matrix's row indices. 55 | col_idxs: The sub-matrix's column indices. 56 | 57 | Raises: 58 | ValueError: If the index lists contain duplicate values, non-integers, 59 | or out-of-bounds indices. 60 | """ 61 | shape = [] 62 | 63 | for ax_idx, idxs in enumerate([row_idxs, col_idxs]): 64 | if any(not isinstance(i, int) for i in idxs): 65 | raise ValueError("Index lists must contain integers.") 66 | if len(idxs) != len(set(idxs)): 67 | raise ValueError("Index lists cannot contain duplicates.") 68 | if any(i < 0 or i >= self._A.shape[ax_idx] for i in idxs): 69 | raise ValueError("Index lists contain out-of-bounds indices.") 70 | shape.append(len(idxs)) 71 | 72 | in_shape, out_shape = [(shape[1],)], [(shape[0],)] 73 | super().__init__(in_shape, out_shape) 74 | self._row_idxs = row_idxs 75 | self._col_idxs = col_idxs 76 | 77 | def _matmat(self, X: List[Tensor]) -> List[Tensor]: 78 | """Matrix-matrix multiplication. 79 | 80 | Args: 81 | X: A list that contains a single tensor, which is the input tensor. 82 | 83 | Returns: 84 | A list that contains a single tensor, which is the output tensor. 85 | """ 86 | (M,) = X 87 | V = zeros(self._A.shape[1], M.shape[-1], dtype=self.dtype, device=self.device) 88 | V[self._col_idxs] = M 89 | AV = self._A @ V 90 | return [AV[self._row_idxs]] 91 | 92 | def _adjoint(self) -> SubmatrixLinearOperator: 93 | """Return the adjoint of the sub-matrix. 94 | 95 | For that, we need to take the adjoint operator, and swap row and column indices. 96 | 97 | Returns: 98 | The linear operator for the adjoint sub-matrix. 99 | """ 100 | return type(self)(self._A.adjoint(), self._col_idxs, self._row_idxs) 101 | -------------------------------------------------------------------------------- /docs/rtd/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "Curvlinops" 21 | copyright = "2022, F. Dangel, L. Tatzel, R. Eschenhagen" 22 | author = "F. Dangel, L. Tatzel, R. Eschenhagen" 23 | 24 | # -- General configuration --------------------------------------------------- 25 | 26 | # Add any Sphinx extension module names here, as strings. They can be 27 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 28 | # ones. 29 | extensions = [ 30 | "sphinx.ext.napoleon", 31 | "sphinx.ext.autodoc", 32 | "sphinx.ext.coverage", 33 | "sphinx.ext.autosectionlabel", 34 | "sphinx.ext.intersphinx", 35 | "sphinx_gallery.gen_gallery", 36 | "sphinx.ext.viewcode", # show source code links 37 | ] 38 | 39 | # -- Autodoc configuration -------------------------------------------------- 40 | 41 | # Set maximum line length for signatures before wrapping to multiple lines 42 | maximum_signature_line_length = 200 43 | 44 | # -- Intersphinx config ----------------------------------------------------- 45 | 46 | intersphinx_mapping = { 47 | "torch": ("https://docs.pytorch.org/docs/stable/", None), 48 | "backpack": ("https://docs.backpack.pt/en/master", None), 49 | "scipy": ("https://docs.scipy.org/doc/scipy", None), 50 | "numpy": ("https://numpy.org/doc/stable/", None), 51 | "matplotlib": ("https://matplotlib.org/stable/", None), 52 | "curvlinops": ("https://curvlinops.readthedocs.io/en/latest/", None), 53 | } 54 | 55 | # -- Sphinx Gallery config --------------------------------------------------- 56 | 57 | sphinx_gallery_conf = { 58 | "examples_dirs": [ 59 | "../examples/basic_usage", 60 | # "../examples/use_cases", 61 | ], # path to your example scripts 62 | "gallery_dirs": [ 63 | "basic_usage", 64 | # "use_cases", 65 | ], # path to where to save gallery generated output 66 | "default_thumb_file": "assets/logo.png", 67 | "filename_pattern": "example", 68 | "matplotlib_animations": True, 69 | "ignore_pattern": "memory_benchmark.py|benchmark_utils.py|nanogpt_model.py", 70 | } 71 | # Add any paths that contain templates here, relative to this directory. 72 | templates_path = ["_templates"] 73 | 74 | # List of patterns, relative to source directory, that match files and 75 | # directories to ignore when looking for source files. 76 | # This pattern also affects html_static_path and html_extra_path. 77 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 78 | 79 | 80 | # -- Options for HTML output ------------------------------------------------- 81 | 82 | # The theme to use for HTML and HTML Help pages. See the documentation for 83 | # a list of builtin themes. 84 | # 85 | html_theme = "sphinx_rtd_theme" 86 | html_logo = "assets/logo.svg" 87 | 88 | # Add any paths that contain custom static files (such as style sheets) here, 89 | # relative to this directory. They are copied after the builtin static files, 90 | # so a file named "default.css" will overwrite the builtin "default.css". 91 | # html_static_path = ["_static"] 92 | -------------------------------------------------------------------------------- /test/test_submatrix.py: -------------------------------------------------------------------------------- 1 | """Contains tests for ``curvlinops/submatrix`` on curvature linear operators.""" 2 | 3 | from typing import List 4 | 5 | from pytest import mark, raises 6 | 7 | from curvlinops import EFLinearOperator, GGNLinearOperator, HessianLinearOperator 8 | from curvlinops.examples.functorch import ( 9 | functorch_empirical_fisher, 10 | functorch_ggn, 11 | functorch_hessian, 12 | ) 13 | from curvlinops.submatrix import SubmatrixLinearOperator 14 | from test.utils import compare_consecutive_matmats, compare_matmat 15 | 16 | CURVATURE_IN_FUNCTORCH = { 17 | HessianLinearOperator: functorch_hessian, 18 | GGNLinearOperator: functorch_ggn, 19 | EFLinearOperator: functorch_empirical_fisher, 20 | } 21 | CURVATURE_CASES = CURVATURE_IN_FUNCTORCH.keys() 22 | 23 | 24 | def even_idxs(dim: int) -> List[int]: 25 | return list(range(0, dim, 2)) 26 | 27 | 28 | def odd_idxs(dim: int) -> List[int]: 29 | return list(range(1, dim, 2)) 30 | 31 | 32 | def every_third_idxs(dim: int): 33 | return list(range(0, dim, 3)) 34 | 35 | 36 | SUBMATRIX_CASES = [ 37 | # same indices for rows and columns (square matrix) 38 | { 39 | "row_idx_fn": even_idxs, 40 | "col_idx_fn": even_idxs, 41 | }, 42 | # different indices for rows and columns (square matrix if dim is even) 43 | { 44 | "row_idx_fn": even_idxs, 45 | "col_idx_fn": odd_idxs, 46 | }, 47 | # different indices for rows and columns (rectangular matrix if dim>5) 48 | { 49 | "row_idx_fn": odd_idxs, 50 | "col_idx_fn": every_third_idxs, 51 | }, 52 | ] 53 | SUBMATRIX_IDS = [ 54 | f"({case['row_idx_fn'].__name__},{case['col_idx_fn'].__name__})" 55 | for case in SUBMATRIX_CASES 56 | ] 57 | 58 | 59 | def setup_submatrix_linear_operator(case, operator_case, submatrix_case): 60 | model_func, loss_func, params, data, batch_size_fn = case 61 | dim = sum(p.numel() for p in params) 62 | row_idxs = submatrix_case["row_idx_fn"](dim) 63 | col_idxs = submatrix_case["col_idx_fn"](dim) 64 | 65 | A = operator_case(model_func, loss_func, params, data, batch_size_fn=batch_size_fn) 66 | A_sub = SubmatrixLinearOperator(A, row_idxs, col_idxs) 67 | 68 | A_functorch = CURVATURE_IN_FUNCTORCH[operator_case]( 69 | model_func, loss_func, params, data, "x" 70 | ) 71 | A_sub_functorch = A_functorch[row_idxs, :][:, col_idxs] 72 | 73 | return A_sub, A_sub_functorch, row_idxs, col_idxs 74 | 75 | 76 | @mark.parametrize("operator_case", CURVATURE_CASES) 77 | @mark.parametrize("submatrix_case", SUBMATRIX_CASES) 78 | def test_SubmatrixLinearOperator_on_curvatures( 79 | case, 80 | operator_case, 81 | submatrix_case, 82 | adjoint: bool, 83 | is_vec: bool, 84 | ): 85 | A_sub, A_sub_functorch, row_idxs, col_idxs = setup_submatrix_linear_operator( 86 | case, operator_case, submatrix_case 87 | ) 88 | assert A_sub.shape == (len(row_idxs), len(col_idxs)) 89 | compare_consecutive_matmats(A_sub, adjoint, is_vec) 90 | compare_matmat(A_sub, A_sub_functorch, adjoint, is_vec, atol=1e-6, rtol=1e-4) 91 | 92 | # try specifying the sub-matrix using invalid indices 93 | invalid_idxs = [ 94 | [[0.0], [0]], # wrong type in row_idxs 95 | [[0], [0.0]], # wrong type in col_idxs 96 | [[2, 1, 2], [3]], # duplicate entries in row_idxs 97 | [[3], [2, 1, 2]], # duplicate entries in col_idxs 98 | [[1_000_000_000_000, 5], [2]], # out-of-bounds in row_idxs 99 | [[6, 5], [-1]], # out-of-bounds in col_idxs 100 | ] 101 | for row_idxs, col_idxs in invalid_idxs: 102 | with raises(ValueError): 103 | A_sub.set_submatrix(row_idxs, col_idxs) 104 | -------------------------------------------------------------------------------- /test/diagonal/test_epperly2024xtrace.py: -------------------------------------------------------------------------------- 1 | """Test ``curvlinops.diagonal.epperly2024xtrace``.""" 2 | 3 | from functools import partial 4 | from typing import Union 5 | 6 | from pytest import mark 7 | from torch import Tensor, column_stack, manual_seed, rand 8 | from torch.linalg import qr 9 | 10 | from curvlinops import xdiag 11 | from curvlinops._torch_base import PyTorchLinearOperator 12 | from curvlinops.sampling import random_vector 13 | from test.diagonal import NUM_MATVEC_IDS, NUM_MATVECS 14 | from test.utils import check_estimator_convergence 15 | 16 | 17 | def xdiag_naive(A: Union[PyTorchLinearOperator, Tensor], num_matvecs: int) -> Tensor: 18 | """Naive reference implementation of XDiag. 19 | 20 | See Section 2.4 in https://arxiv.org/pdf/2301.07825. 21 | 22 | Args: 23 | A: A square linear operator. 24 | num_matvecs: Total number of matrix-vector products to use. Must be even and 25 | less than the dimension of the linear operator. 26 | 27 | Returns: 28 | The estimated diagonal of the linear operator. 29 | 30 | Raises: 31 | ValueError: If the linear operator is not square or if the number of matrix- 32 | vector products is not even or is greater than the dimension of the linear 33 | operator. 34 | """ 35 | if len(A.shape) != 2 or A.shape[0] != A.shape[1]: 36 | raise ValueError(f"A must be square. Got shape {A.shape}.") 37 | dim = A.shape[1] 38 | if num_matvecs % 2 != 0 or num_matvecs >= dim: 39 | raise ValueError( 40 | "num_matvecs must be even and less than the dimension of A.", 41 | f" Got {num_matvecs}.", 42 | ) 43 | num_vecs = num_matvecs // 2 44 | 45 | W = column_stack( 46 | [random_vector(dim, "rademacher", A.device, A.dtype) for _ in range(num_vecs)] 47 | ) 48 | A_W = A @ W 49 | 50 | diagonals = [] 51 | 52 | for i in range(num_vecs): 53 | # compute the exact diagonal of the projection onto the basis spanned by 54 | # the sketch matrix without test vector i 55 | not_i = [j for j in range(num_vecs) if j != i] 56 | Q_i, _ = qr(A_W[:, not_i]) 57 | QT_i_A = Q_i.T @ A 58 | diag_Q_i_QT_i_A = (Q_i @ QT_i_A).diag() 59 | 60 | # apply vanilla Hutchinson in the complement, using test vector i 61 | w_i = W[:, i] 62 | A_w_i = A_W[:, i] 63 | diag_w_i = w_i * (A_w_i - Q_i @ (Q_i.T @ A_w_i)) / w_i**2 64 | diagonals.append(diag_Q_i_QT_i_A + diag_w_i) 65 | 66 | return sum(diagonals) / len(diagonals) 67 | 68 | 69 | @mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS) 70 | def test_xdiag(num_matvecs: int): 71 | """Test whether the XDiag estimator converges to the true diagonal. 72 | 73 | Args: 74 | num_matvecs: Number of matrix-vector multiplications used by one estimator. 75 | """ 76 | manual_seed(0) 77 | A = rand(30, 30) 78 | 79 | estimator = partial(xdiag, A=A, num_matvecs=num_matvecs) 80 | check_estimator_convergence(estimator, num_matvecs, A.diag(), target_rel_error=3e-2) 81 | 82 | 83 | @mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS) 84 | def test_xdiag_matches_naive(num_matvecs: int, num_seeds: int = 5): 85 | """Test whether the efficient implementation of XDiag matches the naive. 86 | 87 | Args: 88 | num_matvecs: Number of matrix-vector multiplications used by one estimator. 89 | num_seeds: Number of different seeds to test the estimators with. 90 | Default: ``5``. 91 | """ 92 | manual_seed(0) 93 | A = rand(30, 30).double() 94 | 95 | # check for different seeds 96 | for i in range(num_seeds): 97 | manual_seed(i) 98 | efficient = xdiag(A, num_matvecs) 99 | manual_seed(i) 100 | naive = xdiag_naive(A, num_matvecs) 101 | assert efficient.allclose(naive) 102 | -------------------------------------------------------------------------------- /test/experimental/test_activation_hessian.py: -------------------------------------------------------------------------------- 1 | """Contains tests for ``curvlinops.activation_hessian``.""" 2 | 3 | from pytest import mark, raises 4 | from torch import allclose, block_diag, device, einsum, eye, manual_seed, rand 5 | from torch.nn import CrossEntropyLoss, Linear, ReLU, Sequential, Sigmoid 6 | 7 | from curvlinops.experimental.activation_hessian import ( 8 | ActivationHessianLinearOperator, 9 | store_activation, 10 | ) 11 | from test.cases import DEVICES, DEVICES_IDS 12 | from test.utils import classification_targets, eye_like 13 | 14 | 15 | @mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) 16 | def test_store_activation(dev: device): 17 | """Test context that stores the input/output of a layer. 18 | 19 | Args: 20 | dev: Device on which to run the tests. 21 | """ 22 | manual_seed(0) 23 | layers = [ 24 | Linear(10, 8), 25 | ReLU(), 26 | Linear(8, 6), 27 | Sigmoid(), 28 | Linear(6, 4), 29 | ] 30 | layers = [layer.to(dev) for layer in layers] 31 | model = Sequential(*layers).to(dev) 32 | X = rand(5, 10, device=dev) 33 | 34 | # compare with manual forward pass 35 | activation = ("3", "input", 0) # input to the sigmoid layer 36 | activation_storage = [] 37 | with store_activation(model, *activation, activation_storage): 38 | model(X) 39 | act = activation_storage.pop() 40 | assert act.shape == (5, 6) 41 | assert not activation_storage 42 | truth = layers[2](layers[1](layers[0](X))) 43 | assert allclose(act, truth) 44 | # make sure the hooks were removed when computing ``truth`` 45 | assert not activation_storage 46 | 47 | # check failure scenarios 48 | # layer name does not exist 49 | invalid_activation = ("foo", "input", 0) 50 | with raises(ValueError): 51 | store_activation(model, *invalid_activation, []) 52 | 53 | # io type does not exist 54 | invalid_activation = ("3", "foo", 0) 55 | with raises(ValueError): 56 | store_activation(model, *invalid_activation, []) 57 | 58 | # destination not empty 59 | invalid_activation = ("3", "foo", 0) 60 | with raises(ValueError): 61 | store_activation(model, *invalid_activation, [42]) 62 | 63 | 64 | @mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) 65 | def test_ActivationHessianLinearOperator(dev: device): 66 | """Check the Hessian w.r.t. an activation. 67 | 68 | Verifies the Hessian of a function ``l(id(X), y)`` w.r.t. ``X`` where ``id`` is the 69 | identity. 70 | 71 | Args: 72 | dev: Device on which to run the tests. 73 | """ 74 | manual_seed(0) 75 | batch_size, num_classes = 2, 10 76 | 77 | # model does nothing to the input but needs parameters so the linear 78 | # operator can infer the device 79 | model = Linear(num_classes, num_classes, bias=False).to(dev) 80 | model.weight.data = eye_like(model.weight.data) 81 | 82 | loss_func = CrossEntropyLoss(reduction="sum").to(dev) 83 | X = rand(batch_size, num_classes, requires_grad=True, device=dev) 84 | y = classification_targets((batch_size,), num_classes).to(dev) 85 | data = [(X, y)] 86 | activation = ("", "input", 0) 87 | 88 | # compute the Hessian matrix representation 89 | H = ActivationHessianLinearOperator(model, loss_func, activation, data) 90 | H_mat = H @ eye(H.shape[1], dtype=X.dtype, device=dev) 91 | 92 | # we know that the Hessian of softmax CE loss is ``diag(p(x)) - p(x) p(x)ᵀ`` 93 | # where ``p(x)`` is the softmax probability on a single datum ``x``. On a batch, 94 | # the Hessian is the block diagonal stack of these per-sample Hessians 95 | p = X.softmax(dim=1).detach() 96 | blocks = [] 97 | for n in range(batch_size): 98 | p_n = p[n] 99 | H_n = p_n.diag() - einsum("i,j->ij", p_n, p_n) 100 | blocks.append(H_n) 101 | truth = block_diag(*blocks) 102 | 103 | assert allclose(H_mat, truth) 104 | -------------------------------------------------------------------------------- /curvlinops/trace/epperly2024xtrace.py: -------------------------------------------------------------------------------- 1 | """Implements the XTrace algorithm from Epperly 2024.""" 2 | 3 | from typing import Union 4 | 5 | from torch import Tensor, column_stack, dot, einsum, mean 6 | from torch.linalg import inv, qr 7 | 8 | from curvlinops._torch_base import PyTorchLinearOperator 9 | from curvlinops.sampling import random_vector 10 | from curvlinops.utils import ( 11 | assert_divisible_by, 12 | assert_is_square, 13 | assert_matvecs_subseed_dim, 14 | ) 15 | 16 | 17 | def xtrace( 18 | A: Union[PyTorchLinearOperator, Tensor], 19 | num_matvecs: int, 20 | distribution: str = "rademacher", 21 | ) -> Tensor: 22 | """Estimate a linear operator's trace using the XTrace algorithm. 23 | 24 | The method is presented in `this paper `_: 25 | 26 | - Epperly, E. N., Tropp, J. A., & Webber, R. J. (2024). Xtrace: making the most 27 | of every sample in stochastic trace estimation. SIAM Journal on Matrix Analysis 28 | and Applications (SIMAX). 29 | 30 | It combines the variance reduction from Hutch++ with the exchangeability principle. 31 | 32 | Args: 33 | A: A square linear operator. 34 | num_matvecs: Total number of matrix-vector products to use. Must be even and 35 | less than the dimension of the linear operator (because otherwise one can 36 | evaluate the true trace directly at the same cost). 37 | distribution: Distribution of the random vectors used for the trace estimation. 38 | Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``. 39 | 40 | Returns: 41 | The estimated trace of the linear operator. 42 | """ 43 | dim = assert_is_square(A) 44 | assert_matvecs_subseed_dim(A, num_matvecs) 45 | assert_divisible_by(num_matvecs, 2, "num_matvecs") 46 | 47 | # draw random vectors and compute their matrix-vector products 48 | num_vecs = num_matvecs // 2 49 | W = column_stack( 50 | [random_vector(dim, distribution, A.device, A.dtype) for _ in range(num_vecs)] 51 | ) 52 | A_W = A @ W 53 | 54 | # compute the orthogonal basis for all test vectors, and its associated trace 55 | Q, R = qr(A_W) 56 | A_Q = A @ Q 57 | tr_QT_A_Q = einsum("ij,ij->", Q, A_Q) 58 | 59 | # compute the traces in the bases that would result had we left out the i-th 60 | # test vector in the QR decomposition 61 | RT_inv = inv(R.T) 62 | D = 1 / (RT_inv**2).sum(0) ** 0.5 63 | S = einsum("ij,j->ij", RT_inv, D) 64 | tr_QT_i_A_Q_i = einsum("ij,ki,kl,lj->j", S, Q, A_Q, S) 65 | 66 | # Traces in the bases {Q_i}. This follows by writing Tr(QT_i A Q_i) = Tr(A Q_i QT_i) 67 | # then using the relation that Q_i QT_i = Q (I - s_i sT_i) QT. Further 68 | # simplification then leads to 69 | traces = tr_QT_A_Q - tr_QT_i_A_Q_i 70 | 71 | def deflate(v: Tensor, s: Tensor) -> Tensor: 72 | """Apply (I - s sT) to a vector. 73 | 74 | Args: 75 | v: Vector to deflate. 76 | s: Deflation vector. 77 | 78 | Returns: 79 | Deflated vector. 80 | """ 81 | return v - dot(s, v) * s 82 | 83 | # estimate the trace on the complement of Q_i with vanilla Hutchinson using the 84 | # i-th test vector 85 | for i in range(num_vecs): 86 | w_i = W[:, i] 87 | s_i = S[:, i] 88 | A_w_i = A_W[:, i] 89 | 90 | # Compute (I - Q_i QT_i) A (I - Q_i QT_i) w_i 91 | # = (I - Q_i QT_i) (Aw - AQ_i QT_i w_i) 92 | # ( using that Q_i QT_i = Q (I - s_i sT_i) QT ) 93 | # = (I - Q_i QT_i) (Aw - AQ (I - s_i sT_i) QT w) 94 | # = (I - Q (I - s_i sT_i) QT) (Aw - AQ (I - s_i sT_i) QT w) 95 | # |--------- A_p_w_i ---------| 96 | # |-------------------- PT_A_P_w_i----------------------| 97 | A_P_w_i = A_w_i - A_Q @ deflate(Q.T @ w_i, s_i) 98 | PT_A_P_w_i = A_P_w_i - Q @ deflate(Q.T @ A_P_w_i, s_i) 99 | 100 | tr_w_i = dot(w_i, PT_A_P_w_i) 101 | traces[i] += tr_w_i 102 | 103 | return mean(traces) 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Logo Linear Operators for Curvature Matrices in PyTorch 2 | 3 | [![Python 4 | 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/release/python-390/) 5 | ![tests](https://github.com/f-dangel/curvature-linear-operators/actions/workflows/test.yaml/badge.svg) 6 | [![Coveralls](https://coveralls.io/repos/github/f-dangel/curvlinops/badge.svg?branch=main)](https://coveralls.io/github/f-dangel/curvlinops) 7 | 8 | This library provides **lin**ear **op**erator**s**---a unified interface for matrix-free computation---for deep learning **curv**ature matrices in PyTorch. 9 | `curvlinops` is inspired by SciPy's [`sparse.linalg.LinearOperator`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.LinearOperator.html) interface and implements a PyTorch version. 10 | 11 | You can read our [position paper](https://arxiv.org/abs/2501.19183) to know more about why combining linear operators with curvature matrices might be a good idea. 12 | 13 | Main features: 14 | 15 | - **Broad support of curvature matrices.** `curvlinops` supports many common curvature matrices and approximations thereof, such as the Hessian, Fisher, generalized Gauss-Newton, and K-FAC ([overview](https://curvlinops.readthedocs.io/en/latest/linops.html#linear-operators), [visual tour](https://curvlinops.readthedocs.io/en/latest/basic_usage/example_visual_tour.html#visualization)). 16 | 17 | - **Unified interface.** All linear operators share the same interface, making it easy to switch between curvature matrices. 18 | 19 | - **Purely PyTorch.** All computations can run on a GPU. 20 | 21 | - **SciPy export.** You can export a `curvlinops` linear operator to a SciPy `LinearOperator` with `.to_scipy()`. 22 | This allows plugging it into `scipy`, while carrying out the heavy lifting (matrix-vector multiplies) in PyTorch on GPU. 23 | My favorite example is 24 | [`scipy.sparse.linalg.eigsh`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.eigsh.html) that lets you compute a subset of eigen-pairs ([example](https://curvlinops.readthedocs.io/en/latest/basic_usage/example_eigenvalues.html)). 25 | 26 | - **Randomized estimation algorithms.** `curvlinops` offers functionality to estimate properties the matrix represented by a linear operators, like its spectral density ([example](https://curvlinops.readthedocs.io/en/latest/basic_usage/example_verification_spectral_density.html)), 27 | inverse ([example](https://curvlinops.readthedocs.io/en/latest/basic_usage/example_inverses.html)), 28 | trace & diagonal ([example](https://curvlinops.readthedocs.io/en/latest/basic_usage/example_trace_diagonal_estimation.html)). 29 | 30 | ## Installation 31 | 32 | ```bash 33 | pip install curvlinops-for-pytorch 34 | ``` 35 | 36 | ## Useful Links 37 | 38 | - [Basic 39 | usage](https://curvlinops.readthedocs.io/en/latest/basic_usage/example_matrix_vector_products.html) 40 | 41 | - [Advanced 42 | examples](https://curvlinops.readthedocs.io/en/latest/basic_usage/index.html) 43 | 44 | - **Documentation:** https://curvlinops.readthedocs.io/en/latest/ 45 | 46 | - **Bug reports & feature requests:** 47 | https://github.com/f-dangel/curvlinops/issues 48 | 49 | ## Citation 50 | 51 | If you find `curvlinops` useful for your work, consider citing our [position paper](https://arxiv.org/abs/2501.19183) 52 | 53 | ```bibtex 54 | 55 | @article{dangel2025position, 56 | title = {Position: Curvature Matrices Should Be Democratized via Linear 57 | Operators}, 58 | author = {Dangel, Felix and Eschenhagen, Runa and Ormaniec, Weronika and 59 | Fernandez, Andres and Tatzel, Lukas and Kristiadi, Agustinus}, 60 | journal = {arXiv 2501.19183}, 61 | year = 2025, 62 | } 63 | 64 | ``` 65 | 66 | ## Future ideas 67 | 68 | - Refactor the back-end for curvature-matrix multiplication into pure functions 69 | to improve recycle-ability and ease the use of `torch.compile`. 70 | 71 | - Multi-GPU support. 72 | 73 | - Include more curvature matrices 74 | - E.g. the [GGN's hierarchical decomposition](https://arxiv.org/abs/2008.11865) 75 | 76 | ###### Logo mage credits 77 | - PyTorch logo: https://github.com/soumith, [CC BY-SA 78 | 4.0](https://creativecommons.org/licenses/by-sa/4.0), via Wikimedia Commons 79 | -------------------------------------------------------------------------------- /docs/examples/basic_usage/benchmark_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for setting up nanoGPT.""" 2 | 3 | import inspect 4 | from os import path 5 | from typing import List, Tuple 6 | 7 | import requests 8 | from torch import Tensor, rand, randint, stack, zeros_like 9 | from torch.nn import CrossEntropyLoss, Module, Parameter 10 | from torchvision.models import ResNet50_Weights, resnet50 11 | 12 | # In the execution with sphinx-gallery, __file__ is not defined and we need 13 | # to set it manually using the trick from https://stackoverflow.com/a/53293924 14 | if "__file__" not in globals(): 15 | __file__ = inspect.getfile(lambda: None) 16 | 17 | HEREDIR = path.dirname(path.abspath(__file__)) 18 | 19 | 20 | def maybe_download_nanogpt(): 21 | """Download the nanoGPT model definition.""" 22 | commit = "f08abb45bd2285627d17da16daea14dda7e7253e" 23 | repo = "https://raw.githubusercontent.com/karpathy/nanoGPT/" 24 | 25 | # download the model definition as 'nanogpt_model.py' 26 | model_url = f"{repo}{commit}/model.py" 27 | model_path = path.join(HEREDIR, "nanogpt_model.py") 28 | if not path.exists(model_path): 29 | url = requests.get(model_url) 30 | with open(model_path, "w") as f: 31 | f.write(url.content.decode("utf-8")) 32 | 33 | 34 | class GPTWrapper(Module): 35 | """Wraps Karpathy's nanoGPT model repo so that it produces the flattened logits.""" 36 | 37 | def __init__(self, gpt: Module): 38 | """Store the wrapped nanoGPT model. 39 | 40 | Args: 41 | gpt: The nanoGPT model. 42 | """ 43 | super().__init__() 44 | self.gpt = gpt 45 | 46 | def forward(self, x: Tensor) -> Tensor: 47 | """Forward pass of the nanoGPT model. 48 | 49 | Args: 50 | x: The input tensor. Has shape ``(batch_size, sequence_length)``. 51 | 52 | Returns: 53 | The flattened logits. 54 | Has shape ``(batch_size * sequence_length, vocab_size)``. 55 | """ 56 | y_dummy = zeros_like(x) 57 | logits, _ = self.gpt(x, y_dummy) 58 | return logits.view(-1, logits.size(-1)) 59 | 60 | 61 | def setup_synthetic_shakespeare_nanogpt( 62 | batch_size: int = 4, 63 | ) -> Tuple[GPTWrapper, CrossEntropyLoss, List[Tuple[Tensor, Tensor]]]: 64 | """Set up the nanoGPT model and synthetic Shakespeare dataset for the benchmark. 65 | 66 | Args: 67 | batch_size: The batch size to use. Default is ``4``. 68 | 69 | Returns: 70 | A tuple containing the nanoGPT model, the loss function, and the data. 71 | """ 72 | # download nanogpt_model and import GPT and GPTConfig from it 73 | maybe_download_nanogpt() 74 | from nanogpt_model import GPT, GPTConfig 75 | 76 | config = GPTConfig() 77 | block_size = config.block_size 78 | 79 | base = GPT(config) 80 | # Remove weight tying as this will break the parameter-to-layer detection 81 | base.transformer.wte.weight = Parameter( 82 | data=base.transformer.wte.weight.data.detach().clone() 83 | ) 84 | 85 | model = GPTWrapper(base).eval() 86 | loss_function = CrossEntropyLoss(ignore_index=-1) 87 | 88 | # generate a synthetic Shakespeare and load one batch 89 | vocab_size = config.vocab_size 90 | train_data = randint(0, vocab_size, (5 * block_size,)).long() 91 | ix = randint(train_data.numel() - block_size, (batch_size,)) 92 | X = stack([train_data[i : i + block_size] for i in ix]) 93 | y = stack([train_data[i + 1 : i + 1 + block_size] for i in ix]) 94 | # flatten the target because the GPT wrapper flattens the logits 95 | data = [(X, y.flatten())] 96 | 97 | return model, loss_function, data 98 | 99 | 100 | def setup_synthetic_imagenet_resnet50( 101 | batch_size: int = 64, 102 | ) -> Tuple[Module, CrossEntropyLoss, List[Tuple[Tensor, Tensor]]]: 103 | """Set up ResNet50 on synthetic ImageNet for the benchmark. 104 | 105 | Args: 106 | batch_size: The batch size to use. Default is ``64``. 107 | 108 | Returns: 109 | A tuple containing the ResNet50 model, the loss function 110 | and the data. 111 | """ 112 | X = rand(batch_size, 3, 224, 224) 113 | y = randint(0, 1000, (batch_size,)) 114 | data = [(X, y)] 115 | model = resnet50(weights=ResNet50_Weights.DEFAULT) 116 | loss_function = CrossEntropyLoss() 117 | 118 | return model, loss_function, data 119 | -------------------------------------------------------------------------------- /test/trace/test_epperly2024xtrace.py: -------------------------------------------------------------------------------- 1 | """Test ``curvlinops.trace.epperli2024xtrace.""" 2 | 3 | from functools import partial 4 | from typing import Union 5 | 6 | from pytest import mark 7 | from torch import Tensor, column_stack, dot, isclose, manual_seed, rand, trace 8 | from torch.linalg import qr 9 | 10 | from curvlinops import xtrace 11 | from curvlinops._torch_base import PyTorchLinearOperator 12 | from curvlinops.sampling import random_vector 13 | from test.trace import DISTRIBUTION_IDS, DISTRIBUTIONS 14 | from test.utils import check_estimator_convergence 15 | 16 | NUM_MATVECS = [6, 8] 17 | NUM_MATVEC_IDS = [f"num_matvecs={num_matvecs}" for num_matvecs in NUM_MATVECS] 18 | 19 | 20 | def xtrace_naive( 21 | A: Union[PyTorchLinearOperator, Tensor], 22 | num_matvecs: int, 23 | distribution: str = "rademacher", 24 | ) -> Tensor: 25 | """Naive reference implementation of XTrace. 26 | 27 | See Algorithm 1.2 in https://arxiv.org/pdf/2301.07825. 28 | 29 | Args: 30 | A: A square linear operator. 31 | num_matvecs: Total number of matrix-vector products to use. Must be even and 32 | less than the dimension of the linear operator. 33 | distribution: Distribution of the random vectors used for the trace estimation. 34 | Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``. 35 | 36 | Returns: 37 | The estimated trace of the linear operator. 38 | 39 | Raises: 40 | ValueError: If the linear operator is not square or if the number of matrix- 41 | vector products is not even or is greater than the dimension of the linear 42 | operator. 43 | """ 44 | if len(A.shape) != 2 or A.shape[0] != A.shape[1]: 45 | raise ValueError(f"A must be square. Got shape {A.shape}.") 46 | dim = A.shape[1] 47 | if num_matvecs % 2 != 0 or num_matvecs >= dim: 48 | raise ValueError( 49 | "num_matvecs must be even and less than the dimension of A.", 50 | f" Got {num_matvecs}.", 51 | ) 52 | num_vecs = num_matvecs // 2 53 | 54 | W = column_stack( 55 | [random_vector(dim, distribution, A.device, A.dtype) for _ in range(num_vecs)] 56 | ) 57 | A_W = A @ W 58 | 59 | traces = [] 60 | 61 | for i in range(num_vecs): 62 | # compute the exact trace in the basis spanned by the sketch matrix without 63 | # test vector i 64 | not_i = [j for j in range(num_vecs) if j != i] 65 | Q_i, _ = qr(A_W[:, not_i]) 66 | A_Q_i = A @ Q_i 67 | tr_QT_i_A_Q_i = trace(Q_i.T @ A_Q_i) 68 | 69 | # apply vanilla Hutchinson in the complement, using test vector i 70 | w_i = W[:, i] 71 | A_w_i = A_W[:, i] 72 | A_P_w_i = A_w_i - A_Q_i @ (Q_i.T @ w_i) 73 | PT_A_P_w_i = A_P_w_i - Q_i @ (Q_i.T @ A_P_w_i) 74 | tr_w_i = dot(w_i, PT_A_P_w_i) 75 | 76 | traces.append(tr_QT_i_A_Q_i + tr_w_i) 77 | 78 | return sum(traces) / len(traces) 79 | 80 | 81 | @mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS) 82 | @mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS) 83 | def test_xtrace(distribution: str, num_matvecs: int): 84 | """Test whether the XTrace estimator converges to the true trace. 85 | 86 | Args: 87 | distribution: Distribution of the random vectors used for the trace estimation. 88 | num_matvecs: Number of matrix-vector multiplications used by one estimator. 89 | """ 90 | manual_seed(0) 91 | A = rand(15, 15) 92 | estimator = partial(xtrace, A=A, num_matvecs=num_matvecs, distribution=distribution) 93 | check_estimator_convergence( 94 | estimator, 95 | num_matvecs, 96 | A.trace(), 97 | # use half the target tolerance as vanilla Hutchinson 98 | target_rel_error=5e-4, 99 | ) 100 | 101 | 102 | @mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS) 103 | @mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS) 104 | def test_xtrace_matches_naive(num_matvecs: int, distribution: str, num_seeds: int = 5): 105 | """Test whether the efficient implementation of XTrace matches the naive. 106 | 107 | Args: 108 | num_matvecs: Number of matrix-vector multiplications used by one estimator. 109 | distribution: Distribution of the random vectors used for the trace estimation. 110 | num_seeds: Number of different seeds to test the estimators with. 111 | Default: ``5``. 112 | """ 113 | manual_seed(0) 114 | A = rand(50, 50) 115 | 116 | # check for different seeds 117 | for i in range(num_seeds): 118 | manual_seed(i) 119 | efficient = xtrace(A, num_matvecs, distribution=distribution) 120 | manual_seed(i) 121 | naive = xtrace_naive(A, num_matvecs, distribution=distribution) 122 | assert isclose(efficient, naive) 123 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # This file is used to configure the project. 2 | # Read more about the various options under: 3 | # https://packaging.python.org/en/latest/guides/writing-pyproject-toml 4 | # https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html 5 | 6 | [build-system] 7 | requires = ["setuptools >= 61.0", "setuptools_scm"] 8 | build-backend = "setuptools.build_meta" 9 | 10 | ############################################################################### 11 | # Main library # 12 | ############################################################################### 13 | 14 | [project] 15 | name = "curvlinops-for-pytorch" 16 | authors = [ 17 | { name = "Felix Dangel" }, 18 | { name = "Runa Eschenhagen" }, 19 | { name = "Lukas Tatzel" }, 20 | ] 21 | urls = { Repository = "https://github.com/f-dangel/curvlinops" } 22 | description = "scipy Linear operators for curvature matrices in PyTorch" 23 | readme = { file = "README.md", content-type = "text/markdown; charset=UTF-8; variant=GFM" } 24 | license = { text = "MIT" } 25 | # Add all kinds of additional classifiers as defined under 26 | # https://pypi.python.org/pypi?%3Aaction=list_classifiers 27 | classifiers = [ 28 | "Development Status :: 4 - Beta", 29 | "License :: OSI Approved :: MIT License", 30 | "Operating System :: OS Independent", 31 | "Programming Language :: Python :: 3.9", 32 | "Programming Language :: Python :: 3.10", 33 | "Programming Language :: Python :: 3.11", 34 | "Programming Language :: Python :: 3.12", 35 | ] 36 | dynamic = ["version"] 37 | # Dependencies of the project: 38 | dependencies = [ 39 | "backpack-for-pytorch>=1.6.0", 40 | "torch>=2.0", # introduces torch.func 41 | "scipy>=1.0.0", # introduces LSMR 42 | "numpy>=1.21.0", # compatible with modern Python versions 43 | "tqdm>=4.0.0", # compatible with modern Python versions 44 | "einops>=0.3.0", # introduces einsum 45 | "einconv>=0.1.0", # first release 46 | "linear_operator>=0.2.0", # matvec-based CG implementation in pure PyTorch 47 | ] 48 | # Require a specific Python version, e.g. Python 2.7 or >= 3.4 49 | requires-python = ">=3.9" 50 | 51 | ############################################################################### 52 | # Development dependencies # 53 | ############################################################################### 54 | 55 | [project.optional-dependencies] 56 | # Dependencies needed to run the tests. 57 | test = [ 58 | "matplotlib", 59 | "tueplots", 60 | "coveralls", 61 | "pytest", 62 | "pytest-cov", 63 | "pytest-optional-tests", 64 | ] 65 | 66 | # Dependencies needed for linting. 67 | lint = [ 68 | "ruff", 69 | "darglint", 70 | "pydocstyle", 71 | ] 72 | 73 | # Dependencies needed to build/view the documentation. 74 | docs = [ 75 | "memory_profiler", 76 | "transformers", 77 | "tiktoken", 78 | "datasets", 79 | "matplotlib", 80 | "sphinx-gallery", 81 | "sphinx-rtd-theme", 82 | "tueplots" 83 | ] 84 | 85 | ############################################################################### 86 | # Development tool configurations # 87 | ############################################################################### 88 | [tool.setuptools_scm] 89 | 90 | [tool.pydocstyle] 91 | convention = "google" 92 | match = '.*\.py' 93 | match_dir = '^(?!(test|.git)).*' 94 | 95 | [tool.ruff] 96 | line-length = 88 97 | 98 | [tool.ruff.lint] 99 | # Enable all rules from flake8 (E, F), plus additional ones including isort (I) 100 | select = ["E", "F", "B", "C", "W", "B9", "PLE", "PLW", "PLR", "I"] 101 | ignore = [ 102 | # E501 max-line-length (replaced by B950 (max-line-length + 10%)) 103 | "E501", 104 | # C408 use {} instead of dict() (ignored because pytorch uses dict) 105 | "C408", 106 | # E203 whitespace before : 107 | "E203", 108 | # E231 missing whitespace after ',' 109 | "E231", 110 | # W291 trailing whitespace 111 | "W291", 112 | # E203 line break before binary operator (replaces W503) 113 | "E203", 114 | # Line break occurred after a binary operator (replaces W504) 115 | "E226", 116 | # B905 `zip()` without an explicit `strict=` parameter 117 | "B905", 118 | # Too many arguments in function definition (9 > 5) 119 | "PLR0913", 120 | # Magic value comparison 121 | "PLR2004", 122 | # Loop variable overwritten by assignment target 123 | "PLW2901", 124 | ] 125 | 126 | [tool.ruff.format] 127 | quote-style = "double" 128 | indent-style = "space" 129 | skip-magic-trailing-comma = false 130 | line-ending = "auto" 131 | exclude = [ 132 | ".eggs", 133 | ".git", 134 | ".pytest_cache", 135 | "docs/rtd", 136 | "build", 137 | "dist", 138 | ] 139 | 140 | [tool.ruff.lint.per-file-ignores] 141 | # Add any per-file ignores here if needed 142 | 143 | [tool.ruff.lint.flake8-bugbear] 144 | extend-immutable-calls = ["pytest.raises", "pytest.warns", "pytest.mark.skip"] -------------------------------------------------------------------------------- /curvlinops/trace/meyer2020hutch.py: -------------------------------------------------------------------------------- 1 | """Implementation of Hutch++ trace estimation from Meyer et al.""" 2 | 3 | from typing import Union 4 | 5 | from torch import Tensor, column_stack, einsum 6 | from torch.linalg import qr 7 | 8 | from curvlinops._torch_base import PyTorchLinearOperator 9 | from curvlinops.sampling import random_vector 10 | from curvlinops.utils import ( 11 | assert_divisible_by, 12 | assert_is_square, 13 | assert_matvecs_subseed_dim, 14 | ) 15 | 16 | 17 | def hutchpp_trace( 18 | A: Union[PyTorchLinearOperator, Tensor], 19 | num_matvecs: int, 20 | distribution: str = "rademacher", 21 | ) -> Tensor: 22 | r"""Estimate a linear operator's trace using the Hutch++ method. 23 | 24 | In contrast to vanilla Hutchinson, Hutch++ has lower variance, but requires more 25 | memory. The method is presented in 26 | 27 | - Meyer, R. A., Musco, C., Musco, C., & Woodruff, D. P. (2020). Hutch++: 28 | optimal stochastic trace estimation. 29 | 30 | Let :math:`\mathbf{A}` be a square linear operator whose trace we want to 31 | approximate. First, using one third of the available matrix-vector products, 32 | we compute an orthonormal basis :math:`\mathbf{Q}` of a sub-space spanned by 33 | :math:`\mathbf{A} \mathbf{S}` where :math:`\mathbf{S}` is a tall random matrix 34 | with i.i.d. elements. Then, using one third of the available matrix-vector 35 | products, we compute the trace in the sub-space. Finally, we apply Hutchinson's 36 | estimator in the remaining space spanned by 37 | :math:`\mathbf{I} - \mathbf{Q} \mathbf{Q}^\top`. Let :math:`3N` denote the 38 | total number of matrix-vector products. We can draw :math:`2N` random vectors 39 | :math:`\mathbf{v}_n \sim \mathbf{v}` from a distribution which satisfies 40 | :math:`\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}`, compute 41 | :math:`\mathbf{Q}` from the first :math:`N` vectors, and use the remaining 42 | to compute the estimator 43 | 44 | .. math:: 45 | a 46 | := \mathrm{Tr}(\mathbf{Q}^\top \mathbf{A} \mathbf{Q}) 47 | + \frac{1}{N} \sum_{n = N+1}^{2N} \mathbf{v}_n^\top 48 | (\mathbf{I} - \mathbf{Q} \mathbf{Q}^\top)^\top 49 | \mathbf{A} (\mathbf{I} - \mathbf{Q} \mathbf{Q}^\top) \mathbf{v}_n 50 | \approx \mathrm{Tr}(\mathbf{A})\,. 51 | 52 | This estimator is unbiased, :math:`\mathbb{E}[a] = \mathrm{Tr}(\mathbf{A})`, as the 53 | first term is the exact trace in the space spanned by :math:`\mathbf{Q}`, and the 54 | second part is Hutchinson's unbiased estimator in the complementary space. 55 | 56 | Args: 57 | A: A square linear operator whose trace is estimated. 58 | num_matvecs: Total number of matrix-vector products to use. Must be smaller 59 | than the dimension of the linear operator (because otherwise one can 60 | evaluate the true trace directly at the same cost), and divisible by 3. 61 | distribution: Distribution of the random vectors used for the trace estimation. 62 | Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``. 63 | 64 | Returns: 65 | The estimated trace of the linear operator. 66 | 67 | Example: 68 | >>> from torch import manual_seed, rand 69 | >>> _ = manual_seed(0) # make deterministic 70 | >>> A = rand(50, 50) 71 | >>> tr_A = A.trace().item() # exact trace as reference 72 | >>> # one- and multi-sample approximations 73 | >>> tr_A_low_precision = hutchpp_trace(A, num_matvecs=3).item() 74 | >>> tr_A_high_precision = hutchpp_trace(A, num_matvecs=30).item() 75 | >>> # compute the relative errors 76 | >>> rel_error_low_precision = abs(tr_A - tr_A_low_precision) / abs(tr_A) 77 | >>> rel_error_high_precision = abs(tr_A - tr_A_high_precision) / abs(tr_A) 78 | >>> assert rel_error_low_precision > rel_error_high_precision 79 | >>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4) 80 | (23.7836, 15.7879, 19.6381) 81 | """ 82 | dim = assert_is_square(A) 83 | assert_matvecs_subseed_dim(A, num_matvecs) 84 | assert_divisible_by(num_matvecs, 3, "num_matvecs") 85 | N = num_matvecs // 3 86 | dev, dt = (A.device, A.dtype) 87 | 88 | # compute the orthogonal basis for the subspace spanned by AS, and evaluate the 89 | # exact trace using 2/3 of the available matrix-vector products 90 | AS = A @ column_stack([random_vector(dim, distribution, dev, dt) for _ in range(N)]) 91 | Q, _ = qr(AS) 92 | tr_QT_A_Q = einsum("ji,ji", Q, A @ Q) 93 | 94 | # compute the trace in the complementary space using the remaining 1/3 of the 95 | # matrix-vector products 96 | G = column_stack([random_vector(dim, distribution, dev, dt) for _ in range(N)]) 97 | 98 | # project out subspace 99 | A_proj_G = A @ (G - Q @ (Q.T @ G)) 100 | A_proj_G -= Q @ (Q.T @ A_proj_G) 101 | # compute trace with vanilla Hutchinson 102 | tr_A_proj = einsum("ij,ij", G, A_proj_G) / N 103 | 104 | return tr_QT_A_Q + tr_A_proj 105 | -------------------------------------------------------------------------------- /docs/examples/basic_usage/example_huggingface.py: -------------------------------------------------------------------------------- 1 | r"""Usage with Huggingface LLMs 2 | =============================== 3 | 4 | This example demonstrates how to work with Huggingface (HF) language models. 5 | 6 | As always, let's first import the required functionality. 7 | Remember to run :code:`pip install -U transformers datasets` 8 | """ 9 | 10 | from collections import UserDict 11 | from collections.abc import MutableMapping 12 | 13 | import torch.utils.data as data_utils 14 | from datasets import Dataset 15 | from torch import Tensor, bfloat16, eye, manual_seed, no_grad 16 | from torch.nn import CrossEntropyLoss, Module 17 | from transformers import ( 18 | DataCollatorWithPadding, 19 | GPT2Config, 20 | GPT2ForSequenceClassification, 21 | GPT2Tokenizer, 22 | PreTrainedTokenizer, 23 | ) 24 | 25 | from curvlinops import GGNLinearOperator 26 | 27 | # make deterministic 28 | manual_seed(0) 29 | 30 | # %% 31 | # 32 | # Data 33 | # ---- 34 | # 35 | # We will use synthetic data for simplicity. But obviously this can 36 | # be replaced with any HF dataloader. 37 | 38 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 39 | tokenizer.pad_token_id = tokenizer.eos_token_id 40 | 41 | data = [ 42 | {"text": "Today is hot, but I will manage!!!!", "label": 1}, 43 | {"text": "Tomorrow is cold", "label": 0}, 44 | {"text": "Carpe diem", "label": 1}, 45 | {"text": "Tempus fugit", "label": 1}, 46 | ] 47 | dataset = Dataset.from_list(data) 48 | 49 | 50 | def tokenize(row): 51 | return tokenizer(row["text"]) 52 | 53 | 54 | dataset = dataset.map(tokenize, remove_columns=["text"]) 55 | dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) 56 | dataloader = data_utils.DataLoader( 57 | dataset, batch_size=100, collate_fn=DataCollatorWithPadding(tokenizer) 58 | ) 59 | 60 | # %% 61 | # 62 | # Let's check the batch emitted by HF. We will see that it is a :code:`UserDict`, 63 | # containing the input and label tensors. Note that :code:`UserDict` is 64 | # :code:`MutableMapping`, so it is compatible with :code:`curvlinops`. 65 | 66 | data = next(iter(dataloader)) 67 | print(f"Is the data a UserDict? {isinstance(data, UserDict)}") 68 | for k, v in data.items(): 69 | print(k, v.shape) 70 | 71 | 72 | # %% 73 | # 74 | # Model 75 | # ----- 76 | # 77 | # Curvlinops supports general :code:`UserDict` inputs. However, everything must 78 | # be handled inside the :code:`forward` function of the model. This gives 79 | # the users the most flexibility, without much overhead. 80 | # 81 | # Let's wrap the HF model to conform this requirement then. 82 | 83 | 84 | class MyGPT2(Module): 85 | """ 86 | Huggingface LLM wrapper. 87 | 88 | Args: 89 | tokenizer: The tokenizer used for preprocessing the text data. Needed 90 | since the model needs to know the padding token id. 91 | """ 92 | 93 | def __init__(self, tokenizer: PreTrainedTokenizer) -> None: 94 | super().__init__() 95 | config = GPT2Config.from_pretrained("gpt2") 96 | config.pad_token_id = tokenizer.pad_token_id 97 | config.num_labels = 2 98 | self.hf_model = GPT2ForSequenceClassification.from_pretrained( 99 | "gpt2", config=config 100 | ) 101 | 102 | # For simplicity, only enable grad for the last layer 103 | for p in self.hf_model.parameters(): 104 | p.requires_grad = False 105 | 106 | for p in self.hf_model.score.parameters(): 107 | p.requires_grad = True 108 | 109 | def forward(self, data: MutableMapping) -> Tensor: 110 | """ 111 | Custom forward function. Handles things like moving the 112 | input tensor to the correct device inside. 113 | 114 | Args: 115 | data: A dict-like data structure with `input_ids` inside. 116 | This is the default data structure assumed by Huggingface 117 | dataloaders. 118 | 119 | Returns: 120 | logits: An `(batch_size, n_classes)`-sized tensor of logits. 121 | """ 122 | device = next(self.parameters()).device 123 | input_ids = data["input_ids"].to(device) 124 | output_dict = self.hf_model(input_ids) 125 | return output_dict.logits 126 | 127 | 128 | model = MyGPT2(tokenizer).to(bfloat16) 129 | 130 | with no_grad(): 131 | logits = model(data) 132 | print(f"Logits shape: {logits.shape}") 133 | 134 | 135 | # %% 136 | # 137 | # Curvlinops 138 | # ---------- 139 | # 140 | # We are now ready to compute the curvature of this HF model using Curvlinops. 141 | # For this, we need to define a function to tell Curvlinops how to get the 142 | # batch size of the :code:`UserDict` input batch. Everything else is unchanged 143 | # from the standard usage of Curvlinops! 144 | 145 | 146 | def batch_size_fn(x: MutableMapping): 147 | return x["input_ids"].shape[0] 148 | 149 | 150 | params = [p for p in model.parameters() if p.requires_grad] 151 | 152 | ggn = GGNLinearOperator( 153 | model, 154 | CrossEntropyLoss(), 155 | params, 156 | [(data, data["labels"])], # We still need to input a list of "(X, y)" pairs! 157 | check_deterministic=False, 158 | batch_size_fn=batch_size_fn, # Remember to specify this! 159 | ) 160 | 161 | G = ggn @ eye(ggn.shape[0], device=params[0].device) 162 | 163 | print(f"GGN shape: {G.shape}") 164 | 165 | 166 | # %% 167 | # 168 | # Conclusion 169 | # ---------- 170 | # 171 | # This :code:`UserDict` (or any other dict-like data structure) specification 172 | # is very flexible. This doesn't stop at HF models. You can leverage this 173 | # for any custom models! 174 | -------------------------------------------------------------------------------- /curvlinops/examples/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains functionality for examples in the documentation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import List, Tuple 6 | 7 | from torch import Tensor, device, dtype, einsum 8 | 9 | from curvlinops._torch_base import PyTorchLinearOperator 10 | 11 | 12 | class TensorLinearOperator(PyTorchLinearOperator): 13 | """Linear operator wrapping a single tensor as a linear operator.""" 14 | 15 | def __init__(self, A: Tensor): 16 | """Initialize linear operator from a 2D tensor. 17 | 18 | Args: 19 | A: A 2D tensor representing the matrix. 20 | 21 | Raises: 22 | ValueError: If ``A`` is not a 2D tensor. 23 | """ 24 | if A.ndim != 2: 25 | raise ValueError(f"Input tensor must be 2D. Got {A.ndim}D.") 26 | super().__init__([(A.shape[1],)], [(A.shape[0],)]) 27 | self._A = A 28 | self.SELF_ADJOINT = A.shape == A.T.shape and A.allclose(A.T) 29 | 30 | @property 31 | def device(self) -> device: 32 | """Infer the linear operator's device. 33 | 34 | Returns: 35 | The linear operator's device. 36 | """ 37 | return self._A.device 38 | 39 | @property 40 | def dtype(self) -> dtype: 41 | """Infer the linear operator's data type. 42 | 43 | Returns: 44 | The linear operator's data type. 45 | """ 46 | return self._A.dtype 47 | 48 | def _adjoint(self) -> TensorLinearOperator: 49 | """Return a linear operator representing the adjoint. 50 | 51 | Returns: 52 | The adjoint linear operator. 53 | """ 54 | return TensorLinearOperator(self._A.conj().T) 55 | 56 | def _matmat(self, M: List[Tensor]) -> List[Tensor]: 57 | """Multiply the linear operator onto a matrix in list format. 58 | 59 | Args: 60 | M: Matrix for multiplication in list format. 61 | 62 | Returns: 63 | Result of the matrix-matrix multiplication in list format. 64 | """ 65 | (M0,) = M 66 | return [self._A @ M0] 67 | 68 | 69 | class OuterProductLinearOperator(PyTorchLinearOperator): 70 | """Linear operator for low-rank matrices of the form ``∑ᵢ cᵢ aᵢ aᵢᵀ``. 71 | 72 | ``cᵢ`` is the coefficient for the vector ``aᵢ``. 73 | """ 74 | 75 | SELF_ADJOINT = True 76 | 77 | def __init__(self, c: Tensor, A: Tensor): 78 | """Store coefficients and vectors for low-rank representation. 79 | 80 | Args: 81 | c: Coefficients ``cᵢ``. Has shape ``[K]`` where ``K`` is the rank. 82 | A: Matrix of shape ``[D, K]``, where ``D`` is the linear operators 83 | dimension, that stores the low-rank vectors columnwise, i.e. ``aᵢ`` 84 | is stored in ``A[:,i]``. 85 | """ 86 | shape = [(A.shape[0],)] 87 | super().__init__(shape, shape) 88 | self._A = A 89 | self._c = c 90 | 91 | def _matmat(self, M: List[Tensor]) -> List[Tensor]: 92 | """Apply the linear operator to a matrix in list format. 93 | 94 | Args: 95 | M: The matrix to multiply onto in list format. 96 | 97 | Returns: 98 | The result of the multiplication in list format. 99 | """ 100 | (M0,) = M 101 | # Compute ∑ᵢ cᵢ aᵢ aᵢᵀ @ X 102 | return [einsum("ik,k,jk,jl->il", self._A, self._c, self._A, M0)] 103 | 104 | def _adjoint(self) -> OuterProductLinearOperator: 105 | """Return the linear operator representing the adjoint. 106 | 107 | An outer product is self-adjoint. 108 | 109 | Returns: 110 | Self. 111 | """ 112 | return self 113 | 114 | @property 115 | def dtype(self) -> dtype: 116 | """Return the data type of the linear operator. 117 | 118 | Returns: 119 | The data type of the linear operator. 120 | """ 121 | return self._A.dtype 122 | 123 | @property 124 | def device(self) -> device: 125 | """Return the linear operator's device. 126 | 127 | Returns: 128 | The device on which the linear operator is defined. 129 | """ 130 | return self._A.device 131 | 132 | 133 | class IdentityLinearOperator(PyTorchLinearOperator): 134 | """Linear operator representing the identity matrix.""" 135 | 136 | SELF_ADJOINT = True 137 | 138 | def __init__(self, shape: List[Tuple[int, ...]], device: device, dtype: dtype): 139 | """Store the linear operator's input and output space dimensions. 140 | 141 | Args: 142 | shape: A list of shapes specifying the identity's input and output space. 143 | device: The device on which the identity operator is defined. 144 | dtype: The data type of the identity operator. 145 | """ 146 | super().__init__(shape, shape) 147 | self._device = device 148 | self._dtype = dtype 149 | 150 | def _matmat(self, M: List[Tensor]) -> List[Tensor]: 151 | """Apply the linear operator to a matrix in list format. 152 | 153 | Args: 154 | M: The matrix to multiply onto in list format. 155 | 156 | Returns: 157 | The result of the matrix multiplication in list format. 158 | """ 159 | return M 160 | 161 | @property 162 | def dtype(self) -> dtype: 163 | """Return the data type of the linear operator. 164 | 165 | Returns: 166 | The data type of the linear operator. 167 | """ 168 | return self._dtype 169 | 170 | @property 171 | def device(self) -> device: 172 | """Return the linear operator's device. 173 | 174 | Returns: 175 | The device on which the linear operator is defined. 176 | """ 177 | return self._device 178 | -------------------------------------------------------------------------------- /docs/examples/basic_usage/memory_benchmark.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | from contextlib import nullcontext 4 | from os import path 5 | 6 | from benchmark_utils import GPTWrapper 7 | from example_benchmark import ( 8 | HAS_JVP, 9 | LINOP_STRS, 10 | OP_STRS, 11 | PROBLEM_STRS, 12 | SKIP_EXISTING, 13 | benchpath, 14 | setup_linop, 15 | setup_problem, 16 | ) 17 | from memory_profiler import memory_usage 18 | from torch import cuda, device, manual_seed, rand 19 | from torch.nn.attention import SDPBackend, sdpa_kernel 20 | 21 | from curvlinops import KFACInverseLinearOperator, KFACLinearOperator 22 | 23 | 24 | def run_peakmem_benchmark( # noqa: C901, PLR0915 25 | linop_str: str, problem_str: str, device_str: str, op_str: str 26 | ): 27 | """Execute the memory benchmark for a given linear operator class and save results. 28 | 29 | Args: 30 | linop_str: The linear operator. 31 | problem_str: The problem. 32 | device_str: The device. 33 | op_str: The operation that is benchmarked. 34 | """ 35 | savepath = benchpath(linop_str, problem_str, device_str, op_str, metric="peakmem") 36 | if SKIP_EXISTING and path.exists(savepath): 37 | print( 38 | f"[Memory] Skipping {linop_str} on {problem_str} and {device_str} for " 39 | + f"{op_str}" 40 | ) 41 | return 42 | 43 | dev = device(device_str) 44 | is_cuda = "cuda" in str(dev) 45 | 46 | def f_gradient_and_loss(): 47 | manual_seed(0) # make deterministic 48 | 49 | model, loss_function, params, data = setup_problem(problem_str, linop_str, dev) 50 | # NOTE Disable deterministic check as it will otherwise compute matvecs 51 | linop = setup_linop( 52 | linop_str, model, loss_function, params, data, check_deterministic=False 53 | ) 54 | 55 | if isinstance(linop, KFACInverseLinearOperator): 56 | _ = linop._A.gradient_and_loss() 57 | else: 58 | _ = linop.gradient_and_loss() 59 | 60 | if is_cuda: 61 | cuda.synchronize() 62 | 63 | def f_precompute(): 64 | manual_seed(0) # make deterministic 65 | 66 | model, loss_function, params, data = setup_problem(problem_str, linop_str, dev) 67 | # NOTE Disable deterministic check as it will otherwise compute matvecs 68 | linop = setup_linop( 69 | linop_str, model, loss_function, params, data, check_deterministic=False 70 | ) 71 | 72 | if isinstance(linop, KFACLinearOperator): 73 | linop.compute_kronecker_factors() 74 | 75 | if isinstance(linop, KFACInverseLinearOperator): 76 | linop._A.compute_kronecker_factors() 77 | # damp and invert the Kronecker matrices 78 | for mod_name in linop._A._mapping: 79 | linop._compute_or_get_cached_inverse(mod_name) 80 | 81 | if is_cuda: 82 | cuda.synchronize() 83 | 84 | def f_matvec(): 85 | manual_seed(0) # make deterministic 86 | 87 | model, loss_function, params, data = setup_problem(problem_str, linop_str, dev) 88 | # NOTE Disable deterministic check as it will otherwise compute matvecs 89 | linop = setup_linop( 90 | linop_str, model, loss_function, params, data, check_deterministic=False 91 | ) 92 | v = rand(linop.shape[1], device=dev) 93 | 94 | if isinstance(linop, KFACLinearOperator): 95 | linop.compute_kronecker_factors() 96 | 97 | if isinstance(linop, KFACInverseLinearOperator): 98 | linop._A.compute_kronecker_factors() 99 | # damp and invert the Kronecker matrices 100 | for mod_name in linop._A._mapping: 101 | linop._compute_or_get_cached_inverse(mod_name) 102 | 103 | # Double-backward through efficient attention is unsupported, disable fused kernels 104 | # (https://github.com/pytorch/pytorch/issues/116350#issuecomment-1954667011) 105 | attention_double_backward = isinstance(linop, HAS_JVP) and isinstance( 106 | model, GPTWrapper 107 | ) 108 | with ( 109 | sdpa_kernel(SDPBackend.MATH) if attention_double_backward else nullcontext() 110 | ): 111 | _ = linop @ v 112 | 113 | if is_cuda: 114 | cuda.synchronize() 115 | 116 | func = { 117 | "gradient_and_loss": f_gradient_and_loss, 118 | "precompute": f_precompute, 119 | "matvec": f_matvec, 120 | }[op_str] 121 | 122 | if is_cuda: 123 | func() 124 | cuda.synchronize() 125 | peakmem_bytes = cuda.max_memory_allocated() 126 | cuda.reset_peak_memory_stats() 127 | else: 128 | peakmem_bytes = memory_usage(func, interval=1e-4, max_usage=True) * 2**20 129 | 130 | peakmem_gib = peakmem_bytes / 2**30 131 | print( 132 | f"[Memory] {linop_str}'s {op_str} on {problem_str} and {device_str}:" 133 | + f" {peakmem_gib:.2f} GiB" 134 | ) 135 | 136 | with open(savepath, "w") as f: 137 | json.dump({"peakmem": peakmem_gib}, f) 138 | 139 | 140 | if __name__ == "__main__": 141 | parser = ArgumentParser( 142 | description="Run memory benchmark for a given linear operator." 143 | ) 144 | parser.add_argument( 145 | "--linop", 146 | type=str, 147 | help="The linear operator class to benchmark.", 148 | choices=LINOP_STRS, 149 | ) 150 | parser.add_argument( 151 | "--problem", 152 | type=str, 153 | help="The problem to benchmark.", 154 | choices=PROBLEM_STRS, 155 | ) 156 | parser.add_argument( 157 | "--device", 158 | type=str, 159 | help="The device to benchmark.", 160 | ) 161 | parser.add_argument( 162 | "--op", 163 | type=str, 164 | help="The operation to benchmark.", 165 | choices=OP_STRS, 166 | ) 167 | 168 | args = parser.parse_args() 169 | run_peakmem_benchmark(args.linop, args.problem, args.device, args.op) 170 | -------------------------------------------------------------------------------- /docs/examples/basic_usage/example_visual_tour.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Visual tour of curvature matrices 3 | ================================= 4 | 5 | This tutorial visualizes different curvature matrices for a model with 6 | sufficiently small parameter space. 7 | 8 | First, the imports. 9 | """ 10 | 11 | from typing import Callable, Tuple 12 | 13 | import matplotlib.pyplot as plt 14 | from matplotlib.axes import Axes 15 | from matplotlib.figure import Figure 16 | from numpy import cumsum 17 | from torch import Tensor, cuda, device, eye, manual_seed, rand, randint 18 | from torch.nn import ( 19 | Conv2d, 20 | CrossEntropyLoss, 21 | Flatten, 22 | Linear, 23 | ReLU, 24 | Sequential, 25 | Sigmoid, 26 | ) 27 | from torch.utils.data import DataLoader, TensorDataset 28 | from tueplots import bundles 29 | 30 | from curvlinops import ( 31 | EFLinearOperator, 32 | FisherMCLinearOperator, 33 | GGNLinearOperator, 34 | HessianLinearOperator, 35 | KFACLinearOperator, 36 | ) 37 | 38 | # make deterministic 39 | manual_seed(0) 40 | 41 | DEVICE = device("cuda" if cuda.is_available() else "cpu") 42 | 43 | # %% 44 | # Setup 45 | # ----- 46 | # 47 | # We will create a synthetic classification task, a small CNN, and use 48 | # cross-entropy error as loss function. 49 | 50 | num_data = 50 51 | batch_size = 20 52 | in_channels = 3 53 | in_features_shape = (in_channels, 10, 10) 54 | num_classes = 5 55 | 56 | # dataset 57 | dataset = TensorDataset( 58 | rand(num_data, *in_features_shape), # X 59 | randint(size=(num_data,), low=0, high=num_classes), # y 60 | ) 61 | dataloader = DataLoader(dataset, batch_size=batch_size) 62 | 63 | # model 64 | model = Sequential( 65 | Conv2d(in_channels, 4, 3, padding=1), 66 | ReLU(), 67 | Conv2d(4, 4, 5, padding=2, stride=2), 68 | Sigmoid(), 69 | Conv2d(4, 1, 3, padding=1), 70 | Flatten(), 71 | Linear(25, num_classes), 72 | ).to(DEVICE) 73 | 74 | params = [p for p in model.parameters() if p.requires_grad] 75 | num_params = sum(p.numel() for p in params) 76 | num_params_layer = [ 77 | sum(p.numel() for p in child.parameters()) for child in model.children() 78 | ] 79 | num_tensors_layer = [len(list(child.parameters())) for child in model.children()] 80 | 81 | loss_function = CrossEntropyLoss(reduction="mean").to(DEVICE) 82 | 83 | print(f"Total parameters: {num_params}") 84 | print(f"Layer parameters: {num_params_layer}") 85 | 86 | # %% 87 | # Computation 88 | # ----------- 89 | # 90 | # We can now set up linear operators for the curvature matrices we want to 91 | # visualize, and compute them by multiplying the linear operator onto the 92 | # identity matrix. 93 | # 94 | # First, create the linear operators: 95 | 96 | Hessian_linop = HessianLinearOperator(model, loss_function, params, dataloader) 97 | GGN_linop = GGNLinearOperator(model, loss_function, params, dataloader) 98 | EF_linop = EFLinearOperator(model, loss_function, params, dataloader) 99 | Hessian_blocked_linop = HessianLinearOperator( 100 | model, 101 | loss_function, 102 | params, 103 | dataloader, 104 | block_sizes=[s for s in num_tensors_layer if s != 0], 105 | ) 106 | F_linop = FisherMCLinearOperator(model, loss_function, params, dataloader) 107 | KFAC_linop = KFACLinearOperator( 108 | model, loss_function, params, dataloader, separate_weight_and_bias=False 109 | ) 110 | 111 | # %% 112 | # 113 | # Then, compute the matrices 114 | 115 | identity = eye(num_params, device=DEVICE) 116 | 117 | Hessian_mat = Hessian_linop @ identity 118 | GGN_mat = GGN_linop @ identity 119 | EF_mat = EF_linop @ identity 120 | Hessian_blocked_mat = Hessian_blocked_linop @ identity 121 | F_mat = F_linop @ identity 122 | KFAC_mat = KFAC_linop @ identity 123 | 124 | # %% 125 | # Visualization 126 | # ------------- 127 | # 128 | # We will show the matrix entries on a shared domain for better comparability. 129 | 130 | matrices = [Hessian_mat, GGN_mat, EF_mat, Hessian_blocked_mat, F_mat, KFAC_mat] 131 | titles = [ 132 | "Hessian", 133 | "Generalized Gauss-Newton", 134 | "Empirical Fisher", 135 | "Block-diagonal Hessian", 136 | "Monte-Carlo Fisher", 137 | "KFAC", 138 | ] 139 | 140 | rows, columns = 2, 3 141 | 142 | 143 | def plot( 144 | transform: Callable[[Tensor], Tensor], transform_title: str = None 145 | ) -> Tuple[Figure, Axes]: 146 | """Visualize transformed curvature matrices using a shared domain. 147 | 148 | Args: 149 | transform: A transformation that will be applied to the matrices. Must 150 | accept a matrix and return a matrix of the same shape. 151 | transform_title: An optional string describing the transformation. 152 | Default: `None` (empty). 153 | 154 | Returns: 155 | Figure and axes of the created subplot. 156 | """ 157 | min_value = min(transform(mat).min() for mat in matrices) 158 | max_value = max(transform(mat).max() for mat in matrices) 159 | 160 | fig, axes = plt.subplots(nrows=rows, ncols=columns, sharex=True, sharey=True) 161 | fig.supxlabel("Layer") 162 | fig.supylabel("Layer") 163 | 164 | for idx, (ax, mat, title) in enumerate(zip(axes.flat, matrices, titles)): 165 | ax.set_title(title) 166 | img = ax.imshow(transform(mat), vmin=min_value, vmax=max_value) 167 | 168 | # layer blocks 169 | boundaries = [0] + cumsum(num_params_layer).tolist() 170 | for pos in boundaries: 171 | if pos not in [0, num_params]: 172 | style = {"color": "w", "lw": 0.5, "ls": "-"} 173 | ax.axhline(y=pos - 1, xmin=0, xmax=num_params - 1, **style) 174 | ax.axvline(x=pos - 1, ymin=0, ymax=num_params - 1, **style) 175 | 176 | # label positions 177 | label_positions = [ 178 | (boundaries[layer_idx] + boundaries[layer_idx + 1]) / 2 179 | for layer_idx in range(len(boundaries) - 1) 180 | if boundaries[layer_idx] != boundaries[layer_idx + 1] 181 | ] 182 | labels = [str(i + 1) for i in range(len(label_positions))] 183 | ax.set_xticks(label_positions) 184 | ax.set_xticklabels(labels) 185 | ax.set_yticks(label_positions) 186 | ax.set_yticklabels(labels) 187 | 188 | # colorbar 189 | last = idx == len(matrices) - 1 190 | if last: 191 | fig.colorbar( 192 | img, ax=axes.ravel().tolist(), label=transform_title, shrink=0.8 193 | ) 194 | 195 | return fig, axes 196 | 197 | 198 | # use `tueplots` to make the plot look pretty 199 | plot_config = bundles.icml2024(column="full", nrows=1.5 * rows, ncols=columns) 200 | 201 | # %% 202 | # 203 | # We will show their logarithmic absolute value: 204 | 205 | 206 | def logabs(mat: Tensor, epsilon: float = 1e-6) -> Tensor: 207 | return mat.abs().clamp(min=epsilon).log10() 208 | 209 | 210 | with plt.rc_context(plot_config): 211 | plot(logabs, transform_title="Logarithmic absolute entries") 212 | plt.savefig("curvature_matrices_log_abs.pdf", bbox_inches="tight") 213 | 214 | # %% 215 | # 216 | # That's because it is hard to recognize structure in the unaltered entries: 217 | 218 | 219 | def unchanged(mat): 220 | return mat 221 | 222 | 223 | with plt.rc_context(plot_config): 224 | plot(unchanged, transform_title="Unaltered matrix entries") 225 | 226 | # %% 227 | # 228 | # That's all for now. 229 | 230 | plt.close("all") 231 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | """Contains pytest fixtures that are visible by other files.""" 2 | 3 | from collections.abc import MutableMapping 4 | from typing import Callable, Dict, Iterable, List, Optional, Tuple 5 | 6 | from numpy import random 7 | from pytest import fixture 8 | from torch import Tensor, manual_seed 9 | from torch.nn import Module, MSELoss, Parameter 10 | 11 | import test.utils 12 | from test.cases import ( 13 | ADJOINT_CASES, 14 | ADJOINT_IDS, 15 | BLOCK_SIZES_FNS, 16 | CASES, 17 | CNN_CASES, 18 | INV_CASES, 19 | IS_VEC_IDS, 20 | IS_VECS, 21 | NON_DETERMINISTIC_CASES, 22 | ) 23 | from test.kfac_cases import ( 24 | KFAC_EXACT_CASES, 25 | KFAC_EXACT_ONE_DATUM_CASES, 26 | KFAC_WEIGHT_SHARING_EXACT_CASES, 27 | SINGLE_LAYER_CASES, 28 | SINGLE_LAYER_WEIGHT_SHARING_CASES, 29 | ) 30 | 31 | 32 | def initialize_case( 33 | case: Dict, 34 | ) -> Tuple[ 35 | Callable[[Tensor], Tensor], 36 | Callable[[Tensor, Tensor], Tensor], 37 | List[Tensor], 38 | Iterable[Tuple[Tensor, Tensor]], 39 | Optional[Callable[[MutableMapping], int]], 40 | ]: 41 | random.seed(case["seed"]) 42 | manual_seed(case["seed"]) 43 | 44 | model_func = case["model_func"]().to(case["device"]) 45 | loss_func = case["loss_func"]().to(case["device"]) 46 | params = [p for p in model_func.parameters() if p.requires_grad] 47 | data = case["data"]() 48 | 49 | # In some KFAC cases, 50 | # ``data = {KFACType.EXPAND: [(X, y), ...], KFACType.REDUCE: [(X, y), ...]}`` 51 | # unlike the standard ``data = [(X: Tensor | MutableMapping, y), ...]``. 52 | # We ignore the former since the latter is included in KFAC cases, and thus the 53 | # feature of ``MutableMapping`` inputs is sufficiently covered already. 54 | if not isinstance(data, dict) and isinstance(next(iter(data))[0], MutableMapping): 55 | batch_size_fn = test.utils.batch_size_fn 56 | else: 57 | batch_size_fn = None 58 | 59 | return model_func, loss_func, params, data, batch_size_fn 60 | 61 | 62 | @fixture(params=CASES) 63 | def case( 64 | request, 65 | ) -> Tuple[ 66 | Callable[[Tensor], Tensor], 67 | Callable[[Tensor, Tensor], Tensor], 68 | List[Tensor], 69 | Iterable[Tuple[Tensor, Tensor]], 70 | Optional[Callable[[MutableMapping], int]], 71 | ]: 72 | case = request.param 73 | yield initialize_case(case) 74 | 75 | 76 | @fixture(params=INV_CASES) 77 | def inv_case( 78 | request, 79 | ) -> Tuple[ 80 | Callable[[Tensor], Tensor], 81 | Callable[[Tensor, Tensor], Tensor], 82 | List[Tensor], 83 | Iterable[Tuple[Tensor, Tensor]], 84 | Optional[Callable[[MutableMapping], int]], 85 | ]: 86 | case = request.param 87 | yield initialize_case(case) 88 | 89 | 90 | @fixture(params=CNN_CASES) 91 | def cnn_case( 92 | request, 93 | ) -> Tuple[ 94 | Callable[[Tensor], Tensor], 95 | Callable[[Tensor, Tensor], Tensor], 96 | List[Tensor], 97 | Iterable[Tuple[Tensor, Tensor]], 98 | Optional[Callable[[MutableMapping], int]], 99 | ]: 100 | cnn_case = request.param 101 | yield initialize_case(cnn_case) 102 | 103 | 104 | @fixture(params=NON_DETERMINISTIC_CASES) 105 | def non_deterministic_case( 106 | request, 107 | ) -> Tuple[ 108 | Callable[[Tensor], Tensor], 109 | Callable[[Tensor, Tensor], Tensor], 110 | List[Tensor], 111 | Iterable[Tuple[Tensor, Tensor]], 112 | Optional[Callable[[MutableMapping], int]], 113 | ]: 114 | case = request.param 115 | yield initialize_case(case) 116 | 117 | 118 | @fixture(params=ADJOINT_CASES, ids=ADJOINT_IDS) 119 | def adjoint(request) -> bool: 120 | return request.param 121 | 122 | 123 | @fixture(params=IS_VECS, ids=IS_VEC_IDS) 124 | def is_vec(request) -> bool: 125 | """Whether to test matrix-vector or matrix-matrix multiplication. 126 | 127 | Args: 128 | request: Pytest request object. 129 | 130 | Returns: 131 | ``True`` if the test is for matrix-vector multiplication, ``False`` otherwise. 132 | """ 133 | return request.param 134 | 135 | 136 | @fixture(params=BLOCK_SIZES_FNS.values(), ids=BLOCK_SIZES_FNS.keys()) 137 | def block_sizes_fn(request) -> Callable[[List[Parameter]], Optional[List[int]]]: 138 | """Function to generate the ``block_sizes`` argument for a linear operator. 139 | 140 | Args: 141 | request: Pytest request object. 142 | 143 | Returns: 144 | A function that generates the block sizes for a linear operator from the 145 | parameters. 146 | """ 147 | return request.param 148 | 149 | 150 | @fixture(params=KFAC_EXACT_CASES) 151 | def kfac_exact_case( 152 | request, 153 | ) -> Tuple[ 154 | Module, 155 | MSELoss, 156 | List[Tensor], 157 | Iterable[Tuple[Tensor, Tensor]], 158 | Optional[Callable[[MutableMapping], int]], 159 | ]: 160 | """Prepare a test case for which KFAC equals the GGN. 161 | 162 | Yields: 163 | A neural network, the mean-squared error function, a list of parameters, and 164 | a data set. 165 | """ 166 | case = request.param 167 | yield initialize_case(case) 168 | 169 | 170 | @fixture(params=KFAC_WEIGHT_SHARING_EXACT_CASES) 171 | def kfac_weight_sharing_exact_case( 172 | request, 173 | ) -> Tuple[ 174 | Module, 175 | MSELoss, 176 | List[Tensor], 177 | Iterable[Tuple[Tensor, Tensor]], 178 | Optional[Callable[[MutableMapping], int]], 179 | ]: 180 | """Prepare a test case with weight-sharing for which KFAC equals the GGN. 181 | 182 | Yields: 183 | A neural network, the mean-squared error function, a list of parameters, and 184 | a data set. 185 | """ 186 | case = request.param 187 | yield initialize_case(case) 188 | 189 | 190 | @fixture(params=KFAC_EXACT_ONE_DATUM_CASES) 191 | def kfac_exact_one_datum_case( 192 | request, 193 | ) -> Tuple[ 194 | Module, 195 | Module, 196 | List[Tensor], 197 | Iterable[Tuple[Tensor, Tensor]], 198 | Optional[Callable[[MutableMapping], int]], 199 | ]: 200 | """Prepare a test case for which KFAC equals the GGN and one datum is used. 201 | 202 | Yields: 203 | A neural network, loss function, a list of parameters, and 204 | a data set with a single datum. 205 | """ 206 | case = request.param 207 | yield initialize_case(case) 208 | 209 | 210 | @fixture(params=SINGLE_LAYER_CASES) 211 | def single_layer_case( 212 | request, 213 | ) -> Tuple[ 214 | Module, 215 | Module, 216 | List[Tensor], 217 | Iterable[Tuple[Tensor, Tensor]], 218 | Optional[Callable[[MutableMapping], int]], 219 | ]: 220 | """Prepare a test case with a single-layer model for which FOOF is exact. 221 | 222 | Yields: 223 | A neural network, loss function, a list of parameters, and 224 | a data set with a single datum. 225 | """ 226 | case = request.param 227 | yield initialize_case(case) 228 | 229 | 230 | @fixture(params=SINGLE_LAYER_WEIGHT_SHARING_CASES) 231 | def single_layer_weight_sharing_case( 232 | request, 233 | ) -> Tuple[ 234 | Module, 235 | Module, 236 | List[Tensor], 237 | Iterable[Tuple[Tensor, Tensor]], 238 | Optional[Callable[[MutableMapping], int]], 239 | ]: 240 | """Test case with a single-layer model with weight-sharing for which FOOF is exact. 241 | 242 | Yields: 243 | A neural network, loss function, a list of parameters, and 244 | a data set with a single datum. 245 | """ 246 | case = request.param 247 | yield initialize_case(case) 248 | -------------------------------------------------------------------------------- /curvlinops/ggn.py: -------------------------------------------------------------------------------- 1 | """Contains LinearOperator implementation of the GGN.""" 2 | 3 | from collections.abc import MutableMapping 4 | from functools import cached_property, partial 5 | from typing import Callable, List, Tuple, Union 6 | 7 | from torch import Tensor, no_grad, vmap 8 | from torch.func import jacrev, jvp, vjp 9 | from torch.nn import Module, Parameter 10 | 11 | from curvlinops._torch_base import CurvatureLinearOperator 12 | from curvlinops.utils import make_functional_model_and_loss 13 | 14 | 15 | def make_ggn_vector_product( 16 | f: Callable[..., Tensor], c: Callable[..., Tensor], num_c_extra_args: int = 0 17 | ) -> Callable[..., Tuple[Tensor, ...]]: 18 | """Create a function that computes GGN-vector products for given f and c functions. 19 | 20 | Args: 21 | f: Function that takes parameters and input, returns prediction. 22 | Signature: (*params, X) -> prediction 23 | c: Function that takes prediction, target, and optional additional args. 24 | Signature: (prediction, y, *args) -> loss 25 | num_c_extra_args: Number of additional arguments that the loss function c expects 26 | beyond prediction and target. Used to correctly split the input arguments 27 | between the vector to multiply and the additional loss function arguments. 28 | 29 | Returns: 30 | A function that computes GGN-vector products. 31 | Signature: (params, X, y, *c_args, *v) -> GGN @ v 32 | where c_args are additional arguments passed to the loss function c. 33 | """ 34 | 35 | @no_grad() 36 | def ggn_vector_product( 37 | params: Tuple[Tensor, ...], 38 | X: Tensor, 39 | y: Tensor, 40 | *args_and_v: Tuple[Tensor, ...], 41 | ) -> Tuple[Tensor, ...]: 42 | """Multiply the GGN on a vector in list format. 43 | 44 | Args: 45 | params: Parameters of the model. 46 | X: Input to the DNN. 47 | y: Ground truth. 48 | *args_and_v: Additional arguments for the loss function c, 49 | followed by vector to be multiplied with in tensor list format. 50 | 51 | Returns: 52 | Result of GGN multiplication in list format. Has the same shape as 53 | the vector part of args_and_v. 54 | """ 55 | # Split args_and_v into additional loss function arguments and vector v 56 | c_args, v = args_and_v[:num_c_extra_args], args_and_v[num_c_extra_args:] 57 | 58 | # Apply the Jacobian of f onto v: v → Jv 59 | f_val, f_jvp = jvp(lambda *params_inner: f(*params_inner, X), params, v) 60 | 61 | # Apply the criterion's Hessian onto Jv: Jv → HJv 62 | c_grad_func = jacrev(lambda pred: c(pred, y, *c_args)) 63 | _, c_hvp = jvp(c_grad_func, (f_val,), (f_jvp,)) 64 | 65 | # Apply the transposed Jacobian of f onto HJv: HJv → JᵀHJv 66 | # NOTE This re-evaluates the net's forward pass. [Unverified] It should be op- 67 | # timized away by common sub-expression elimination if you compile the function. 68 | _, f_vjp_func = vjp(lambda *params_inner: f(*params_inner, X), *params) 69 | return f_vjp_func(c_hvp) 70 | 71 | return ggn_vector_product 72 | 73 | 74 | def make_batch_ggn_matrix_product( 75 | model_func: Module, loss_func: Module, params: Tuple[Parameter, ...] 76 | ) -> Callable[ 77 | [Union[Tensor, MutableMapping], Tensor, Tuple[Tensor, ...]], Tuple[Tensor, ...] 78 | ]: 79 | r"""Set up function that multiplies the mini-batch GGN onto a matrix in list format. 80 | 81 | Args: 82 | model_func: The neural network :math:`f_{\mathbf{\theta}}`. 83 | loss_func: The loss function :math:`\ell`. 84 | params: A tuple of parameters w.r.t. which the GGN is computed. 85 | All parameters must be part of ``model_func.parameters()``. 86 | 87 | Returns: 88 | A function that takes inputs ``X``, ``y``, and a matrix ``M`` in list 89 | format, and returns the mini-batch GGN applied to ``M`` in list format. 90 | """ 91 | # Create functional versions of the model (f: *params, X -> prediction) and 92 | # criterion function (c: prediction, y -> loss) 93 | f, c = make_functional_model_and_loss(model_func, loss_func, params) 94 | 95 | # Create the functional GGN-vector product 96 | ggn_vp = make_ggn_vector_product(f, c) # params, X, y, *v -> *Gv 97 | 98 | # Fix the parameters 99 | ggnvp = partial(ggn_vp, params) # X, y, *c_args, *v -> *Gv 100 | 101 | # Parallelize over vectors to multiply onto a matrix in list format 102 | list_format_vmap_dims = tuple(p.ndim for p in params) # last axis 103 | return vmap( 104 | ggnvp, 105 | # No vmap in X, y, last-axis vmap over vector in list format 106 | in_dims=(None, None, *list_format_vmap_dims), 107 | # Vmapped output axis is last 108 | out_dims=list_format_vmap_dims, 109 | # We want each vector to be multiplied with the same mini-batch GGN 110 | randomness="same", 111 | ) 112 | 113 | 114 | class GGNLinearOperator(CurvatureLinearOperator): 115 | r"""Linear operator for the generalized Gauss-Newton matrix of an empirical risk. 116 | 117 | Consider the empirical risk 118 | 119 | .. math:: 120 | \mathcal{L}(\mathbf{\theta}) 121 | = 122 | c \sum_{n=1}^{N} 123 | \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) 124 | 125 | with :math:`c = \frac{1}{N}` for ``reduction='mean'`` and :math:`c=1` for 126 | ``reduction='sum'``. The GGN matrix is 127 | 128 | .. math:: 129 | c \sum_{n=1}^{N} 130 | \left( 131 | \mathbf{J}_{\mathbf{\theta}} 132 | f_{\mathbf{\theta}}(\mathbf{x}_n) 133 | \right)^\top 134 | \left( 135 | \nabla_{f_\mathbf{\theta}(\mathbf{x}_n)}^2 136 | \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) 137 | \right) 138 | \left( 139 | \mathbf{J}_{\mathbf{\theta}} 140 | f_{\mathbf{\theta}}(\mathbf{x}_n) 141 | \right)\,. 142 | 143 | Attributes: 144 | SELF_ADJOINT: Whether the linear operator is self-adjoint. ``True`` for GGNs. 145 | """ 146 | 147 | SELF_ADJOINT: bool = True 148 | 149 | @cached_property 150 | def _mp( 151 | self, 152 | ) -> Callable[ 153 | [Union[Tensor, MutableMapping], Tensor, Tuple[Tensor, ...]], Tuple[Tensor, ...] 154 | ]: 155 | """Lazy initialization of batch-GGN matrix product function. 156 | 157 | Returns: 158 | Function that computes mini-batch GGN-vector products, given inputs ``X``, 159 | labels ``y``, and the entries ``v1, v2, ...`` of the vector in list format. 160 | Produces a list of tensors with the same shape as the input vector that re- 161 | presents the result of the batch-GGN multiplication. 162 | """ 163 | return make_batch_ggn_matrix_product( 164 | self._model_func, self._loss_func, tuple(self._params) 165 | ) 166 | 167 | def _matmat_batch( 168 | self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor] 169 | ) -> List[Tensor]: 170 | """Apply the mini-batch GGN to a matrix. 171 | 172 | Args: 173 | X: Input to the DNN. 174 | y: Ground truth. 175 | M: Matrix to be multiplied with in tensor list format. 176 | Tensors have same shape as trainable model parameters, and an 177 | additional trailing axis for the matrix columns. 178 | 179 | Returns: 180 | Result of GGN multiplication in list format. Has the same shape as 181 | ``M``, i.e. each tensor in the list has the shape of a parameter and a 182 | trailing dimension of matrix columns. 183 | """ 184 | return list(self._mp(X, y, *M)) 185 | -------------------------------------------------------------------------------- /curvlinops/gradient_moments.py: -------------------------------------------------------------------------------- 1 | """Contains linear operator implementation of gradient moment matrices.""" 2 | 3 | from collections.abc import MutableMapping 4 | from functools import cached_property, partial 5 | from typing import Callable, List, Tuple, Union 6 | 7 | from einops import einsum 8 | from torch import Tensor, vmap 9 | from torch.func import grad 10 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss, Parameter 11 | 12 | from curvlinops._torch_base import CurvatureLinearOperator 13 | from curvlinops.ggn import make_ggn_vector_product 14 | from curvlinops.utils import make_functional_flattened_model_and_loss 15 | 16 | 17 | def make_batch_ef_matrix_product( 18 | model_func: Module, loss_func: Module, params: Tuple[Parameter, ...] 19 | ) -> Callable[ 20 | [Union[Tensor, MutableMapping], Tensor, Tuple[Tensor, ...]], Tuple[Tensor, ...] 21 | ]: 22 | r"""Set up function that multiplies the mini-batch empirical Fisher onto a matrix. 23 | 24 | The empirical Fisher is computed as the GGN of a pseudo-loss that is quadratic 25 | in the gradients of the original loss. Specifically, for loss gradients 26 | :math:`g_n = \nabla_f \ell(f_n, y_n)`, the pseudo-loss is: 27 | 28 | .. math:: 29 | L'(\mathbf{\theta}) = \frac{1}{2c} \sum_{n=1}^{N} \langle f_n, g_n \rangle^2 30 | 31 | where :math:`c` is the reduction factor and :math:`f_n = f_{\mathbf{\theta}}(x_n)`. 32 | The GGN of this pseudo-loss equals the empirical Fisher of the original loss. 33 | 34 | Args: 35 | model_func: The neural network :math:`f_{\mathbf{\theta}}`. 36 | loss_func: The loss function :math:`\ell`. 37 | params: A tuple of parameters w.r.t. which the empirical Fisher is computed. 38 | All parameters must be part of ``model_func.parameters()``. 39 | 40 | Returns: 41 | A function that takes inputs ``X``, ``y``, and a matrix ``M`` in list 42 | format, and returns the mini-batch empirical Fisher applied to ``M`` in 43 | list format. 44 | """ 45 | f_flat, c_flat = make_functional_flattened_model_and_loss( 46 | model_func, loss_func, params 47 | ) 48 | # function that computes gradients of the loss w.r.t. the flattened outputs 49 | c_flat_grad = grad(c_flat, argnums=0) 50 | 51 | def c_pseudo_flat(output_flat: Tensor, y: Tensor) -> Tensor: 52 | """Compute pseudo-loss: L' = 0.5 / c * sum_n ^2. 53 | 54 | This pseudo-loss L' := 0.5 / c ∑ₙ fₙᵀ (gₙ gₙᵀ) fₙ where gₙ = ∂ℓₙ/∂fₙ 55 | (detached). The GGN of L' linearized at fₙ is the empirical Fisher. 56 | We can thus multiply with the EF by computing the GGN-vector products of L'. 57 | 58 | The reduction factor adjusts the scale depending on the loss reduction used. 59 | 60 | Args: 61 | output_flat: Flattened model outputs for the mini-batch. 62 | y: Un-flattened labels for the mini-batch. 63 | 64 | Returns: 65 | The pseudo-loss whose GGN is the empirical Fisher on the batch. 66 | """ 67 | # Compute ∂ℓₙ/∂fₙ without reduction factor of L (detached) 68 | grad_output_flat = c_flat_grad(output_flat.detach(), y) 69 | 70 | # Adjust the scale depending on the loss reduction used 71 | num_loss_terms, C = output_flat.shape 72 | reduction_factor = { 73 | "mean": ( 74 | num_loss_terms 75 | if isinstance(loss_func, CrossEntropyLoss) 76 | else num_loss_terms * C 77 | ), 78 | "sum": 1.0, 79 | }[loss_func.reduction] 80 | 81 | # compute the pseudo-loss 82 | grad_output_flat = grad_output_flat * reduction_factor 83 | inner_products = einsum(output_flat, grad_output_flat, "n ..., n ... -> n") 84 | return 0.5 / reduction_factor * (inner_products**2).sum() 85 | 86 | # Create the functional EF-vector product using GGN of pseudo-loss 87 | ef_vp = make_ggn_vector_product(f_flat, c_pseudo_flat) 88 | 89 | # Freeze parameter values 90 | efvp = partial(ef_vp, params) # X, y, *v -> *EFv 91 | 92 | # Parallelize over vectors to multiply onto a matrix in list format 93 | list_format_vmap_dims = tuple(p.ndim for p in params) # last axis 94 | return vmap( 95 | efvp, 96 | # No vmap in X, y, assume last axis is vmapped in the matrix list 97 | in_dims=(None, None, *list_format_vmap_dims), 98 | # Vmapped output axis is last 99 | out_dims=list_format_vmap_dims, 100 | # We want each vector to be multiplied with the same mini-batch EF 101 | randomness="same", 102 | ) 103 | 104 | 105 | class EFLinearOperator(CurvatureLinearOperator): 106 | r"""Uncentered gradient covariance as PyTorch linear operator. 107 | 108 | The uncentered gradient covariance is often called 'empirical Fisher' (EF). 109 | 110 | Consider the empirical risk 111 | 112 | .. math:: 113 | \mathcal{L}(\mathbf{\theta}) 114 | = 115 | c \sum_{n=1}^{N} 116 | \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) 117 | 118 | with :math:`c = \frac{1}{N}` for ``reduction='mean'`` and :math:`c=1` for 119 | ``reduction='sum'``. The uncentered gradient covariance matrix is 120 | 121 | .. math:: 122 | c \sum_{n=1}^{N} 123 | \left( 124 | \nabla_{\mathbf{\theta}} 125 | \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) 126 | \right) 127 | \left( 128 | \nabla_{\mathbf{\theta}} 129 | \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) 130 | \right)^\top\,. 131 | 132 | Attributes: 133 | SELF_ADJOINT: Whether the linear operator is self-adjoint. ``True`` for 134 | empirical Fisher. 135 | """ 136 | 137 | SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss, BCEWithLogitsLoss) 138 | SELF_ADJOINT: bool = True 139 | 140 | @cached_property 141 | def _mp( 142 | self, 143 | ) -> Callable[ 144 | [Union[Tensor, MutableMapping], Tensor, Tuple[Tensor, ...]], Tuple[Tensor, ...] 145 | ]: 146 | """Lazy initialization of the batch empirical Fisher matrix product function. 147 | 148 | Returns: 149 | Function that computes mini-batch EF-vector products, given inputs ``X``, 150 | labels ``y``, and the entries ``v1, v2, ...`` of the vector in list format. 151 | Produces a list of tensors with the same shape as the input vector that re- 152 | presents the result of the batch-EF multiplication. 153 | 154 | Raises: 155 | NotImplementedError: If the loss function is not supported. 156 | """ 157 | if not isinstance(self._loss_func, self.SUPPORTED_LOSSES): 158 | raise NotImplementedError( 159 | f"Loss must be one of {self.SUPPORTED_LOSSES}. Got: {self._loss_func}." 160 | ) 161 | return make_batch_ef_matrix_product( 162 | self._model_func, self._loss_func, tuple(self._params) 163 | ) 164 | 165 | def _matmat_batch( 166 | self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor] 167 | ) -> List[Tensor]: 168 | """Apply the mini-batch empirical Fisher to a matrix in tensor list format. 169 | 170 | Args: 171 | X: Input to the DNN. 172 | y: Ground truth. 173 | M: Matrix to be multiplied with in tensor list format. 174 | Tensors have same shape as trainable model parameters, and an 175 | additional trailing axis for the matrix columns. 176 | 177 | Returns: 178 | Result of EF multiplication in tensor list format. Has the same shape as 179 | ``M``, i.e. each tensor in the list has the shape of a parameter and a 180 | trailing dimension of matrix columns. 181 | """ 182 | return list(self._mp(X, y, *M)) 183 | -------------------------------------------------------------------------------- /docs/examples/basic_usage/example_matrix_vector_products.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Matrix-vector products 3 | ====================== 4 | 5 | This tutorial contains a basic demonstration how to set up ``LinearOperators`` 6 | for the Hessian and the GGN and how to multiply them to a vector. 7 | 8 | First, the imports. 9 | """ 10 | 11 | import matplotlib.pyplot as plt 12 | from torch import cat, cuda, device, eye, manual_seed, nn, rand 13 | 14 | from curvlinops import GGNLinearOperator, HessianLinearOperator 15 | from curvlinops.examples.functorch import functorch_ggn, functorch_hessian 16 | from curvlinops.utils import allclose_report 17 | 18 | # make deterministic 19 | manual_seed(0) 20 | 21 | # %% 22 | # Setup 23 | # ----- 24 | # Let's create some toy data, a small MLP, and use mean-squared error as loss function. 25 | 26 | N = 4 27 | D_in = 7 28 | D_hidden = 5 29 | D_out = 3 30 | 31 | DEVICE = device("cuda" if cuda.is_available() else "cpu") 32 | 33 | X = rand(N, D_in, device=DEVICE) 34 | y = rand(N, D_out, device=DEVICE) 35 | 36 | model = nn.Sequential( 37 | nn.Linear(D_in, D_hidden), 38 | nn.ReLU(), 39 | nn.Linear(D_hidden, D_hidden), 40 | nn.Sigmoid(), 41 | nn.Linear(D_hidden, D_out), 42 | ).to(DEVICE) 43 | params = [p for p in model.parameters() if p.requires_grad] 44 | 45 | loss_function = nn.MSELoss(reduction="mean").to(DEVICE) 46 | 47 | 48 | # %% 49 | # Hessian-vector products 50 | # ----------------------- 51 | # 52 | # Setting up a linear operator for the Hessian is straightforward. 53 | 54 | data = [(X, y)] 55 | H = HessianLinearOperator(model, loss_function, params, data) 56 | 57 | # %% 58 | # 59 | # We can now multiply the Hessian onto a vector. 60 | 61 | D = H.shape[0] 62 | v = rand(D, device=DEVICE) 63 | 64 | Hv = H @ v 65 | 66 | # %% 67 | # 68 | # To verify the result, we compute the Hessian using ``functorch``, using a 69 | # utility function from ``curvlinops.examples``: 70 | 71 | H_mat = functorch_hessian(model, loss_function, params, data).detach() 72 | 73 | # %% 74 | # 75 | # Let's check that the multiplication onto ``v`` leads to the same result: 76 | 77 | Hv_functorch = H_mat @ v 78 | 79 | print("Comparing Hessian-vector product with functorch's Hessian-vector product.") 80 | assert allclose_report(Hv, Hv_functorch) 81 | 82 | 83 | # %% 84 | # Hessian-matrix products 85 | # ----------------------- 86 | # 87 | # We can also compute the Hessian's matrix representation with the linear 88 | # operator, simply by multiplying it onto the identity matrix. (Of course, this 89 | # only works if the Hessian is small enough.) 90 | H_mat_from_linop = H @ eye(D, device=DEVICE) 91 | 92 | # %% 93 | # 94 | # This should yield the same matrix as with :code:`functorch`. 95 | 96 | print("Comparing Hessian with functorch's Hessian.") 97 | assert allclose_report(H_mat, H_mat_from_linop) 98 | 99 | # %% 100 | # 101 | # Last, here's a visualization of the Hessian. 102 | 103 | plt.figure() 104 | plt.title("Hessian") 105 | plt.imshow(H_mat) 106 | plt.colorbar() 107 | 108 | # %% 109 | # Accepted vector/matrix formats 110 | # ------------------------------ 111 | # 112 | # Curvature matrices are usually defined w.r.t. parameters of a neural net. In PyTorch, 113 | # these parameters are split into multiple tensors (e.g. per layer). It is often more 114 | # convenient to think and work with vectors/matrices defined in this list format, rather 115 | # than in the flattened-and-concatenated parameter space. 116 | # 117 | # So far, we have only used vectors/matrices in the flattened-and-concatenated format. 118 | # To account for the often more convenient list format, all linear operators in 119 | # ``curvlinops`` can also handle vectors/matrices specified in tensor list format. 120 | # In that format, a matrix is a list of tensors, each of which has the same shape as 121 | # its corresponding parameter plus an additional trailing dimension for the matrix's 122 | # column dimension. 123 | # 124 | # ``curvlinops`` preserves the format when performing matrix multiplies: If the input 125 | # lived in the flattened-and-concatenated parameter space, the result will be as well. 126 | # If the input lived in the tensor list parameter space, the result will be a tensor 127 | # list as well. 128 | # 129 | # Let's make this concrete. First, set up the same matrix in flattened and list format: 130 | 131 | num_columns = 3 132 | 133 | print(f"Total network parameters: {D}") 134 | print(f"Parameter shapes: {[p.shape for p in params]}") 135 | print(f"Number of columns: {num_columns}") 136 | 137 | # Matrix in tensor list format 138 | M_list = [rand(*p.shape, num_columns, device=DEVICE) for p in params] 139 | print(f"[Tensor list format] Matrix: {[m.shape for m in M_list]}") 140 | 141 | # Matrix in flattened format (what we have been using before) 142 | M_flat = cat([m.flatten(end_dim=-2) for m in M_list]) 143 | print(f"[Flat format] Matrix: {M_flat.shape}") 144 | 145 | # %% 146 | # 147 | # Next, let's carry out the Hessian-matrix product and inspect the result's format: 148 | 149 | HM_list = H @ M_list 150 | print(f"[Tensor list format] Hessian-matrix product: {[hm.shape for hm in HM_list]}") 151 | 152 | HM_flat = H @ M_flat 153 | print(f"[Flat format] Hessian-matrix product: {HM_flat.shape}") 154 | 155 | # %% 156 | # 157 | # As expected, this produces the same result: 158 | 159 | HM_list_flattened = cat([hm.flatten(end_dim=-2) for hm in HM_list]) 160 | 161 | print("Comparing Hessian-matrix products across formats.") 162 | assert allclose_report(HM_flat, HM_list_flattened) 163 | 164 | # %% 165 | # 166 | # **Note:** Like in the early part of the tutorial, the column dimension is not 167 | # necessary if we just want to multiply the Hessian onto a single vector. 168 | # 169 | # GGN-vector products 170 | # ------------------- 171 | # 172 | # Setting up a linear operator for the Fisher/GGN is identical to the Hessian. 173 | 174 | GGN = GGNLinearOperator(model, loss_function, params, data) 175 | 176 | # %% 177 | # 178 | # This is one of ``curvlinops``'s design features: All linear operators share the same 179 | # interface, making it easy to switch between curvature matrices. 180 | # 181 | # Let's compute a GGN-vector product. 182 | 183 | D = H.shape[0] 184 | v = rand(D, device=DEVICE) 185 | 186 | GGNv = GGN @ v 187 | 188 | # %% 189 | # 190 | # To verify the result, we will use ``functorch`` to compute the GGN. For that, 191 | # we use that the GGN corresponds to the Hessian if we replace the neural 192 | # network by its linearization. This is implemented in a utility function of 193 | # :code:`curvlinops.examples`: 194 | 195 | GGN_mat = functorch_ggn(model, loss_function, params, data).detach() 196 | 197 | GGNv_functorch = GGN_mat @ v 198 | 199 | print("Comparing GGN-vector product with functorch's GGN-vector product.") 200 | assert allclose_report(GGNv, GGNv_functorch) 201 | 202 | # %% 203 | # GGN-matrix products 204 | # ------------------- 205 | # 206 | # We can also compute the GGN matrix representation with the linear operator, 207 | # simply by multiplying it onto the identity matrix. (Of course, this only 208 | # works if the GGN is small enough.) 209 | GGN_mat_from_linop = GGN @ eye(D, device=DEVICE) 210 | 211 | # %% 212 | # 213 | # This should yield the same matrix as with :code:`functorch`. 214 | 215 | print("Comparing GGN with functorch's GGN.") 216 | assert allclose_report(GGN_mat, GGN_mat_from_linop) 217 | 218 | # %% 219 | # 220 | # Last, here's a visualization of the GGN. 221 | 222 | plt.figure() 223 | plt.title("GGN") 224 | plt.imshow(GGN_mat) 225 | plt.colorbar() 226 | 227 | # %% 228 | # Visual comparison: Hessian and GGN 229 | # ---------------------------------- 230 | # 231 | # To conclude, let's plot both the Hessian and GGN using the same limits 232 | 233 | min_value = min(GGN_mat.min(), H_mat.min()) 234 | max_value = max(GGN_mat.max(), H_mat.max()) 235 | 236 | fig, ax = plt.subplots(ncols=2) 237 | ax[0].set_title("Hessian") 238 | ax[0].imshow(H_mat, vmin=min_value, vmax=max_value) 239 | ax[1].set_title("GGN") 240 | ax[1].imshow(GGN_mat, vmin=min_value, vmax=max_value) 241 | -------------------------------------------------------------------------------- /docs/examples/basic_usage/example_model_merging.py: -------------------------------------------------------------------------------- 1 | r"""Fisher-weighted Model Averaging 2 | =================================== 3 | 4 | In this example we implement Fisher-weighted model averaging, a technique 5 | described in this `NeurIPS 2022 paper `_. 6 | It requires Fisher-vector products, and multiplication with the inverse of a 7 | sum of Fisher matrices. The paper uses a diagonal approximation of the Fisher 8 | matrices. Instead, we will use the exact Fisher matrices and rely on 9 | matrix-free methods for applying the inverse. 10 | 11 | .. note:: 12 | In our setup, the Fisher equals the generalized Gauss-Newton matrix. 13 | Hence, we work with :py:class:`curvlinops.GGNLinearOperator`. 14 | 15 | **Description:** We are given a set of :math:`T` tasks (represented by data 16 | sets :math:`\mathcal{D}_t`), and train a model :math:`f_\mathbf{\theta}` on 17 | each task independently using the same criterion function. This yields 18 | :math:`T` parameters :math:`\mathbf{\theta}_1^\star, \dots, 19 | \mathbf{\theta}_T^\star`, and we would like to combine them into a single model 20 | :math:`f_\mathbf{\theta^\star}`. To do that, we use the Fisher information 21 | matrices :math:`\mathbf{F}_t` of each task (given by the data set 22 | :math:`\mathcal{D}_t` and the trained model parameters 23 | :math:`\mathbf{\theta}_t^\star`). The merged parameters are given by 24 | 25 | .. math:: 26 | \mathbf{\theta}^\star = \left(\lambda \mathbf{I} 27 | + \sum_{t=1}^T \mathbf{F}_t \right)^{-1} 28 | \left( \sum_{t=1}^T \mathbf{F}_t \mathbf{\theta}_t^\star\right)\,. 29 | 30 | This requires multiplying with the inverse of the sum of Fisher matrices 31 | (extended with a damping term). We will use 32 | :py:class:`curvlinops.CGInverseLinearOperator` for that. 33 | 34 | Let's start with the imports. 35 | """ 36 | 37 | from backpack.utils.convert_parameters import vector_to_parameter_list 38 | from torch import cuda, device, manual_seed, rand 39 | from torch.nn import Linear, MSELoss, ReLU, Sequential, Sigmoid 40 | from torch.nn.utils import parameters_to_vector 41 | from torch.optim import SGD 42 | from torch.utils.data import DataLoader, TensorDataset 43 | 44 | from curvlinops import CGInverseLinearOperator, GGNLinearOperator 45 | from curvlinops.examples import IdentityLinearOperator 46 | 47 | # make deterministic 48 | manual_seed(0) 49 | 50 | DEVICE = device("cuda" if cuda.is_available() else "cpu") 51 | 52 | # %% 53 | # 54 | # Setup 55 | # ----- 56 | # 57 | # First, we will create a bunch of synthetic regression tasks (i.e. data sets) 58 | # and an untrained model for each of them. 59 | 60 | T = 3 # number of tasks 61 | D_in = 7 # input dimension of each task 62 | D_hidden = 5 # hidden dimension of the architecture we will use 63 | D_out = 3 # output dimension of each task 64 | N = 20 # number of data per task 65 | batch_size = 7 66 | 67 | 68 | def make_architecture() -> Sequential: 69 | """Create a neural network. 70 | 71 | Returns: 72 | A neural network. 73 | """ 74 | return Sequential( 75 | Linear(D_in, D_hidden), 76 | ReLU(), 77 | Linear(D_hidden, D_hidden), 78 | Sigmoid(), 79 | Linear(D_hidden, D_out), 80 | ) 81 | 82 | 83 | def make_dataset() -> TensorDataset: 84 | """Create a synthetic regression data set. 85 | 86 | Returns: 87 | A synthetic regression data set. 88 | """ 89 | X, y = rand(N, D_in), rand(N, D_out) 90 | return TensorDataset(X, y) 91 | 92 | 93 | models = [make_architecture().to(DEVICE) for _ in range(T)] 94 | data_loaders = [DataLoader(make_dataset(), batch_size=batch_size) for _ in range(T)] 95 | loss_functions = [MSELoss(reduction="mean").to(DEVICE) for _ in range(T)] 96 | 97 | # %% 98 | # 99 | # Training 100 | # -------- 101 | # 102 | # Here, we train each model for a small number of epochs. 103 | 104 | num_epochs = 10 105 | log_epochs = [0, num_epochs - 1] 106 | 107 | for task_idx in range(T): 108 | model = models[task_idx] 109 | data_loader = data_loaders[task_idx] 110 | loss_function = loss_functions[task_idx] 111 | optimizer = SGD(model.parameters(), lr=1e-2) 112 | 113 | for epoch in range(num_epochs): 114 | for batch_idx, (X, y) in enumerate(data_loader): 115 | optimizer.zero_grad() 116 | X, y = X.to(DEVICE), y.to(DEVICE) 117 | loss = loss_function(model(X), y) 118 | loss.backward() 119 | optimizer.step() 120 | 121 | if epoch in log_epochs and batch_idx == 0: 122 | print(f"Task {task_idx} batch loss at epoch {epoch}: {loss.item():.3f}") 123 | 124 | # %% 125 | # 126 | # Linear operators 127 | # ---------------- 128 | # 129 | # We are now ready to set up the linear operators for the per-task Fishers: 130 | 131 | fishers = [ 132 | GGNLinearOperator( 133 | model, 134 | loss_function, 135 | [p for p in model.parameters() if p.requires_grad], 136 | data_loader, 137 | ) 138 | for model, loss_function, data_loader in zip(models, loss_functions, data_loaders) 139 | ] 140 | 141 | # %% 142 | # 143 | # Fisher-weighted Averaging 144 | # ------------------------- 145 | # 146 | # Next, we also need the trained parameters as vectors: 147 | 148 | # flatten and concatenate 149 | thetas = [ 150 | parameters_to_vector((p for p in model.parameters() if p.requires_grad)).detach() 151 | for model in models 152 | ] 153 | 154 | # %% 155 | # 156 | # We are ready to compute the sum of Fisher-weighted parameters (the right-hand 157 | # side in the above equation): 158 | 159 | rhs = sum(fisher @ theta for fisher, theta in zip(fishers, thetas)) 160 | 161 | # %% 162 | # 163 | # In the last step we need to normalize by multiplying with the inverse of the 164 | # summed Fishers. Let's first create the linear operator and add a damping 165 | # term: 166 | 167 | dim = fishers[0].shape[0] 168 | param_shapes = [p.shape for p in models[0].parameters() if p.requires_grad] 169 | identity = IdentityLinearOperator(param_shapes, DEVICE, rhs.dtype) 170 | damping = 1e-3 171 | 172 | fisher_sum = damping * identity 173 | 174 | for fisher in fishers: 175 | fisher_sum += fisher 176 | 177 | # %% 178 | # 179 | # Finally, we define a linear operator for the inverse of the damped Fisher sum: 180 | 181 | fisher_sum_inv = CGInverseLinearOperator(fisher_sum) 182 | 183 | # %% 184 | # 185 | # .. note:: 186 | # You may want to tweak the convergence criterion of CG using 187 | # :py:func:`curvlinops.CGInverseLinearOperator.set_cg_hyperparameters`. before 188 | # applying the matrix-vector product. 189 | 190 | fisher_weighted_params = fisher_sum_inv @ rhs 191 | 192 | # %% 193 | # 194 | # Comparison 195 | # ---------- 196 | # 197 | # Let's compare the performance of the Fisher-averaged parameters with a naive 198 | # average. 199 | 200 | average_params = sum(thetas) / len(thetas) 201 | 202 | # %% 203 | # 204 | # We initialize two neural networks with those parameters 205 | 206 | fisher_model = make_architecture() 207 | 208 | params = [p for p in fisher_model.parameters() if p.requires_grad] 209 | theta_fisher = vector_to_parameter_list(fisher_weighted_params, params) 210 | for theta, param in zip(theta_fisher, params): 211 | param.data = theta.to(param.device, param.dtype).data 212 | 213 | # same for the average-weighted parameters 214 | average_model = make_architecture() 215 | 216 | params = [p for p in average_model.parameters() if p.requires_grad] 217 | theta_average = vector_to_parameter_list(average_params, params) 218 | for theta, param in zip(theta_average, params): 219 | param.data = theta.to(param.device, param.dtype).data 220 | 221 | # %% 222 | # 223 | # and probe them on one batch of each task: 224 | 225 | for task_idx in range(T): 226 | data_loader = data_loaders[task_idx] 227 | loss_function = loss_functions[task_idx] 228 | 229 | X, y = next(iter(data_loader)) 230 | X, y = X.to(DEVICE), y.to(DEVICE) 231 | 232 | fisher_loss = loss_function(fisher_model(X), y) 233 | average_loss = loss_function(average_model(X), y) 234 | assert fisher_loss < average_loss 235 | 236 | print(f"Task {task_idx} batch loss with Fisher averaging: {fisher_loss.item():.3f}") 237 | print(f"Task {task_idx} batch loss with naive averaging: {average_loss.item():.3f}") 238 | 239 | # %% 240 | # 241 | # The Fisher-averaged parameters perform better than the naively averaged 242 | # parameters; at least on the training data. 243 | # 244 | # That's all for now. 245 | -------------------------------------------------------------------------------- /curvlinops/hessian.py: -------------------------------------------------------------------------------- 1 | """Contains a linear operator implementation of the Hessian.""" 2 | 3 | from collections.abc import MutableMapping 4 | from functools import cached_property, partial 5 | from typing import Callable, List, Optional, Tuple, Union 6 | 7 | from torch import Tensor, no_grad, vmap 8 | from torch.func import jacrev, jvp 9 | from torch.nn import Module, Parameter 10 | 11 | from curvlinops._torch_base import CurvatureLinearOperator 12 | from curvlinops.utils import make_functional_model_and_loss, split_list 13 | 14 | 15 | def make_batch_hessian_matrix_product( 16 | model_func: Module, 17 | loss_func: Module, 18 | params: Tuple[Parameter, ...], 19 | block_sizes: Optional[List[int]] = None, 20 | ) -> Callable[[Tensor, Tensor, Tuple[Tensor, ...]], Tuple[Tensor, ...]]: 21 | r"""Set up function that multiplies the mini-batch Hessian onto a matrix in list format. 22 | 23 | Args: 24 | model_func: The neural network :math:`f_{\mathbf{\theta}}`. 25 | loss_func: The loss function :math:`\ell`. 26 | params: A tuple of parameters w.r.t. which the Hessian is computed. 27 | All parameters must be part of ``model_func.parameters()``. 28 | block_sizes: Sizes of parameter blocks for block-diagonal approximation. 29 | If ``None``, the full Hessian is used. 30 | 31 | Returns: 32 | A function that takes inputs ``X``, ``y``, and a matrix ``M`` in list 33 | format, and returns the mini-batch Hessian applied to ``M`` in list format. 34 | """ 35 | # Determine block structure 36 | block_sizes = [len(params)] if block_sizes is None else block_sizes 37 | 38 | # Create block-specific functional calls: *block_params, X -> prediction 39 | block_params = split_list(list(params), block_sizes) 40 | block_functionals = [] 41 | 42 | for block in block_params: 43 | # criterion functional c is the same for all blocks 44 | f_block, c = make_functional_model_and_loss(model_func, loss_func, tuple(block)) 45 | block_functionals.append(f_block) 46 | 47 | @no_grad() 48 | def hessian_vector_product( 49 | X: Tensor, y: Tensor, *v: Tuple[Tensor, ...] 50 | ) -> Tuple[Tensor, ...]: 51 | """Multiply the mini-batch Hessian on a vector in list format. 52 | 53 | Args: 54 | X: Input to the DNN. 55 | y: Ground truth. 56 | *v: Vector to be multiplied with in tensor list format. 57 | 58 | Returns: 59 | Result of Hessian multiplication in list format. Has the same shape as 60 | ``v``, i.e. each tensor in the list has the shape of a parameter. 61 | """ 62 | # Split input vectors by blocks 63 | v_blocks = split_list(list(v), block_sizes) 64 | 65 | # Set up loss functions for each block 66 | block_grad_fns = [] 67 | 68 | def loss_fn( 69 | f: Callable[[Tuple[Tensor, ...], Union[Tensor, MutableMapping]], Tensor], 70 | *params: Tuple[Tensor, ...], 71 | ) -> Tensor: 72 | """Compute the mini-batch loss given the neural net and its parameters. 73 | 74 | Args: 75 | f: Functional model with signature (*params, X) -> prediction 76 | *params: Parameters for the functional model. 77 | 78 | Returns: 79 | Mini-batch loss. 80 | """ 81 | return c(f(*params, X), y) 82 | 83 | for f_block, ps in zip(block_functionals, block_params): 84 | # Define the loss function composition for this block 85 | block_loss_fn = partial(loss_fn, f_block) 86 | block_grad_fn = jacrev(block_loss_fn, argnums=tuple(range(len(ps)))) 87 | block_grad_fns.append(block_grad_fn) 88 | 89 | # Compute the HVPs per block and concatenate the results 90 | hvps = [] 91 | for grad_fn, ps, vs in zip(block_grad_fns, block_params, v_blocks): 92 | _, hvp_block = jvp(grad_fn, tuple(ps), tuple(vs)) 93 | hvps.extend(hvp_block) 94 | 95 | return tuple(hvps) 96 | 97 | # Parallelize over vectors to multiply onto a matrix in list format 98 | list_format_vmap_dims = tuple(p.ndim for p in params) # last axis 99 | return vmap( 100 | hessian_vector_product, 101 | # No vmap in X, y, last-axis vmap over vector in list format 102 | in_dims=(None, None, *list_format_vmap_dims), 103 | # Vmapped output axis is last 104 | out_dims=list_format_vmap_dims, 105 | # We want each vector to be multiplied with the same mini-batch Hessian 106 | randomness="same", 107 | ) 108 | 109 | 110 | class HessianLinearOperator(CurvatureLinearOperator): 111 | r"""Linear operator for the Hessian of an empirical risk. 112 | 113 | Consider the empirical risk 114 | 115 | .. math:: 116 | \mathcal{L}(\mathbf{\theta}) 117 | = 118 | c \sum_{n=1}^{N} 119 | \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n) 120 | 121 | with :math:`c = \frac{1}{N}` for ``reduction='mean'`` and :math:`c=1` for 122 | ``reduction='sum'``. The Hessian matrix is 123 | 124 | .. math:: 125 | \nabla^2_{\mathbf{\theta}} \mathcal{L} 126 | = 127 | c \sum_{n=1}^{N} 128 | \nabla_{\mathbf{\theta}}^2 129 | \ell(f_{\mathbf{\theta}}(\mathbf{x}_n), \mathbf{y}_n)\,. 130 | 131 | Example: 132 | >>> from torch import rand, eye, allclose, kron, manual_seed 133 | >>> from torch.nn import Linear, MSELoss 134 | >>> from curvlinops import HessianLinearOperator 135 | >>> 136 | >>> # Create a simple linear model without bias 137 | >>> _ = manual_seed(0) # make deterministic 138 | >>> D_in, D_out = 4, 2 139 | >>> num_data, num_batches = 10, 3 140 | >>> model = Linear(D_in, D_out, bias=False) 141 | >>> params = list(model.parameters()) 142 | >>> loss_func = MSELoss(reduction='sum') 143 | >>> 144 | >>> # Generate synthetic dataset and chunk into batches 145 | >>> X, y = rand(num_data, D_in), rand(num_data, D_out) 146 | >>> data = list(zip(X.split(num_batches), y.split(num_batches))) 147 | >>> 148 | >>> # Create Hessian linear operator 149 | >>> H_op = HessianLinearOperator(model, loss_func, params, data) 150 | >>> 151 | >>> # Compare with the known Hessian matrix 2 I ⊗ Xᵀ X 152 | >>> H_mat = 2 * kron(eye(D_out), X.T @ X) 153 | >>> P = sum(p.numel() for p in params) 154 | >>> v = rand(P) # generate a random vector 155 | >>> (H_mat @ v).allclose(H_op @ v) 156 | True 157 | 158 | Attributes: 159 | SUPPORTS_BLOCKS: Whether the linear operator supports block operations. 160 | Default is ``True``. 161 | SELF_ADJOINT: Whether the linear operator is self-adjoint (``True`` for 162 | Hessians). 163 | """ 164 | 165 | SELF_ADJOINT: bool = True 166 | SUPPORTS_BLOCKS: bool = True 167 | 168 | @cached_property 169 | def _mp( 170 | self, 171 | ) -> Callable[ 172 | [Union[Tensor, MutableMapping], Tensor, Tuple[Tensor, ...]], Tuple[Tensor, ...] 173 | ]: 174 | """Lazy initialization of batch-Hessian matrix product function. 175 | 176 | Returns: 177 | Function that computes mini-batch Hessian-vector products, given inputs 178 | ``X``, labels ``y``, and the entries ``v1, v2, ...`` of the vector in list 179 | format. Produces a list of tensors with the same shape as the input vector 180 | that represents the result of the batch-Hessian multiplication. 181 | """ 182 | return make_batch_hessian_matrix_product( 183 | self._model_func, self._loss_func, tuple(self._params), self._block_sizes 184 | ) 185 | 186 | def _matmat_batch( 187 | self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor] 188 | ) -> List[Tensor]: 189 | """Apply the mini-batch Hessian to a matrix. 190 | 191 | Args: 192 | X: Input to the DNN. 193 | y: Ground truth. 194 | M: Matrix to be multiplied with in tensor list format. 195 | Tensors have same shape as trainable model parameters, and an 196 | additional trailing axis for the matrix columns. 197 | 198 | Returns: 199 | Result of Hessian multiplication in list format. Has the same shape as 200 | ``M``, i.e. each tensor in the list has the shape of a parameter and a 201 | trailing dimension of matrix columns. 202 | """ 203 | return list(self._mp(X, y, *M)) 204 | -------------------------------------------------------------------------------- /docs/examples/basic_usage/example_submatrices.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Sub-matrices of linear operators 3 | ================================ 4 | 5 | This tutorial explains how to create linear operators that correspond to a sub-matrix 6 | of another linear operator. 7 | 8 | Specifically, given the linear operator :code:`A`, we are 9 | interested in constructing the linear operator that corresponds to its sub-matrix 10 | :code:`A[row_idxs, :][:, col_idxs]`, where :code:`row_idxs` contains the sub-matrix's 11 | row indices, and :code:`col_idxs` contains the sub-matrix's column indices. 12 | 13 | First, the imports. 14 | """ 15 | 16 | from time import time 17 | from typing import List 18 | 19 | from torch import Tensor, cuda, device, eye, manual_seed, rand 20 | from torch.nn import Linear, MSELoss, ReLU, Sequential, Sigmoid 21 | 22 | from curvlinops import HessianLinearOperator 23 | from curvlinops.examples.functorch import functorch_hessian 24 | from curvlinops.submatrix import SubmatrixLinearOperator 25 | from curvlinops.utils import allclose_report 26 | 27 | # make deterministic 28 | manual_seed(0) 29 | 30 | # %% 31 | # 32 | # Setup 33 | # ----- 34 | # 35 | # Let's create some toy data, a small MLP, and use mean-squared error as loss function. 36 | 37 | N = 4 38 | D_in = 7 39 | D_hidden = 5 40 | D_out = 3 41 | 42 | DEVICE = device("cuda" if cuda.is_available() else "cpu") 43 | 44 | X1, y1 = rand(N, D_in, device=DEVICE), rand(N, D_out, device=DEVICE) 45 | X2, y2 = rand(N, D_in, device=DEVICE), rand(N, D_out, device=DEVICE) 46 | data = [(X1, y1), (X2, y2)] 47 | 48 | model = Sequential( 49 | Linear(D_in, D_hidden), 50 | ReLU(), 51 | Linear(D_hidden, D_hidden), 52 | Sigmoid(), 53 | Linear(D_hidden, D_out), 54 | ).to(DEVICE) 55 | params = [p for p in model.parameters() if p.requires_grad] 56 | 57 | loss_function = MSELoss(reduction="mean").to(DEVICE) 58 | 59 | # %% 60 | # 61 | # We will investigate the Hessian. To make sure our results are correct, let's keep 62 | # a Hessian matrix computed via :mod:`functorch` around. 63 | 64 | H_functorch = functorch_hessian(model, loss_function, params, data) 65 | 66 | # %% 67 | # 68 | # Here is the corresponding linear operator and a quick check that builds up 69 | # its matrix representation through multiplication with the identity matrix, 70 | # followed by comparison to the Hessian matrix computed via :mod:`functorch`. 71 | 72 | H = HessianLinearOperator(model, loss_function, params, data) 73 | 74 | num_params = sum(p.numel() for p in params) 75 | identity = eye(num_params, device=DEVICE) 76 | assert allclose_report(H_functorch, H @ identity) 77 | 78 | # %% 79 | # 80 | # Diagonal blocks 81 | # --------------- 82 | # 83 | # The Hessian consists of blocks :code:`(i, j)` that contain the second-order 84 | # derivatives of the loss w.r.t. the parameters in :code:`(params[i], params[j])`. 85 | # 86 | # Let's define a function to extract these blocks from the Hessian: 87 | 88 | 89 | def extract_block(mat: Tensor, params: List[Tensor], i: int, j: int) -> Tensor: 90 | """Extract the Hessian block from parameters ``i`` and ``j``. 91 | 92 | Args: 93 | mat: The matrix with block structure. 94 | params: The parameters defining the blocks. 95 | i: Row index of the block to be extracted. 96 | j: Column index of the block to be extracted. 97 | 98 | Returns: 99 | Block ``(i, j)``. Has shape ``[params[i].numel(), params[j].numel()]``. 100 | """ 101 | param_dims = [p.numel() for p in params] 102 | row_start, row_end = sum(param_dims[:i]), sum(param_dims[: i + 1]) 103 | col_start, col_end = sum(param_dims[:j]), sum(param_dims[: j + 1]) 104 | 105 | return mat[row_start:row_end, :][:, col_start:col_end] 106 | 107 | 108 | # %% 109 | # 110 | # As an example, let's extract the block that corresponds to the Hessian w.r.t. 111 | # the first layer's weights in our model. 112 | 113 | i, j = 0, 0 114 | H_param0_functorch = extract_block(H_functorch, params, i, j) 115 | 116 | # %% 117 | # 118 | # We can build a linear operator for this sub-Hessian by only providing the 119 | # first layer's weight as parameter: 120 | 121 | H_param0 = HessianLinearOperator(model, loss_function, [params[i]], data) 122 | 123 | # %% 124 | # 125 | # Like this we can get blocks from the diagonal. 126 | # 127 | # Let's check that this linear operator works as expected by multiplying it 128 | # onto the identity matrix and comparing the result to the block we extracted 129 | # from our ground truth: 130 | 131 | assert allclose_report( 132 | H_param0_functorch, H_param0 @ eye(params[i].numel(), device=DEVICE) 133 | ) 134 | 135 | # %% 136 | # 137 | # Now you might be wondering if we can also build up linear operators for 138 | # off-diagonal blocks. These blocks contain mixed second-order derivatives and 139 | # are not Hessians anymore. For instance, such a block is rectangular in 140 | # general, and thus non-symmetric. Since we are not asking for a Hessian 141 | # anymore, we cannot use the interface of :class:`HessianLinearOperator`. 142 | # 143 | # Luckily, there is a different way to achieve this. 144 | # 145 | # Off-diagonal blocks 146 | # ------------------- 147 | # 148 | # As an example,let's try to extract the Hessian block from the first and 149 | # second parameters in our network (i.e. the weights and biases in the first 150 | # layer). For that we need to slice the Hessian differently along its rows and 151 | # columns. We can use the :class:`curvlinops.SubmatrixLinearOperator` class for 152 | # that: 153 | 154 | param_dims = [p.numel() for p in params] 155 | i, j = 0, 1 156 | row_start, row_end = sum(param_dims[:i]), sum(param_dims[: i + 1]) 157 | col_start, col_end = sum(param_dims[:j]), sum(param_dims[: j + 1]) 158 | 159 | row_idxs = list(range(row_start, row_end)) # keep the following row indices 160 | col_idxs = list(range(col_start, col_end)) # keep the following column indices 161 | 162 | H_param0_param1 = SubmatrixLinearOperator(H, row_idxs, col_idxs) 163 | 164 | # %% 165 | # 166 | # As the following test shows, this linear operator indeed represents the 167 | # desired rectangular Hessian block: 168 | 169 | H_param0_param1_functorch = extract_block(H_functorch, params, i, j) 170 | 171 | assert allclose_report( 172 | H_param0_param1_functorch, 173 | H_param0_param1_functorch @ eye(param_dims[j], device=DEVICE), 174 | ) 175 | 176 | # %% 177 | # 178 | # Arbitrary sub-matrices 179 | # ---------------------- 180 | # 181 | # So far, we were constrained to blocks spanned by parameter tensors rather 182 | # than arbitrary elements. As the name :class:`SubmatrixLinearOperator` 183 | # suggests, we can use it to create arbitrary sub-matrices. 184 | # 185 | # As an example, let's say we want to keep rows :code:`[0, 13, 42]` of the 186 | # Hessian, and columns :code:`[1, 2, 3]`. This works as follows: 187 | 188 | row_idxs = [0, 13, 42] # keep the following row indices 189 | col_idxs = [1, 2, 3] # keep the following column indices 190 | 191 | H_sub = SubmatrixLinearOperator(H, row_idxs, col_idxs) 192 | H_sub_functorch = H_functorch[row_idxs, :][:, col_idxs] 193 | 194 | # %% 195 | # 196 | # Quick check to see if it worked: 197 | 198 | assert allclose_report(H_sub_functorch, H_sub @ eye(len(col_idxs), device=DEVICE)) 199 | 200 | # %% 201 | # 202 | # Looks good. 203 | # 204 | # Performance remarks 205 | # ---------------------- 206 | # 207 | # By the way, using this interface, we could have also constructed the first 208 | # parameter's Hessian as follows: 209 | 210 | i, j = 0, 0 211 | row_start, row_end = sum(param_dims[:i]), sum(param_dims[: i + 1]) 212 | col_start, col_end = sum(param_dims[:j]), sum(param_dims[: j + 1]) 213 | 214 | row_idxs = list(range(row_start, row_end)) 215 | col_idxs = list(range(col_start, col_end)) 216 | 217 | H_param0_alternative = SubmatrixLinearOperator(H, row_idxs, col_idxs) 218 | 219 | assert allclose_report( 220 | H_param0_functorch, H_param0_alternative @ eye(param_dims[0], device=DEVICE) 221 | ) 222 | 223 | # %% 224 | # 225 | # In general though, it is a good idea to first reduce the linear operator's 226 | # size as much as possible (in our case, by restricting the parameters to the 227 | # necessary ones using the :code:`params` argument in 228 | # :class:`HessianLinearOperator`) and apply slicing afterwards to save 229 | # computations. 230 | # 231 | # In our example, the matrix-vector product of :code:`H_param0` should 232 | # therefore be faster than that of :code:`H_param0_alternative`: 233 | 234 | x = rand(param_dims[0], device=DEVICE) 235 | 236 | # less computations 237 | start = time() 238 | _ = H_param0 @ x 239 | end = time() 240 | print(f"H_param0.matvec: {end - start:.2e} s") 241 | 242 | # more computations 243 | start = time() 244 | _ = H_param0_alternative @ x 245 | end = time() 246 | print(f"H_param0_alternative.matvec: {end - start:.2e} s") 247 | 248 | # %% 249 | # 250 | # That's all for now. 251 | -------------------------------------------------------------------------------- /docs/examples/basic_usage/example_eigenvalues.py: -------------------------------------------------------------------------------- 1 | r"""Eigenvalues 2 | =============== 3 | 4 | This example demonstrates how to compute a subset of eigenvalues of a linear 5 | operator, using :func:`scipy.sparse.linalg.eigsh`. Concretely, we will compute 6 | leading eigenvalues of the Hessian. 7 | 8 | As always, imports go first. 9 | """ 10 | 11 | from contextlib import redirect_stderr 12 | from io import StringIO 13 | from typing import List, Tuple 14 | 15 | import numpy 16 | import scipy 17 | import torch 18 | from torch import nn 19 | 20 | from curvlinops import HessianLinearOperator 21 | from curvlinops.examples.functorch import functorch_hessian 22 | from curvlinops.utils import allclose_report 23 | 24 | # make deterministic 25 | torch.manual_seed(0) 26 | numpy.random.seed(0) 27 | 28 | # %% 29 | # 30 | # Setup 31 | # ----- 32 | # 33 | # We will use synthetic data, consisting of two mini-batches, a small MLP, and 34 | # mean-squared error as loss function. 35 | 36 | N = 20 37 | D_in = 7 38 | D_hidden = 5 39 | D_out = 3 40 | 41 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | 43 | X1, y1 = torch.rand(N, D_in).to(DEVICE), torch.rand(N, D_out).to(DEVICE) 44 | X2, y2 = torch.rand(N, D_in).to(DEVICE), torch.rand(N, D_out).to(DEVICE) 45 | 46 | model = nn.Sequential( 47 | nn.Linear(D_in, D_hidden), 48 | nn.ReLU(), 49 | nn.Linear(D_hidden, D_hidden), 50 | nn.Sigmoid(), 51 | nn.Linear(D_hidden, D_out), 52 | ).to(DEVICE) 53 | params = [p for p in model.parameters() if p.requires_grad] 54 | 55 | loss_function = nn.MSELoss(reduction="mean").to(DEVICE) 56 | 57 | 58 | # %% 59 | # 60 | # Linear operator 61 | # ------------------ 62 | # 63 | # We are ready to setup the linear operator. In this example, we will use the Hessian. 64 | 65 | data = [(X1, y1), (X2, y2)] 66 | H = HessianLinearOperator(model, loss_function, params, data).to_scipy() 67 | 68 | # %% 69 | # 70 | # Leading eigenvalues 71 | # ------------------- 72 | # 73 | # Through :func:`scipy.sparse.linalg.eigsh`, we can obtain the leading 74 | # :math:`k=3` eigenvalues. 75 | 76 | k = 3 77 | which = "LM" # largest magnitude 78 | top_k_evals, _ = scipy.sparse.linalg.eigsh(H, k=k, which=which) 79 | 80 | print(f"Leading {k} Hessian eigenvalues: {top_k_evals}") 81 | 82 | # %% 83 | # 84 | # Verifying results 85 | # ----------------- 86 | # 87 | # To double-check this result, let's compute the Hessian with 88 | # :code:`functorch`, compute all its eigenvalues with 89 | # :func:`scipy.linalg.eigh`, then extract the top :math:`k`. 90 | 91 | H_functorch = functorch_hessian(model, loss_function, params, data).detach() 92 | evals_functorch, _ = torch.linalg.eigh(H_functorch) 93 | top_k_evals_functorch = evals_functorch[-k:] 94 | 95 | print(f"Leading {k} Hessian eigenvalues (functorch): {top_k_evals_functorch}") 96 | 97 | # %% 98 | # 99 | # Both results should match. 100 | 101 | print(f"Comparing leading {k} Hessian eigenvalues (linear operator vs. functorch).") 102 | assert allclose_report(top_k_evals, top_k_evals_functorch.double()) 103 | 104 | # %% 105 | # 106 | # :func:`scipy.sparse.linalg.eigsh` can also compute other subsets of 107 | # eigenvalues, and also their associated eigenvectors. Check out its 108 | # documentation for more! 109 | 110 | 111 | # %% 112 | # 113 | # Power iteration versus ``eigsh`` 114 | # -------------------------------- 115 | # 116 | # Here, we compare the query efficiency of :func:`scipy.sparse.linalg.eigsh` with the 117 | # `power iteration `_ method, a simple 118 | # method to compute the leading eigenvalues (in terms of magnitude). We re-use the im- 119 | # plementation from the `PyHessian library `_ 120 | # and adapt it to work with SciPy arrays rather than PyTorch tensors: 121 | 122 | 123 | def power_method( 124 | A: scipy.sparse.linalg.LinearOperator, 125 | max_iterations: int = 100, 126 | tol: float = 1e-3, 127 | k: int = 1, 128 | ) -> Tuple[numpy.ndarray, numpy.ndarray]: 129 | """Compute the top-k eigenpairs of a linear operator using power iteration. 130 | 131 | Code modified from PyHessian, see 132 | https://github.com/amirgholami/PyHessian/blob/72e5f0a0d06142387fccdab2226b4c6bae088202/pyhessian/hessian.py#L111-L156 133 | 134 | Args: 135 | A: Linear operator of dimension ``D`` whose top eigenpairs will be computed. 136 | max_iterations: Maximum number of iterations. Defaults to ``100``. 137 | tol: Relative tolerance between two consecutive iterations that has to be 138 | reached for convergence. Defaults to ``1e-3``. 139 | k: Number of eigenpairs to compute. Defaults to ``1``. 140 | 141 | Returns: 142 | The eigenvalues as array of shape ``[k]`` in descending order, and their 143 | corresponding eigenvectors as array of shape ``[D, k]``. 144 | """ 145 | eigenvalues = [] 146 | eigenvectors = [] 147 | 148 | def normalize(v: numpy.ndarray) -> numpy.ndarray: 149 | return v / numpy.linalg.norm(v) 150 | 151 | def orthonormalize(v: numpy.ndarray, basis: List[numpy.ndarray]) -> numpy.ndarray: 152 | for basis_vector in basis: 153 | v -= numpy.dot(v, basis_vector) * basis_vector 154 | return normalize(v) 155 | 156 | computed_dim = 0 157 | while computed_dim < k: 158 | eigenvalue = None 159 | v = normalize(numpy.random.randn(A.shape[0])) 160 | 161 | for _ in range(max_iterations): 162 | v = orthonormalize(v, eigenvectors) 163 | Av = A @ v 164 | 165 | tmp_eigenvalue = v.dot(Av) 166 | v = normalize(Av) 167 | 168 | if eigenvalue is None: 169 | eigenvalue = tmp_eigenvalue 170 | elif abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) + 1e-6) < tol: 171 | break 172 | else: 173 | eigenvalue = tmp_eigenvalue 174 | 175 | eigenvalues.append(eigenvalue) 176 | eigenvectors.append(v) 177 | computed_dim += 1 178 | 179 | # sort in ascending order and convert into arrays 180 | eigenvalues = numpy.array(eigenvalues[::-1]) 181 | eigenvectors = numpy.array(eigenvectors[::-1]) 182 | 183 | return eigenvalues, eigenvectors 184 | 185 | 186 | # %% 187 | # 188 | # Let's compute the top-3 eigenvalues via power iteration and verify they roughly match. 189 | # Note that we are using a smaller :code:`tol` value than the PyHessian default value 190 | # here to get better convergence, and we have to use relatively large tolerances for the 191 | # comparison (which we didn't do when comparing :code:`eigsh` with :code:`eigh`). 192 | 193 | top_k_evals_power, _ = power_method(H, tol=1e-4, k=k) 194 | print(f"Comparing leading {k} Hessian eigenvalues (eigsh vs. power).") 195 | assert allclose_report( 196 | top_k_evals_functorch.double(), top_k_evals_power, rtol=2e-2, atol=1e-6 197 | ) 198 | 199 | # %% 200 | # 201 | # This indicates that the power method achieves poorer accuracy than :code:`eigsh`. But 202 | # does it therefore require fewer matrix-vector products? To answer this, let's turn on 203 | # the linear operator's progress bar, which allows us to count the number of 204 | # matrix-vector products invoked by both eigen-solvers: 205 | 206 | H = HessianLinearOperator( 207 | model, loss_function, params, data, progressbar=True 208 | ).to_scipy() 209 | 210 | # determine number of matrix-vector products used by `eigsh` 211 | with StringIO() as buf, redirect_stderr(buf): 212 | top_k_evals, _ = scipy.sparse.linalg.eigsh(H, k=k, which=which) 213 | # The tqdm progressbar will print "matmat" for each batch in a matrix-vector 214 | # product. Therefore, we need to divide by the number of batches 215 | queries_eigsh = buf.getvalue().count("matmat") // len(data) 216 | print(f"eigsh used {queries_eigsh} matrix-vector products.") 217 | 218 | # determine number of matrix-vector products used by power iteration 219 | with StringIO() as buf, redirect_stderr(buf): 220 | top_k_evals_power, _ = power_method(H, k=k, tol=1e-4) 221 | # The tqdm progressbar will print "matmat" for each batch in a matrix-vector 222 | # product. Therefore, we need to divide by the number of batches 223 | queries_power = buf.getvalue().count("matmat") // len(data) 224 | print(f"Power iteration used {queries_power} matrix-vector products.") 225 | 226 | assert queries_power > queries_eigsh 227 | 228 | # %% 229 | # 230 | # Sadly, the power iteration also does not offer computational benefits, consuming 231 | # more matrix-vector products than :code:`eigsh`. While it is elegant and simple, 232 | # it cannot compete with :code:`eigsh`, at least in the comparison provided here. 233 | # 234 | # Therefore, we recommend using :code:`eigsh` for computing eigenvalues. This method 235 | # becomes accessible because :code:`curvlinops` interfaces with SciPy's linear 236 | # operators. 237 | --------------------------------------------------------------------------------