├── .conda_env.yml ├── .coveragerc ├── .github └── workflows │ ├── lint.yaml │ ├── python-publish.yml │ └── test.yaml ├── .gitignore ├── .readthedocs.yml ├── LICENSE.txt ├── README-dev.md ├── README.md ├── backpack ├── __init__.py ├── context.py ├── core │ ├── __init__.py │ └── derivatives │ │ ├── __init__.py │ │ ├── adaptive_avg_pool_nd.py │ │ ├── avgpool1d.py │ │ ├── avgpool2d.py │ │ ├── avgpool3d.py │ │ ├── avgpoolnd.py │ │ ├── basederivatives.py │ │ ├── batchnorm_nd.py │ │ ├── bcewithlogitsloss.py │ │ ├── conv1d.py │ │ ├── conv2d.py │ │ ├── conv3d.py │ │ ├── conv_transpose1d.py │ │ ├── conv_transpose2d.py │ │ ├── conv_transpose3d.py │ │ ├── conv_transposend.py │ │ ├── convnd.py │ │ ├── crossentropyloss.py │ │ ├── dropout.py │ │ ├── elementwise.py │ │ ├── elu.py │ │ ├── embedding.py │ │ ├── flatten.py │ │ ├── leakyrelu.py │ │ ├── linear.py │ │ ├── logsigmoid.py │ │ ├── lstm.py │ │ ├── maxpool1d.py │ │ ├── maxpool2d.py │ │ ├── maxpool3d.py │ │ ├── maxpoolnd.py │ │ ├── mseloss.py │ │ ├── nll_base.py │ │ ├── pad.py │ │ ├── permute.py │ │ ├── relu.py │ │ ├── rnn.py │ │ ├── scale_module.py │ │ ├── selu.py │ │ ├── shape_check.py │ │ ├── sigmoid.py │ │ ├── slicing.py │ │ ├── sum_module.py │ │ ├── tanh.py │ │ └── zeropad2d.py ├── custom_module │ ├── __init__.py │ ├── branching.py │ ├── graph_utils.py │ ├── pad.py │ ├── permute.py │ ├── reduce_tuple.py │ ├── scale_module.py │ └── slicing.py ├── extensions │ ├── __init__.py │ ├── backprop_extension.py │ ├── curvature.py │ ├── curvmatprod │ │ ├── __init__.py │ │ ├── ggnmp │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── batchnorm1d.py │ │ │ ├── conv2d.py │ │ │ ├── dropout.py │ │ │ ├── flatten.py │ │ │ ├── ggnmpbase.py │ │ │ ├── linear.py │ │ │ ├── losses.py │ │ │ ├── padding.py │ │ │ └── pooling.py │ │ ├── hmp │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── batchnorm1d.py │ │ │ ├── conv2d.py │ │ │ ├── dropout.py │ │ │ ├── flatten.py │ │ │ ├── hmpbase.py │ │ │ ├── linear.py │ │ │ ├── losses.py │ │ │ ├── padding.py │ │ │ └── pooling.py │ │ └── pchmp │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── conv2d.py │ │ │ ├── dropout.py │ │ │ ├── flatten.py │ │ │ ├── linear.py │ │ │ ├── losses.py │ │ │ ├── padding.py │ │ │ ├── pchmpbase.py │ │ │ └── pooling.py │ ├── firstorder │ │ ├── __init__.py │ │ ├── base.py │ │ ├── batch_grad │ │ │ ├── __init__.py │ │ │ ├── batch_grad_base.py │ │ │ ├── batchnorm_nd.py │ │ │ ├── conv1d.py │ │ │ ├── conv2d.py │ │ │ ├── conv3d.py │ │ │ ├── conv_transpose1d.py │ │ │ ├── conv_transpose2d.py │ │ │ ├── conv_transpose3d.py │ │ │ ├── embedding.py │ │ │ ├── linear.py │ │ │ └── rnn.py │ │ ├── batch_l2_grad │ │ │ ├── __init__.py │ │ │ ├── batch_l2_base.py │ │ │ ├── batchnorm_nd.py │ │ │ ├── convnd.py │ │ │ ├── convtransposend.py │ │ │ ├── embedding.py │ │ │ ├── linear.py │ │ │ └── rnn.py │ │ ├── gradient │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── batchnorm_nd.py │ │ │ ├── conv1d.py │ │ │ ├── conv2d.py │ │ │ ├── conv3d.py │ │ │ ├── convtranspose1d.py │ │ │ ├── convtranspose2d.py │ │ │ ├── convtranspose3d.py │ │ │ ├── embedding.py │ │ │ ├── linear.py │ │ │ └── rnn.py │ │ ├── sum_grad_squared │ │ │ ├── __init__.py │ │ │ ├── batchnorm_nd.py │ │ │ ├── conv1d.py │ │ │ ├── conv2d.py │ │ │ ├── conv3d.py │ │ │ ├── convtranspose1d.py │ │ │ ├── convtranspose2d.py │ │ │ ├── convtranspose3d.py │ │ │ ├── embedding.py │ │ │ ├── linear.py │ │ │ ├── rnn.py │ │ │ └── sgs_base.py │ │ └── variance │ │ │ ├── __init__.py │ │ │ ├── batchnorm_nd.py │ │ │ ├── conv1d.py │ │ │ ├── conv2d.py │ │ │ ├── conv3d.py │ │ │ ├── convtranspose1d.py │ │ │ ├── convtranspose2d.py │ │ │ ├── convtranspose3d.py │ │ │ ├── embedding.py │ │ │ ├── linear.py │ │ │ ├── rnn.py │ │ │ └── variance_base.py │ ├── mat_to_mat_jac_base.py │ ├── module_extension.py │ ├── saved_quantities.py │ └── secondorder │ │ ├── __init__.py │ │ ├── base.py │ │ ├── diag_ggn │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── adaptive_avg_pool_nd.py │ │ ├── batchnorm_nd.py │ │ ├── conv1d.py │ │ ├── conv2d.py │ │ ├── conv3d.py │ │ ├── convnd.py │ │ ├── convtranspose1d.py │ │ ├── convtranspose2d.py │ │ ├── convtranspose3d.py │ │ ├── convtransposend.py │ │ ├── custom_module.py │ │ ├── diag_ggn_base.py │ │ ├── dropout.py │ │ ├── embedding.py │ │ ├── flatten.py │ │ ├── linear.py │ │ ├── losses.py │ │ ├── pad.py │ │ ├── padding.py │ │ ├── permute.py │ │ ├── pooling.py │ │ ├── rnn.py │ │ └── slicing.py │ │ ├── diag_hessian │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── adaptive_avg_pool_nd.py │ │ ├── conv1d.py │ │ ├── conv2d.py │ │ ├── conv3d.py │ │ ├── convnd.py │ │ ├── convtranspose1d.py │ │ ├── convtranspose2d.py │ │ ├── convtranspose3d.py │ │ ├── convtransposend.py │ │ ├── diag_h_base.py │ │ ├── dropout.py │ │ ├── flatten.py │ │ ├── linear.py │ │ ├── losses.py │ │ ├── pad.py │ │ ├── padding.py │ │ ├── pooling.py │ │ └── slicing.py │ │ ├── hbp │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── conv1d.py │ │ ├── conv2d.py │ │ ├── conv3d.py │ │ ├── conv_transpose1d.py │ │ ├── conv_transpose2d.py │ │ ├── conv_transpose3d.py │ │ ├── conv_transposend.py │ │ ├── convnd.py │ │ ├── custom_module.py │ │ ├── dropout.py │ │ ├── flatten.py │ │ ├── hbp_options.py │ │ ├── hbpbase.py │ │ ├── linear.py │ │ ├── losses.py │ │ ├── padding.py │ │ └── pooling.py │ │ └── sqrt_ggn │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── base.py │ │ ├── batchnorm_nd.py │ │ ├── convnd.py │ │ ├── convtransposend.py │ │ ├── custom_module.py │ │ ├── dropout.py │ │ ├── embedding.py │ │ ├── flatten.py │ │ ├── linear.py │ │ ├── losses.py │ │ ├── pad.py │ │ ├── padding.py │ │ ├── pooling.py │ │ └── slicing.py ├── hessianfree │ ├── __init__.py │ ├── ggnvp.py │ ├── hvp.py │ ├── lop.py │ └── rop.py └── utils │ ├── __init__.py │ ├── conv.py │ ├── conv_transpose.py │ ├── convert_parameters.py │ ├── errors.py │ ├── examples.py │ ├── hooks.py │ ├── kroneckers.py │ ├── linear.py │ ├── module_classification.py │ ├── subsampling.py │ └── unsqueeze.py ├── black.toml ├── changelog.md ├── docs ├── .nojekyll ├── CNAME ├── assets │ ├── css │ │ └── style.css │ ├── dangel2020backpack.bib │ ├── fonts │ │ ├── Noto-Sans-700 │ │ │ ├── Noto-Sans-700.eot │ │ │ ├── Noto-Sans-700.svg │ │ │ ├── Noto-Sans-700.ttf │ │ │ ├── Noto-Sans-700.woff │ │ │ └── Noto-Sans-700.woff2 │ │ ├── Noto-Sans-700italic │ │ │ ├── Noto-Sans-700italic.eot │ │ │ ├── Noto-Sans-700italic.svg │ │ │ ├── Noto-Sans-700italic.ttf │ │ │ ├── Noto-Sans-700italic.woff │ │ │ └── Noto-Sans-700italic.woff2 │ │ ├── Noto-Sans-italic │ │ │ ├── Noto-Sans-italic.eot │ │ │ ├── Noto-Sans-italic.svg │ │ │ ├── Noto-Sans-italic.ttf │ │ │ ├── Noto-Sans-italic.woff │ │ │ └── Noto-Sans-italic.woff2 │ │ └── Noto-Sans-regular │ │ │ ├── Noto-Sans-regular.eot │ │ │ ├── Noto-Sans-regular.svg │ │ │ ├── Noto-Sans-regular.ttf │ │ │ ├── Noto-Sans-regular.woff │ │ │ └── Noto-Sans-regular.woff2 │ ├── img │ │ ├── backpack_logo_torch.svg │ │ ├── logo.png │ │ └── updaterule.png │ └── js │ │ └── scale.fix.js ├── examples.html ├── index.html ├── jekyll-theme-minimal.gemspec └── script │ ├── bootstrap │ ├── cibuild │ ├── release │ └── validate-html ├── docs_src ├── .gitignore ├── CNAME ├── README.md ├── buildweb.sh ├── examples │ ├── basic_usage │ │ ├── README.rst │ │ └── example_all_in_one.py │ ├── cheatsheet.pdf │ └── use_cases │ │ ├── README.rst │ │ ├── example_batched_jacobians.py │ │ ├── example_cg_newton.py │ │ ├── example_custom_module.py │ │ ├── example_diag_ggn_optimizer.py │ │ ├── example_differential_privacy.py │ │ ├── example_extension_hook.py │ │ ├── example_first_order_resnet.py │ │ ├── example_gradient_of_variance.py │ │ ├── example_resnet_all_in_one.py │ │ ├── example_retain_graph.py │ │ ├── example_rnn.py │ │ ├── example_save_memory_convolutions.py │ │ ├── example_subsampling.py │ │ └── example_trace_estimation.py ├── images │ └── comp_graph.jpg ├── rtd │ ├── .gitignore │ ├── .nojekyll │ ├── Makefile │ ├── assets │ │ ├── backpack_logo_torch.png │ │ └── backpack_logo_torch.svg │ ├── conf.py │ ├── extensions.rst │ ├── good-to-know.rst │ ├── index.rst │ ├── main-api.rst │ ├── make.bat │ ├── supported-layers.rst │ └── torch.inventory └── splash │ ├── .gitignore │ ├── Gemfile │ ├── _config.yml │ ├── _includes │ ├── code-samples.html │ └── dangel2020backpack.bib │ ├── _layouts │ ├── default.html │ └── post.html │ ├── _sass │ ├── fonts.scss │ ├── jekyll-theme-minimal.scss │ └── rouge-github.scss │ ├── assets │ ├── css │ │ └── style.css │ ├── dangel2020backpack.bib │ ├── fonts │ │ ├── Noto-Sans-700 │ │ │ ├── Noto-Sans-700.eot │ │ │ ├── Noto-Sans-700.svg │ │ │ ├── Noto-Sans-700.ttf │ │ │ ├── Noto-Sans-700.woff │ │ │ └── Noto-Sans-700.woff2 │ │ ├── Noto-Sans-700italic │ │ │ ├── Noto-Sans-700italic.eot │ │ │ ├── Noto-Sans-700italic.svg │ │ │ ├── Noto-Sans-700italic.ttf │ │ │ ├── Noto-Sans-700italic.woff │ │ │ └── Noto-Sans-700italic.woff2 │ │ ├── Noto-Sans-italic │ │ │ ├── Noto-Sans-italic.eot │ │ │ ├── Noto-Sans-italic.svg │ │ │ ├── Noto-Sans-italic.ttf │ │ │ ├── Noto-Sans-italic.woff │ │ │ └── Noto-Sans-italic.woff2 │ │ └── Noto-Sans-regular │ │ │ ├── Noto-Sans-regular.eot │ │ │ ├── Noto-Sans-regular.svg │ │ │ ├── Noto-Sans-regular.ttf │ │ │ ├── Noto-Sans-regular.woff │ │ │ └── Noto-Sans-regular.woff2 │ ├── img │ │ ├── backpack_logo_torch.svg │ │ ├── logo.png │ │ └── updaterule.png │ └── js │ │ └── scale.fix.js │ ├── examples.md │ ├── index.md │ ├── jekyll-theme-minimal.gemspec │ └── script │ ├── bootstrap │ ├── cibuild │ ├── release │ └── validate-html ├── fully_documented.txt ├── logo ├── backpack_logo_no_torch.svg └── backpack_logo_torch.svg ├── makefile ├── pyproject.toml ├── pytest.ini ├── setup.cfg └── test ├── __init__.py ├── adaptive_avg_pool ├── __init__.py ├── problem.py ├── settings_adaptive_avg_pool_nd.py └── test_adaptive_avg_pool_nd.py ├── automated_kfac_test.py ├── automated_test.py ├── benchmark ├── __init__.py ├── functionality.py ├── jvp.py ├── jvp_activations.py ├── jvp_avgpool2d.py ├── jvp_conv2d.py ├── jvp_linear.py ├── jvp_maxpool2d.py └── jvp_zeropad2d.py ├── bugfixes_test.py ├── conv2d_test.py ├── converter ├── __init__.py ├── converter_cases.py ├── resnet_cases.py └── test_converter.py ├── core ├── __init__.py └── derivatives │ ├── __init__.py │ ├── activation_settings.py │ ├── batch_norm_settings.py │ ├── convolution_settings.py │ ├── derivatives_test.py │ ├── embedding_settings.py │ ├── implementation │ ├── autograd.py │ ├── backpack.py │ └── base.py │ ├── linear_settings.py │ ├── loss_settings.py │ ├── lstm_settings.py │ ├── padding_settings.py │ ├── permute_settings.py │ ├── pooling_adaptive_settings.py │ ├── pooling_settings.py │ ├── problem.py │ ├── rnn_settings.py │ ├── scale_module_settings.py │ ├── settings.py │ ├── slicing_settings.py │ └── utils.py ├── custom_module ├── __init__.py ├── test_pad.py └── test_slicing.py ├── extensions ├── __init__.py ├── automated_settings.py ├── firstorder │ ├── __init__.py │ ├── batch_grad │ │ ├── __init__.py │ │ ├── batch_grad_settings.py │ │ └── test_batch_grad.py │ ├── batch_l2_grad │ │ ├── __init__.py │ │ ├── batchl2grad_settings.py │ │ └── test_batchl2grad.py │ ├── firstorder_settings.py │ ├── sum_grad_squared │ │ ├── __init__.py │ │ ├── sumgradsquared_settings.py │ │ └── test_sumgradsquared.py │ └── variance │ │ ├── __init__.py │ │ ├── test_variance.py │ │ └── variance_settings.py ├── graph_clear_test.py ├── implementation │ ├── __init__.py │ ├── autograd.py │ ├── backpack.py │ ├── base.py │ └── hooks.py ├── problem.py ├── secondorder │ ├── __init__.py │ ├── diag_ggn │ │ ├── __init__.py │ │ ├── diag_ggn_settings.py │ │ ├── test_batch_diag_ggn.py │ │ └── test_diag_ggn.py │ ├── diag_hessian │ │ ├── __init__.py │ │ ├── diagh_settings.py │ │ └── test_diag_hessian.py │ ├── hbp │ │ ├── __init__.py │ │ ├── kfac_settings.py │ │ ├── kflr_settings.py │ │ ├── kfra_settings.py │ │ ├── test_kfac.py │ │ ├── test_kflr.py │ │ └── test_kfra.py │ ├── secondorder_settings.py │ └── sqrt_ggn │ │ ├── __init__.py │ │ ├── sqrt_ggn_settings.py │ │ └── test_sqrt_ggn.py ├── test_backprop_extension.py ├── test_hooks.py └── utils.py ├── hessianfree ├── __init__.py └── test_ggnvp.py ├── implementation ├── __init__.py ├── implementation.py ├── implementation_autograd.py └── implementation_bpext.py ├── interface_test.py ├── layers.py ├── layers_test.py ├── linear_test.py ├── networks.py ├── problems.py ├── readme.md ├── test___init__.py ├── test_batch_first.py ├── test_problem.py ├── test_problems_activations.py ├── test_problems_bn.py ├── test_problems_convolutions.py ├── test_problems_kfacs.py ├── test_problems_linear.py ├── test_problems_padding.py ├── test_problems_pooling.py ├── test_retain_graph.py ├── test_second_order_warnings.py ├── utils ├── __init__.py ├── conv_transpose.py ├── evaluation_mode.py ├── skip_extension_test.py ├── skip_test.py ├── test_conv.py ├── test_conv_settings.py ├── test_conv_transpose.py ├── test_conv_transpose_settings.py └── test_subsampling.py └── utils_test.py /.conda_env.yml: -------------------------------------------------------------------------------- 1 | name: backpack 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.9.16 7 | - pip=23.1.2 8 | - pip: 9 | - -e .[lint,test,docs] 10 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | # https://coverage.readthedocs.io/en/v4.5.x/config.html#config 2 | [report] 3 | # Regexes for lines to exclude from consideration 4 | exclude_lines = 5 | # Have to re-enable the standard pragma 6 | pragma: no cover 7 | 8 | # Don't complain if tests don't hit defensive assertion code: 9 | raise NotImplementedError 10 | raise AssertionError 11 | 12 | # TYPE_CHECKING block is never executed during pytest run 13 | if TYPE_CHECKING: 14 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up Python 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: "3.x" 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install --upgrade twine build 24 | - name: Build and publish 25 | env: 26 | TWINE_USERNAME: __token__ 27 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 28 | run: | 29 | python -m build 30 | twine upload dist/* 31 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - '*' 7 | pull_request: 8 | branches: 9 | - development 10 | - master 11 | - release 12 | 13 | jobs: 14 | tests: 15 | name: "py${{ matrix.python-version }} torch${{ matrix.pytorch-version}}" 16 | runs-on: ubuntu-latest 17 | env: 18 | USING_COVERAGE: '3.9' 19 | 20 | strategy: 21 | matrix: 22 | python-version: [3.9] 23 | pytorch-version: 24 | - "==2.2.0" 25 | - "" # latest 26 | steps: 27 | - uses: actions/checkout@v1 28 | - uses: actions/setup-python@v1 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | - name: Install Dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | make install-test 35 | pip install torch${{ matrix.pytorch-version }} torchvision 36 | - name: Run test 37 | if: contains('refs/heads/master refs/heads/development refs/heads/release', github.ref) 38 | run: | 39 | make test 40 | - name: Run test-light 41 | if: contains('refs/heads/master refs/heads/development refs/heads/release', github.ref) != 1 42 | run: | 43 | make test-light 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .mypy_cache 3 | *.egg-info/ 4 | **/*.pyc 5 | .cache 6 | examples/data 7 | .idea 8 | .coverage 9 | dist/* 10 | build/* 11 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | sphinx: 7 | configuration: docs_src/rtd/conf.py 8 | 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.9" 13 | 14 | python: 15 | install: 16 | - method: pip 17 | path: . 18 | extra_requirements: 19 | - docs 20 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Felix Dangel, Frederik Künstner & Philipp Hennig 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /backpack/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/backpack/core/__init__.py -------------------------------------------------------------------------------- /backpack/core/derivatives/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains derivatives of all supported modules.""" 2 | -------------------------------------------------------------------------------- /backpack/core/derivatives/avgpool1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.avgpoolnd import AvgPoolNDDerivatives 2 | 3 | 4 | class AvgPool1DDerivatives(AvgPoolNDDerivatives): 5 | def __init__(self): 6 | super().__init__(N=1) 7 | -------------------------------------------------------------------------------- /backpack/core/derivatives/avgpool2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.avgpoolnd import AvgPoolNDDerivatives 2 | 3 | 4 | class AvgPool2DDerivatives(AvgPoolNDDerivatives): 5 | def __init__(self): 6 | super().__init__(N=2) 7 | -------------------------------------------------------------------------------- /backpack/core/derivatives/avgpool3d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.avgpoolnd import AvgPoolNDDerivatives 2 | 3 | 4 | class AvgPool3DDerivatives(AvgPoolNDDerivatives): 5 | def __init__(self): 6 | super().__init__(N=3) 7 | -------------------------------------------------------------------------------- /backpack/core/derivatives/conv1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.convnd import ConvNDDerivatives 2 | 3 | 4 | class Conv1DDerivatives(ConvNDDerivatives): 5 | def __init__(self): 6 | super().__init__(N=1) 7 | -------------------------------------------------------------------------------- /backpack/core/derivatives/conv2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.convnd import ConvNDDerivatives 2 | 3 | 4 | class Conv2DDerivatives(ConvNDDerivatives): 5 | def __init__(self): 6 | super().__init__(N=2) 7 | -------------------------------------------------------------------------------- /backpack/core/derivatives/conv3d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.convnd import ConvNDDerivatives 2 | 3 | 4 | class Conv3DDerivatives(ConvNDDerivatives): 5 | def __init__(self): 6 | super().__init__(N=3) 7 | -------------------------------------------------------------------------------- /backpack/core/derivatives/conv_transpose1d.py: -------------------------------------------------------------------------------- 1 | """Partial derivatives for `torch.nn.ConvTranspose1d`.""" 2 | 3 | from backpack.core.derivatives.conv_transposend import ConvTransposeNDDerivatives 4 | 5 | 6 | class ConvTranspose1DDerivatives(ConvTransposeNDDerivatives): 7 | def __init__(self): 8 | super().__init__(N=1) 9 | -------------------------------------------------------------------------------- /backpack/core/derivatives/conv_transpose2d.py: -------------------------------------------------------------------------------- 1 | """Partial derivatives for `torch.nn.ConvTranspose2d`.""" 2 | 3 | from backpack.core.derivatives.conv_transposend import ConvTransposeNDDerivatives 4 | 5 | 6 | class ConvTranspose2DDerivatives(ConvTransposeNDDerivatives): 7 | def __init__(self): 8 | super().__init__(N=2) 9 | -------------------------------------------------------------------------------- /backpack/core/derivatives/conv_transpose3d.py: -------------------------------------------------------------------------------- 1 | """Partial derivatives for `torch.nn.ConvTranspose3d`.""" 2 | 3 | from backpack.core.derivatives.conv_transposend import ConvTransposeNDDerivatives 4 | 5 | 6 | class ConvTranspose3DDerivatives(ConvTransposeNDDerivatives): 7 | def __init__(self): 8 | super().__init__(N=3) 9 | -------------------------------------------------------------------------------- /backpack/core/derivatives/dropout.py: -------------------------------------------------------------------------------- 1 | """Partial derivatives for the dropout layer.""" 2 | 3 | from typing import List, Tuple 4 | 5 | from torch import Tensor, eq, ones_like 6 | from torch.nn import Dropout 7 | 8 | from backpack.core.derivatives.elementwise import ElementwiseDerivatives 9 | from backpack.utils.subsampling import subsample 10 | 11 | 12 | class DropoutDerivatives(ElementwiseDerivatives): 13 | """Derivatives for the Dropout module.""" 14 | 15 | def hessian_is_zero(self, module: Dropout) -> bool: 16 | """``Dropout''(x) = 0``. 17 | 18 | Args: 19 | module: dropout module 20 | 21 | Returns: 22 | whether hessian is zero 23 | """ 24 | return True 25 | 26 | def df( 27 | self, 28 | module: Dropout, 29 | g_inp: Tuple[Tensor], 30 | g_out: Tuple[Tensor], 31 | subsampling: List[int] = None, 32 | ) -> Tensor: # noqa: D102 33 | output = subsample(module.output, subsampling=subsampling) 34 | if module.training: 35 | scaling = 1 / (1 - module.p) 36 | mask = 1 - eq(output, 0.0).to(output.dtype) 37 | return mask * scaling 38 | else: 39 | return ones_like(output) 40 | -------------------------------------------------------------------------------- /backpack/core/derivatives/elu.py: -------------------------------------------------------------------------------- 1 | """Partial derivatives for the ELU activation function.""" 2 | 3 | from typing import List, Tuple 4 | 5 | from torch import Tensor, exp, le, ones_like, zeros_like 6 | from torch.nn import ELU 7 | 8 | from backpack.core.derivatives.elementwise import ElementwiseDerivatives 9 | from backpack.utils.subsampling import subsample 10 | 11 | 12 | class ELUDerivatives(ElementwiseDerivatives): 13 | """Implement first- and second-order partial derivatives of ELU.""" 14 | 15 | def hessian_is_zero(self, module: ELU) -> bool: 16 | """`ELU''(x) ≠ 0`.""" 17 | return False 18 | 19 | def df( 20 | self, 21 | module: ELU, 22 | g_inp: Tuple[Tensor], 23 | g_out: Tuple[Tensor], 24 | subsampling: List[int] = None, 25 | ): 26 | """First ELU derivative: `ELU'(x) = alpha * e^x if x <= 0 else 1`.""" 27 | input0 = subsample(module.input0, subsampling=subsampling) 28 | non_pos = le(input0, 0) 29 | 30 | result = ones_like(input0) 31 | result[non_pos] = module.alpha * exp(input0[non_pos]) 32 | 33 | return result 34 | 35 | def d2f(self, module, g_inp, g_out): 36 | """Second ELU derivative: `ELU''(x) = alpha * e^x if x <= 0 else 0`.""" 37 | non_pos = le(module.input0, 0) 38 | 39 | result = zeros_like(module.input0) 40 | result[non_pos] = module.alpha * exp(module.input0[non_pos]) 41 | 42 | return result 43 | -------------------------------------------------------------------------------- /backpack/core/derivatives/flatten.py: -------------------------------------------------------------------------------- 1 | """Partial derivatives of the flatten layer.""" 2 | 3 | from typing import List, Tuple 4 | 5 | from torch import Tensor 6 | from torch.nn import Flatten 7 | 8 | from backpack.core.derivatives.basederivatives import BaseDerivatives 9 | 10 | 11 | class FlattenDerivatives(BaseDerivatives): 12 | def hessian_is_zero(self, module): 13 | return True 14 | 15 | def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat): 16 | return mat 17 | 18 | def _jac_t_mat_prod( 19 | self, 20 | module: Flatten, 21 | g_inp: Tuple[Tensor], 22 | g_out: Tuple[Tensor], 23 | mat: Tensor, 24 | subsampling: List[int] = None, 25 | ) -> Tensor: 26 | return self.reshape_like_input(mat, module, subsampling=subsampling) 27 | 28 | def _jac_mat_prod( 29 | self, 30 | module: Flatten, 31 | g_inp: Tuple[Tensor], 32 | g_out: Tuple[Tensor], 33 | mat: Tensor, 34 | ) -> Tensor: 35 | return self.reshape_like_output(mat, module) 36 | -------------------------------------------------------------------------------- /backpack/core/derivatives/leakyrelu.py: -------------------------------------------------------------------------------- 1 | """Partial derivatives for the leaky ReLU layer.""" 2 | 3 | from typing import List, Tuple 4 | 5 | from torch import Tensor, gt 6 | from torch.nn import LeakyReLU 7 | 8 | from backpack.core.derivatives.elementwise import ElementwiseDerivatives 9 | from backpack.utils.subsampling import subsample 10 | 11 | 12 | class LeakyReLUDerivatives(ElementwiseDerivatives): 13 | def hessian_is_zero(self, module: LeakyReLU) -> bool: 14 | """`LeakyReLU''(x) = 0`.""" 15 | return True 16 | 17 | def df( 18 | self, 19 | module: LeakyReLU, 20 | g_inp: Tuple[Tensor], 21 | g_out: Tuple[Tensor], 22 | subsampling: List[int] = None, 23 | ) -> Tensor: 24 | """``LeakyReLU'(x) = negative_slope if x < 0 else 1``.""" 25 | input0 = subsample(module.input0, subsampling=subsampling) 26 | df_leakyrelu = gt(input0, 0).to(input0.dtype) 27 | df_leakyrelu[df_leakyrelu == 0] = module.negative_slope 28 | return df_leakyrelu 29 | -------------------------------------------------------------------------------- /backpack/core/derivatives/logsigmoid.py: -------------------------------------------------------------------------------- 1 | """Contains partial derivatives for the ``torch.nn.LogSigmoid`` layer.""" 2 | 3 | from typing import List, Tuple 4 | 5 | from torch import Tensor, exp 6 | from torch.nn import LogSigmoid 7 | 8 | from backpack.core.derivatives.elementwise import ElementwiseDerivatives 9 | from backpack.utils.subsampling import subsample 10 | 11 | 12 | class LogSigmoidDerivatives(ElementwiseDerivatives): 13 | def hessian_is_zero(self, module): 14 | """`logsigmoid''(x) ≠ 0`.""" 15 | return False 16 | 17 | def df( 18 | self, 19 | module: LogSigmoid, 20 | g_inp: Tuple[Tensor], 21 | g_out: Tuple[Tensor], 22 | subsampling: List[int] = None, 23 | ) -> Tensor: 24 | """First Logsigmoid derivative: `logsigmoid'(x) = 1 / (e^x + 1) `.""" 25 | input0 = subsample(module.input0, subsampling=subsampling) 26 | return 1 / (exp(input0) + 1) 27 | 28 | def d2f(self, module, g_inp, g_out): 29 | """Second Logsigmoid derivative: `logsigmoid''(x) = - e^x / (e^x + 1)^2`.""" 30 | exp_input = exp(module.input0) 31 | return -(exp_input / (exp_input + 1) ** 2) 32 | -------------------------------------------------------------------------------- /backpack/core/derivatives/maxpool1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.maxpoolnd import MaxPoolNDDerivatives 2 | 3 | 4 | class MaxPool1DDerivatives(MaxPoolNDDerivatives): 5 | def __init__(self): 6 | super().__init__(N=1) 7 | -------------------------------------------------------------------------------- /backpack/core/derivatives/maxpool2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.maxpoolnd import MaxPoolNDDerivatives 2 | 3 | 4 | class MaxPool2DDerivatives(MaxPoolNDDerivatives): 5 | def __init__(self): 6 | super().__init__(N=2) 7 | -------------------------------------------------------------------------------- /backpack/core/derivatives/maxpool3d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.maxpoolnd import MaxPoolNDDerivatives 2 | 3 | 4 | class MaxPool3DDerivatives(MaxPoolNDDerivatives): 5 | def __init__(self): 6 | super().__init__(N=3) 7 | -------------------------------------------------------------------------------- /backpack/core/derivatives/permute.py: -------------------------------------------------------------------------------- 1 | """Module containing derivatives of Permute.""" 2 | 3 | from typing import List, Tuple 4 | 5 | from torch import Tensor, argsort 6 | 7 | from backpack.core.derivatives.basederivatives import BaseDerivatives 8 | from backpack.custom_module.permute import Permute 9 | 10 | 11 | class PermuteDerivatives(BaseDerivatives): 12 | """Derivatives of Permute.""" 13 | 14 | def _jac_t_mat_prod( 15 | self, 16 | module: Permute, 17 | g_inp: Tuple[Tensor], 18 | g_out: Tuple[Tensor], 19 | mat: Tensor, 20 | subsampling: List[int] = None, 21 | ) -> Tensor: 22 | return mat.permute( 23 | [0] + [element + 1 for element in argsort(Tensor(module.dims))] 24 | ) 25 | 26 | def _jac_mat_prod( 27 | self, module: Permute, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor 28 | ) -> Tensor: 29 | return mat.permute([0] + [element + 1 for element in module.dims]) 30 | -------------------------------------------------------------------------------- /backpack/core/derivatives/relu.py: -------------------------------------------------------------------------------- 1 | """Partial derivatives for the ReLU activation function.""" 2 | 3 | from typing import List, Tuple 4 | 5 | from torch import Tensor, gt 6 | from torch.nn import ReLU 7 | 8 | from backpack.core.derivatives.elementwise import ElementwiseDerivatives 9 | from backpack.utils.subsampling import subsample 10 | 11 | 12 | class ReLUDerivatives(ElementwiseDerivatives): 13 | def hessian_is_zero(self, module): 14 | """`ReLU''(x) = 0`.""" 15 | return True 16 | 17 | def df( 18 | self, 19 | module: ReLU, 20 | g_inp: Tuple[Tensor], 21 | g_out: Tuple[Tensor], 22 | subsampling: List[int] = None, 23 | ) -> Tensor: 24 | """First ReLU derivative: `ReLU'(x) = 0 if x < 0 else 1`.""" 25 | input0 = subsample(module.input0, subsampling=subsampling) 26 | return gt(input0, 0).to(input0.dtype) 27 | -------------------------------------------------------------------------------- /backpack/core/derivatives/scale_module.py: -------------------------------------------------------------------------------- 1 | """Derivatives of ScaleModule (implies Identity).""" 2 | 3 | from typing import List, Tuple, Union 4 | 5 | from torch import Tensor 6 | from torch.nn import Identity 7 | 8 | from backpack.core.derivatives.basederivatives import BaseDerivatives 9 | from backpack.custom_module.scale_module import ScaleModule 10 | 11 | 12 | class ScaleModuleDerivatives(BaseDerivatives): 13 | """Derivatives of ScaleModule (implies Identity).""" 14 | 15 | def _jac_t_mat_prod( 16 | self, 17 | module: Union[ScaleModule, Identity], 18 | g_inp: Tuple[Tensor], 19 | g_out: Tuple[Tensor], 20 | mat: Tensor, 21 | subsampling: List[int] = None, 22 | ) -> Tensor: 23 | if isinstance(module, Identity): 24 | return mat 25 | else: 26 | return mat * module.weight 27 | -------------------------------------------------------------------------------- /backpack/core/derivatives/sigmoid.py: -------------------------------------------------------------------------------- 1 | """Partial derivatives for the Sigmoid activation function.""" 2 | 3 | from typing import List, Tuple 4 | 5 | from torch import Tensor 6 | from torch.nn import Sigmoid 7 | 8 | from backpack.core.derivatives.elementwise import ElementwiseDerivatives 9 | from backpack.utils.subsampling import subsample 10 | 11 | 12 | class SigmoidDerivatives(ElementwiseDerivatives): 13 | def hessian_is_zero(self, module): 14 | """`σ''(x) ≠ 0`.""" 15 | return False 16 | 17 | def df( 18 | self, 19 | module: Sigmoid, 20 | g_inp: Tuple[Tensor], 21 | g_out: Tuple[Tensor], 22 | subsampling: List[int] = None, 23 | ) -> Tensor: 24 | """First sigmoid derivative: `σ'(x) = σ(x) (1 - σ(x))`.""" 25 | output = subsample(module.output, subsampling=subsampling) 26 | return output * (1.0 - output) 27 | 28 | def d2f(self, module, g_inp, g_out): 29 | """Second sigmoid derivative: `σ''(x) = σ(x) (1 - σ(x)) (1 - 2 σ(x))`.""" 30 | return module.output * (1 - module.output) * (1 - 2 * module.output) 31 | -------------------------------------------------------------------------------- /backpack/core/derivatives/sum_module.py: -------------------------------------------------------------------------------- 1 | """Contains derivatives for SumModule.""" 2 | 3 | from typing import List, Tuple 4 | 5 | from torch import Tensor 6 | 7 | from backpack.core.derivatives.basederivatives import BaseDerivatives 8 | from backpack.custom_module.branching import SumModule 9 | 10 | 11 | class SumModuleDerivatives(BaseDerivatives): 12 | """Contains derivatives for SumModule.""" 13 | 14 | def _jac_t_mat_prod( 15 | self, 16 | module: SumModule, 17 | g_inp: Tuple[Tensor], 18 | g_out: Tuple[Tensor], 19 | mat: Tensor, 20 | subsampling: List[int] = None, 21 | ) -> Tensor: 22 | return mat 23 | -------------------------------------------------------------------------------- /backpack/core/derivatives/tanh.py: -------------------------------------------------------------------------------- 1 | """Partial derivatives for the Tanh activation function.""" 2 | 3 | from typing import List, Tuple 4 | 5 | from torch import Tensor 6 | from torch.nn import Tanh 7 | 8 | from backpack.core.derivatives.elementwise import ElementwiseDerivatives 9 | from backpack.utils.subsampling import subsample 10 | 11 | 12 | class TanhDerivatives(ElementwiseDerivatives): 13 | def hessian_is_zero(self, module): 14 | return False 15 | 16 | def df( 17 | self, 18 | module: Tanh, 19 | g_inp: Tuple[Tensor], 20 | g_out: Tuple[Tensor], 21 | subsampling: List[int] = None, 22 | ) -> Tensor: 23 | output = subsample(module.output, subsampling=subsampling) 24 | return 1.0 - output**2 25 | 26 | def d2f(self, module, g_inp, g_out): 27 | return -2.0 * module.output * (1.0 - module.output**2) 28 | -------------------------------------------------------------------------------- /backpack/custom_module/__init__.py: -------------------------------------------------------------------------------- 1 | """This package adds torch.nn.Module type modules. 2 | 3 | These are used as utilities. 4 | """ 5 | -------------------------------------------------------------------------------- /backpack/custom_module/pad.py: -------------------------------------------------------------------------------- 1 | """Module version of ``torch.nn.functional.pad``.""" 2 | 3 | from typing import Sequence 4 | 5 | from torch import Tensor 6 | from torch.nn import Module 7 | from torch.nn.functional import pad 8 | 9 | 10 | class Pad(Module): 11 | """Module version of ``torch.nn.functional.pad`` (N-dimensional padding).""" 12 | 13 | def __init__(self, pad: Sequence[int], mode: str = "constant", value: float = 0.0): 14 | """Store padding hyperparameters. 15 | 16 | See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html. 17 | 18 | Args: 19 | pad: Tuple of even length specifying the padding. 20 | mode: Padding mode. Default ``'constant'``. 21 | value: Fill value for constant padding. Default ``0.0``. 22 | """ 23 | super().__init__() 24 | self.pad = pad 25 | self.mode = mode 26 | self.value = value 27 | 28 | def forward(self, input: Tensor) -> Tensor: 29 | """Pad the input tensor. 30 | 31 | Args: 32 | input: Input tensor. 33 | 34 | Returns: 35 | Padded input tensor. 36 | """ 37 | return pad(input, self.pad, mode=self.mode, value=self.value) 38 | -------------------------------------------------------------------------------- /backpack/custom_module/reduce_tuple.py: -------------------------------------------------------------------------------- 1 | """Module containing ReduceTuple module.""" 2 | 3 | from typing import Union 4 | 5 | from torch import Tensor 6 | from torch.nn import Module 7 | 8 | 9 | class ReduceTuple(Module): 10 | """Module reducing tuple input.""" 11 | 12 | def __init__(self, index: int = 0): 13 | """Initialization. 14 | 15 | Args: 16 | index: which element to choose 17 | """ 18 | super().__init__() 19 | self.index = index 20 | 21 | def forward(self, input: tuple) -> Union[tuple, Tensor]: 22 | """Reduces the tuple. 23 | 24 | Args: 25 | input: the tuple of data 26 | 27 | Returns: 28 | the selected element 29 | """ 30 | return input[self.index] 31 | -------------------------------------------------------------------------------- /backpack/custom_module/scale_module.py: -------------------------------------------------------------------------------- 1 | """Contains ScaleModule.""" 2 | 3 | from torch import Tensor 4 | from torch.nn import Module 5 | 6 | 7 | class ScaleModule(Module): 8 | """Scale Module scales the input by a constant.""" 9 | 10 | def __init__(self, weight: float = 1.0): 11 | """Store scalar weight. 12 | 13 | Args: 14 | weight: Initial value for weight. Defaults to 1.0. 15 | 16 | Raises: 17 | ValueError: if weight is no float 18 | """ 19 | super().__init__() 20 | if not isinstance(weight, float): 21 | raise ValueError("Weight must be float.") 22 | self.weight: float = weight 23 | 24 | def forward(self, input: Tensor) -> Tensor: 25 | """Defines forward pass. 26 | 27 | Args: 28 | input: input 29 | 30 | Returns: 31 | product of input and weight 32 | """ 33 | return input * self.weight 34 | -------------------------------------------------------------------------------- /backpack/custom_module/slicing.py: -------------------------------------------------------------------------------- 1 | """Custom module to perform tensor slicing.""" 2 | 3 | from typing import Tuple, Union 4 | 5 | from torch import Tensor 6 | from torch.nn import Module 7 | 8 | 9 | class Slicing(Module): 10 | """Module that slices a tensor.""" 11 | 12 | def __init__(self, slice_info: Tuple[Union[slice, int]]): 13 | """Store the slicing object. 14 | 15 | Args: 16 | slice_info: Argument that is passed to the slicing operator in the 17 | forward pass. 18 | """ 19 | super().__init__() 20 | self.slice_info = slice_info 21 | 22 | def forward(self, input: Tensor) -> Tensor: 23 | """Slice the input tensor. 24 | 25 | Args: 26 | input: Input tensor. 27 | 28 | Returns: 29 | Sliced input tensor. 30 | """ 31 | return input[self.slice_info] 32 | -------------------------------------------------------------------------------- /backpack/extensions/__init__.py: -------------------------------------------------------------------------------- 1 | """BackPACK extensions that can be passed into a ``with backpack(...)`` context.""" 2 | 3 | from .curvmatprod import GGNMP, HMP, PCHMP 4 | from .firstorder import BatchGrad, BatchL2Grad, SumGradSquared, Variance 5 | from .secondorder import ( 6 | HBP, 7 | KFAC, 8 | KFLR, 9 | KFRA, 10 | BatchDiagGGNExact, 11 | BatchDiagGGNMC, 12 | BatchDiagHessian, 13 | DiagGGNExact, 14 | DiagGGNMC, 15 | DiagHessian, 16 | SqrtGGNExact, 17 | SqrtGGNMC, 18 | ) 19 | 20 | __all__ = [ 21 | "PCHMP", 22 | "GGNMP", 23 | "HMP", 24 | "BatchL2Grad", 25 | "BatchGrad", 26 | "SumGradSquared", 27 | "Variance", 28 | "KFAC", 29 | "KFLR", 30 | "KFRA", 31 | "HBP", 32 | "DiagGGNExact", 33 | "BatchDiagGGNExact", 34 | "DiagGGNMC", 35 | "BatchDiagGGNMC", 36 | "DiagHessian", 37 | "BatchDiagHessian", 38 | "SqrtGGNExact", 39 | "SqrtGGNMC", 40 | ] 41 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/__init__.py: -------------------------------------------------------------------------------- 1 | """Block-diagonal curvature products 2 | ==================================== 3 | 4 | These extensions do not compute information directly, but give access to 5 | functions to compute matrix-matrix products with block-diagonal approximations 6 | of the Hessian. 7 | 8 | Extensions propagate functions through the computation graph. In contrast to 9 | standard gradient computation, the graph is retained during backpropagation 10 | (this results in higher memory consumption). The cost of one matrix-vector 11 | multiplication is on the order of one backward pass. 12 | 13 | Implemented extensions are matrix-free curvature-matrix multiplication with 14 | the block-diagonal of the Hessian, generalized Gauss-Newton (GGN)/Fisher, and 15 | positive-curvature Hessian. They are formalized by the concept of Hessian 16 | backpropagation, described in: 17 | 18 | - `Modular Block-diagonal Curvature Approximations for Feedforward Architectures 19 | `_ 20 | by Felix Dangel, Stefan Harmeling, Philipp Hennig, 2020. 21 | """ 22 | 23 | from .ggnmp import GGNMP 24 | from .hmp import HMP 25 | from .pchmp import PCHMP 26 | 27 | __all__ = [ 28 | "GGNMP", 29 | "HMP", 30 | "PCHMP", 31 | ] 32 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/ggnmp/activations.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.relu import ReLUDerivatives 2 | from backpack.core.derivatives.sigmoid import SigmoidDerivatives 3 | from backpack.core.derivatives.tanh import TanhDerivatives 4 | from backpack.extensions.curvmatprod.ggnmp.ggnmpbase import GGNMPBase 5 | 6 | 7 | class GGNMPReLU(GGNMPBase): 8 | def __init__(self): 9 | super().__init__(derivatives=ReLUDerivatives()) 10 | 11 | 12 | class GGNMPSigmoid(GGNMPBase): 13 | def __init__(self): 14 | super().__init__(derivatives=SigmoidDerivatives()) 15 | 16 | 17 | class GGNMPTanh(GGNMPBase): 18 | def __init__(self): 19 | super().__init__(derivatives=TanhDerivatives()) 20 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/ggnmp/batchnorm1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives 2 | from backpack.extensions.curvmatprod.ggnmp.ggnmpbase import GGNMPBase 3 | 4 | 5 | class GGNMPBatchNorm1d(GGNMPBase): 6 | def __init__(self): 7 | super().__init__( 8 | derivatives=BatchNormNdDerivatives(), params=["weight", "bias"] 9 | ) 10 | 11 | def weight(self, ext, module, g_inp, g_out, backproped): 12 | h_out_mat_prod = backproped 13 | 14 | def weight_ggnmp(mat): 15 | result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) 16 | result = h_out_mat_prod(result) 17 | result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) 18 | 19 | return result 20 | 21 | return weight_ggnmp 22 | 23 | def bias(self, ext, module, g_inp, g_out, backproped): 24 | h_out_mat_prod = backproped 25 | 26 | def bias_ggnmp(mat): 27 | result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) 28 | result = h_out_mat_prod(result) 29 | result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) 30 | 31 | return result 32 | 33 | return bias_ggnmp 34 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/ggnmp/conv2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv2d import Conv2DDerivatives 2 | from backpack.extensions.curvmatprod.ggnmp.ggnmpbase import GGNMPBase 3 | 4 | 5 | class GGNMPConv2d(GGNMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=Conv2DDerivatives(), params=["weight", "bias"]) 8 | 9 | def weight(self, ext, module, g_inp, g_out, backproped): 10 | h_out_mat_prod = backproped 11 | 12 | def weight_ggnmp(mat): 13 | result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) 14 | result = h_out_mat_prod(result) 15 | result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) 16 | 17 | return result 18 | 19 | return weight_ggnmp 20 | 21 | def bias(self, ext, module, g_inp, g_out, backproped): 22 | h_out_mat_prod = backproped 23 | 24 | def bias_ggnmp(mat): 25 | result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) 26 | result = h_out_mat_prod(result) 27 | result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) 28 | 29 | return result 30 | 31 | return bias_ggnmp 32 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/ggnmp/dropout.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.dropout import DropoutDerivatives 2 | from backpack.extensions.curvmatprod.ggnmp.ggnmpbase import GGNMPBase 3 | 4 | 5 | class GGNMPDropout(GGNMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=DropoutDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/ggnmp/flatten.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.flatten import FlattenDerivatives 2 | from backpack.extensions.curvmatprod.ggnmp.ggnmpbase import GGNMPBase 3 | 4 | 5 | class GGNMPFlatten(GGNMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=FlattenDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/ggnmp/ggnmpbase.py: -------------------------------------------------------------------------------- 1 | """Block generalized Gauss-Newton matrix products""" 2 | 3 | from backpack.extensions.module_extension import ModuleExtension 4 | 5 | 6 | class GGNMPBase(ModuleExtension): 7 | def __init__(self, derivatives, params=None): 8 | super().__init__(params=params) 9 | self.derivatives = derivatives 10 | 11 | def backpropagate(self, ext, module, g_inp, g_out, backproped): 12 | """Backpropagate Hessian multiplication routines. 13 | 14 | Given mat → ℋz(x) mat, backpropagate mat → ℋx mat. 15 | """ 16 | h_out_mat_prod = backproped 17 | 18 | def h_in_mat_prod(mat): 19 | """Multiplication with curvature matrix w.r.t. the module input. 20 | 21 | Parameters: 22 | ----------- 23 | mat : torch.Tensor 24 | Matrix that will be multiplied. 25 | """ 26 | # Multiply with the GGN term: mat → [𝒟z(x)ᵀ ℋz 𝒟z(x)] mat. 27 | result = self.derivatives.jac_mat_prod(module, g_inp, g_out, mat) 28 | result = h_out_mat_prod(result) 29 | result = self.derivatives.jac_t_mat_prod(module, g_inp, g_out, result) 30 | 31 | return result 32 | 33 | return h_in_mat_prod 34 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/ggnmp/linear.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.linear import LinearDerivatives 2 | from backpack.extensions.curvmatprod.ggnmp.ggnmpbase import GGNMPBase 3 | 4 | 5 | class GGNMPLinear(GGNMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=LinearDerivatives(), params=["weight", "bias"]) 8 | 9 | def weight(self, ext, module, g_inp, g_out, backproped): 10 | h_out_mat_prod = backproped 11 | 12 | def weight_ggnmp(mat): 13 | result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) 14 | result = h_out_mat_prod(result) 15 | result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) 16 | 17 | return result 18 | 19 | return weight_ggnmp 20 | 21 | def bias(self, ext, module, g_inp, g_out, backproped): 22 | h_out_mat_prod = backproped 23 | 24 | def bias_ggnmp(mat): 25 | result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) 26 | result = h_out_mat_prod(result) 27 | result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) 28 | 29 | return result 30 | 31 | return bias_ggnmp 32 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/ggnmp/losses.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives 2 | from backpack.core.derivatives.mseloss import MSELossDerivatives 3 | from backpack.extensions.curvmatprod.ggnmp.ggnmpbase import GGNMPBase 4 | 5 | 6 | class GGNMPLoss(GGNMPBase): 7 | def backpropagate(self, ext, module, g_inp, g_out, backproped): 8 | def h_in_mat_prod(mat): 9 | """Multiplication with curvature matrix w.r.t. the module input. 10 | 11 | Parameters: 12 | ----------- 13 | mat : torch.Tensor 14 | Matrix that will be multiplied. 15 | """ 16 | return self.derivatives.make_hessian_mat_prod(module, g_inp, g_out)(mat) 17 | 18 | return h_in_mat_prod 19 | 20 | 21 | class GGNMPMSELoss(GGNMPLoss): 22 | def __init__(self): 23 | super().__init__(derivatives=MSELossDerivatives()) 24 | 25 | 26 | class GGNMPCrossEntropyLoss(GGNMPLoss): 27 | def __init__(self): 28 | super().__init__(derivatives=CrossEntropyLossDerivatives()) 29 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/ggnmp/padding.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives 2 | from backpack.extensions.curvmatprod.ggnmp.ggnmpbase import GGNMPBase 3 | 4 | 5 | class GGNMPZeroPad2d(GGNMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=ZeroPad2dDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/ggnmp/pooling.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives 2 | from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives 3 | from backpack.extensions.curvmatprod.ggnmp.ggnmpbase import GGNMPBase 4 | 5 | 6 | class GGNMPAvgPool2d(GGNMPBase): 7 | def __init__(self): 8 | super().__init__(derivatives=AvgPool2DDerivatives()) 9 | 10 | 11 | class GGNMPMaxpool2d(GGNMPBase): 12 | def __init__(self): 13 | super().__init__(derivatives=MaxPool2DDerivatives()) 14 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/hmp/activations.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.relu import ReLUDerivatives 2 | from backpack.core.derivatives.sigmoid import SigmoidDerivatives 3 | from backpack.core.derivatives.tanh import TanhDerivatives 4 | from backpack.extensions.curvmatprod.hmp.hmpbase import HMPBase 5 | 6 | 7 | class HMPReLU(HMPBase): 8 | def __init__(self): 9 | super().__init__(derivatives=ReLUDerivatives()) 10 | 11 | 12 | class HMPSigmoid(HMPBase): 13 | def __init__(self): 14 | super().__init__(derivatives=SigmoidDerivatives()) 15 | 16 | 17 | class HMPTanh(HMPBase): 18 | def __init__(self): 19 | super().__init__(derivatives=TanhDerivatives()) 20 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/hmp/batchnorm1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives 2 | from backpack.extensions.curvmatprod.hmp.hmpbase import HMPBase 3 | 4 | 5 | class HMPBatchNorm1d(HMPBase): 6 | def __init__(self): 7 | super().__init__( 8 | derivatives=BatchNormNdDerivatives(), params=["weight", "bias"] 9 | ) 10 | 11 | def weight(self, ext, module, g_inp, g_out, backproped): 12 | h_out_mat_prod = backproped 13 | 14 | def weight_hmp(mat): 15 | result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) 16 | result = h_out_mat_prod(result) 17 | result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) 18 | 19 | return result 20 | 21 | return weight_hmp 22 | 23 | def bias(self, ext, module, g_inp, g_out, backproped): 24 | h_out_mat_prod = backproped 25 | 26 | def bias_hmp(mat): 27 | result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) 28 | result = h_out_mat_prod(result) 29 | result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) 30 | 31 | return result 32 | 33 | return bias_hmp 34 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/hmp/conv2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv2d import Conv2DDerivatives 2 | from backpack.extensions.curvmatprod.hmp.hmpbase import HMPBase 3 | 4 | 5 | class HMPConv2d(HMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=Conv2DDerivatives(), params=["weight", "bias"]) 8 | 9 | def weight(self, ext, module, g_inp, g_out, backproped): 10 | h_out_mat_prod = backproped 11 | 12 | def weight_hmp(mat): 13 | result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) 14 | result = h_out_mat_prod(result) 15 | result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) 16 | 17 | return result 18 | 19 | return weight_hmp 20 | 21 | def bias(self, ext, module, g_inp, g_out, backproped): 22 | h_out_mat_prod = backproped 23 | 24 | def bias_hmp(mat): 25 | result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) 26 | result = h_out_mat_prod(result) 27 | result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) 28 | 29 | return result 30 | 31 | return bias_hmp 32 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/hmp/dropout.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.dropout import DropoutDerivatives 2 | from backpack.extensions.curvmatprod.hmp.hmpbase import HMPBase 3 | 4 | 5 | class HMPDropout(HMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=DropoutDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/hmp/flatten.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.flatten import FlattenDerivatives 2 | from backpack.extensions.curvmatprod.hmp.hmpbase import HMPBase 3 | 4 | 5 | class HMPFlatten(HMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=FlattenDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/hmp/hmpbase.py: -------------------------------------------------------------------------------- 1 | """Block Hessian-matrix products""" 2 | 3 | from backpack.extensions.module_extension import ModuleExtension 4 | 5 | 6 | class HMPBase(ModuleExtension): 7 | def __init__(self, derivatives, params=None): 8 | super().__init__(params=params) 9 | self.derivatives = derivatives 10 | 11 | def backpropagate(self, ext, module, g_inp, g_out, backproped): 12 | """Backpropagate Hessian multiplication routines. 13 | 14 | Given mat → ℋz(x) mat, backpropagate mat → ℋx mat. 15 | """ 16 | h_out_mat_prod = backproped 17 | 18 | def h_in_mat_prod(mat): 19 | """Multiplication with curvature matrix w.r.t. the module input. 20 | 21 | Parameters: 22 | ----------- 23 | mat : torch.Tensor 24 | Matrix that will be multiplied. 25 | """ 26 | # Multiply with the GGN term: mat → [𝒟z(x)ᵀ ℋz 𝒟z(x)] mat. 27 | result = self.derivatives.jac_mat_prod(module, g_inp, g_out, mat) 28 | result = h_out_mat_prod(result) 29 | result = self.derivatives.jac_t_mat_prod(module, g_inp, g_out, result) 30 | 31 | # Multiply with the residual term: mat → [∑ᵢ Hzᵢ(x) δzᵢ] mat. 32 | if not self.derivatives.hessian_is_zero(module): 33 | result += self.derivatives.residual_mat_prod(module, g_inp, g_out, mat) 34 | 35 | return result 36 | 37 | return h_in_mat_prod 38 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/hmp/linear.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.linear import LinearDerivatives 2 | from backpack.extensions.curvmatprod.hmp.hmpbase import HMPBase 3 | 4 | 5 | class HMPLinear(HMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=LinearDerivatives(), params=["weight", "bias"]) 8 | 9 | def weight(self, ext, module, g_inp, g_out, backproped): 10 | h_out_mat_prod = backproped 11 | 12 | def weight_hmp(mat): 13 | result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) 14 | result = h_out_mat_prod(result) 15 | result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) 16 | 17 | return result 18 | 19 | return weight_hmp 20 | 21 | def bias(self, ext, module, g_inp, g_out, backproped): 22 | h_out_mat_prod = backproped 23 | 24 | def bias_hmp(mat): 25 | result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) 26 | result = h_out_mat_prod(result) 27 | result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) 28 | 29 | return result 30 | 31 | return bias_hmp 32 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/hmp/losses.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives 2 | from backpack.core.derivatives.mseloss import MSELossDerivatives 3 | from backpack.extensions.curvmatprod.hmp.hmpbase import HMPBase 4 | 5 | 6 | class HMPLoss(HMPBase): 7 | def backpropagate(self, ext, module, g_inp, g_out, backproped): 8 | def h_in_mat_prod(mat): 9 | """Multiplication with curvature matrix w.r.t. the module input. 10 | 11 | Parameters: 12 | ----------- 13 | mat : torch.Tensor 14 | Matrix that will be multiplied. 15 | """ 16 | return self.derivatives.make_hessian_mat_prod(module, g_inp, g_out)(mat) 17 | 18 | return h_in_mat_prod 19 | 20 | 21 | class HMPMSELoss(HMPLoss): 22 | def __init__(self): 23 | super().__init__(derivatives=MSELossDerivatives()) 24 | 25 | 26 | class HMPCrossEntropyLoss(HMPLoss): 27 | def __init__(self): 28 | super().__init__(derivatives=CrossEntropyLossDerivatives()) 29 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/hmp/padding.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives 2 | from backpack.extensions.curvmatprod.hmp.hmpbase import HMPBase 3 | 4 | 5 | class HMPZeroPad2d(HMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=ZeroPad2dDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/hmp/pooling.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives 2 | from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives 3 | from backpack.extensions.curvmatprod.hmp.hmpbase import HMPBase 4 | 5 | 6 | class HMPAvgPool2d(HMPBase): 7 | def __init__(self): 8 | super().__init__(derivatives=AvgPool2DDerivatives()) 9 | 10 | 11 | class HMPMaxpool2d(HMPBase): 12 | def __init__(self): 13 | super().__init__(derivatives=MaxPool2DDerivatives()) 14 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/pchmp/activations.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.relu import ReLUDerivatives 2 | from backpack.core.derivatives.sigmoid import SigmoidDerivatives 3 | from backpack.core.derivatives.tanh import TanhDerivatives 4 | from backpack.extensions.curvmatprod.pchmp.pchmpbase import PCHMPBase 5 | 6 | 7 | class PCHMPReLU(PCHMPBase): 8 | def __init__(self): 9 | super().__init__(derivatives=ReLUDerivatives()) 10 | 11 | 12 | class PCHMPSigmoid(PCHMPBase): 13 | def __init__(self): 14 | super().__init__(derivatives=SigmoidDerivatives()) 15 | 16 | 17 | class PCHMPTanh(PCHMPBase): 18 | def __init__(self): 19 | super().__init__(derivatives=TanhDerivatives()) 20 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/pchmp/conv2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv2d import Conv2DDerivatives 2 | from backpack.extensions.curvmatprod.pchmp.pchmpbase import PCHMPBase 3 | 4 | 5 | class PCHMPConv2d(PCHMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=Conv2DDerivatives(), params=["weight", "bias"]) 8 | 9 | def weight(self, ext, module, g_inp, g_out, backproped): 10 | h_out_mat_prod = backproped 11 | 12 | def weight_pchmp(mat): 13 | result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) 14 | result = h_out_mat_prod(result) 15 | result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) 16 | 17 | return result 18 | 19 | return weight_pchmp 20 | 21 | def bias(self, ext, module, g_inp, g_out, backproped): 22 | h_out_mat_prod = backproped 23 | 24 | def bias_pchmp(mat): 25 | result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) 26 | result = h_out_mat_prod(result) 27 | result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) 28 | 29 | return result 30 | 31 | return bias_pchmp 32 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/pchmp/dropout.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.dropout import DropoutDerivatives 2 | from backpack.extensions.curvmatprod.pchmp.pchmpbase import PCHMPBase 3 | 4 | 5 | class PCHMPDropout(PCHMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=DropoutDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/pchmp/flatten.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.flatten import FlattenDerivatives 2 | from backpack.extensions.curvmatprod.pchmp.pchmpbase import PCHMPBase 3 | 4 | 5 | class PCHMPFlatten(PCHMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=FlattenDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/pchmp/linear.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.linear import LinearDerivatives 2 | from backpack.extensions.curvmatprod.pchmp.pchmpbase import PCHMPBase 3 | 4 | 5 | class PCHMPLinear(PCHMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=LinearDerivatives(), params=["weight", "bias"]) 8 | 9 | def weight(self, ext, module, g_inp, g_out, backproped): 10 | h_out_mat_prod = backproped 11 | 12 | def weight_pchmp(mat): 13 | result = self.derivatives.weight_jac_mat_prod(module, g_inp, g_out, mat) 14 | result = h_out_mat_prod(result) 15 | result = self.derivatives.param_mjp("weight", module, g_inp, g_out, result) 16 | 17 | return result 18 | 19 | return weight_pchmp 20 | 21 | def bias(self, ext, module, g_inp, g_out, backproped): 22 | h_out_mat_prod = backproped 23 | 24 | def bias_pchmp(mat): 25 | result = self.derivatives.bias_jac_mat_prod(module, g_inp, g_out, mat) 26 | result = h_out_mat_prod(result) 27 | result = self.derivatives.param_mjp("bias", module, g_inp, g_out, result) 28 | 29 | return result 30 | 31 | return bias_pchmp 32 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/pchmp/losses.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives 2 | from backpack.core.derivatives.mseloss import MSELossDerivatives 3 | from backpack.extensions.curvmatprod.pchmp.pchmpbase import PCHMPBase 4 | 5 | 6 | class PCHMPLoss(PCHMPBase): 7 | def backpropagate(self, ext, module, g_inp, g_out, backproped): 8 | if not self.derivatives.hessian_is_psd(): 9 | raise ValueError("Only convex losses supported.") 10 | 11 | def h_in_mat_prod(mat): 12 | """Multiplication with curvature matrix w.r.t. the module input. 13 | 14 | Parameters: 15 | ----------- 16 | mat : torch.Tensor 17 | Matrix that will be multiplied. 18 | """ 19 | return self.derivatives.make_hessian_mat_prod(module, g_inp, g_out)(mat) 20 | 21 | return h_in_mat_prod 22 | 23 | 24 | class PCHMPMSELoss(PCHMPLoss): 25 | def __init__(self): 26 | super().__init__(derivatives=MSELossDerivatives()) 27 | 28 | 29 | class PCHMPCrossEntropyLoss(PCHMPLoss): 30 | def __init__(self): 31 | super().__init__(derivatives=CrossEntropyLossDerivatives()) 32 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/pchmp/padding.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives 2 | from backpack.extensions.curvmatprod.pchmp.pchmpbase import PCHMPBase 3 | 4 | 5 | class PCHMPZeroPad2d(PCHMPBase): 6 | def __init__(self): 7 | super().__init__(derivatives=ZeroPad2dDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/curvmatprod/pchmp/pooling.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives 2 | from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives 3 | from backpack.extensions.curvmatprod.pchmp.pchmpbase import PCHMPBase 4 | 5 | 6 | class PCHMPAvgPool2d(PCHMPBase): 7 | def __init__(self): 8 | super().__init__(derivatives=AvgPool2DDerivatives()) 9 | 10 | 11 | class PCHMPMaxpool2d(PCHMPBase): 12 | def __init__(self): 13 | super().__init__(derivatives=MaxPool2DDerivatives()) 14 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/__init__.py: -------------------------------------------------------------------------------- 1 | """First order extensions 2 | =================================== 3 | 4 | First-order extensions make it easier to extract information from the gradients 5 | being already backpropagated through the computational graph. 6 | They do not backpropagate additional information, and have small overhead. 7 | The implemented extensions are 8 | 9 | - :func:`BatchGrad ` 10 | The individual gradients, rather than the sum over the samples 11 | - :func:`SumGradSquared ` 12 | The second moment of the individual gradient 13 | - :func:`Variance ` 14 | The variance of the individual gradients 15 | - :func:`BatchL2Grad ` 16 | The L2 norm of the individual gradients 17 | """ 18 | 19 | from .batch_grad import BatchGrad 20 | from .batch_l2_grad import BatchL2Grad 21 | from .sum_grad_squared import SumGradSquared 22 | from .variance import Variance 23 | 24 | __all__ = ["BatchL2Grad", "BatchGrad", "SumGradSquared", "Variance"] 25 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/base.py: -------------------------------------------------------------------------------- 1 | """Base class for first order extensions.""" 2 | 3 | from typing import Dict, List, Type 4 | 5 | from torch.nn import Module 6 | 7 | from backpack.extensions.backprop_extension import FAIL_WARN, BackpropExtension 8 | from backpack.extensions.module_extension import ModuleExtension 9 | 10 | 11 | class FirstOrderModuleExtension(ModuleExtension): 12 | """Base class for first order module extensions.""" 13 | 14 | 15 | class FirstOrderBackpropExtension(BackpropExtension): 16 | """Base backpropagation extension for first order.""" 17 | 18 | def __init__( 19 | self, 20 | savefield: str, 21 | module_exts: Dict[Type[Module], ModuleExtension], 22 | fail_mode: str = FAIL_WARN, 23 | subsampling: List[int] = None, 24 | ): # noqa: D107 25 | super().__init__( 26 | savefield, module_exts, fail_mode=fail_mode, subsampling=subsampling 27 | ) 28 | 29 | def expects_backpropagation_quantities(self) -> bool: # noqa: D102 30 | return False 31 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/batch_grad/batchnorm_nd.py: -------------------------------------------------------------------------------- 1 | """Contains grad_batch extension for BatchNorm.""" 2 | 3 | from typing import Tuple, Union 4 | 5 | from torch import Tensor 6 | from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d 7 | 8 | from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives 9 | from backpack.extensions.backprop_extension import BackpropExtension 10 | from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase 11 | from backpack.utils.errors import batch_norm_raise_error_if_train 12 | 13 | 14 | class BatchGradBatchNormNd(BatchGradBase): 15 | """BatchGrad extension for BatchNorm.""" 16 | 17 | def __init__(self): 18 | """Initialization.""" 19 | super().__init__( 20 | derivatives=BatchNormNdDerivatives(), params=["bias", "weight"] 21 | ) 22 | 23 | def check_hyperparameters_module_extension( 24 | self, 25 | ext: BackpropExtension, 26 | module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], 27 | g_inp: Tuple[Tensor], 28 | g_out: Tuple[Tensor], 29 | ) -> None: # noqa: D102 30 | batch_norm_raise_error_if_train(module, raise_error=False) 31 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/batch_grad/conv1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv1d import Conv1DDerivatives 2 | from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase 3 | 4 | 5 | class BatchGradConv1d(BatchGradBase): 6 | def __init__(self): 7 | super().__init__(derivatives=Conv1DDerivatives(), params=["bias", "weight"]) 8 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/batch_grad/conv2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv2d import Conv2DDerivatives 2 | from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase 3 | 4 | 5 | class BatchGradConv2d(BatchGradBase): 6 | def __init__(self): 7 | super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) 8 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/batch_grad/conv3d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv3d import Conv3DDerivatives 2 | from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase 3 | 4 | 5 | class BatchGradConv3d(BatchGradBase): 6 | def __init__(self): 7 | super().__init__(derivatives=Conv3DDerivatives(), params=["bias", "weight"]) 8 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/batch_grad/conv_transpose1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives 2 | from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase 3 | 4 | 5 | class BatchGradConvTranspose1d(BatchGradBase): 6 | def __init__(self): 7 | super().__init__( 8 | derivatives=ConvTranspose1DDerivatives(), params=["bias", "weight"] 9 | ) 10 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/batch_grad/conv_transpose2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives 2 | from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase 3 | 4 | 5 | class BatchGradConvTranspose2d(BatchGradBase): 6 | def __init__(self): 7 | super().__init__( 8 | derivatives=ConvTranspose2DDerivatives(), params=["bias", "weight"] 9 | ) 10 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/batch_grad/conv_transpose3d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives 2 | from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase 3 | 4 | 5 | class BatchGradConvTranspose3d(BatchGradBase): 6 | def __init__(self): 7 | super().__init__( 8 | derivatives=ConvTranspose3DDerivatives(), params=["bias", "weight"] 9 | ) 10 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/batch_grad/embedding.py: -------------------------------------------------------------------------------- 1 | """BatchGrad extension for Embedding.""" 2 | 3 | from backpack.core.derivatives.embedding import EmbeddingDerivatives 4 | from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase 5 | 6 | 7 | class BatchGradEmbedding(BatchGradBase): 8 | """BatchGrad extension for Embedding.""" 9 | 10 | def __init__(self): 11 | """Initialization.""" 12 | super().__init__(derivatives=EmbeddingDerivatives(), params=["weight"]) 13 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/batch_grad/linear.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.linear import LinearDerivatives 2 | from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase 3 | 4 | 5 | class BatchGradLinear(BatchGradBase): 6 | def __init__(self): 7 | super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) 8 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/batch_grad/rnn.py: -------------------------------------------------------------------------------- 1 | """Contains BatchGradRNN.""" 2 | 3 | from backpack.core.derivatives.lstm import LSTMDerivatives 4 | from backpack.core.derivatives.rnn import RNNDerivatives 5 | from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase 6 | 7 | 8 | class BatchGradRNN(BatchGradBase): 9 | """Extension for RNN calculating grad_batch.""" 10 | 11 | def __init__(self): 12 | """Initialization.""" 13 | super().__init__( 14 | derivatives=RNNDerivatives(), 15 | params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], 16 | ) 17 | 18 | 19 | class BatchGradLSTM(BatchGradBase): 20 | """Extension for LSTM calculating grad_batch.""" 21 | 22 | def __init__(self): 23 | """Initialization.""" 24 | super().__init__( 25 | derivatives=LSTMDerivatives(), 26 | params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], 27 | ) 28 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py: -------------------------------------------------------------------------------- 1 | """Contains batch_l2 extension for BatchNorm.""" 2 | 3 | from typing import Tuple, Union 4 | 5 | from torch import Tensor 6 | from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d 7 | 8 | from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives 9 | from backpack.extensions.backprop_extension import BackpropExtension 10 | from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base 11 | from backpack.utils.errors import batch_norm_raise_error_if_train 12 | 13 | 14 | class BatchL2BatchNorm(BatchL2Base): 15 | """batch_l2 extension for BatchNorm.""" 16 | 17 | def __init__(self): 18 | """Initialization.""" 19 | super().__init__(["weight", "bias"], BatchNormNdDerivatives()) 20 | 21 | def check_hyperparameters_module_extension( 22 | self, 23 | ext: BackpropExtension, 24 | module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], 25 | g_inp: Tuple[Tensor], 26 | g_out: Tuple[Tensor], 27 | ) -> None: # noqa: D102 28 | batch_norm_raise_error_if_train(module) 29 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/batch_l2_grad/embedding.py: -------------------------------------------------------------------------------- 1 | """BatchL2 extension for Embedding.""" 2 | 3 | from backpack.core.derivatives.embedding import EmbeddingDerivatives 4 | from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base 5 | 6 | 7 | class BatchL2Embedding(BatchL2Base): 8 | """BatchL2 extension for Embedding.""" 9 | 10 | def __init__(self): 11 | """Initialization.""" 12 | super().__init__(derivatives=EmbeddingDerivatives(), params=["weight"]) 13 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/batch_l2_grad/rnn.py: -------------------------------------------------------------------------------- 1 | """Contains BatchL2RNN.""" 2 | 3 | from backpack.core.derivatives.lstm import LSTMDerivatives 4 | from backpack.core.derivatives.rnn import RNNDerivatives 5 | from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base 6 | 7 | 8 | class BatchL2RNN(BatchL2Base): 9 | """Extension for RNN, calculating batch_l2.""" 10 | 11 | def __init__(self): 12 | """Initialization.""" 13 | super().__init__( 14 | ["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], 15 | derivatives=RNNDerivatives(), 16 | ) 17 | 18 | 19 | class BatchL2LSTM(BatchL2Base): 20 | """Extension for LSTM, calculating batch_l2.""" 21 | 22 | def __init__(self): 23 | """Initialization.""" 24 | super().__init__( 25 | ["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], 26 | derivatives=LSTMDerivatives(), 27 | ) 28 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/gradient/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains the gradient extension. 2 | 3 | It calculates the same result as torch backward(). 4 | """ 5 | 6 | # TODO: Rewrite variance to not need this extension 7 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/gradient/batchnorm_nd.py: -------------------------------------------------------------------------------- 1 | """Gradient extension for BatchNorm.""" 2 | 3 | from typing import Tuple, Union 4 | 5 | from torch import Tensor 6 | from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d 7 | 8 | from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives 9 | from backpack.extensions.backprop_extension import BackpropExtension 10 | from backpack.utils.errors import batch_norm_raise_error_if_train 11 | 12 | from .base import GradBaseModule 13 | 14 | 15 | class GradBatchNormNd(GradBaseModule): 16 | """Gradient extension for BatchNorm.""" 17 | 18 | def __init__(self): 19 | """Initialization.""" 20 | super().__init__( 21 | derivatives=BatchNormNdDerivatives(), params=["bias", "weight"] 22 | ) 23 | 24 | def check_hyperparameters_module_extension( 25 | self, 26 | ext: BackpropExtension, 27 | module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], 28 | g_inp: Tuple[Tensor], 29 | g_out: Tuple[Tensor], 30 | ) -> None: # noqa: D102 31 | batch_norm_raise_error_if_train(module) 32 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/gradient/conv1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv1d import Conv1DDerivatives 2 | 3 | from .base import GradBaseModule 4 | 5 | 6 | class GradConv1d(GradBaseModule): 7 | def __init__(self): 8 | super().__init__(derivatives=Conv1DDerivatives(), params=["bias", "weight"]) 9 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/gradient/conv2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv2d import Conv2DDerivatives 2 | 3 | from .base import GradBaseModule 4 | 5 | 6 | class GradConv2d(GradBaseModule): 7 | def __init__(self): 8 | super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) 9 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/gradient/conv3d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv3d import Conv3DDerivatives 2 | 3 | from .base import GradBaseModule 4 | 5 | 6 | class GradConv3d(GradBaseModule): 7 | def __init__(self): 8 | super().__init__(derivatives=Conv3DDerivatives(), params=["bias", "weight"]) 9 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/gradient/convtranspose1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives 2 | 3 | from .base import GradBaseModule 4 | 5 | 6 | class GradConvTranspose1d(GradBaseModule): 7 | def __init__(self): 8 | super().__init__( 9 | derivatives=ConvTranspose1DDerivatives(), params=["bias", "weight"] 10 | ) 11 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/gradient/convtranspose2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives 2 | 3 | from .base import GradBaseModule 4 | 5 | 6 | class GradConvTranspose2d(GradBaseModule): 7 | def __init__(self): 8 | super().__init__( 9 | derivatives=ConvTranspose2DDerivatives(), params=["bias", "weight"] 10 | ) 11 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/gradient/convtranspose3d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives 2 | 3 | from .base import GradBaseModule 4 | 5 | 6 | class GradConvTranspose3d(GradBaseModule): 7 | def __init__(self): 8 | super().__init__( 9 | derivatives=ConvTranspose3DDerivatives(), params=["bias", "weight"] 10 | ) 11 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/gradient/embedding.py: -------------------------------------------------------------------------------- 1 | """Gradient extension for Embedding.""" 2 | 3 | from backpack.core.derivatives.embedding import EmbeddingDerivatives 4 | from backpack.extensions.firstorder.gradient.base import GradBaseModule 5 | 6 | 7 | class GradEmbedding(GradBaseModule): 8 | """Gradient extension for Embedding.""" 9 | 10 | def __init__(self): 11 | """Initialization.""" 12 | super().__init__(derivatives=EmbeddingDerivatives(), params=["weight"]) 13 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/gradient/linear.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.linear import LinearDerivatives 2 | 3 | from .base import GradBaseModule 4 | 5 | 6 | class GradLinear(GradBaseModule): 7 | def __init__(self): 8 | super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) 9 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/gradient/rnn.py: -------------------------------------------------------------------------------- 1 | """Contains GradRNN.""" 2 | 3 | from backpack.core.derivatives.lstm import LSTMDerivatives 4 | from backpack.core.derivatives.rnn import RNNDerivatives 5 | from backpack.extensions.firstorder.gradient.base import GradBaseModule 6 | 7 | 8 | class GradRNN(GradBaseModule): 9 | """Extension for RNN, calculating gradient.""" 10 | 11 | def __init__(self): 12 | """Initialization.""" 13 | super().__init__( 14 | derivatives=RNNDerivatives(), 15 | params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], 16 | ) 17 | 18 | 19 | class GradLSTM(GradBaseModule): 20 | """Extension for LSTM, calculating gradient.""" 21 | 22 | def __init__(self): 23 | """Initialization.""" 24 | super().__init__( 25 | derivatives=LSTMDerivatives(), 26 | params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], 27 | ) 28 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py: -------------------------------------------------------------------------------- 1 | """SGS extension for BatchNorm.""" 2 | 3 | from typing import Tuple, Union 4 | 5 | from torch import Tensor 6 | from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d 7 | 8 | from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives 9 | from backpack.extensions.backprop_extension import BackpropExtension 10 | from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase 11 | from backpack.utils.errors import batch_norm_raise_error_if_train 12 | 13 | 14 | class SGSBatchNormNd(SGSBase): 15 | """SGS extension for BatchNorm.""" 16 | 17 | def __init__(self): 18 | """Initialization.""" 19 | super().__init__(BatchNormNdDerivatives(), ["weight", "bias"]) 20 | 21 | def check_hyperparameters_module_extension( 22 | self, 23 | ext: BackpropExtension, 24 | module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], 25 | g_inp: Tuple[Tensor], 26 | g_out: Tuple[Tensor], 27 | ) -> None: # noqa: D102 28 | batch_norm_raise_error_if_train(module) 29 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/sum_grad_squared/conv1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv1d import Conv1DDerivatives 2 | from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase 3 | 4 | 5 | class SGSConv1d(SGSBase): 6 | def __init__(self): 7 | super().__init__(derivatives=Conv1DDerivatives(), params=["bias", "weight"]) 8 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/sum_grad_squared/conv2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv2d import Conv2DDerivatives 2 | from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase 3 | 4 | 5 | class SGSConv2d(SGSBase): 6 | def __init__(self): 7 | super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) 8 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/sum_grad_squared/conv3d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv3d import Conv3DDerivatives 2 | from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase 3 | 4 | 5 | class SGSConv3d(SGSBase): 6 | def __init__(self): 7 | super().__init__(derivatives=Conv3DDerivatives(), params=["bias", "weight"]) 8 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/sum_grad_squared/convtranspose1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives 2 | from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase 3 | 4 | 5 | class SGSConvTranspose1d(SGSBase): 6 | def __init__(self): 7 | super().__init__( 8 | derivatives=ConvTranspose1DDerivatives(), params=["bias", "weight"] 9 | ) 10 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/sum_grad_squared/convtranspose2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives 2 | from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase 3 | 4 | 5 | class SGSConvTranspose2d(SGSBase): 6 | def __init__(self): 7 | super().__init__( 8 | derivatives=ConvTranspose2DDerivatives(), params=["bias", "weight"] 9 | ) 10 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/sum_grad_squared/convtranspose3d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives 2 | from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase 3 | 4 | 5 | class SGSConvTranspose3d(SGSBase): 6 | def __init__(self): 7 | super().__init__( 8 | derivatives=ConvTranspose3DDerivatives(), params=["bias", "weight"] 9 | ) 10 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/sum_grad_squared/embedding.py: -------------------------------------------------------------------------------- 1 | """SGS extension for Embedding.""" 2 | 3 | from backpack.core.derivatives.embedding import EmbeddingDerivatives 4 | from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase 5 | 6 | 7 | class SGSEmbedding(SGSBase): 8 | """SGS extension for Embedding.""" 9 | 10 | def __init__(self): 11 | """Initialization.""" 12 | super().__init__(derivatives=EmbeddingDerivatives(), params=["weight"]) 13 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/sum_grad_squared/linear.py: -------------------------------------------------------------------------------- 1 | from torch import einsum 2 | 3 | from backpack.core.derivatives.linear import LinearDerivatives 4 | from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase 5 | 6 | 7 | class SGSLinear(SGSBase): 8 | def __init__(self): 9 | super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) 10 | 11 | def weight(self, ext, module, g_inp, g_out, backproped): 12 | """Compute second moments without expanding individual gradients. 13 | 14 | Overwrites the base class implementation that computes the gradient second 15 | moments from individual gradients. This approach is more memory-efficient. 16 | 17 | Note: 18 | For details, see page 12 (paragraph about "second moment") of the 19 | paper (https://arxiv.org/pdf/1912.10985.pdf). 20 | """ 21 | has_additional_axes = g_out[0].dim() > 2 22 | 23 | if has_additional_axes: 24 | # TODO Compare `torch.einsum`, `opt_einsum.contract` and the base class 25 | # implementation: https://github.com/fKunstner/backpack-discuss/issues/111 26 | dE_dY = g_out[0].flatten(start_dim=1, end_dim=-2) 27 | X = module.input0.flatten(start_dim=1, end_dim=-2) 28 | return einsum("nmi,nmj,nki,nkj->ij", dE_dY, X, dE_dY, X) 29 | else: 30 | return einsum("ni,nj->ij", g_out[0] ** 2, module.input0**2) 31 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/sum_grad_squared/rnn.py: -------------------------------------------------------------------------------- 1 | """Contains SGSRNN module.""" 2 | 3 | from backpack.core.derivatives.lstm import LSTMDerivatives 4 | from backpack.core.derivatives.rnn import RNNDerivatives 5 | from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase 6 | 7 | 8 | class SGSRNN(SGSBase): 9 | """Extension for RNN, calculating sum_gradient_squared.""" 10 | 11 | def __init__(self): 12 | """Initialization.""" 13 | super().__init__( 14 | derivatives=RNNDerivatives(), 15 | params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], 16 | ) 17 | 18 | 19 | class SGSLSTM(SGSBase): 20 | """Extension for LSTM, calculating sum_gradient_squared.""" 21 | 22 | def __init__(self): 23 | """Initialization.""" 24 | super().__init__( 25 | derivatives=LSTMDerivatives(), 26 | params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], 27 | ) 28 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/variance/batchnorm_nd.py: -------------------------------------------------------------------------------- 1 | """Variance extension for BatchNorm.""" 2 | 3 | from typing import Tuple, Union 4 | 5 | from torch import Tensor 6 | from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d 7 | 8 | from backpack.extensions.backprop_extension import BackpropExtension 9 | from backpack.extensions.firstorder.gradient.batchnorm_nd import GradBatchNormNd 10 | from backpack.extensions.firstorder.sum_grad_squared.batchnorm_nd import SGSBatchNormNd 11 | from backpack.extensions.firstorder.variance.variance_base import VarianceBaseModule 12 | from backpack.utils.errors import batch_norm_raise_error_if_train 13 | 14 | 15 | class VarianceBatchNormNd(VarianceBaseModule): 16 | """Variance extension for BatchNorm.""" 17 | 18 | def __init__(self): 19 | """Initialization.""" 20 | super().__init__(["weight", "bias"], GradBatchNormNd(), SGSBatchNormNd()) 21 | 22 | def check_hyperparameters_module_extension( 23 | self, 24 | ext: BackpropExtension, 25 | module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], 26 | g_inp: Tuple[Tensor], 27 | g_out: Tuple[Tensor], 28 | ) -> None: # noqa: D102 29 | batch_norm_raise_error_if_train(module) 30 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/variance/conv1d.py: -------------------------------------------------------------------------------- 1 | from backpack.extensions.firstorder.gradient.conv1d import GradConv1d 2 | from backpack.extensions.firstorder.sum_grad_squared.conv1d import SGSConv1d 3 | 4 | from .variance_base import VarianceBaseModule 5 | 6 | 7 | class VarianceConv1d(VarianceBaseModule): 8 | def __init__(self): 9 | super().__init__( 10 | params=["bias", "weight"], 11 | grad_extension=GradConv1d(), 12 | sgs_extension=SGSConv1d(), 13 | ) 14 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/variance/conv2d.py: -------------------------------------------------------------------------------- 1 | from backpack.extensions.firstorder.gradient.conv2d import GradConv2d 2 | from backpack.extensions.firstorder.sum_grad_squared.conv2d import SGSConv2d 3 | 4 | from .variance_base import VarianceBaseModule 5 | 6 | 7 | class VarianceConv2d(VarianceBaseModule): 8 | def __init__(self): 9 | super().__init__( 10 | params=["bias", "weight"], 11 | grad_extension=GradConv2d(), 12 | sgs_extension=SGSConv2d(), 13 | ) 14 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/variance/conv3d.py: -------------------------------------------------------------------------------- 1 | from backpack.extensions.firstorder.gradient.conv3d import GradConv3d 2 | from backpack.extensions.firstorder.sum_grad_squared.conv3d import SGSConv3d 3 | 4 | from .variance_base import VarianceBaseModule 5 | 6 | 7 | class VarianceConv3d(VarianceBaseModule): 8 | def __init__(self): 9 | super().__init__( 10 | params=["bias", "weight"], 11 | grad_extension=GradConv3d(), 12 | sgs_extension=SGSConv3d(), 13 | ) 14 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/variance/convtranspose1d.py: -------------------------------------------------------------------------------- 1 | from backpack.extensions.firstorder.gradient.convtranspose1d import GradConvTranspose1d 2 | from backpack.extensions.firstorder.sum_grad_squared.convtranspose1d import ( 3 | SGSConvTranspose1d, 4 | ) 5 | 6 | from .variance_base import VarianceBaseModule 7 | 8 | 9 | class VarianceConvTranspose1d(VarianceBaseModule): 10 | def __init__(self): 11 | super().__init__( 12 | params=["bias", "weight"], 13 | grad_extension=GradConvTranspose1d(), 14 | sgs_extension=SGSConvTranspose1d(), 15 | ) 16 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/variance/convtranspose2d.py: -------------------------------------------------------------------------------- 1 | from backpack.extensions.firstorder.gradient.convtranspose2d import GradConvTranspose2d 2 | from backpack.extensions.firstorder.sum_grad_squared.convtranspose2d import ( 3 | SGSConvTranspose2d, 4 | ) 5 | 6 | from .variance_base import VarianceBaseModule 7 | 8 | 9 | class VarianceConvTranspose2d(VarianceBaseModule): 10 | def __init__(self): 11 | super().__init__( 12 | params=["bias", "weight"], 13 | grad_extension=GradConvTranspose2d(), 14 | sgs_extension=SGSConvTranspose2d(), 15 | ) 16 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/variance/convtranspose3d.py: -------------------------------------------------------------------------------- 1 | from backpack.extensions.firstorder.gradient.convtranspose3d import GradConvTranspose3d 2 | from backpack.extensions.firstorder.sum_grad_squared.convtranspose3d import ( 3 | SGSConvTranspose3d, 4 | ) 5 | 6 | from .variance_base import VarianceBaseModule 7 | 8 | 9 | class VarianceConvTranspose3d(VarianceBaseModule): 10 | def __init__(self): 11 | super().__init__( 12 | params=["bias", "weight"], 13 | grad_extension=GradConvTranspose3d(), 14 | sgs_extension=SGSConvTranspose3d(), 15 | ) 16 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/variance/embedding.py: -------------------------------------------------------------------------------- 1 | """Variance extension for Embedding.""" 2 | 3 | from backpack.extensions.firstorder.gradient.embedding import GradEmbedding 4 | from backpack.extensions.firstorder.sum_grad_squared.embedding import SGSEmbedding 5 | from backpack.extensions.firstorder.variance.variance_base import VarianceBaseModule 6 | 7 | 8 | class VarianceEmbedding(VarianceBaseModule): 9 | """Variance extension for Embedding.""" 10 | 11 | def __init__(self): 12 | """Initialization.""" 13 | super().__init__( 14 | grad_extension=GradEmbedding(), 15 | sgs_extension=SGSEmbedding(), 16 | params=["weight"], 17 | ) 18 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/variance/linear.py: -------------------------------------------------------------------------------- 1 | from backpack.extensions.firstorder.gradient.linear import GradLinear 2 | from backpack.extensions.firstorder.sum_grad_squared.linear import SGSLinear 3 | 4 | from .variance_base import VarianceBaseModule 5 | 6 | 7 | class VarianceLinear(VarianceBaseModule): 8 | def __init__(self): 9 | super().__init__( 10 | params=["bias", "weight"], 11 | grad_extension=GradLinear(), 12 | sgs_extension=SGSLinear(), 13 | ) 14 | -------------------------------------------------------------------------------- /backpack/extensions/firstorder/variance/rnn.py: -------------------------------------------------------------------------------- 1 | """Contains VarianceRNN.""" 2 | 3 | from backpack.extensions.firstorder.gradient.rnn import GradLSTM, GradRNN 4 | from backpack.extensions.firstorder.sum_grad_squared.rnn import SGSLSTM, SGSRNN 5 | from backpack.extensions.firstorder.variance.variance_base import VarianceBaseModule 6 | 7 | 8 | class VarianceRNN(VarianceBaseModule): 9 | """Extension for RNN, calculating variance.""" 10 | 11 | def __init__(self): 12 | """Initialization.""" 13 | super().__init__( 14 | params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], 15 | grad_extension=GradRNN(), 16 | sgs_extension=SGSRNN(), 17 | ) 18 | 19 | 20 | class VarianceLSTM(VarianceBaseModule): 21 | """Extension for LSTM, calculating variance.""" 22 | 23 | def __init__(self): 24 | """Initialization.""" 25 | super().__init__( 26 | params=["bias_ih_l0", "bias_hh_l0", "weight_ih_l0", "weight_hh_l0"], 27 | grad_extension=GradLSTM(), 28 | sgs_extension=SGSLSTM(), 29 | ) 30 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/base.py: -------------------------------------------------------------------------------- 1 | """Contains base classes for second order extensions.""" 2 | 3 | from backpack.extensions.backprop_extension import BackpropExtension 4 | 5 | 6 | class SecondOrderBackpropExtension(BackpropExtension): 7 | """Base backpropagation extension for second order.""" 8 | 9 | def expects_backpropagation_quantities(self) -> bool: # noqa: D102 10 | return True 11 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py: -------------------------------------------------------------------------------- 1 | """DiagGGN extension for AdaptiveAvgPool.""" 2 | 3 | from backpack.core.derivatives.adaptive_avg_pool_nd import AdaptiveAvgPoolNDDerivatives 4 | from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule 5 | 6 | 7 | class DiagGGNAdaptiveAvgPoolNd(DiagGGNBaseModule): 8 | """DiagGGN extension for AdaptiveAvgPool.""" 9 | 10 | def __init__(self, N: int): 11 | """Initialization. 12 | 13 | Args: 14 | N: number of free dimensions, e.g. use N=1 for AdaptiveAvgPool1d 15 | """ 16 | super().__init__(derivatives=AdaptiveAvgPoolNDDerivatives(N=N)) 17 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/conv1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv1d import Conv1DDerivatives 2 | from backpack.extensions.secondorder.diag_ggn.convnd import ( 3 | BatchDiagGGNConvND, 4 | DiagGGNConvND, 5 | ) 6 | 7 | 8 | class DiagGGNConv1d(DiagGGNConvND): 9 | def __init__(self): 10 | super().__init__(derivatives=Conv1DDerivatives(), params=["bias", "weight"]) 11 | 12 | 13 | class BatchDiagGGNConv1d(BatchDiagGGNConvND): 14 | def __init__(self): 15 | super().__init__(derivatives=Conv1DDerivatives(), params=["bias", "weight"]) 16 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/conv2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv2d import Conv2DDerivatives 2 | from backpack.extensions.secondorder.diag_ggn.convnd import ( 3 | BatchDiagGGNConvND, 4 | DiagGGNConvND, 5 | ) 6 | 7 | 8 | class DiagGGNConv2d(DiagGGNConvND): 9 | def __init__(self): 10 | super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) 11 | 12 | 13 | class BatchDiagGGNConv2d(BatchDiagGGNConvND): 14 | def __init__(self): 15 | super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) 16 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/conv3d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv3d import Conv3DDerivatives 2 | from backpack.extensions.secondorder.diag_ggn.convnd import ( 3 | BatchDiagGGNConvND, 4 | DiagGGNConvND, 5 | ) 6 | 7 | 8 | class DiagGGNConv3d(DiagGGNConvND): 9 | def __init__(self): 10 | super().__init__(derivatives=Conv3DDerivatives(), params=["bias", "weight"]) 11 | 12 | 13 | class BatchDiagGGNConv3d(BatchDiagGGNConvND): 14 | def __init__(self): 15 | super().__init__(derivatives=Conv3DDerivatives(), params=["bias", "weight"]) 16 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/convnd.py: -------------------------------------------------------------------------------- 1 | from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule 2 | from backpack.utils import conv as convUtils 3 | 4 | 5 | class DiagGGNConvND(DiagGGNBaseModule): 6 | def bias(self, ext, module, grad_inp, grad_out, backproped): 7 | sqrt_ggn = backproped 8 | return convUtils.extract_bias_diagonal(module, sqrt_ggn, sum_batch=True) 9 | 10 | def weight(self, ext, module, grad_inp, grad_out, backproped): 11 | X = convUtils.unfold_input(module, module.input0) 12 | weight_diag = convUtils.extract_weight_diagonal( 13 | module, X, backproped, sum_batch=True 14 | ) 15 | return weight_diag 16 | 17 | 18 | class BatchDiagGGNConvND(DiagGGNBaseModule): 19 | def bias(self, ext, module, grad_inp, grad_out, backproped): 20 | sqrt_ggn = backproped 21 | return convUtils.extract_bias_diagonal(module, sqrt_ggn, sum_batch=False) 22 | 23 | def weight(self, ext, module, grad_inp, grad_out, backproped): 24 | X = convUtils.unfold_input(module, module.input0) 25 | weight_diag = convUtils.extract_weight_diagonal( 26 | module, X, backproped, sum_batch=False 27 | ) 28 | return weight_diag 29 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/convtranspose1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives 2 | from backpack.extensions.secondorder.diag_ggn.convtransposend import ( 3 | BatchDiagGGNConvTransposeND, 4 | DiagGGNConvTransposeND, 5 | ) 6 | 7 | 8 | class DiagGGNConvTranspose1d(DiagGGNConvTransposeND): 9 | def __init__(self): 10 | super().__init__( 11 | derivatives=ConvTranspose1DDerivatives(), params=["bias", "weight"] 12 | ) 13 | 14 | 15 | class BatchDiagGGNConvTranspose1d(BatchDiagGGNConvTransposeND): 16 | def __init__(self): 17 | super().__init__( 18 | derivatives=ConvTranspose1DDerivatives(), params=["bias", "weight"] 19 | ) 20 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/convtranspose2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives 2 | from backpack.extensions.secondorder.diag_ggn.convtransposend import ( 3 | BatchDiagGGNConvTransposeND, 4 | DiagGGNConvTransposeND, 5 | ) 6 | 7 | 8 | class DiagGGNConvTranspose2d(DiagGGNConvTransposeND): 9 | def __init__(self): 10 | super().__init__( 11 | derivatives=ConvTranspose2DDerivatives(), params=["bias", "weight"] 12 | ) 13 | 14 | 15 | class BatchDiagGGNConvTranspose2d(BatchDiagGGNConvTransposeND): 16 | def __init__(self): 17 | super().__init__( 18 | derivatives=ConvTranspose2DDerivatives(), params=["bias", "weight"] 19 | ) 20 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/convtranspose3d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives 2 | from backpack.extensions.secondorder.diag_ggn.convtransposend import ( 3 | BatchDiagGGNConvTransposeND, 4 | DiagGGNConvTransposeND, 5 | ) 6 | 7 | 8 | class DiagGGNConvTranspose3d(DiagGGNConvTransposeND): 9 | def __init__(self): 10 | super().__init__( 11 | derivatives=ConvTranspose3DDerivatives(), params=["bias", "weight"] 12 | ) 13 | 14 | 15 | class BatchDiagGGNConvTranspose3d(BatchDiagGGNConvTransposeND): 16 | def __init__(self): 17 | super().__init__( 18 | derivatives=ConvTranspose3DDerivatives(), params=["bias", "weight"] 19 | ) 20 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/convtransposend.py: -------------------------------------------------------------------------------- 1 | from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule 2 | from backpack.utils import conv_transpose as convUtils 3 | 4 | 5 | class DiagGGNConvTransposeND(DiagGGNBaseModule): 6 | def bias(self, ext, module, grad_inp, grad_out, backproped): 7 | sqrt_ggn = backproped 8 | return convUtils.extract_bias_diagonal(module, sqrt_ggn, sum_batch=True) 9 | 10 | def weight(self, ext, module, grad_inp, grad_out, backproped): 11 | X = convUtils.unfold_by_conv_transpose(module.input0, module) 12 | return convUtils.extract_weight_diagonal(module, X, backproped, sum_batch=True) 13 | 14 | 15 | class BatchDiagGGNConvTransposeND(DiagGGNBaseModule): 16 | def bias(self, ext, module, grad_inp, grad_out, backproped): 17 | sqrt_ggn = backproped 18 | return convUtils.extract_bias_diagonal(module, sqrt_ggn, sum_batch=False) 19 | 20 | def weight(self, ext, module, grad_inp, grad_out, backproped): 21 | X = convUtils.unfold_by_conv_transpose(module.input0, module) 22 | return convUtils.extract_weight_diagonal(module, X, backproped, sum_batch=False) 23 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/custom_module.py: -------------------------------------------------------------------------------- 1 | """DiagGGN extensions for backpack's custom modules.""" 2 | 3 | from backpack.core.derivatives.scale_module import ScaleModuleDerivatives 4 | from backpack.core.derivatives.sum_module import SumModuleDerivatives 5 | from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule 6 | 7 | 8 | class DiagGGNScaleModule(DiagGGNBaseModule): 9 | """DiagGGN extension for ScaleModule.""" 10 | 11 | def __init__(self): 12 | """Initialization.""" 13 | super().__init__(derivatives=ScaleModuleDerivatives()) 14 | 15 | 16 | class DiagGGNSumModule(DiagGGNBaseModule): 17 | """DiagGGN extension for SumModule.""" 18 | 19 | def __init__(self): 20 | """Initialization.""" 21 | super().__init__(derivatives=SumModuleDerivatives()) 22 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/dropout.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.dropout import DropoutDerivatives 2 | from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule 3 | 4 | 5 | class DiagGGNDropout(DiagGGNBaseModule): 6 | def __init__(self): 7 | super().__init__(derivatives=DropoutDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/embedding.py: -------------------------------------------------------------------------------- 1 | """DiagGGN extension for Embedding.""" 2 | 3 | from backpack.core.derivatives.embedding import EmbeddingDerivatives 4 | from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule 5 | 6 | 7 | class DiagGGNEmbedding(DiagGGNBaseModule): 8 | """DiagGGN extension of Embedding.""" 9 | 10 | def __init__(self): 11 | """Initialize.""" 12 | super().__init__( 13 | derivatives=EmbeddingDerivatives(), params=["weight"], sum_batch=True 14 | ) 15 | 16 | 17 | class BatchDiagGGNEmbedding(DiagGGNBaseModule): 18 | """DiagGGN extension of Embedding.""" 19 | 20 | def __init__(self): 21 | """Initialize.""" 22 | super().__init__( 23 | derivatives=EmbeddingDerivatives(), params=["weight"], sum_batch=False 24 | ) 25 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/flatten.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.flatten import FlattenDerivatives 2 | from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule 3 | 4 | 5 | class DiagGGNFlatten(DiagGGNBaseModule): 6 | def __init__(self): 7 | super().__init__(derivatives=FlattenDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/linear.py: -------------------------------------------------------------------------------- 1 | import backpack.utils.linear as LinUtils 2 | from backpack.core.derivatives.linear import LinearDerivatives 3 | from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule 4 | 5 | 6 | class DiagGGNLinear(DiagGGNBaseModule): 7 | def __init__(self): 8 | super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) 9 | 10 | def bias(self, ext, module, grad_inp, grad_out, backproped): 11 | return LinUtils.extract_bias_diagonal(module, backproped, sum_batch=True) 12 | 13 | def weight(self, ext, module, grad_inp, grad_out, backproped): 14 | return LinUtils.extract_weight_diagonal(module, backproped, sum_batch=True) 15 | 16 | 17 | class BatchDiagGGNLinear(DiagGGNBaseModule): 18 | def __init__(self): 19 | super().__init__(derivatives=LinearDerivatives(), params=["bias", "weight"]) 20 | 21 | def bias(self, ext, module, grad_inp, grad_out, backproped): 22 | return LinUtils.extract_bias_diagonal(module, backproped, sum_batch=False) 23 | 24 | def weight(self, ext, module, grad_inp, grad_out, backproped): 25 | return LinUtils.extract_weight_diagonal(module, backproped, sum_batch=False) 26 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/pad.py: -------------------------------------------------------------------------------- 1 | """Contains ``DiagGGN{Exact, MC}`` extension for BackPACK's custom ``Pad`` module.""" 2 | 3 | from backpack.core.derivatives.pad import PadDerivatives 4 | from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule 5 | 6 | 7 | class DiagGGNPad(DiagGGNBaseModule): 8 | """``DiagGGN{Exact, MC}`` extension for ``backpack.custom_modules.pad.Pad``.""" 9 | 10 | def __init__(self): 11 | """Pass derivatives for ``backpack.custom_modules.pad.Pad`` module.""" 12 | super().__init__(PadDerivatives()) 13 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/padding.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives 2 | from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule 3 | 4 | 5 | class DiagGGNZeroPad2d(DiagGGNBaseModule): 6 | def __init__(self): 7 | super().__init__(derivatives=ZeroPad2dDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/permute.py: -------------------------------------------------------------------------------- 1 | """Module defining DiagGGNPermute.""" 2 | 3 | from backpack.core.derivatives.permute import PermuteDerivatives 4 | from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule 5 | 6 | 7 | class DiagGGNPermute(DiagGGNBaseModule): 8 | """DiagGGN extension of Permute.""" 9 | 10 | def __init__(self): 11 | """Initialize.""" 12 | super().__init__(derivatives=PermuteDerivatives()) 13 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/pooling.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.avgpool1d import AvgPool1DDerivatives 2 | from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives 3 | from backpack.core.derivatives.avgpool3d import AvgPool3DDerivatives 4 | from backpack.core.derivatives.maxpool1d import MaxPool1DDerivatives 5 | from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives 6 | from backpack.core.derivatives.maxpool3d import MaxPool3DDerivatives 7 | from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule 8 | 9 | 10 | class DiagGGNMaxPool1d(DiagGGNBaseModule): 11 | def __init__(self): 12 | super().__init__(derivatives=MaxPool1DDerivatives()) 13 | 14 | 15 | class DiagGGNMaxPool2d(DiagGGNBaseModule): 16 | def __init__(self): 17 | super().__init__(derivatives=MaxPool2DDerivatives()) 18 | 19 | 20 | class DiagGGNAvgPool1d(DiagGGNBaseModule): 21 | def __init__(self): 22 | super().__init__(derivatives=AvgPool1DDerivatives()) 23 | 24 | 25 | class DiagGGNMaxPool3d(DiagGGNBaseModule): 26 | def __init__(self): 27 | super().__init__(derivatives=MaxPool3DDerivatives()) 28 | 29 | 30 | class DiagGGNAvgPool2d(DiagGGNBaseModule): 31 | def __init__(self): 32 | super().__init__(derivatives=AvgPool2DDerivatives()) 33 | 34 | 35 | class DiagGGNAvgPool3d(DiagGGNBaseModule): 36 | def __init__(self): 37 | super().__init__(derivatives=AvgPool3DDerivatives()) 38 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_ggn/slicing.py: -------------------------------------------------------------------------------- 1 | """Holds ``DiagGGN{Exact, MC}`` extension for BackPACK's custom ``Slicing`` module.""" 2 | 3 | from backpack.core.derivatives.slicing import SlicingDerivatives 4 | from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule 5 | 6 | 7 | class DiagGGNSlicing(DiagGGNBaseModule): 8 | """``DiagGGN{Exact, MC}`` for ``backpack.custom_modules.slicing.Slicing``.""" 9 | 10 | def __init__(self): 11 | """Pass derivatives for ``backpack.custom_modules.pad.Pad`` module.""" 12 | super().__init__(SlicingDerivatives()) 13 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/adaptive_avg_pool_nd.py: -------------------------------------------------------------------------------- 1 | """DiagH extension for AdaptiveAvgPool.""" 2 | 3 | from backpack.core.derivatives.adaptive_avg_pool_nd import AdaptiveAvgPoolNDDerivatives 4 | from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule 5 | 6 | 7 | class DiagHAdaptiveAvgPoolNd(DiagHBaseModule): 8 | """DiagH extension for AdaptiveAvgPool.""" 9 | 10 | def __init__(self, N: int): 11 | """Initialization. 12 | 13 | Args: 14 | N: number of free dimensions, e.g. use N=1 for AdaptiveAvgPool1d 15 | """ 16 | super().__init__(derivatives=AdaptiveAvgPoolNDDerivatives(N=N)) 17 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/conv1d.py: -------------------------------------------------------------------------------- 1 | """Module extensions for diagonal Hessian properties of ``torch.nn.Conv1d``.""" 2 | 3 | from backpack.core.derivatives.conv1d import Conv1DDerivatives 4 | from backpack.extensions.secondorder.diag_hessian.convnd import ( 5 | BatchDiagHConvND, 6 | DiagHConvND, 7 | ) 8 | 9 | 10 | class DiagHConv1d(DiagHConvND): 11 | """Module extension for the Hessian diagonal of ``torch.nn.Conv1d``.""" 12 | 13 | def __init__(self): 14 | """Store parameter names and derivatives object.""" 15 | super().__init__(derivatives=Conv1DDerivatives(), params=["bias", "weight"]) 16 | 17 | 18 | class BatchDiagHConv1d(BatchDiagHConvND): 19 | """Module extension for the per-sample Hessian diagonal of ``torch.nn.Conv1d``.""" 20 | 21 | def __init__(self): 22 | """Store parameter names and derivatives object.""" 23 | super().__init__(derivatives=Conv1DDerivatives(), params=["bias", "weight"]) 24 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/conv2d.py: -------------------------------------------------------------------------------- 1 | """Module extensions for diagonal Hessian properties of ``torch.nn.Conv2d``.""" 2 | 3 | from backpack.core.derivatives.conv2d import Conv2DDerivatives 4 | from backpack.extensions.secondorder.diag_hessian.convnd import ( 5 | BatchDiagHConvND, 6 | DiagHConvND, 7 | ) 8 | 9 | 10 | class DiagHConv2d(DiagHConvND): 11 | """Module extension for the Hessian diagonal of ``torch.nn.Conv2d``.""" 12 | 13 | def __init__(self): 14 | """Store parameter names and derivatives object.""" 15 | super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) 16 | 17 | 18 | class BatchDiagHConv2d(BatchDiagHConvND): 19 | """Module extension for the per-sample Hessian diagonal of ``torch.nn.Conv2d``.""" 20 | 21 | def __init__(self): 22 | """Store parameter names and derivatives object.""" 23 | super().__init__(derivatives=Conv2DDerivatives(), params=["bias", "weight"]) 24 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/conv3d.py: -------------------------------------------------------------------------------- 1 | """Module extensions for diagonal Hessian properties of ``torch.nn.Conv3d``.""" 2 | 3 | from backpack.core.derivatives.conv3d import Conv3DDerivatives 4 | from backpack.extensions.secondorder.diag_hessian.convnd import ( 5 | BatchDiagHConvND, 6 | DiagHConvND, 7 | ) 8 | 9 | 10 | class DiagHConv3d(DiagHConvND): 11 | """Module extension for the Hessian diagonal of ``torch.nn.Conv3d``.""" 12 | 13 | def __init__(self): 14 | """Store parameter names and derivatives object.""" 15 | super().__init__(derivatives=Conv3DDerivatives(), params=["bias", "weight"]) 16 | 17 | 18 | class BatchDiagHConv3d(BatchDiagHConvND): 19 | """Module extension for the per-sample Hessian diagonal of ``torch.nn.Conv3d``.""" 20 | 21 | def __init__(self): 22 | """Store parameter names and derivatives object.""" 23 | super().__init__(derivatives=Conv3DDerivatives(), params=["bias", "weight"]) 24 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/convtranspose1d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives 2 | from backpack.extensions.secondorder.diag_hessian.convtransposend import ( 3 | BatchDiagHConvTransposeND, 4 | DiagHConvTransposeND, 5 | ) 6 | 7 | 8 | class DiagHConvTranspose1d(DiagHConvTransposeND): 9 | def __init__(self): 10 | super().__init__( 11 | derivatives=ConvTranspose1DDerivatives(), params=["bias", "weight"] 12 | ) 13 | 14 | 15 | class BatchDiagHConvTranspose1d(BatchDiagHConvTransposeND): 16 | def __init__(self): 17 | super().__init__( 18 | derivatives=ConvTranspose1DDerivatives(), params=["bias", "weight"] 19 | ) 20 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/convtranspose2d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives 2 | from backpack.extensions.secondorder.diag_hessian.convtransposend import ( 3 | BatchDiagHConvTransposeND, 4 | DiagHConvTransposeND, 5 | ) 6 | 7 | 8 | class DiagHConvTranspose2d(DiagHConvTransposeND): 9 | def __init__(self): 10 | super().__init__( 11 | derivatives=ConvTranspose2DDerivatives(), params=["bias", "weight"] 12 | ) 13 | 14 | 15 | class BatchDiagHConvTranspose2d(BatchDiagHConvTransposeND): 16 | def __init__(self): 17 | super().__init__( 18 | derivatives=ConvTranspose2DDerivatives(), params=["bias", "weight"] 19 | ) 20 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/convtranspose3d.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives 2 | from backpack.extensions.secondorder.diag_hessian.convtransposend import ( 3 | BatchDiagHConvTransposeND, 4 | DiagHConvTransposeND, 5 | ) 6 | 7 | 8 | class DiagHConvTranspose3d(DiagHConvTransposeND): 9 | def __init__(self): 10 | super().__init__( 11 | derivatives=ConvTranspose3DDerivatives(), params=["bias", "weight"] 12 | ) 13 | 14 | 15 | class BatchDiagHConvTranspose3d(BatchDiagHConvTransposeND): 16 | def __init__(self): 17 | super().__init__( 18 | derivatives=ConvTranspose3DDerivatives(), params=["bias", "weight"] 19 | ) 20 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/dropout.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.dropout import DropoutDerivatives 2 | from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule 3 | 4 | 5 | class DiagHDropout(DiagHBaseModule): 6 | def __init__(self): 7 | super().__init__(derivatives=DropoutDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/flatten.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.flatten import FlattenDerivatives 2 | from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule 3 | 4 | 5 | class DiagHFlatten(DiagHBaseModule): 6 | def __init__(self): 7 | super().__init__(derivatives=FlattenDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/losses.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.bcewithlogitsloss import BCELossWithLogitsDerivatives 2 | from backpack.core.derivatives.crossentropyloss import CrossEntropyLossDerivatives 3 | from backpack.core.derivatives.mseloss import MSELossDerivatives 4 | from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule 5 | 6 | 7 | class DiagHLoss(DiagHBaseModule): 8 | def backpropagate(self, ext, module, g_inp, g_out, backproped): 9 | sqrt_H = self.derivatives.sqrt_hessian(module, g_inp, g_out) 10 | return {"matrices": [sqrt_H], "signs": [self.PLUS]} 11 | 12 | 13 | class DiagHMSELoss(DiagHLoss): 14 | def __init__(self): 15 | super().__init__(derivatives=MSELossDerivatives()) 16 | 17 | 18 | class DiagHCrossEntropyLoss(DiagHLoss): 19 | def __init__(self): 20 | super().__init__(derivatives=CrossEntropyLossDerivatives()) 21 | 22 | 23 | class DiagHBCEWithLogitsLoss(DiagHLoss): 24 | def __init__(self): 25 | super().__init__(derivatives=BCELossWithLogitsDerivatives()) 26 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/pad.py: -------------------------------------------------------------------------------- 1 | """Contains ``DiagH`` extension for BackPACK's custom ``Pad`` module.""" 2 | 3 | from backpack.core.derivatives.pad import PadDerivatives 4 | from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule 5 | 6 | 7 | class DiagHPad(DiagHBaseModule): 8 | """``DiagH`` extension for ``backpack.custom_modules.pad.Pad``.""" 9 | 10 | def __init__(self): 11 | """Pass derivatives for ``backpack.custom_modules.pad.Pad`` module.""" 12 | super().__init__(PadDerivatives()) 13 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/padding.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives 2 | from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule 3 | 4 | 5 | class DiagHZeroPad2d(DiagHBaseModule): 6 | def __init__(self): 7 | super().__init__(derivatives=ZeroPad2dDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/pooling.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.avgpool1d import AvgPool1DDerivatives 2 | from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives 3 | from backpack.core.derivatives.avgpool3d import AvgPool3DDerivatives 4 | from backpack.core.derivatives.maxpool1d import MaxPool1DDerivatives 5 | from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives 6 | from backpack.core.derivatives.maxpool3d import MaxPool3DDerivatives 7 | from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule 8 | 9 | 10 | class DiagHAvgPool1d(DiagHBaseModule): 11 | def __init__(self): 12 | super().__init__(derivatives=AvgPool1DDerivatives()) 13 | 14 | 15 | class DiagHAvgPool2d(DiagHBaseModule): 16 | def __init__(self): 17 | super().__init__(derivatives=AvgPool2DDerivatives()) 18 | 19 | 20 | class DiagHAvgPool3d(DiagHBaseModule): 21 | def __init__(self): 22 | super().__init__(derivatives=AvgPool3DDerivatives()) 23 | 24 | 25 | class DiagHMaxPool1d(DiagHBaseModule): 26 | def __init__(self): 27 | super().__init__(derivatives=MaxPool1DDerivatives()) 28 | 29 | 30 | class DiagHMaxPool2d(DiagHBaseModule): 31 | def __init__(self): 32 | super().__init__(derivatives=MaxPool2DDerivatives()) 33 | 34 | 35 | class DiagHMaxPool3d(DiagHBaseModule): 36 | def __init__(self): 37 | super().__init__(derivatives=MaxPool3DDerivatives()) 38 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/diag_hessian/slicing.py: -------------------------------------------------------------------------------- 1 | """Contains ``DiagH`` extension for BackPACK's custom ``Slicing`` module.""" 2 | 3 | from backpack.core.derivatives.slicing import SlicingDerivatives 4 | from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule 5 | 6 | 7 | class DiagHSlicing(DiagHBaseModule): 8 | """``DiagH`` extension for ``backpack.custom_modules.slicing.Slicing``.""" 9 | 10 | def __init__(self): 11 | """Pass derivatives for ``backpack.custom_modules.slicing.Slicing`` module.""" 12 | super().__init__(SlicingDerivatives()) 13 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/hbp/activations.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.relu import ReLUDerivatives 2 | from backpack.core.derivatives.sigmoid import SigmoidDerivatives 3 | from backpack.core.derivatives.tanh import TanhDerivatives 4 | from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule 5 | 6 | 7 | class HBPReLU(HBPBaseModule): 8 | def __init__(self): 9 | super().__init__(derivatives=ReLUDerivatives()) 10 | 11 | 12 | class HBPSigmoid(HBPBaseModule): 13 | def __init__(self): 14 | super().__init__(derivatives=SigmoidDerivatives()) 15 | 16 | 17 | class HBPTanh(HBPBaseModule): 18 | def __init__(self): 19 | super().__init__(derivatives=TanhDerivatives()) 20 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/hbp/conv1d.py: -------------------------------------------------------------------------------- 1 | """Kronecker approximations of the Hessian for 1d convolution layers.""" 2 | 3 | from backpack.extensions.secondorder.hbp.convnd import HBPConvNd 4 | 5 | 6 | class HBPConv1d(HBPConvNd): 7 | """Computes Kronecker-structured Hessian approximations for 1d convolutions.""" 8 | 9 | def __init__(self): 10 | """Instantiate base class with convolution dimension.""" 11 | super().__init__(N=1) 12 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/hbp/conv2d.py: -------------------------------------------------------------------------------- 1 | """Kronecker approximations of the Hessian for 2d convolution layers.""" 2 | 3 | from backpack.extensions.secondorder.hbp.convnd import HBPConvNd 4 | 5 | 6 | class HBPConv2d(HBPConvNd): 7 | """Compute Kronecker-structured Hessian approximations for 2d convolutions.""" 8 | 9 | def __init__(self): 10 | """Instantiate base class with convolution dimension.""" 11 | super().__init__(N=2) 12 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/hbp/conv3d.py: -------------------------------------------------------------------------------- 1 | """Kronecker approximations of the Hessian for 3d convolution layers.""" 2 | 3 | from backpack.extensions.secondorder.hbp.convnd import HBPConvNd 4 | 5 | 6 | class HBPConv3d(HBPConvNd): 7 | """Computes Kronecker-structured Hessian approximations for 3d convolutions.""" 8 | 9 | def __init__(self): 10 | """Instantiate base class with convolution dimension.""" 11 | super().__init__(N=3) 12 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/hbp/conv_transpose1d.py: -------------------------------------------------------------------------------- 1 | """Kronecker approximations of the Hessian for 1d transpose convolution layers.""" 2 | 3 | from backpack.extensions.secondorder.hbp.conv_transposend import HBPConvTransposeNd 4 | 5 | 6 | class HBPConvTranspose1d(HBPConvTransposeNd): 7 | """Compute Kronecker-structured Hessian proxies for 1d transpose convolutions.""" 8 | 9 | def __init__(self): 10 | """Instantiate base class with convolution dimension.""" 11 | super().__init__(N=1) 12 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/hbp/conv_transpose2d.py: -------------------------------------------------------------------------------- 1 | """Kronecker approximations of the Hessian for 2d transpose convolution layers.""" 2 | 3 | from backpack.extensions.secondorder.hbp.conv_transposend import HBPConvTransposeNd 4 | 5 | 6 | class HBPConvTranspose2d(HBPConvTransposeNd): 7 | """Compute Kronecker-structured Hessian proxies for 2d transpose convolutions.""" 8 | 9 | def __init__(self): 10 | """Instantiate base class with convolution dimension.""" 11 | super().__init__(N=2) 12 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/hbp/conv_transpose3d.py: -------------------------------------------------------------------------------- 1 | """Kronecker approximations of the Hessian for 3d transpose convolution layers.""" 2 | 3 | from backpack.extensions.secondorder.hbp.conv_transposend import HBPConvTransposeNd 4 | 5 | 6 | class HBPConvTranspose3d(HBPConvTransposeNd): 7 | """Compute Kronecker-structured Hessian proxies for 3d transpose convolutions.""" 8 | 9 | def __init__(self): 10 | """Instantiate base class with convolution dimension.""" 11 | super().__init__(N=3) 12 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/hbp/custom_module.py: -------------------------------------------------------------------------------- 1 | """Module extensions for custom properties of HBPBaseModule.""" 2 | 3 | from backpack.core.derivatives.scale_module import ScaleModuleDerivatives 4 | from backpack.core.derivatives.sum_module import SumModuleDerivatives 5 | from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule 6 | 7 | 8 | class HBPScaleModule(HBPBaseModule): 9 | """HBP extension for ScaleModule.""" 10 | 11 | def __init__(self): 12 | """Initialization.""" 13 | super().__init__(derivatives=ScaleModuleDerivatives()) 14 | 15 | 16 | class HBPSumModule(HBPBaseModule): 17 | """HBP extension for SumModule.""" 18 | 19 | def __init__(self): 20 | """Initialization.""" 21 | super().__init__(derivatives=SumModuleDerivatives()) 22 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/hbp/dropout.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.dropout import DropoutDerivatives 2 | from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule 3 | 4 | 5 | class HBPDropout(HBPBaseModule): 6 | def __init__(self): 7 | super().__init__(derivatives=DropoutDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/hbp/flatten.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.flatten import FlattenDerivatives 2 | from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule 3 | 4 | 5 | class HBPFlatten(HBPBaseModule): 6 | def __init__(self): 7 | super().__init__(derivatives=FlattenDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/hbp/padding.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives 2 | from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule 3 | 4 | 5 | class HBPZeroPad2d(HBPBaseModule): 6 | def __init__(self): 7 | super().__init__(derivatives=ZeroPad2dDerivatives()) 8 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/hbp/pooling.py: -------------------------------------------------------------------------------- 1 | from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives 2 | from backpack.core.derivatives.maxpool2d import MaxPool2DDerivatives 3 | from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule 4 | 5 | 6 | class HBPAvgPool2d(HBPBaseModule): 7 | def __init__(self): 8 | super().__init__(derivatives=AvgPool2DDerivatives()) 9 | 10 | 11 | class HBPMaxpool2d(HBPBaseModule): 12 | def __init__(self): 13 | super().__init__(derivatives=MaxPool2DDerivatives()) 14 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/sqrt_ggn/batchnorm_nd.py: -------------------------------------------------------------------------------- 1 | """``SqrtGGN{Exact, MC}`` extensions for ``BatchNormNd``.""" 2 | 3 | from typing import Tuple, Union 4 | 5 | from torch import Tensor 6 | from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d 7 | 8 | from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives 9 | from backpack.extensions.backprop_extension import BackpropExtension 10 | from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule 11 | from backpack.utils.errors import batch_norm_raise_error_if_train 12 | 13 | 14 | class SqrtGGNBatchNormNd(SqrtGGNBaseModule): 15 | """``SqrtGGN{Exact, MC}`` extension for ``BatchNormNd``.""" 16 | 17 | def __init__(self): 18 | """Initialization.""" 19 | super().__init__(BatchNormNdDerivatives(), ["weight", "bias"]) 20 | 21 | def check_hyperparameters_module_extension( 22 | self, 23 | ext: BackpropExtension, 24 | module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], 25 | g_inp: Tuple[Tensor], 26 | g_out: Tuple[Tensor], 27 | ) -> None: # noqa: D102 28 | batch_norm_raise_error_if_train(module) 29 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/sqrt_ggn/convnd.py: -------------------------------------------------------------------------------- 1 | """Contains extensions for convolution layers used by ``SqrtGGN{Exact, MC}``.""" 2 | 3 | from backpack.core.derivatives.conv1d import Conv1DDerivatives 4 | from backpack.core.derivatives.conv2d import Conv2DDerivatives 5 | from backpack.core.derivatives.conv3d import Conv3DDerivatives 6 | from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule 7 | 8 | 9 | class SqrtGGNConv1d(SqrtGGNBaseModule): 10 | """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv1d`` module.""" 11 | 12 | def __init__(self): 13 | """Pass derivatives for ``torch.nn.Conv1d`` module.""" 14 | super().__init__(Conv1DDerivatives(), params=["bias", "weight"]) 15 | 16 | 17 | class SqrtGGNConv2d(SqrtGGNBaseModule): 18 | """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv2d`` module.""" 19 | 20 | def __init__(self): 21 | """Pass derivatives for ``torch.nn.Conv2d`` module.""" 22 | super().__init__(Conv2DDerivatives(), params=["bias", "weight"]) 23 | 24 | 25 | class SqrtGGNConv3d(SqrtGGNBaseModule): 26 | """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Conv3d`` module.""" 27 | 28 | def __init__(self): 29 | """Pass derivatives for ``torch.nn.Conv3d`` module.""" 30 | super().__init__(Conv3DDerivatives(), params=["bias", "weight"]) 31 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/sqrt_ggn/convtransposend.py: -------------------------------------------------------------------------------- 1 | """Contains transpose convolution layer extensions used by ``SqrtGGN{Exact, MC}``.""" 2 | 3 | from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives 4 | from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives 5 | from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives 6 | from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule 7 | 8 | 9 | class SqrtGGNConvTranspose1d(SqrtGGNBaseModule): 10 | """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose1d`` module.""" 11 | 12 | def __init__(self): 13 | """Pass derivatives for ``torch.nn.ConvTranspose1d`` module.""" 14 | super().__init__(ConvTranspose1DDerivatives(), params=["bias", "weight"]) 15 | 16 | 17 | class SqrtGGNConvTranspose2d(SqrtGGNBaseModule): 18 | """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose2d`` module.""" 19 | 20 | def __init__(self): 21 | """Pass derivatives for ``torch.nn.ConvTranspose2d`` module.""" 22 | super().__init__(ConvTranspose2DDerivatives(), params=["bias", "weight"]) 23 | 24 | 25 | class SqrtGGNConvTranspose3d(SqrtGGNBaseModule): 26 | """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ConvTranspose3d`` module.""" 27 | 28 | def __init__(self): 29 | """Pass derivatives for ``torch.nn.ConvTranspose3d`` module.""" 30 | super().__init__(ConvTranspose3DDerivatives(), params=["bias", "weight"]) 31 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/sqrt_ggn/custom_module.py: -------------------------------------------------------------------------------- 1 | """``SqrtGGN{Exact, MC}`` extensions for BackPACK's custom modules.""" 2 | 3 | from backpack.core.derivatives.scale_module import ScaleModuleDerivatives 4 | from backpack.core.derivatives.sum_module import SumModuleDerivatives 5 | from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule 6 | 7 | 8 | class SqrtGGNScaleModule(SqrtGGNBaseModule): 9 | """``SqrtGGN{Exact, MC}`` extension for ``ScaleModule``.""" 10 | 11 | def __init__(self): 12 | """Initialization.""" 13 | super().__init__(derivatives=ScaleModuleDerivatives()) 14 | 15 | 16 | class SqrtGGNSumModule(SqrtGGNBaseModule): 17 | """``SqrtGGN{Exact, MC}`` extension for ``SumModule``.""" 18 | 19 | def __init__(self): 20 | """Initialization.""" 21 | super().__init__(derivatives=SumModuleDerivatives()) 22 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/sqrt_ggn/dropout.py: -------------------------------------------------------------------------------- 1 | """Contains extensions for dropout layers used by ``SqrtGGN{Exact, MC}``.""" 2 | 3 | from backpack.core.derivatives.dropout import DropoutDerivatives 4 | from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule 5 | 6 | 7 | class SqrtGGNDropout(SqrtGGNBaseModule): 8 | """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Dropout`` module.""" 9 | 10 | def __init__(self): 11 | """Pass derivatives for ``torch.nn.Dropout`` module.""" 12 | super().__init__(DropoutDerivatives()) 13 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/sqrt_ggn/embedding.py: -------------------------------------------------------------------------------- 1 | """Contains extension for the embedding layer used by ``SqrtGGN{Exact, MC}``.""" 2 | 3 | from backpack.core.derivatives.embedding import EmbeddingDerivatives 4 | from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule 5 | 6 | 7 | class SqrtGGNEmbedding(SqrtGGNBaseModule): 8 | """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Embedding`` module.""" 9 | 10 | def __init__(self): 11 | """Pass derivatives for ``torch.nn.Embedding`` module.""" 12 | super().__init__(EmbeddingDerivatives(), params=["weight"]) 13 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/sqrt_ggn/flatten.py: -------------------------------------------------------------------------------- 1 | """Contains extensions for the flatten layer used by ``SqrtGGN{Exact, MC}``.""" 2 | 3 | from backpack.core.derivatives.flatten import FlattenDerivatives 4 | from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule 5 | 6 | 7 | class SqrtGGNFlatten(SqrtGGNBaseModule): 8 | """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Flatten`` module.""" 9 | 10 | def __init__(self): 11 | """Pass derivatives for ``torch.nn.Flatten`` module.""" 12 | super().__init__(FlattenDerivatives()) 13 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/sqrt_ggn/linear.py: -------------------------------------------------------------------------------- 1 | """Contains extension for the linear layer used by ``SqrtGGN{Exact, MC}``.""" 2 | 3 | from backpack.core.derivatives.linear import LinearDerivatives 4 | from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule 5 | 6 | 7 | class SqrtGGNLinear(SqrtGGNBaseModule): 8 | """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.Linear`` module.""" 9 | 10 | def __init__(self): 11 | """Pass derivatives for ``torch.nn.Linear`` module.""" 12 | super().__init__(LinearDerivatives(), params=["bias", "weight"]) 13 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/sqrt_ggn/pad.py: -------------------------------------------------------------------------------- 1 | """Contains ``SqrtGGN{Exact, MC}`` extension for BackPACK's custom ``Pad`` module.""" 2 | 3 | from backpack.core.derivatives.pad import PadDerivatives 4 | from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule 5 | 6 | 7 | class SqrtGGNPad(SqrtGGNBaseModule): 8 | """``SqrtGGN{Exact, MC}`` extension for ``backpack.custom_modules.pad.Pad``.""" 9 | 10 | def __init__(self): 11 | """Pass derivatives for ``backpack.custom_modules.pad.Pad`` module.""" 12 | super().__init__(PadDerivatives()) 13 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/sqrt_ggn/padding.py: -------------------------------------------------------------------------------- 1 | """Contains extensions for padding layers used by ``SqrtGGN{Exact, MC}``.""" 2 | 3 | from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives 4 | from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule 5 | 6 | 7 | class SqrtGGNZeroPad2d(SqrtGGNBaseModule): 8 | """``SqrtGGN{Exact, MC}`` extension for ``torch.nn.ZeroPad2d`` module.""" 9 | 10 | def __init__(self): 11 | """Pass derivatives for ``torch.nn.ZeroPad2d`` module.""" 12 | super().__init__(ZeroPad2dDerivatives()) 13 | -------------------------------------------------------------------------------- /backpack/extensions/secondorder/sqrt_ggn/slicing.py: -------------------------------------------------------------------------------- 1 | """Holds ``SqrtGGN{Exact, MC}`` extension for BackPACK's custom ``Slicing`` module.""" 2 | 3 | from backpack.core.derivatives.slicing import SlicingDerivatives 4 | from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule 5 | 6 | 7 | class SqrtGGNSlicing(SqrtGGNBaseModule): 8 | """``SqrtGGN{Exact, MC}`` for ``backpack.custom_modules.slicing.Slicing``.""" 9 | 10 | def __init__(self): 11 | """Pass derivatives for ``backpack.custom_modules.pad.Pad`` module.""" 12 | super().__init__(SlicingDerivatives()) 13 | -------------------------------------------------------------------------------- /backpack/hessianfree/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A collection of helper functions for Hessian-free operations, 3 | (Jacobian, Jacobian Transpose, GGN, Hessian)-vector products, 4 | using Pytorch's with autograd. 5 | """ 6 | -------------------------------------------------------------------------------- /backpack/hessianfree/lop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def L_op(ys, xs, ws, retain_graph=True, detach=True): 5 | """ 6 | Multiplies the vector `ws` with the transposed Jacobian of `ys` w.r.t. `xs`. 7 | """ 8 | vJ = torch.autograd.grad( 9 | ys, 10 | xs, 11 | grad_outputs=ws, 12 | create_graph=True, 13 | retain_graph=retain_graph, 14 | allow_unused=True, 15 | materialize_grads=True, 16 | ) 17 | return tuple(j.detach() for j in vJ) if detach else vJ 18 | 19 | 20 | def transposed_jacobian_vector_product(f, x, v, retain_graph=True, detach=True): 21 | """ 22 | Multiplies the vector `v` with the (transposed) Jacobian of `f` w.r.t. `x`. 23 | 24 | Corresponds to the application of the L-operator. 25 | 26 | Parameters: 27 | ----------- 28 | f: torch.Tensor or [torch.Tensor] 29 | Outputs of the differentiated function. 30 | x: torch.Tensor or [torch.Tensor] 31 | Inputs w.r.t. which the gradient will be returned. 32 | v: torch.Tensor or [torch.Tensor] 33 | The vector to be multiplied by the transposed Jacobian. 34 | retain_graph: Bool, optional 35 | If False, the graph used to compute the grad will be freed. 36 | (default: True) 37 | detach: Bool, optional 38 | If True, the transposed Jacobian-vector product will be detached. 39 | (default: True) 40 | """ 41 | return L_op(f, x, v, retain_graph=retain_graph, detach=detach) 42 | -------------------------------------------------------------------------------- /backpack/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains utility functions.""" 2 | -------------------------------------------------------------------------------- /backpack/utils/errors.py: -------------------------------------------------------------------------------- 1 | """Contains errors for BackPACK.""" 2 | 3 | from typing import Union 4 | from warnings import warn 5 | 6 | from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d 7 | 8 | 9 | def batch_norm_raise_error_if_train( 10 | module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d], raise_error: bool = True 11 | ) -> None: 12 | """Check if BatchNorm module is in training mode. 13 | 14 | Args: 15 | module: BatchNorm module to check 16 | raise_error: whether to raise an error, alternatively warn. Default: True. 17 | 18 | Raises: 19 | NotImplementedError: if module is in training mode 20 | """ 21 | if module.training: 22 | message = ( 23 | "Encountered BatchNorm module in training mode. BackPACK's computation " 24 | "will pass, but results like individual gradients may not be meaningful, " 25 | "as BatchNorm mixes samples. Only proceed if you know what you are doing." 26 | ) 27 | if raise_error: 28 | raise NotImplementedError(message) 29 | else: 30 | warn(message) 31 | -------------------------------------------------------------------------------- /backpack/utils/hooks.py: -------------------------------------------------------------------------------- 1 | """Utility functions to handle the backpropagation.""" 2 | 3 | 4 | def no_op(*args, **kwargs): 5 | """Placeholder function that accepts arbitrary input and does nothing. 6 | 7 | Args: 8 | *args: anything 9 | **kwargs: anything 10 | """ 11 | pass 12 | -------------------------------------------------------------------------------- /backpack/utils/module_classification.py: -------------------------------------------------------------------------------- 1 | """Contains util function for classification of modules.""" 2 | 3 | from torch.fx import GraphModule 4 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss, Sequential 5 | from torch.nn.modules.loss import _Loss 6 | 7 | from backpack.custom_module.branching import Parallel, _Branch 8 | from backpack.custom_module.reduce_tuple import ReduceTuple 9 | 10 | 11 | def is_loss(module: Module) -> bool: 12 | """Return whether `module` is a `torch` loss function. 13 | 14 | Args: 15 | module: A PyTorch module. 16 | 17 | Returns: 18 | Whether `module` is a loss function. 19 | """ 20 | return isinstance(module, _Loss) 21 | 22 | 23 | def is_nll(module: Module) -> bool: 24 | """Return whether 'module' is an NLL loss function. 25 | 26 | Current NLL loss functions include MSE, CE and BCEWithLogits. 27 | 28 | Args: 29 | module: A PyTorch module. 30 | 31 | Returns: 32 | Whether 'module' is an NLL loss function 33 | """ 34 | return isinstance(module, (MSELoss, CrossEntropyLoss, BCEWithLogitsLoss)) 35 | 36 | 37 | def is_no_op(module: Module) -> bool: 38 | """Return whether the module does no operation in graph. 39 | 40 | Args: 41 | module: module 42 | 43 | Returns: 44 | whether module is no operation 45 | """ 46 | no_op_modules = (Sequential, _Branch, Parallel, ReduceTuple, GraphModule) 47 | return isinstance(module, no_op_modules) 48 | -------------------------------------------------------------------------------- /backpack/utils/subsampling.py: -------------------------------------------------------------------------------- 1 | """Utility functions to enable mini-batch subsampling in extensions.""" 2 | 3 | from typing import List 4 | 5 | from torch import Tensor 6 | 7 | 8 | def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Tensor: 9 | """Select samples from a tensor along a dimension. 10 | 11 | Args: 12 | tensor: Tensor to select from. 13 | dim: Selection dimension. Defaults to ``0``. 14 | subsampling: Indices of samples that are sliced along the dimension. 15 | Defaults to ``None`` (use all samples). 16 | 17 | Returns: 18 | Tensor of same rank that is sub-sampled along the dimension. 19 | """ 20 | if subsampling is None: 21 | return tensor 22 | else: 23 | return tensor[(slice(None),) * dim + (subsampling,)] 24 | -------------------------------------------------------------------------------- /backpack/utils/unsqueeze.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | 4 | def kfacmp_unsqueeze_if_missing_dim(mat_dim): 5 | """ 6 | Allows Kronecker-factored matrix-matrix routines to do matrix-vector products. 7 | """ 8 | 9 | def kfacmp_wrapper(kfacmp): 10 | @functools.wraps(kfacmp) 11 | def wrapped_kfacmp_support_kfacvp(mat): 12 | is_vec = len(mat.shape) == mat_dim - 1 13 | mat_used = mat.unsqueeze(-1) if is_vec else mat 14 | result = kfacmp(mat_used) 15 | if is_vec: 16 | return result.squeeze(-1) 17 | else: 18 | return result 19 | 20 | return wrapped_kfacmp_support_kfacvp 21 | 22 | return kfacmp_wrapper 23 | -------------------------------------------------------------------------------- /black.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | target-version = ['py35', 'py36', 'py37'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | ( 7 | /( 8 | \.eggs 9 | | \.git 10 | | \.pytest_cache 11 | | \.benchmarks 12 | | docs_src/rtd 13 | | docs_src/rtd_output 14 | | docs 15 | | build 16 | | dist 17 | )/ 18 | ) 19 | ''' 20 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/.nojekyll -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | backpack.pt -------------------------------------------------------------------------------- /docs/assets/dangel2020backpack.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{dangel2020backpack, 2 | title = {BackPACK: Packing more into Backprop}, 3 | author = {Felix Dangel and Frederik Kunstner and Philipp Hennig}, 4 | booktitle = {International Conference on Learning Representations}, 5 | year = {2020}, 6 | url = {https://openreview.net/forum?id=BJlrF24twB} 7 | } -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-700/Noto-Sans-700.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-700/Noto-Sans-700.eot -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-700/Noto-Sans-700.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-700/Noto-Sans-700.ttf -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-700/Noto-Sans-700.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-700/Noto-Sans-700.woff -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-700/Noto-Sans-700.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-700/Noto-Sans-700.woff2 -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.eot -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.ttf -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.woff -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.woff2 -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.eot -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.ttf -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.woff -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.woff2 -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.eot -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.ttf -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.woff -------------------------------------------------------------------------------- /docs/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.woff2 -------------------------------------------------------------------------------- /docs/assets/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/img/logo.png -------------------------------------------------------------------------------- /docs/assets/img/updaterule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs/assets/img/updaterule.png -------------------------------------------------------------------------------- /docs/assets/js/scale.fix.js: -------------------------------------------------------------------------------- 1 | (function(document) { 2 | var metas = document.getElementsByTagName('meta'), 3 | changeViewportContent = function(content) { 4 | for (var i = 0; i < metas.length; i++) { 5 | if (metas[i].name == "viewport") { 6 | metas[i].content = content; 7 | } 8 | } 9 | }, 10 | initialize = function() { 11 | changeViewportContent("width=device-width, minimum-scale=1.0, maximum-scale=1.0"); 12 | }, 13 | gestureStart = function() { 14 | changeViewportContent("width=device-width, minimum-scale=0.25, maximum-scale=1.6"); 15 | }, 16 | gestureEnd = function() { 17 | initialize(); 18 | }; 19 | 20 | 21 | if (navigator.userAgent.match(/iPhone/i)) { 22 | initialize(); 23 | 24 | document.addEventListener("touchstart", gestureStart, false); 25 | document.addEventListener("touchend", gestureEnd, false); 26 | } 27 | })(document); 28 | -------------------------------------------------------------------------------- /docs/jekyll-theme-minimal.gemspec: -------------------------------------------------------------------------------- 1 | # frozen_string_literal: true 2 | 3 | Gem::Specification.new do |s| 4 | s.name = 'jekyll-theme-minimal' 5 | s.version = '0.1.1' 6 | s.license = 'CC0-1.0' 7 | s.authors = ['Steve Smith', 'GitHub, Inc.'] 8 | s.email = ['opensource+jekyll-theme-minimal@github.com'] 9 | s.homepage = 'https://github.com/pages-themes/minimal' 10 | s.summary = 'Minimal is a Jekyll theme for GitHub Pages' 11 | 12 | s.files = `git ls-files -z`.split("\x0").select do |f| 13 | f.match(%r{^((_includes|_layouts|_sass|assets)/|(LICENSE|README)((\.(txt|md|markdown)|$)))}i) 14 | end 15 | 16 | s.platform = Gem::Platform::RUBY 17 | s.add_runtime_dependency 'jekyll', '> 3.5', '< 5.0' 18 | s.add_runtime_dependency 'jekyll-seo-tag', '~> 2.0' 19 | s.add_development_dependency 'html-proofer', '~> 3.0' 20 | s.add_development_dependency 'rubocop', '~> 0.50' 21 | s.add_development_dependency 'w3c_validators', '~> 1.3' 22 | end 23 | -------------------------------------------------------------------------------- /docs/script/bootstrap: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | gem install bundler 6 | bundle install 7 | -------------------------------------------------------------------------------- /docs/script/cibuild: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | bundle exec jekyll build 6 | bundle exec htmlproofer ./_site --check-html --check-sri 7 | bundle exec rubocop -D 8 | bundle exec script/validate-html 9 | gem build jekyll-theme-minimal.gemspec 10 | -------------------------------------------------------------------------------- /docs/script/release: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Tag and push a release. 3 | 4 | set -e 5 | 6 | # Make sure we're in the project root. 7 | 8 | cd $(dirname "$0")/.. 9 | 10 | # Make sure the darn thing works 11 | 12 | bundle update 13 | 14 | # Build a new gem archive. 15 | 16 | rm -rf jekyll-theme-minimal-*.gem 17 | gem build -q jekyll-theme-minimal.gemspec 18 | 19 | # Make sure we're on the master branch. 20 | 21 | (git branch | grep -q 'master') || { 22 | echo "Only release from the master branch." 23 | exit 1 24 | } 25 | 26 | # Figure out what version we're releasing. 27 | 28 | tag=v`ls jekyll-theme-minimal-*.gem | sed 's/^jekyll-theme-minimal-\(.*\)\.gem$/\1/'` 29 | 30 | # Make sure we haven't released this version before. 31 | 32 | git fetch -t origin 33 | 34 | (git tag -l | grep -q "$tag") && { 35 | echo "Whoops, there's already a '${tag}' tag." 36 | exit 1 37 | } 38 | 39 | # Tag it and bag it. 40 | 41 | gem push jekyll-theme-minimal-*.gem && git tag "$tag" && 42 | git push origin master && git push origin "$tag" 43 | -------------------------------------------------------------------------------- /docs/script/validate-html: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env ruby 2 | # frozen_string_literal: true 3 | 4 | require 'w3c_validators' 5 | 6 | def validator(file) 7 | extension = File.extname(file) 8 | if extension == '.html' 9 | W3CValidators::NuValidator.new 10 | elsif extension == '.css' 11 | W3CValidators::CSSValidator.new 12 | end 13 | end 14 | 15 | def validate(file) 16 | puts "Checking #{file}..." 17 | 18 | path = File.expand_path "../_site/#{file}", __dir__ 19 | results = validator(file).validate_file(path) 20 | 21 | return puts 'Valid!' if results.errors.empty? 22 | 23 | results.errors.each { |err| puts err.to_s } 24 | exit 1 25 | end 26 | 27 | validate 'index.html' 28 | validate File.join 'assets', 'css', 'style.css' 29 | -------------------------------------------------------------------------------- /docs_src/.gitignore: -------------------------------------------------------------------------------- 1 | rtd_output/* 2 | rtd/basic_usage/* 3 | rtd/use_cases/* 4 | examples/use_cases/data/* 5 | examples/basic_usage/data/* -------------------------------------------------------------------------------- /docs_src/CNAME: -------------------------------------------------------------------------------- 1 | backpack.pt -------------------------------------------------------------------------------- /docs_src/README.md: -------------------------------------------------------------------------------- 1 | **Building the web version** 2 | 3 | Requirements: [Jekyll](https://jekyllrb.com/docs/installation/) and [Sphinx](https://www.sphinx-doc.org/en/1.8/usage/installation.html) 4 | and installing the jekyll dependencies (`bundle install` in `docs_src/splash`) 5 | 6 | - Full build to output results in `../docs` 7 | ``` 8 | bash buildweb.sh 9 | ``` 10 | 11 | - Local build of the Jekyll splash page 12 | ``` 13 | cd splash 14 | bundle exec jekyll server 15 | ``` 16 | and go to `localhost:4000/backpack` 17 | 18 | Note: The code examples on backpack.pt are defined with HTML tags in 19 | `splash/_includes/code-samples.html`. 20 | There are no python source file to generate them. 21 | Test manually by copy-pasting from the resulting page. 22 | 23 | - Local build of the documentation 24 | ``` 25 | cd rtd 26 | make 27 | ``` 28 | and open `/docs_src/rtd_output/index.html` 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs_src/buildweb.sh: -------------------------------------------------------------------------------- 1 | cd splash 2 | bundle exec jekyll build -d "../../docs" 3 | cd .. 4 | touch ../docs/.nojekyll 5 | cp CNAME ../docs/CNAME -------------------------------------------------------------------------------- /docs_src/examples/basic_usage/README.rst: -------------------------------------------------------------------------------- 1 | Code samples 2 | ================== 3 | -------------------------------------------------------------------------------- /docs_src/examples/cheatsheet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/examples/cheatsheet.pdf -------------------------------------------------------------------------------- /docs_src/examples/use_cases/README.rst: -------------------------------------------------------------------------------- 1 | Code examples and use cases 2 | ============================ 3 | -------------------------------------------------------------------------------- /docs_src/examples/use_cases/example_first_order_resnet.py: -------------------------------------------------------------------------------- 1 | r"""First order extensions with a ResNet 2 | ======================================== 3 | """ 4 | 5 | # %% 6 | # This tutorial has moved. Click 7 | # `here `_ 8 | # to continue to its new location. 9 | -------------------------------------------------------------------------------- /docs_src/images/comp_graph.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/images/comp_graph.jpg -------------------------------------------------------------------------------- /docs_src/rtd/.gitignore: -------------------------------------------------------------------------------- 1 | rtd_output/* 2 | examples/* 3 | -------------------------------------------------------------------------------- /docs_src/rtd/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/rtd/.nojekyll -------------------------------------------------------------------------------- /docs_src/rtd/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = ../rtd_output 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs_src/rtd/assets/backpack_logo_torch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/rtd/assets/backpack_logo_torch.png -------------------------------------------------------------------------------- /docs_src/rtd/extensions.rst: -------------------------------------------------------------------------------- 1 | Available Extensions 2 | ==================================== 3 | 4 | .. automodule:: backpack.extensions.firstorder 5 | 6 | ----- 7 | 8 | .. autofunction:: backpack.extensions.BatchGrad 9 | .. autofunction:: backpack.extensions.BatchL2Grad 10 | .. autofunction:: backpack.extensions.SumGradSquared 11 | .. autofunction:: backpack.extensions.Variance 12 | 13 | ----- 14 | 15 | .. automodule:: backpack.extensions.secondorder 16 | 17 | ----- 18 | 19 | .. autofunction:: backpack.extensions.DiagGGNMC 20 | .. autofunction:: backpack.extensions.DiagGGNExact 21 | .. autofunction:: backpack.extensions.BatchDiagGGNMC 22 | .. autofunction:: backpack.extensions.BatchDiagGGNExact 23 | .. autofunction:: backpack.extensions.KFAC 24 | .. autofunction:: backpack.extensions.KFLR 25 | .. autofunction:: backpack.extensions.KFRA 26 | .. autofunction:: backpack.extensions.DiagHessian 27 | .. autofunction:: backpack.extensions.BatchDiagHessian 28 | .. autofunction:: backpack.extensions.SqrtGGNExact 29 | .. autofunction:: backpack.extensions.SqrtGGNMC 30 | 31 | ----- 32 | 33 | .. automodule:: backpack.extensions.curvmatprod 34 | 35 | ----- 36 | 37 | .. autofunction:: backpack.extensions.HMP 38 | .. autofunction:: backpack.extensions.GGNMP 39 | .. autofunction:: backpack.extensions.PCHMP 40 | -------------------------------------------------------------------------------- /docs_src/rtd/index.rst: -------------------------------------------------------------------------------- 1 | BackPACK 2 | ==================================== 3 | 4 | BackPACK is a library built on top of PyTorch 5 | to extract more information from a backward pass. 6 | 7 | .. code:: bash 8 | 9 | pip install backpack-for-pytorch 10 | 11 | For a quick overview of the features, check `backpack.pt `_. 12 | The code lives on `Github `_. 13 | 14 | 15 | .. toctree:: 16 | :maxdepth: 2 17 | :caption: Getting started 18 | 19 | main-api 20 | basic_usage/example_all_in_one 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | :caption: Backpack 25 | 26 | supported-layers 27 | extensions 28 | good-to-know 29 | use_cases/index 30 | 31 | 32 | -------------------------------------------------------------------------------- /docs_src/rtd/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=../rtd_output 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs_src/rtd/torch.inventory: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/rtd/torch.inventory -------------------------------------------------------------------------------- /docs_src/splash/.gitignore: -------------------------------------------------------------------------------- 1 | _site 2 | .sass-cache 3 | Gemfile.lock 4 | *.gem 5 | .jekyll-cache 6 | -------------------------------------------------------------------------------- /docs_src/splash/Gemfile: -------------------------------------------------------------------------------- 1 | # frozen_string_literal: true 2 | 3 | source 'https://rubygems.org' 4 | 5 | gemspec 6 | -------------------------------------------------------------------------------- /docs_src/splash/_config.yml: -------------------------------------------------------------------------------- 1 | title: BackPACK 2 | description: Get more out of your backward pass 3 | show_downloads: false 4 | theme: jekyll-theme-minimal 5 | baseurl: "/" -------------------------------------------------------------------------------- /docs_src/splash/_includes/dangel2020backpack.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{dangel2020backpack, 2 | title = {BackPACK: Packing more into Backprop}, 3 | author = {Felix Dangel and Frederik Kunstner and Philipp Hennig}, 4 | booktitle = {International Conference on Learning Representations}, 5 | year = {2020}, 6 | url = {https://openreview.net/forum?id=BJlrF24twB} 7 | } -------------------------------------------------------------------------------- /docs_src/splash/_layouts/post.html: -------------------------------------------------------------------------------- 1 | --- 2 | layout: default 3 | --- 4 | 5 | {{ page.date | date: "%-d %B %Y" }} 6 |

{{ page.title }}

7 | 8 |

by {{ page.author | default: site.author }}

9 | 10 | {{content}} 11 | 12 | {% if page.tags %} 13 | tags: {{ page.tags | join: " - " }} 14 | {% endif %} 15 | -------------------------------------------------------------------------------- /docs_src/splash/assets/dangel2020backpack.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{dangel2020backpack, 2 | title = {BackPACK: Packing more into Backprop}, 3 | author = {Felix Dangel and Frederik Kunstner and Philipp Hennig}, 4 | booktitle = {International Conference on Learning Representations}, 5 | year = {2020}, 6 | url = {https://openreview.net/forum?id=BJlrF24twB} 7 | } -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-700/Noto-Sans-700.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-700/Noto-Sans-700.eot -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-700/Noto-Sans-700.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-700/Noto-Sans-700.ttf -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-700/Noto-Sans-700.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-700/Noto-Sans-700.woff -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-700/Noto-Sans-700.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-700/Noto-Sans-700.woff2 -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.eot -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.ttf -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.woff -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-700italic/Noto-Sans-700italic.woff2 -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.eot -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.ttf -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.woff -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-italic/Noto-Sans-italic.woff2 -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.eot -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.ttf -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.woff -------------------------------------------------------------------------------- /docs_src/splash/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/fonts/Noto-Sans-regular/Noto-Sans-regular.woff2 -------------------------------------------------------------------------------- /docs_src/splash/assets/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/img/logo.png -------------------------------------------------------------------------------- /docs_src/splash/assets/img/updaterule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/docs_src/splash/assets/img/updaterule.png -------------------------------------------------------------------------------- /docs_src/splash/assets/js/scale.fix.js: -------------------------------------------------------------------------------- 1 | (function(document) { 2 | var metas = document.getElementsByTagName('meta'), 3 | changeViewportContent = function(content) { 4 | for (var i = 0; i < metas.length; i++) { 5 | if (metas[i].name == "viewport") { 6 | metas[i].content = content; 7 | } 8 | } 9 | }, 10 | initialize = function() { 11 | changeViewportContent("width=device-width, minimum-scale=1.0, maximum-scale=1.0"); 12 | }, 13 | gestureStart = function() { 14 | changeViewportContent("width=device-width, minimum-scale=0.25, maximum-scale=1.6"); 15 | }, 16 | gestureEnd = function() { 17 | initialize(); 18 | }; 19 | 20 | 21 | if (navigator.userAgent.match(/iPhone/i)) { 22 | initialize(); 23 | 24 | document.addEventListener("touchstart", gestureStart, false); 25 | document.addEventListener("touchend", gestureEnd, false); 26 | } 27 | })(document); 28 | -------------------------------------------------------------------------------- /docs_src/splash/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: default 3 | --- 4 | 5 | BackPACK is a library built on top of [PyTorch](https://pytorch.org/) 6 | to make it easy to extract more information from a backward pass. 7 | Some of the things you can compute: 8 | {% include code-samples.html %} 9 | 10 | --- 11 | 12 | **Install with** 13 | ``` 14 | pip install backpack-for-pytorch 15 | ``` 16 | 17 | --- 18 | 19 | If you use BackPACK in your research, please cite download bibtex 20 | 21 | ``` 22 | {% include dangel2020backpack.bib %} 23 | ``` 24 | 25 | -------------------------------------------------------------------------------- /docs_src/splash/jekyll-theme-minimal.gemspec: -------------------------------------------------------------------------------- 1 | # frozen_string_literal: true 2 | 3 | Gem::Specification.new do |s| 4 | s.name = 'jekyll-theme-minimal' 5 | s.version = '0.1.1' 6 | s.license = 'CC0-1.0' 7 | s.authors = ['Steve Smith', 'GitHub, Inc.'] 8 | s.email = ['opensource+jekyll-theme-minimal@github.com'] 9 | s.homepage = 'https://github.com/pages-themes/minimal' 10 | s.summary = 'Minimal is a Jekyll theme for GitHub Pages' 11 | 12 | s.files = `git ls-files -z`.split("\x0").select do |f| 13 | f.match(%r{^((_includes|_layouts|_sass|assets)/|(LICENSE|README)((\.(txt|md|markdown)|$)))}i) 14 | end 15 | 16 | s.platform = Gem::Platform::RUBY 17 | s.add_runtime_dependency 'jekyll', '> 3.5', '< 5.0' 18 | s.add_runtime_dependency 'jekyll-seo-tag', '~> 2.0' 19 | s.add_development_dependency 'html-proofer', '~> 3.0' 20 | s.add_development_dependency 'rubocop', '~> 0.50' 21 | s.add_development_dependency 'w3c_validators', '~> 1.3' 22 | end 23 | -------------------------------------------------------------------------------- /docs_src/splash/script/bootstrap: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | gem install bundler 6 | bundle install 7 | -------------------------------------------------------------------------------- /docs_src/splash/script/cibuild: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -e 4 | 5 | bundle exec jekyll build 6 | bundle exec htmlproofer ./_site --check-html --check-sri 7 | bundle exec rubocop -D 8 | bundle exec script/validate-html 9 | gem build jekyll-theme-minimal.gemspec 10 | -------------------------------------------------------------------------------- /docs_src/splash/script/release: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Tag and push a release. 3 | 4 | set -e 5 | 6 | # Make sure we're in the project root. 7 | 8 | cd $(dirname "$0")/.. 9 | 10 | # Make sure the darn thing works 11 | 12 | bundle update 13 | 14 | # Build a new gem archive. 15 | 16 | rm -rf jekyll-theme-minimal-*.gem 17 | gem build -q jekyll-theme-minimal.gemspec 18 | 19 | # Make sure we're on the master branch. 20 | 21 | (git branch | grep -q 'master') || { 22 | echo "Only release from the master branch." 23 | exit 1 24 | } 25 | 26 | # Figure out what version we're releasing. 27 | 28 | tag=v`ls jekyll-theme-minimal-*.gem | sed 's/^jekyll-theme-minimal-\(.*\)\.gem$/\1/'` 29 | 30 | # Make sure we haven't released this version before. 31 | 32 | git fetch -t origin 33 | 34 | (git tag -l | grep -q "$tag") && { 35 | echo "Whoops, there's already a '${tag}' tag." 36 | exit 1 37 | } 38 | 39 | # Tag it and bag it. 40 | 41 | gem push jekyll-theme-minimal-*.gem && git tag "$tag" && 42 | git push origin master && git push origin "$tag" 43 | -------------------------------------------------------------------------------- /docs_src/splash/script/validate-html: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env ruby 2 | # frozen_string_literal: true 3 | 4 | require 'w3c_validators' 5 | 6 | def validator(file) 7 | extension = File.extname(file) 8 | if extension == '.html' 9 | W3CValidators::NuValidator.new 10 | elsif extension == '.css' 11 | W3CValidators::CSSValidator.new 12 | end 13 | end 14 | 15 | def validate(file) 16 | puts "Checking #{file}..." 17 | 18 | path = File.expand_path "../_site/#{file}", __dir__ 19 | results = validator(file).validate_file(path) 20 | 21 | return puts 'Valid!' if results.errors.empty? 22 | 23 | results.errors.each { |err| puts err.to_s } 24 | exit 1 25 | end 26 | 27 | validate 'index.html' 28 | validate File.join 'assets', 'css', 'style.css' 29 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | # NOTE The documentation recommends to **not** configure pytest with setup.cfg 2 | # (https://docs.pytest.org/en/6.2.x/customize.html#setup-cfg) 3 | [pytest] 4 | optional_tests: 5 | montecarlo: slow tests using low-precision allclose after Monte-Carlo sampling 6 | filterwarnings = 7 | ignore:cannot collect test class 'TestProblem':pytest.PytestCollectionWarning: 8 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions for the test suite.""" 2 | 3 | from pytorch_memlab import MemReporter 4 | 5 | 6 | def pytorch_current_memory_usage(): 7 | """Return current memory usage in PyTorch (all devices).""" 8 | reporter = MemReporter() 9 | reporter.collect_tensor() 10 | reporter.get_stats() 11 | 12 | total_mem = 0 13 | for _, tensor_stats in reporter.device_tensor_stat.items(): 14 | for stat in tensor_stats: 15 | _, _, _, mem = stat 16 | total_mem += mem 17 | 18 | return total_mem 19 | -------------------------------------------------------------------------------- /test/adaptive_avg_pool/__init__.py: -------------------------------------------------------------------------------- 1 | """Module tests AdaptiveAvgPoolNDDerivatives. 2 | 3 | Especially the shape checker for equivalence with AvgPoolND. 4 | """ 5 | -------------------------------------------------------------------------------- /test/adaptive_avg_pool/settings_adaptive_avg_pool_nd.py: -------------------------------------------------------------------------------- 1 | """Settings to run test_adaptive_avg_pool_nd.""" 2 | 3 | from typing import Any, Dict, List 4 | 5 | from torch import Size 6 | 7 | SETTINGS: List[Dict[str, Any]] = [ 8 | { 9 | "N": 1, 10 | "shape_target": 2, 11 | "shape_input": (1, 5, 8), 12 | "works": True, 13 | }, 14 | { 15 | "N": 1, 16 | "shape_target": 2, 17 | "shape_input": (1, 8, 7), 18 | "works": False, 19 | }, 20 | { 21 | "N": 2, 22 | "shape_target": Size((4, 3)), 23 | "shape_input": (1, 64, 8, 9), 24 | "works": True, 25 | }, 26 | { 27 | "N": 2, 28 | "shape_target": 2, 29 | "shape_input": (1, 64, 8, 10), 30 | "works": True, 31 | }, 32 | { 33 | "N": 2, 34 | "shape_target": 2, 35 | "shape_input": (1, 64, 8, 9), 36 | "works": False, 37 | }, 38 | { 39 | "N": 2, 40 | "shape_target": (5, 2), 41 | "shape_input": (1, 64, 64, 10), 42 | "works": False, 43 | }, 44 | { 45 | "N": 3, 46 | "shape_target": (None, 2, None), 47 | "shape_input": (1, 64, 7, 10, 5), 48 | "works": True, 49 | }, 50 | ] 51 | -------------------------------------------------------------------------------- /test/adaptive_avg_pool/test_adaptive_avg_pool_nd.py: -------------------------------------------------------------------------------- 1 | """Test the shape checker of AdaptiveAvgPoolNDDerivatives.""" 2 | 3 | from test.adaptive_avg_pool.problem import AdaptiveAvgPoolProblem, make_test_problems 4 | from test.adaptive_avg_pool.settings_adaptive_avg_pool_nd import SETTINGS 5 | from typing import List 6 | 7 | import pytest 8 | 9 | PROBLEMS: List[AdaptiveAvgPoolProblem] = make_test_problems(SETTINGS) 10 | IDS: List[str] = [problem.make_id() for problem in PROBLEMS] 11 | 12 | 13 | @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) 14 | def test_adaptive_avg_pool_check_parameters(problem: AdaptiveAvgPoolProblem): 15 | """Test AdaptiveAvgPoolNDDerivatives.check_parameters(). 16 | 17 | Additionally check if returned parameters are indeed equivalent. 18 | 19 | Args: 20 | problem: test problem 21 | """ 22 | problem.set_up() 23 | if problem.works: 24 | problem.check_parameters() 25 | problem.check_equivalence() 26 | else: 27 | with pytest.raises(NotImplementedError): 28 | problem.check_parameters() 29 | problem.tear_down() 30 | -------------------------------------------------------------------------------- /test/benchmark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/test/benchmark/__init__.py -------------------------------------------------------------------------------- /test/benchmark/jvp_activations.py: -------------------------------------------------------------------------------- 1 | from torch import randn 2 | 3 | from backpack import extend 4 | 5 | 6 | def data(module_class, device="cpu"): 7 | N, D = 100, 200 8 | 9 | X = randn(N, D, requires_grad=True, device=device) 10 | module = extend(module_class()).to(device=device) 11 | out = module(X) 12 | 13 | v = randn(N, D, device=device) 14 | 15 | return { 16 | "X": X, 17 | "module": module, 18 | "output": out, 19 | "vout_ag": v, 20 | "vout_bp": v.unsqueeze(2), 21 | "vin_ag": v, 22 | "vin_bp": v.unsqueeze(2), 23 | } 24 | -------------------------------------------------------------------------------- /test/benchmark/jvp_avgpool2d.py: -------------------------------------------------------------------------------- 1 | from torch import randn 2 | from torch.nn import AvgPool2d 3 | 4 | from backpack import extend 5 | 6 | 7 | def data(device="cpu"): 8 | N, C, Hin, Win = 100, 10, 32, 32 9 | KernelSize = 4 10 | 11 | X = randn(N, C, Hin, Win, requires_grad=True, device=device) 12 | module = extend(AvgPool2d(KernelSize)).to(device=device) 13 | out = module(X) 14 | 15 | Hout = int(Hin / KernelSize) 16 | Wout = int(Win / KernelSize) 17 | vout = randn(N, C, Hin, Win, device=device) 18 | vin = randn(N, C, Hout, Wout, device=device) 19 | 20 | return { 21 | "X": X, 22 | "module": module, 23 | "output": out, 24 | "vout_ag": vout, 25 | "vout_bp": vout.view(N, -1, 1), 26 | "vin_ag": vin, 27 | "vin_bp": vin.view(N, -1, 1), 28 | } 29 | -------------------------------------------------------------------------------- /test/benchmark/jvp_conv2d.py: -------------------------------------------------------------------------------- 1 | from torch import randn 2 | from torch.nn import Conv2d 3 | 4 | from backpack import extend 5 | 6 | 7 | def data_conv2d(device="cpu"): 8 | N, Cin, Hin, Win = 100, 10, 32, 32 9 | Cout, KernelH, KernelW = 25, 5, 5 10 | 11 | X = randn(N, Cin, Hin, Win, requires_grad=True, device=device) 12 | module = extend(Conv2d(Cin, Cout, (KernelH, KernelW))).to(device=device) 13 | out = module(X) 14 | 15 | Hout = Hin - (KernelH - 1) 16 | Wout = Win - (KernelW - 1) 17 | vin = randn(N, Cout, Hout, Wout, device=device) 18 | vout = randn(N, Cin, Hin, Win, device=device) 19 | 20 | return { 21 | "X": X, 22 | "module": module, 23 | "output": out, 24 | "vout_ag": vout, 25 | "vout_bp": vout.view(N, -1, 1), 26 | "vin_ag": vin, 27 | "vin_bp": vin.view(N, -1, 1), 28 | } 29 | -------------------------------------------------------------------------------- /test/benchmark/jvp_linear.py: -------------------------------------------------------------------------------- 1 | from torch import randn 2 | from torch.nn import Linear 3 | 4 | from backpack import extend 5 | 6 | 7 | def data_linear(device="cpu"): 8 | N, D1, D2 = 100, 64, 256 9 | 10 | X = randn(N, D1, requires_grad=True, device=device) 11 | linear = extend(Linear(D1, D2).to(device=device)) 12 | out = linear(X) 13 | 14 | vin = randn(N, D2, device=device) 15 | vout = randn(N, D1, device=device) 16 | 17 | return { 18 | "X": X, 19 | "module": linear, 20 | "output": out, 21 | "vout_ag": vout, 22 | "vout_bp": vout.unsqueeze(2), 23 | "vin_ag": vin, 24 | "vin_bp": vin.unsqueeze(2), 25 | } 26 | -------------------------------------------------------------------------------- /test/benchmark/jvp_maxpool2d.py: -------------------------------------------------------------------------------- 1 | from torch import randn 2 | from torch.nn import MaxPool2d 3 | 4 | from backpack import extend 5 | 6 | 7 | def data(device="cpu"): 8 | N, C, Hin, Win = 100, 10, 32, 32 9 | KernelSize = 4 10 | 11 | X = randn(N, C, Hin, Win, requires_grad=True, device=device) 12 | module = extend(MaxPool2d(KernelSize)).to(device=device) 13 | out = module(X) 14 | 15 | Hout = int(Hin / KernelSize) 16 | Wout = int(Win / KernelSize) 17 | vout = randn(N, C, Hin, Win, device=device) 18 | vin = randn(N, C, Hout, Wout, device=device) 19 | 20 | return { 21 | "X": X, 22 | "module": module, 23 | "output": out, 24 | "vout_ag": vout, 25 | "vout_bp": vout.view(N, -1, 1), 26 | "vin_ag": vin, 27 | "vin_bp": vin.view(N, -1, 1), 28 | } 29 | -------------------------------------------------------------------------------- /test/benchmark/jvp_zeropad2d.py: -------------------------------------------------------------------------------- 1 | from torch import randn 2 | from torch.nn import ZeroPad2d 3 | 4 | from backpack import extend 5 | 6 | 7 | def data(device="cpu"): 8 | N, C, Hin, Win = 100, 10, 32, 32 9 | padding = [1, 2, 3, 4] 10 | Hout = Hin + padding[2] + padding[3] 11 | Wout = Win + padding[0] + padding[1] 12 | 13 | X = randn(N, C, Hin, Win, requires_grad=True, device=device) 14 | module = extend(ZeroPad2d(padding)).to(device=device) 15 | out = module(X) 16 | 17 | vout = randn(N, C, Hin, Win, device=device) 18 | vin = randn(N, C, Hout, Wout, device=device) 19 | 20 | return { 21 | "X": X, 22 | "module": module, 23 | "output": out, 24 | "vout_ag": vout, 25 | "vout_bp": vout.view(N, -1, 1), 26 | "vin_ag": vin, 27 | "vin_bp": vin.view(N, -1, 1), 28 | } 29 | -------------------------------------------------------------------------------- /test/converter/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains tests for the converter and ResNets.""" 2 | -------------------------------------------------------------------------------- /test/core/__init__.py: -------------------------------------------------------------------------------- 1 | """Test functionality of `backpack.core` module.""" 2 | -------------------------------------------------------------------------------- /test/core/derivatives/embedding_settings.py: -------------------------------------------------------------------------------- 1 | """Settings for testing derivatives of Embedding.""" 2 | 3 | from torch import randint 4 | from torch.nn import Embedding 5 | 6 | EMBEDDING_SETTINGS = [ 7 | { 8 | "module_fn": lambda: Embedding(3, 5), 9 | "input_fn": lambda: randint(0, 3, (4,)), 10 | }, 11 | { 12 | "module_fn": lambda: Embedding(5, 7), 13 | "input_fn": lambda: randint(0, 5, (8, 3, 3)), 14 | }, 15 | ] 16 | -------------------------------------------------------------------------------- /test/core/derivatives/lstm_settings.py: -------------------------------------------------------------------------------- 1 | """Test configurations for `backpack.core.derivatives` LSTM layers. 2 | 3 | Required entries: 4 | "module_fn" (callable): Contains a model constructed from `torch.nn` layers 5 | "input_fn" (callable): Used for specifying input function 6 | 7 | Optional entries: 8 | "target_fn" (callable): Fetches the groundtruth/target classes 9 | of regression/classification task 10 | "loss_function_fn" (callable): Loss function used in the model 11 | "device" [list(torch.device)]: List of devices to run the test on. 12 | "id_prefix" (str): Prefix to be included in the test name. 13 | "seed" (int): seed for the random number for torch.rand 14 | """ 15 | 16 | from torch import rand 17 | from torch.nn import LSTM 18 | 19 | LSTM_SETTINGS = [] 20 | 21 | ############################################################################### 22 | # test settings # 23 | ############################################################################### 24 | LSTM_SETTINGS += [ 25 | { 26 | "module_fn": lambda: LSTM(input_size=4, hidden_size=3, batch_first=True), 27 | "input_fn": lambda: rand(size=(3, 5, 4)), 28 | }, 29 | ] 30 | -------------------------------------------------------------------------------- /test/core/derivatives/permute_settings.py: -------------------------------------------------------------------------------- 1 | """Test configurations for `backpack.core.derivatives` Permute. 2 | 3 | Required entries: 4 | "module_fn" (callable): Contains a model constructed from `torch.nn` layers 5 | "input_fn" (callable): Used for specifying input function 6 | 7 | Optional entries: 8 | "target_fn" (callable): Fetches the groundtruth/target classes 9 | of regression/classification task 10 | "loss_function_fn" (callable): Loss function used in the model 11 | "device" [list(torch.device)]: List of devices to run the test on. 12 | "id_prefix" (str): Prefix to be included in the test name. 13 | "seed" (int): seed for the random number for torch.rand 14 | """ 15 | 16 | import torch 17 | 18 | from backpack.custom_module.permute import Permute 19 | 20 | PERMUTE_SETTINGS = [ 21 | { 22 | "module_fn": lambda: Permute(0, 1, 2), 23 | "input_fn": lambda: torch.rand(size=(1, 2, 3)), 24 | }, 25 | { 26 | "module_fn": lambda: Permute(0, 2, 1), 27 | "input_fn": lambda: torch.rand(size=(4, 3, 2)), 28 | }, 29 | { 30 | "module_fn": lambda: Permute(0, 3, 1, 2), 31 | "input_fn": lambda: torch.rand(size=(5, 4, 3, 2)), 32 | }, 33 | ] 34 | -------------------------------------------------------------------------------- /test/core/derivatives/pooling_adaptive_settings.py: -------------------------------------------------------------------------------- 1 | """Test configurations for `backpack.core.derivatives` for adaptive pooling layers. 2 | 3 | Required entries: 4 | "module_fn" (callable): Contains a model constructed from `torch.nn` layers 5 | "input_fn" (callable): Used for specifying input function 6 | 7 | Optional entries: 8 | "target_fn" (callable): Fetches the groundtruth/target classes 9 | of regression/classification task 10 | "loss_function_fn" (callable): Loss function used in the model 11 | "device" [list(torch.device)]: List of devices to run the test on. 12 | "id_prefix" (str): Prefix to be included in the test name. 13 | "seed" (int): seed for the random number for torch.rand 14 | """ 15 | 16 | import torch 17 | 18 | POOLING_ADAPTIVE_SETTINGS = [] 19 | 20 | ############################################################################### 21 | # test settings # 22 | ############################################################################### 23 | POOLING_ADAPTIVE_SETTINGS += [ 24 | { 25 | "module_fn": lambda: torch.nn.AdaptiveAvgPool1d(output_size=(3,)), 26 | "input_fn": lambda: torch.rand(size=(1, 4, 9)), 27 | }, 28 | { 29 | "module_fn": lambda: torch.nn.AdaptiveAvgPool2d(output_size=(3, 5)), 30 | "input_fn": lambda: torch.rand(size=(2, 3, 9, 20)), 31 | }, 32 | { 33 | "module_fn": lambda: torch.nn.AdaptiveAvgPool3d(output_size=(2, 2, 2)), 34 | "input_fn": lambda: torch.rand(size=(1, 3, 4, 8, 8)), 35 | }, 36 | ] 37 | -------------------------------------------------------------------------------- /test/core/derivatives/rnn_settings.py: -------------------------------------------------------------------------------- 1 | """Test configurations for `backpack.core.derivatives` RNN layers. 2 | 3 | Required entries: 4 | "module_fn" (callable): Contains a model constructed from `torch.nn` layers 5 | "input_fn" (callable): Used for specifying input function 6 | 7 | Optional entries: 8 | "target_fn" (callable): Fetches the groundtruth/target classes 9 | of regression/classification task 10 | "loss_function_fn" (callable): Loss function used in the model 11 | "device" [list(torch.device)]: List of devices to run the test on. 12 | "id_prefix" (str): Prefix to be included in the test name. 13 | "seed" (int): seed for the random number for torch.rand 14 | """ 15 | 16 | import torch 17 | 18 | RNN_SETTINGS = [ 19 | { 20 | "module_fn": lambda: torch.nn.RNN( 21 | input_size=4, hidden_size=3, batch_first=True 22 | ), 23 | "input_fn": lambda: torch.rand(size=(3, 5, 4)), 24 | }, 25 | ] 26 | -------------------------------------------------------------------------------- /test/core/derivatives/scale_module_settings.py: -------------------------------------------------------------------------------- 1 | """Test settings for ScaleModule derivatives.""" 2 | 3 | from torch import rand 4 | from torch.nn import Identity 5 | 6 | from backpack.custom_module.scale_module import ScaleModule 7 | 8 | SCALE_MODULE_SETTINGS = [ 9 | { 10 | "module_fn": lambda: ScaleModule(), 11 | "input_fn": lambda: rand(3, 4, 2), 12 | }, 13 | { 14 | "module_fn": lambda: ScaleModule(0.3), 15 | "input_fn": lambda: rand(3, 2), 16 | }, 17 | { 18 | "module_fn": lambda: ScaleModule(5.7), 19 | "input_fn": lambda: rand(2, 3), 20 | }, 21 | { 22 | "module_fn": lambda: Identity(), 23 | "input_fn": lambda: rand(3, 1, 2), 24 | }, 25 | ] 26 | -------------------------------------------------------------------------------- /test/core/derivatives/settings.py: -------------------------------------------------------------------------------- 1 | """Test cases for `backpack.core.derivatives`. 2 | 3 | Cases are divided into the following layer categories: 4 | 5 | - Activations 6 | - (Transposed) convolutions 7 | - Linear 8 | - Losses 9 | - Padding 10 | - Pooling 11 | """ 12 | 13 | from test.core.derivatives.activation_settings import ACTIVATION_SETTINGS 14 | from test.core.derivatives.convolution_settings import CONVOLUTION_SETTINGS 15 | from test.core.derivatives.linear_settings import LINEAR_SETTINGS 16 | from test.core.derivatives.loss_settings import LOSS_SETTINGS 17 | from test.core.derivatives.padding_settings import PADDING_SETTINGS 18 | from test.core.derivatives.pooling_adaptive_settings import POOLING_ADAPTIVE_SETTINGS 19 | from test.core.derivatives.pooling_settings import POOLING_SETTINGS 20 | 21 | SETTINGS = ( 22 | ACTIVATION_SETTINGS 23 | + CONVOLUTION_SETTINGS 24 | + LINEAR_SETTINGS 25 | + LOSS_SETTINGS 26 | + PADDING_SETTINGS 27 | + POOLING_SETTINGS 28 | + POOLING_ADAPTIVE_SETTINGS 29 | ) 30 | -------------------------------------------------------------------------------- /test/core/derivatives/slicing_settings.py: -------------------------------------------------------------------------------- 1 | """Contains test cases of BackPACK's custom Slicing module.""" 2 | 3 | from torch import rand 4 | 5 | from backpack.custom_module.slicing import Slicing 6 | 7 | CUSTOM_SLICING_SETTINGS = [ 8 | { 9 | "module_fn": lambda: Slicing((slice(None), 0)), 10 | "input_fn": lambda: rand(size=(2, 4, 2, 5)), 11 | }, 12 | { 13 | "module_fn": lambda: Slicing((slice(None),)), 14 | "input_fn": lambda: rand(size=(3, 4, 2, 5)), 15 | }, 16 | { 17 | "module_fn": lambda: Slicing((slice(None), 2)), 18 | "input_fn": lambda: rand(size=(3, 4, 2, 5)), 19 | }, 20 | { 21 | "module_fn": lambda: Slicing((slice(None), 2, slice(1, 2), slice(0, 5, 2))), 22 | "input_fn": lambda: rand(size=(3, 4, 2, 5)), 23 | }, 24 | ] 25 | -------------------------------------------------------------------------------- /test/custom_module/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains tests of BackPACK's custom modules.""" 2 | -------------------------------------------------------------------------------- /test/custom_module/test_slicing.py: -------------------------------------------------------------------------------- 1 | """Tests for ``backpack.custom_module.slicing.Slicing``.""" 2 | 3 | from test.core.derivatives.utils import get_available_devices 4 | from typing import Dict 5 | 6 | import torch 7 | from pytest import mark 8 | from torch import allclose, manual_seed, rand 9 | 10 | from backpack.custom_module.slicing import Slicing 11 | 12 | CONFIGURATIONS = [ 13 | { 14 | "input_fn": lambda: rand(2, 3, 4, 5), 15 | "slice_info": (0,), 16 | "seed": 0, 17 | }, 18 | { 19 | "input_fn": lambda: rand(2, 3, 4, 5), 20 | "slice_info": (slice(None), 0, slice(0, 2), slice(1, 5, 2)), 21 | "seed": 1, 22 | }, 23 | { 24 | "input_fn": lambda: rand(2, 3, 4, 5), 25 | "slice_info": (slice(1, 2), 0, slice(None), slice(1, 5, 2)), 26 | "seed": 1, 27 | }, 28 | ] 29 | DEVICES = get_available_devices() 30 | 31 | 32 | @mark.parametrize("device", DEVICES, ids=str) 33 | @mark.parametrize("config", CONFIGURATIONS, ids=str) 34 | def test_slicing_forward(config: Dict, device: torch.device): 35 | """Test forward pass of the custom slicing module. 36 | 37 | Args: 38 | config: Dictionary specifying the test case. 39 | device: Device to execute the test on. 40 | """ 41 | manual_seed(config["seed"]) 42 | 43 | input = config["input_fn"]().to(device) 44 | slice_info = config["slice_info"] 45 | 46 | layer = Slicing(slice_info).to(device) 47 | 48 | assert allclose(layer(input), input[slice_info]) 49 | -------------------------------------------------------------------------------- /test/extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/test/extensions/__init__.py -------------------------------------------------------------------------------- /test/extensions/firstorder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/test/extensions/firstorder/__init__.py -------------------------------------------------------------------------------- /test/extensions/firstorder/batch_grad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/test/extensions/firstorder/batch_grad/__init__.py -------------------------------------------------------------------------------- /test/extensions/firstorder/batch_grad/batch_grad_settings.py: -------------------------------------------------------------------------------- 1 | """Test cases for BackPACK's ``BatchGrad`` extension. 2 | 3 | The tests are taken from ``test.extensions.firstorder.firstorder_settings``, 4 | but additional custom tests can be defined here by appending it to the list. 5 | """ 6 | 7 | from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS 8 | 9 | SHARED_SETTINGS = FIRSTORDER_SETTINGS 10 | LOCAL_SETTINGS = [] 11 | 12 | BATCH_GRAD_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS 13 | -------------------------------------------------------------------------------- /test/extensions/firstorder/batch_l2_grad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/test/extensions/firstorder/batch_l2_grad/__init__.py -------------------------------------------------------------------------------- /test/extensions/firstorder/batch_l2_grad/batchl2grad_settings.py: -------------------------------------------------------------------------------- 1 | """Test configurations to test batch_l2_grad 2 | 3 | The tests are taken from `test.extensions.firstorder.firstorder_settings`, 4 | but additional custom tests can be defined here by appending it to the list. 5 | """ 6 | 7 | from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS 8 | 9 | SHARED_SETTINGS = FIRSTORDER_SETTINGS 10 | LOCAL_SETTINGS = [] 11 | 12 | BATCHl2GRAD_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS 13 | -------------------------------------------------------------------------------- /test/extensions/firstorder/sum_grad_squared/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/test/extensions/firstorder/sum_grad_squared/__init__.py -------------------------------------------------------------------------------- /test/extensions/firstorder/variance/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains tests for BackPACK's ``Variance`` extension.""" 2 | -------------------------------------------------------------------------------- /test/extensions/firstorder/variance/test_variance.py: -------------------------------------------------------------------------------- 1 | """Test BackPACK's ``Variance`` extension.""" 2 | 3 | from test.automated_test import check_sizes_and_values 4 | from test.extensions.firstorder.variance.variance_settings import VARIANCE_SETTINGS 5 | from test.extensions.implementation.autograd import AutogradExtensions 6 | from test.extensions.implementation.backpack import BackpackExtensions 7 | from test.extensions.problem import ExtensionsTestProblem, make_test_problems 8 | 9 | import pytest 10 | 11 | PROBLEMS = make_test_problems(VARIANCE_SETTINGS) 12 | IDS = [problem.make_id() for problem in PROBLEMS] 13 | 14 | 15 | @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) 16 | def test_variance(problem: ExtensionsTestProblem) -> None: 17 | """Test variance of individual gradients. 18 | 19 | Args: 20 | problem: Test case. 21 | """ 22 | problem.set_up() 23 | 24 | backpack_res = BackpackExtensions(problem).variance() 25 | autograd_res = AutogradExtensions(problem).variance() 26 | 27 | rtol = 5e-5 28 | check_sizes_and_values(autograd_res, backpack_res, rtol=rtol) 29 | problem.tear_down() 30 | -------------------------------------------------------------------------------- /test/extensions/firstorder/variance/variance_settings.py: -------------------------------------------------------------------------------- 1 | """Test cases for ``Variance`` extension. 2 | 3 | Uses shared test cases from `test.extensions.firstorder.firstorder_settings`, 4 | and the local cases defined in this file. 5 | """ 6 | 7 | from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS 8 | 9 | SHARED_SETTINGS = FIRSTORDER_SETTINGS 10 | LOCAL_SETTINGS = [] 11 | 12 | VARIANCE_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS 13 | -------------------------------------------------------------------------------- /test/extensions/implementation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/test/extensions/implementation/__init__.py -------------------------------------------------------------------------------- /test/extensions/secondorder/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for `backpack.extensions.secondorder`.""" 2 | -------------------------------------------------------------------------------- /test/extensions/secondorder/diag_ggn/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for `backpack.extensions.secondorder.diag_ggn`.""" 2 | -------------------------------------------------------------------------------- /test/extensions/secondorder/diag_hessian/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for `backpack.extensions.secondorder.diag_hessian`.""" 2 | -------------------------------------------------------------------------------- /test/extensions/secondorder/diag_hessian/diagh_settings.py: -------------------------------------------------------------------------------- 1 | """Test cases for DiagHessian and BatchDiagHessian extensions. 2 | 3 | The tests are taken from `test.extensions.secondorder.secondorder_settings`, 4 | but additional custom tests can be defined here by appending it to the list. 5 | """ 6 | 7 | from test.extensions.automated_settings import ( 8 | make_simple_act_setting, 9 | make_simple_pooling_setting, 10 | ) 11 | from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS 12 | 13 | from torch.nn import ( 14 | AdaptiveAvgPool1d, 15 | AdaptiveAvgPool2d, 16 | AdaptiveAvgPool3d, 17 | Conv1d, 18 | Conv2d, 19 | Conv3d, 20 | LogSigmoid, 21 | ) 22 | 23 | SHARED_SETTINGS = SECONDORDER_SETTINGS 24 | LOCAL_SETTINGS = [ 25 | make_simple_act_setting(LogSigmoid, bias=True), 26 | make_simple_act_setting(LogSigmoid, bias=False), 27 | ] 28 | 29 | ############################################################################### 30 | # test setting: Adaptive Pooling Layers # 31 | ############################################################################### 32 | LOCAL_SETTINGS += [ 33 | make_simple_pooling_setting((3, 3, 7), Conv1d, AdaptiveAvgPool1d, (2,)), 34 | make_simple_pooling_setting((3, 3, 11, 11), Conv2d, AdaptiveAvgPool2d, (2,)), 35 | make_simple_pooling_setting((3, 3, 7, 7, 7), Conv3d, AdaptiveAvgPool3d, (2,)), 36 | ] 37 | 38 | 39 | DiagHESSIAN_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS 40 | -------------------------------------------------------------------------------- /test/extensions/secondorder/diag_hessian/test_diag_hessian.py: -------------------------------------------------------------------------------- 1 | from test.automated_test import check_sizes_and_values 2 | from test.extensions.implementation.autograd import AutogradExtensions 3 | from test.extensions.implementation.backpack import BackpackExtensions 4 | from test.extensions.problem import make_test_problems 5 | from test.extensions.secondorder.diag_hessian.diagh_settings import DiagHESSIAN_SETTINGS 6 | 7 | import pytest 8 | 9 | PROBLEMS = make_test_problems(DiagHESSIAN_SETTINGS) 10 | IDS = [problem.make_id() for problem in PROBLEMS] 11 | 12 | 13 | @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) 14 | def test_diag_h(problem): 15 | """Test Diagonal of Hessian 16 | 17 | Args: 18 | problem (ExtensionsTestProblem): Problem for extension test. 19 | """ 20 | problem.set_up() 21 | 22 | backpack_res = BackpackExtensions(problem).diag_h() 23 | autograd_res = AutogradExtensions(problem).diag_h() 24 | 25 | check_sizes_and_values(autograd_res, backpack_res) 26 | problem.tear_down() 27 | 28 | 29 | @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) 30 | def test_diag_h_batch(problem): 31 | """Test Diagonal of Hessian 32 | 33 | Args: 34 | problem (ExtensionsTestProblem): Problem for extension test. 35 | """ 36 | problem.set_up() 37 | 38 | backpack_res = BackpackExtensions(problem).diag_h_batch() 39 | autograd_res = AutogradExtensions(problem).diag_h_batch() 40 | 41 | check_sizes_and_values(autograd_res, backpack_res) 42 | problem.tear_down() 43 | -------------------------------------------------------------------------------- /test/extensions/secondorder/hbp/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``backpack.extensions.secondorder.hbp`` (Kronecker curvatures).""" 2 | -------------------------------------------------------------------------------- /test/extensions/secondorder/hbp/kflr_settings.py: -------------------------------------------------------------------------------- 1 | """Define test cases for KFLR.""" 2 | 3 | from test.extensions.secondorder.hbp.kfac_settings import ( 4 | _BATCH_SIZE_1_NO_BRANCHING_SETTINGS, 5 | ) 6 | from test.extensions.secondorder.secondorder_settings import ( 7 | GROUP_CONV_SETTINGS, 8 | LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS, 9 | ) 10 | 11 | SHARED_NOT_SUPPORTED_SETTINGS = ( 12 | GROUP_CONV_SETTINGS + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS 13 | ) 14 | LOCAL_NOT_SUPPORTED_SETTINGS = [] 15 | 16 | NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS 17 | 18 | BATCH_SIZE_1_SETTINGS = _BATCH_SIZE_1_NO_BRANCHING_SETTINGS 19 | -------------------------------------------------------------------------------- /test/extensions/secondorder/hbp/kfra_settings.py: -------------------------------------------------------------------------------- 1 | """Define test cases for KFRA.""" 2 | 3 | from test.extensions.secondorder.hbp.kfac_settings import ( 4 | _BATCH_SIZE_1_NO_BRANCHING_SETTINGS, 5 | ) 6 | from test.extensions.secondorder.secondorder_settings import ( 7 | GROUP_CONV_SETTINGS, 8 | LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS, 9 | ) 10 | 11 | SHARED_NOT_SUPPORTED_SETTINGS = ( 12 | GROUP_CONV_SETTINGS + LINEAR_ADDITIONAL_DIMENSIONS_SETTINGS 13 | ) 14 | LOCAL_NOT_SUPPORTED_SETTINGS = [] 15 | 16 | NOT_SUPPORTED_SETTINGS = SHARED_NOT_SUPPORTED_SETTINGS + LOCAL_NOT_SUPPORTED_SETTINGS 17 | 18 | BATCH_SIZE_1_SETTINGS = _BATCH_SIZE_1_NO_BRANCHING_SETTINGS 19 | -------------------------------------------------------------------------------- /test/extensions/secondorder/sqrt_ggn/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains tests of ``backpack.extensions.secondorder.sqrt_ggn``.""" 2 | -------------------------------------------------------------------------------- /test/extensions/test_backprop_extension.py: -------------------------------------------------------------------------------- 1 | """Test custom extensions for backprop_extension.""" 2 | 3 | import pytest 4 | from torch.nn import Linear, Module 5 | 6 | from backpack.extensions import BatchGrad, Variance 7 | from backpack.extensions.firstorder.base import FirstOrderModuleExtension 8 | 9 | 10 | def test_set_custom_extension(): 11 | """Test the method set_custom_extension of BackpropExtension.""" 12 | 13 | class _A(Module): 14 | pass 15 | 16 | class _ABatchGrad(FirstOrderModuleExtension): 17 | pass 18 | 19 | class _AVariance(FirstOrderModuleExtension): 20 | pass 21 | 22 | class _MyLinearBatchGrad(FirstOrderModuleExtension): 23 | pass 24 | 25 | grad_batch = BatchGrad() 26 | 27 | # Set module extension 28 | grad_batch.set_module_extension(_A, _ABatchGrad()) 29 | 30 | # setting again should raise a ValueError 31 | with pytest.raises(ValueError): 32 | grad_batch.set_module_extension(_A, _ABatchGrad()) 33 | 34 | # setting again with overwrite 35 | grad_batch.set_module_extension(_A, _ABatchGrad(), overwrite=True) 36 | 37 | # in a different extension, set another extension for the same module 38 | variance = Variance() 39 | variance.set_module_extension(_A, _AVariance()) 40 | 41 | # set an extension for an already existing extension 42 | with pytest.raises(ValueError): 43 | grad_batch.set_module_extension(Linear, _MyLinearBatchGrad()) 44 | 45 | grad_batch.set_module_extension(Linear, _MyLinearBatchGrad(), overwrite=True) 46 | -------------------------------------------------------------------------------- /test/extensions/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for testing BackPACK's extensions.""" 2 | 3 | from test.extensions.problem import ExtensionsTestProblem 4 | from typing import List, Union 5 | 6 | from pytest import skip 7 | 8 | 9 | def skip_if_subsampling_conflict( 10 | problem: ExtensionsTestProblem, subsampling: Union[List[int], None] 11 | ) -> None: 12 | """Skip if some samples in subsampling are not contained in input. 13 | 14 | Args: 15 | problem: Test case. 16 | subsampling: Indices of active samples. 17 | """ 18 | N = problem.input.shape[0] 19 | enough_samples = subsampling is None or N > max(subsampling) 20 | if not enough_samples: 21 | skip(f"Not enough samples: N={N}, subsampling={subsampling}") 22 | -------------------------------------------------------------------------------- /test/hessianfree/__init__.py: -------------------------------------------------------------------------------- 1 | """Test matrix-free multiplication methods.""" 2 | -------------------------------------------------------------------------------- /test/hessianfree/test_ggnvp.py: -------------------------------------------------------------------------------- 1 | """Test multiplication with the GGN.""" 2 | 3 | from torch import zeros, zeros_like 4 | from torch.autograd import grad 5 | from torch.nn import Linear 6 | 7 | from backpack.hessianfree.ggnvp import ggn_vector_product 8 | 9 | 10 | def test_ggnvp_no_explicit_dependency(): 11 | """Test GGN-vector-product when the graph is independent of a parameter.""" 12 | x = zeros(1, requires_grad=True) 13 | f = Linear(1, 1) 14 | 15 | y = f(x) 16 | # does not depend on the linear layer's bias 17 | (dy_dx,) = grad(y, x, create_graph=True) 18 | loss = (dy_dx**2).sum() 19 | 20 | # multiply the GGN onto a vector 21 | v = [zeros_like(p) for p in f.parameters()] 22 | ggn_vector_product(loss, dy_dx, f, v) 23 | -------------------------------------------------------------------------------- /test/implementation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/test/implementation/__init__.py -------------------------------------------------------------------------------- /test/implementation/implementation.py: -------------------------------------------------------------------------------- 1 | class Implementation: 2 | def __init__(self, test_problem, device=None): 3 | self.problem = test_problem 4 | self.model = self.problem.model 5 | self.N = self.problem.N 6 | if device is not None: 7 | self.problem.to(device) 8 | self.device = device 9 | else: 10 | self.device = self.problem.device 11 | 12 | def to(self, device): 13 | self.model.to(device) 14 | return self 15 | 16 | def loss(self, b=None): 17 | return self.problem.loss(b) 18 | 19 | def clear(self): 20 | self.problem.clear() 21 | 22 | def gradient(self): 23 | raise NotImplementedError 24 | 25 | def batch_gradients(self): 26 | raise NotImplementedError 27 | 28 | def batch_l2(self): 29 | raise NotImplementedError 30 | 31 | def variance(self): 32 | raise NotImplementedError 33 | 34 | def sgs(self): 35 | raise NotImplementedError 36 | 37 | def diag_ggn(self): 38 | raise NotImplementedError 39 | 40 | def diag_h(self): 41 | raise NotImplementedError 42 | 43 | def hmp(self, mat_list): 44 | raise NotImplementedError 45 | -------------------------------------------------------------------------------- /test/layers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | LINEARS = { 4 | "Linear": nn.Linear, 5 | } 6 | 7 | ACTIVATIONS = { 8 | "ReLU": nn.ReLU, 9 | "Sigmoid": nn.Sigmoid, 10 | "Tanh": nn.Tanh, 11 | } 12 | 13 | CONVS = { 14 | "Conv2d": nn.Conv2d, 15 | } 16 | 17 | PADDINGS = { 18 | "ZeroPad2d": nn.ZeroPad2d, 19 | } 20 | 21 | POOLINGS = { 22 | "MaxPool2d": nn.MaxPool2d, 23 | "AvgPool2d": nn.AvgPool2d, 24 | } 25 | 26 | BN = { 27 | "BatchNorm1d": nn.BatchNorm1d, 28 | } 29 | -------------------------------------------------------------------------------- /test/layers_test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/backpack/8aad99d50ce9d6cb8b4924535f262a99eab22cad/test/layers_test.py -------------------------------------------------------------------------------- /test/readme.md: -------------------------------------------------------------------------------- 1 | # Testing 2 | Automated testing based on [`pytest`](https://docs.pytest.org/en/latest/). 3 | Install with `pip install pytest`, run tests with `pytest` from this directory. 4 | 5 | Useful options: 6 | ``` 7 | -v verbose output 8 | -k text select tests containing text in their name 9 | -x stop if a test fails 10 | --tb=no disable trace output 11 | --help 12 | ``` 13 | 14 | ## Optional tests 15 | Uses [`pytest-optional-tests`](https://pypi.org/project/pytest-optional-tests) f 16 | or optional tests. Install with `pip install pytest-optional-tests`. 17 | 18 | Optional test categories are defined in `pytest.ini` 19 | and tests are marked with `@pytest.mark.OPTIONAL_TEST_CATEGORY`. 20 | 21 | To run the optional tests, use 22 | `pytest --run-optional-tests=OPTIONAL_TEST_CATEGORY` 23 | 24 | ## Run all tests for BackPACK 25 | In working directory `tests/`, run 26 | ```bash 27 | pytest -vx --run-optional-tests=montecarlo . 28 | ``` 29 | -------------------------------------------------------------------------------- /test/test_batch_first.py: -------------------------------------------------------------------------------- 1 | """Tests whether batch axis is always first.""" 2 | 3 | from pytest import raises 4 | 5 | from backpack.custom_module.permute import Permute 6 | 7 | 8 | def test_permute_batch_axis() -> None: 9 | """Verify that an Error is raised in the correct settings.""" 10 | Permute(0, 1, 2) 11 | Permute(0, 2, 1) 12 | Permute(0, 2, 3, 1) 13 | with raises(ValueError): 14 | Permute(1, 0, 2) 15 | with raises(ValueError): 16 | Permute(2, 0, 1) 17 | 18 | Permute(1, 2, init_transpose=True) 19 | Permute(3, 1, init_transpose=True) 20 | Permute(2, 1, init_transpose=True) 21 | with raises(ValueError): 22 | Permute(0, 1, init_transpose=True) 23 | with raises(ValueError): 24 | Permute(1, 0, init_transpose=True) 25 | with raises(ValueError): 26 | Permute(2, 0, init_transpose=True) 27 | -------------------------------------------------------------------------------- /test/test_problems_activations.py: -------------------------------------------------------------------------------- 1 | from .layers import ACTIVATIONS, LINEARS 2 | from .networks import single_linear_layer, two_linear_layers 3 | from .problems import make_classification_problem, make_regression_problem 4 | 5 | TEST_SETTINGS = { 6 | "in_features": 7, 7 | "out_features": 3, 8 | "out_features2": 3, 9 | "bias": True, 10 | "batch": 5, 11 | "rtol": 1e-5, 12 | "atol": 1e-5, 13 | } 14 | INPUT_SHAPE = (TEST_SETTINGS["batch"], TEST_SETTINGS["in_features"]) 15 | 16 | TEST_PROBLEMS = {} 17 | 18 | for act_name, act_cls in ACTIVATIONS.items(): 19 | for lin_name, lin_cls in LINEARS.items(): 20 | TEST_PROBLEMS["{}{}-regression".format(lin_name, act_name)] = ( 21 | make_regression_problem( 22 | INPUT_SHAPE, 23 | single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), 24 | ) 25 | ) 26 | 27 | TEST_PROBLEMS["{}{}-classification".format(lin_name, act_name)] = ( 28 | make_classification_problem( 29 | INPUT_SHAPE, 30 | single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), 31 | ) 32 | ) 33 | 34 | TEST_PROBLEMS["{}{}-2layer-classification".format(lin_name, act_name)] = ( 35 | make_classification_problem( 36 | INPUT_SHAPE, 37 | two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=act_cls), 38 | ) 39 | ) 40 | -------------------------------------------------------------------------------- /test/test_problems_linear.py: -------------------------------------------------------------------------------- 1 | from .layers import LINEARS 2 | from .networks import single_linear_layer, two_linear_layers 3 | from .problems import make_classification_problem, make_regression_problem 4 | 5 | TEST_SETTINGS = { 6 | "in_features": 7, 7 | "out_features": 3, 8 | "out_features2": 3, 9 | "bias": True, 10 | "batch": 5, 11 | "rtol": 1e-5, 12 | "atol": 1e-5, 13 | } 14 | INPUT_SHAPE = (TEST_SETTINGS["batch"], TEST_SETTINGS["in_features"]) 15 | 16 | TEST_PROBLEMS = {} 17 | 18 | for lin_name, lin_cls in LINEARS.items(): 19 | TEST_PROBLEMS["{}-regression".format(lin_name)] = make_regression_problem( 20 | INPUT_SHAPE, single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=None) 21 | ) 22 | 23 | TEST_PROBLEMS["{}-classification".format(lin_name)] = make_classification_problem( 24 | INPUT_SHAPE, single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=None) 25 | ) 26 | 27 | TEST_PROBLEMS["{}-2layer-classification".format(lin_name)] = ( 28 | make_classification_problem( 29 | INPUT_SHAPE, two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=None) 30 | ) 31 | ) 32 | -------------------------------------------------------------------------------- /test/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Helper functions for tests.""" 2 | 3 | from typing import List 4 | 5 | 6 | def chunk_sizes(total_size: int, num_chunks: int) -> List[int]: 7 | """Return list containing the sizes of chunks. 8 | 9 | Args: 10 | total_size: Total computation work. 11 | num_chunks: Maximum number of chunks the work will be split into. 12 | 13 | Returns: 14 | List of chunks with split work. 15 | """ 16 | chunk_size = max(total_size // num_chunks, 1) 17 | 18 | if chunk_size == 1: 19 | sizes = total_size * [chunk_size] 20 | else: 21 | equal, rest = divmod(total_size, chunk_size) 22 | sizes = equal * [chunk_size] 23 | 24 | if rest != 0: 25 | sizes.append(rest) 26 | 27 | return sizes 28 | -------------------------------------------------------------------------------- /test/utils/evaluation_mode.py: -------------------------------------------------------------------------------- 1 | """Tools for initializing in evaluation mode, especially BatchNorm.""" 2 | 3 | from typing import Union 4 | 5 | from torch import rand_like 6 | from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d, Module 7 | 8 | 9 | def initialize_training_false_recursive(module: Module) -> Module: 10 | """Initializes a module recursively in evaluation mode. 11 | 12 | Args: 13 | module: the module to initialize 14 | 15 | Returns: 16 | initialized module in evaluation mode 17 | """ 18 | if isinstance(module, (BatchNorm1d, BatchNorm2d, BatchNorm3d)): 19 | initialize_batch_norm_eval(module) 20 | else: 21 | for module_child in module.children(): 22 | initialize_training_false_recursive(module_child) 23 | return module.train(False) 24 | 25 | 26 | def initialize_batch_norm_eval( 27 | module: Union[BatchNorm1d, BatchNorm2d, BatchNorm3d] 28 | ) -> Union[BatchNorm1d, BatchNorm2d, BatchNorm3d]: 29 | """Initializes a BatchNorm module in evaluation mode. 30 | 31 | Args: 32 | module: BatchNorm module 33 | 34 | Returns: 35 | the initialized BatchNorm module in evaluation mode 36 | """ 37 | module.running_mean = rand_like(module.running_mean) 38 | module.running_var = rand_like(module.running_var) 39 | module.weight.data = rand_like(module.weight) 40 | module.bias.data = rand_like(module.bias) 41 | return module.train(False) 42 | -------------------------------------------------------------------------------- /test/utils/skip_extension_test.py: -------------------------------------------------------------------------------- 1 | """Contains skip conditions for BackPACK's extension tests.""" 2 | 3 | from test.extensions.problem import ExtensionsTestProblem 4 | 5 | from pytest import skip 6 | from torch.nn import BCEWithLogitsLoss 7 | 8 | 9 | def skip_BCEWithLogitsLoss_non_binary_labels(problem: ExtensionsTestProblem) -> None: 10 | """Skip if case uses BCEWithLogitsLoss as loss function with non-binary labels. 11 | 12 | Args: 13 | problem: Extension test case. 14 | """ 15 | if isinstance(problem.loss_function, BCEWithLogitsLoss) and any( 16 | y not in [0, 1] for y in problem.target.flatten() 17 | ): 18 | skip("Skipping BCEWithLogitsLoss with non-binary labels") 19 | 20 | 21 | def skip_BCEWithLogitsLoss(problem: ExtensionsTestProblem) -> None: 22 | """Skip if case uses BCEWithLogitsLoss as loss function. 23 | 24 | Args: 25 | problem: Extension test case. 26 | """ 27 | if isinstance(problem.loss_function, BCEWithLogitsLoss): 28 | skip("Skipping BCEWithLogitsLoss") 29 | -------------------------------------------------------------------------------- /test/utils/test_subsampling.py: -------------------------------------------------------------------------------- 1 | """Contains tests of sub-sampling functionality.""" 2 | 3 | from torch import allclose, manual_seed, rand 4 | 5 | from backpack.utils.subsampling import subsample 6 | 7 | 8 | def test_subsample(): 9 | """Test slicing operations for sub-sampling a tensor's batch axis.""" 10 | manual_seed(0) 11 | tensor = rand(3, 4, 5, 6) 12 | 13 | # leave tensor untouched when `subsampling = None` 14 | assert id(subsample(tensor)) == id(tensor) 15 | assert allclose(subsample(tensor), tensor) 16 | 17 | # slice along correct dimension 18 | idx = [2, 0] 19 | assert allclose(subsample(tensor, dim=0, subsampling=idx), tensor[idx]) 20 | assert allclose(subsample(tensor, dim=1, subsampling=idx), tensor[:, idx]) 21 | assert allclose(subsample(tensor, dim=2, subsampling=idx), tensor[:, :, idx]) 22 | assert allclose(subsample(tensor, dim=3, subsampling=idx), tensor[:, :, :, idx]) 23 | --------------------------------------------------------------------------------