├── .conda_env.yml ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── CI.yml │ ├── Lint.yml │ └── python-publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── AUTHORS.md ├── CHANGELOG.md ├── LICENSE.txt ├── README.md ├── black.toml ├── cockpit ├── __init__.py ├── cockpit.py ├── context.py ├── instruments │ ├── __init__.py │ ├── alpha_gauge.py │ ├── cabs_gauge.py │ ├── distance_gauge.py │ ├── early_stopping_gauge.py │ ├── grad_norm_gauge.py │ ├── gradient_tests_gauge.py │ ├── histogram_1d_gauge.py │ ├── histogram_2d_gauge.py │ ├── hyperparameter_gauge.py │ ├── max_ev_gauge.py │ ├── mean_gsnr_gauge.py │ ├── performance_gauge.py │ ├── tic_gauge.py │ ├── trace_gauge.py │ ├── utils_instruments.py │ └── utils_plotting.py ├── plotter.py ├── quantities │ ├── __init__.py │ ├── alpha.py │ ├── bin_adaptation.py │ ├── cabs.py │ ├── distance.py │ ├── early_stopping.py │ ├── grad_hist.py │ ├── grad_norm.py │ ├── hess_max_ev.py │ ├── hess_trace.py │ ├── hooks │ │ ├── __init__.py │ │ ├── base.py │ │ └── cleanup.py │ ├── inner_test.py │ ├── loss.py │ ├── mean_gsnr.py │ ├── norm_test.py │ ├── ortho_test.py │ ├── parameters.py │ ├── quantity.py │ ├── tic.py │ ├── time.py │ ├── update_size.py │ ├── utils_hists.py │ ├── utils_quantities.py │ └── utils_transforms.py └── utils │ ├── __init__.py │ ├── configuration.py │ ├── optim.py │ └── schedules.py ├── docs ├── Makefile ├── requirements_doc.txt └── source │ ├── _static │ ├── 01_basic_fmnist.png │ ├── 02_advanced_fmnist.png │ ├── 03_deepobs.png │ ├── Banner.svg │ ├── LogoSquare.png │ ├── favicon.ico │ ├── instrument_preview_run.json │ ├── instrument_previews │ │ ├── Alpha.png │ │ ├── CABS.png │ │ ├── Distances.png │ │ ├── EarlyStopping.png │ │ ├── GradientNorm.png │ │ ├── GradientTests.png │ │ ├── HessMaxEV.png │ │ ├── HessTrace.png │ │ ├── Hist1d.png │ │ ├── Hist2d.png │ │ ├── Hyperparameters.png │ │ ├── MeanGSNR.png │ │ ├── Performance.png │ │ └── TIC.png │ ├── showcase.gif │ └── stylefile.css │ ├── api │ ├── automod │ │ ├── cockpit.instruments.alpha_gauge.rst │ │ ├── cockpit.instruments.cabs_gauge.rst │ │ ├── cockpit.instruments.distance_gauge.rst │ │ ├── cockpit.instruments.early_stopping_gauge.rst │ │ ├── cockpit.instruments.grad_norm_gauge.rst │ │ ├── cockpit.instruments.gradient_tests_gauge.rst │ │ ├── cockpit.instruments.histogram_1d_gauge.rst │ │ ├── cockpit.instruments.histogram_2d_gauge.rst │ │ ├── cockpit.instruments.hyperparameter_gauge.rst │ │ ├── cockpit.instruments.max_ev_gauge.rst │ │ ├── cockpit.instruments.mean_gsnr_gauge.rst │ │ ├── cockpit.instruments.performance_gauge.rst │ │ ├── cockpit.instruments.tic_gauge.rst │ │ ├── cockpit.instruments.trace_gauge.rst │ │ ├── cockpit.quantities.Alpha.rst │ │ ├── cockpit.quantities.CABS.rst │ │ ├── cockpit.quantities.Distance.rst │ │ ├── cockpit.quantities.EarlyStopping.rst │ │ ├── cockpit.quantities.GradHist1d.rst │ │ ├── cockpit.quantities.GradHist2d.rst │ │ ├── cockpit.quantities.GradNorm.rst │ │ ├── cockpit.quantities.HessMaxEV.rst │ │ ├── cockpit.quantities.HessTrace.rst │ │ ├── cockpit.quantities.InnerTest.rst │ │ ├── cockpit.quantities.Loss.rst │ │ ├── cockpit.quantities.MeanGSNR.rst │ │ ├── cockpit.quantities.NormTest.rst │ │ ├── cockpit.quantities.OrthoTest.rst │ │ ├── cockpit.quantities.Parameters.rst │ │ ├── cockpit.quantities.TICDiag.rst │ │ ├── cockpit.quantities.TICTrace.rst │ │ ├── cockpit.quantities.Time.rst │ │ ├── cockpit.quantities.UpdateSize.rst │ │ ├── cockpit.utils.configuration.configuration.rst │ │ ├── cockpit.utils.configuration.quantities_cls_for_configuration.rst │ │ ├── cockpit.utils.schedules.linear.rst │ │ └── cockpit.utils.schedules.logarithmic.rst │ ├── cockpit.rst │ ├── instruments.rst │ ├── plotter.rst │ ├── quantities.rst │ └── utils.rst │ ├── conf.py │ ├── examples │ ├── 01_basic_fmnist.rst │ ├── 02_advanced_fmnist.rst │ └── 03_deepobs.rst │ ├── extract_instrument_previews.py │ ├── index.rst │ ├── introduction │ └── good_to_know.rst │ └── other │ ├── changelog.rst │ ├── contributors.rst │ └── license.rst ├── examples ├── 01_basic_fmnist.py ├── 02_advanced_fmnist.py ├── 03_deepobs.py ├── _utils_deepobs.py └── _utils_examples.py ├── makefile ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── settings.py ├── test_bugs ├── test_issue5.py └── test_issue6.py ├── test_cockpit ├── __init__.py ├── settings.py ├── test_automatic_call_track.py ├── test_backpack_extensions.py └── test_multiple_batch_grad_transforms.py ├── test_examples ├── __init__.py └── test_examples.py ├── test_quantities ├── __init__.py ├── adam_settings.py ├── settings.py ├── test_alpha.py ├── test_bin_adaptation.py ├── test_cabs.py ├── test_early_stopping.py ├── test_grad_hist.py ├── test_hess_max_ev.py ├── test_hess_trace.py ├── test_inner_test.py ├── test_mean_gsnr.py ├── test_norm_test.py ├── test_ortho_test.py ├── test_quantity_integration.py ├── test_tic.py └── utils.py ├── test_utils ├── __init__.py ├── test_configurations.py └── test_schedules.py └── utils ├── __init__.py ├── check.py ├── data.py ├── harness.py ├── models.py ├── problem.py └── rand.py /.conda_env.yml: -------------------------------------------------------------------------------- 1 | name: cockpit 2 | dependencies: 3 | - pip 4 | - python 5 | - pytorch::pytorch 6 | - pytorch::torchvision 7 | - conda-forge::black 8 | - conda-forge::darglint 9 | - conda-forge::flake8 10 | - conda-forge::flake8-bugbear 11 | - conda-forge::flake8-comprehensions 12 | - conda-forge::isort 13 | - conda-forge::palettable 14 | - conda-forge::pydocstyle 15 | - conda-forge::pytest 16 | - conda-forge::pytest-cov 17 | - conda-forge::pytest-benchmark 18 | - conda-forge::m2r2 19 | - conda-forge::sphinx 20 | - conda-forge::sphinx-automodapi 21 | - conda-forge::sphinx_rtd_theme 22 | - conda-forge::matplotlib=3.4.3 23 | - conda-forge::tikzplotlib 24 | - pip: 25 | - memory-profiler 26 | - pre-commit 27 | - graphviz 28 | - sphinx-notfound-page 29 | - git+https://git@github.com/fsschneider/DeepOBS.git@develop#egg=deepobs 30 | - git+https://git@github.com/f-dangel/backobs.git@development#egg=backobs 31 | - -e . 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us fix a bug 4 | title: '' 5 | labels: "\U0001F195 Status: New, \U0001F41B Type: Bug" 6 | assignees: fsschneider 7 | 8 | --- 9 | 10 | [Provide a general description of the issue] 11 | 12 | ## Description 13 | 14 | [Provide more details on the bug itself] 15 | 16 | ## Steps to Reproduce 17 | 18 | [An ordered list of steps to recreating the issue] 19 | 20 | ## Source or Possible Fix 21 | 22 | [If possible, you can describe what you think is the source of the bug, or possibly even describe a fix] 23 | 24 | --- 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: "\U0001F195 Status: New" 6 | assignees: fsschneider 7 | 8 | --- 9 | 10 | **Are you requesting a new feature or an enhancement of an existing one? Please use the labels accordingly.** 11 | 12 | ## Description 13 | 14 | [Describe the feature you would like to be added] 15 | 16 | --- 17 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - '*' 7 | pull_request: 8 | branches: 9 | - development 10 | - master 11 | 12 | 13 | jobs: 14 | pytest: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v1 18 | - name: Set up Python 3.7 19 | uses: actions/setup-python@v1 20 | with: 21 | python-version: 3.7 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install pytest 26 | pip install pytest-cov 27 | pip install requests 28 | pip install coveralls 29 | pip install . 30 | pip install git+https://git@github.com/fsschneider/DeepOBS.git@develop#egg=deepobs 31 | pip install git+https://git@github.com/f-dangel/backobs.git@development#egg=backobs 32 | - name: Run pytest 33 | run: | 34 | make test 35 | - name: Test coveralls 36 | run: coveralls --service=github 37 | env: 38 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/Lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - '*' 7 | pull_request: 8 | branches: 9 | - development 10 | - master 11 | 12 | 13 | jobs: 14 | black: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python 3.7 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: 3.7 22 | - name: Install black 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install black 26 | - name: Run black 27 | run: | 28 | make black-check 29 | 30 | flake8: 31 | runs-on: ubuntu-latest 32 | steps: 33 | - uses: actions/checkout@v2 34 | - name: Set up Python 3.7 35 | uses: actions/setup-python@v2 36 | with: 37 | python-version: 3.7 38 | - name: Install flake8 39 | run: | 40 | python -m pip install --upgrade pip 41 | pip install flake8 42 | pip install flake8-bugbear 43 | pip install flake8-comprehensions 44 | - name: Run flake8 45 | run: | 46 | make flake8 47 | 48 | pydocstyle: 49 | runs-on: ubuntu-latest 50 | steps: 51 | - uses: actions/checkout@v2 52 | - name: Set up Python 3.7 53 | uses: actions/setup-python@v2 54 | with: 55 | python-version: 3.7 56 | - name: Install pydocstyle 57 | run: | 58 | python -m pip install --upgrade pip 59 | pip install pydocstyle 60 | - name: Run pydocstyle 61 | run: | 62 | make pydocstyle-check 63 | 64 | isort: 65 | runs-on: ubuntu-latest 66 | steps: 67 | - uses: actions/checkout@v2 68 | - name: Set up Python 3.7 69 | uses: actions/setup-python@v2 70 | with: 71 | python-version: 3.7 72 | - name: Install isort 73 | run: | 74 | python -m pip install --upgrade pip 75 | pip install isort 76 | - name: Run isort 77 | run: | 78 | make isort-check 79 | 80 | darglint: 81 | runs-on: ubuntu-latest 82 | steps: 83 | - uses: actions/checkout@v2 84 | - name: Set up Python 3.7 85 | uses: actions/setup-python@v2 86 | with: 87 | python-version: 3.7 88 | - name: Install dependencies 89 | run: | 90 | python -m pip install --upgrade pip 91 | pip install darglint 92 | - name: Run darglint 93 | run: | 94 | make darglint-check 95 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up Python 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: "3.x" 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install setuptools wheel twine 24 | - name: Build and publish 25 | env: 26 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 27 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 28 | run: | 29 | python setup.py sdist bdist_wheel 30 | twine upload dist/* 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | potato.yml 3 | 4 | ## Core latex/pdflatex auxiliary files: 5 | *.aux 6 | *.lof 7 | *.log 8 | *.lot 9 | *.fls 10 | *.out 11 | *.toc 12 | *.fmt 13 | *.fot 14 | *.cb 15 | *.cb2 16 | .*.lb 17 | 18 | ## Intermediate documents: 19 | *.dvi 20 | *.xdv 21 | *-converted-to.* 22 | # these rules might exclude image files for figures etc. 23 | # *.ps 24 | # *.eps 25 | # *.pdf 26 | 27 | ## Generated if empty string is given at "Please type another file name for output:" 28 | .pdf 29 | 30 | ## Bibliography auxiliary files (bibtex/biblatex/biber): 31 | *.bbl 32 | *.bcf 33 | *.blg 34 | *-blx.aux 35 | *-blx.bib 36 | *.run.xml 37 | 38 | ## Build tool auxiliary files: 39 | *.fdb_latexmk 40 | *.synctex 41 | *.synctex(busy) 42 | *.synctex.gz 43 | *.synctex.gz(busy) 44 | *.pdfsync 45 | 46 | ## Build tool directories for auxiliary files 47 | # latexrun 48 | latex.out/ 49 | 50 | ## Auxiliary and intermediate files from other packages: 51 | # algorithms 52 | *.alg 53 | *.loa 54 | 55 | # achemso 56 | acs-*.bib 57 | 58 | # amsthm 59 | *.thm 60 | 61 | # beamer 62 | *.nav 63 | *.pre 64 | *.snm 65 | *.vrb 66 | 67 | # changes 68 | *.soc 69 | 70 | # comment 71 | *.cut 72 | 73 | # cprotect 74 | *.cpt 75 | 76 | # elsarticle (documentclass of Elsevier journals) 77 | *.spl 78 | 79 | # endnotes 80 | *.ent 81 | 82 | # fixme 83 | *.lox 84 | 85 | # feynmf/feynmp 86 | *.mf 87 | *.mp 88 | *.t[1-9] 89 | *.t[1-9][0-9] 90 | *.tfm 91 | 92 | #(r)(e)ledmac/(r)(e)ledpar 93 | *.end 94 | *.?end 95 | *.[1-9] 96 | *.[1-9][0-9] 97 | *.[1-9][0-9][0-9] 98 | *.[1-9]R 99 | *.[1-9][0-9]R 100 | *.[1-9][0-9][0-9]R 101 | *.eledsec[1-9] 102 | *.eledsec[1-9]R 103 | *.eledsec[1-9][0-9] 104 | *.eledsec[1-9][0-9]R 105 | *.eledsec[1-9][0-9][0-9] 106 | *.eledsec[1-9][0-9][0-9]R 107 | 108 | # glossaries 109 | *.acn 110 | *.acr 111 | *.glg 112 | *.glo 113 | *.gls 114 | *.glsdefs 115 | *.lzo 116 | *.lzs 117 | 118 | # uncomment this for glossaries-extra (will ignore makeindex's style files!) 119 | # *.ist 120 | 121 | # gnuplottex 122 | *-gnuplottex-* 123 | 124 | # gregoriotex 125 | *.gaux 126 | *.gtex 127 | 128 | # htlatex 129 | *.4ct 130 | *.4tc 131 | *.idv 132 | *.lg 133 | *.trc 134 | *.xref 135 | 136 | # hyperref 137 | *.brf 138 | 139 | # knitr 140 | *-concordance.tex 141 | # TODO Comment the next line if you want to keep your tikz graphics files 142 | *.tikz 143 | *-tikzDictionary 144 | 145 | # listings 146 | *.lol 147 | 148 | # luatexja-ruby 149 | *.ltjruby 150 | 151 | # makeidx 152 | *.idx 153 | *.ilg 154 | *.ind 155 | 156 | # minitoc 157 | *.maf 158 | *.mlf 159 | *.mlt 160 | *.mtc[0-9]* 161 | *.slf[0-9]* 162 | *.slt[0-9]* 163 | *.stc[0-9]* 164 | 165 | # minted 166 | _minted* 167 | *.pyg 168 | 169 | # morewrites 170 | *.mw 171 | 172 | # nomencl 173 | *.nlg 174 | *.nlo 175 | *.nls 176 | 177 | # pax 178 | *.pax 179 | 180 | # pdfpcnotes 181 | *.pdfpc 182 | 183 | # sagetex 184 | *.sagetex.sage 185 | *.sagetex.py 186 | *.sagetex.scmd 187 | 188 | # scrwfile 189 | *.wrt 190 | 191 | # sympy 192 | *.sout 193 | *.sympy 194 | sympy-plots-for-*.tex/ 195 | 196 | # pdfcomment 197 | *.upa 198 | *.upb 199 | 200 | # pythontex 201 | *.pytxcode 202 | pythontex-files-*/ 203 | 204 | # tcolorbox 205 | *.listing 206 | 207 | # thmtools 208 | *.loe 209 | 210 | # TikZ & PGF 211 | *.dpth 212 | *.md5 213 | *.auxlock 214 | 215 | # todonotes 216 | *.tdo 217 | 218 | # vhistory 219 | *.hst 220 | *.ver 221 | 222 | # easy-todo 223 | *.lod 224 | 225 | # xcolor 226 | *.xcp 227 | 228 | # xmpincl 229 | *.xmpi 230 | 231 | # xindy 232 | *.xdy 233 | 234 | # xypic precompiled matrices and outlines 235 | *.xyc 236 | *.xyd 237 | 238 | # endfloat 239 | *.ttt 240 | *.fff 241 | 242 | # Latexian 243 | TSWLatexianTemp* 244 | 245 | ## Editors: 246 | # WinEdt 247 | *.bak 248 | *.sav 249 | 250 | # Texpad 251 | .texpadtmp 252 | 253 | # LyX 254 | *.lyx~ 255 | 256 | # Kile 257 | *.backup 258 | 259 | # KBibTeX 260 | *~[0-9]* 261 | 262 | # auto folder when using emacs and auctex 263 | ./auto/* 264 | *.el 265 | 266 | # expex forward references with \gathertags 267 | *-tags.tex 268 | 269 | # standalone packages 270 | *.sta 271 | 272 | # Makeindex log files 273 | *.lpz 274 | 275 | 276 | # Python data files 277 | *-ubyte 278 | *.pt 279 | **/__pycache__ 280 | 281 | 282 | # Python IDE files 283 | .idea 284 | 285 | # General 286 | .DS_Store 287 | .AppleDouble 288 | .LSOverride 289 | 290 | ######################## 291 | 292 | **/data 293 | 294 | # pip src 295 | **/src 296 | *.egg-info 297 | 298 | # DeepOBS 299 | *.json 300 | !docs/source/_static/instrument_preview_run.json 301 | **/results 302 | **/data_deepobs 303 | 304 | # Torchvision 305 | **MNIST/**/*.gz 306 | 307 | # Tensorboard 308 | **/*events.out.tfevents.* 309 | 310 | # Test 311 | .coverage 312 | **/logfiles 313 | 314 | # Doc 315 | docs/build/ 316 | 317 | # local emacs variables 318 | !**/.dir-locals.el 319 | 320 | examples/logfiles/ 321 | 322 | .envrc -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | args: [--config=black.toml] 7 | - repo: https://gitlab.com/pycqa/flake8 8 | rev: "3.7.9" 9 | hooks: 10 | - id: flake8 11 | additional_dependencies: 12 | [ 13 | mccabe, 14 | pycodestyle, 15 | pyflakes, 16 | pep8-naming, 17 | flake8-bugbear, 18 | flake8-comprehensions, 19 | ] 20 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # RTD configuration file version 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | builder: html 11 | 12 | # Optionally build your docs in additional formats such as PDF 13 | formats: [] 14 | 15 | # Optionally set the version of Python and requirements required to build your docs 16 | python: 17 | version: 3.7 18 | install: 19 | - method: pip 20 | path: . 21 | - requirements: docs/requirements_doc.txt -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | ## Cockpit Contributors 2 | 3 | Many people have helped with the development of **Cockpit**. 4 | 5 | **Maintainers:** 6 | 7 |
8 |
9 | 23 |
24 |
25 |
26 | 27 | --- 28 | 29 | Additional support was offered by PhilippHennig Philipp Hennig, JonathanWenger Jonathan Wenger, and the entire [ProbNum](https://github.com/probabilistic-numerics/probnum) team and the [Methods of Machine Learning](https://uni-tuebingen.de/en/faculties/faculty-of-science/departments/computer-science/lehrstuehle/methods-of-machine-learning/) group. 30 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). 6 | 7 | ## [Unreleased] 8 | 9 | ## [1.0.2] - 2021-10-26 10 | 11 | ### Added 12 | 13 | - Added references to a separate [experiment repository](https://github.com/fsschneider/cockpit-experiments) that publishes the code for all experiments shown in the paper. 14 | 15 | ### Fixed 16 | 17 | - Protects the `batch_grad` field in the case where non-SGD is used together with other quantities that free `batch_grad` for memory performance. [[#5](https://github.com/f-dangel/cockpit/issues/5), [PR](https://github.com/f-dangel/cockpit/pull/18)] 18 | 19 | ## [1.0.1] - 2021-10-13 20 | 21 | From this version on, `cockpit` will be available as `cockpit-for-pytorch` on 22 | PyPI. 23 | 24 | ### Added 25 | - Make library `pip`-installable as `cockpit-for-pytorch` 26 | [[PR](https://github.com/f-dangel/cockpit/pull/17)] 27 | - Require BackPACK main release 28 | [[PR](https://github.com/f-dangel/cockpit/pull/12)] 29 | - Added a `savename` argument to the `CockpitPlotter.plot()` function, which lets you define the name, and now the `savedir` should really only describe the **directory**. [[PR](https://github.com/f-dangel/cockpit/pull/16), Fixes #8] 30 | - Added optional `savefig_kwargs` argument to the `CockpitPlotter.plot()` function that gets passed to the `matplotlib` function `fig.savefig()` to, e.g., specify DPI value or use a different file format (e.g. PDF). [[PR](https://github.com/f-dangel/cockpit/pull/16), Fixes #10] 31 | 32 | ### Internal 33 | - Fix [#6](https://github.com/f-dangel/cockpit/issues/6): Don't execute 34 | extension hook on modules with children 35 | [[PR](https://github.com/f-dangel/cockpit/pull/7)] 36 | 37 | ## [1.0.0] - 2021-04-30 38 | 39 | ### Added 40 | 41 | - First public release version of **Cockpit**. 42 | 43 | [Unreleased]: https://github.com/f-dangel/cockpit/compare/v1.0.2...HEAD 44 | [1.0.2]: https://github.com/f-dangel/cockpit/compare/1.0.1...1.0.2 45 | [1.0.1]: https://github.com/f-dangel/cockpit/compare/1.0.0...1.0.1 46 | [1.0.0]: https://github.com/f-dangel/cockpit/releases/tag/1.0.0 47 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Frank Schneider, Felix Dangel & Philipp Hennig 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

4 | Logo 5 |

A Practical Debugging Tool for Training Deep Neural Networks

6 | 7 |

8 | A better status screen for deep learning. 9 |

10 |

11 | 12 |

13 | Installation • 14 | Docs • 15 | Experiments • 16 | License • 17 | Citation 18 |

19 | 20 | [![CI](https://github.com/f-dangel/cockpit/actions/workflows/CI.yml/badge.svg)](https://github.com/f-dangel/cockpit/actions/workflows/CI.yml) 21 | [![Lint](https://github.com/f-dangel/cockpit/actions/workflows/Lint.yml/badge.svg)](https://github.com/f-dangel/cockpit/actions/workflows/Lint.yml) 22 | [![Doc](https://img.shields.io/readthedocs/cockpit/latest.svg?logo=read%20the%20docs&logoColor=white&label=Doc)](https://cockpit.readthedocs.io) 23 | [![Coverage](https://coveralls.io/repos/github/f-dangel/cockpit/badge.svg?branch=main&t=piyZHm)](https://coveralls.io/github/f-dangel/cockpit?branch=main) 24 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://github.com/f-dangel/cockpit/blob/master/LICENSE) 25 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 26 | [![arXiv](https://img.shields.io/static/v1?logo=arxiv&logoColor=white&label=Preprint&message=2102.06604&color=B31B1B)](https://arxiv.org/abs/2102.06604) 27 | 28 | --- 29 | 30 | ```bash 31 | pip install cockpit-for-pytorch 32 | ``` 33 | 34 | --- 35 | 36 | **Cockpit is a visual and statistical debugger specifically designed for deep learning.** Training a deep neural network is often a pain! Successfully training such a network usually requires either years of intuition or expensive parameter searches involving lots of trial and error. Traditional debuggers provide only limited help: They can find *syntactical errors* but not *training bugs* such as ill-chosen learning rates. **Cockpit** offers a closer, more meaningful look into the training process with multiple well-chosen *instruments*. 37 | 38 | --- 39 | 40 | ![CockpitAnimation](docs/source/_static/showcase.gif) 41 | 42 | 43 | ## Installation 44 | 45 | To install **Cockpit** simply run 46 | 47 | ```bash 48 | pip install cockpit-for-pytorch 49 | ``` 50 | 51 |
52 | Conda environment 53 | For convenience, we also provide a conda environment, which can be installed via the conda yml file. It includes all the necessary requirements to build the docs, execute the tests and run the examples. 54 |
55 | 56 | 57 | ## Documentation 58 | 59 | The [documentation](https://cockpit.readthedocs.io/) provides a full tutorial on how to get started using **Cockpit** as well as a detailed documentation of its API. 60 | 61 | 62 | ## Experiments 63 | 64 | To showcase the capabilities of **Cockpit** we performed several experiments illustrating the usefulness of our debugging tool. The code for the experiments can be found in a [separate repository](https://github.com/fsschneider/cockpit-experiments). For a discussion of those experiments please refer to our [paper](https://arxiv.org/abs/2102.06604). 65 | 66 | 67 | ## License 68 | 69 | Distributed under the MIT License. See [`LICENSE`](LICENSE.txt) for more information. 70 | 71 | 72 | ## Citation 73 | 74 | If you use **Cockpit**, please consider citing: 75 | 76 | > [Frank Schneider, Felix Dangel, Philipp Hennig
77 | > **Cockpit: A Practical Debugging Tool for Training Deep Neural Networks**
78 | > *arXiv 2102.06604*](http://arxiv.org/abs/2102.06604) 79 | 80 | ```bibtex 81 | @misc{schneider2021cockpit, 82 | title={{Cockpit: A Practical Debugging Tool for Training Deep Neural Networks}}, 83 | author={Frank Schneider and Felix Dangel and Philipp Hennig}, 84 | year={2021}, 85 | eprint={2102.06604}, 86 | archivePrefix={arXiv}, 87 | primaryClass={cs.LG} 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /black.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | target-version = ['py35', 'py36', 'py37'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | ( 7 | /( 8 | \.eggs 9 | | \.git 10 | | \.pytest_cache 11 | | \.benchmarks 12 | | docs_src 13 | | docs 14 | | build 15 | | dist 16 | )/ 17 | ) 18 | ''' 19 | -------------------------------------------------------------------------------- /cockpit/__init__.py: -------------------------------------------------------------------------------- 1 | """Init for Cockpit.""" 2 | from pkg_resources import DistributionNotFound, get_distribution 3 | 4 | from cockpit.cockpit import Cockpit 5 | from cockpit.plotter import CockpitPlotter 6 | 7 | # Extract the version number, for accessing it via __version__ 8 | try: 9 | # Change here if project is renamed and does not equal the package name 10 | dist_name = __name__ 11 | __version__ = get_distribution(dist_name).version 12 | except DistributionNotFound: 13 | __version__ = "unknown" 14 | finally: 15 | del get_distribution, DistributionNotFound 16 | 17 | 18 | __all__ = ["Cockpit", "CockpitPlotter", "__version__", "cockpit.quantities"] 19 | -------------------------------------------------------------------------------- /cockpit/instruments/__init__.py: -------------------------------------------------------------------------------- 1 | """All Instruments for the Cockpit.""" 2 | 3 | from cockpit.instruments.alpha_gauge import alpha_gauge 4 | from cockpit.instruments.cabs_gauge import cabs_gauge 5 | from cockpit.instruments.distance_gauge import distance_gauge 6 | from cockpit.instruments.early_stopping_gauge import early_stopping_gauge 7 | from cockpit.instruments.grad_norm_gauge import grad_norm_gauge 8 | from cockpit.instruments.gradient_tests_gauge import gradient_tests_gauge 9 | from cockpit.instruments.histogram_1d_gauge import histogram_1d_gauge 10 | from cockpit.instruments.histogram_2d_gauge import histogram_2d_gauge 11 | from cockpit.instruments.hyperparameter_gauge import hyperparameter_gauge 12 | from cockpit.instruments.max_ev_gauge import max_ev_gauge 13 | from cockpit.instruments.mean_gsnr_gauge import mean_gsnr_gauge 14 | from cockpit.instruments.performance_gauge import performance_gauge 15 | from cockpit.instruments.tic_gauge import tic_gauge 16 | from cockpit.instruments.trace_gauge import trace_gauge 17 | 18 | __all__ = [ 19 | "alpha_gauge", 20 | "distance_gauge", 21 | "grad_norm_gauge", 22 | "histogram_1d_gauge", 23 | "histogram_2d_gauge", 24 | "gradient_tests_gauge", 25 | "hyperparameter_gauge", 26 | "max_ev_gauge", 27 | "performance_gauge", 28 | "tic_gauge", 29 | "trace_gauge", 30 | "mean_gsnr_gauge", 31 | "early_stopping_gauge", 32 | "cabs_gauge", 33 | ] 34 | -------------------------------------------------------------------------------- /cockpit/instruments/cabs_gauge.py: -------------------------------------------------------------------------------- 1 | """CABS Gauge.""" 2 | 3 | import warnings 4 | 5 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot 6 | 7 | 8 | def cabs_gauge(self, fig, gridspec): 9 | """CABS gauge, showing the CABS rule versus iteration. 10 | 11 | The batch size trades-off more accurate gradient approximations with longer 12 | computation. The `CABS criterion `_ describes 13 | the optimal batch size under certain assumptions. 14 | 15 | The instruments shows the suggested batch size (and an exponential weighted 16 | average) over the course of training, according to 17 | 18 | - `Balles, L., Romero, J., & Hennig, P., 19 | Coupling adaptive batch sizes with learning rates (2017). 20 | `_ 21 | 22 | **Preview** 23 | 24 | .. image:: ../../_static/instrument_previews/CABS.png 25 | :alt: Preview CABS Gauge 26 | 27 | **Requires** 28 | 29 | This instrument requires data from the :class:`~cockpit.quantities.CABS` 30 | quantity class. 31 | 32 | Args: 33 | self (CockpitPlotter): The cockpit plotter requesting this instrument. 34 | fig (matplotlib.figure.Figure): Figure of the Cockpit. 35 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be 36 | placed 37 | """ 38 | # Plot Trace vs iteration 39 | title = "CABS" 40 | 41 | # Check if the required data is available, else skip this instrument 42 | requires = ["CABS"] 43 | plot_possible = check_data(self.tracking_data, requires) 44 | if not plot_possible: 45 | if self.debug: 46 | warnings.warn( 47 | "Couldn't get the required data for the " + title + " instrument", 48 | stacklevel=1, 49 | ) 50 | return 51 | 52 | plot_args = { 53 | "x": "iteration", 54 | "y": "CABS", 55 | "data": self.tracking_data, 56 | "x_scale": "symlog" if self.show_log_iter else "linear", 57 | "y_scale": "linear", 58 | "cmap": self.cmap, 59 | "EMA": "y", 60 | "EMA_alpha": self.EMA_alpha, 61 | "EMA_cmap": self.cmap2, 62 | "title": title, 63 | "xlim": "tight", 64 | "ylim": None, 65 | "fontweight": "bold", 66 | "facecolor": self.bg_color_instruments, 67 | } 68 | ax = fig.add_subplot(gridspec) 69 | create_basic_plot(**plot_args, ax=ax) 70 | -------------------------------------------------------------------------------- /cockpit/instruments/distance_gauge.py: -------------------------------------------------------------------------------- 1 | """Distance Gauge.""" 2 | 3 | import warnings 4 | 5 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot 6 | from cockpit.quantities.utils_quantities import _root_sum_of_squares 7 | 8 | 9 | def distance_gauge(self, fig, gridspec): 10 | """Distance gauge showing two different quantities related to distance. 11 | 12 | This instruments shows two quantities at once. Firstly, the :math:`L_2`-distance 13 | of the current parameters to their initialization. This describes the total distance 14 | that the optimization trajectory "has traveled so far" and can be seen via the 15 | blue-to-green dots (and the left y-axis). 16 | 17 | Secondly, the update sizes of individual steps are shown via the yellow-to-blue 18 | dots (and the right y-axis). It measure the distance that a single parameter 19 | update covers. 20 | 21 | Both quantities are overlayed with an exponentially weighted average. 22 | 23 | .. image:: ../../_static/instrument_previews/Distances.png 24 | :alt: Preview Distances Gauge 25 | 26 | **Requires** 27 | 28 | The distance instrument requires data from both, the 29 | :class:`~cockpit.quantities.UpdateSize` and the 30 | :class:`~cockpit.quantities.Distance` quantity class. 31 | 32 | Args: 33 | self (CockpitPlotter): The cockpit plotter requesting this instrument. 34 | fig (matplotlib.figure.Figure): Figure of the Cockpit. 35 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be 36 | placed 37 | """ 38 | # Plot Trace vs iteration 39 | title = "Distance" 40 | 41 | # Check if the required data is available, else skip this instrument 42 | requires = ["Distance", "UpdateSize"] 43 | plot_possible = check_data(self.tracking_data, requires) 44 | if not plot_possible: 45 | if self.debug: 46 | warnings.warn( 47 | "Couldn't get the required data for the " + title + " instrument", 48 | stacklevel=1, 49 | ) 50 | return 51 | 52 | # Compute 53 | self.tracking_data["Distance_all"] = self.tracking_data.Distance.map( 54 | lambda x: _root_sum_of_squares(x) if type(x) == list else x 55 | ) 56 | self.tracking_data["UpdateSize_all"] = self.tracking_data.UpdateSize.map( 57 | lambda x: _root_sum_of_squares(x) if type(x) == list else x 58 | ) 59 | 60 | plot_args = { 61 | "x": "iteration", 62 | "y": "Distance_all", 63 | "data": self.tracking_data, 64 | "y_scale": "linear", 65 | "x_scale": "symlog" if self.show_log_iter else "linear", 66 | "cmap": self.cmap, 67 | "EMA": "y", 68 | "EMA_alpha": self.EMA_alpha, 69 | "EMA_cmap": self.cmap2, 70 | "title": title, 71 | "xlim": "tight", 72 | "ylim": None, 73 | "fontweight": "bold", 74 | "facecolor": self.bg_color_instruments, 75 | } 76 | ax = fig.add_subplot(gridspec) 77 | create_basic_plot(**plot_args, ax=ax) 78 | 79 | ax2 = ax.twinx() 80 | plot_args = { 81 | "x": "iteration", 82 | "y": "UpdateSize_all", 83 | "data": self.tracking_data, 84 | "y_scale": "linear", 85 | "x_scale": "symlog" if self.show_log_iter else "linear", 86 | "cmap": self.cmap.reversed(), 87 | "EMA": "y", 88 | "EMA_alpha": self.EMA_alpha, 89 | "EMA_cmap": self.cmap2.reversed(), 90 | "xlim": "tight", 91 | "ylim": None, 92 | "marker": ",", 93 | } 94 | create_basic_plot(**plot_args, ax=ax2) 95 | -------------------------------------------------------------------------------- /cockpit/instruments/early_stopping_gauge.py: -------------------------------------------------------------------------------- 1 | """Early Stopping Gauge.""" 2 | 3 | import warnings 4 | 5 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot 6 | 7 | 8 | def early_stopping_gauge(self, fig, gridspec): 9 | """Early Stopping gauge, showing the LHS of the stopping criterion versus iteration. 10 | 11 | Early stopping the training has been widely used to prevent poor generalization 12 | due to over-fitting. `Mahsereci et al. (2017) `_ 13 | proposed an evidence-based stopping criterion based on mini-batch statistics. 14 | This instruments visualizes this criterion versus iteration, overlayed 15 | with an exponentially weighted average. If the stopping criterion becomes 16 | positive, this suggests stopping the training according to 17 | 18 | - `Mahsereci, M., Balles, L., Lassner, C., & Hennig, P., 19 | Early stopping without a validation set (2017). 20 | `_ 21 | 22 | **Preview** 23 | 24 | .. image:: ../../_static/instrument_previews/EarlyStopping.png 25 | :alt: Preview EarlyStopping Gauge 26 | 27 | **Requires** 28 | 29 | This instrument requires data from the :class:`~cockpit.quantities.EarlyStopping` 30 | quantity class. 31 | 32 | Args: 33 | self (CockpitPlotter): The cockpit plotter requesting this instrument. 34 | fig (matplotlib.figure.Figure): Figure of the Cockpit. 35 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be 36 | placed 37 | """ 38 | # Plot Trace vs iteration 39 | title = "Early stopping" 40 | 41 | # Check if the required data is available, else skip this instrument 42 | requires = ["EarlyStopping"] 43 | plot_possible = check_data(self.tracking_data, requires) 44 | if not plot_possible: 45 | if self.debug: 46 | warnings.warn( 47 | "Couldn't get the required data for the " + title + " instrument", 48 | stacklevel=1, 49 | ) 50 | return 51 | 52 | plot_args = { 53 | "x": "iteration", 54 | "y": "EarlyStopping", 55 | "data": self.tracking_data, 56 | "x_scale": "symlog" if self.show_log_iter else "linear", 57 | "y_scale": "linear", 58 | "cmap": self.cmap, 59 | "EMA": "y", 60 | "EMA_alpha": self.EMA_alpha, 61 | "EMA_cmap": self.cmap2, 62 | "title": title, 63 | "xlim": "tight", 64 | "ylim": None, 65 | "fontweight": "bold", 66 | "facecolor": self.bg_color_instruments, 67 | } 68 | ax = fig.add_subplot(gridspec) 69 | create_basic_plot(**plot_args, ax=ax) 70 | -------------------------------------------------------------------------------- /cockpit/instruments/grad_norm_gauge.py: -------------------------------------------------------------------------------- 1 | """Gradient Norm Gauge.""" 2 | 3 | import warnings 4 | 5 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot 6 | from cockpit.quantities.utils_quantities import _root_sum_of_squares 7 | 8 | 9 | def grad_norm_gauge(self, fig, gridspec): 10 | """Showing the gradient norm versus iteration. 11 | 12 | If the training gets stuck, due to a small 13 | :class:`~cockpit.quantities.UpdateSize` it can be the result of both a badly 14 | chosen learning rate, or from a flat plateau in the loss landscape. 15 | This instrument shows the gradient norm at each iteration, overlayed with an 16 | exponentially weighted average, and can thus distinguish these two cases. 17 | 18 | **Preview** 19 | 20 | .. image:: ../../_static/instrument_previews/GradientNorm.png 21 | :alt: Preview GradientNorm Gauge 22 | 23 | **Requires** 24 | 25 | The gradient norm instrument requires data from the 26 | :class:`~cockpit.quantities.GradNorm` quantity class. 27 | 28 | Args: 29 | self (CockpitPlotter): The cockpit plotter requesting this instrument. 30 | fig (matplotlib.figure.Figure): Figure of the Cockpit. 31 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be 32 | placed 33 | """ 34 | # Plot Trace vs iteration 35 | title = "Gradient Norm" 36 | 37 | # Check if the required data is available, else skip this instrument 38 | requires = ["GradNorm"] 39 | plot_possible = check_data(self.tracking_data, requires) 40 | if not plot_possible: 41 | if self.debug: 42 | warnings.warn( 43 | "Couldn't get the required data for the " + title + " instrument", 44 | stacklevel=1, 45 | ) 46 | return 47 | 48 | # Compute 49 | self.tracking_data["GradNorm_all"] = self.tracking_data.GradNorm.map( 50 | lambda x: _root_sum_of_squares(x) if type(x) == list else x 51 | ) 52 | 53 | plot_args = { 54 | "x": "iteration", 55 | "y": "GradNorm_all", 56 | "data": self.tracking_data, 57 | "x_scale": "symlog" if self.show_log_iter else "linear", 58 | "y_scale": "linear", 59 | "cmap": self.cmap, 60 | "EMA": "y", 61 | "EMA_alpha": self.EMA_alpha, 62 | "EMA_cmap": self.cmap2, 63 | "title": title, 64 | "xlim": "tight", 65 | "ylim": None, 66 | "fontweight": "bold", 67 | "facecolor": self.bg_color_instruments, 68 | } 69 | ax = fig.add_subplot(gridspec) 70 | create_basic_plot(**plot_args, ax=ax) 71 | -------------------------------------------------------------------------------- /cockpit/instruments/histogram_1d_gauge.py: -------------------------------------------------------------------------------- 1 | """One-dimensional Histogram Gauge.""" 2 | 3 | import warnings 4 | 5 | from cockpit.instruments.utils_instruments import _beautify_plot, check_data 6 | 7 | 8 | def histogram_1d_gauge(self, fig, gridspec, y_scale="log"): 9 | """One-dimensional histogram of the individual gradient elements. 10 | 11 | This instrument provides a histogram of the gradient element values across all 12 | individual gradients in a mini-batch. The histogram shows the distribution for 13 | the last tracked iteration only. 14 | 15 | **Preview** 16 | 17 | .. image:: ../../_static/instrument_previews/Hist1d.png 18 | :alt: Preview Hist1d Gauge 19 | 20 | **Requires** 21 | 22 | This two dimensional histogram instrument requires data from the 23 | :class:`~cockpit.quantities.GradHist1d` quantity class. 24 | 25 | Args: 26 | self (CockpitPlotter): The cockpit plotter requesting this instrument. 27 | fig (matplotlib.figure.Figure): Figure of the Cockpit. 28 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be 29 | placed 30 | y_scale (str, optional): Scale of the y-axis. Defaults to "log". 31 | """ 32 | # Plot 33 | title = "Gradient Element Histogram" 34 | 35 | # Check if the required data is available, else skip this instrument 36 | requires = ["GradHist1d"] 37 | plot_possible = check_data(self.tracking_data, requires, min_elements=1) 38 | if not plot_possible: 39 | if self.debug: 40 | warnings.warn( 41 | "Couldn't get the required data for the " + title + " instrument", 42 | stacklevel=1, 43 | ) 44 | return 45 | 46 | ax = fig.add_subplot(gridspec) 47 | 48 | plot_args = { 49 | "title": title, 50 | "fontweight": "bold", 51 | "facecolor": self.bg_color_instruments, 52 | "xlabel": "Gradient Element Value", 53 | "ylabel": "Frequency", 54 | "y_scale": y_scale, 55 | } 56 | 57 | vals, mid_points, width = _get_histogram_data(self.tracking_data) 58 | 59 | ax.bar(mid_points, vals, width=width, color=self.primary_color) 60 | 61 | _beautify_plot(ax=ax, **plot_args) 62 | 63 | ax.set_title(title, fontweight="bold", fontsize="large") 64 | 65 | 66 | def _get_histogram_data(tracking_data): 67 | """Returns the histogram data for the plot. 68 | 69 | Currently we return the bins and values of the last iteration tracked before 70 | this plot. 71 | 72 | Args: 73 | tracking_data (pandas.DataFrame): DataFrame holding the tracking data. 74 | 75 | Returns: 76 | list: Bins of the histogram. 77 | list: Mid points of the bins. 78 | list: Width of the bins. 79 | """ 80 | clean_data = tracking_data.GradHist1d.dropna() 81 | last_step_data = clean_data[clean_data.index[-1]] 82 | 83 | vals = last_step_data["hist"] 84 | bins = last_step_data["edges"] 85 | 86 | width = bins[1] - bins[0] 87 | 88 | mid_points = (bins[1:] + bins[:-1]) / 2 89 | 90 | return vals, mid_points, width 91 | -------------------------------------------------------------------------------- /cockpit/instruments/hyperparameter_gauge.py: -------------------------------------------------------------------------------- 1 | """Hyperparameter Gauge.""" 2 | 3 | import warnings 4 | 5 | import seaborn as sns 6 | 7 | from cockpit.instruments.utils_instruments import ( 8 | _add_last_value_to_legend, 9 | _beautify_plot, 10 | check_data, 11 | ) 12 | 13 | 14 | def hyperparameter_gauge(self, fig, gridspec): 15 | """Hyperparameter gauge, currently showing the learning rate over time. 16 | 17 | This instrument visualizes the hyperparameters values over the course of the 18 | training. Currently, it shows the learning rate, the most likely parameter to 19 | be adapted during training. The current learning rate is additionally shown 20 | in the figure's legend. 21 | 22 | **Preview** 23 | 24 | .. image:: ../../_static/instrument_previews/Hyperparameters.png 25 | :alt: Preview Hyperparameter Gauge 26 | 27 | **Requires** 28 | 29 | This instrument requires the learning rate data passed via the 30 | :func:`cockpit.Cockpit.log()` method. 31 | 32 | Args: 33 | self (CockpitPlotter): The cockpit plotter requesting this instrument. 34 | fig (matplotlib.figure.Figure): Figure of the Cockpit. 35 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be 36 | placed 37 | """ 38 | # Plot Trace vs iteration 39 | title = "Hyperparameters" 40 | 41 | # Check if the required data is available, else skip this instrument 42 | requires = ["iteration", "learning_rate"] 43 | plot_possible = check_data(self.tracking_data, requires) 44 | if not plot_possible: 45 | if self.debug: 46 | warnings.warn( 47 | "Couldn't get the required data for the " + title + " instrument", 48 | stacklevel=1, 49 | ) 50 | return 51 | 52 | ax = fig.add_subplot(gridspec) 53 | 54 | clean_learning_rate = self.tracking_data[["iteration", "learning_rate"]].dropna() 55 | 56 | # Plot Settings 57 | plot_args = { 58 | "x": "iteration", 59 | "y": "learning_rate", 60 | "data": clean_learning_rate, 61 | } 62 | ylabel = plot_args["y"].replace("_", " ").title() 63 | sns.lineplot( 64 | **plot_args, ax=ax, label=ylabel, linewidth=2, color=self.secondary_color 65 | ) 66 | 67 | _beautify_plot( 68 | ax=ax, 69 | xlabel=plot_args["x"], 70 | ylabel=ylabel, 71 | x_scale="symlog" if self.show_log_iter else "linear", 72 | title=title, 73 | xlim="tight", 74 | fontweight="bold", 75 | facecolor=self.bg_color_instruments2, 76 | ) 77 | 78 | ax.legend() 79 | _add_last_value_to_legend(ax) 80 | -------------------------------------------------------------------------------- /cockpit/instruments/max_ev_gauge.py: -------------------------------------------------------------------------------- 1 | """Max EV Gauge.""" 2 | 3 | import warnings 4 | 5 | from matplotlib import ticker 6 | 7 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot 8 | 9 | 10 | def max_ev_gauge(self, fig, gridspec): 11 | """Showing the largest eigenvalue of the Hessian versus iteration. 12 | 13 | The largest eigenvalue of the Hessian indicates the loss surface's sharpest 14 | valley. Together with the :func:`~cockpit.instruments.trace_gauge()`, which 15 | provides a notion of "average curvature", it can help understand the "average 16 | condition number" of the loss landscape at the current point. The instrument 17 | shows the largest eigenvalue of the Hessian versus iteration, overlayed with 18 | an exponentially weighted average. 19 | 20 | **Preview** 21 | 22 | .. image:: ../../_static/instrument_previews/HessMaxEV.png 23 | :alt: Preview HessMaxEV Gauge 24 | 25 | **Requires** 26 | 27 | The trace instrument requires data from the :class:`~cockpit.quantities.HessMaxEv` 28 | quantity class. 29 | 30 | Args: 31 | self (CockpitPlotter): The cockpit plotter requesting this instrument. 32 | fig (matplotlib.figure.Figure): Figure of the Cockpit. 33 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be 34 | placed. 35 | """ 36 | # Plot Trace vs iteration 37 | title = "Max Eigenvalue" 38 | 39 | # Check if the required data is available, else skip this instrument 40 | requires = ["HessMaxEV"] 41 | plot_possible = check_data(self.tracking_data, requires) 42 | if not plot_possible: 43 | if self.debug: 44 | warnings.warn( 45 | "Couldn't get the required data for the " + title + " instrument", 46 | stacklevel=1, 47 | ) 48 | return 49 | 50 | plot_args = { 51 | "x": "iteration", 52 | "y": "HessMaxEV", 53 | "data": self.tracking_data, 54 | "x_scale": "symlog" if self.show_log_iter else "linear", 55 | "y_scale": "log", 56 | "cmap": self.cmap, 57 | "EMA": "y", 58 | "EMA_alpha": self.EMA_alpha, 59 | "EMA_cmap": self.cmap2, 60 | "title": title, 61 | "xlim": "tight", 62 | "ylim": None, 63 | "fontweight": "bold", 64 | "facecolor": self.bg_color_instruments, 65 | } 66 | # part that should be plotted 67 | ax = fig.add_subplot(gridspec) 68 | create_basic_plot(**plot_args, ax=ax) 69 | 70 | ax.yaxis.set_minor_formatter(ticker.FormatStrFormatter("%.2g")) 71 | -------------------------------------------------------------------------------- /cockpit/instruments/mean_gsnr_gauge.py: -------------------------------------------------------------------------------- 1 | """Mean GSNR gauge.""" 2 | 3 | import warnings 4 | 5 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot 6 | 7 | 8 | def mean_gsnr_gauge(self, fig, gridspec): 9 | """Mean GSNR gauge, showing the mean GSNR versus iteration. 10 | 11 | The mean GSNR describes the average gradient signal-to-noise-ratio. `Recent 12 | work `_ used this quantity to study the 13 | generalization performances of neural networks, noting "that larger GSNR during 14 | training process leads to better generalization performance. The instrument 15 | shows the mean GSNR versus iteration, overlayed with an exponentially weighted 16 | average. 17 | 18 | **Preview** 19 | 20 | .. image:: ../../_static/instrument_previews/MeanGSNR.png 21 | :alt: Preview MeanGSNR Gauge 22 | 23 | **Requires** 24 | 25 | This instrument requires data from the :class:`~cockpit.quantities.MeanGSNR` 26 | quantity class. 27 | 28 | Args: 29 | self (CockpitPlotter): The cockpit plotter requesting this instrument. 30 | fig (matplotlib.figure.Figure): Figure of the Cockpit. 31 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be 32 | placed 33 | """ 34 | # Plot Trace vs iteration 35 | title = "Mean GSNR" 36 | 37 | # Check if the required data is available, else skip this instrument 38 | requires = ["MeanGSNR"] 39 | plot_possible = check_data(self.tracking_data, requires) 40 | if not plot_possible: 41 | if self.debug: 42 | warnings.warn( 43 | "Couldn't get the required data for the " + title + " instrument", 44 | stacklevel=1, 45 | ) 46 | return 47 | 48 | plot_args = { 49 | "x": "iteration", 50 | "y": "MeanGSNR", 51 | "data": self.tracking_data, 52 | "x_scale": "symlog" if self.show_log_iter else "linear", 53 | "y_scale": "linear", 54 | "cmap": self.cmap, 55 | "EMA": "y", 56 | "EMA_alpha": self.EMA_alpha, 57 | "EMA_cmap": self.cmap2, 58 | "title": title, 59 | "xlim": "tight", 60 | "ylim": None, 61 | "fontweight": "bold", 62 | "facecolor": self.bg_color_instruments, 63 | } 64 | ax = fig.add_subplot(gridspec) 65 | create_basic_plot(**plot_args, ax=ax) 66 | -------------------------------------------------------------------------------- /cockpit/instruments/performance_gauge.py: -------------------------------------------------------------------------------- 1 | """Performance Gauge.""" 2 | 3 | import warnings 4 | 5 | import seaborn as sns 6 | 7 | from cockpit.instruments.utils_instruments import ( 8 | _add_last_value_to_legend, 9 | check_data, 10 | create_basic_plot, 11 | ) 12 | 13 | 14 | def performance_gauge(self, fig, gridspec): 15 | """Plotting train/valid accuracy vs. epoch and mini-batch loss vs. iteration. 16 | 17 | This instruments visualizes the currently most popular diagnostic metrics. It 18 | shows the mini-batch loss in each iteration (overlayed with an exponentially 19 | weighted average) as well as accuracies for both the training as well as the 20 | validation set. The current accuracy numbers are also shown in the legend. 21 | 22 | **Preview** 23 | 24 | .. image:: ../../_static/instrument_previews/Performance.png 25 | :alt: Preview Performance Gauge 26 | 27 | **Requires** 28 | 29 | This instrument visualizes quantities passed via the 30 | :func:`cockpit.Cockpit.log()` method. 31 | 32 | Args: 33 | self (CockpitPlotter): The cockpit plotter requesting this instrument. 34 | fig (matplotlib.figure.Figure): Figure of the Cockpit. 35 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be 36 | placed 37 | """ 38 | # Plot Trace vs iteration 39 | title = "Performance Plot" 40 | 41 | # Check if the required data is available, else skip this instrument 42 | requires = ["iteration", "Loss"] 43 | plot_possible = check_data(self.tracking_data, requires) 44 | if not plot_possible: 45 | if self.debug: 46 | warnings.warn( 47 | "Couldn't get the loss data for the " + title + " instrument", 48 | stacklevel=1, 49 | ) 50 | return 51 | 52 | # Mini-batch train loss 53 | plot_args = { 54 | "x": "iteration", 55 | "y": "Loss", 56 | "data": self.tracking_data, 57 | "EMA": "y", 58 | "EMA_alpha": self.EMA_alpha, 59 | "EMA_cmap": self.cmap2, 60 | "x_scale": "symlog" if self.show_log_iter else "linear", 61 | "y_scale": "linear", 62 | "cmap": self.cmap, 63 | "title": title, 64 | "xlim": "tight", 65 | "ylim": None, 66 | "fontweight": "bold", 67 | "facecolor": self.bg_color_instruments2, 68 | } 69 | ax = fig.add_subplot(gridspec) 70 | create_basic_plot(**plot_args, ax=ax) 71 | 72 | requires = ["iteration", "train_accuracy", "valid_accuracy"] 73 | plot_possible = check_data(self.tracking_data, requires) 74 | if not plot_possible: 75 | if self.debug: 76 | warnings.warn( 77 | "Couldn't get the accuracy data for the " + title + " instrument", 78 | stacklevel=1, 79 | ) 80 | return 81 | else: 82 | clean_accuracies = self.tracking_data[ 83 | ["iteration", "train_accuracy", "valid_accuracy"] 84 | ].dropna() 85 | 86 | # Train Accuracy 87 | plot_args = { 88 | "x": "iteration", 89 | "y": "train_accuracy", 90 | "data": clean_accuracies, 91 | } 92 | ax2 = ax.twinx() 93 | sns.lineplot( 94 | **plot_args, 95 | ax=ax2, 96 | label=plot_args["y"].title().replace("_", " "), 97 | linewidth=2, 98 | color=self.primary_color, 99 | ) 100 | 101 | # Train Accuracy 102 | plot_args = { 103 | "x": "iteration", 104 | "y": "valid_accuracy", 105 | "data": clean_accuracies, 106 | } 107 | sns.lineplot( 108 | **plot_args, 109 | ax=ax2, 110 | label=plot_args["y"].title().replace("_", " "), 111 | linewidth=2, 112 | color=self.secondary_color, 113 | ) 114 | 115 | # Customization 116 | ax2.set_ylim([0, 1]) 117 | ax2.set_ylabel("Accuracy") 118 | _add_last_value_to_legend(ax2, percentage=True) 119 | -------------------------------------------------------------------------------- /cockpit/instruments/tic_gauge.py: -------------------------------------------------------------------------------- 1 | """TIC Gauge.""" 2 | 3 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot 4 | 5 | 6 | def tic_gauge(self, fig, gridspec): 7 | """TIC gauge, showing the TIC versus iteration. 8 | 9 | The TIC (either approximated via traces or using a diagonal approximation) 10 | describes the relation between the curvature and the gradient noise. `Recent 11 | work `_ suggested that *at a local minimum*, 12 | this quantitiy can estimate the generalization gap. This instrument shows the 13 | TIC versus iteration, overlayed with an exponentially weighted average. 14 | 15 | **Preview** 16 | 17 | .. image:: ../../_static/instrument_previews/TIC.png 18 | :alt: Preview TIC Gauge 19 | 20 | **Requires** 21 | 22 | The trace instrument requires data from the :class:`~cockpit.quantities.TICDiag` 23 | or :class:`~cockpit.quantities.TICTrace` quantity class. 24 | 25 | Args: 26 | self (CockpitPlotter): The cockpit plotter requesting this instrument. 27 | fig (matplotlib.figure.Figure): Figure of the Cockpit. 28 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be 29 | placed 30 | """ 31 | # Plot Trace vs iteration 32 | title = "TIC" 33 | 34 | if check_data(self.tracking_data, ["TICDiag"]): 35 | plot_args = { 36 | "x": "iteration", 37 | "y": "TICDiag", 38 | "data": self.tracking_data, 39 | "x_scale": "symlog" if self.show_log_iter else "linear", 40 | "y_scale": "linear", 41 | "cmap": self.cmap, 42 | "EMA": "y", 43 | "EMA_alpha": self.EMA_alpha, 44 | "EMA_cmap": self.cmap2, 45 | "title": title, 46 | "xlim": "tight", 47 | "ylim": None, 48 | "fontweight": "bold", 49 | "facecolor": self.bg_color_instruments, 50 | } 51 | ax = fig.add_subplot(gridspec) 52 | create_basic_plot(**plot_args, ax=ax) 53 | 54 | if check_data(self.tracking_data, ["TICTrace"]): 55 | if "ax" in locals(): 56 | ax2 = ax.twinx() 57 | else: 58 | ax2 = fig.add_subplot(gridspec) 59 | plot_args = { 60 | "x": "iteration", 61 | "y": "TICTrace", 62 | "data": self.tracking_data, 63 | "x_scale": "symlog" if self.show_log_iter else "linear", 64 | "y_scale": "linear", 65 | "cmap": self.cmap_backup, 66 | "EMA": "y", 67 | "EMA_alpha": self.EMA_alpha, 68 | "EMA_cmap": self.cmap2_backup, 69 | "title": title, 70 | "xlim": "tight", 71 | "ylim": None, 72 | "fontweight": "bold", 73 | "facecolor": self.bg_color_instruments, 74 | } 75 | create_basic_plot(**plot_args, ax=ax2) 76 | -------------------------------------------------------------------------------- /cockpit/instruments/trace_gauge.py: -------------------------------------------------------------------------------- 1 | """Trace Gauge.""" 2 | 3 | import warnings 4 | 5 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot 6 | 7 | 8 | def trace_gauge(self, fig, gridspec): 9 | """Trace gauge, showing the trace of the Hessian versus iteration. 10 | 11 | The trace of the hessian is the sum of its eigenvalues and thus can indicate 12 | the overall or average curvature of the loss landscape at the current point. 13 | Increasing values for the trace indicate a steeper curvature, for example, a 14 | narrower valley. This instrument shows the trace versus iteration, overlayed 15 | with an exponentially weighted average. 16 | 17 | **Preview** 18 | 19 | .. image:: ../../_static/instrument_previews/HessTrace.png 20 | :alt: Preview HessTrace Gauge 21 | 22 | **Requires** 23 | 24 | The trace instrument requires data from the :class:`~cockpit.quantities.HessTrace` 25 | quantity class. 26 | 27 | Args: 28 | self (CockpitPlotter): The cockpit plotter requesting this instrument. 29 | fig (matplotlib.figure.Figure): Figure of the Cockpit. 30 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be 31 | placed 32 | """ 33 | # Plot Trace vs iteration 34 | title = "Trace" 35 | 36 | # Check if the required data is available, else skip this instrument 37 | requires = ["HessTrace"] 38 | plot_possible = check_data(self.tracking_data, requires) 39 | if not plot_possible: 40 | if self.debug: 41 | warnings.warn( 42 | "Couldn't get the required data for the " + title + " instrument", 43 | stacklevel=1, 44 | ) 45 | return 46 | 47 | # Compute 48 | self.tracking_data["HessTrace_all"] = self.tracking_data.HessTrace.map( 49 | lambda x: sum(x) if type(x) == list else x 50 | ) 51 | 52 | plot_args = { 53 | "x": "iteration", 54 | "y": "HessTrace_all", 55 | "data": self.tracking_data, 56 | "x_scale": "symlog" if self.show_log_iter else "linear", 57 | "y_scale": "linear", 58 | "cmap": self.cmap, 59 | "EMA": "y", 60 | "EMA_alpha": self.EMA_alpha, 61 | "EMA_cmap": self.cmap2, 62 | "title": title, 63 | "xlim": "tight", 64 | "ylim": None, 65 | "fontweight": "bold", 66 | "facecolor": self.bg_color_instruments, 67 | } 68 | ax = fig.add_subplot(gridspec) 69 | create_basic_plot(**plot_args, ax=ax) 70 | -------------------------------------------------------------------------------- /cockpit/instruments/utils_plotting.py: -------------------------------------------------------------------------------- 1 | """Utility functions for the CockpitPlotter.""" 2 | 3 | import numpy as np 4 | from matplotlib.colors import LinearSegmentedColormap 5 | 6 | 7 | def _extract_problem_info(source): 8 | """Split the logpath to identify test problem, data set, etc. 9 | 10 | Args: 11 | source (Cockpit or str): ``Cockpit`` instance, or string containing the 12 | path to a .json log produced with ``Cockpit.write``, where 13 | information will be fetched from. 14 | 15 | Returns: 16 | [dict]: Dictioniary of logpath, testproblem, optimizer, etc. 17 | """ 18 | if isinstance(source, str): 19 | # Split logpath if possible 20 | try: 21 | dicty = { 22 | "logpath": source + ".json", 23 | "optimizer": source.split("/")[-3], 24 | "testproblem": source.split("/")[-4], 25 | "dataset": source.split("/")[-4].split("_", 1)[0], 26 | "model": source.split("/")[-4].split("_", 1)[1], 27 | } 28 | except Exception: 29 | dicty = { 30 | "logpath": source + ".json", 31 | "optimizer": "", 32 | "testproblem": "", 33 | "dataset": "", 34 | "model": "", 35 | } 36 | else: 37 | # Source is Cockpit instance 38 | dicty = { 39 | "logpath": "", 40 | "optimizer": source._optimizer_name, 41 | "testproblem": "", 42 | "dataset": "", 43 | "model": "", 44 | } 45 | return dicty 46 | 47 | 48 | def legend(): 49 | """Creates the legend of the whole cockpit, combining the individual instruments.""" 50 | pass 51 | 52 | 53 | def _alpha_cmap(color, ncolors=256): 54 | """Create a Color map that goes from transparant to a given color. 55 | 56 | Args: 57 | color (tuple): A matplotlib-compatible color. 58 | ncolors (int, optional): Number of "steps" in the colormap. 59 | Defaults to 256. 60 | 61 | Returns: 62 | [matplotlib.cmap]: A matplotlib colormap 63 | """ 64 | color_array = np.array(ncolors * [list(color)]) 65 | 66 | # change alpha values 67 | color_array[:, -1] = np.linspace(0.0, 1.0, ncolors) 68 | # create a colormap object 69 | cmap = LinearSegmentedColormap.from_list(name="alpha_cmap", colors=color_array) 70 | 71 | return cmap 72 | -------------------------------------------------------------------------------- /cockpit/quantities/__init__.py: -------------------------------------------------------------------------------- 1 | """Quantities tracked during training.""" 2 | 3 | from cockpit.quantities.alpha import Alpha 4 | from cockpit.quantities.cabs import CABS 5 | from cockpit.quantities.distance import Distance 6 | from cockpit.quantities.early_stopping import EarlyStopping 7 | from cockpit.quantities.grad_hist import GradHist1d, GradHist2d 8 | from cockpit.quantities.grad_norm import GradNorm 9 | from cockpit.quantities.hess_max_ev import HessMaxEV 10 | from cockpit.quantities.hess_trace import HessTrace 11 | from cockpit.quantities.inner_test import InnerTest 12 | from cockpit.quantities.loss import Loss 13 | from cockpit.quantities.mean_gsnr import MeanGSNR 14 | from cockpit.quantities.norm_test import NormTest 15 | from cockpit.quantities.ortho_test import OrthoTest 16 | from cockpit.quantities.parameters import Parameters 17 | from cockpit.quantities.tic import TICDiag, TICTrace 18 | from cockpit.quantities.time import Time 19 | from cockpit.quantities.update_size import UpdateSize 20 | 21 | __all__ = [ 22 | "Loss", 23 | "Parameters", 24 | "Distance", 25 | "UpdateSize", 26 | "GradNorm", 27 | "Time", 28 | "Alpha", 29 | "CABS", 30 | "EarlyStopping", 31 | "GradHist1d", 32 | "GradHist2d", 33 | "NormTest", 34 | "InnerTest", 35 | "OrthoTest", 36 | "HessMaxEV", 37 | "HessTrace", 38 | "TICDiag", 39 | "TICTrace", 40 | "MeanGSNR", 41 | ] 42 | -------------------------------------------------------------------------------- /cockpit/quantities/cabs.py: -------------------------------------------------------------------------------- 1 | """Class for tracking the CABS criterion for adaptive batch size.""" 2 | 3 | from backpack.extensions import BatchGrad 4 | 5 | from cockpit.context import get_batch_size, get_optimizer 6 | from cockpit.quantities.quantity import SingleStepQuantity 7 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook_SumGradSquared 8 | from cockpit.utils.optim import ComputeStep 9 | 10 | 11 | class CABS(SingleStepQuantity): 12 | """CABS Quantity class for the suggested batch size using the CABS criterion. 13 | 14 | CABS uses the current learning rate and variance of the stochastic gradients 15 | to suggest an optimal batch size. 16 | 17 | Only applies to SGD without momentum. 18 | 19 | Note: Proposed in 20 | 21 | - Balles, L., Romero, J., & Hennig, P., 22 | Coupling adaptive batch sizes with learning rates (2017). 23 | """ 24 | 25 | def get_lr(self, optimizer): 26 | """Extract the learning rate. 27 | 28 | Args: 29 | optimizer (torch.optim.Optimizer): A PyTorch optimizer. 30 | 31 | Returns: 32 | float: Learning rate 33 | 34 | Raises: 35 | ValueError: If the learning rate varies over parameter groups. 36 | """ 37 | lrs = {group["lr"] for group in optimizer.param_groups} 38 | 39 | if len(lrs) != 1: 40 | raise ValueError(f"Found non-unique learning rates {lrs}") 41 | 42 | return lrs.pop() 43 | 44 | def extensions(self, global_step): 45 | """Return list of BackPACK extensions required for the computation. 46 | 47 | Args: 48 | global_step (int): The current iteration number. 49 | 50 | Returns: 51 | list: (Potentially empty) list with required BackPACK quantities. 52 | """ 53 | ext = [] 54 | 55 | if self.should_compute(global_step): 56 | ext.append(BatchGrad()) 57 | 58 | return ext 59 | 60 | def extension_hooks(self, global_step): 61 | """Return list of BackPACK extension hooks required for the computation. 62 | 63 | Args: 64 | global_step (int): The current iteration number. 65 | 66 | Returns: 67 | [callable]: List of required BackPACK extension hooks for the current 68 | iteration. 69 | """ 70 | hooks = [] 71 | 72 | if self.should_compute(global_step): 73 | hooks.append(BatchGradTransformsHook_SumGradSquared()) 74 | 75 | return hooks 76 | 77 | def _compute(self, global_step, params, batch_loss): 78 | """Compute the CABS rule. Return suggested batch size. 79 | 80 | Evaluates Equ. 22 of 81 | 82 | - Balles, L., Romero, J., & Hennig, P., 83 | Coupling adaptive batch sizes with learning rates (2017). 84 | 85 | Args: 86 | global_step (int): The current iteration number. 87 | params ([torch.Tensor]): List of torch.Tensors holding the network's 88 | parameters. 89 | batch_loss (torch.Tensor): Mini-batch loss from current step. 90 | 91 | Returns: 92 | float: Batch size suggested by CABS. 93 | 94 | Raises: 95 | ValueError: If the optimizer differs from SGD with default arguments. 96 | """ 97 | optimizer = get_optimizer(global_step) 98 | if not ComputeStep.is_sgd_default_kwargs(optimizer): 99 | raise ValueError("This criterion only supports zero-momentum SGD.") 100 | 101 | B = get_batch_size(global_step) 102 | lr = self.get_lr(optimizer) 103 | 104 | grad_squared = self._fetch_grad(params, aggregate=True) ** 2 105 | # # compensate BackPACK's 1/B scaling 106 | sgs_compensated = ( 107 | B ** 2 108 | * self._fetch_sum_grad_squared_via_batch_grad_transforms( 109 | params, aggregate=True 110 | ) 111 | ) 112 | 113 | return ( 114 | lr * (sgs_compensated - B * grad_squared).sum() / (B * batch_loss) 115 | ).item() 116 | -------------------------------------------------------------------------------- /cockpit/quantities/distance.py: -------------------------------------------------------------------------------- 1 | """Class for tracking distance from initialization.""" 2 | 3 | from cockpit.quantities.quantity import TwoStepQuantity 4 | 5 | 6 | class Distance(TwoStepQuantity): 7 | """Distance Quantity class tracking distance of the parameters from their init.""" 8 | 9 | CACHE_KEY = "params" 10 | """str: String under which the parameters are cached for computation. 11 | Default: ``'params'``. 12 | """ 13 | INIT_GLOBAL_STEP = 0 14 | """int: Iteration number used as reference. Defaults to ``0``.""" 15 | 16 | def extensions(self, global_step): 17 | """Return list of BackPACK extensions required for the computation. 18 | 19 | Args: 20 | global_step (int): The current iteration number. 21 | 22 | Returns: 23 | list: (Potentially empty) list with required BackPACK quantities. 24 | """ 25 | return [] 26 | 27 | def is_start(self, global_step): 28 | """Return whether current iteration is start point. 29 | 30 | Only the initializtion (first iteration) is a start point. 31 | 32 | Args: 33 | global_step (int): The current iteration number. 34 | 35 | Returns: 36 | bool: Whether ``global_step`` is a start point. 37 | """ 38 | return global_step == self.INIT_GLOBAL_STEP 39 | 40 | def is_end(self, global_step): 41 | """Return whether current iteration is end point. 42 | 43 | Args: 44 | global_step (int): The current iteration number. 45 | 46 | Returns: 47 | bool: Whether ``global_step`` is an end point. 48 | """ 49 | return self._track_schedule(global_step) 50 | 51 | def _compute_start(self, global_step, params, batch_loss): 52 | """Perform computations at start point (store initial parameter values). 53 | 54 | Modifies ``self._cache``. 55 | 56 | Args: 57 | global_step (int): The current iteration number. 58 | params ([torch.Tensor]): List of torch.Tensors holding the network's 59 | parameters. 60 | batch_loss (torch.Tensor): Mini-batch loss from current step. 61 | """ 62 | params_copy = [p.data.clone().detach() for p in params] 63 | 64 | def block_fn(step): 65 | """Block deletion of parameters for all non-negative iterations. 66 | 67 | Args: 68 | step (int): Iteration number. 69 | 70 | Returns: 71 | bool: Whether deletion is blocked in the specified iteration 72 | """ 73 | return step >= self.INIT_GLOBAL_STEP 74 | 75 | self.save_to_cache(global_step, self.CACHE_KEY, params_copy, block_fn) 76 | 77 | def _compute_end(self, global_step, params, batch_loss): 78 | """Compute and return the current distance from initialization. 79 | 80 | Args: 81 | global_step (int): The current iteration number. 82 | params ([torch.Tensor]): List of torch.Tensors holding the network's 83 | parameters. 84 | batch_loss (torch.Tensor): Mini-batch loss from current step. 85 | 86 | Returns: 87 | [float]: Layer-wise L2-distances to initialization. 88 | """ 89 | params_init = self.load_from_cache(self.INIT_GLOBAL_STEP, self.CACHE_KEY) 90 | 91 | distance = [ 92 | (p.data - p_init).norm(2).item() for p, p_init in zip(params, params_init) 93 | ] 94 | 95 | return distance 96 | -------------------------------------------------------------------------------- /cockpit/quantities/early_stopping.py: -------------------------------------------------------------------------------- 1 | """Class for tracking the EB criterion for early stopping.""" 2 | 3 | from backpack.extensions import BatchGrad 4 | 5 | from cockpit.context import get_batch_size, get_optimizer 6 | from cockpit.quantities.quantity import SingleStepQuantity 7 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook_SumGradSquared 8 | from cockpit.utils.optim import ComputeStep 9 | 10 | 11 | class EarlyStopping(SingleStepQuantity): 12 | """Quantity class for the evidence-based early-stopping criterion. 13 | 14 | This criterion uses local statistics of the gradients to indicate when training 15 | should be stopped. If the criterion exceeds zero, training should be stopped. 16 | 17 | Note: Proposed in 18 | 19 | - Mahsereci, M., Balles, L., Lassner, C., & Hennig, P., 20 | Early stopping without a validation set (2017). 21 | """ 22 | 23 | def __init__(self, track_schedule, verbose=False, epsilon=1e-5): 24 | """Initialization sets the tracking schedule & creates the output dict. 25 | 26 | Args: 27 | track_schedule (callable): Function that maps the ``global_step`` 28 | to a boolean, which determines if the quantity should be computed. 29 | verbose (bool, optional): Turns on verbose mode. Defaults to ``False``. 30 | epsilon (float): Stabilization constant. Defaults to 0.0. 31 | """ 32 | super().__init__(track_schedule, verbose=verbose) 33 | 34 | self._epsilon = epsilon 35 | 36 | def extensions(self, global_step): 37 | """Return list of BackPACK extensions required for the computation. 38 | 39 | Args: 40 | global_step (int): The current iteration number. 41 | 42 | Returns: 43 | list: (Potentially empty) list with required BackPACK quantities. 44 | """ 45 | ext = [] 46 | 47 | if self.should_compute(global_step): 48 | ext.append(BatchGrad()) 49 | 50 | return ext 51 | 52 | def extension_hooks(self, global_step): 53 | """Return list of BackPACK extension hooks required for the computation. 54 | 55 | Args: 56 | global_step (int): The current iteration number. 57 | 58 | Returns: 59 | [callable]: List of required BackPACK extension hooks for the current 60 | iteration. 61 | """ 62 | hooks = [] 63 | 64 | if self.should_compute(global_step): 65 | hooks.append(BatchGradTransformsHook_SumGradSquared()) 66 | 67 | return hooks 68 | 69 | def _compute(self, global_step, params, batch_loss): 70 | """Compute the EB early stopping criterion. 71 | 72 | Evaluates the left hand side of Equ. 7 in 73 | 74 | - Mahsereci, M., Balles, L., Lassner, C., & Hennig, P., 75 | Early stopping without a validation set (2017). 76 | 77 | If this value exceeds 0, training should be stopped. 78 | 79 | Args: 80 | global_step (int): The current iteration number. 81 | params ([torch.Tensor]): List of torch.Tensors holding the network's 82 | parameters. 83 | batch_loss (torch.Tensor): Mini-batch loss from current step. 84 | 85 | Returns: 86 | float: Result of the Early stopping criterion. Training should stop 87 | if it is larger than 0. 88 | 89 | Raises: 90 | ValueError: If the used optimizer differs from SGD with default parameters. 91 | """ 92 | if not ComputeStep.is_sgd_default_kwargs(get_optimizer(global_step)): 93 | raise ValueError("This criterion only supports zero-momentum SGD.") 94 | 95 | B = get_batch_size(global_step) 96 | 97 | grad_squared = self._fetch_grad(params, aggregate=True) ** 2 98 | 99 | # compensate BackPACK's 1/B scaling 100 | sgs_compensated = ( 101 | B ** 2 102 | * self._fetch_sum_grad_squared_via_batch_grad_transforms( 103 | params, aggregate=True 104 | ) 105 | ) 106 | 107 | diag_variance = (sgs_compensated - B * grad_squared) / (B - 1) 108 | 109 | snr = grad_squared / (diag_variance + self._epsilon) 110 | 111 | return 1 - B * snr.mean().item() 112 | -------------------------------------------------------------------------------- /cockpit/quantities/grad_norm.py: -------------------------------------------------------------------------------- 1 | """Class for tracking the Gradient Norm.""" 2 | 3 | from cockpit.quantities.quantity import ByproductQuantity 4 | 5 | 6 | class GradNorm(ByproductQuantity): 7 | """Quantitiy Class for tracking the norm of the mean gradient.""" 8 | 9 | def _compute(self, global_step, params, batch_loss): 10 | """Evaluate the gradient norm at the current point. 11 | 12 | Args: 13 | global_step (int): The current iteration number. 14 | params ([torch.Tensor]): List of torch.Tensors holding the network's 15 | parameters. 16 | batch_loss (torch.Tensor): Mini-batch loss from current step. 17 | 18 | Returns: 19 | torch.Tensor: The quantity's value. 20 | """ 21 | return [p.grad.data.norm(2).item() for p in params] 22 | -------------------------------------------------------------------------------- /cockpit/quantities/hess_trace.py: -------------------------------------------------------------------------------- 1 | """Class for tracking the Trace of the Hessian or an approximation thereof.""" 2 | 3 | from backpack import extensions 4 | 5 | from cockpit.quantities.quantity import SingleStepQuantity 6 | 7 | 8 | class HessTrace(SingleStepQuantity): 9 | """Quantitiy Class tracking the trace of the Hessian during training.""" 10 | 11 | extensions_from_str = { 12 | "diag_h": extensions.DiagHessian, 13 | "diag_ggn_exact": extensions.DiagGGNExact, 14 | "diag_ggn_mc": extensions.DiagGGNMC, 15 | } 16 | 17 | def __init__(self, track_schedule, verbose=False, curvature="diag_h"): 18 | """Initialization sets the tracking schedule & creates the output dict. 19 | 20 | Note: 21 | The curvature options ``"diag_h"`` and ``"diag_ggn_exact"`` are more 22 | expensive than ``"diag_ggn_mc"``, but more precise. For a classification 23 | task with ``C`` classes, the former require that ``C`` times more 24 | information be backpropagated through the computation graph. 25 | 26 | Args: 27 | track_schedule (callable): Function that maps the ``global_step`` 28 | to a boolean, which determines if the quantity should be computed. 29 | verbose (bool, optional): Turns on verbose mode. Defaults to ``False``. 30 | curvature (string): Which diagonal curvature approximation should be used. 31 | Options are ``"diag_h"``, ``"diag_ggn_exact"``, ``"diag_ggn_mc"``. 32 | """ 33 | super().__init__(track_schedule, verbose=verbose) 34 | 35 | self._curvature = curvature 36 | 37 | def extensions(self, global_step): 38 | """Return list of BackPACK extensions required for the computation. 39 | 40 | Args: 41 | global_step (int): The current iteration number. 42 | 43 | Raises: 44 | KeyError: If curvature string has unknown associated extension. 45 | 46 | Returns: 47 | list: (Potentially empty) list with required BackPACK quantities. 48 | """ 49 | ext = [] 50 | 51 | if self.should_compute(global_step): 52 | try: 53 | ext.append(self.extensions_from_str[self._curvature]()) 54 | except KeyError as e: 55 | available = list(self.extensions_from_str.keys()) 56 | raise KeyError(f"Available: {available}") from e 57 | 58 | return ext 59 | 60 | def _compute(self, global_step, params, batch_loss): 61 | """Evaluate the trace of the Hessian at the current point. 62 | 63 | Args: 64 | global_step (int): The current iteration number. 65 | params ([torch.Tensor]): List of torch.Tensors holding the network's 66 | parameters. 67 | batch_loss (torch.Tensor): Mini-batch loss from current step. 68 | 69 | Returns: 70 | list: Trace of the Hessian at the current point. 71 | """ 72 | return [ 73 | diag_c.sum().item() 74 | for diag_c in self._fetch_diag_curvature( 75 | params, self._curvature, aggregate=False 76 | ) 77 | ] 78 | -------------------------------------------------------------------------------- /cockpit/quantities/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | """BackPACK extension hooks. 2 | 3 | Cockpit leverages extension hooks to compact information from BackPACK buffers 4 | which can then be freed immediately during backpropagation. 5 | """ 6 | -------------------------------------------------------------------------------- /cockpit/quantities/hooks/cleanup.py: -------------------------------------------------------------------------------- 1 | """Contains hook that deletes BackPACK buffers during backpropagation.""" 2 | 3 | from typing import Set 4 | 5 | from torch import Tensor 6 | 7 | from cockpit.quantities.hooks.base import ParameterExtensionHook 8 | 9 | 10 | class CleanupHook(ParameterExtensionHook): 11 | """Deletes specified BackPACK buffers during backpropagation.""" 12 | 13 | def __init__(self, delete_savefields: Set[str]): 14 | """Store savefields to be deleted in the backward pass. 15 | 16 | Args: 17 | delete_savefields: Name of buffers to delete. 18 | """ 19 | super().__init__() 20 | self._delete_savefields = delete_savefields 21 | 22 | def param_hook(self, param: Tensor): 23 | """Delete BackPACK buffers in parameter. 24 | 25 | Args: 26 | param: Trainable parameter which hosts BackPACK quantities. 27 | """ 28 | for savefield in self._delete_savefields: 29 | if hasattr(param, savefield): 30 | delattr(param, savefield) 31 | -------------------------------------------------------------------------------- /cockpit/quantities/inner_test.py: -------------------------------------------------------------------------------- 1 | """Class for tracking the Inner Product Test.""" 2 | 3 | 4 | from backpack.extensions import BatchGrad 5 | 6 | from cockpit.quantities.quantity import SingleStepQuantity 7 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook_BatchDotGrad 8 | 9 | 10 | class InnerTest(SingleStepQuantity): 11 | """Quantitiy Class for tracking the result of the inner product test. 12 | 13 | Note: Inner Product test as proposed in 14 | 15 | - Bollapragada, R., Byrd, R., & Nocedal, J., 16 | Adaptive Sampling Strategies for Stochastic Optimization (2017). 17 | https://arxiv.org/abs/1710.11258 18 | """ 19 | 20 | def extensions(self, global_step): 21 | """Return list of BackPACK extensions required for the computation. 22 | 23 | Args: 24 | global_step (int): The current iteration number. 25 | 26 | Returns: 27 | list: (Potentially empty) list with required BackPACK quantities. 28 | """ 29 | ext = [] 30 | 31 | if self.should_compute(global_step): 32 | ext.append(BatchGrad()) 33 | 34 | return ext 35 | 36 | def extension_hooks(self, global_step): 37 | """Return list of BackPACK extension hooks required for the computation. 38 | 39 | Args: 40 | global_step (int): The current iteration number. 41 | 42 | Returns: 43 | [callable]: List of required BackPACK extension hooks for the current 44 | iteration. 45 | """ 46 | hooks = [] 47 | 48 | if self.should_compute(global_step): 49 | hooks.append(BatchGradTransformsHook_BatchDotGrad()) 50 | 51 | return hooks 52 | 53 | def _compute(self, global_step, params, batch_loss): 54 | """Track the practical version of the inner product test. 55 | 56 | Return maximum θ for which the inner product test would pass. 57 | 58 | The inner product test is defined by Equation (2.6) in bollapragada2017adaptive. 59 | 60 | Args: 61 | global_step (int): The current iteration number. 62 | params ([torch.Tensor]): List of torch.Tensors holding the network's 63 | parameters. 64 | batch_loss (torch.Tensor): Mini-batch loss from current step. 65 | 66 | Returns: 67 | float: Maximum θ for which the inner product test would pass. 68 | """ 69 | batch_dot = self._fetch_batch_dot_via_batch_grad_transforms( 70 | params, aggregate=True 71 | ) 72 | grad_l2_squared = self._fetch_grad_l2_squared(params, aggregate=True) 73 | batch_size = batch_dot.size(0) 74 | 75 | var_projection = self._compute_projection_variance( 76 | batch_size, batch_dot, grad_l2_squared 77 | ) 78 | 79 | return self._compute_theta_max( 80 | batch_size, var_projection, grad_l2_squared 81 | ).item() 82 | 83 | def _compute_theta_max(self, batch_size, var_projection, grad_l2_squared): 84 | """Return maximum θ for which the inner product test would pass. 85 | 86 | Args: 87 | batch_size (int): Mini-batch size. 88 | var_projection (torch.Tensor): The sample variance of individual 89 | gradient projections on the mini-batch gradient. 90 | grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient. 91 | 92 | Returns: 93 | [type]: [description] 94 | """ 95 | return (var_projection / batch_size / grad_l2_squared ** 2).sqrt() 96 | 97 | def _compute_projection_variance(self, batch_size, batch_dot, grad_l2_squared): 98 | """Compute sample variance of individual gradient projections onto the gradient. 99 | 100 | The sample variance of projections is given by Equation (line after 2.6) in 101 | bollapragada2017adaptive (https://arxiv.org/pdf/1710.11258.pdf) 102 | 103 | Args: 104 | batch_size (int): Mini-batch size. 105 | batch_dot (torch.Tensor): Individual gradient pairwise dot product. 106 | grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient. 107 | 108 | Returns: 109 | torch.Tensor: The sample variance of individual gradient projections on the 110 | mini-batch gradient. 111 | """ 112 | projections = batch_size * batch_dot.sum(1) 113 | 114 | return (1 / (batch_size - 1)) * ( 115 | (projections ** 2).sum() - batch_size * grad_l2_squared ** 2 116 | ) 117 | -------------------------------------------------------------------------------- /cockpit/quantities/loss.py: -------------------------------------------------------------------------------- 1 | """Class for tracking the loss.""" 2 | 3 | from cockpit.quantities.quantity import ByproductQuantity 4 | 5 | 6 | class Loss(ByproductQuantity): 7 | """Loss Quantity class tracking the mini-batch training loss during training.""" 8 | 9 | def _compute(self, global_step, params, batch_loss): 10 | """Track the loss at the current point. 11 | 12 | Args: 13 | global_step (int): The current iteration number. 14 | params ([torch.Tensor]): List of torch.Tensors holding the network's 15 | parameters. 16 | batch_loss (torch.Tensor): Mini-batch loss from current step. 17 | 18 | Returns: 19 | float: Mini-batch loss at the current iteration. 20 | """ 21 | return batch_loss.item() 22 | -------------------------------------------------------------------------------- /cockpit/quantities/mean_gsnr.py: -------------------------------------------------------------------------------- 1 | """Class for tracking the Mean Gradient Signal to Noise Ration (GSNR).""" 2 | 3 | 4 | from backpack.extensions import BatchGrad 5 | 6 | from cockpit.context import get_batch_size 7 | from cockpit.quantities.quantity import SingleStepQuantity 8 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook_SumGradSquared 9 | 10 | 11 | class MeanGSNR(SingleStepQuantity): 12 | """Quantitiy Class for the mean gradient signal-to-noise ratio (GSNR). 13 | 14 | Note: Mean gradient signal-to-noise ratio as defined by 15 | 16 | - Liu, J., et al. 17 | Understanding Why Neural Networks Generalize Well Through GSNR of 18 | Parameters (2020). 19 | https://arxiv.org/abs/2001.07384 20 | """ 21 | 22 | def __init__(self, track_schedule, verbose=False, epsilon=1e-5): 23 | """Initialize. 24 | 25 | Args: 26 | track_schedule (callable): Function that maps the ``global_step`` 27 | to a boolean, which determines if the quantity should be computed. 28 | verbose (bool, optional): Turns on verbose mode. Defaults to ``False``. 29 | epsilon (float): Stabilization constant. Defaults to 1e-5. 30 | """ 31 | super().__init__(track_schedule, verbose=verbose) 32 | 33 | self._epsilon = epsilon 34 | 35 | def extensions(self, global_step): 36 | """Return list of BackPACK extensions required for the computation. 37 | 38 | Args: 39 | global_step (int): The current iteration number. 40 | 41 | Returns: 42 | list: (Potentially empty) list with required BackPACK quantities. 43 | """ 44 | ext = [] 45 | 46 | if self.should_compute(global_step): 47 | ext.append(BatchGrad()) 48 | 49 | return ext 50 | 51 | def extension_hooks(self, global_step): 52 | """Return list of BackPACK extension hooks required for the computation. 53 | 54 | Args: 55 | global_step (int): The current iteration number. 56 | 57 | Returns: 58 | [callable]: List of required BackPACK extension hooks for the current 59 | iteration. 60 | """ 61 | hooks = [] 62 | 63 | if self.should_compute(global_step): 64 | hooks.append(BatchGradTransformsHook_SumGradSquared()) 65 | 66 | return hooks 67 | 68 | def _compute(self, global_step, params, batch_loss): 69 | """Track the mean GSNR. 70 | 71 | Args: 72 | global_step (int): The current iteration number. 73 | params ([torch.Tensor]): List of torch.Tensors holding the network's 74 | parameters. 75 | batch_loss (torch.Tensor): Mini-batch loss from current step. 76 | 77 | Returns: 78 | float: Mean GSNR of the current iteration. 79 | """ 80 | return self._compute_gsnr(global_step, params, batch_loss).mean().item() 81 | 82 | def _compute_gsnr(self, global_step, params, batch_loss): 83 | """Compute gradient signal-to-noise ratio. 84 | 85 | Args: 86 | global_step (int): The current iteration number. 87 | params ([torch.Tensor]): List of parameters considered in the computation. 88 | batch_loss (torch.Tensor): Mini-batch loss from current step. 89 | 90 | Returns: 91 | float: Mean GSNR of the current iteration. 92 | """ 93 | grad_squared = self._fetch_grad(params, aggregate=True) ** 2 94 | sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms( 95 | params, aggregate=True 96 | ) 97 | 98 | batch_size = get_batch_size(global_step) 99 | 100 | return grad_squared / ( 101 | batch_size * sum_grad_squared - grad_squared + self._epsilon 102 | ) 103 | -------------------------------------------------------------------------------- /cockpit/quantities/norm_test.py: -------------------------------------------------------------------------------- 1 | """Class for tracking the Norm Test.""" 2 | 3 | from backpack.extensions import BatchGrad 4 | 5 | from cockpit.quantities.quantity import SingleStepQuantity 6 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook_BatchL2Grad 7 | 8 | 9 | class NormTest(SingleStepQuantity): 10 | """Quantitiy Class for the norm test. 11 | 12 | Note: Norm test as proposed in 13 | 14 | - Byrd, R., Chin, G., Nocedal, J., & Wu, Y., 15 | Sample size selection in optimization methods for machine learning (2012). 16 | https://link.springer.com/article/10.1007%2Fs10107-012-0572-5 17 | """ 18 | 19 | def extensions(self, global_step): 20 | """Return list of BackPACK extensions required for the computation. 21 | 22 | Args: 23 | global_step (int): The current iteration number. 24 | 25 | Returns: 26 | list: (Potentially empty) list with required BackPACK quantities. 27 | """ 28 | ext = [] 29 | 30 | if self.should_compute(global_step): 31 | ext.append(BatchGrad()) 32 | 33 | return ext 34 | 35 | def extension_hooks(self, global_step): 36 | """Return list of BackPACK extension hooks required for the computation. 37 | 38 | Args: 39 | global_step (int): The current iteration number. 40 | 41 | Returns: 42 | [callable]: List of required BackPACK extension hooks for the current 43 | iteration. 44 | """ 45 | hooks = [] 46 | 47 | if self.should_compute(global_step): 48 | hooks.append(BatchGradTransformsHook_BatchL2Grad()) 49 | 50 | return hooks 51 | 52 | def _compute(self, global_step, params, batch_loss): 53 | """Track the practical version of the norm test. 54 | 55 | Return maximum θ for which the norm test would pass. 56 | 57 | The norm test is defined by Equation (3.9) in byrd2012sample. 58 | 59 | Args: 60 | global_step (int): The current iteration number. 61 | params ([torch.Tensor]): List of torch.Tensors holding the network's 62 | parameters. 63 | batch_loss (torch.Tensor): Mini-batch loss from current step. 64 | 65 | Returns: 66 | float: Maximum θ for which the norm test would pass. 67 | """ 68 | batch_l2_squared = self._fetch_batch_l2_squared_via_batch_grad_transforms( 69 | params, aggregate=True 70 | ) 71 | grad_l2_squared = self._fetch_grad_l2_squared(params, aggregate=True) 72 | batch_size = batch_l2_squared.size(0) 73 | 74 | var_l1 = self._compute_variance_l1( 75 | batch_size, batch_l2_squared, grad_l2_squared 76 | ) 77 | 78 | return self._compute_theta_max(batch_size, var_l1, grad_l2_squared).item() 79 | 80 | def _compute_theta_max(self, batch_size, var_l1, grad_l2_squared): 81 | """Return maximum θ for which the norm test would pass. 82 | 83 | Args: 84 | batch_size (int): Mini-batch size. 85 | var_l1 (torch.Tensor): [description] 86 | grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient. 87 | 88 | Returns: 89 | [type]: [description] 90 | """ 91 | return (var_l1 / batch_size / grad_l2_squared).sqrt() 92 | 93 | def _compute_variance_l1(self, batch_size, batch_l2_squared, grad_l2_squared): 94 | """Compute the sample variance ℓ₁ norm. 95 | 96 | It shows up in Equations (3.9) and (3.11) in byrd2012sample and relies 97 | on the sample variance (Equation 3.6). The ℓ₁ norm can be computed using 98 | individual gradient squared ℓ₂ norms and the mini-batch gradient squared 99 | ℓ₂ norm. 100 | 101 | Args: 102 | batch_size (int): Mini-batch size. 103 | batch_l2_squared (torch.Tensor): [description] 104 | grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient. 105 | 106 | Returns: 107 | torch.Tensor: The sample variance ℓ₁ norm. 108 | """ 109 | return (1 / (batch_size - 1)) * ( 110 | batch_size ** 2 * batch_l2_squared.sum() - batch_size * grad_l2_squared 111 | ) 112 | -------------------------------------------------------------------------------- /cockpit/quantities/ortho_test.py: -------------------------------------------------------------------------------- 1 | """Class for tracking the Orthogonality Test.""" 2 | 3 | from backpack.extensions import BatchGrad 4 | 5 | from cockpit.quantities.quantity import SingleStepQuantity 6 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook_BatchDotGrad 7 | 8 | 9 | class OrthoTest(SingleStepQuantity): 10 | """Quantity Class for the orthogonality test. 11 | 12 | Note: Orthogonality test as proposed in 13 | 14 | - Bollapragada, R., Byrd, R., & Nocedal, J., 15 | Adaptive Sampling Strategies for Stochastic Optimization (2017). 16 | https://arxiv.org/abs/1710.11258 17 | """ 18 | 19 | def extensions(self, global_step): 20 | """Return list of BackPACK extensions required for the computation. 21 | 22 | Args: 23 | global_step (int): The current iteration number. 24 | 25 | Returns: 26 | list: (Potentially empty) list with required BackPACK quantities. 27 | """ 28 | ext = [] 29 | 30 | if self.should_compute(global_step): 31 | ext.append(BatchGrad()) 32 | 33 | return ext 34 | 35 | def extension_hooks(self, global_step): 36 | """Return list of BackPACK extension hooks required for the computation. 37 | 38 | Args: 39 | global_step (int): The current iteration number. 40 | 41 | Returns: 42 | [callable]: List of required BackPACK extension hooks for the current 43 | iteration. 44 | """ 45 | hooks = [] 46 | 47 | if self.should_compute(global_step): 48 | hooks.append(BatchGradTransformsHook_BatchDotGrad()) 49 | 50 | return hooks 51 | 52 | def _compute(self, global_step, params, batch_loss): 53 | """Track the practical version of the orthogonality test. 54 | 55 | Return maximum ν for which the orthogonality test would pass. 56 | 57 | The orthogonality test is defined by Equation (3.3) in bollapragada2017adaptive. 58 | 59 | Args: 60 | global_step (int): The current iteration number. 61 | params ([torch.Tensor]): List of torch.Tensors holding the network's 62 | parameters. 63 | batch_loss (torch.Tensor): Mini-batch loss from current step. 64 | 65 | Returns: 66 | float: Maximum ν for which the orthogonality test would pass. 67 | """ 68 | batch_dot = self._fetch_batch_dot_via_batch_grad_transforms( 69 | params, aggregate=True 70 | ) 71 | batch_size = batch_dot.size(0) 72 | grad_l2_squared = self._fetch_grad_l2_squared(params, aggregate=True) 73 | 74 | var_orthogonal_projection = self._compute_orthogonal_projection_variance( 75 | batch_size, batch_dot, grad_l2_squared 76 | ) 77 | 78 | return self._compute_nu_max( 79 | batch_size, var_orthogonal_projection, grad_l2_squared 80 | ).item() 81 | 82 | def _compute_nu_max(self, batch_size, var_orthogonal_projection, grad_l2_squared): 83 | """Return maximum ν for which the orthogonality test would pass. 84 | 85 | The orthogonality test is defined by Equation (3.3) in 86 | bollapragada2017adaptive. 87 | 88 | Args: 89 | batch_size (int): Mini-batch size. 90 | var_orthogonal_projection (torch.Tensor): [description] 91 | grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient. 92 | 93 | Returns: 94 | [type]: Maximum ν for which the orthogonality test would pass. 95 | """ 96 | return (var_orthogonal_projection / batch_size / grad_l2_squared).sqrt() 97 | 98 | def _compute_orthogonal_projection_variance( 99 | self, batch_size, batch_dot, grad_l2_squared 100 | ): 101 | """Compute sample variance of individual gradient orthogonal projections. 102 | 103 | The sample variance of orthogonal projections shows up in Equation (3.3) in 104 | bollapragada2017adaptive (https://arxiv.org/pdf/1710.11258.pdf) 105 | 106 | Args: 107 | batch_size (int): Mini-batch size. 108 | batch_dot (torch.Tensor): Individual gradient pairwise dot product. 109 | grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient. 110 | 111 | Returns: 112 | torch.Tensor: The sample variance of individual gradient orthogonal 113 | projections on the mini-batch gradient. 114 | """ 115 | batch_l2_squared = batch_dot.diag() 116 | projections = batch_size * batch_dot.sum(1) 117 | 118 | return (1 / (batch_size - 1)) * ( 119 | batch_size ** 2 * batch_l2_squared.sum() 120 | - (projections ** 2 / grad_l2_squared).sum() 121 | ) 122 | -------------------------------------------------------------------------------- /cockpit/quantities/parameters.py: -------------------------------------------------------------------------------- 1 | """Class for tracking the Individual Parameters.""" 2 | 3 | from cockpit.quantities.quantity import ByproductQuantity 4 | 5 | 6 | class Parameters(ByproductQuantity): 7 | """Parameter Quantitiy class tracking the current parameters in each iteration.""" 8 | 9 | def _compute(self, global_step, params, batch_loss): 10 | """Store the current parameter. 11 | 12 | Args: 13 | global_step (int): The current iteration number. 14 | params ([torch.Tensor]): List of torch.Tensors holding the network's 15 | parameters. 16 | batch_loss (torch.Tensor): Mini-batch loss from current step. 17 | 18 | Returns: 19 | list: Current model parameters. 20 | """ 21 | return [p.data.tolist() for p in params] 22 | -------------------------------------------------------------------------------- /cockpit/quantities/time.py: -------------------------------------------------------------------------------- 1 | """Class for tracking the time.""" 2 | 3 | import time 4 | 5 | from cockpit.quantities.quantity import ByproductQuantity 6 | 7 | 8 | class Time(ByproductQuantity): 9 | """Time Quantity Class tracking the time during training.""" 10 | 11 | def _compute(self, global_step, params, batch_loss): 12 | """Return the time at the current point. 13 | 14 | Args: 15 | global_step (int): The current iteration number. 16 | params ([torch.Tensor]): List of torch.Tensors holding the network's 17 | parameters. 18 | batch_loss (torch.Tensor): Mini-batch loss from current step. 19 | 20 | Returns: 21 | float: Current time as given by ``time.time()``. 22 | """ 23 | return time.time() 24 | -------------------------------------------------------------------------------- /cockpit/quantities/update_size.py: -------------------------------------------------------------------------------- 1 | """Class for tracking the update size.""" 2 | 3 | from cockpit.quantities.quantity import TwoStepQuantity 4 | 5 | 6 | class UpdateSize(TwoStepQuantity): 7 | """Quantity class for tracking parameter update sizes.""" 8 | 9 | CACHE_KEY = "params" 10 | """str: String under which the parameters are cached for computation. 11 | Default: ``'params'``. 12 | """ 13 | SAVE_SHIFT = 1 14 | """int: Difference between iteration at which information is computed versus 15 | iteration under which it is stored. For instance, if set to ``1``, the 16 | information computed at iteration ``n + 1`` is saved under iteration ``n``. 17 | Defaults to ``1``. 18 | """ 19 | 20 | def extensions(self, global_step): 21 | """Return list of BackPACK extensions required for the computation. 22 | 23 | Args: 24 | global_step (int): The current iteration number. 25 | 26 | Returns: 27 | list: (Potentially empty) list with required BackPACK quantities. 28 | """ 29 | return [] 30 | 31 | def is_start(self, global_step): 32 | """Return whether current iteration is start point. 33 | 34 | Args: 35 | global_step (int): The current iteration number. 36 | 37 | Returns: 38 | bool: Whether ``global_step`` is a start point. 39 | """ 40 | return self._track_schedule(global_step) 41 | 42 | def is_end(self, global_step): 43 | """Return whether current iteration is end point. 44 | 45 | Args: 46 | global_step (int): The current iteration number. 47 | 48 | Returns: 49 | bool: Whether ``global_step`` is an end point. 50 | """ 51 | return self._track_schedule(global_step - self.SAVE_SHIFT) 52 | 53 | def _compute_start(self, global_step, params, batch_loss): 54 | """Perform computations at start point (store current parameter values). 55 | 56 | Modifies ``self._cache``. 57 | 58 | Args: 59 | global_step (int): The current iteration number. 60 | params ([torch.Tensor]): List of torch.Tensors holding the network's 61 | parameters. 62 | batch_loss (torch.Tensor): Mini-batch loss from current step. 63 | """ 64 | params_copy = [p.data.clone().detach() for p in params] 65 | 66 | def block_fn(step): 67 | """Block deletion of parameters for current and next iteration. 68 | 69 | Args: 70 | step (int): Iteration number. 71 | 72 | Returns: 73 | bool: Whether deletion is blocked in the specified iteration 74 | """ 75 | return 0 <= step - global_step <= self.SAVE_SHIFT 76 | 77 | self.save_to_cache(global_step, self.CACHE_KEY, params_copy, block_fn) 78 | 79 | def _compute_end(self, global_step, params, batch_loss): 80 | """Compute and return update size. 81 | 82 | Args: 83 | global_step (int): The current iteration number. 84 | params ([torch.Tensor]): List of torch.Tensors holding the network's 85 | parameters. 86 | batch_loss (torch.Tensor): Mini-batch loss from current step. 87 | 88 | Returns: 89 | [float]: Layer-wise L2-norms of parameter updates. 90 | """ 91 | params_start = self.load_from_cache( 92 | global_step - self.SAVE_SHIFT, self.CACHE_KEY 93 | ) 94 | 95 | update_size = [ 96 | (p.data - p_start).norm(2).item() 97 | for p, p_start in zip(params, params_start) 98 | ] 99 | 100 | return update_size 101 | -------------------------------------------------------------------------------- /cockpit/quantities/utils_quantities.py: -------------------------------------------------------------------------------- 1 | """Utility Functions for the Quantities and the Tracking in General.""" 2 | 3 | import torch 4 | 5 | 6 | def _layerwise_dot_product(x_s, y_s): 7 | """Computes the dot product of two parameter vectors layerwise. 8 | 9 | Args: 10 | x_s (list): First list of parameter vectors. 11 | y_s (list): Second list of parameter vectors. 12 | 13 | Returns: 14 | torch.Tensor: 1-D list of scalars. Each scalar is a dot product of one layer. 15 | """ 16 | return [torch.sum(x * y).item() for x, y in zip(x_s, y_s)] 17 | 18 | 19 | def _root_sum_of_squares(list): 20 | """Returns the root of the sum of squares of a given list. 21 | 22 | Args: 23 | list (list): A list of floats 24 | 25 | Returns: 26 | [float]: Root sum of squares 27 | """ 28 | return sum((el ** 2 for el in list)) ** (0.5) 29 | 30 | 31 | def abs_max(tensor): 32 | """Return maximum absolute entry in ``tensor``.""" 33 | min_val, max_val = tensor.min(), tensor.max() 34 | return max(min_val.abs(), max_val.abs()) 35 | -------------------------------------------------------------------------------- /cockpit/quantities/utils_transforms.py: -------------------------------------------------------------------------------- 1 | """Utility functions for Transforms.""" 2 | 3 | import string 4 | import weakref 5 | 6 | from torch import Tensor, einsum 7 | 8 | from cockpit.quantities.hooks.base import ParameterExtensionHook 9 | 10 | 11 | def BatchGradTransformsHook_BatchL2Grad(): 12 | """Compute individual gradient ℓ₂ norms via individual gradients.""" 13 | return BatchGradTransformsHook({"batch_l2": batch_l2_transform}) 14 | 15 | 16 | def batch_l2_transform(batch_grad): 17 | """Transform individual gradients into individual ℓ₂ norms.""" 18 | sum_axes = list(range(batch_grad.dim()))[1:] 19 | return (batch_grad ** 2).sum(sum_axes) 20 | 21 | 22 | def BatchGradTransformsHook_BatchDotGrad(): 23 | """Compute pairwise individual gradient dot products via individual gradients.""" 24 | return BatchGradTransformsHook({"batch_dot": batch_dot_transform}) 25 | 26 | 27 | def batch_dot_transform(batch_grad): 28 | """Transform individual gradients into pairwise dot products.""" 29 | # make einsum string 30 | letters = get_first_n_alphabet(batch_grad.dim() + 1) 31 | n1, n2, sum_out = letters[0], letters[1], "".join(letters[2:]) 32 | 33 | einsum_equation = f"{n1}{sum_out},{n2}{sum_out}->{n1}{n2}" 34 | 35 | return einsum(einsum_equation, batch_grad, batch_grad) 36 | 37 | 38 | def get_first_n_alphabet(n): 39 | """Return the first n lowercase letters of the alphabet as a list.""" 40 | return string.ascii_lowercase[:n] 41 | 42 | 43 | def BatchGradTransformsHook_SumGradSquared(): 44 | """Compute sum of squared individual gradients via individual gradients.""" 45 | return BatchGradTransformsHook({"sum_grad_squared": sum_grad_squared_transform}) 46 | 47 | 48 | def sum_grad_squared_transform(batch_grad): 49 | """Transform individual gradients into second non-centered moment.""" 50 | return (batch_grad ** 2).sum(0) 51 | 52 | 53 | class BatchGradTransformsHook(ParameterExtensionHook): 54 | """Hook implementation of ``BatchGradTransforms``.""" 55 | 56 | def __init__(self, transforms, savefield=None): 57 | """Store transformations and potential savefield. 58 | 59 | Args: 60 | transforms (dict): Values are functions that are evaluated on a parameter's 61 | ``grad_batch`` attribute. The result is stored in a dictionary stored 62 | under ``grad_batch_transforms``. 63 | savefield (str, optional): Attribute name under which the hook's result 64 | is saved in a parameter. If ``None``, it is assumed that the hook acts 65 | via side effects and no output needs to be stored. 66 | """ 67 | super().__init__(savefield=savefield) 68 | self._transforms = transforms 69 | 70 | def param_hook(self, param: Tensor): 71 | """Execute all transformations and store results as dictionary in the parameter. 72 | 73 | Args: 74 | param: Trainable parameter which hosts BackPACK quantities. 75 | """ 76 | param.grad_batch._param_weakref = weakref.ref(param) 77 | # TODO Delete after backward pass with Cockpit 78 | param.grad_batch_transforms = { 79 | key: func(param.grad_batch) for key, func in self._transforms.items() 80 | } 81 | -------------------------------------------------------------------------------- /cockpit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions to configure cockpit.""" 2 | -------------------------------------------------------------------------------- /cockpit/utils/configuration.py: -------------------------------------------------------------------------------- 1 | """Configuration utilities for cockpit.""" 2 | 3 | from cockpit import quantities 4 | from cockpit.utils import schedules 5 | 6 | 7 | def configuration(label, track_schedule=None, verbose=False): 8 | """Use pre-defined collections of quantities that should be used for tracking. 9 | 10 | Currently supports three different configurations: 11 | 12 | - ``"economy"``: Combines the :class:`~cockpit.quantities.Alpha`, 13 | :class:`~cockpit.quantities.Distance`, :class:`~cockpit.quantities.GradHist1d`, 14 | :class:`~cockpit.quantities.GradNorm`, :class:`~cockpit.quantities.InnerTest`, 15 | :class:`~cockpit.quantities.Loss`, :class:`~cockpit.quantities.NormTest`, 16 | :class:`~cockpit.quantities.OrthoTest` and :class:`~cockpit.quantities.UpdateSize` 17 | quantities. 18 | - ``"business"``: Same as ``"economy"`` but additionally with 19 | :class:`~cockpit.quantities.TICDiag` and :class:`~cockpit.quantities.HessTrace`. 20 | - ``"full"``: Same as ``"business"`` but additionally with 21 | :class:`~cockpit.quantities.HessMaxEV` and 22 | :class:`~cockpit.quantities.GradHist2d`. 23 | 24 | Args: 25 | label (str): String specifying the configuration type. Possible configurations 26 | are (least to most expensive) ``'economy'``, ``'business'``, ``'full'``. 27 | track_schedule (callable, optional): Function that maps the ``global_step`` 28 | to a boolean, which determines if the quantity should be computed. 29 | Defaults to ``None``. 30 | verbose (bool, optional): Turns on verbose mode. Defaults to ``False``. 31 | 32 | Returns: 33 | list: Instantiated quantities for a cockpit configuration. 34 | """ 35 | if track_schedule is None: 36 | track_schedule = schedules.linear(interval=1, offset=0) 37 | 38 | quants = [] 39 | for q_cls in quantities_cls_for_configuration(label): 40 | quants.append(q_cls(track_schedule=track_schedule, verbose=verbose)) 41 | 42 | return quants 43 | 44 | 45 | def quantities_cls_for_configuration(label): 46 | """Return the quantity classes for a cockpit configuration. 47 | 48 | Currently supports three different configurations: 49 | 50 | - ``"economy"``: Combines the :class:`~cockpit.quantities.Alpha`, 51 | :class:`~cockpit.quantities.Distance`, :class:`~cockpit.quantities.GradHist1d`, 52 | :class:`~cockpit.quantities.GradNorm`, :class:`~cockpit.quantities.InnerTest`, 53 | :class:`~cockpit.quantities.Loss`, :class:`~cockpit.quantities.NormTest`, 54 | :class:`~cockpit.quantities.OrthoTest` and :class:`~cockpit.quantities.UpdateSize` 55 | quantities. 56 | - ``"business"``: Same as ``"economy"`` but additionally with 57 | :class:`~cockpit.quantities.TICDiag` and :class:`~cockpit.quantities.HessTrace`. 58 | - ``"full"``: Same as ``"business"`` but additionally with 59 | :class:`~cockpit.quantities.HessMaxEV` and 60 | :class:`~cockpit.quantities.GradHist2d`. 61 | 62 | Args: 63 | label (str): String specifying the configuration type. Possible configurations 64 | are (least to most expensive) ``'economy'``, ``'business'``, ``'full'``. 65 | 66 | Returns: 67 | [Quantity]: A list of quantity classes used in the 68 | specified configuration. 69 | """ 70 | economy = [ 71 | quantities.Alpha, 72 | quantities.Distance, 73 | quantities.GradHist1d, 74 | quantities.GradNorm, 75 | quantities.InnerTest, 76 | quantities.Loss, 77 | quantities.NormTest, 78 | quantities.OrthoTest, 79 | quantities.UpdateSize, 80 | ] 81 | business = economy + [ 82 | quantities.TICDiag, 83 | quantities.HessTrace, 84 | ] 85 | full = business + [ 86 | quantities.HessMaxEV, 87 | quantities.GradHist2d, 88 | ] 89 | 90 | configs = { 91 | "full": full, 92 | "business": business, 93 | "economy": economy, 94 | } 95 | 96 | return configs[label] 97 | -------------------------------------------------------------------------------- /cockpit/utils/optim.py: -------------------------------------------------------------------------------- 1 | """Utility functions to investigate optimizers. 2 | 3 | Some quantities either require computation of the optimizer update step, or are only 4 | defined for certain optimizers. 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class ComputeStep: 12 | """Update step computation from BackPACK quantities for different optimizers. 13 | 14 | Note: 15 | The ``.grad`` attribute cannot be used to compute update steps as this code 16 | is invoked as hook during backpropagation at a time where the ``.grad`` field 17 | has not yet been updated with the latest gradients. 18 | """ 19 | 20 | @staticmethod 21 | def compute_update_step(optimizer, parameter_ids): 22 | """Compute an optimizer's update step. 23 | 24 | Args: 25 | optimizer (torch.optim.Optimizer): A PyTorch optimizer. 26 | parameter_ids ([id]): List of parameter ids for which the updates are 27 | computed. 28 | 29 | Returns: 30 | dict: Mapping between parameters and their updates. Keys are parameter ids 31 | and items are ``torch.Tensor``s representing the update. 32 | 33 | Raises: 34 | NotImplementedError: If the optimizer's update step is not implemented. 35 | """ 36 | if ComputeStep.is_sgd_default_kwargs(optimizer): 37 | return ComputeStep.update_sgd_default_kwargs(optimizer, parameter_ids) 38 | 39 | raise NotImplementedError 40 | 41 | @staticmethod 42 | def is_sgd_default_kwargs(optimizer): 43 | """Return whether the input is momentum-free SGD with default values. 44 | 45 | Args: 46 | optimizer (torch.optim.Optimizer): A PyTorch optimizer. 47 | 48 | Returns: 49 | bool: Whether the input is momentum-free SGD with default values. 50 | """ 51 | if not isinstance(optimizer, torch.optim.SGD): 52 | return False 53 | 54 | for group in optimizer.param_groups: 55 | if not np.isclose(group["weight_decay"], 0.0): 56 | return False 57 | 58 | if not np.isclose(group["momentum"], 0.0): 59 | return False 60 | 61 | if not np.isclose(group["dampening"], 0.0): 62 | return False 63 | 64 | if not group["nesterov"] is False: 65 | return False 66 | 67 | return True 68 | 69 | @staticmethod 70 | def update_sgd_default_kwargs(optimizer, parameter_ids): 71 | """Return the update of momentum-free SGD with default values. 72 | 73 | Args: 74 | optimizer (torch.optim.SGD): Zero-momentum default SGD. 75 | parameter_ids ([id]): List of parameter ids for which the updates are 76 | computed. 77 | 78 | Returns: 79 | dict: Mapping between parameters and their updates. Keys are parameter ids 80 | and items are ``torch.Tensor``s representing the update. 81 | """ 82 | updates = {} 83 | 84 | for group in optimizer.param_groups: 85 | for p in group["params"]: 86 | if id(p) in parameter_ids: 87 | lr = group["lr"] 88 | updates[id(p)] = -lr * p.grad_batch.sum(0).detach() 89 | 90 | if len(updates.keys()) == len(parameter_ids): 91 | return updates 92 | 93 | assert len(updates.keys()) == len( 94 | parameter_ids 95 | ), "Could not compute step for all specified parameters" 96 | 97 | return updates 98 | -------------------------------------------------------------------------------- /cockpit/utils/schedules.py: -------------------------------------------------------------------------------- 1 | """Convenient schedule functions.""" 2 | 3 | import torch 4 | 5 | 6 | def linear(interval, offset=0): 7 | """Creates a linear schedule that tracks when ``{offset + n interval | n >= 0}``. 8 | 9 | Args: 10 | interval (int): The regular tracking interval. 11 | offset (int, optional): Offset of tracking. Defaults to 0. 12 | 13 | Returns: 14 | callable: Function that given the global_step returns whether it should track. 15 | """ 16 | docstring = "Track at iterations {" + f"{offset} + n * {interval} " + "| n >= 0}." 17 | 18 | def schedule(global_step): 19 | shifted = global_step - offset 20 | if shifted < 0: 21 | return False 22 | else: 23 | return shifted % interval == 0 24 | 25 | schedule.__doc__ = docstring 26 | 27 | return schedule 28 | 29 | 30 | def logarithmic(start, end, steps=300, base=10, init=True): 31 | """Creates a logarithmic tracking schedule. 32 | 33 | Args: 34 | start ([type]): The starting value. 35 | end ([type]): The end value. 36 | steps (int, optional): Number of log spaced points. Defaults to 300. 37 | base (int, optional): Logarithmic base. Defaults to 10. 38 | init (bool, optional): Whether 0 should be included. Defaults to True. 39 | 40 | Returns: 41 | callable: Function that given the global_step returns whether it should track. 42 | """ 43 | # TODO Compute match and avoid array lookup 44 | scheduled_steps = torch.logspace(start, end, steps, base=base, dtype=int) 45 | 46 | if init: 47 | zero = torch.tensor([0], dtype=int) 48 | scheduled_steps = torch.cat((scheduled_steps, zero)).int() 49 | 50 | def schedule(global_step): 51 | return global_step in scheduled_steps 52 | 53 | return schedule 54 | -------------------------------------------------------------------------------- /docs/requirements_doc.txt: -------------------------------------------------------------------------------- 1 | # Sphinx builds the documentation 2 | sphinx 3 | 4 | # Sphinx extensions 5 | sphinx-rtd-theme 6 | sphinx-automodapi 7 | #sphinx-copybutton 8 | sphinx-notfound-page 9 | 10 | # Markdown conversion 11 | m2r2 -------------------------------------------------------------------------------- /docs/source/_static/01_basic_fmnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/01_basic_fmnist.png -------------------------------------------------------------------------------- /docs/source/_static/02_advanced_fmnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/02_advanced_fmnist.png -------------------------------------------------------------------------------- /docs/source/_static/03_deepobs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/03_deepobs.png -------------------------------------------------------------------------------- /docs/source/_static/LogoSquare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/LogoSquare.png -------------------------------------------------------------------------------- /docs/source/_static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/favicon.ico -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/Alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/Alpha.png -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/CABS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/CABS.png -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/Distances.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/Distances.png -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/EarlyStopping.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/EarlyStopping.png -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/GradientNorm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/GradientNorm.png -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/GradientTests.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/GradientTests.png -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/HessMaxEV.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/HessMaxEV.png -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/HessTrace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/HessTrace.png -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/Hist1d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/Hist1d.png -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/Hist2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/Hist2d.png -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/Hyperparameters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/Hyperparameters.png -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/MeanGSNR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/MeanGSNR.png -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/Performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/Performance.png -------------------------------------------------------------------------------- /docs/source/_static/instrument_previews/TIC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/TIC.png -------------------------------------------------------------------------------- /docs/source/_static/showcase.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/showcase.gif -------------------------------------------------------------------------------- /docs/source/_static/stylefile.css: -------------------------------------------------------------------------------- 1 | /* Stylefile shamelessly copied (with permission) from ProbNum (Jonathan Wenger et al.) */ 2 | 3 | /* Color variables */ 4 | :root { 5 | --pn-green: #107d79; 6 | --pn-lightgreen: #e9f4f4; 7 | --pn-brightgreen: #16b3ad; 8 | --pn-orange: #ff9933; 9 | --pn-darkred: #820000; 10 | --pn-darkgray: #333333; 11 | --pn-lightgray: #f2f2f2; 12 | --background-gray: #fcfcfc; 13 | } 14 | 15 | /* Font */ 16 | body { 17 | font-family: Roboto light, Helvetica, sans-serif; 18 | } 19 | 20 | .rst-content .toctree-wrapper > p.caption, h1, h2, h3, h4, h5, h6, legend { 21 | margin-top: 0; 22 | font-weight: 700; 23 | font-family: Open sans light, Helvetica, sans-serif; 24 | } 25 | 26 | /* Content */ 27 | .wy-nav-content { 28 | max-width: 1000px; 29 | } 30 | .wy-nav-content a:visited { 31 | color: var(--pn-green); 32 | } 33 | 34 | /* Sidebar */ 35 | 36 | /* Logo area */ 37 | .wy-nav-top { 38 | background: var(--pn-green); 39 | } 40 | .wy-nav-side { 41 | background: var(--pn-darkgray); 42 | } 43 | /* Home button */ 44 | .wy-side-nav-search .wy-dropdown > a, .wy-side-nav-search > a { 45 | color: var(--pn-darkgray); 46 | } 47 | /* Version number */ 48 | .wy-side-nav-search > div.version { 49 | color: var(--pn-darkgray); 50 | } 51 | /* Sidebar headers */ 52 | .wy-menu-vertical header, .wy-menu-vertical p.caption { 53 | color: var(--pn-brightgreen); 54 | } 55 | /* Sidebar links */ 56 | .wy-menu-vertical a:active { 57 | background-color: var(--pn-green); 58 | } 59 | .wy-menu-vertical a:hover { 60 | color: var(--pn-brightgreen); 61 | } 62 | 63 | /* Hyperlinks */ 64 | a { 65 | color: var(--pn-green); 66 | text-decoration: none; 67 | } 68 | 69 | a:hover { 70 | color: var(--pn-green); 71 | text-decoration: underline; 72 | } 73 | 74 | /* API Documentation */ 75 | .rst-content code, .rst-content tt { 76 | color: var(--pn-darkgray); 77 | } 78 | .rst-content .viewcode-link { 79 | color: -var(--pn-green); 80 | } 81 | 82 | /* Code block */ 83 | .highlight { 84 | background: var(--pn-lightgray); 85 | } 86 | 87 | /* Code inline */ 88 | .rst-content code.literal, .rst-content tt.literal{ 89 | color: var(--pn-green); 90 | padding: 1px 2px; 91 | background-color: var(--pn-lightgreen); 92 | border-radius: 4px; 93 | } 94 | 95 | /* Alert boxes (e.g. documentation "see also") */ 96 | html.writer-html4 .rst-content dl:not(.docutils) dl:not(.field-list) > dt, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) dl:not(.field-list) > dt { 97 | border-left: 3px solid grey; 98 | background: #fff0; 99 | color: #555; 100 | } 101 | 102 | /* Footnotes */ 103 | html.writer-html5 .rst-content dl.footnote > dd { 104 | margin: 0 0 0; 105 | } 106 | 107 | /* Block quotes */ 108 | .rst-content blockquote { 109 | margin-left: 0px; 110 | line-height: 24px; 111 | margin-bottom: 0px; 112 | border-left: 5px solid #f2f2f2; 113 | padding-left: 18px 114 | } 115 | 116 | /* Sphinx gallery */ 117 | div.sphx-glr-thumbcontainer { 118 | background: var(--pn-lightgray); 119 | border: solid var(--pn-lightgray) 1px; 120 | border-radius: 10px; 121 | margin-bottom: 15px; 122 | } 123 | div.sphx-glr-thumbcontainer:hover { 124 | box-shadow: 0 0 10px #107d7930; 125 | border: solid var(--pn-green) 1px; 126 | } 127 | 128 | /* Plot animations */ 129 | .anim-state label { 130 | margin-right: 8px; 131 | display: inline; 132 | } 133 | 134 | /* Author attribution */ 135 | .rst-content .section .authorlist ul li{ 136 | float: left; 137 | list-style: none; 138 | margin-left: 16px; 139 | margin-bottom: 0px; 140 | } 141 | 142 | .avatar { 143 | vertical-align: middle; 144 | border-radius: 10%; 145 | } 146 | 147 | /* ReadTheDocs */ 148 | .fa { 149 | font: Open Sans light, Helvetica, sans-serif; 150 | } 151 | 152 | .rst-versions { 153 | font-family: Open sans light, Helvetica, sans-serif; 154 | } 155 | 156 | .rst-versions .rst-current-version { 157 | color: var(--pn-brightgreen); 158 | } 159 | 160 | .keep-us-sustainable { 161 | border: 1px dotted var(--pn-lightgray); 162 | } -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.alpha_gauge.rst: -------------------------------------------------------------------------------- 1 | alpha_gauge 2 | =========== 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: alpha_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.cabs_gauge.rst: -------------------------------------------------------------------------------- 1 | cabs_gauge 2 | ========== 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: cabs_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.distance_gauge.rst: -------------------------------------------------------------------------------- 1 | distance_gauge 2 | ============== 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: distance_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.early_stopping_gauge.rst: -------------------------------------------------------------------------------- 1 | early_stopping_gauge 2 | ==================== 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: early_stopping_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.grad_norm_gauge.rst: -------------------------------------------------------------------------------- 1 | grad_norm_gauge 2 | =============== 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: grad_norm_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.gradient_tests_gauge.rst: -------------------------------------------------------------------------------- 1 | gradient_tests_gauge 2 | ==================== 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: gradient_tests_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.histogram_1d_gauge.rst: -------------------------------------------------------------------------------- 1 | histogram_1d_gauge 2 | ================== 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: histogram_1d_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.histogram_2d_gauge.rst: -------------------------------------------------------------------------------- 1 | histogram_2d_gauge 2 | ================== 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: histogram_2d_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.hyperparameter_gauge.rst: -------------------------------------------------------------------------------- 1 | hyperparameter_gauge 2 | ==================== 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: hyperparameter_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.max_ev_gauge.rst: -------------------------------------------------------------------------------- 1 | max_ev_gauge 2 | ============ 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: max_ev_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.mean_gsnr_gauge.rst: -------------------------------------------------------------------------------- 1 | mean_gsnr_gauge 2 | =============== 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: mean_gsnr_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.performance_gauge.rst: -------------------------------------------------------------------------------- 1 | performance_gauge 2 | ================= 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: performance_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.tic_gauge.rst: -------------------------------------------------------------------------------- 1 | tic_gauge 2 | ========= 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: tic_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.instruments.trace_gauge.rst: -------------------------------------------------------------------------------- 1 | trace_gauge 2 | =========== 3 | 4 | .. currentmodule:: cockpit.instruments 5 | 6 | .. autoclass:: trace_gauge 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.Alpha.rst: -------------------------------------------------------------------------------- 1 | Alpha 2 | ===== 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: Alpha 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.CABS.rst: -------------------------------------------------------------------------------- 1 | CABS 2 | ==== 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: CABS 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.Distance.rst: -------------------------------------------------------------------------------- 1 | Distance 2 | ======== 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: Distance 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.EarlyStopping.rst: -------------------------------------------------------------------------------- 1 | EarlyStopping 2 | ============= 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: EarlyStopping 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.GradHist1d.rst: -------------------------------------------------------------------------------- 1 | GradHist1d 2 | ========== 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: GradHist1d 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.GradHist2d.rst: -------------------------------------------------------------------------------- 1 | GradHist2d 2 | ========== 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: GradHist2d 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.GradNorm.rst: -------------------------------------------------------------------------------- 1 | GradNorm 2 | ======== 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: GradNorm 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.HessMaxEV.rst: -------------------------------------------------------------------------------- 1 | HessMaxEV 2 | ========= 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: HessMaxEV 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.HessTrace.rst: -------------------------------------------------------------------------------- 1 | HessTrace 2 | ========= 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: HessTrace 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.InnerTest.rst: -------------------------------------------------------------------------------- 1 | InnerTest 2 | ========= 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: InnerTest 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.Loss.rst: -------------------------------------------------------------------------------- 1 | Loss 2 | ==== 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: Loss 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.MeanGSNR.rst: -------------------------------------------------------------------------------- 1 | MeanGSNR 2 | ======== 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: MeanGSNR 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.NormTest.rst: -------------------------------------------------------------------------------- 1 | NormTest 2 | ======== 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: NormTest 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.OrthoTest.rst: -------------------------------------------------------------------------------- 1 | OrthoTest 2 | ========= 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: OrthoTest 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.Parameters.rst: -------------------------------------------------------------------------------- 1 | Parameters 2 | ========== 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: Parameters 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.TICDiag.rst: -------------------------------------------------------------------------------- 1 | TICDiag 2 | ======= 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: TICDiag 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.TICTrace.rst: -------------------------------------------------------------------------------- 1 | TICTrace 2 | ======== 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: TICTrace 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.Time.rst: -------------------------------------------------------------------------------- 1 | Time 2 | ==== 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: Time 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.quantities.UpdateSize.rst: -------------------------------------------------------------------------------- 1 | UpdateSize 2 | ========== 3 | 4 | .. currentmodule:: cockpit.quantities 5 | 6 | .. autoclass:: UpdateSize 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.utils.configuration.configuration.rst: -------------------------------------------------------------------------------- 1 | configuration 2 | ============= 3 | 4 | .. currentmodule:: cockpit.utils.configuration 5 | 6 | .. autofunction:: configuration 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.utils.configuration.quantities_cls_for_configuration.rst: -------------------------------------------------------------------------------- 1 | quantities_cls_for_configuration 2 | ================================ 3 | 4 | .. currentmodule:: cockpit.utils.configuration 5 | 6 | .. autofunction:: quantities_cls_for_configuration 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.utils.schedules.linear.rst: -------------------------------------------------------------------------------- 1 | linear 2 | ====== 3 | 4 | .. currentmodule:: cockpit.utils.schedules 5 | 6 | .. autofunction:: linear 7 | -------------------------------------------------------------------------------- /docs/source/api/automod/cockpit.utils.schedules.logarithmic.rst: -------------------------------------------------------------------------------- 1 | logarithmic 2 | =========== 3 | 4 | .. currentmodule:: cockpit.utils.schedules 5 | 6 | .. autofunction:: logarithmic 7 | -------------------------------------------------------------------------------- /docs/source/api/cockpit.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Cockpit 3 | ======= 4 | 5 | .. autoclass:: cockpit.Cockpit 6 | :members: -------------------------------------------------------------------------------- /docs/source/api/instruments.rst: -------------------------------------------------------------------------------- 1 | .. _instruments: 2 | 3 | =========== 4 | Instruments 5 | =========== 6 | 7 | **Cockpit** offers a large set of so called *instruments* that takes tracked 8 | *quantities* and visualizes them. 9 | 10 | .. automodsumm:: cockpit.instruments 11 | 12 | .. toctree:: 13 | :glob: 14 | :hidden: 15 | 16 | automod/cockpit.instruments.* -------------------------------------------------------------------------------- /docs/source/api/plotter.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Cockpit Plotter 3 | =============== 4 | 5 | .. autoclass:: cockpit.CockpitPlotter 6 | :members: -------------------------------------------------------------------------------- /docs/source/api/quantities.rst: -------------------------------------------------------------------------------- 1 | .. _quantities: 2 | 3 | ========== 4 | Quantities 5 | ========== 6 | 7 | **Cockpit** offers a large set of so called *quantities* that can be efficiently 8 | tracked during the training process. 9 | 10 | .. automodsumm:: cockpit.quantities 11 | 12 | .. toctree:: 13 | :glob: 14 | :hidden: 15 | 16 | automod/cockpit.quantities.* -------------------------------------------------------------------------------- /docs/source/api/utils.rst: -------------------------------------------------------------------------------- 1 | ===== 2 | Utils 3 | ===== 4 | 5 | 6 | Configuration 7 | ============= 8 | 9 | .. automodapi:: cockpit.utils.configuration 10 | :no-heading: 11 | 12 | 13 | Schedules 14 | ============= 15 | 16 | .. automodapi:: cockpit.utils.schedules 17 | :no-heading: 18 | -------------------------------------------------------------------------------- /docs/source/examples/03_deepobs.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | DeepOBS Example 3 | =============== 4 | 5 | **Cockpit** easily integrates with and can be used together with 6 | `DeepOBS `_. 7 | This will directly give you access to dozens of deep learning problems 8 | that you can explore with **Cockpit**. 9 | 10 | .. Note:: 11 | This example requires a `DeepOBS `__ 12 | and a `BackOBS `_ installation. 13 | You can install them by running 14 | 15 | .. code:: bash 16 | 17 | pip install 'git+https://github.com/fsschneider/DeepOBS.git@develop#egg=deepobs' 18 | 19 | and 20 | 21 | .. code:: bash 22 | 23 | pip install 'git+https://github.com/f-dangel/backobs.git@master#egg=backobs' 24 | 25 | Note, that currently, only the 1.2.0 beta version of DeepOBS supports PyTorch 26 | which will be installed by the above command. 27 | 28 | .. note:: 29 | 30 | In the following example, we will use an additional :download:`utility file 31 | <../../../examples/_utils_deepobs.py>` which automatically incorporates **Cockpit** 32 | with the DeepOBS training loop. 33 | 34 | Having the two `utility files from our repository 35 | `_ we can run 36 | 37 | .. code:: bash 38 | 39 | python 03_deepobs.py 40 | 41 | which exectues the following :download:`example script <../../../examples/03_deepobs.py>`: 42 | 43 | .. literalinclude:: ../../../examples/03_deepobs.py 44 | :language: python 45 | :linenos: 46 | 47 | Just like before, we can define a list of quantities (here we use the 48 | :mod:`~cockpit.utils.configuration` ``"full"``) that we this time pass to the 49 | ``DeepOBSRunner``. It will automatically pass it on to the :class:`~cockpit.Cockpit`. 50 | 51 | With the arguments of the ``runner.run()`` function, we can define whether we want 52 | the :class:`~cockpit.CockpitPlotter` plots to show and/or be stored. 53 | 54 | The **Cockpit** will show a status screen every few epochs, as well as writing to 55 | a logfile and saving its final plot, after training completed. 56 | 57 | .. code-block:: console 58 | 59 | $ python 03_deepobs.py 60 | 61 | ******************************** 62 | Evaluating after 0 of 15 epochs... 63 | TRAIN: loss 7.0372 64 | VALID: loss 7.07626 65 | TEST: loss 7.06894 66 | ******************************** 67 | [cockpit|plot] Showing current Cockpit. 68 | ******************************** 69 | Evaluating after 1 of 15 epochs... 70 | TRAIN: loss 7.00634 71 | VALID: loss 7.01242 72 | TEST: loss 7.00535 73 | ******************************** 74 | ******************************** 75 | Evaluating after 2 of 15 epochs... 76 | TRAIN: loss 6.98335 77 | VALID: loss 6.94937 78 | TEST: loss 6.94255 79 | ******************************** 80 | [cockpit|plot] Showing current Cockpit. 81 | 82 | [...] 83 | 84 | The fifty epochs on the `deep quadratic 85 | `_ 86 | problem will result in a **Cockpit** plot similar to this: 87 | 88 | .. image:: ../_static/03_deepobs.png 89 | :alt: Preview Cockpit DeepOBS Example 90 | -------------------------------------------------------------------------------- /docs/source/extract_instrument_previews.py: -------------------------------------------------------------------------------- 1 | """Quick file to extract previews of instruments.""" 2 | 3 | 4 | import os 5 | 6 | import matplotlib.pyplot as plt 7 | from matplotlib.transforms import Bbox 8 | 9 | import cockpit 10 | 11 | HERE_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | FILE_PATH = "_static" 13 | FILE_NAME = "instrument_preview_run" 14 | SECONDARY_FILE_NAME = FILE_NAME + "_secondary" 15 | SAVE_PATH = "instrument_previews" 16 | 17 | full_path = os.path.join(HERE_DIR, FILE_PATH, FILE_NAME) 18 | cp = cockpit.CockpitPlotter(secondary_screen=False) 19 | cp.plot(full_path, show_plot=False, show_log_iter=True, debug=True) 20 | 21 | # Store preview images 22 | preview_dict = { 23 | "Hyperparameters": [[3, 1], [11.2, 4.8]], 24 | "Performance": [[11.2, 1], [27.8, 4.8]], 25 | "GradientNorm": [[4, 5], [10.5, 7.75]], 26 | "Distances": [[4, 7.75], [10.6, 10.25]], 27 | "Alpha": [[4, 10.15], [10.7, 12.75]], 28 | "GradientTests": [[11.75, 10.15], [19, 12.75]], 29 | "HessMaxEV": [[20, 10.15], [26.5, 12.75]], 30 | "HessTrace": [[20, 7.75], [26.5, 10.25]], 31 | "TIC": [[20, 5.1], [26.5, 7.75]], 32 | "Hist1d": [[11.75, 7.75], [19, 10.25]], 33 | "Hist2d": [[11.75, 5.1], [18.5, 7.75]], 34 | } 35 | 36 | for instrument in preview_dict: 37 | plt.savefig( 38 | os.path.join(HERE_DIR, FILE_PATH, SAVE_PATH, instrument), 39 | bbox_inches=Bbox(preview_dict[instrument]), 40 | ) 41 | 42 | # plt.savefig( 43 | # os.path.join(HERE_DIR, FILE_PATH, SAVE_PATH, "cockpit.png"), 44 | # format="png", 45 | # bbox_inches=Bbox([[11.75, 5.1], [18.5, 7.75]]), 46 | # ) 47 | 48 | # Secondary Screen 49 | plt.close("all") 50 | cp = cockpit.CockpitPlotter(secondary_screen=True) 51 | secondary_full_path = os.path.join(HERE_DIR, FILE_PATH, SECONDARY_FILE_NAME) 52 | cp.plot(secondary_full_path, show_plot=False, show_log_iter=False) 53 | 54 | # Store preview images 55 | preview_dict = { 56 | "MeanGSNR": [[3.8, 10.45], [11, 13.15]], 57 | "CABS": [[11.5, 10.45], [19, 13.15]], 58 | "EarlyStopping": [[19, 10.45], [26.5, 13.15]], 59 | } 60 | 61 | for instrument in preview_dict: 62 | plt.savefig( 63 | os.path.join(HERE_DIR, FILE_PATH, SAVE_PATH, instrument), 64 | bbox_inches=Bbox(preview_dict[instrument]), 65 | ) 66 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Cockpit 3 | ======= 4 | 5 | |CI Status| |Lint Status| |Doc Status| |Coverage| |License| |Code Style| |arXiv| 6 | 7 | ---- 8 | 9 | .. code:: bash 10 | 11 | pip install cockpit-for-pytorch 12 | 13 | ---- 14 | 15 | **Cockpit is a visual and statistical debugger specifically designed for deep 16 | learning.** Training a deep neural network is often a pain! Successfully training 17 | such a network usually requires either years of intuition or expensive parameter 18 | searches involving lots of trial and error. Traditional debuggers provide only 19 | limited help: They can find *syntactical errors* but not *training bugs* such as 20 | ill-chosen learning rates. **Cockpit** offers a closer, more meaningful look 21 | into the training process with multiple well-chosen *instruments*. 22 | 23 | ---- 24 | 25 | .. image:: _static/showcase.gif 26 | 27 | 28 | To install **Cockpit** simply run 29 | 30 | .. code:: bash 31 | 32 | pip install cockpit-for-pytorch 33 | 34 | 35 | .. toctree:: 36 | :maxdepth: 1 37 | :caption: Getting Started 38 | 39 | examples/01_basic_fmnist 40 | examples/02_advanced_fmnist 41 | examples/03_deepobs 42 | introduction/good_to_know 43 | 44 | .. toctree:: 45 | :maxdepth: 1 46 | :caption: API Documentation 47 | 48 | api/cockpit 49 | api/plotter 50 | api/quantities 51 | api/instruments 52 | api/utils 53 | 54 | .. toctree:: 55 | :maxdepth: 1 56 | :caption: Other 57 | 58 | GitHub Repository 59 | other/contributors 60 | other/license 61 | other/changelog 62 | 63 | 64 | .. |CI Status| image:: https://github.com/f-dangel/cockpit/actions/workflows/CI.yml/badge.svg 65 | :target: https://github.com/f-dangel/cockpit/actions/workflows/CI.yml 66 | :alt: CI Status 67 | 68 | .. |Lint Status| image:: https://github.com/f-dangel/cockpit/actions/workflows/Lint.yml/badge.svg 69 | :target: https://github.com/f-dangel/cockpit/actions/workflows/Lint.yml 70 | :alt: Lint Status 71 | 72 | .. |Doc Status| image:: https://img.shields.io/readthedocs/cockpit/latest.svg?logo=read%20the%20docs&logoColor=white&label=Doc 73 | :target: https://cockpit.readthedocs.io 74 | :alt: Doc Status 75 | 76 | .. |Coverage| image:: https://coveralls.io/repos/github/f-dangel/cockpit/badge.svg?branch=main&t=piyZHm 77 | :target: https://coveralls.io/github/f-dangel/cockpit?branch=main 78 | :alt: CI Status 79 | 80 | .. |License| image:: https://img.shields.io/badge/License-MIT-green.svg 81 | :target: https://github.com/f-dangel/cockpit/blob/master/LICENSE 82 | :alt: License 83 | 84 | .. |Code Style| image:: https://img.shields.io/badge/code%20style-black-000000.svg 85 | :target: https://github.com/psf/black 86 | :alt: Code Style 87 | 88 | .. |arXiv| image:: https://img.shields.io/static/v1?logo=arxiv&logoColor=white&label=Preprint&message=2102.06604&color=B31B1B 89 | :target: https://arxiv.org/abs/2102.06604 90 | :alt: arXiv 91 | -------------------------------------------------------------------------------- /docs/source/introduction/good_to_know.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Good to Know 3 | ============ 4 | 5 | We try to make Cockpit's usage as easy and convenient as possible. Still, there 6 | are limitations. Here are some common pitfalls and recommendations. 7 | 8 | BackPACK 9 | ######## 10 | 11 | Most of Cockpit's quantities use BackPACK_ as the back-end for efficient 12 | computation. Please pay attention to the following points for smooth 13 | integration: 14 | 15 | - Don't forget to `extend the model and loss function 16 | `_ 17 | yourself [1]_ to activate BackPACK_. 18 | 19 | - Verify that your model architecture is `supported by BackPACK 20 | `_. 21 | 22 | - Your loss function must use ``"mean"`` reduction, that is the loss is of the 23 | following structure 24 | 25 | .. math:: 26 | 27 | \mathcal{L}(\mathbf{\theta}) = \frac{1}{N} \sum_{n=0}^{N} 28 | \ell(f(\mathbf{x}_n, \mathbf{\theta}), \mathbf{y}_n)\,. 29 | 30 | This avoids an ambiguous scale in individual gradients, which is documented in 31 | `BackPACK's individual gradient extension 32 | `_. 33 | Otherwise, Cockpit quantities will use incorrectly scaled individual gradients 34 | in their computation. 35 | 36 | It's also a good idea to read through BackPACK's `Good to know 37 | `_ section. 38 | 39 | Performance 40 | ########### 41 | 42 | Slow run time and memory errors are annoying. Here are some tweaks to reduce run 43 | time and memory consumption: 44 | 45 | - Use schedules to reduce the tracking frequency. You can specify custom 46 | schedules to literally select any iteration to be tracked, or rely on 47 | pre-defined :mod:`~cockpit.utils.schedules`. 48 | 49 | - Exclude :py:class:`GradHist2d ` from your quantities. The 50 | two-dimensional histogram implementation uses :py:func:`torch.scatter_add`, 51 | which can be slow on GPU due to atomic additions. 52 | 53 | - Exclude :py:class:`HessMaxEV ` from your quantities. It 54 | requires multiple Hessian-vector products, that are executed sequentially. 55 | Also, this requires the full computation be kept in memory. 56 | 57 | - Spot :ref:`quantities ` whose constructor contains a ``curvature`` 58 | argument. It defaults to the most accurate, but also most expensive type. You 59 | may want to sacrifice accuracy for memory and run time performance by 60 | selecting a cheaper option. 61 | 62 | 63 | .. [1] Leaving this responsibility to users is a deliberate choice, as Cockpit 64 | does not always need the package. Specific configurations, that are very 65 | limited though, work without BackPACK_ as they rely only on functionality 66 | built into PyTorch_. 67 | 68 | .. _BackPACK: https://backpack.pt/ 69 | .. _PyTorch: https://pytorch.org/ 70 | -------------------------------------------------------------------------------- /docs/source/other/changelog.rst: -------------------------------------------------------------------------------- 1 | .. mdinclude:: ../../../CHANGELOG.md -------------------------------------------------------------------------------- /docs/source/other/contributors.rst: -------------------------------------------------------------------------------- 1 | .. mdinclude:: ../../../AUTHORS.md -------------------------------------------------------------------------------- /docs/source/other/license.rst: -------------------------------------------------------------------------------- 1 | .. _license: 2 | 3 | ======= 4 | License 5 | ======= 6 | 7 | **Cockpit** is published under the MIT License. 8 | 9 | .. literalinclude:: ../../../LICENSE.txt 10 | :language: text -------------------------------------------------------------------------------- /examples/01_basic_fmnist.py: -------------------------------------------------------------------------------- 1 | """A basic example of using Cockpit with PyTorch for Fashion-MNIST.""" 2 | 3 | import torch 4 | from _utils_examples import fmnist_data 5 | from backpack import extend 6 | 7 | from cockpit import Cockpit, CockpitPlotter 8 | from cockpit.utils.configuration import configuration 9 | 10 | # Build Fashion-MNIST classifier 11 | fmnist_data = fmnist_data() 12 | model = extend(torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(784, 10))) 13 | loss_fn = extend(torch.nn.CrossEntropyLoss(reduction="mean")) 14 | individual_loss_fn = torch.nn.CrossEntropyLoss(reduction="none") 15 | 16 | # Create SGD Optimizer 17 | opt = torch.optim.SGD(model.parameters(), lr=1e-2) 18 | 19 | # Create Cockpit and a plotter 20 | cockpit = Cockpit(model.parameters(), quantities=configuration("full")) 21 | plotter = CockpitPlotter() 22 | 23 | # Main training loop 24 | max_steps, global_step = 5, 0 25 | for inputs, labels in iter(fmnist_data): 26 | opt.zero_grad() 27 | 28 | # forward pass 29 | outputs = model(inputs) 30 | loss = loss_fn(outputs, labels) 31 | losses = individual_loss_fn(outputs, labels) 32 | 33 | # backward pass 34 | with cockpit( 35 | global_step, 36 | info={ 37 | "batch_size": inputs.shape[0], 38 | "individual_losses": losses, 39 | "loss": loss, 40 | "optimizer": opt, 41 | }, 42 | ): 43 | loss.backward(create_graph=cockpit.create_graph(global_step)) 44 | 45 | # optimizer step 46 | opt.step() 47 | global_step += 1 48 | 49 | print(f"Step: {global_step:5d} | Loss: {loss.item():.4f}") 50 | 51 | plotter.plot(cockpit) 52 | 53 | if global_step >= max_steps: 54 | break 55 | 56 | plotter.plot(cockpit, block=True) 57 | -------------------------------------------------------------------------------- /examples/02_advanced_fmnist.py: -------------------------------------------------------------------------------- 1 | """A slightly advanced example of using Cockpit with PyTorch for Fashion-MNIST.""" 2 | 3 | import torch 4 | from _utils_examples import cnn, fmnist_data, get_logpath 5 | from backpack import extend, extensions 6 | 7 | from cockpit import Cockpit, CockpitPlotter, quantities 8 | from cockpit.utils import schedules 9 | 10 | # Build Fashion-MNIST classifier 11 | fmnist_data = fmnist_data() 12 | model = extend(cnn()) # Use a basic convolutional network 13 | loss_fn = extend(torch.nn.CrossEntropyLoss(reduction="mean")) 14 | individual_loss_fn = extend(torch.nn.CrossEntropyLoss(reduction="none")) 15 | 16 | # Create SGD Optimizer 17 | opt = torch.optim.SGD(model.parameters(), lr=5e-1) 18 | 19 | # Create Cockpit and a plotter 20 | # Customize the tracked quantities and their tracking schedule 21 | quantities = [ 22 | quantities.GradNorm(schedules.linear(interval=1)), 23 | quantities.Distance(schedules.linear(interval=1)), 24 | quantities.UpdateSize(schedules.linear(interval=1)), 25 | quantities.HessMaxEV(schedules.linear(interval=3)), 26 | quantities.GradHist1d(schedules.linear(interval=10), bins=10), 27 | ] 28 | cockpit = Cockpit(model.parameters(), quantities=quantities) 29 | plotter = CockpitPlotter() 30 | 31 | # Main training loop 32 | max_steps, global_step = 50, 0 33 | for inputs, labels in iter(fmnist_data): 34 | opt.zero_grad() 35 | 36 | # forward pass 37 | outputs = model(inputs) 38 | loss = loss_fn(outputs, labels) 39 | losses = individual_loss_fn(outputs, labels) 40 | 41 | # backward pass 42 | with cockpit( 43 | global_step, 44 | extensions.DiagHessian(), # Other BackPACK quantities can be computed as well 45 | info={ 46 | "batch_size": inputs.shape[0], 47 | "individual_losses": losses, 48 | "loss": loss, 49 | "optimizer": opt, 50 | }, 51 | ): 52 | loss.backward(create_graph=cockpit.create_graph(global_step)) 53 | 54 | # optimizer step 55 | opt.step() 56 | global_step += 1 57 | 58 | print(f"Step: {global_step:5d} | Loss: {loss.item():.4f}") 59 | 60 | if global_step % 10 == 0: 61 | plotter.plot( 62 | cockpit, 63 | savedir=get_logpath(), 64 | show_plot=False, 65 | save_plot=True, 66 | savename_append=str(global_step), 67 | ) 68 | 69 | if global_step >= max_steps: 70 | break 71 | 72 | # Write Cockpit to json file. 73 | cockpit.write(get_logpath()) 74 | 75 | # Plot results from file 76 | plotter.plot( 77 | get_logpath(), 78 | savedir=get_logpath(), 79 | show_plot=False, 80 | save_plot=True, 81 | savename_append="_final", 82 | ) 83 | -------------------------------------------------------------------------------- /examples/03_deepobs.py: -------------------------------------------------------------------------------- 1 | """An example of using Cockpit with DeepOBS.""" 2 | 3 | from _utils_deepobs import DeepOBSRunner 4 | from _utils_examples import get_logpath 5 | from torch.optim import SGD 6 | 7 | from cockpit.utils import configuration, schedules 8 | 9 | optimizer = SGD 10 | hyperparams = {"lr": {"type": float, "default": 0.001}} 11 | 12 | track_schedule = schedules.linear(10) 13 | plot_schedule = schedules.linear(20) 14 | quantities = configuration.configuration("full", track_schedule=track_schedule) 15 | 16 | runner = DeepOBSRunner(optimizer, hyperparams, quantities, plot_schedule=plot_schedule) 17 | 18 | 19 | def const_schedule(num_epochs): 20 | """Constant learning rate schedule.""" 21 | return lambda epoch: 1.0 22 | 23 | 24 | runner.run( 25 | testproblem="quadratic_deep", 26 | output_dir=get_logpath(), 27 | l2_reg=0.0, 28 | num_epochs=50, 29 | show_plots=True, 30 | save_plots=False, 31 | save_final_plot=True, 32 | save_animation=False, 33 | lr_schedule=const_schedule, 34 | ) 35 | -------------------------------------------------------------------------------- /examples/_utils_examples.py: -------------------------------------------------------------------------------- 1 | """Utility functions for the basic PyTorch example.""" 2 | 3 | import os 4 | import warnings 5 | 6 | import torch 7 | import torchvision 8 | 9 | HERE = os.path.abspath(__file__) 10 | HEREDIR = os.path.dirname(HERE) 11 | EXAMPLESDIR = os.path.dirname(HEREDIR) 12 | 13 | # Ignore the PyTorch warning that is irrelevant for us 14 | warnings.filterwarnings("ignore", message="Using a non-full backward hook ") 15 | 16 | 17 | def fmnist_data(batch_size=64, shuffle=True): 18 | """Returns a dataloader for Fashion-MNIST. 19 | 20 | Args: 21 | batch_size (int, optional): Batch size. Defaults to 64. 22 | shuffle (bool, optional): Whether the data should be shuffled. Defaults to True. 23 | 24 | Returns: 25 | torch.DataLoader: Dataloader for Fashion-MNIST data. 26 | """ 27 | # Additionally set the random seed for reproducability 28 | torch.manual_seed(0) 29 | 30 | fmnist_dataset = torchvision.datasets.FashionMNIST( 31 | root=os.path.join(EXAMPLESDIR, "data"), 32 | train=True, 33 | transform=torchvision.transforms.Compose( 34 | [ 35 | torchvision.transforms.ToTensor(), 36 | torchvision.transforms.Normalize((0.1307,), (0.3081,)), 37 | ] 38 | ), 39 | download=True, 40 | ) 41 | 42 | return torch.utils.data.dataloader.DataLoader( 43 | fmnist_dataset, 44 | batch_size=batch_size, 45 | shuffle=shuffle, 46 | ) 47 | 48 | 49 | def cnn(): 50 | """Basic Conv-Net for (Fashion-)MNIST.""" 51 | return torch.nn.Sequential( 52 | torch.nn.Conv2d(1, 11, 5), 53 | torch.nn.ReLU(), 54 | torch.nn.MaxPool2d(2, 2), 55 | torch.nn.Conv2d(11, 16, 5), 56 | torch.nn.ReLU(), 57 | torch.nn.MaxPool2d(2, 2), 58 | torch.nn.Flatten(), 59 | torch.nn.Linear(16 * 4 * 4, 80), 60 | torch.nn.ReLU(), 61 | torch.nn.Linear(80, 84), 62 | torch.nn.ReLU(), 63 | torch.nn.Linear(84, 10), 64 | ) 65 | 66 | 67 | def get_logpath(suffix=""): 68 | """Create a logpath and return it. 69 | 70 | Args: 71 | suffix (str, optional): suffix to add to the output. Defaults to "". 72 | 73 | Returns: 74 | str: Path to the logfile (output of Cockpit). 75 | """ 76 | save_dir = os.path.join(EXAMPLESDIR, "logfiles") 77 | os.makedirs(save_dir, exist_ok=True) 78 | log_path = os.path.join(save_dir, f"cockpit_output{suffix}") 79 | return log_path 80 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help 2 | .PHONY: black black-check flake8 3 | .PHONY: test 4 | .PHONY: conda-env 5 | .PHONY: black isort format 6 | .PHONY: black-check isort-check format-check, code-standard-check 7 | .PHONY: flake8 8 | .PHONY: pydocstyle-check 9 | .PHONY: darglint-check 10 | .PHONY: build-docs 11 | .PHONY: clean-all 12 | 13 | .DEFAULT: help 14 | help: 15 | @echo "test" 16 | @echo " Run pytest on the project and report coverage" 17 | @echo "black" 18 | @echo " Run black on the project" 19 | @echo "black-check" 20 | @echo " Check if black would change files" 21 | @echo "flake8" 22 | @echo " Run flake8 on the project" 23 | @echo "pydocstyle-check" 24 | @echo " Run pydocstyle on the project" 25 | @echo "darglint-check" 26 | @echo " Run darglint on the project" 27 | @echo "code-standard-check" 28 | @echo " Run all linters on the project to check quality standards." 29 | @echo "conda-env" 30 | @echo " Create conda environment 'cockpit' with dev setup" 31 | @echo "build-docs" 32 | @echo " Build the docs" 33 | @echo "clean-all" 34 | @echo " Removes all unnecessary files." 35 | 36 | ### TESTING ### 37 | # Run pytest with the Matplotlib backend agg to not show plots 38 | test: 39 | @MPLBACKEND=agg pytest -vx --cov=cockpit . 40 | 41 | ### LINTING & FORMATTING ### 42 | 43 | # Uses black.toml config instead of pyproject.toml to avoid pip issues. See 44 | # - https://github.com/psf/black/issues/683 45 | # - https://github.com/pypa/pip/pull/6370 46 | # - https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support 47 | black: 48 | @black . --config=black.toml 49 | 50 | black-check: 51 | @black . --config=black.toml --check 52 | 53 | flake8: 54 | @flake8 . 55 | 56 | pydocstyle-check: 57 | @pydocstyle --count . 58 | 59 | darglint-check: 60 | @darglint --verbosity 2 . 61 | 62 | isort: 63 | @isort . 64 | 65 | isort-check: 66 | @isort . --check 67 | 68 | format: 69 | @make black 70 | @make isort 71 | @make black-check 72 | 73 | format-check: black-check isort-check pydocstyle-check darglint-check 74 | 75 | code-standard-check: 76 | @make black 77 | @make isort 78 | @make black-check 79 | @make flake8 80 | @make pydocstyle-check 81 | 82 | ### CONDA ### 83 | conda-env: 84 | @conda env create --file .conda_env.yml 85 | 86 | ### DOCS ### 87 | build-docs: 88 | @find . -type d -name "automod" -exec rm -r {} + 89 | @cd docs && make clean && make html 90 | 91 | ### CLEAN ### 92 | clean-all: 93 | @find . -name '*.pyc' -delete 94 | @find . -name '*.pyo' -delete 95 | @find . -name '*~' -delete 96 | @find . -type d -name "__pycache__" -delete 97 | @rm -fr .pytest_cache/ 98 | @rm -fr .benchmarks/ -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = cockpit-for-pytorch 3 | url = https://github.com/f-dangel/cockpit 4 | description = A Practical Debugging Tool for Training Deep Neural Networks. 5 | long_description = file: README.md, CHANGELOG.md, LICENSE.txt 6 | long-description-content-type = text/markdown 7 | author = Frank Schneider and Felix Dangel 8 | author-email = f.schneider@uni-tuebingen.de 9 | license = MIT 10 | keywords = deep-learning, machine-learning, debugging 11 | platforms = any 12 | classifiers = 13 | Development Status :: 4 - Beta 14 | License :: OSI Approved :: MIT License 15 | Operating System :: OS Independent 16 | Programming Language :: Python :: 3.7 17 | Programming Language :: Python :: 3.8 18 | Programming Language :: Python :: 3.9 19 | 20 | [options] 21 | # Define which packages are required to run 22 | install_requires = 23 | json-tricks 24 | matplotlib>=3.4.0 25 | numpy 26 | pandas 27 | scipy 28 | seaborn 29 | torch 30 | backpack-for-pytorch>=1.3.0 31 | zip_safe = False 32 | packages = find: 33 | python_requires = >=3.7 34 | setup_requires = 35 | setuptools_scm 36 | 37 | [options.packages.find] 38 | exclude = test* 39 | 40 | [flake8] 41 | # Configure flake8 linting, minor differences to pytorch. 42 | select = B,C,E,F,P,W,B9 43 | max-line-length = 80 44 | max-complexity = 10 45 | ignore = 46 | # replaced by B950 (max-line-length + 10%) 47 | E501, # max-line-length 48 | # ignored because pytorch uses dict 49 | C408, # use {} instead of dict() 50 | # Not Black-compatible 51 | E203, # whitespace before : 52 | E231, # missing whitespace after ',' 53 | W291, # trailing whitespace 54 | W503, # line break before binary operator 55 | W504, # line break after binary operator 56 | exclude = docs,docs_src,build,.git,src,tex 57 | 58 | [pydocstyle] 59 | convention = google 60 | # exclude directories, see 61 | # https://github.com/PyCQA/pydocstyle/issues/363#issuecomment-625563088 62 | match_dir = ^(?!(docs|docs_src|build|.git|src|exp)).* 63 | match = .*\.py 64 | 65 | [isort] 66 | profile=black 67 | 68 | [darglint] 69 | docstring_style = google 70 | # short, long, full 71 | strictness = short -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup file for python_package_template using setup.cfg for configuration.""" 2 | import sys 3 | 4 | from pkg_resources import VersionConflict, require 5 | from setuptools import setup 6 | 7 | # Use setup.cfg if possible 8 | try: 9 | require("setuptools>=38.3") 10 | except VersionConflict: 11 | print("Error: version of setuptools is too old (<38.3)!") 12 | sys.exit(1) 13 | 14 | 15 | if __name__ == "__main__": 16 | setup(use_scm_version=True) 17 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``cockpit``.""" 2 | -------------------------------------------------------------------------------- /tests/settings.py: -------------------------------------------------------------------------------- 1 | """Problem settings shared across all submodules.""" 2 | 3 | 4 | SETTINGS = [] 5 | -------------------------------------------------------------------------------- /tests/test_bugs/test_issue5.py: -------------------------------------------------------------------------------- 1 | """Reproduces the bug described in https://github.com/f-dangel/cockpit/issues/5.""" 2 | 3 | from backpack import extend 4 | from torch import manual_seed, rand 5 | from torch.nn import Flatten, Linear, MSELoss, Sequential 6 | from torch.optim import Adam 7 | 8 | from cockpit import Cockpit 9 | from cockpit.quantities import Alpha, GradHist1d 10 | from cockpit.utils.schedules import linear 11 | 12 | 13 | def test_BatchGradTransformsHook_deletes_attribute_required_by_Alpha(): 14 | """If the optimizer is not SGD, ``Alpha`` needs access to ``.grad_batch``. 15 | 16 | But if an extension that uses ``BatchGradTransformsHook`` is used at the same time, 17 | it will delete the ``grad_batch`` attribute during the backward pass. Consequently, 18 | ``Alpha`` cannot access the attribute anymore. This leads to the error. 19 | """ 20 | manual_seed(0) 21 | 22 | N, D_in, D_out = 2, 3, 1 23 | model = extend(Sequential(Flatten(), Linear(D_in, D_out))) 24 | 25 | opt_not_sgd = Adam(model.parameters(), lr=1e-3) 26 | loss_fn = extend(MSELoss(reduction="mean")) 27 | individual_loss_fn = MSELoss(reduction="none") 28 | 29 | on_first = linear(1) 30 | alpha = Alpha(on_first) 31 | uses_BatchGradTransformsHook = GradHist1d(on_first) 32 | 33 | cockpit = Cockpit( 34 | model.parameters(), quantities=[alpha, uses_BatchGradTransformsHook] 35 | ) 36 | 37 | global_step = 0 38 | inputs, labels = rand(N, D_in), rand(N, D_out) 39 | 40 | # forward pass 41 | outputs = model(inputs) 42 | loss = loss_fn(outputs, labels) 43 | losses = individual_loss_fn(outputs, labels) 44 | 45 | # backward pass 46 | with cockpit( 47 | global_step, 48 | info={ 49 | "batch_size": N, 50 | "individual_losses": losses, 51 | "loss": loss, 52 | "optimizer": opt_not_sgd, 53 | }, 54 | ): 55 | loss.backward(create_graph=cockpit.create_graph(global_step)) 56 | -------------------------------------------------------------------------------- /tests/test_bugs/test_issue6.py: -------------------------------------------------------------------------------- 1 | """Reproduces the bug described in https://github.com/f-dangel/cockpit/issues/6.""" 2 | 3 | from backpack import extend 4 | from torch import Tensor, manual_seed, rand 5 | from torch.nn import Linear, Module, MSELoss, ReLU 6 | from torch.optim import SGD 7 | 8 | from cockpit import Cockpit 9 | from cockpit.quantities import GradHist1d 10 | from cockpit.utils.schedules import linear 11 | 12 | 13 | def test_extension_hook_executes_on_custom_module(): 14 | """Cockpit's extension hook is only skipped for known containers like Sequential. 15 | 16 | It will thus execute on custom containers and lead to crashes whenever a quantity 17 | that uses extension hooks is used. 18 | """ 19 | manual_seed(0) 20 | N, D_in, D_out = 2, 3, 1 21 | 22 | # NOTE Inheriting from Sequential passes 23 | class CustomModule(Module): 24 | """Custom container that is not skipped by the extension hook.""" 25 | 26 | def __init__(self): 27 | super().__init__() 28 | self.linear = Linear(D_in, D_out) 29 | self.relu = ReLU() 30 | 31 | def forward(self, x: Tensor) -> Tensor: 32 | return self.relu(self.linear(x)) 33 | 34 | uses_extension_hook = GradHist1d(linear(interval=1)) 35 | config = [uses_extension_hook] 36 | 37 | model = extend(CustomModule()) 38 | cockpit = Cockpit(model.parameters(), quantities=config) 39 | 40 | opt = SGD(model.parameters(), lr=0.1) 41 | 42 | loss_fn = extend(MSELoss(reduction="mean")) 43 | individual_loss_fn = MSELoss(reduction="none") 44 | 45 | global_step = 0 46 | inputs, labels = rand(N, D_in), rand(N, D_out) 47 | 48 | # forward pass 49 | outputs = model(inputs) 50 | loss = loss_fn(outputs, labels) 51 | losses = individual_loss_fn(outputs, labels) 52 | 53 | # backward pass 54 | with cockpit( 55 | global_step, 56 | info={ 57 | "batch_size": N, 58 | "individual_losses": losses, 59 | "loss": loss, 60 | "optimizer": opt, 61 | }, 62 | ): 63 | loss.backward(create_graph=cockpit.create_graph(global_step)) 64 | -------------------------------------------------------------------------------- /tests/test_cockpit/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``cockpit.Cockpit``.""" 2 | -------------------------------------------------------------------------------- /tests/test_cockpit/settings.py: -------------------------------------------------------------------------------- 1 | """Settings used by the tests in this submodule.""" 2 | 3 | import torch 4 | 5 | from tests.settings import SETTINGS as GLOBAL_SETTINGS 6 | from tests.utils.data import load_toy_data 7 | from tests.utils.models import load_toy_model 8 | from tests.utils.problem import make_problems_with_ids 9 | 10 | LOCAL_SETTINGS = [ 11 | { 12 | "data_fn": lambda: load_toy_data(batch_size=5), 13 | "model_fn": load_toy_model, 14 | "individual_loss_function_fn": lambda: torch.nn.CrossEntropyLoss( 15 | reduction="none" 16 | ), 17 | "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), 18 | "iterations": 1, 19 | }, 20 | ] 21 | SETTINGS = GLOBAL_SETTINGS + LOCAL_SETTINGS 22 | 23 | PROBLEMS, PROBLEMS_IDS = make_problems_with_ids(SETTINGS) 24 | -------------------------------------------------------------------------------- /tests/test_cockpit/test_automatic_call_track.py: -------------------------------------------------------------------------------- 1 | """Check that ``track`` is called when leaving the ``cockpit`` context.""" 2 | 3 | import pytest 4 | 5 | from cockpit import quantities 6 | from cockpit.utils.schedules import linear 7 | from tests.test_cockpit.settings import PROBLEMS, PROBLEMS_IDS 8 | from tests.utils.harness import SimpleTestHarness 9 | from tests.utils.problem import instantiate 10 | 11 | 12 | # TODO Reconsider purpose of this test 13 | class CustomTestHarness(SimpleTestHarness): 14 | """Custom Test Harness checking that track gets called when leaving the context.""" 15 | 16 | def check_in_context(self): 17 | """Check that track has not been called yet.""" 18 | assert self.problem.iterations == 1, "Test only checks the first step" 19 | global_step = 0 20 | assert global_step not in self.cockpit.quantities[0].output.keys() 21 | 22 | def check_after_context(self): 23 | """Verify that track has been called after the context is left.""" 24 | assert self.problem.iterations == 1, "Test only checks the first step" 25 | global_step = 0 26 | assert global_step in self.cockpit.quantities[0].output.keys() 27 | 28 | 29 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS) 30 | def test_track_writes_output(problem): 31 | """Test that a ``cockpit``'s ``track`` function writes to the output.""" 32 | quantity = quantities.Time(track_schedule=linear(1)) 33 | 34 | with instantiate(problem): 35 | testing_harness = CustomTestHarness(problem) 36 | cockpit_kwargs = {"quantities": [quantity]} 37 | testing_harness.test(cockpit_kwargs) 38 | -------------------------------------------------------------------------------- /tests/test_cockpit/test_backpack_extensions.py: -------------------------------------------------------------------------------- 1 | """Test if BackPACK quantities other than that required by cockpit can be computed.""" 2 | 3 | import pytest 4 | from backpack.extensions import DiagHessian 5 | 6 | from cockpit import quantities 7 | from cockpit.utils.schedules import linear 8 | from tests.test_cockpit.settings import PROBLEMS, PROBLEMS_IDS 9 | from tests.utils.harness import SimpleTestHarness 10 | from tests.utils.problem import instantiate 11 | 12 | 13 | class CustomTestHarness(SimpleTestHarness): 14 | """Create a Custom Test Harness that checks whether the BackPACK buffers exist.""" 15 | 16 | def check_in_context(self): 17 | """Check that the BackPACK buffers exists in the context.""" 18 | for param in self.problem.model.parameters(): 19 | # required by TICDiag and user 20 | assert hasattr(param, "diag_h") 21 | # required by TICDiag only 22 | assert hasattr(param, "grad_batch_transforms") 23 | assert "sum_grad_squared" in param.grad_batch_transforms 24 | 25 | def check_after_context(self): 26 | """Check that the buffers are not deleted when specified by the user.""" 27 | for param in self.problem.model.parameters(): 28 | assert hasattr(param, "diag_h") 29 | # not protected by user 30 | assert not hasattr(param, "grad_batch_transforms") 31 | 32 | 33 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS) 34 | def test_backpack_extensions(problem): 35 | """Check if backpack quantities can be computed inside cockpit.""" 36 | quantity = quantities.TICDiag(track_schedule=linear(1)) 37 | 38 | with instantiate(problem): 39 | testing_harness = CustomTestHarness(problem) 40 | cockpit_kwargs = {"quantities": [quantity]} 41 | testing_harness.test(cockpit_kwargs, DiagHessian()) 42 | -------------------------------------------------------------------------------- /tests/test_cockpit/test_multiple_batch_grad_transforms.py: -------------------------------------------------------------------------------- 1 | """Tests for using multiple batch grad transforms in Cockpit.""" 2 | 3 | import pytest 4 | 5 | from cockpit.cockpit import Cockpit 6 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook 7 | 8 | 9 | def test_merge_batch_grad_transforms(): 10 | """Test merging of multiple ``BatchGradTransforms``.""" 11 | bgt1 = BatchGradTransformsHook({"x": lambda t: t, "y": lambda t: t}) 12 | bgt2 = BatchGradTransformsHook({"v": lambda t: t, "w": lambda t: t}) 13 | 14 | merged_bgt = Cockpit._merge_batch_grad_transform_hooks([bgt1, bgt2]) 15 | assert isinstance(merged_bgt, BatchGradTransformsHook) 16 | 17 | merged_keys = ["x", "y", "v", "w"] 18 | assert len(merged_bgt._transforms.keys()) == len(merged_keys) 19 | 20 | for key in merged_keys: 21 | assert key in merged_bgt._transforms.keys() 22 | 23 | assert id(bgt1._transforms["x"]) == id(merged_bgt._transforms["x"]) 24 | assert id(bgt2._transforms["w"]) == id(merged_bgt._transforms["w"]) 25 | 26 | 27 | def test_merge_batch_grad_transforms_same_key_different_trafo(): 28 | """Merging ``BatchGradTransforms`` with same key but different trafo should fail.""" 29 | bgt1 = BatchGradTransformsHook({"x": lambda t: t, "y": lambda t: t}) 30 | bgt2 = BatchGradTransformsHook({"x": lambda t: t, "w": lambda t: t}) 31 | 32 | with pytest.raises(ValueError): 33 | _ = Cockpit._merge_batch_grad_transform_hooks([bgt1, bgt2]) 34 | 35 | 36 | def test_merge_batch_grad_transforms_same_key_same_trafo(): 37 | """Test merging multiple ``BatchGradTransforms`` with same key and same trafo.""" 38 | 39 | def func(t): 40 | return t 41 | 42 | bgt1 = BatchGradTransformsHook({"x": func}) 43 | bgt2 = BatchGradTransformsHook({"x": func}) 44 | 45 | merged = Cockpit._merge_batch_grad_transform_hooks([bgt1, bgt2]) 46 | 47 | assert len(merged._transforms.keys()) == 1 48 | assert id(merged._transforms["x"]) == id(func) 49 | -------------------------------------------------------------------------------- /tests/test_examples/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``examples``.""" 2 | -------------------------------------------------------------------------------- /tests/test_examples/test_examples.py: -------------------------------------------------------------------------------- 1 | """Test whether the example scripts run.""" 2 | 3 | import os 4 | import pathlib 5 | import runpy 6 | import sys 7 | 8 | import pytest 9 | 10 | SCRIPTS = sorted(pathlib.Path(__file__, "../../..", "examples").resolve().glob("*.py")) 11 | SCRIPTS_STR, SCRIPTS_ID = [], [] 12 | for s in SCRIPTS: 13 | if not str(s).split("/")[-1].startswith("_"): 14 | SCRIPTS_STR.append(str(s)) 15 | SCRIPTS_ID.append(str(s).split("/")[-1].split(".")[0]) 16 | 17 | 18 | @pytest.mark.parametrize("script", SCRIPTS_STR, ids=SCRIPTS_ID) 19 | def test_example_scripts(script): 20 | """Run a single example script. 21 | 22 | Args: 23 | script (str): Script that should be run. 24 | """ 25 | sys.path.append(os.path.dirname(script)) 26 | del sys.argv[1:] # Clear CLI arguments from pytest 27 | runpy.run_path(str(script)) 28 | -------------------------------------------------------------------------------- /tests/test_quantities/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``cockpit.quantities``.""" 2 | -------------------------------------------------------------------------------- /tests/test_quantities/adam_settings.py: -------------------------------------------------------------------------------- 1 | """Problem settings using Adam as optimizer. 2 | 3 | Some quantities are only defined for zero-momentum SGD (``CABS``, ``EarlyStopping``), 4 | or use a different computation strategy (``Alpha``). This behavior needs to be tested. 5 | """ 6 | 7 | import torch 8 | 9 | from tests.utils.data import load_toy_data 10 | from tests.utils.models import load_toy_model 11 | from tests.utils.problem import make_problems_with_ids 12 | 13 | ADAM_SETTINGS = [ 14 | { 15 | "data_fn": lambda: load_toy_data(batch_size=4), 16 | "model_fn": load_toy_model, 17 | "individual_loss_function_fn": lambda: torch.nn.CrossEntropyLoss( 18 | reduction="none" 19 | ), 20 | "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), 21 | "iterations": 5, 22 | "optimizer_fn": lambda parameters: torch.optim.Adam(parameters, lr=0.01), 23 | }, 24 | ] 25 | 26 | ADAM_PROBLEMS, ADAM_IDS = make_problems_with_ids(ADAM_SETTINGS) 27 | -------------------------------------------------------------------------------- /tests/test_quantities/settings.py: -------------------------------------------------------------------------------- 1 | """Settings used by the tests in this submodule.""" 2 | 3 | import torch 4 | 5 | from cockpit.utils.schedules import linear, logarithmic 6 | from tests.settings import SETTINGS as GLOBAL_SETTINGS 7 | from tests.utils.data import load_toy_data 8 | from tests.utils.models import load_toy_model 9 | from tests.utils.problem import make_problems_with_ids 10 | 11 | LOCAL_SETTINGS = [ 12 | { 13 | "data_fn": lambda: load_toy_data(batch_size=5), 14 | "model_fn": load_toy_model, 15 | "individual_loss_function_fn": lambda: torch.nn.CrossEntropyLoss( 16 | reduction="none" 17 | ), 18 | "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"), 19 | "iterations": 5, 20 | }, 21 | ] 22 | 23 | SETTINGS = GLOBAL_SETTINGS + LOCAL_SETTINGS 24 | 25 | PROBLEMS, PROBLEMS_IDS = make_problems_with_ids(SETTINGS) 26 | 27 | INDEPENDENT_RUNS = [True, False] 28 | INDEPENDENT_RUNS_IDS = [f"independent_runs={run}" for run in INDEPENDENT_RUNS] 29 | 30 | CPU_PROBLEMS = [] 31 | CPU_PROBLEMS_ID = [] 32 | for problem, problem_id in zip(PROBLEMS, PROBLEMS_IDS): 33 | if "cpu" in str(problem.device): 34 | CPU_PROBLEMS.append(problem) 35 | CPU_PROBLEMS_ID.append(problem_id) 36 | 37 | QUANTITY_KWARGS = [ 38 | { 39 | "track_schedule": linear(interval=1, offset=2), # [1, 3, 5, ...] 40 | "verbose": True, 41 | }, 42 | { 43 | "track_schedule": logarithmic( 44 | start=0, end=1, steps=4, init=False 45 | ), # [1, 2, 4, 10] 46 | "verbose": True, 47 | }, 48 | ] 49 | QUANTITY_KWARGS_IDS = [f"q_kwargs={q_kwargs}" for q_kwargs in QUANTITY_KWARGS] 50 | -------------------------------------------------------------------------------- /tests/test_quantities/test_cabs.py: -------------------------------------------------------------------------------- 1 | """Compare ``CABS`` quantity with ``torch.autograd``.""" 2 | 3 | import pytest 4 | 5 | from cockpit.context import get_individual_losses, get_optimizer 6 | from cockpit.quantities import CABS 7 | from cockpit.utils.optim import ComputeStep 8 | from tests.test_quantities.adam_settings import ADAM_IDS, ADAM_PROBLEMS 9 | from tests.test_quantities.settings import ( 10 | INDEPENDENT_RUNS, 11 | INDEPENDENT_RUNS_IDS, 12 | PROBLEMS, 13 | PROBLEMS_IDS, 14 | QUANTITY_KWARGS, 15 | QUANTITY_KWARGS_IDS, 16 | ) 17 | from tests.test_quantities.utils import autograd_diagonal_variance, get_compare_fn 18 | 19 | 20 | class AutogradCABS(CABS): 21 | """``torch.autograd`` implementation of ``CABS``.""" 22 | 23 | def extensions(self, global_step): 24 | """Return list of BackPACK extensions required for the computation. 25 | 26 | Args: 27 | global_step (int): The current iteration number. 28 | 29 | Returns: 30 | list: (Potentially empty) list with required BackPACK quantities. 31 | """ 32 | return [] 33 | 34 | def extension_hooks(self, global_step): 35 | """Return list of BackPACK extension hooks required for the computation. 36 | 37 | Args: 38 | global_step (int): The current iteration number. 39 | 40 | Returns: 41 | [callable]: List of required BackPACK extension hooks for the current 42 | iteration. 43 | """ 44 | return [] 45 | 46 | def create_graph(self, global_step): 47 | """Return whether access to the forward pass computation graph is needed. 48 | 49 | Args: 50 | global_step (int): The current iteration number. 51 | 52 | Returns: 53 | bool: ``True`` if the computation graph shall not be deleted, 54 | else ``False``. 55 | """ 56 | return self.should_compute(global_step) 57 | 58 | def _compute(self, global_step, params, batch_loss): 59 | """Evaluate the CABS criterion. 60 | 61 | Args: 62 | global_step (int): The current iteration number. 63 | params ([torch.Tensor]): List of torch.Tensors holding the network's 64 | parameters. 65 | batch_loss (torch.Tensor): Mini-batch loss from current step. 66 | 67 | Returns: 68 | float: Evaluated CABS criterion. 69 | 70 | Raises: 71 | ValueError: If the optimizer differs from SGD with default arguments. 72 | """ 73 | optimizer = get_optimizer(global_step) 74 | if not ComputeStep.is_sgd_default_kwargs(optimizer): 75 | raise ValueError("This criterion only supports zero-momentum SGD.") 76 | 77 | losses = get_individual_losses(global_step) 78 | batch_axis = 0 79 | trace_variance = autograd_diagonal_variance( 80 | losses, params, concat=True, unbiased=False 81 | ).sum(batch_axis) 82 | lr = self.get_lr(optimizer) 83 | 84 | return (lr * trace_variance / batch_loss).item() 85 | 86 | 87 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS) 88 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS) 89 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS) 90 | def test_cabs(problem, independent_runs, q_kwargs): 91 | """Compare BackPACK and ``torch.autograd`` implementation of CABS. 92 | 93 | Args: 94 | problem (tests.utils.Problem): Settings for train loop. 95 | independent_runs (bool): Whether to use to separate runs to compute the 96 | output of every quantity. 97 | q_kwargs (dict): Keyword arguments handed over to both quantities. 98 | """ 99 | compare_fn = get_compare_fn(independent_runs) 100 | compare_fn(problem, (CABS, AutogradCABS), q_kwargs) 101 | 102 | 103 | @pytest.mark.parametrize("problem", ADAM_PROBLEMS, ids=ADAM_IDS) 104 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS) 105 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS) 106 | def test_cabs_no_adam(problem, independent_runs, q_kwargs): 107 | """Verify Adam is not supported by CABS criterion. 108 | 109 | Args: 110 | problem (tests.utils.Problem): Settings for train loop. 111 | independent_runs (bool): Whether to use to separate runs to compute the 112 | output of every quantity. 113 | q_kwargs (dict): Keyword arguments handed over to both quantities. 114 | """ 115 | with pytest.raises(ValueError): 116 | test_cabs(problem, independent_runs, q_kwargs) 117 | -------------------------------------------------------------------------------- /tests/test_quantities/test_early_stopping.py: -------------------------------------------------------------------------------- 1 | """Compare ``EarlyStopping`` quantity with ``torch.autograd``.""" 2 | 3 | import pytest 4 | import torch 5 | 6 | from cockpit.context import get_batch_size, get_individual_losses 7 | from cockpit.quantities import EarlyStopping 8 | from tests.test_quantities.adam_settings import ADAM_IDS, ADAM_PROBLEMS 9 | from tests.test_quantities.settings import ( 10 | INDEPENDENT_RUNS, 11 | INDEPENDENT_RUNS_IDS, 12 | PROBLEMS, 13 | PROBLEMS_IDS, 14 | QUANTITY_KWARGS, 15 | QUANTITY_KWARGS_IDS, 16 | ) 17 | from tests.test_quantities.utils import autograd_diagonal_variance, get_compare_fn 18 | 19 | 20 | class AutogradEarlyStopping(EarlyStopping): 21 | """``torch.autograd`` implementation of ``EarlyStopping``.""" 22 | 23 | def extensions(self, global_step): 24 | """Return list of BackPACK extensions required for the computation. 25 | 26 | Args: 27 | global_step (int): The current iteration number. 28 | 29 | Returns: 30 | list: (Potentially empty) list with required BackPACK quantities. 31 | """ 32 | return [] 33 | 34 | def extension_hooks(self, global_step): 35 | """Return list of BackPACK extension hooks required for the computation. 36 | 37 | Args: 38 | global_step (int): The current iteration number. 39 | 40 | Returns: 41 | [callable]: List of required BackPACK extension hooks for the current 42 | iteration. 43 | """ 44 | return [] 45 | 46 | def create_graph(self, global_step): 47 | """Return whether access to the forward pass computation graph is needed. 48 | 49 | Args: 50 | global_step (int): The current iteration number. 51 | 52 | Returns: 53 | bool: ``True`` if the computation graph shall not be deleted, 54 | else ``False``. 55 | """ 56 | return self.should_compute(global_step) 57 | 58 | def _compute(self, global_step, params, batch_loss): 59 | """Evaluate the early stopping criterion. 60 | 61 | Args: 62 | global_step (int): The current iteration number. 63 | params ([torch.Tensor]): List of torch.Tensors holding the network's 64 | parameters. 65 | batch_loss (torch.Tensor): Mini-batch loss from current step. 66 | 67 | Returns: 68 | float: Early stopping criterion. 69 | """ 70 | grad_squared = torch.cat([p.grad.flatten() for p in params]) ** 2 71 | 72 | losses = get_individual_losses(global_step) 73 | diag_variance = autograd_diagonal_variance(losses, params, concat=True) 74 | 75 | B = get_batch_size(global_step) 76 | 77 | return 1 - B * (grad_squared / (diag_variance + self._epsilon)).mean().item() 78 | 79 | 80 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS) 81 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS) 82 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS) 83 | def test_early_stopping(problem, independent_runs, q_kwargs): 84 | """Compare BackPACK and ``torch.autograd`` implementation of EarlyStopping. 85 | 86 | Args: 87 | problem (tests.utils.Problem): Settings for train loop. 88 | independent_runs (bool): Whether to use to separate runs to compute the 89 | output of every quantity. 90 | q_kwargs (dict): Keyword arguments handed over to both quantities. 91 | """ 92 | rtol, atol = 5e-3, 1e-5 93 | 94 | compare_fn = get_compare_fn(independent_runs) 95 | compare_fn( 96 | problem, (EarlyStopping, AutogradEarlyStopping), q_kwargs, rtol=rtol, atol=atol 97 | ) 98 | 99 | 100 | @pytest.mark.parametrize("problem", ADAM_PROBLEMS, ids=ADAM_IDS) 101 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS) 102 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS) 103 | def test_early_stopping_no_adam(problem, independent_runs, q_kwargs): 104 | """Verify Adam is not supported by EarlyStopping criterion. 105 | 106 | Args: 107 | problem (tests.utils.Problem): Settings for train loop. 108 | independent_runs (bool): Whether to use to separate runs to compute the 109 | output of every quantity. 110 | q_kwargs (dict): Keyword arguments handed over to both quantities. 111 | """ 112 | with pytest.raises(ValueError): 113 | test_early_stopping(problem, independent_runs, q_kwargs) 114 | -------------------------------------------------------------------------------- /tests/test_quantities/test_hess_max_ev.py: -------------------------------------------------------------------------------- 1 | """Compare ``HessMaxEV`` quantity with ``torch.autograd``.""" 2 | 3 | import warnings 4 | 5 | import pytest 6 | 7 | from cockpit.quantities import HessMaxEV 8 | from tests.test_quantities.settings import ( 9 | INDEPENDENT_RUNS, 10 | INDEPENDENT_RUNS_IDS, 11 | PROBLEMS, 12 | PROBLEMS_IDS, 13 | QUANTITY_KWARGS, 14 | QUANTITY_KWARGS_IDS, 15 | ) 16 | from tests.test_quantities.utils import ( 17 | autograd_hessian_maximum_eigenvalue, 18 | get_compare_fn, 19 | ) 20 | 21 | 22 | class AutogradHessMaxEV(HessMaxEV): 23 | """``torch.autograd`` implementation of ``HessMaxEV``. 24 | 25 | Requires storing the full Hessian in memory and can hence only be applied to small 26 | networks. 27 | """ 28 | 29 | def _compute(self, global_step, params, batch_loss): 30 | """Evaluate the maximum mini-batch loss Hessian eigenvalue. 31 | 32 | Args: 33 | global_step (int): The current iteration number. 34 | params ([torch.Tensor]): List of torch.Tensors holding the network's 35 | parameters. 36 | batch_loss (torch.Tensor): Mini-batch loss from current step. 37 | 38 | Returns: 39 | float: Maximum Hessian eigenvalue (of the mini-batch loss). 40 | """ 41 | self._maybe_warn_dimension(sum(p.numel() for p in params)) 42 | 43 | return autograd_hessian_maximum_eigenvalue(batch_loss, params).item() 44 | 45 | @staticmethod 46 | def _maybe_warn_dimension(dim): 47 | """Warn user if the Hessian is large.""" 48 | MAX_DIM = 1000 49 | 50 | if dim >= MAX_DIM: 51 | warnings.warn(f"Computing Hessians of size ({dim}, {dim}) is expensive") 52 | 53 | 54 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS) 55 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS) 56 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS) 57 | def test_hess_max_ev(problem, independent_runs, q_kwargs): 58 | """Compare BackPACK and ``torch.autograd`` implementation of Hessian max eigenvalue. 59 | 60 | Args: 61 | problem (tests.utils.Problem): Settings for train loop. 62 | independent_runs (bool): Whether to use to separate runs to compute the 63 | output of every quantity. 64 | q_kwargs (dict): Keyword arguments handed over to both quantities. 65 | """ 66 | atol, rtol = 1e-4, 1e-2 67 | 68 | compare_fn = get_compare_fn(independent_runs) 69 | compare_fn(problem, (HessMaxEV, AutogradHessMaxEV), q_kwargs, rtol=rtol, atol=atol) 70 | -------------------------------------------------------------------------------- /tests/test_quantities/test_hess_trace.py: -------------------------------------------------------------------------------- 1 | """Compare ``HessTrace`` quantity with ``torch.autograd``.""" 2 | 3 | import pytest 4 | 5 | from cockpit.quantities import HessTrace 6 | from tests.test_quantities.settings import ( 7 | INDEPENDENT_RUNS, 8 | INDEPENDENT_RUNS_IDS, 9 | PROBLEMS, 10 | PROBLEMS_IDS, 11 | QUANTITY_KWARGS, 12 | QUANTITY_KWARGS_IDS, 13 | ) 14 | from tests.test_quantities.utils import autograd_diag_hessian, get_compare_fn 15 | 16 | 17 | class AutogradHessTrace(HessTrace): 18 | """``torch.autograd`` implementation of ``HessTrace``.""" 19 | 20 | def extensions(self, global_step): 21 | """Return list of BackPACK extensions required for the computation. 22 | 23 | Args: 24 | global_step (int): The current iteration number. 25 | 26 | Returns: 27 | list: (Potentially empty) list with required BackPACK quantities. 28 | """ 29 | return [] 30 | 31 | def create_graph(self, global_step): 32 | """Return whether access to the forward pass computation graph is needed. 33 | 34 | Args: 35 | global_step (int): The current iteration number. 36 | 37 | Returns: 38 | bool: ``True`` if the computation graph shall not be deleted, 39 | else ``False``. 40 | """ 41 | return self.should_compute(global_step) 42 | 43 | def _compute(self, global_step, params, batch_loss): 44 | """Evaluate the trace of the Hessian at the current point. 45 | 46 | Args: 47 | global_step (int): The current iteration number. 48 | params ([torch.Tensor]): List of torch.Tensors holding the network's 49 | parameters. 50 | batch_loss (torch.Tensor): Mini-batch loss from current step. 51 | 52 | Returns: 53 | list: Traces of the Hessian by layer. 54 | """ 55 | return [ 56 | diag_h.sum().item() for diag_h in autograd_diag_hessian(batch_loss, params) 57 | ] 58 | 59 | 60 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS) 61 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS) 62 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS) 63 | def test_hess_trace(problem, independent_runs, q_kwargs): 64 | """Compare BackPACK and ``torch.autograd`` implementation of Hessian trace. 65 | 66 | Args: 67 | problem (tests.utils.Problem): Settings for train loop. 68 | independent_runs (bool): Whether to use to separate runs to compute the 69 | output of every quantity. 70 | q_kwargs (dict): Keyword arguments handed over to both quantities. 71 | """ 72 | compare_fn = get_compare_fn(independent_runs) 73 | compare_fn(problem, (HessTrace, AutogradHessTrace), q_kwargs) 74 | -------------------------------------------------------------------------------- /tests/test_quantities/test_inner_test.py: -------------------------------------------------------------------------------- 1 | """Compare ``InnerTest`` quantity with ``torch.autograd``.""" 2 | 3 | import pytest 4 | import torch 5 | 6 | from cockpit.context import get_batch_size, get_individual_losses 7 | from cockpit.quantities import InnerTest 8 | from tests.test_quantities.settings import ( 9 | INDEPENDENT_RUNS, 10 | INDEPENDENT_RUNS_IDS, 11 | PROBLEMS, 12 | PROBLEMS_IDS, 13 | QUANTITY_KWARGS, 14 | QUANTITY_KWARGS_IDS, 15 | ) 16 | from tests.test_quantities.utils import autograd_individual_gradients, get_compare_fn 17 | 18 | 19 | class AutogradInnerTest(InnerTest): 20 | """``torch.autograd`` implementation of ``InnerTest``.""" 21 | 22 | def extensions(self, global_step): 23 | """Return list of BackPACK extensions required for the computation. 24 | 25 | Args: 26 | global_step (int): The current iteration number. 27 | 28 | Returns: 29 | list: (Potentially empty) list with required BackPACK quantities. 30 | """ 31 | return [] 32 | 33 | def extension_hooks(self, global_step): 34 | """Return list of BackPACK extension hooks required for the computation. 35 | 36 | Args: 37 | global_step (int): The current iteration number. 38 | 39 | Returns: 40 | [callable]: List of required BackPACK extension hooks for the current 41 | iteration. 42 | """ 43 | return [] 44 | 45 | def create_graph(self, global_step): 46 | """Return whether access to the forward pass computation graph is needed. 47 | 48 | Args: 49 | global_step (int): The current iteration number. 50 | 51 | Returns: 52 | bool: ``True`` if the computation graph shall not be deleted, 53 | else ``False``. 54 | """ 55 | return self.should_compute(global_step) 56 | 57 | def _compute(self, global_step, params, batch_loss): 58 | """Evaluate the inner-product test. 59 | 60 | Args: 61 | global_step (int): The current iteration number. 62 | params ([torch.Tensor]): List of torch.Tensors holding the network's 63 | parameters. 64 | batch_loss (torch.Tensor): Mini-batch loss from current step. 65 | 66 | Returns: 67 | float: Resut of the inner-product test. 68 | """ 69 | losses = get_individual_losses(global_step) 70 | individual_gradients_flat = autograd_individual_gradients( 71 | losses, params, concat=True 72 | ) 73 | grad = torch.cat([p.grad.flatten() for p in params]) 74 | 75 | projections = torch.einsum("ni,i->n", individual_gradients_flat, grad) 76 | grad_norm = grad.norm() 77 | 78 | N_axis = 0 79 | batch_size = get_batch_size(global_step) 80 | 81 | return ( 82 | ( 83 | 1 84 | / (batch_size * (batch_size - 1)) 85 | * ((projections ** 2).sum(N_axis) / grad_norm ** 4 - batch_size) 86 | ) 87 | .sqrt() 88 | .item() 89 | ) 90 | 91 | 92 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS) 93 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS) 94 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS) 95 | def test_inner_test(problem, independent_runs, q_kwargs): 96 | """Compare BackPACK and ``torch.autograd`` implementation of InnerTest. 97 | 98 | Args: 99 | problem (tests.utils.Problem): Settings for train loop. 100 | independent_runs (bool): Whether to use to separate runs to compute the 101 | output of every quantity. 102 | q_kwargs (dict): Keyword arguments handed over to both quantities. 103 | """ 104 | compare_fn = get_compare_fn(independent_runs) 105 | compare_fn(problem, (InnerTest, AutogradInnerTest), q_kwargs) 106 | -------------------------------------------------------------------------------- /tests/test_quantities/test_mean_gsnr.py: -------------------------------------------------------------------------------- 1 | """Compare ``MeanGSNR`` quantity with ``torch.autograd``.""" 2 | 3 | import pytest 4 | import torch 5 | 6 | from cockpit.context import get_individual_losses 7 | from cockpit.quantities import MeanGSNR 8 | from tests.test_quantities.settings import ( 9 | INDEPENDENT_RUNS, 10 | INDEPENDENT_RUNS_IDS, 11 | PROBLEMS, 12 | PROBLEMS_IDS, 13 | QUANTITY_KWARGS, 14 | QUANTITY_KWARGS_IDS, 15 | ) 16 | from tests.test_quantities.utils import autograd_individual_gradients, get_compare_fn 17 | 18 | 19 | class AutogradMeanGSNR(MeanGSNR): 20 | """``torch.autograd`` implementation of ``MeanGSNR``.""" 21 | 22 | def extensions(self, global_step): 23 | """Return list of BackPACK extensions required for the computation. 24 | 25 | Args: 26 | global_step (int): The current iteration number. 27 | 28 | Returns: 29 | list: (Potentially empty) list with required BackPACK quantities. 30 | """ 31 | return [] 32 | 33 | def extension_hooks(self, global_step): 34 | """Return list of BackPACK extension hooks required for the computation. 35 | 36 | Args: 37 | global_step (int): The current iteration number. 38 | 39 | Returns: 40 | [callable]: List of required BackPACK extension hooks for the current 41 | iteration. 42 | """ 43 | return [] 44 | 45 | def create_graph(self, global_step): 46 | """Return whether access to the forward pass computation graph is needed. 47 | 48 | Args: 49 | global_step (int): The current iteration number. 50 | 51 | Returns: 52 | bool: ``True`` if the computation graph shall not be deleted, 53 | else ``False``. 54 | """ 55 | return self.should_compute(global_step) 56 | 57 | def _compute(self, global_step, params, batch_loss): 58 | """Evaluate the MeanGSNR. 59 | 60 | Args: 61 | global_step (int): The current iteration number. 62 | params ([torch.Tensor]): List of torch.Tensors holding the network's 63 | parameters. 64 | batch_loss (torch.Tensor): Mini-batch loss from current step. 65 | 66 | Returns: 67 | float: Mean GSNR of the current iteration. 68 | """ 69 | losses = get_individual_losses(global_step) 70 | individual_gradients_flat = autograd_individual_gradients( 71 | losses, params, concat=True 72 | ) 73 | 74 | grad_squared = torch.cat([p.grad.flatten() for p in params]) ** 2 75 | 76 | N_axis = 0 77 | second_moment_flat = (individual_gradients_flat ** 2).mean(N_axis) 78 | 79 | gsnr = grad_squared / (second_moment_flat - grad_squared + self._epsilon) 80 | 81 | return gsnr.mean().item() 82 | 83 | 84 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS) 85 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS) 86 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS) 87 | def test_mean_gsnr(problem, independent_runs, q_kwargs): 88 | """Compare BackPACK and ``torch.autograd`` implementation of MeanGSNR. 89 | 90 | Args: 91 | problem (tests.utils.Problem): Settings for train loop. 92 | independent_runs (bool): Whether to use to separate runs to compute the 93 | output of every quantity. 94 | q_kwargs (dict): Keyword arguments handed over to both quantities. 95 | """ 96 | rtol, atol = 5e-3, 1e-5 97 | 98 | compare_fn = get_compare_fn(independent_runs) 99 | compare_fn(problem, (MeanGSNR, AutogradMeanGSNR), q_kwargs, rtol=rtol, atol=atol) 100 | -------------------------------------------------------------------------------- /tests/test_quantities/test_norm_test.py: -------------------------------------------------------------------------------- 1 | """Compare ``NormTest`` quantity with ``torch.autograd``.""" 2 | 3 | import pytest 4 | import torch 5 | 6 | from cockpit.context import get_batch_size, get_individual_losses 7 | from cockpit.quantities import NormTest 8 | from tests.test_quantities.settings import ( 9 | INDEPENDENT_RUNS, 10 | INDEPENDENT_RUNS_IDS, 11 | PROBLEMS, 12 | PROBLEMS_IDS, 13 | QUANTITY_KWARGS, 14 | QUANTITY_KWARGS_IDS, 15 | ) 16 | from tests.test_quantities.utils import autograd_individual_gradients, get_compare_fn 17 | 18 | 19 | class AutogradNormTest(NormTest): 20 | """``torch.autograd`` implementation of ``NormTest``.""" 21 | 22 | def extensions(self, global_step): 23 | """Return list of BackPACK extensions required for the computation. 24 | 25 | Args: 26 | global_step (int): The current iteration number. 27 | 28 | Returns: 29 | list: (Potentially empty) list with required BackPACK quantities. 30 | """ 31 | return [] 32 | 33 | def extension_hooks(self, global_step): 34 | """Return list of BackPACK extension hooks required for the computation. 35 | 36 | Args: 37 | global_step (int): The current iteration number. 38 | 39 | Returns: 40 | [callable]: List of required BackPACK extension hooks for the current 41 | iteration. 42 | """ 43 | return [] 44 | 45 | def create_graph(self, global_step): 46 | """Return whether access to the forward pass computation graph is needed. 47 | 48 | Args: 49 | global_step (int): The current iteration number. 50 | 51 | Returns: 52 | bool: ``True`` if the computation graph shall not be deleted, 53 | else ``False``. 54 | """ 55 | return self.should_compute(global_step) 56 | 57 | def _compute(self, global_step, params, batch_loss): 58 | """Evaluate the norm test. 59 | 60 | Args: 61 | global_step (int): The current iteration number. 62 | params ([torch.Tensor]): List of torch.Tensors holding the network's 63 | parameters. 64 | batch_loss (torch.Tensor): Mini-batch loss from current step. 65 | 66 | Returns: 67 | float: Result of the norm test. 68 | """ 69 | losses = get_individual_losses(global_step) 70 | individual_gradients_flat = autograd_individual_gradients( 71 | losses, params, concat=True 72 | ) 73 | sum_of_squares = (individual_gradients_flat ** 2).sum() 74 | 75 | grad_norm = torch.cat([p.grad.flatten() for p in params]).norm() 76 | 77 | batch_size = get_batch_size(global_step) 78 | 79 | return ( 80 | ( 81 | 1 82 | / (batch_size * (batch_size - 1)) 83 | * (sum_of_squares / grad_norm ** 2 - batch_size) 84 | ) 85 | .sqrt() 86 | .item() 87 | ) 88 | 89 | 90 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS) 91 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS) 92 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS) 93 | def test_norm_test(problem, independent_runs, q_kwargs): 94 | """Compare BackPACK and ``torch.autograd`` implementation of NormTest. 95 | 96 | Args: 97 | problem (tests.utils.Problem): Settings for train loop. 98 | independent_runs (bool): Whether to use to separate runs to compute the 99 | output of every quantity. 100 | q_kwargs (dict): Keyword arguments handed over to both quantities. 101 | """ 102 | compare_fn = get_compare_fn(independent_runs) 103 | compare_fn(problem, (NormTest, AutogradNormTest), q_kwargs) 104 | -------------------------------------------------------------------------------- /tests/test_quantities/test_ortho_test.py: -------------------------------------------------------------------------------- 1 | """Compare ``OrthoTest`` quantity with ``torch.autograd``.""" 2 | 3 | import pytest 4 | import torch 5 | 6 | from cockpit.context import get_batch_size, get_individual_losses 7 | from cockpit.quantities import OrthoTest 8 | from tests.test_quantities.settings import ( 9 | INDEPENDENT_RUNS, 10 | INDEPENDENT_RUNS_IDS, 11 | PROBLEMS, 12 | PROBLEMS_IDS, 13 | QUANTITY_KWARGS, 14 | QUANTITY_KWARGS_IDS, 15 | ) 16 | from tests.test_quantities.utils import autograd_individual_gradients, get_compare_fn 17 | 18 | 19 | class AutogradOrthoTest(OrthoTest): 20 | """``torch.autograd`` implementation of ``OrthoTest``.""" 21 | 22 | def extensions(self, global_step): 23 | """Return list of BackPACK extensions required for the computation. 24 | 25 | Args: 26 | global_step (int): The current iteration number. 27 | 28 | Returns: 29 | list: (Potentially empty) list with required BackPACK quantities. 30 | """ 31 | return [] 32 | 33 | def extension_hooks(self, global_step): 34 | """Return list of BackPACK extension hooks required for the computation. 35 | 36 | Args: 37 | global_step (int): The current iteration number. 38 | 39 | Returns: 40 | [callable]: List of required BackPACK extension hooks for the current 41 | iteration. 42 | """ 43 | return [] 44 | 45 | def create_graph(self, global_step): 46 | """Return whether access to the forward pass computation graph is needed. 47 | 48 | Args: 49 | global_step (int): The current iteration number. 50 | 51 | Returns: 52 | bool: ``True`` if the computation graph shall not be deleted, 53 | else ``False``. 54 | """ 55 | return self.should_compute(global_step) 56 | 57 | def _compute(self, global_step, params, batch_loss): 58 | """Evaluate the norm test. 59 | 60 | Args: 61 | global_step (int): The current iteration number. 62 | params ([torch.Tensor]): List of torch.Tensors holding the network's 63 | parameters. 64 | batch_loss (torch.Tensor): Mini-batch loss from current step. 65 | 66 | Returns: 67 | foat: Result of the norm test. 68 | """ 69 | losses = get_individual_losses(global_step) 70 | individual_gradients_flat = autograd_individual_gradients( 71 | losses, params, concat=True 72 | ) 73 | D_axis = 1 74 | individual_l2_norms_squared = (individual_gradients_flat ** 2).sum(D_axis) 75 | 76 | grad = torch.cat([p.grad.flatten() for p in params]) 77 | grad_norm = grad.norm() 78 | 79 | projections = torch.einsum("ni,i->n", individual_gradients_flat, grad) 80 | 81 | batch_size = get_batch_size(global_step) 82 | 83 | return ( 84 | ( 85 | 1 86 | / (batch_size * (batch_size - 1)) 87 | * ( 88 | individual_l2_norms_squared / grad_norm ** 2 89 | - (projections ** 2) / grad_norm ** 4 90 | ).sum() 91 | ) 92 | .sqrt() 93 | .item() 94 | ) 95 | 96 | 97 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS) 98 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS) 99 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS) 100 | def test_ortho_test(problem, independent_runs, q_kwargs): 101 | """Compare BackPACK and ``torch.autograd`` implementation of OrthoTest. 102 | 103 | Args: 104 | problem (tests.utils.Problem): Settings for train loop. 105 | independent_runs (bool): Whether to use to separate runs to compute the 106 | output of every quantity. 107 | q_kwargs (dict): Keyword arguments handed over to both quantities. 108 | """ 109 | compare_fn = get_compare_fn(independent_runs) 110 | compare_fn(problem, (OrthoTest, AutogradOrthoTest), q_kwargs) 111 | -------------------------------------------------------------------------------- /tests/test_quantities/test_quantity_integration.py: -------------------------------------------------------------------------------- 1 | """Intergation tests for ``cockpit.quantities``.""" 2 | 3 | import pytest 4 | 5 | from cockpit import quantities 6 | from cockpit.quantities import __all__ 7 | from cockpit.quantities.quantity import SingleStepQuantity, TwoStepQuantity 8 | from cockpit.utils.schedules import linear 9 | from tests.test_quantities.settings import PROBLEMS, PROBLEMS_IDS 10 | from tests.utils.harness import SimpleTestHarness 11 | from tests.utils.problem import instantiate 12 | 13 | QUANTITIES = [ 14 | getattr(quantities, q) 15 | for q in __all__ 16 | if q != "Quantity" 17 | and q != "SingleStepQuantity" 18 | and q != "TwoStepQuantity" 19 | and q != "ByproductQuantity" 20 | ] 21 | IDS = [q_cls.__name__ for q_cls in QUANTITIES] 22 | 23 | 24 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS) 25 | @pytest.mark.parametrize("quantity_cls", QUANTITIES, ids=IDS) 26 | def test_quantity_integration_and_track_events(problem, quantity_cls): 27 | """Check if ``Cockpit`` with a single quantity works. 28 | 29 | Args: 30 | problem (tests.utils.Problem): Settings for train loop. 31 | quantity_cls (Class): Quantity class that should be tested. 32 | """ 33 | interval, offset = 1, 2 34 | schedule = linear(interval, offset=offset) 35 | quantity = quantity_cls(track_schedule=schedule, verbose=True) 36 | 37 | with instantiate(problem): 38 | iterations = problem.iterations 39 | testing_harness = SimpleTestHarness(problem) 40 | cockpit_kwargs = {"quantities": [quantity]} 41 | testing_harness.test(cockpit_kwargs) 42 | 43 | def is_track_event(iteration): 44 | if isinstance(quantity, SingleStepQuantity): 45 | return schedule(iteration) 46 | elif isinstance(quantity, TwoStepQuantity): 47 | end_iter = quantity.SAVE_SHIFT + iteration 48 | return quantity.is_end(end_iter) and end_iter < iterations 49 | else: 50 | raise ValueError(f"Unknown quantity: {quantity}") 51 | 52 | track_events = sorted(i for i in range(iterations) if is_track_event(i)) 53 | output_events = sorted(quantity.get_output().keys()) 54 | 55 | assert output_events == track_events 56 | -------------------------------------------------------------------------------- /tests/test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for ``cockpit.utils``.""" 2 | -------------------------------------------------------------------------------- /tests/test_utils/test_configurations.py: -------------------------------------------------------------------------------- 1 | """Tests for ``cockpit.utils.configuration``.""" 2 | 3 | import pytest 4 | 5 | from cockpit import quantities 6 | from cockpit.utils.configuration import quantities_cls_for_configuration 7 | 8 | 9 | @pytest.mark.parametrize("label", ["full", "business", "economy"]) 10 | def test_quantities_cls_for_configuration(label): 11 | """Check cockpit configurations contain the correct quantities.""" 12 | economy = [ 13 | quantities.Alpha, 14 | quantities.GradNorm, 15 | quantities.UpdateSize, 16 | quantities.Distance, 17 | quantities.InnerTest, 18 | quantities.OrthoTest, 19 | quantities.NormTest, 20 | quantities.GradHist1d, 21 | quantities.Loss, 22 | ] 23 | business = economy + [quantities.TICDiag, quantities.HessTrace] 24 | full = business + [quantities.HessMaxEV, quantities.GradHist2d] 25 | 26 | configs = { 27 | "full": set(full), 28 | "business": set(business), 29 | "economy": set(economy), 30 | } 31 | 32 | quants = set(quantities_cls_for_configuration(label)) 33 | 34 | assert quants == configs[label] 35 | -------------------------------------------------------------------------------- /tests/test_utils/test_schedules.py: -------------------------------------------------------------------------------- 1 | """Test for ``cockpit.utils.schedules``.""" 2 | 3 | import pytest 4 | 5 | from cockpit.utils import schedules 6 | 7 | MAX_STEP = 100 8 | 9 | 10 | @pytest.mark.parametrize("interval", [1, 2, 3]) 11 | @pytest.mark.parametrize("offset", [0, 1, 2, -1]) 12 | def test_linear_schedule(interval, offset): 13 | """Check linear schedule. 14 | 15 | Args: 16 | interval (int): The regular tracking interval. 17 | offset (int, optional): Offset of tracking. Defaults to 0. 18 | """ 19 | schedule = schedules.linear(interval, offset) 20 | tracking = [] 21 | for i in range(MAX_STEP): 22 | tracking.append(schedule(i)) 23 | 24 | # If offset is negativ, start from first true value 25 | if offset < 0: 26 | offset = interval + offset 27 | 28 | # check that all steps that should be tracked are true 29 | assert all(tracking[offset::interval]) 30 | # Check that everything else is false 31 | assert sum(tracking[offset::interval]) == sum(tracking) 32 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions for ``Cockpit`` tests.""" 2 | -------------------------------------------------------------------------------- /tests/utils/check.py: -------------------------------------------------------------------------------- 1 | """Utility functions to compare results of two implementations.""" 2 | 3 | import numpy 4 | 5 | 6 | def compare_outputs(output1, output2, rtol=1e-5, atol=1e-7): 7 | """Compare outputs of two quantities.""" 8 | assert len(list(output1.keys())) == len( 9 | list(output2.keys()) 10 | ), "Different number of entries" 11 | 12 | for key in output1.keys(): 13 | if isinstance(output1[key], dict): 14 | compare_outputs(output1[key], output2[key]) 15 | else: 16 | val1, val2 = output1[key], output2[key] 17 | 18 | compare_fn = get_compare_function(val1, val2) 19 | 20 | compare_fn(val1, val2, atol=atol, rtol=rtol) 21 | 22 | 23 | def get_compare_function(value1, value2): 24 | """Return the function used to compare ``value1`` with ``value2``.""" 25 | if isinstance(value1, float) and isinstance(value2, float): 26 | compare_fn = compare_floats 27 | elif isinstance(value1, int) and isinstance(value2, int): 28 | compare_fn = compare_ints 29 | elif isinstance(value1, numpy.ndarray) and isinstance(value2, numpy.ndarray): 30 | compare_fn = compare_arrays 31 | elif isinstance(value1, list) and isinstance(value2, list): 32 | compare_fn = compare_lists 33 | elif isinstance(value1, tuple) and isinstance(value2, tuple): 34 | compare_fn = compare_tuples 35 | else: 36 | raise NotImplementedError( 37 | "No comparison available for these data types: " 38 | + f"{type(value1)}, {type(value2)}." 39 | ) 40 | 41 | return compare_fn 42 | 43 | 44 | def compare_tuples(tuple1, tuple2, rtol=1e-5, atol=1e-7): 45 | """Compare two tuples.""" 46 | assert len(tuple1) == len(tuple2), "Different number of entries" 47 | 48 | for value1, value2 in zip(tuple1, tuple2): 49 | compare_fn = get_compare_function(value1, value2) 50 | compare_fn(value1, value2, rtol=rtol, atol=atol) 51 | 52 | 53 | def compare_arrays(array1, array2, rtol=1e-5, atol=1e-7): 54 | """Compare two ``numpy`` arrays.""" 55 | assert numpy.allclose(array1, array2, rtol=rtol, atol=atol) 56 | 57 | 58 | def compare_floats(float1, float2, rtol=1e-5, atol=1e-7): 59 | """Compare two floats.""" 60 | assert numpy.isclose(float1, float2, atol=atol, rtol=rtol), f"{float1} ≠ {float2}" 61 | 62 | 63 | def compare_ints(int1, int2, rtol=None, atol=None): 64 | """Compare two integers. 65 | 66 | ``rtol`` and ``atol`` are ignored in the comparison, but required to keep the 67 | interface identical among comparison functions. 68 | 69 | Args: 70 | int1 (int): First integer. 71 | int2 (int): Another integer. 72 | rtol (any): Ignored, see above. 73 | atol (any): Ignored, see above. 74 | """ 75 | assert int1 == int2 76 | 77 | 78 | def compare_lists(list1, list2, rtol=1e-5, atol=1e-7): 79 | """Compare two lists containing floats.""" 80 | assert len(list1) == len( 81 | list2 82 | ), f"Lists don't match in size: {len(list1)} ≠ {len(list2)}" 83 | 84 | for val1, val2 in zip(list1, list2): 85 | compare_fn = get_compare_function(val1, val2) 86 | compare_fn(val1, val2, rtol=rtol, atol=atol) 87 | -------------------------------------------------------------------------------- /tests/utils/data.py: -------------------------------------------------------------------------------- 1 | """Utility functions to create toy input for Cockpit's tests.""" 2 | 3 | import torch 4 | from torch.utils.data.dataloader import DataLoader 5 | from torch.utils.data.dataset import Dataset 6 | 7 | 8 | def load_toy_data(batch_size): 9 | """Build a ``DataLoader`` with specified batch size from the toy data.""" 10 | return DataLoader(ToyData(), batch_size=batch_size) 11 | 12 | 13 | class ToyData(Dataset): 14 | """Toy data set used for testing. Consists of small random "images" and labels.""" 15 | 16 | def __init__(self, center=0.1): 17 | """Init the toy data set. 18 | 19 | Args: 20 | center (float): Center around which the data is randomly distributed. 21 | """ 22 | super(ToyData, self).__init__() 23 | self._center = center 24 | 25 | def __getitem__(self, index): 26 | """Return item with index `index` of data set. 27 | 28 | Args: 29 | index (int): Index of sample to access. Ignored for now. 30 | 31 | Returns: 32 | [tuple]: Tuple of (random) input and (random) label. 33 | """ 34 | item_input = torch.rand(1, 5, 5) + self._center 35 | item_label = torch.randint(size=(), low=0, high=3) 36 | return (item_input, item_label) 37 | 38 | def __len__(self): 39 | """Length of dataset. Arbitrarily set to 10 000.""" 40 | return 10000 # of how many examples(images?) you have 41 | -------------------------------------------------------------------------------- /tests/utils/harness.py: -------------------------------------------------------------------------------- 1 | """Base class for executing and hooking into a training loop to execute checks.""" 2 | 3 | 4 | from backpack import extend 5 | 6 | from cockpit import Cockpit 7 | from tests.utils.rand import restore_rng_state 8 | 9 | 10 | class SimpleTestHarness: 11 | """Class for running a simple test loop with the Cockpit. 12 | 13 | Args: 14 | problem (string): The (instantiated) problem to test on. 15 | """ 16 | 17 | def __init__(self, problem): 18 | """Store the instantiated problem.""" 19 | self.problem = problem 20 | 21 | def test(self, cockpit_kwargs, *backpack_exts): 22 | """Run the test loop. 23 | 24 | Args: 25 | cockpit_kwargs (dict): Arguments for the cockpit. 26 | *backpack_exts (list): List of user-defined BackPACK extensions. 27 | """ 28 | problem = self.problem 29 | 30 | data = problem.data 31 | device = problem.device 32 | iterations = problem.iterations 33 | 34 | # Extend 35 | model = extend(problem.model) 36 | loss_fn = extend(problem.loss_function) 37 | individual_loss_fn = extend(problem.individual_loss_function) 38 | 39 | # Create Optimizer 40 | optimizer = problem.optimizer 41 | 42 | # Initialize Cockpit 43 | self.cockpit = Cockpit(model.parameters(), **cockpit_kwargs) 44 | 45 | # print(cockpit_exts) 46 | 47 | # Main training loop 48 | global_step = 0 49 | for inputs, labels in iter(data): 50 | inputs, labels = inputs.to(device), labels.to(device) 51 | optimizer.zero_grad() 52 | 53 | # forward pass 54 | outputs = model(inputs) 55 | loss = loss_fn(outputs, labels) 56 | losses = individual_loss_fn(outputs, labels) 57 | 58 | # code inside this block does not alter random number generation 59 | with restore_rng_state(): 60 | # backward pass 61 | with self.cockpit( 62 | global_step, 63 | *backpack_exts, 64 | info={ 65 | "batch_size": inputs.shape[0], 66 | "individual_losses": losses, 67 | "loss": loss, 68 | "optimizer": optimizer, 69 | }, 70 | ): 71 | loss.backward(create_graph=self.cockpit.create_graph(global_step)) 72 | self.check_in_context() 73 | 74 | self.check_after_context() 75 | 76 | # optimizer step 77 | optimizer.step() 78 | global_step += 1 79 | 80 | if global_step >= iterations: 81 | break 82 | 83 | def check_in_context(self): 84 | """Check that will be executed within the cockpit context.""" 85 | pass 86 | 87 | def check_after_context(self): 88 | """Check that will be executed directly after the cockpit context.""" 89 | pass 90 | -------------------------------------------------------------------------------- /tests/utils/models.py: -------------------------------------------------------------------------------- 1 | """Utility functions to create toy models for Cockpit's tests.""" 2 | 3 | import torch 4 | 5 | 6 | def load_toy_model(): 7 | """Build a tor model that can be used in conjunction with ``ToyData``.""" 8 | return torch.nn.Sequential( 9 | torch.nn.Flatten(), 10 | torch.nn.Linear(25, 4), 11 | torch.nn.ReLU(), 12 | torch.nn.Linear(4, 4), 13 | torch.nn.ReLU(), 14 | torch.nn.Linear(4, 3), 15 | ) 16 | -------------------------------------------------------------------------------- /tests/utils/rand.py: -------------------------------------------------------------------------------- 1 | """Utility functions for random number generation.""" 2 | 3 | import torch 4 | 5 | 6 | class restore_rng_state: 7 | """Restores PyTorch seed to the value of initialization. 8 | 9 | This has the effect that code inside this context does not influence the outer 10 | loop's random generator state. 11 | """ 12 | 13 | def __init__(self): 14 | """Store the current PyTorch seed.""" 15 | self._rng_state = torch.get_rng_state() 16 | 17 | def __enter__(self): 18 | """Do nothing.""" 19 | pass 20 | 21 | def __exit__(self, exc_type, exc_value, traceback): 22 | """Restore the random generator state at initialization.""" 23 | torch.set_rng_state(self._rng_state) 24 | --------------------------------------------------------------------------------