├── .conda_env.yml
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
└── workflows
│ ├── CI.yml
│ ├── Lint.yml
│ └── python-publish.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yml
├── AUTHORS.md
├── CHANGELOG.md
├── LICENSE.txt
├── README.md
├── black.toml
├── cockpit
├── __init__.py
├── cockpit.py
├── context.py
├── instruments
│ ├── __init__.py
│ ├── alpha_gauge.py
│ ├── cabs_gauge.py
│ ├── distance_gauge.py
│ ├── early_stopping_gauge.py
│ ├── grad_norm_gauge.py
│ ├── gradient_tests_gauge.py
│ ├── histogram_1d_gauge.py
│ ├── histogram_2d_gauge.py
│ ├── hyperparameter_gauge.py
│ ├── max_ev_gauge.py
│ ├── mean_gsnr_gauge.py
│ ├── performance_gauge.py
│ ├── tic_gauge.py
│ ├── trace_gauge.py
│ ├── utils_instruments.py
│ └── utils_plotting.py
├── plotter.py
├── quantities
│ ├── __init__.py
│ ├── alpha.py
│ ├── bin_adaptation.py
│ ├── cabs.py
│ ├── distance.py
│ ├── early_stopping.py
│ ├── grad_hist.py
│ ├── grad_norm.py
│ ├── hess_max_ev.py
│ ├── hess_trace.py
│ ├── hooks
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── cleanup.py
│ ├── inner_test.py
│ ├── loss.py
│ ├── mean_gsnr.py
│ ├── norm_test.py
│ ├── ortho_test.py
│ ├── parameters.py
│ ├── quantity.py
│ ├── tic.py
│ ├── time.py
│ ├── update_size.py
│ ├── utils_hists.py
│ ├── utils_quantities.py
│ └── utils_transforms.py
└── utils
│ ├── __init__.py
│ ├── configuration.py
│ ├── optim.py
│ └── schedules.py
├── docs
├── Makefile
├── requirements_doc.txt
└── source
│ ├── _static
│ ├── 01_basic_fmnist.png
│ ├── 02_advanced_fmnist.png
│ ├── 03_deepobs.png
│ ├── Banner.svg
│ ├── LogoSquare.png
│ ├── favicon.ico
│ ├── instrument_preview_run.json
│ ├── instrument_previews
│ │ ├── Alpha.png
│ │ ├── CABS.png
│ │ ├── Distances.png
│ │ ├── EarlyStopping.png
│ │ ├── GradientNorm.png
│ │ ├── GradientTests.png
│ │ ├── HessMaxEV.png
│ │ ├── HessTrace.png
│ │ ├── Hist1d.png
│ │ ├── Hist2d.png
│ │ ├── Hyperparameters.png
│ │ ├── MeanGSNR.png
│ │ ├── Performance.png
│ │ └── TIC.png
│ ├── showcase.gif
│ └── stylefile.css
│ ├── api
│ ├── automod
│ │ ├── cockpit.instruments.alpha_gauge.rst
│ │ ├── cockpit.instruments.cabs_gauge.rst
│ │ ├── cockpit.instruments.distance_gauge.rst
│ │ ├── cockpit.instruments.early_stopping_gauge.rst
│ │ ├── cockpit.instruments.grad_norm_gauge.rst
│ │ ├── cockpit.instruments.gradient_tests_gauge.rst
│ │ ├── cockpit.instruments.histogram_1d_gauge.rst
│ │ ├── cockpit.instruments.histogram_2d_gauge.rst
│ │ ├── cockpit.instruments.hyperparameter_gauge.rst
│ │ ├── cockpit.instruments.max_ev_gauge.rst
│ │ ├── cockpit.instruments.mean_gsnr_gauge.rst
│ │ ├── cockpit.instruments.performance_gauge.rst
│ │ ├── cockpit.instruments.tic_gauge.rst
│ │ ├── cockpit.instruments.trace_gauge.rst
│ │ ├── cockpit.quantities.Alpha.rst
│ │ ├── cockpit.quantities.CABS.rst
│ │ ├── cockpit.quantities.Distance.rst
│ │ ├── cockpit.quantities.EarlyStopping.rst
│ │ ├── cockpit.quantities.GradHist1d.rst
│ │ ├── cockpit.quantities.GradHist2d.rst
│ │ ├── cockpit.quantities.GradNorm.rst
│ │ ├── cockpit.quantities.HessMaxEV.rst
│ │ ├── cockpit.quantities.HessTrace.rst
│ │ ├── cockpit.quantities.InnerTest.rst
│ │ ├── cockpit.quantities.Loss.rst
│ │ ├── cockpit.quantities.MeanGSNR.rst
│ │ ├── cockpit.quantities.NormTest.rst
│ │ ├── cockpit.quantities.OrthoTest.rst
│ │ ├── cockpit.quantities.Parameters.rst
│ │ ├── cockpit.quantities.TICDiag.rst
│ │ ├── cockpit.quantities.TICTrace.rst
│ │ ├── cockpit.quantities.Time.rst
│ │ ├── cockpit.quantities.UpdateSize.rst
│ │ ├── cockpit.utils.configuration.configuration.rst
│ │ ├── cockpit.utils.configuration.quantities_cls_for_configuration.rst
│ │ ├── cockpit.utils.schedules.linear.rst
│ │ └── cockpit.utils.schedules.logarithmic.rst
│ ├── cockpit.rst
│ ├── instruments.rst
│ ├── plotter.rst
│ ├── quantities.rst
│ └── utils.rst
│ ├── conf.py
│ ├── examples
│ ├── 01_basic_fmnist.rst
│ ├── 02_advanced_fmnist.rst
│ └── 03_deepobs.rst
│ ├── extract_instrument_previews.py
│ ├── index.rst
│ ├── introduction
│ └── good_to_know.rst
│ └── other
│ ├── changelog.rst
│ ├── contributors.rst
│ └── license.rst
├── examples
├── 01_basic_fmnist.py
├── 02_advanced_fmnist.py
├── 03_deepobs.py
├── _utils_deepobs.py
└── _utils_examples.py
├── makefile
├── setup.cfg
├── setup.py
└── tests
├── __init__.py
├── settings.py
├── test_bugs
├── test_issue5.py
└── test_issue6.py
├── test_cockpit
├── __init__.py
├── settings.py
├── test_automatic_call_track.py
├── test_backpack_extensions.py
└── test_multiple_batch_grad_transforms.py
├── test_examples
├── __init__.py
└── test_examples.py
├── test_quantities
├── __init__.py
├── adam_settings.py
├── settings.py
├── test_alpha.py
├── test_bin_adaptation.py
├── test_cabs.py
├── test_early_stopping.py
├── test_grad_hist.py
├── test_hess_max_ev.py
├── test_hess_trace.py
├── test_inner_test.py
├── test_mean_gsnr.py
├── test_norm_test.py
├── test_ortho_test.py
├── test_quantity_integration.py
├── test_tic.py
└── utils.py
├── test_utils
├── __init__.py
├── test_configurations.py
└── test_schedules.py
└── utils
├── __init__.py
├── check.py
├── data.py
├── harness.py
├── models.py
├── problem.py
└── rand.py
/.conda_env.yml:
--------------------------------------------------------------------------------
1 | name: cockpit
2 | dependencies:
3 | - pip
4 | - python
5 | - pytorch::pytorch
6 | - pytorch::torchvision
7 | - conda-forge::black
8 | - conda-forge::darglint
9 | - conda-forge::flake8
10 | - conda-forge::flake8-bugbear
11 | - conda-forge::flake8-comprehensions
12 | - conda-forge::isort
13 | - conda-forge::palettable
14 | - conda-forge::pydocstyle
15 | - conda-forge::pytest
16 | - conda-forge::pytest-cov
17 | - conda-forge::pytest-benchmark
18 | - conda-forge::m2r2
19 | - conda-forge::sphinx
20 | - conda-forge::sphinx-automodapi
21 | - conda-forge::sphinx_rtd_theme
22 | - conda-forge::matplotlib=3.4.3
23 | - conda-forge::tikzplotlib
24 | - pip:
25 | - memory-profiler
26 | - pre-commit
27 | - graphviz
28 | - sphinx-notfound-page
29 | - git+https://git@github.com/fsschneider/DeepOBS.git@develop#egg=deepobs
30 | - git+https://git@github.com/f-dangel/backobs.git@development#egg=backobs
31 | - -e .
32 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us fix a bug
4 | title: ''
5 | labels: "\U0001F195 Status: New, \U0001F41B Type: Bug"
6 | assignees: fsschneider
7 |
8 | ---
9 |
10 | [Provide a general description of the issue]
11 |
12 | ## Description
13 |
14 | [Provide more details on the bug itself]
15 |
16 | ## Steps to Reproduce
17 |
18 | [An ordered list of steps to recreating the issue]
19 |
20 | ## Source or Possible Fix
21 |
22 | [If possible, you can describe what you think is the source of the bug, or possibly even describe a fix]
23 |
24 | ---
25 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: "\U0001F195 Status: New"
6 | assignees: fsschneider
7 |
8 | ---
9 |
10 | **Are you requesting a new feature or an enhancement of an existing one? Please use the labels accordingly.**
11 |
12 | ## Description
13 |
14 | [Describe the feature you would like to be added]
15 |
16 | ---
17 |
--------------------------------------------------------------------------------
/.github/workflows/CI.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches:
6 | - '*'
7 | pull_request:
8 | branches:
9 | - development
10 | - master
11 |
12 |
13 | jobs:
14 | pytest:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v1
18 | - name: Set up Python 3.7
19 | uses: actions/setup-python@v1
20 | with:
21 | python-version: 3.7
22 | - name: Install dependencies
23 | run: |
24 | python -m pip install --upgrade pip
25 | pip install pytest
26 | pip install pytest-cov
27 | pip install requests
28 | pip install coveralls
29 | pip install .
30 | pip install git+https://git@github.com/fsschneider/DeepOBS.git@develop#egg=deepobs
31 | pip install git+https://git@github.com/f-dangel/backobs.git@development#egg=backobs
32 | - name: Run pytest
33 | run: |
34 | make test
35 | - name: Test coveralls
36 | run: coveralls --service=github
37 | env:
38 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
--------------------------------------------------------------------------------
/.github/workflows/Lint.yml:
--------------------------------------------------------------------------------
1 | name: Lint
2 |
3 | on:
4 | push:
5 | branches:
6 | - '*'
7 | pull_request:
8 | branches:
9 | - development
10 | - master
11 |
12 |
13 | jobs:
14 | black:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v2
18 | - name: Set up Python 3.7
19 | uses: actions/setup-python@v2
20 | with:
21 | python-version: 3.7
22 | - name: Install black
23 | run: |
24 | python -m pip install --upgrade pip
25 | pip install black
26 | - name: Run black
27 | run: |
28 | make black-check
29 |
30 | flake8:
31 | runs-on: ubuntu-latest
32 | steps:
33 | - uses: actions/checkout@v2
34 | - name: Set up Python 3.7
35 | uses: actions/setup-python@v2
36 | with:
37 | python-version: 3.7
38 | - name: Install flake8
39 | run: |
40 | python -m pip install --upgrade pip
41 | pip install flake8
42 | pip install flake8-bugbear
43 | pip install flake8-comprehensions
44 | - name: Run flake8
45 | run: |
46 | make flake8
47 |
48 | pydocstyle:
49 | runs-on: ubuntu-latest
50 | steps:
51 | - uses: actions/checkout@v2
52 | - name: Set up Python 3.7
53 | uses: actions/setup-python@v2
54 | with:
55 | python-version: 3.7
56 | - name: Install pydocstyle
57 | run: |
58 | python -m pip install --upgrade pip
59 | pip install pydocstyle
60 | - name: Run pydocstyle
61 | run: |
62 | make pydocstyle-check
63 |
64 | isort:
65 | runs-on: ubuntu-latest
66 | steps:
67 | - uses: actions/checkout@v2
68 | - name: Set up Python 3.7
69 | uses: actions/setup-python@v2
70 | with:
71 | python-version: 3.7
72 | - name: Install isort
73 | run: |
74 | python -m pip install --upgrade pip
75 | pip install isort
76 | - name: Run isort
77 | run: |
78 | make isort-check
79 |
80 | darglint:
81 | runs-on: ubuntu-latest
82 | steps:
83 | - uses: actions/checkout@v2
84 | - name: Set up Python 3.7
85 | uses: actions/setup-python@v2
86 | with:
87 | python-version: 3.7
88 | - name: Install dependencies
89 | run: |
90 | python -m pip install --upgrade pip
91 | pip install darglint
92 | - name: Run darglint
93 | run: |
94 | make darglint-check
95 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - uses: actions/checkout@v2
16 | - name: Set up Python
17 | uses: actions/setup-python@v2
18 | with:
19 | python-version: "3.x"
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install setuptools wheel twine
24 | - name: Build and publish
25 | env:
26 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
27 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
28 | run: |
29 | python setup.py sdist bdist_wheel
30 | twine upload dist/*
31 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | potato.yml
3 |
4 | ## Core latex/pdflatex auxiliary files:
5 | *.aux
6 | *.lof
7 | *.log
8 | *.lot
9 | *.fls
10 | *.out
11 | *.toc
12 | *.fmt
13 | *.fot
14 | *.cb
15 | *.cb2
16 | .*.lb
17 |
18 | ## Intermediate documents:
19 | *.dvi
20 | *.xdv
21 | *-converted-to.*
22 | # these rules might exclude image files for figures etc.
23 | # *.ps
24 | # *.eps
25 | # *.pdf
26 |
27 | ## Generated if empty string is given at "Please type another file name for output:"
28 | .pdf
29 |
30 | ## Bibliography auxiliary files (bibtex/biblatex/biber):
31 | *.bbl
32 | *.bcf
33 | *.blg
34 | *-blx.aux
35 | *-blx.bib
36 | *.run.xml
37 |
38 | ## Build tool auxiliary files:
39 | *.fdb_latexmk
40 | *.synctex
41 | *.synctex(busy)
42 | *.synctex.gz
43 | *.synctex.gz(busy)
44 | *.pdfsync
45 |
46 | ## Build tool directories for auxiliary files
47 | # latexrun
48 | latex.out/
49 |
50 | ## Auxiliary and intermediate files from other packages:
51 | # algorithms
52 | *.alg
53 | *.loa
54 |
55 | # achemso
56 | acs-*.bib
57 |
58 | # amsthm
59 | *.thm
60 |
61 | # beamer
62 | *.nav
63 | *.pre
64 | *.snm
65 | *.vrb
66 |
67 | # changes
68 | *.soc
69 |
70 | # comment
71 | *.cut
72 |
73 | # cprotect
74 | *.cpt
75 |
76 | # elsarticle (documentclass of Elsevier journals)
77 | *.spl
78 |
79 | # endnotes
80 | *.ent
81 |
82 | # fixme
83 | *.lox
84 |
85 | # feynmf/feynmp
86 | *.mf
87 | *.mp
88 | *.t[1-9]
89 | *.t[1-9][0-9]
90 | *.tfm
91 |
92 | #(r)(e)ledmac/(r)(e)ledpar
93 | *.end
94 | *.?end
95 | *.[1-9]
96 | *.[1-9][0-9]
97 | *.[1-9][0-9][0-9]
98 | *.[1-9]R
99 | *.[1-9][0-9]R
100 | *.[1-9][0-9][0-9]R
101 | *.eledsec[1-9]
102 | *.eledsec[1-9]R
103 | *.eledsec[1-9][0-9]
104 | *.eledsec[1-9][0-9]R
105 | *.eledsec[1-9][0-9][0-9]
106 | *.eledsec[1-9][0-9][0-9]R
107 |
108 | # glossaries
109 | *.acn
110 | *.acr
111 | *.glg
112 | *.glo
113 | *.gls
114 | *.glsdefs
115 | *.lzo
116 | *.lzs
117 |
118 | # uncomment this for glossaries-extra (will ignore makeindex's style files!)
119 | # *.ist
120 |
121 | # gnuplottex
122 | *-gnuplottex-*
123 |
124 | # gregoriotex
125 | *.gaux
126 | *.gtex
127 |
128 | # htlatex
129 | *.4ct
130 | *.4tc
131 | *.idv
132 | *.lg
133 | *.trc
134 | *.xref
135 |
136 | # hyperref
137 | *.brf
138 |
139 | # knitr
140 | *-concordance.tex
141 | # TODO Comment the next line if you want to keep your tikz graphics files
142 | *.tikz
143 | *-tikzDictionary
144 |
145 | # listings
146 | *.lol
147 |
148 | # luatexja-ruby
149 | *.ltjruby
150 |
151 | # makeidx
152 | *.idx
153 | *.ilg
154 | *.ind
155 |
156 | # minitoc
157 | *.maf
158 | *.mlf
159 | *.mlt
160 | *.mtc[0-9]*
161 | *.slf[0-9]*
162 | *.slt[0-9]*
163 | *.stc[0-9]*
164 |
165 | # minted
166 | _minted*
167 | *.pyg
168 |
169 | # morewrites
170 | *.mw
171 |
172 | # nomencl
173 | *.nlg
174 | *.nlo
175 | *.nls
176 |
177 | # pax
178 | *.pax
179 |
180 | # pdfpcnotes
181 | *.pdfpc
182 |
183 | # sagetex
184 | *.sagetex.sage
185 | *.sagetex.py
186 | *.sagetex.scmd
187 |
188 | # scrwfile
189 | *.wrt
190 |
191 | # sympy
192 | *.sout
193 | *.sympy
194 | sympy-plots-for-*.tex/
195 |
196 | # pdfcomment
197 | *.upa
198 | *.upb
199 |
200 | # pythontex
201 | *.pytxcode
202 | pythontex-files-*/
203 |
204 | # tcolorbox
205 | *.listing
206 |
207 | # thmtools
208 | *.loe
209 |
210 | # TikZ & PGF
211 | *.dpth
212 | *.md5
213 | *.auxlock
214 |
215 | # todonotes
216 | *.tdo
217 |
218 | # vhistory
219 | *.hst
220 | *.ver
221 |
222 | # easy-todo
223 | *.lod
224 |
225 | # xcolor
226 | *.xcp
227 |
228 | # xmpincl
229 | *.xmpi
230 |
231 | # xindy
232 | *.xdy
233 |
234 | # xypic precompiled matrices and outlines
235 | *.xyc
236 | *.xyd
237 |
238 | # endfloat
239 | *.ttt
240 | *.fff
241 |
242 | # Latexian
243 | TSWLatexianTemp*
244 |
245 | ## Editors:
246 | # WinEdt
247 | *.bak
248 | *.sav
249 |
250 | # Texpad
251 | .texpadtmp
252 |
253 | # LyX
254 | *.lyx~
255 |
256 | # Kile
257 | *.backup
258 |
259 | # KBibTeX
260 | *~[0-9]*
261 |
262 | # auto folder when using emacs and auctex
263 | ./auto/*
264 | *.el
265 |
266 | # expex forward references with \gathertags
267 | *-tags.tex
268 |
269 | # standalone packages
270 | *.sta
271 |
272 | # Makeindex log files
273 | *.lpz
274 |
275 |
276 | # Python data files
277 | *-ubyte
278 | *.pt
279 | **/__pycache__
280 |
281 |
282 | # Python IDE files
283 | .idea
284 |
285 | # General
286 | .DS_Store
287 | .AppleDouble
288 | .LSOverride
289 |
290 | ########################
291 |
292 | **/data
293 |
294 | # pip src
295 | **/src
296 | *.egg-info
297 |
298 | # DeepOBS
299 | *.json
300 | !docs/source/_static/instrument_preview_run.json
301 | **/results
302 | **/data_deepobs
303 |
304 | # Torchvision
305 | **MNIST/**/*.gz
306 |
307 | # Tensorboard
308 | **/*events.out.tfevents.*
309 |
310 | # Test
311 | .coverage
312 | **/logfiles
313 |
314 | # Doc
315 | docs/build/
316 |
317 | # local emacs variables
318 | !**/.dir-locals.el
319 |
320 | examples/logfiles/
321 |
322 | .envrc
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/psf/black
3 | rev: stable
4 | hooks:
5 | - id: black
6 | args: [--config=black.toml]
7 | - repo: https://gitlab.com/pycqa/flake8
8 | rev: "3.7.9"
9 | hooks:
10 | - id: flake8
11 | additional_dependencies:
12 | [
13 | mccabe,
14 | pycodestyle,
15 | pyflakes,
16 | pep8-naming,
17 | flake8-bugbear,
18 | flake8-comprehensions,
19 | ]
20 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # RTD configuration file version
6 | version: 2
7 |
8 | # Build documentation in the docs/ directory with Sphinx
9 | sphinx:
10 | builder: html
11 |
12 | # Optionally build your docs in additional formats such as PDF
13 | formats: []
14 |
15 | # Optionally set the version of Python and requirements required to build your docs
16 | python:
17 | version: 3.7
18 | install:
19 | - method: pip
20 | path: .
21 | - requirements: docs/requirements_doc.txt
--------------------------------------------------------------------------------
/AUTHORS.md:
--------------------------------------------------------------------------------
1 | ## Cockpit Contributors
2 |
3 | Many people have helped with the development of **Cockpit**.
4 |
5 | **Maintainers:**
6 |
7 |
25 |
26 |
27 | ---
28 |
29 | Additional support was offered by
Philipp Hennig,
Jonathan Wenger, and the entire [ProbNum](https://github.com/probabilistic-numerics/probnum) team and the [Methods of Machine Learning](https://uni-tuebingen.de/en/faculties/faculty-of-science/departments/computer-science/lehrstuehle/methods-of-machine-learning/) group.
30 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 |
5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
6 |
7 | ## [Unreleased]
8 |
9 | ## [1.0.2] - 2021-10-26
10 |
11 | ### Added
12 |
13 | - Added references to a separate [experiment repository](https://github.com/fsschneider/cockpit-experiments) that publishes the code for all experiments shown in the paper.
14 |
15 | ### Fixed
16 |
17 | - Protects the `batch_grad` field in the case where non-SGD is used together with other quantities that free `batch_grad` for memory performance. [[#5](https://github.com/f-dangel/cockpit/issues/5), [PR](https://github.com/f-dangel/cockpit/pull/18)]
18 |
19 | ## [1.0.1] - 2021-10-13
20 |
21 | From this version on, `cockpit` will be available as `cockpit-for-pytorch` on
22 | PyPI.
23 |
24 | ### Added
25 | - Make library `pip`-installable as `cockpit-for-pytorch`
26 | [[PR](https://github.com/f-dangel/cockpit/pull/17)]
27 | - Require BackPACK main release
28 | [[PR](https://github.com/f-dangel/cockpit/pull/12)]
29 | - Added a `savename` argument to the `CockpitPlotter.plot()` function, which lets you define the name, and now the `savedir` should really only describe the **directory**. [[PR](https://github.com/f-dangel/cockpit/pull/16), Fixes #8]
30 | - Added optional `savefig_kwargs` argument to the `CockpitPlotter.plot()` function that gets passed to the `matplotlib` function `fig.savefig()` to, e.g., specify DPI value or use a different file format (e.g. PDF). [[PR](https://github.com/f-dangel/cockpit/pull/16), Fixes #10]
31 |
32 | ### Internal
33 | - Fix [#6](https://github.com/f-dangel/cockpit/issues/6): Don't execute
34 | extension hook on modules with children
35 | [[PR](https://github.com/f-dangel/cockpit/pull/7)]
36 |
37 | ## [1.0.0] - 2021-04-30
38 |
39 | ### Added
40 |
41 | - First public release version of **Cockpit**.
42 |
43 | [Unreleased]: https://github.com/f-dangel/cockpit/compare/v1.0.2...HEAD
44 | [1.0.2]: https://github.com/f-dangel/cockpit/compare/1.0.1...1.0.2
45 | [1.0.1]: https://github.com/f-dangel/cockpit/compare/1.0.0...1.0.1
46 | [1.0.0]: https://github.com/f-dangel/cockpit/releases/tag/1.0.0
47 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Frank Schneider, Felix Dangel & Philipp Hennig
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
A Practical Debugging Tool for Training Deep Neural Networks
6 |
7 |
8 | A better status screen for deep learning.
9 |
10 |
11 |
12 |
13 | Installation •
14 | Docs •
15 | Experiments •
16 | License •
17 | Citation
18 |
19 |
20 | [](https://github.com/f-dangel/cockpit/actions/workflows/CI.yml)
21 | [](https://github.com/f-dangel/cockpit/actions/workflows/Lint.yml)
22 | [](https://cockpit.readthedocs.io)
23 | [](https://coveralls.io/github/f-dangel/cockpit?branch=main)
24 | [](https://github.com/f-dangel/cockpit/blob/master/LICENSE)
25 | [](https://github.com/psf/black)
26 | [](https://arxiv.org/abs/2102.06604)
27 |
28 | ---
29 |
30 | ```bash
31 | pip install cockpit-for-pytorch
32 | ```
33 |
34 | ---
35 |
36 | **Cockpit is a visual and statistical debugger specifically designed for deep learning.** Training a deep neural network is often a pain! Successfully training such a network usually requires either years of intuition or expensive parameter searches involving lots of trial and error. Traditional debuggers provide only limited help: They can find *syntactical errors* but not *training bugs* such as ill-chosen learning rates. **Cockpit** offers a closer, more meaningful look into the training process with multiple well-chosen *instruments*.
37 |
38 | ---
39 |
40 | 
41 |
42 |
43 | ## Installation
44 |
45 | To install **Cockpit** simply run
46 |
47 | ```bash
48 | pip install cockpit-for-pytorch
49 | ```
50 |
51 |
52 | Conda environment
53 | For convenience, we also provide a conda environment, which can be installed via the conda yml file. It includes all the necessary requirements to build the docs, execute the tests and run the examples.
54 |
55 |
56 |
57 | ## Documentation
58 |
59 | The [documentation](https://cockpit.readthedocs.io/) provides a full tutorial on how to get started using **Cockpit** as well as a detailed documentation of its API.
60 |
61 |
62 | ## Experiments
63 |
64 | To showcase the capabilities of **Cockpit** we performed several experiments illustrating the usefulness of our debugging tool. The code for the experiments can be found in a [separate repository](https://github.com/fsschneider/cockpit-experiments). For a discussion of those experiments please refer to our [paper](https://arxiv.org/abs/2102.06604).
65 |
66 |
67 | ## License
68 |
69 | Distributed under the MIT License. See [`LICENSE`](LICENSE.txt) for more information.
70 |
71 |
72 | ## Citation
73 |
74 | If you use **Cockpit**, please consider citing:
75 |
76 | > [Frank Schneider, Felix Dangel, Philipp Hennig
77 | > **Cockpit: A Practical Debugging Tool for Training Deep Neural Networks**
78 | > *arXiv 2102.06604*](http://arxiv.org/abs/2102.06604)
79 |
80 | ```bibtex
81 | @misc{schneider2021cockpit,
82 | title={{Cockpit: A Practical Debugging Tool for Training Deep Neural Networks}},
83 | author={Frank Schneider and Felix Dangel and Philipp Hennig},
84 | year={2021},
85 | eprint={2102.06604},
86 | archivePrefix={arXiv},
87 | primaryClass={cs.LG}
88 | }
89 | ```
90 |
--------------------------------------------------------------------------------
/black.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 88
3 | target-version = ['py35', 'py36', 'py37']
4 | include = '\.pyi?$'
5 | exclude = '''
6 | (
7 | /(
8 | \.eggs
9 | | \.git
10 | | \.pytest_cache
11 | | \.benchmarks
12 | | docs_src
13 | | docs
14 | | build
15 | | dist
16 | )/
17 | )
18 | '''
19 |
--------------------------------------------------------------------------------
/cockpit/__init__.py:
--------------------------------------------------------------------------------
1 | """Init for Cockpit."""
2 | from pkg_resources import DistributionNotFound, get_distribution
3 |
4 | from cockpit.cockpit import Cockpit
5 | from cockpit.plotter import CockpitPlotter
6 |
7 | # Extract the version number, for accessing it via __version__
8 | try:
9 | # Change here if project is renamed and does not equal the package name
10 | dist_name = __name__
11 | __version__ = get_distribution(dist_name).version
12 | except DistributionNotFound:
13 | __version__ = "unknown"
14 | finally:
15 | del get_distribution, DistributionNotFound
16 |
17 |
18 | __all__ = ["Cockpit", "CockpitPlotter", "__version__", "cockpit.quantities"]
19 |
--------------------------------------------------------------------------------
/cockpit/instruments/__init__.py:
--------------------------------------------------------------------------------
1 | """All Instruments for the Cockpit."""
2 |
3 | from cockpit.instruments.alpha_gauge import alpha_gauge
4 | from cockpit.instruments.cabs_gauge import cabs_gauge
5 | from cockpit.instruments.distance_gauge import distance_gauge
6 | from cockpit.instruments.early_stopping_gauge import early_stopping_gauge
7 | from cockpit.instruments.grad_norm_gauge import grad_norm_gauge
8 | from cockpit.instruments.gradient_tests_gauge import gradient_tests_gauge
9 | from cockpit.instruments.histogram_1d_gauge import histogram_1d_gauge
10 | from cockpit.instruments.histogram_2d_gauge import histogram_2d_gauge
11 | from cockpit.instruments.hyperparameter_gauge import hyperparameter_gauge
12 | from cockpit.instruments.max_ev_gauge import max_ev_gauge
13 | from cockpit.instruments.mean_gsnr_gauge import mean_gsnr_gauge
14 | from cockpit.instruments.performance_gauge import performance_gauge
15 | from cockpit.instruments.tic_gauge import tic_gauge
16 | from cockpit.instruments.trace_gauge import trace_gauge
17 |
18 | __all__ = [
19 | "alpha_gauge",
20 | "distance_gauge",
21 | "grad_norm_gauge",
22 | "histogram_1d_gauge",
23 | "histogram_2d_gauge",
24 | "gradient_tests_gauge",
25 | "hyperparameter_gauge",
26 | "max_ev_gauge",
27 | "performance_gauge",
28 | "tic_gauge",
29 | "trace_gauge",
30 | "mean_gsnr_gauge",
31 | "early_stopping_gauge",
32 | "cabs_gauge",
33 | ]
34 |
--------------------------------------------------------------------------------
/cockpit/instruments/cabs_gauge.py:
--------------------------------------------------------------------------------
1 | """CABS Gauge."""
2 |
3 | import warnings
4 |
5 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot
6 |
7 |
8 | def cabs_gauge(self, fig, gridspec):
9 | """CABS gauge, showing the CABS rule versus iteration.
10 |
11 | The batch size trades-off more accurate gradient approximations with longer
12 | computation. The `CABS criterion `_ describes
13 | the optimal batch size under certain assumptions.
14 |
15 | The instruments shows the suggested batch size (and an exponential weighted
16 | average) over the course of training, according to
17 |
18 | - `Balles, L., Romero, J., & Hennig, P.,
19 | Coupling adaptive batch sizes with learning rates (2017).
20 | `_
21 |
22 | **Preview**
23 |
24 | .. image:: ../../_static/instrument_previews/CABS.png
25 | :alt: Preview CABS Gauge
26 |
27 | **Requires**
28 |
29 | This instrument requires data from the :class:`~cockpit.quantities.CABS`
30 | quantity class.
31 |
32 | Args:
33 | self (CockpitPlotter): The cockpit plotter requesting this instrument.
34 | fig (matplotlib.figure.Figure): Figure of the Cockpit.
35 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be
36 | placed
37 | """
38 | # Plot Trace vs iteration
39 | title = "CABS"
40 |
41 | # Check if the required data is available, else skip this instrument
42 | requires = ["CABS"]
43 | plot_possible = check_data(self.tracking_data, requires)
44 | if not plot_possible:
45 | if self.debug:
46 | warnings.warn(
47 | "Couldn't get the required data for the " + title + " instrument",
48 | stacklevel=1,
49 | )
50 | return
51 |
52 | plot_args = {
53 | "x": "iteration",
54 | "y": "CABS",
55 | "data": self.tracking_data,
56 | "x_scale": "symlog" if self.show_log_iter else "linear",
57 | "y_scale": "linear",
58 | "cmap": self.cmap,
59 | "EMA": "y",
60 | "EMA_alpha": self.EMA_alpha,
61 | "EMA_cmap": self.cmap2,
62 | "title": title,
63 | "xlim": "tight",
64 | "ylim": None,
65 | "fontweight": "bold",
66 | "facecolor": self.bg_color_instruments,
67 | }
68 | ax = fig.add_subplot(gridspec)
69 | create_basic_plot(**plot_args, ax=ax)
70 |
--------------------------------------------------------------------------------
/cockpit/instruments/distance_gauge.py:
--------------------------------------------------------------------------------
1 | """Distance Gauge."""
2 |
3 | import warnings
4 |
5 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot
6 | from cockpit.quantities.utils_quantities import _root_sum_of_squares
7 |
8 |
9 | def distance_gauge(self, fig, gridspec):
10 | """Distance gauge showing two different quantities related to distance.
11 |
12 | This instruments shows two quantities at once. Firstly, the :math:`L_2`-distance
13 | of the current parameters to their initialization. This describes the total distance
14 | that the optimization trajectory "has traveled so far" and can be seen via the
15 | blue-to-green dots (and the left y-axis).
16 |
17 | Secondly, the update sizes of individual steps are shown via the yellow-to-blue
18 | dots (and the right y-axis). It measure the distance that a single parameter
19 | update covers.
20 |
21 | Both quantities are overlayed with an exponentially weighted average.
22 |
23 | .. image:: ../../_static/instrument_previews/Distances.png
24 | :alt: Preview Distances Gauge
25 |
26 | **Requires**
27 |
28 | The distance instrument requires data from both, the
29 | :class:`~cockpit.quantities.UpdateSize` and the
30 | :class:`~cockpit.quantities.Distance` quantity class.
31 |
32 | Args:
33 | self (CockpitPlotter): The cockpit plotter requesting this instrument.
34 | fig (matplotlib.figure.Figure): Figure of the Cockpit.
35 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be
36 | placed
37 | """
38 | # Plot Trace vs iteration
39 | title = "Distance"
40 |
41 | # Check if the required data is available, else skip this instrument
42 | requires = ["Distance", "UpdateSize"]
43 | plot_possible = check_data(self.tracking_data, requires)
44 | if not plot_possible:
45 | if self.debug:
46 | warnings.warn(
47 | "Couldn't get the required data for the " + title + " instrument",
48 | stacklevel=1,
49 | )
50 | return
51 |
52 | # Compute
53 | self.tracking_data["Distance_all"] = self.tracking_data.Distance.map(
54 | lambda x: _root_sum_of_squares(x) if type(x) == list else x
55 | )
56 | self.tracking_data["UpdateSize_all"] = self.tracking_data.UpdateSize.map(
57 | lambda x: _root_sum_of_squares(x) if type(x) == list else x
58 | )
59 |
60 | plot_args = {
61 | "x": "iteration",
62 | "y": "Distance_all",
63 | "data": self.tracking_data,
64 | "y_scale": "linear",
65 | "x_scale": "symlog" if self.show_log_iter else "linear",
66 | "cmap": self.cmap,
67 | "EMA": "y",
68 | "EMA_alpha": self.EMA_alpha,
69 | "EMA_cmap": self.cmap2,
70 | "title": title,
71 | "xlim": "tight",
72 | "ylim": None,
73 | "fontweight": "bold",
74 | "facecolor": self.bg_color_instruments,
75 | }
76 | ax = fig.add_subplot(gridspec)
77 | create_basic_plot(**plot_args, ax=ax)
78 |
79 | ax2 = ax.twinx()
80 | plot_args = {
81 | "x": "iteration",
82 | "y": "UpdateSize_all",
83 | "data": self.tracking_data,
84 | "y_scale": "linear",
85 | "x_scale": "symlog" if self.show_log_iter else "linear",
86 | "cmap": self.cmap.reversed(),
87 | "EMA": "y",
88 | "EMA_alpha": self.EMA_alpha,
89 | "EMA_cmap": self.cmap2.reversed(),
90 | "xlim": "tight",
91 | "ylim": None,
92 | "marker": ",",
93 | }
94 | create_basic_plot(**plot_args, ax=ax2)
95 |
--------------------------------------------------------------------------------
/cockpit/instruments/early_stopping_gauge.py:
--------------------------------------------------------------------------------
1 | """Early Stopping Gauge."""
2 |
3 | import warnings
4 |
5 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot
6 |
7 |
8 | def early_stopping_gauge(self, fig, gridspec):
9 | """Early Stopping gauge, showing the LHS of the stopping criterion versus iteration.
10 |
11 | Early stopping the training has been widely used to prevent poor generalization
12 | due to over-fitting. `Mahsereci et al. (2017) `_
13 | proposed an evidence-based stopping criterion based on mini-batch statistics.
14 | This instruments visualizes this criterion versus iteration, overlayed
15 | with an exponentially weighted average. If the stopping criterion becomes
16 | positive, this suggests stopping the training according to
17 |
18 | - `Mahsereci, M., Balles, L., Lassner, C., & Hennig, P.,
19 | Early stopping without a validation set (2017).
20 | `_
21 |
22 | **Preview**
23 |
24 | .. image:: ../../_static/instrument_previews/EarlyStopping.png
25 | :alt: Preview EarlyStopping Gauge
26 |
27 | **Requires**
28 |
29 | This instrument requires data from the :class:`~cockpit.quantities.EarlyStopping`
30 | quantity class.
31 |
32 | Args:
33 | self (CockpitPlotter): The cockpit plotter requesting this instrument.
34 | fig (matplotlib.figure.Figure): Figure of the Cockpit.
35 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be
36 | placed
37 | """
38 | # Plot Trace vs iteration
39 | title = "Early stopping"
40 |
41 | # Check if the required data is available, else skip this instrument
42 | requires = ["EarlyStopping"]
43 | plot_possible = check_data(self.tracking_data, requires)
44 | if not plot_possible:
45 | if self.debug:
46 | warnings.warn(
47 | "Couldn't get the required data for the " + title + " instrument",
48 | stacklevel=1,
49 | )
50 | return
51 |
52 | plot_args = {
53 | "x": "iteration",
54 | "y": "EarlyStopping",
55 | "data": self.tracking_data,
56 | "x_scale": "symlog" if self.show_log_iter else "linear",
57 | "y_scale": "linear",
58 | "cmap": self.cmap,
59 | "EMA": "y",
60 | "EMA_alpha": self.EMA_alpha,
61 | "EMA_cmap": self.cmap2,
62 | "title": title,
63 | "xlim": "tight",
64 | "ylim": None,
65 | "fontweight": "bold",
66 | "facecolor": self.bg_color_instruments,
67 | }
68 | ax = fig.add_subplot(gridspec)
69 | create_basic_plot(**plot_args, ax=ax)
70 |
--------------------------------------------------------------------------------
/cockpit/instruments/grad_norm_gauge.py:
--------------------------------------------------------------------------------
1 | """Gradient Norm Gauge."""
2 |
3 | import warnings
4 |
5 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot
6 | from cockpit.quantities.utils_quantities import _root_sum_of_squares
7 |
8 |
9 | def grad_norm_gauge(self, fig, gridspec):
10 | """Showing the gradient norm versus iteration.
11 |
12 | If the training gets stuck, due to a small
13 | :class:`~cockpit.quantities.UpdateSize` it can be the result of both a badly
14 | chosen learning rate, or from a flat plateau in the loss landscape.
15 | This instrument shows the gradient norm at each iteration, overlayed with an
16 | exponentially weighted average, and can thus distinguish these two cases.
17 |
18 | **Preview**
19 |
20 | .. image:: ../../_static/instrument_previews/GradientNorm.png
21 | :alt: Preview GradientNorm Gauge
22 |
23 | **Requires**
24 |
25 | The gradient norm instrument requires data from the
26 | :class:`~cockpit.quantities.GradNorm` quantity class.
27 |
28 | Args:
29 | self (CockpitPlotter): The cockpit plotter requesting this instrument.
30 | fig (matplotlib.figure.Figure): Figure of the Cockpit.
31 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be
32 | placed
33 | """
34 | # Plot Trace vs iteration
35 | title = "Gradient Norm"
36 |
37 | # Check if the required data is available, else skip this instrument
38 | requires = ["GradNorm"]
39 | plot_possible = check_data(self.tracking_data, requires)
40 | if not plot_possible:
41 | if self.debug:
42 | warnings.warn(
43 | "Couldn't get the required data for the " + title + " instrument",
44 | stacklevel=1,
45 | )
46 | return
47 |
48 | # Compute
49 | self.tracking_data["GradNorm_all"] = self.tracking_data.GradNorm.map(
50 | lambda x: _root_sum_of_squares(x) if type(x) == list else x
51 | )
52 |
53 | plot_args = {
54 | "x": "iteration",
55 | "y": "GradNorm_all",
56 | "data": self.tracking_data,
57 | "x_scale": "symlog" if self.show_log_iter else "linear",
58 | "y_scale": "linear",
59 | "cmap": self.cmap,
60 | "EMA": "y",
61 | "EMA_alpha": self.EMA_alpha,
62 | "EMA_cmap": self.cmap2,
63 | "title": title,
64 | "xlim": "tight",
65 | "ylim": None,
66 | "fontweight": "bold",
67 | "facecolor": self.bg_color_instruments,
68 | }
69 | ax = fig.add_subplot(gridspec)
70 | create_basic_plot(**plot_args, ax=ax)
71 |
--------------------------------------------------------------------------------
/cockpit/instruments/histogram_1d_gauge.py:
--------------------------------------------------------------------------------
1 | """One-dimensional Histogram Gauge."""
2 |
3 | import warnings
4 |
5 | from cockpit.instruments.utils_instruments import _beautify_plot, check_data
6 |
7 |
8 | def histogram_1d_gauge(self, fig, gridspec, y_scale="log"):
9 | """One-dimensional histogram of the individual gradient elements.
10 |
11 | This instrument provides a histogram of the gradient element values across all
12 | individual gradients in a mini-batch. The histogram shows the distribution for
13 | the last tracked iteration only.
14 |
15 | **Preview**
16 |
17 | .. image:: ../../_static/instrument_previews/Hist1d.png
18 | :alt: Preview Hist1d Gauge
19 |
20 | **Requires**
21 |
22 | This two dimensional histogram instrument requires data from the
23 | :class:`~cockpit.quantities.GradHist1d` quantity class.
24 |
25 | Args:
26 | self (CockpitPlotter): The cockpit plotter requesting this instrument.
27 | fig (matplotlib.figure.Figure): Figure of the Cockpit.
28 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be
29 | placed
30 | y_scale (str, optional): Scale of the y-axis. Defaults to "log".
31 | """
32 | # Plot
33 | title = "Gradient Element Histogram"
34 |
35 | # Check if the required data is available, else skip this instrument
36 | requires = ["GradHist1d"]
37 | plot_possible = check_data(self.tracking_data, requires, min_elements=1)
38 | if not plot_possible:
39 | if self.debug:
40 | warnings.warn(
41 | "Couldn't get the required data for the " + title + " instrument",
42 | stacklevel=1,
43 | )
44 | return
45 |
46 | ax = fig.add_subplot(gridspec)
47 |
48 | plot_args = {
49 | "title": title,
50 | "fontweight": "bold",
51 | "facecolor": self.bg_color_instruments,
52 | "xlabel": "Gradient Element Value",
53 | "ylabel": "Frequency",
54 | "y_scale": y_scale,
55 | }
56 |
57 | vals, mid_points, width = _get_histogram_data(self.tracking_data)
58 |
59 | ax.bar(mid_points, vals, width=width, color=self.primary_color)
60 |
61 | _beautify_plot(ax=ax, **plot_args)
62 |
63 | ax.set_title(title, fontweight="bold", fontsize="large")
64 |
65 |
66 | def _get_histogram_data(tracking_data):
67 | """Returns the histogram data for the plot.
68 |
69 | Currently we return the bins and values of the last iteration tracked before
70 | this plot.
71 |
72 | Args:
73 | tracking_data (pandas.DataFrame): DataFrame holding the tracking data.
74 |
75 | Returns:
76 | list: Bins of the histogram.
77 | list: Mid points of the bins.
78 | list: Width of the bins.
79 | """
80 | clean_data = tracking_data.GradHist1d.dropna()
81 | last_step_data = clean_data[clean_data.index[-1]]
82 |
83 | vals = last_step_data["hist"]
84 | bins = last_step_data["edges"]
85 |
86 | width = bins[1] - bins[0]
87 |
88 | mid_points = (bins[1:] + bins[:-1]) / 2
89 |
90 | return vals, mid_points, width
91 |
--------------------------------------------------------------------------------
/cockpit/instruments/hyperparameter_gauge.py:
--------------------------------------------------------------------------------
1 | """Hyperparameter Gauge."""
2 |
3 | import warnings
4 |
5 | import seaborn as sns
6 |
7 | from cockpit.instruments.utils_instruments import (
8 | _add_last_value_to_legend,
9 | _beautify_plot,
10 | check_data,
11 | )
12 |
13 |
14 | def hyperparameter_gauge(self, fig, gridspec):
15 | """Hyperparameter gauge, currently showing the learning rate over time.
16 |
17 | This instrument visualizes the hyperparameters values over the course of the
18 | training. Currently, it shows the learning rate, the most likely parameter to
19 | be adapted during training. The current learning rate is additionally shown
20 | in the figure's legend.
21 |
22 | **Preview**
23 |
24 | .. image:: ../../_static/instrument_previews/Hyperparameters.png
25 | :alt: Preview Hyperparameter Gauge
26 |
27 | **Requires**
28 |
29 | This instrument requires the learning rate data passed via the
30 | :func:`cockpit.Cockpit.log()` method.
31 |
32 | Args:
33 | self (CockpitPlotter): The cockpit plotter requesting this instrument.
34 | fig (matplotlib.figure.Figure): Figure of the Cockpit.
35 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be
36 | placed
37 | """
38 | # Plot Trace vs iteration
39 | title = "Hyperparameters"
40 |
41 | # Check if the required data is available, else skip this instrument
42 | requires = ["iteration", "learning_rate"]
43 | plot_possible = check_data(self.tracking_data, requires)
44 | if not plot_possible:
45 | if self.debug:
46 | warnings.warn(
47 | "Couldn't get the required data for the " + title + " instrument",
48 | stacklevel=1,
49 | )
50 | return
51 |
52 | ax = fig.add_subplot(gridspec)
53 |
54 | clean_learning_rate = self.tracking_data[["iteration", "learning_rate"]].dropna()
55 |
56 | # Plot Settings
57 | plot_args = {
58 | "x": "iteration",
59 | "y": "learning_rate",
60 | "data": clean_learning_rate,
61 | }
62 | ylabel = plot_args["y"].replace("_", " ").title()
63 | sns.lineplot(
64 | **plot_args, ax=ax, label=ylabel, linewidth=2, color=self.secondary_color
65 | )
66 |
67 | _beautify_plot(
68 | ax=ax,
69 | xlabel=plot_args["x"],
70 | ylabel=ylabel,
71 | x_scale="symlog" if self.show_log_iter else "linear",
72 | title=title,
73 | xlim="tight",
74 | fontweight="bold",
75 | facecolor=self.bg_color_instruments2,
76 | )
77 |
78 | ax.legend()
79 | _add_last_value_to_legend(ax)
80 |
--------------------------------------------------------------------------------
/cockpit/instruments/max_ev_gauge.py:
--------------------------------------------------------------------------------
1 | """Max EV Gauge."""
2 |
3 | import warnings
4 |
5 | from matplotlib import ticker
6 |
7 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot
8 |
9 |
10 | def max_ev_gauge(self, fig, gridspec):
11 | """Showing the largest eigenvalue of the Hessian versus iteration.
12 |
13 | The largest eigenvalue of the Hessian indicates the loss surface's sharpest
14 | valley. Together with the :func:`~cockpit.instruments.trace_gauge()`, which
15 | provides a notion of "average curvature", it can help understand the "average
16 | condition number" of the loss landscape at the current point. The instrument
17 | shows the largest eigenvalue of the Hessian versus iteration, overlayed with
18 | an exponentially weighted average.
19 |
20 | **Preview**
21 |
22 | .. image:: ../../_static/instrument_previews/HessMaxEV.png
23 | :alt: Preview HessMaxEV Gauge
24 |
25 | **Requires**
26 |
27 | The trace instrument requires data from the :class:`~cockpit.quantities.HessMaxEv`
28 | quantity class.
29 |
30 | Args:
31 | self (CockpitPlotter): The cockpit plotter requesting this instrument.
32 | fig (matplotlib.figure.Figure): Figure of the Cockpit.
33 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be
34 | placed.
35 | """
36 | # Plot Trace vs iteration
37 | title = "Max Eigenvalue"
38 |
39 | # Check if the required data is available, else skip this instrument
40 | requires = ["HessMaxEV"]
41 | plot_possible = check_data(self.tracking_data, requires)
42 | if not plot_possible:
43 | if self.debug:
44 | warnings.warn(
45 | "Couldn't get the required data for the " + title + " instrument",
46 | stacklevel=1,
47 | )
48 | return
49 |
50 | plot_args = {
51 | "x": "iteration",
52 | "y": "HessMaxEV",
53 | "data": self.tracking_data,
54 | "x_scale": "symlog" if self.show_log_iter else "linear",
55 | "y_scale": "log",
56 | "cmap": self.cmap,
57 | "EMA": "y",
58 | "EMA_alpha": self.EMA_alpha,
59 | "EMA_cmap": self.cmap2,
60 | "title": title,
61 | "xlim": "tight",
62 | "ylim": None,
63 | "fontweight": "bold",
64 | "facecolor": self.bg_color_instruments,
65 | }
66 | # part that should be plotted
67 | ax = fig.add_subplot(gridspec)
68 | create_basic_plot(**plot_args, ax=ax)
69 |
70 | ax.yaxis.set_minor_formatter(ticker.FormatStrFormatter("%.2g"))
71 |
--------------------------------------------------------------------------------
/cockpit/instruments/mean_gsnr_gauge.py:
--------------------------------------------------------------------------------
1 | """Mean GSNR gauge."""
2 |
3 | import warnings
4 |
5 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot
6 |
7 |
8 | def mean_gsnr_gauge(self, fig, gridspec):
9 | """Mean GSNR gauge, showing the mean GSNR versus iteration.
10 |
11 | The mean GSNR describes the average gradient signal-to-noise-ratio. `Recent
12 | work `_ used this quantity to study the
13 | generalization performances of neural networks, noting "that larger GSNR during
14 | training process leads to better generalization performance. The instrument
15 | shows the mean GSNR versus iteration, overlayed with an exponentially weighted
16 | average.
17 |
18 | **Preview**
19 |
20 | .. image:: ../../_static/instrument_previews/MeanGSNR.png
21 | :alt: Preview MeanGSNR Gauge
22 |
23 | **Requires**
24 |
25 | This instrument requires data from the :class:`~cockpit.quantities.MeanGSNR`
26 | quantity class.
27 |
28 | Args:
29 | self (CockpitPlotter): The cockpit plotter requesting this instrument.
30 | fig (matplotlib.figure.Figure): Figure of the Cockpit.
31 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be
32 | placed
33 | """
34 | # Plot Trace vs iteration
35 | title = "Mean GSNR"
36 |
37 | # Check if the required data is available, else skip this instrument
38 | requires = ["MeanGSNR"]
39 | plot_possible = check_data(self.tracking_data, requires)
40 | if not plot_possible:
41 | if self.debug:
42 | warnings.warn(
43 | "Couldn't get the required data for the " + title + " instrument",
44 | stacklevel=1,
45 | )
46 | return
47 |
48 | plot_args = {
49 | "x": "iteration",
50 | "y": "MeanGSNR",
51 | "data": self.tracking_data,
52 | "x_scale": "symlog" if self.show_log_iter else "linear",
53 | "y_scale": "linear",
54 | "cmap": self.cmap,
55 | "EMA": "y",
56 | "EMA_alpha": self.EMA_alpha,
57 | "EMA_cmap": self.cmap2,
58 | "title": title,
59 | "xlim": "tight",
60 | "ylim": None,
61 | "fontweight": "bold",
62 | "facecolor": self.bg_color_instruments,
63 | }
64 | ax = fig.add_subplot(gridspec)
65 | create_basic_plot(**plot_args, ax=ax)
66 |
--------------------------------------------------------------------------------
/cockpit/instruments/performance_gauge.py:
--------------------------------------------------------------------------------
1 | """Performance Gauge."""
2 |
3 | import warnings
4 |
5 | import seaborn as sns
6 |
7 | from cockpit.instruments.utils_instruments import (
8 | _add_last_value_to_legend,
9 | check_data,
10 | create_basic_plot,
11 | )
12 |
13 |
14 | def performance_gauge(self, fig, gridspec):
15 | """Plotting train/valid accuracy vs. epoch and mini-batch loss vs. iteration.
16 |
17 | This instruments visualizes the currently most popular diagnostic metrics. It
18 | shows the mini-batch loss in each iteration (overlayed with an exponentially
19 | weighted average) as well as accuracies for both the training as well as the
20 | validation set. The current accuracy numbers are also shown in the legend.
21 |
22 | **Preview**
23 |
24 | .. image:: ../../_static/instrument_previews/Performance.png
25 | :alt: Preview Performance Gauge
26 |
27 | **Requires**
28 |
29 | This instrument visualizes quantities passed via the
30 | :func:`cockpit.Cockpit.log()` method.
31 |
32 | Args:
33 | self (CockpitPlotter): The cockpit plotter requesting this instrument.
34 | fig (matplotlib.figure.Figure): Figure of the Cockpit.
35 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be
36 | placed
37 | """
38 | # Plot Trace vs iteration
39 | title = "Performance Plot"
40 |
41 | # Check if the required data is available, else skip this instrument
42 | requires = ["iteration", "Loss"]
43 | plot_possible = check_data(self.tracking_data, requires)
44 | if not plot_possible:
45 | if self.debug:
46 | warnings.warn(
47 | "Couldn't get the loss data for the " + title + " instrument",
48 | stacklevel=1,
49 | )
50 | return
51 |
52 | # Mini-batch train loss
53 | plot_args = {
54 | "x": "iteration",
55 | "y": "Loss",
56 | "data": self.tracking_data,
57 | "EMA": "y",
58 | "EMA_alpha": self.EMA_alpha,
59 | "EMA_cmap": self.cmap2,
60 | "x_scale": "symlog" if self.show_log_iter else "linear",
61 | "y_scale": "linear",
62 | "cmap": self.cmap,
63 | "title": title,
64 | "xlim": "tight",
65 | "ylim": None,
66 | "fontweight": "bold",
67 | "facecolor": self.bg_color_instruments2,
68 | }
69 | ax = fig.add_subplot(gridspec)
70 | create_basic_plot(**plot_args, ax=ax)
71 |
72 | requires = ["iteration", "train_accuracy", "valid_accuracy"]
73 | plot_possible = check_data(self.tracking_data, requires)
74 | if not plot_possible:
75 | if self.debug:
76 | warnings.warn(
77 | "Couldn't get the accuracy data for the " + title + " instrument",
78 | stacklevel=1,
79 | )
80 | return
81 | else:
82 | clean_accuracies = self.tracking_data[
83 | ["iteration", "train_accuracy", "valid_accuracy"]
84 | ].dropna()
85 |
86 | # Train Accuracy
87 | plot_args = {
88 | "x": "iteration",
89 | "y": "train_accuracy",
90 | "data": clean_accuracies,
91 | }
92 | ax2 = ax.twinx()
93 | sns.lineplot(
94 | **plot_args,
95 | ax=ax2,
96 | label=plot_args["y"].title().replace("_", " "),
97 | linewidth=2,
98 | color=self.primary_color,
99 | )
100 |
101 | # Train Accuracy
102 | plot_args = {
103 | "x": "iteration",
104 | "y": "valid_accuracy",
105 | "data": clean_accuracies,
106 | }
107 | sns.lineplot(
108 | **plot_args,
109 | ax=ax2,
110 | label=plot_args["y"].title().replace("_", " "),
111 | linewidth=2,
112 | color=self.secondary_color,
113 | )
114 |
115 | # Customization
116 | ax2.set_ylim([0, 1])
117 | ax2.set_ylabel("Accuracy")
118 | _add_last_value_to_legend(ax2, percentage=True)
119 |
--------------------------------------------------------------------------------
/cockpit/instruments/tic_gauge.py:
--------------------------------------------------------------------------------
1 | """TIC Gauge."""
2 |
3 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot
4 |
5 |
6 | def tic_gauge(self, fig, gridspec):
7 | """TIC gauge, showing the TIC versus iteration.
8 |
9 | The TIC (either approximated via traces or using a diagonal approximation)
10 | describes the relation between the curvature and the gradient noise. `Recent
11 | work `_ suggested that *at a local minimum*,
12 | this quantitiy can estimate the generalization gap. This instrument shows the
13 | TIC versus iteration, overlayed with an exponentially weighted average.
14 |
15 | **Preview**
16 |
17 | .. image:: ../../_static/instrument_previews/TIC.png
18 | :alt: Preview TIC Gauge
19 |
20 | **Requires**
21 |
22 | The trace instrument requires data from the :class:`~cockpit.quantities.TICDiag`
23 | or :class:`~cockpit.quantities.TICTrace` quantity class.
24 |
25 | Args:
26 | self (CockpitPlotter): The cockpit plotter requesting this instrument.
27 | fig (matplotlib.figure.Figure): Figure of the Cockpit.
28 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be
29 | placed
30 | """
31 | # Plot Trace vs iteration
32 | title = "TIC"
33 |
34 | if check_data(self.tracking_data, ["TICDiag"]):
35 | plot_args = {
36 | "x": "iteration",
37 | "y": "TICDiag",
38 | "data": self.tracking_data,
39 | "x_scale": "symlog" if self.show_log_iter else "linear",
40 | "y_scale": "linear",
41 | "cmap": self.cmap,
42 | "EMA": "y",
43 | "EMA_alpha": self.EMA_alpha,
44 | "EMA_cmap": self.cmap2,
45 | "title": title,
46 | "xlim": "tight",
47 | "ylim": None,
48 | "fontweight": "bold",
49 | "facecolor": self.bg_color_instruments,
50 | }
51 | ax = fig.add_subplot(gridspec)
52 | create_basic_plot(**plot_args, ax=ax)
53 |
54 | if check_data(self.tracking_data, ["TICTrace"]):
55 | if "ax" in locals():
56 | ax2 = ax.twinx()
57 | else:
58 | ax2 = fig.add_subplot(gridspec)
59 | plot_args = {
60 | "x": "iteration",
61 | "y": "TICTrace",
62 | "data": self.tracking_data,
63 | "x_scale": "symlog" if self.show_log_iter else "linear",
64 | "y_scale": "linear",
65 | "cmap": self.cmap_backup,
66 | "EMA": "y",
67 | "EMA_alpha": self.EMA_alpha,
68 | "EMA_cmap": self.cmap2_backup,
69 | "title": title,
70 | "xlim": "tight",
71 | "ylim": None,
72 | "fontweight": "bold",
73 | "facecolor": self.bg_color_instruments,
74 | }
75 | create_basic_plot(**plot_args, ax=ax2)
76 |
--------------------------------------------------------------------------------
/cockpit/instruments/trace_gauge.py:
--------------------------------------------------------------------------------
1 | """Trace Gauge."""
2 |
3 | import warnings
4 |
5 | from cockpit.instruments.utils_instruments import check_data, create_basic_plot
6 |
7 |
8 | def trace_gauge(self, fig, gridspec):
9 | """Trace gauge, showing the trace of the Hessian versus iteration.
10 |
11 | The trace of the hessian is the sum of its eigenvalues and thus can indicate
12 | the overall or average curvature of the loss landscape at the current point.
13 | Increasing values for the trace indicate a steeper curvature, for example, a
14 | narrower valley. This instrument shows the trace versus iteration, overlayed
15 | with an exponentially weighted average.
16 |
17 | **Preview**
18 |
19 | .. image:: ../../_static/instrument_previews/HessTrace.png
20 | :alt: Preview HessTrace Gauge
21 |
22 | **Requires**
23 |
24 | The trace instrument requires data from the :class:`~cockpit.quantities.HessTrace`
25 | quantity class.
26 |
27 | Args:
28 | self (CockpitPlotter): The cockpit plotter requesting this instrument.
29 | fig (matplotlib.figure.Figure): Figure of the Cockpit.
30 | gridspec (matplotlib.gridspec.GridSpec): GridSpec where the instrument should be
31 | placed
32 | """
33 | # Plot Trace vs iteration
34 | title = "Trace"
35 |
36 | # Check if the required data is available, else skip this instrument
37 | requires = ["HessTrace"]
38 | plot_possible = check_data(self.tracking_data, requires)
39 | if not plot_possible:
40 | if self.debug:
41 | warnings.warn(
42 | "Couldn't get the required data for the " + title + " instrument",
43 | stacklevel=1,
44 | )
45 | return
46 |
47 | # Compute
48 | self.tracking_data["HessTrace_all"] = self.tracking_data.HessTrace.map(
49 | lambda x: sum(x) if type(x) == list else x
50 | )
51 |
52 | plot_args = {
53 | "x": "iteration",
54 | "y": "HessTrace_all",
55 | "data": self.tracking_data,
56 | "x_scale": "symlog" if self.show_log_iter else "linear",
57 | "y_scale": "linear",
58 | "cmap": self.cmap,
59 | "EMA": "y",
60 | "EMA_alpha": self.EMA_alpha,
61 | "EMA_cmap": self.cmap2,
62 | "title": title,
63 | "xlim": "tight",
64 | "ylim": None,
65 | "fontweight": "bold",
66 | "facecolor": self.bg_color_instruments,
67 | }
68 | ax = fig.add_subplot(gridspec)
69 | create_basic_plot(**plot_args, ax=ax)
70 |
--------------------------------------------------------------------------------
/cockpit/instruments/utils_plotting.py:
--------------------------------------------------------------------------------
1 | """Utility functions for the CockpitPlotter."""
2 |
3 | import numpy as np
4 | from matplotlib.colors import LinearSegmentedColormap
5 |
6 |
7 | def _extract_problem_info(source):
8 | """Split the logpath to identify test problem, data set, etc.
9 |
10 | Args:
11 | source (Cockpit or str): ``Cockpit`` instance, or string containing the
12 | path to a .json log produced with ``Cockpit.write``, where
13 | information will be fetched from.
14 |
15 | Returns:
16 | [dict]: Dictioniary of logpath, testproblem, optimizer, etc.
17 | """
18 | if isinstance(source, str):
19 | # Split logpath if possible
20 | try:
21 | dicty = {
22 | "logpath": source + ".json",
23 | "optimizer": source.split("/")[-3],
24 | "testproblem": source.split("/")[-4],
25 | "dataset": source.split("/")[-4].split("_", 1)[0],
26 | "model": source.split("/")[-4].split("_", 1)[1],
27 | }
28 | except Exception:
29 | dicty = {
30 | "logpath": source + ".json",
31 | "optimizer": "",
32 | "testproblem": "",
33 | "dataset": "",
34 | "model": "",
35 | }
36 | else:
37 | # Source is Cockpit instance
38 | dicty = {
39 | "logpath": "",
40 | "optimizer": source._optimizer_name,
41 | "testproblem": "",
42 | "dataset": "",
43 | "model": "",
44 | }
45 | return dicty
46 |
47 |
48 | def legend():
49 | """Creates the legend of the whole cockpit, combining the individual instruments."""
50 | pass
51 |
52 |
53 | def _alpha_cmap(color, ncolors=256):
54 | """Create a Color map that goes from transparant to a given color.
55 |
56 | Args:
57 | color (tuple): A matplotlib-compatible color.
58 | ncolors (int, optional): Number of "steps" in the colormap.
59 | Defaults to 256.
60 |
61 | Returns:
62 | [matplotlib.cmap]: A matplotlib colormap
63 | """
64 | color_array = np.array(ncolors * [list(color)])
65 |
66 | # change alpha values
67 | color_array[:, -1] = np.linspace(0.0, 1.0, ncolors)
68 | # create a colormap object
69 | cmap = LinearSegmentedColormap.from_list(name="alpha_cmap", colors=color_array)
70 |
71 | return cmap
72 |
--------------------------------------------------------------------------------
/cockpit/quantities/__init__.py:
--------------------------------------------------------------------------------
1 | """Quantities tracked during training."""
2 |
3 | from cockpit.quantities.alpha import Alpha
4 | from cockpit.quantities.cabs import CABS
5 | from cockpit.quantities.distance import Distance
6 | from cockpit.quantities.early_stopping import EarlyStopping
7 | from cockpit.quantities.grad_hist import GradHist1d, GradHist2d
8 | from cockpit.quantities.grad_norm import GradNorm
9 | from cockpit.quantities.hess_max_ev import HessMaxEV
10 | from cockpit.quantities.hess_trace import HessTrace
11 | from cockpit.quantities.inner_test import InnerTest
12 | from cockpit.quantities.loss import Loss
13 | from cockpit.quantities.mean_gsnr import MeanGSNR
14 | from cockpit.quantities.norm_test import NormTest
15 | from cockpit.quantities.ortho_test import OrthoTest
16 | from cockpit.quantities.parameters import Parameters
17 | from cockpit.quantities.tic import TICDiag, TICTrace
18 | from cockpit.quantities.time import Time
19 | from cockpit.quantities.update_size import UpdateSize
20 |
21 | __all__ = [
22 | "Loss",
23 | "Parameters",
24 | "Distance",
25 | "UpdateSize",
26 | "GradNorm",
27 | "Time",
28 | "Alpha",
29 | "CABS",
30 | "EarlyStopping",
31 | "GradHist1d",
32 | "GradHist2d",
33 | "NormTest",
34 | "InnerTest",
35 | "OrthoTest",
36 | "HessMaxEV",
37 | "HessTrace",
38 | "TICDiag",
39 | "TICTrace",
40 | "MeanGSNR",
41 | ]
42 |
--------------------------------------------------------------------------------
/cockpit/quantities/cabs.py:
--------------------------------------------------------------------------------
1 | """Class for tracking the CABS criterion for adaptive batch size."""
2 |
3 | from backpack.extensions import BatchGrad
4 |
5 | from cockpit.context import get_batch_size, get_optimizer
6 | from cockpit.quantities.quantity import SingleStepQuantity
7 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook_SumGradSquared
8 | from cockpit.utils.optim import ComputeStep
9 |
10 |
11 | class CABS(SingleStepQuantity):
12 | """CABS Quantity class for the suggested batch size using the CABS criterion.
13 |
14 | CABS uses the current learning rate and variance of the stochastic gradients
15 | to suggest an optimal batch size.
16 |
17 | Only applies to SGD without momentum.
18 |
19 | Note: Proposed in
20 |
21 | - Balles, L., Romero, J., & Hennig, P.,
22 | Coupling adaptive batch sizes with learning rates (2017).
23 | """
24 |
25 | def get_lr(self, optimizer):
26 | """Extract the learning rate.
27 |
28 | Args:
29 | optimizer (torch.optim.Optimizer): A PyTorch optimizer.
30 |
31 | Returns:
32 | float: Learning rate
33 |
34 | Raises:
35 | ValueError: If the learning rate varies over parameter groups.
36 | """
37 | lrs = {group["lr"] for group in optimizer.param_groups}
38 |
39 | if len(lrs) != 1:
40 | raise ValueError(f"Found non-unique learning rates {lrs}")
41 |
42 | return lrs.pop()
43 |
44 | def extensions(self, global_step):
45 | """Return list of BackPACK extensions required for the computation.
46 |
47 | Args:
48 | global_step (int): The current iteration number.
49 |
50 | Returns:
51 | list: (Potentially empty) list with required BackPACK quantities.
52 | """
53 | ext = []
54 |
55 | if self.should_compute(global_step):
56 | ext.append(BatchGrad())
57 |
58 | return ext
59 |
60 | def extension_hooks(self, global_step):
61 | """Return list of BackPACK extension hooks required for the computation.
62 |
63 | Args:
64 | global_step (int): The current iteration number.
65 |
66 | Returns:
67 | [callable]: List of required BackPACK extension hooks for the current
68 | iteration.
69 | """
70 | hooks = []
71 |
72 | if self.should_compute(global_step):
73 | hooks.append(BatchGradTransformsHook_SumGradSquared())
74 |
75 | return hooks
76 |
77 | def _compute(self, global_step, params, batch_loss):
78 | """Compute the CABS rule. Return suggested batch size.
79 |
80 | Evaluates Equ. 22 of
81 |
82 | - Balles, L., Romero, J., & Hennig, P.,
83 | Coupling adaptive batch sizes with learning rates (2017).
84 |
85 | Args:
86 | global_step (int): The current iteration number.
87 | params ([torch.Tensor]): List of torch.Tensors holding the network's
88 | parameters.
89 | batch_loss (torch.Tensor): Mini-batch loss from current step.
90 |
91 | Returns:
92 | float: Batch size suggested by CABS.
93 |
94 | Raises:
95 | ValueError: If the optimizer differs from SGD with default arguments.
96 | """
97 | optimizer = get_optimizer(global_step)
98 | if not ComputeStep.is_sgd_default_kwargs(optimizer):
99 | raise ValueError("This criterion only supports zero-momentum SGD.")
100 |
101 | B = get_batch_size(global_step)
102 | lr = self.get_lr(optimizer)
103 |
104 | grad_squared = self._fetch_grad(params, aggregate=True) ** 2
105 | # # compensate BackPACK's 1/B scaling
106 | sgs_compensated = (
107 | B ** 2
108 | * self._fetch_sum_grad_squared_via_batch_grad_transforms(
109 | params, aggregate=True
110 | )
111 | )
112 |
113 | return (
114 | lr * (sgs_compensated - B * grad_squared).sum() / (B * batch_loss)
115 | ).item()
116 |
--------------------------------------------------------------------------------
/cockpit/quantities/distance.py:
--------------------------------------------------------------------------------
1 | """Class for tracking distance from initialization."""
2 |
3 | from cockpit.quantities.quantity import TwoStepQuantity
4 |
5 |
6 | class Distance(TwoStepQuantity):
7 | """Distance Quantity class tracking distance of the parameters from their init."""
8 |
9 | CACHE_KEY = "params"
10 | """str: String under which the parameters are cached for computation.
11 | Default: ``'params'``.
12 | """
13 | INIT_GLOBAL_STEP = 0
14 | """int: Iteration number used as reference. Defaults to ``0``."""
15 |
16 | def extensions(self, global_step):
17 | """Return list of BackPACK extensions required for the computation.
18 |
19 | Args:
20 | global_step (int): The current iteration number.
21 |
22 | Returns:
23 | list: (Potentially empty) list with required BackPACK quantities.
24 | """
25 | return []
26 |
27 | def is_start(self, global_step):
28 | """Return whether current iteration is start point.
29 |
30 | Only the initializtion (first iteration) is a start point.
31 |
32 | Args:
33 | global_step (int): The current iteration number.
34 |
35 | Returns:
36 | bool: Whether ``global_step`` is a start point.
37 | """
38 | return global_step == self.INIT_GLOBAL_STEP
39 |
40 | def is_end(self, global_step):
41 | """Return whether current iteration is end point.
42 |
43 | Args:
44 | global_step (int): The current iteration number.
45 |
46 | Returns:
47 | bool: Whether ``global_step`` is an end point.
48 | """
49 | return self._track_schedule(global_step)
50 |
51 | def _compute_start(self, global_step, params, batch_loss):
52 | """Perform computations at start point (store initial parameter values).
53 |
54 | Modifies ``self._cache``.
55 |
56 | Args:
57 | global_step (int): The current iteration number.
58 | params ([torch.Tensor]): List of torch.Tensors holding the network's
59 | parameters.
60 | batch_loss (torch.Tensor): Mini-batch loss from current step.
61 | """
62 | params_copy = [p.data.clone().detach() for p in params]
63 |
64 | def block_fn(step):
65 | """Block deletion of parameters for all non-negative iterations.
66 |
67 | Args:
68 | step (int): Iteration number.
69 |
70 | Returns:
71 | bool: Whether deletion is blocked in the specified iteration
72 | """
73 | return step >= self.INIT_GLOBAL_STEP
74 |
75 | self.save_to_cache(global_step, self.CACHE_KEY, params_copy, block_fn)
76 |
77 | def _compute_end(self, global_step, params, batch_loss):
78 | """Compute and return the current distance from initialization.
79 |
80 | Args:
81 | global_step (int): The current iteration number.
82 | params ([torch.Tensor]): List of torch.Tensors holding the network's
83 | parameters.
84 | batch_loss (torch.Tensor): Mini-batch loss from current step.
85 |
86 | Returns:
87 | [float]: Layer-wise L2-distances to initialization.
88 | """
89 | params_init = self.load_from_cache(self.INIT_GLOBAL_STEP, self.CACHE_KEY)
90 |
91 | distance = [
92 | (p.data - p_init).norm(2).item() for p, p_init in zip(params, params_init)
93 | ]
94 |
95 | return distance
96 |
--------------------------------------------------------------------------------
/cockpit/quantities/early_stopping.py:
--------------------------------------------------------------------------------
1 | """Class for tracking the EB criterion for early stopping."""
2 |
3 | from backpack.extensions import BatchGrad
4 |
5 | from cockpit.context import get_batch_size, get_optimizer
6 | from cockpit.quantities.quantity import SingleStepQuantity
7 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook_SumGradSquared
8 | from cockpit.utils.optim import ComputeStep
9 |
10 |
11 | class EarlyStopping(SingleStepQuantity):
12 | """Quantity class for the evidence-based early-stopping criterion.
13 |
14 | This criterion uses local statistics of the gradients to indicate when training
15 | should be stopped. If the criterion exceeds zero, training should be stopped.
16 |
17 | Note: Proposed in
18 |
19 | - Mahsereci, M., Balles, L., Lassner, C., & Hennig, P.,
20 | Early stopping without a validation set (2017).
21 | """
22 |
23 | def __init__(self, track_schedule, verbose=False, epsilon=1e-5):
24 | """Initialization sets the tracking schedule & creates the output dict.
25 |
26 | Args:
27 | track_schedule (callable): Function that maps the ``global_step``
28 | to a boolean, which determines if the quantity should be computed.
29 | verbose (bool, optional): Turns on verbose mode. Defaults to ``False``.
30 | epsilon (float): Stabilization constant. Defaults to 0.0.
31 | """
32 | super().__init__(track_schedule, verbose=verbose)
33 |
34 | self._epsilon = epsilon
35 |
36 | def extensions(self, global_step):
37 | """Return list of BackPACK extensions required for the computation.
38 |
39 | Args:
40 | global_step (int): The current iteration number.
41 |
42 | Returns:
43 | list: (Potentially empty) list with required BackPACK quantities.
44 | """
45 | ext = []
46 |
47 | if self.should_compute(global_step):
48 | ext.append(BatchGrad())
49 |
50 | return ext
51 |
52 | def extension_hooks(self, global_step):
53 | """Return list of BackPACK extension hooks required for the computation.
54 |
55 | Args:
56 | global_step (int): The current iteration number.
57 |
58 | Returns:
59 | [callable]: List of required BackPACK extension hooks for the current
60 | iteration.
61 | """
62 | hooks = []
63 |
64 | if self.should_compute(global_step):
65 | hooks.append(BatchGradTransformsHook_SumGradSquared())
66 |
67 | return hooks
68 |
69 | def _compute(self, global_step, params, batch_loss):
70 | """Compute the EB early stopping criterion.
71 |
72 | Evaluates the left hand side of Equ. 7 in
73 |
74 | - Mahsereci, M., Balles, L., Lassner, C., & Hennig, P.,
75 | Early stopping without a validation set (2017).
76 |
77 | If this value exceeds 0, training should be stopped.
78 |
79 | Args:
80 | global_step (int): The current iteration number.
81 | params ([torch.Tensor]): List of torch.Tensors holding the network's
82 | parameters.
83 | batch_loss (torch.Tensor): Mini-batch loss from current step.
84 |
85 | Returns:
86 | float: Result of the Early stopping criterion. Training should stop
87 | if it is larger than 0.
88 |
89 | Raises:
90 | ValueError: If the used optimizer differs from SGD with default parameters.
91 | """
92 | if not ComputeStep.is_sgd_default_kwargs(get_optimizer(global_step)):
93 | raise ValueError("This criterion only supports zero-momentum SGD.")
94 |
95 | B = get_batch_size(global_step)
96 |
97 | grad_squared = self._fetch_grad(params, aggregate=True) ** 2
98 |
99 | # compensate BackPACK's 1/B scaling
100 | sgs_compensated = (
101 | B ** 2
102 | * self._fetch_sum_grad_squared_via_batch_grad_transforms(
103 | params, aggregate=True
104 | )
105 | )
106 |
107 | diag_variance = (sgs_compensated - B * grad_squared) / (B - 1)
108 |
109 | snr = grad_squared / (diag_variance + self._epsilon)
110 |
111 | return 1 - B * snr.mean().item()
112 |
--------------------------------------------------------------------------------
/cockpit/quantities/grad_norm.py:
--------------------------------------------------------------------------------
1 | """Class for tracking the Gradient Norm."""
2 |
3 | from cockpit.quantities.quantity import ByproductQuantity
4 |
5 |
6 | class GradNorm(ByproductQuantity):
7 | """Quantitiy Class for tracking the norm of the mean gradient."""
8 |
9 | def _compute(self, global_step, params, batch_loss):
10 | """Evaluate the gradient norm at the current point.
11 |
12 | Args:
13 | global_step (int): The current iteration number.
14 | params ([torch.Tensor]): List of torch.Tensors holding the network's
15 | parameters.
16 | batch_loss (torch.Tensor): Mini-batch loss from current step.
17 |
18 | Returns:
19 | torch.Tensor: The quantity's value.
20 | """
21 | return [p.grad.data.norm(2).item() for p in params]
22 |
--------------------------------------------------------------------------------
/cockpit/quantities/hess_trace.py:
--------------------------------------------------------------------------------
1 | """Class for tracking the Trace of the Hessian or an approximation thereof."""
2 |
3 | from backpack import extensions
4 |
5 | from cockpit.quantities.quantity import SingleStepQuantity
6 |
7 |
8 | class HessTrace(SingleStepQuantity):
9 | """Quantitiy Class tracking the trace of the Hessian during training."""
10 |
11 | extensions_from_str = {
12 | "diag_h": extensions.DiagHessian,
13 | "diag_ggn_exact": extensions.DiagGGNExact,
14 | "diag_ggn_mc": extensions.DiagGGNMC,
15 | }
16 |
17 | def __init__(self, track_schedule, verbose=False, curvature="diag_h"):
18 | """Initialization sets the tracking schedule & creates the output dict.
19 |
20 | Note:
21 | The curvature options ``"diag_h"`` and ``"diag_ggn_exact"`` are more
22 | expensive than ``"diag_ggn_mc"``, but more precise. For a classification
23 | task with ``C`` classes, the former require that ``C`` times more
24 | information be backpropagated through the computation graph.
25 |
26 | Args:
27 | track_schedule (callable): Function that maps the ``global_step``
28 | to a boolean, which determines if the quantity should be computed.
29 | verbose (bool, optional): Turns on verbose mode. Defaults to ``False``.
30 | curvature (string): Which diagonal curvature approximation should be used.
31 | Options are ``"diag_h"``, ``"diag_ggn_exact"``, ``"diag_ggn_mc"``.
32 | """
33 | super().__init__(track_schedule, verbose=verbose)
34 |
35 | self._curvature = curvature
36 |
37 | def extensions(self, global_step):
38 | """Return list of BackPACK extensions required for the computation.
39 |
40 | Args:
41 | global_step (int): The current iteration number.
42 |
43 | Raises:
44 | KeyError: If curvature string has unknown associated extension.
45 |
46 | Returns:
47 | list: (Potentially empty) list with required BackPACK quantities.
48 | """
49 | ext = []
50 |
51 | if self.should_compute(global_step):
52 | try:
53 | ext.append(self.extensions_from_str[self._curvature]())
54 | except KeyError as e:
55 | available = list(self.extensions_from_str.keys())
56 | raise KeyError(f"Available: {available}") from e
57 |
58 | return ext
59 |
60 | def _compute(self, global_step, params, batch_loss):
61 | """Evaluate the trace of the Hessian at the current point.
62 |
63 | Args:
64 | global_step (int): The current iteration number.
65 | params ([torch.Tensor]): List of torch.Tensors holding the network's
66 | parameters.
67 | batch_loss (torch.Tensor): Mini-batch loss from current step.
68 |
69 | Returns:
70 | list: Trace of the Hessian at the current point.
71 | """
72 | return [
73 | diag_c.sum().item()
74 | for diag_c in self._fetch_diag_curvature(
75 | params, self._curvature, aggregate=False
76 | )
77 | ]
78 |
--------------------------------------------------------------------------------
/cockpit/quantities/hooks/__init__.py:
--------------------------------------------------------------------------------
1 | """BackPACK extension hooks.
2 |
3 | Cockpit leverages extension hooks to compact information from BackPACK buffers
4 | which can then be freed immediately during backpropagation.
5 | """
6 |
--------------------------------------------------------------------------------
/cockpit/quantities/hooks/cleanup.py:
--------------------------------------------------------------------------------
1 | """Contains hook that deletes BackPACK buffers during backpropagation."""
2 |
3 | from typing import Set
4 |
5 | from torch import Tensor
6 |
7 | from cockpit.quantities.hooks.base import ParameterExtensionHook
8 |
9 |
10 | class CleanupHook(ParameterExtensionHook):
11 | """Deletes specified BackPACK buffers during backpropagation."""
12 |
13 | def __init__(self, delete_savefields: Set[str]):
14 | """Store savefields to be deleted in the backward pass.
15 |
16 | Args:
17 | delete_savefields: Name of buffers to delete.
18 | """
19 | super().__init__()
20 | self._delete_savefields = delete_savefields
21 |
22 | def param_hook(self, param: Tensor):
23 | """Delete BackPACK buffers in parameter.
24 |
25 | Args:
26 | param: Trainable parameter which hosts BackPACK quantities.
27 | """
28 | for savefield in self._delete_savefields:
29 | if hasattr(param, savefield):
30 | delattr(param, savefield)
31 |
--------------------------------------------------------------------------------
/cockpit/quantities/inner_test.py:
--------------------------------------------------------------------------------
1 | """Class for tracking the Inner Product Test."""
2 |
3 |
4 | from backpack.extensions import BatchGrad
5 |
6 | from cockpit.quantities.quantity import SingleStepQuantity
7 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook_BatchDotGrad
8 |
9 |
10 | class InnerTest(SingleStepQuantity):
11 | """Quantitiy Class for tracking the result of the inner product test.
12 |
13 | Note: Inner Product test as proposed in
14 |
15 | - Bollapragada, R., Byrd, R., & Nocedal, J.,
16 | Adaptive Sampling Strategies for Stochastic Optimization (2017).
17 | https://arxiv.org/abs/1710.11258
18 | """
19 |
20 | def extensions(self, global_step):
21 | """Return list of BackPACK extensions required for the computation.
22 |
23 | Args:
24 | global_step (int): The current iteration number.
25 |
26 | Returns:
27 | list: (Potentially empty) list with required BackPACK quantities.
28 | """
29 | ext = []
30 |
31 | if self.should_compute(global_step):
32 | ext.append(BatchGrad())
33 |
34 | return ext
35 |
36 | def extension_hooks(self, global_step):
37 | """Return list of BackPACK extension hooks required for the computation.
38 |
39 | Args:
40 | global_step (int): The current iteration number.
41 |
42 | Returns:
43 | [callable]: List of required BackPACK extension hooks for the current
44 | iteration.
45 | """
46 | hooks = []
47 |
48 | if self.should_compute(global_step):
49 | hooks.append(BatchGradTransformsHook_BatchDotGrad())
50 |
51 | return hooks
52 |
53 | def _compute(self, global_step, params, batch_loss):
54 | """Track the practical version of the inner product test.
55 |
56 | Return maximum θ for which the inner product test would pass.
57 |
58 | The inner product test is defined by Equation (2.6) in bollapragada2017adaptive.
59 |
60 | Args:
61 | global_step (int): The current iteration number.
62 | params ([torch.Tensor]): List of torch.Tensors holding the network's
63 | parameters.
64 | batch_loss (torch.Tensor): Mini-batch loss from current step.
65 |
66 | Returns:
67 | float: Maximum θ for which the inner product test would pass.
68 | """
69 | batch_dot = self._fetch_batch_dot_via_batch_grad_transforms(
70 | params, aggregate=True
71 | )
72 | grad_l2_squared = self._fetch_grad_l2_squared(params, aggregate=True)
73 | batch_size = batch_dot.size(0)
74 |
75 | var_projection = self._compute_projection_variance(
76 | batch_size, batch_dot, grad_l2_squared
77 | )
78 |
79 | return self._compute_theta_max(
80 | batch_size, var_projection, grad_l2_squared
81 | ).item()
82 |
83 | def _compute_theta_max(self, batch_size, var_projection, grad_l2_squared):
84 | """Return maximum θ for which the inner product test would pass.
85 |
86 | Args:
87 | batch_size (int): Mini-batch size.
88 | var_projection (torch.Tensor): The sample variance of individual
89 | gradient projections on the mini-batch gradient.
90 | grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient.
91 |
92 | Returns:
93 | [type]: [description]
94 | """
95 | return (var_projection / batch_size / grad_l2_squared ** 2).sqrt()
96 |
97 | def _compute_projection_variance(self, batch_size, batch_dot, grad_l2_squared):
98 | """Compute sample variance of individual gradient projections onto the gradient.
99 |
100 | The sample variance of projections is given by Equation (line after 2.6) in
101 | bollapragada2017adaptive (https://arxiv.org/pdf/1710.11258.pdf)
102 |
103 | Args:
104 | batch_size (int): Mini-batch size.
105 | batch_dot (torch.Tensor): Individual gradient pairwise dot product.
106 | grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient.
107 |
108 | Returns:
109 | torch.Tensor: The sample variance of individual gradient projections on the
110 | mini-batch gradient.
111 | """
112 | projections = batch_size * batch_dot.sum(1)
113 |
114 | return (1 / (batch_size - 1)) * (
115 | (projections ** 2).sum() - batch_size * grad_l2_squared ** 2
116 | )
117 |
--------------------------------------------------------------------------------
/cockpit/quantities/loss.py:
--------------------------------------------------------------------------------
1 | """Class for tracking the loss."""
2 |
3 | from cockpit.quantities.quantity import ByproductQuantity
4 |
5 |
6 | class Loss(ByproductQuantity):
7 | """Loss Quantity class tracking the mini-batch training loss during training."""
8 |
9 | def _compute(self, global_step, params, batch_loss):
10 | """Track the loss at the current point.
11 |
12 | Args:
13 | global_step (int): The current iteration number.
14 | params ([torch.Tensor]): List of torch.Tensors holding the network's
15 | parameters.
16 | batch_loss (torch.Tensor): Mini-batch loss from current step.
17 |
18 | Returns:
19 | float: Mini-batch loss at the current iteration.
20 | """
21 | return batch_loss.item()
22 |
--------------------------------------------------------------------------------
/cockpit/quantities/mean_gsnr.py:
--------------------------------------------------------------------------------
1 | """Class for tracking the Mean Gradient Signal to Noise Ration (GSNR)."""
2 |
3 |
4 | from backpack.extensions import BatchGrad
5 |
6 | from cockpit.context import get_batch_size
7 | from cockpit.quantities.quantity import SingleStepQuantity
8 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook_SumGradSquared
9 |
10 |
11 | class MeanGSNR(SingleStepQuantity):
12 | """Quantitiy Class for the mean gradient signal-to-noise ratio (GSNR).
13 |
14 | Note: Mean gradient signal-to-noise ratio as defined by
15 |
16 | - Liu, J., et al.
17 | Understanding Why Neural Networks Generalize Well Through GSNR of
18 | Parameters (2020).
19 | https://arxiv.org/abs/2001.07384
20 | """
21 |
22 | def __init__(self, track_schedule, verbose=False, epsilon=1e-5):
23 | """Initialize.
24 |
25 | Args:
26 | track_schedule (callable): Function that maps the ``global_step``
27 | to a boolean, which determines if the quantity should be computed.
28 | verbose (bool, optional): Turns on verbose mode. Defaults to ``False``.
29 | epsilon (float): Stabilization constant. Defaults to 1e-5.
30 | """
31 | super().__init__(track_schedule, verbose=verbose)
32 |
33 | self._epsilon = epsilon
34 |
35 | def extensions(self, global_step):
36 | """Return list of BackPACK extensions required for the computation.
37 |
38 | Args:
39 | global_step (int): The current iteration number.
40 |
41 | Returns:
42 | list: (Potentially empty) list with required BackPACK quantities.
43 | """
44 | ext = []
45 |
46 | if self.should_compute(global_step):
47 | ext.append(BatchGrad())
48 |
49 | return ext
50 |
51 | def extension_hooks(self, global_step):
52 | """Return list of BackPACK extension hooks required for the computation.
53 |
54 | Args:
55 | global_step (int): The current iteration number.
56 |
57 | Returns:
58 | [callable]: List of required BackPACK extension hooks for the current
59 | iteration.
60 | """
61 | hooks = []
62 |
63 | if self.should_compute(global_step):
64 | hooks.append(BatchGradTransformsHook_SumGradSquared())
65 |
66 | return hooks
67 |
68 | def _compute(self, global_step, params, batch_loss):
69 | """Track the mean GSNR.
70 |
71 | Args:
72 | global_step (int): The current iteration number.
73 | params ([torch.Tensor]): List of torch.Tensors holding the network's
74 | parameters.
75 | batch_loss (torch.Tensor): Mini-batch loss from current step.
76 |
77 | Returns:
78 | float: Mean GSNR of the current iteration.
79 | """
80 | return self._compute_gsnr(global_step, params, batch_loss).mean().item()
81 |
82 | def _compute_gsnr(self, global_step, params, batch_loss):
83 | """Compute gradient signal-to-noise ratio.
84 |
85 | Args:
86 | global_step (int): The current iteration number.
87 | params ([torch.Tensor]): List of parameters considered in the computation.
88 | batch_loss (torch.Tensor): Mini-batch loss from current step.
89 |
90 | Returns:
91 | float: Mean GSNR of the current iteration.
92 | """
93 | grad_squared = self._fetch_grad(params, aggregate=True) ** 2
94 | sum_grad_squared = self._fetch_sum_grad_squared_via_batch_grad_transforms(
95 | params, aggregate=True
96 | )
97 |
98 | batch_size = get_batch_size(global_step)
99 |
100 | return grad_squared / (
101 | batch_size * sum_grad_squared - grad_squared + self._epsilon
102 | )
103 |
--------------------------------------------------------------------------------
/cockpit/quantities/norm_test.py:
--------------------------------------------------------------------------------
1 | """Class for tracking the Norm Test."""
2 |
3 | from backpack.extensions import BatchGrad
4 |
5 | from cockpit.quantities.quantity import SingleStepQuantity
6 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook_BatchL2Grad
7 |
8 |
9 | class NormTest(SingleStepQuantity):
10 | """Quantitiy Class for the norm test.
11 |
12 | Note: Norm test as proposed in
13 |
14 | - Byrd, R., Chin, G., Nocedal, J., & Wu, Y.,
15 | Sample size selection in optimization methods for machine learning (2012).
16 | https://link.springer.com/article/10.1007%2Fs10107-012-0572-5
17 | """
18 |
19 | def extensions(self, global_step):
20 | """Return list of BackPACK extensions required for the computation.
21 |
22 | Args:
23 | global_step (int): The current iteration number.
24 |
25 | Returns:
26 | list: (Potentially empty) list with required BackPACK quantities.
27 | """
28 | ext = []
29 |
30 | if self.should_compute(global_step):
31 | ext.append(BatchGrad())
32 |
33 | return ext
34 |
35 | def extension_hooks(self, global_step):
36 | """Return list of BackPACK extension hooks required for the computation.
37 |
38 | Args:
39 | global_step (int): The current iteration number.
40 |
41 | Returns:
42 | [callable]: List of required BackPACK extension hooks for the current
43 | iteration.
44 | """
45 | hooks = []
46 |
47 | if self.should_compute(global_step):
48 | hooks.append(BatchGradTransformsHook_BatchL2Grad())
49 |
50 | return hooks
51 |
52 | def _compute(self, global_step, params, batch_loss):
53 | """Track the practical version of the norm test.
54 |
55 | Return maximum θ for which the norm test would pass.
56 |
57 | The norm test is defined by Equation (3.9) in byrd2012sample.
58 |
59 | Args:
60 | global_step (int): The current iteration number.
61 | params ([torch.Tensor]): List of torch.Tensors holding the network's
62 | parameters.
63 | batch_loss (torch.Tensor): Mini-batch loss from current step.
64 |
65 | Returns:
66 | float: Maximum θ for which the norm test would pass.
67 | """
68 | batch_l2_squared = self._fetch_batch_l2_squared_via_batch_grad_transforms(
69 | params, aggregate=True
70 | )
71 | grad_l2_squared = self._fetch_grad_l2_squared(params, aggregate=True)
72 | batch_size = batch_l2_squared.size(0)
73 |
74 | var_l1 = self._compute_variance_l1(
75 | batch_size, batch_l2_squared, grad_l2_squared
76 | )
77 |
78 | return self._compute_theta_max(batch_size, var_l1, grad_l2_squared).item()
79 |
80 | def _compute_theta_max(self, batch_size, var_l1, grad_l2_squared):
81 | """Return maximum θ for which the norm test would pass.
82 |
83 | Args:
84 | batch_size (int): Mini-batch size.
85 | var_l1 (torch.Tensor): [description]
86 | grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient.
87 |
88 | Returns:
89 | [type]: [description]
90 | """
91 | return (var_l1 / batch_size / grad_l2_squared).sqrt()
92 |
93 | def _compute_variance_l1(self, batch_size, batch_l2_squared, grad_l2_squared):
94 | """Compute the sample variance ℓ₁ norm.
95 |
96 | It shows up in Equations (3.9) and (3.11) in byrd2012sample and relies
97 | on the sample variance (Equation 3.6). The ℓ₁ norm can be computed using
98 | individual gradient squared ℓ₂ norms and the mini-batch gradient squared
99 | ℓ₂ norm.
100 |
101 | Args:
102 | batch_size (int): Mini-batch size.
103 | batch_l2_squared (torch.Tensor): [description]
104 | grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient.
105 |
106 | Returns:
107 | torch.Tensor: The sample variance ℓ₁ norm.
108 | """
109 | return (1 / (batch_size - 1)) * (
110 | batch_size ** 2 * batch_l2_squared.sum() - batch_size * grad_l2_squared
111 | )
112 |
--------------------------------------------------------------------------------
/cockpit/quantities/ortho_test.py:
--------------------------------------------------------------------------------
1 | """Class for tracking the Orthogonality Test."""
2 |
3 | from backpack.extensions import BatchGrad
4 |
5 | from cockpit.quantities.quantity import SingleStepQuantity
6 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook_BatchDotGrad
7 |
8 |
9 | class OrthoTest(SingleStepQuantity):
10 | """Quantity Class for the orthogonality test.
11 |
12 | Note: Orthogonality test as proposed in
13 |
14 | - Bollapragada, R., Byrd, R., & Nocedal, J.,
15 | Adaptive Sampling Strategies for Stochastic Optimization (2017).
16 | https://arxiv.org/abs/1710.11258
17 | """
18 |
19 | def extensions(self, global_step):
20 | """Return list of BackPACK extensions required for the computation.
21 |
22 | Args:
23 | global_step (int): The current iteration number.
24 |
25 | Returns:
26 | list: (Potentially empty) list with required BackPACK quantities.
27 | """
28 | ext = []
29 |
30 | if self.should_compute(global_step):
31 | ext.append(BatchGrad())
32 |
33 | return ext
34 |
35 | def extension_hooks(self, global_step):
36 | """Return list of BackPACK extension hooks required for the computation.
37 |
38 | Args:
39 | global_step (int): The current iteration number.
40 |
41 | Returns:
42 | [callable]: List of required BackPACK extension hooks for the current
43 | iteration.
44 | """
45 | hooks = []
46 |
47 | if self.should_compute(global_step):
48 | hooks.append(BatchGradTransformsHook_BatchDotGrad())
49 |
50 | return hooks
51 |
52 | def _compute(self, global_step, params, batch_loss):
53 | """Track the practical version of the orthogonality test.
54 |
55 | Return maximum ν for which the orthogonality test would pass.
56 |
57 | The orthogonality test is defined by Equation (3.3) in bollapragada2017adaptive.
58 |
59 | Args:
60 | global_step (int): The current iteration number.
61 | params ([torch.Tensor]): List of torch.Tensors holding the network's
62 | parameters.
63 | batch_loss (torch.Tensor): Mini-batch loss from current step.
64 |
65 | Returns:
66 | float: Maximum ν for which the orthogonality test would pass.
67 | """
68 | batch_dot = self._fetch_batch_dot_via_batch_grad_transforms(
69 | params, aggregate=True
70 | )
71 | batch_size = batch_dot.size(0)
72 | grad_l2_squared = self._fetch_grad_l2_squared(params, aggregate=True)
73 |
74 | var_orthogonal_projection = self._compute_orthogonal_projection_variance(
75 | batch_size, batch_dot, grad_l2_squared
76 | )
77 |
78 | return self._compute_nu_max(
79 | batch_size, var_orthogonal_projection, grad_l2_squared
80 | ).item()
81 |
82 | def _compute_nu_max(self, batch_size, var_orthogonal_projection, grad_l2_squared):
83 | """Return maximum ν for which the orthogonality test would pass.
84 |
85 | The orthogonality test is defined by Equation (3.3) in
86 | bollapragada2017adaptive.
87 |
88 | Args:
89 | batch_size (int): Mini-batch size.
90 | var_orthogonal_projection (torch.Tensor): [description]
91 | grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient.
92 |
93 | Returns:
94 | [type]: Maximum ν for which the orthogonality test would pass.
95 | """
96 | return (var_orthogonal_projection / batch_size / grad_l2_squared).sqrt()
97 |
98 | def _compute_orthogonal_projection_variance(
99 | self, batch_size, batch_dot, grad_l2_squared
100 | ):
101 | """Compute sample variance of individual gradient orthogonal projections.
102 |
103 | The sample variance of orthogonal projections shows up in Equation (3.3) in
104 | bollapragada2017adaptive (https://arxiv.org/pdf/1710.11258.pdf)
105 |
106 | Args:
107 | batch_size (int): Mini-batch size.
108 | batch_dot (torch.Tensor): Individual gradient pairwise dot product.
109 | grad_l2_squared (torch.Tensor): Squared ℓ₂ norm of mini-batch gradient.
110 |
111 | Returns:
112 | torch.Tensor: The sample variance of individual gradient orthogonal
113 | projections on the mini-batch gradient.
114 | """
115 | batch_l2_squared = batch_dot.diag()
116 | projections = batch_size * batch_dot.sum(1)
117 |
118 | return (1 / (batch_size - 1)) * (
119 | batch_size ** 2 * batch_l2_squared.sum()
120 | - (projections ** 2 / grad_l2_squared).sum()
121 | )
122 |
--------------------------------------------------------------------------------
/cockpit/quantities/parameters.py:
--------------------------------------------------------------------------------
1 | """Class for tracking the Individual Parameters."""
2 |
3 | from cockpit.quantities.quantity import ByproductQuantity
4 |
5 |
6 | class Parameters(ByproductQuantity):
7 | """Parameter Quantitiy class tracking the current parameters in each iteration."""
8 |
9 | def _compute(self, global_step, params, batch_loss):
10 | """Store the current parameter.
11 |
12 | Args:
13 | global_step (int): The current iteration number.
14 | params ([torch.Tensor]): List of torch.Tensors holding the network's
15 | parameters.
16 | batch_loss (torch.Tensor): Mini-batch loss from current step.
17 |
18 | Returns:
19 | list: Current model parameters.
20 | """
21 | return [p.data.tolist() for p in params]
22 |
--------------------------------------------------------------------------------
/cockpit/quantities/time.py:
--------------------------------------------------------------------------------
1 | """Class for tracking the time."""
2 |
3 | import time
4 |
5 | from cockpit.quantities.quantity import ByproductQuantity
6 |
7 |
8 | class Time(ByproductQuantity):
9 | """Time Quantity Class tracking the time during training."""
10 |
11 | def _compute(self, global_step, params, batch_loss):
12 | """Return the time at the current point.
13 |
14 | Args:
15 | global_step (int): The current iteration number.
16 | params ([torch.Tensor]): List of torch.Tensors holding the network's
17 | parameters.
18 | batch_loss (torch.Tensor): Mini-batch loss from current step.
19 |
20 | Returns:
21 | float: Current time as given by ``time.time()``.
22 | """
23 | return time.time()
24 |
--------------------------------------------------------------------------------
/cockpit/quantities/update_size.py:
--------------------------------------------------------------------------------
1 | """Class for tracking the update size."""
2 |
3 | from cockpit.quantities.quantity import TwoStepQuantity
4 |
5 |
6 | class UpdateSize(TwoStepQuantity):
7 | """Quantity class for tracking parameter update sizes."""
8 |
9 | CACHE_KEY = "params"
10 | """str: String under which the parameters are cached for computation.
11 | Default: ``'params'``.
12 | """
13 | SAVE_SHIFT = 1
14 | """int: Difference between iteration at which information is computed versus
15 | iteration under which it is stored. For instance, if set to ``1``, the
16 | information computed at iteration ``n + 1`` is saved under iteration ``n``.
17 | Defaults to ``1``.
18 | """
19 |
20 | def extensions(self, global_step):
21 | """Return list of BackPACK extensions required for the computation.
22 |
23 | Args:
24 | global_step (int): The current iteration number.
25 |
26 | Returns:
27 | list: (Potentially empty) list with required BackPACK quantities.
28 | """
29 | return []
30 |
31 | def is_start(self, global_step):
32 | """Return whether current iteration is start point.
33 |
34 | Args:
35 | global_step (int): The current iteration number.
36 |
37 | Returns:
38 | bool: Whether ``global_step`` is a start point.
39 | """
40 | return self._track_schedule(global_step)
41 |
42 | def is_end(self, global_step):
43 | """Return whether current iteration is end point.
44 |
45 | Args:
46 | global_step (int): The current iteration number.
47 |
48 | Returns:
49 | bool: Whether ``global_step`` is an end point.
50 | """
51 | return self._track_schedule(global_step - self.SAVE_SHIFT)
52 |
53 | def _compute_start(self, global_step, params, batch_loss):
54 | """Perform computations at start point (store current parameter values).
55 |
56 | Modifies ``self._cache``.
57 |
58 | Args:
59 | global_step (int): The current iteration number.
60 | params ([torch.Tensor]): List of torch.Tensors holding the network's
61 | parameters.
62 | batch_loss (torch.Tensor): Mini-batch loss from current step.
63 | """
64 | params_copy = [p.data.clone().detach() for p in params]
65 |
66 | def block_fn(step):
67 | """Block deletion of parameters for current and next iteration.
68 |
69 | Args:
70 | step (int): Iteration number.
71 |
72 | Returns:
73 | bool: Whether deletion is blocked in the specified iteration
74 | """
75 | return 0 <= step - global_step <= self.SAVE_SHIFT
76 |
77 | self.save_to_cache(global_step, self.CACHE_KEY, params_copy, block_fn)
78 |
79 | def _compute_end(self, global_step, params, batch_loss):
80 | """Compute and return update size.
81 |
82 | Args:
83 | global_step (int): The current iteration number.
84 | params ([torch.Tensor]): List of torch.Tensors holding the network's
85 | parameters.
86 | batch_loss (torch.Tensor): Mini-batch loss from current step.
87 |
88 | Returns:
89 | [float]: Layer-wise L2-norms of parameter updates.
90 | """
91 | params_start = self.load_from_cache(
92 | global_step - self.SAVE_SHIFT, self.CACHE_KEY
93 | )
94 |
95 | update_size = [
96 | (p.data - p_start).norm(2).item()
97 | for p, p_start in zip(params, params_start)
98 | ]
99 |
100 | return update_size
101 |
--------------------------------------------------------------------------------
/cockpit/quantities/utils_quantities.py:
--------------------------------------------------------------------------------
1 | """Utility Functions for the Quantities and the Tracking in General."""
2 |
3 | import torch
4 |
5 |
6 | def _layerwise_dot_product(x_s, y_s):
7 | """Computes the dot product of two parameter vectors layerwise.
8 |
9 | Args:
10 | x_s (list): First list of parameter vectors.
11 | y_s (list): Second list of parameter vectors.
12 |
13 | Returns:
14 | torch.Tensor: 1-D list of scalars. Each scalar is a dot product of one layer.
15 | """
16 | return [torch.sum(x * y).item() for x, y in zip(x_s, y_s)]
17 |
18 |
19 | def _root_sum_of_squares(list):
20 | """Returns the root of the sum of squares of a given list.
21 |
22 | Args:
23 | list (list): A list of floats
24 |
25 | Returns:
26 | [float]: Root sum of squares
27 | """
28 | return sum((el ** 2 for el in list)) ** (0.5)
29 |
30 |
31 | def abs_max(tensor):
32 | """Return maximum absolute entry in ``tensor``."""
33 | min_val, max_val = tensor.min(), tensor.max()
34 | return max(min_val.abs(), max_val.abs())
35 |
--------------------------------------------------------------------------------
/cockpit/quantities/utils_transforms.py:
--------------------------------------------------------------------------------
1 | """Utility functions for Transforms."""
2 |
3 | import string
4 | import weakref
5 |
6 | from torch import Tensor, einsum
7 |
8 | from cockpit.quantities.hooks.base import ParameterExtensionHook
9 |
10 |
11 | def BatchGradTransformsHook_BatchL2Grad():
12 | """Compute individual gradient ℓ₂ norms via individual gradients."""
13 | return BatchGradTransformsHook({"batch_l2": batch_l2_transform})
14 |
15 |
16 | def batch_l2_transform(batch_grad):
17 | """Transform individual gradients into individual ℓ₂ norms."""
18 | sum_axes = list(range(batch_grad.dim()))[1:]
19 | return (batch_grad ** 2).sum(sum_axes)
20 |
21 |
22 | def BatchGradTransformsHook_BatchDotGrad():
23 | """Compute pairwise individual gradient dot products via individual gradients."""
24 | return BatchGradTransformsHook({"batch_dot": batch_dot_transform})
25 |
26 |
27 | def batch_dot_transform(batch_grad):
28 | """Transform individual gradients into pairwise dot products."""
29 | # make einsum string
30 | letters = get_first_n_alphabet(batch_grad.dim() + 1)
31 | n1, n2, sum_out = letters[0], letters[1], "".join(letters[2:])
32 |
33 | einsum_equation = f"{n1}{sum_out},{n2}{sum_out}->{n1}{n2}"
34 |
35 | return einsum(einsum_equation, batch_grad, batch_grad)
36 |
37 |
38 | def get_first_n_alphabet(n):
39 | """Return the first n lowercase letters of the alphabet as a list."""
40 | return string.ascii_lowercase[:n]
41 |
42 |
43 | def BatchGradTransformsHook_SumGradSquared():
44 | """Compute sum of squared individual gradients via individual gradients."""
45 | return BatchGradTransformsHook({"sum_grad_squared": sum_grad_squared_transform})
46 |
47 |
48 | def sum_grad_squared_transform(batch_grad):
49 | """Transform individual gradients into second non-centered moment."""
50 | return (batch_grad ** 2).sum(0)
51 |
52 |
53 | class BatchGradTransformsHook(ParameterExtensionHook):
54 | """Hook implementation of ``BatchGradTransforms``."""
55 |
56 | def __init__(self, transforms, savefield=None):
57 | """Store transformations and potential savefield.
58 |
59 | Args:
60 | transforms (dict): Values are functions that are evaluated on a parameter's
61 | ``grad_batch`` attribute. The result is stored in a dictionary stored
62 | under ``grad_batch_transforms``.
63 | savefield (str, optional): Attribute name under which the hook's result
64 | is saved in a parameter. If ``None``, it is assumed that the hook acts
65 | via side effects and no output needs to be stored.
66 | """
67 | super().__init__(savefield=savefield)
68 | self._transforms = transforms
69 |
70 | def param_hook(self, param: Tensor):
71 | """Execute all transformations and store results as dictionary in the parameter.
72 |
73 | Args:
74 | param: Trainable parameter which hosts BackPACK quantities.
75 | """
76 | param.grad_batch._param_weakref = weakref.ref(param)
77 | # TODO Delete after backward pass with Cockpit
78 | param.grad_batch_transforms = {
79 | key: func(param.grad_batch) for key, func in self._transforms.items()
80 | }
81 |
--------------------------------------------------------------------------------
/cockpit/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Utility functions to configure cockpit."""
2 |
--------------------------------------------------------------------------------
/cockpit/utils/configuration.py:
--------------------------------------------------------------------------------
1 | """Configuration utilities for cockpit."""
2 |
3 | from cockpit import quantities
4 | from cockpit.utils import schedules
5 |
6 |
7 | def configuration(label, track_schedule=None, verbose=False):
8 | """Use pre-defined collections of quantities that should be used for tracking.
9 |
10 | Currently supports three different configurations:
11 |
12 | - ``"economy"``: Combines the :class:`~cockpit.quantities.Alpha`,
13 | :class:`~cockpit.quantities.Distance`, :class:`~cockpit.quantities.GradHist1d`,
14 | :class:`~cockpit.quantities.GradNorm`, :class:`~cockpit.quantities.InnerTest`,
15 | :class:`~cockpit.quantities.Loss`, :class:`~cockpit.quantities.NormTest`,
16 | :class:`~cockpit.quantities.OrthoTest` and :class:`~cockpit.quantities.UpdateSize`
17 | quantities.
18 | - ``"business"``: Same as ``"economy"`` but additionally with
19 | :class:`~cockpit.quantities.TICDiag` and :class:`~cockpit.quantities.HessTrace`.
20 | - ``"full"``: Same as ``"business"`` but additionally with
21 | :class:`~cockpit.quantities.HessMaxEV` and
22 | :class:`~cockpit.quantities.GradHist2d`.
23 |
24 | Args:
25 | label (str): String specifying the configuration type. Possible configurations
26 | are (least to most expensive) ``'economy'``, ``'business'``, ``'full'``.
27 | track_schedule (callable, optional): Function that maps the ``global_step``
28 | to a boolean, which determines if the quantity should be computed.
29 | Defaults to ``None``.
30 | verbose (bool, optional): Turns on verbose mode. Defaults to ``False``.
31 |
32 | Returns:
33 | list: Instantiated quantities for a cockpit configuration.
34 | """
35 | if track_schedule is None:
36 | track_schedule = schedules.linear(interval=1, offset=0)
37 |
38 | quants = []
39 | for q_cls in quantities_cls_for_configuration(label):
40 | quants.append(q_cls(track_schedule=track_schedule, verbose=verbose))
41 |
42 | return quants
43 |
44 |
45 | def quantities_cls_for_configuration(label):
46 | """Return the quantity classes for a cockpit configuration.
47 |
48 | Currently supports three different configurations:
49 |
50 | - ``"economy"``: Combines the :class:`~cockpit.quantities.Alpha`,
51 | :class:`~cockpit.quantities.Distance`, :class:`~cockpit.quantities.GradHist1d`,
52 | :class:`~cockpit.quantities.GradNorm`, :class:`~cockpit.quantities.InnerTest`,
53 | :class:`~cockpit.quantities.Loss`, :class:`~cockpit.quantities.NormTest`,
54 | :class:`~cockpit.quantities.OrthoTest` and :class:`~cockpit.quantities.UpdateSize`
55 | quantities.
56 | - ``"business"``: Same as ``"economy"`` but additionally with
57 | :class:`~cockpit.quantities.TICDiag` and :class:`~cockpit.quantities.HessTrace`.
58 | - ``"full"``: Same as ``"business"`` but additionally with
59 | :class:`~cockpit.quantities.HessMaxEV` and
60 | :class:`~cockpit.quantities.GradHist2d`.
61 |
62 | Args:
63 | label (str): String specifying the configuration type. Possible configurations
64 | are (least to most expensive) ``'economy'``, ``'business'``, ``'full'``.
65 |
66 | Returns:
67 | [Quantity]: A list of quantity classes used in the
68 | specified configuration.
69 | """
70 | economy = [
71 | quantities.Alpha,
72 | quantities.Distance,
73 | quantities.GradHist1d,
74 | quantities.GradNorm,
75 | quantities.InnerTest,
76 | quantities.Loss,
77 | quantities.NormTest,
78 | quantities.OrthoTest,
79 | quantities.UpdateSize,
80 | ]
81 | business = economy + [
82 | quantities.TICDiag,
83 | quantities.HessTrace,
84 | ]
85 | full = business + [
86 | quantities.HessMaxEV,
87 | quantities.GradHist2d,
88 | ]
89 |
90 | configs = {
91 | "full": full,
92 | "business": business,
93 | "economy": economy,
94 | }
95 |
96 | return configs[label]
97 |
--------------------------------------------------------------------------------
/cockpit/utils/optim.py:
--------------------------------------------------------------------------------
1 | """Utility functions to investigate optimizers.
2 |
3 | Some quantities either require computation of the optimizer update step, or are only
4 | defined for certain optimizers.
5 | """
6 |
7 | import numpy as np
8 | import torch
9 |
10 |
11 | class ComputeStep:
12 | """Update step computation from BackPACK quantities for different optimizers.
13 |
14 | Note:
15 | The ``.grad`` attribute cannot be used to compute update steps as this code
16 | is invoked as hook during backpropagation at a time where the ``.grad`` field
17 | has not yet been updated with the latest gradients.
18 | """
19 |
20 | @staticmethod
21 | def compute_update_step(optimizer, parameter_ids):
22 | """Compute an optimizer's update step.
23 |
24 | Args:
25 | optimizer (torch.optim.Optimizer): A PyTorch optimizer.
26 | parameter_ids ([id]): List of parameter ids for which the updates are
27 | computed.
28 |
29 | Returns:
30 | dict: Mapping between parameters and their updates. Keys are parameter ids
31 | and items are ``torch.Tensor``s representing the update.
32 |
33 | Raises:
34 | NotImplementedError: If the optimizer's update step is not implemented.
35 | """
36 | if ComputeStep.is_sgd_default_kwargs(optimizer):
37 | return ComputeStep.update_sgd_default_kwargs(optimizer, parameter_ids)
38 |
39 | raise NotImplementedError
40 |
41 | @staticmethod
42 | def is_sgd_default_kwargs(optimizer):
43 | """Return whether the input is momentum-free SGD with default values.
44 |
45 | Args:
46 | optimizer (torch.optim.Optimizer): A PyTorch optimizer.
47 |
48 | Returns:
49 | bool: Whether the input is momentum-free SGD with default values.
50 | """
51 | if not isinstance(optimizer, torch.optim.SGD):
52 | return False
53 |
54 | for group in optimizer.param_groups:
55 | if not np.isclose(group["weight_decay"], 0.0):
56 | return False
57 |
58 | if not np.isclose(group["momentum"], 0.0):
59 | return False
60 |
61 | if not np.isclose(group["dampening"], 0.0):
62 | return False
63 |
64 | if not group["nesterov"] is False:
65 | return False
66 |
67 | return True
68 |
69 | @staticmethod
70 | def update_sgd_default_kwargs(optimizer, parameter_ids):
71 | """Return the update of momentum-free SGD with default values.
72 |
73 | Args:
74 | optimizer (torch.optim.SGD): Zero-momentum default SGD.
75 | parameter_ids ([id]): List of parameter ids for which the updates are
76 | computed.
77 |
78 | Returns:
79 | dict: Mapping between parameters and their updates. Keys are parameter ids
80 | and items are ``torch.Tensor``s representing the update.
81 | """
82 | updates = {}
83 |
84 | for group in optimizer.param_groups:
85 | for p in group["params"]:
86 | if id(p) in parameter_ids:
87 | lr = group["lr"]
88 | updates[id(p)] = -lr * p.grad_batch.sum(0).detach()
89 |
90 | if len(updates.keys()) == len(parameter_ids):
91 | return updates
92 |
93 | assert len(updates.keys()) == len(
94 | parameter_ids
95 | ), "Could not compute step for all specified parameters"
96 |
97 | return updates
98 |
--------------------------------------------------------------------------------
/cockpit/utils/schedules.py:
--------------------------------------------------------------------------------
1 | """Convenient schedule functions."""
2 |
3 | import torch
4 |
5 |
6 | def linear(interval, offset=0):
7 | """Creates a linear schedule that tracks when ``{offset + n interval | n >= 0}``.
8 |
9 | Args:
10 | interval (int): The regular tracking interval.
11 | offset (int, optional): Offset of tracking. Defaults to 0.
12 |
13 | Returns:
14 | callable: Function that given the global_step returns whether it should track.
15 | """
16 | docstring = "Track at iterations {" + f"{offset} + n * {interval} " + "| n >= 0}."
17 |
18 | def schedule(global_step):
19 | shifted = global_step - offset
20 | if shifted < 0:
21 | return False
22 | else:
23 | return shifted % interval == 0
24 |
25 | schedule.__doc__ = docstring
26 |
27 | return schedule
28 |
29 |
30 | def logarithmic(start, end, steps=300, base=10, init=True):
31 | """Creates a logarithmic tracking schedule.
32 |
33 | Args:
34 | start ([type]): The starting value.
35 | end ([type]): The end value.
36 | steps (int, optional): Number of log spaced points. Defaults to 300.
37 | base (int, optional): Logarithmic base. Defaults to 10.
38 | init (bool, optional): Whether 0 should be included. Defaults to True.
39 |
40 | Returns:
41 | callable: Function that given the global_step returns whether it should track.
42 | """
43 | # TODO Compute match and avoid array lookup
44 | scheduled_steps = torch.logspace(start, end, steps, base=base, dtype=int)
45 |
46 | if init:
47 | zero = torch.tensor([0], dtype=int)
48 | scheduled_steps = torch.cat((scheduled_steps, zero)).int()
49 |
50 | def schedule(global_step):
51 | return global_step in scheduled_steps
52 |
53 | return schedule
54 |
--------------------------------------------------------------------------------
/docs/requirements_doc.txt:
--------------------------------------------------------------------------------
1 | # Sphinx builds the documentation
2 | sphinx
3 |
4 | # Sphinx extensions
5 | sphinx-rtd-theme
6 | sphinx-automodapi
7 | #sphinx-copybutton
8 | sphinx-notfound-page
9 |
10 | # Markdown conversion
11 | m2r2
--------------------------------------------------------------------------------
/docs/source/_static/01_basic_fmnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/01_basic_fmnist.png
--------------------------------------------------------------------------------
/docs/source/_static/02_advanced_fmnist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/02_advanced_fmnist.png
--------------------------------------------------------------------------------
/docs/source/_static/03_deepobs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/03_deepobs.png
--------------------------------------------------------------------------------
/docs/source/_static/LogoSquare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/LogoSquare.png
--------------------------------------------------------------------------------
/docs/source/_static/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/favicon.ico
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/Alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/Alpha.png
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/CABS.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/CABS.png
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/Distances.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/Distances.png
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/EarlyStopping.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/EarlyStopping.png
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/GradientNorm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/GradientNorm.png
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/GradientTests.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/GradientTests.png
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/HessMaxEV.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/HessMaxEV.png
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/HessTrace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/HessTrace.png
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/Hist1d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/Hist1d.png
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/Hist2d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/Hist2d.png
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/Hyperparameters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/Hyperparameters.png
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/MeanGSNR.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/MeanGSNR.png
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/Performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/Performance.png
--------------------------------------------------------------------------------
/docs/source/_static/instrument_previews/TIC.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/instrument_previews/TIC.png
--------------------------------------------------------------------------------
/docs/source/_static/showcase.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/f-dangel/cockpit/af91391ddab2a8aef85905b081ccf67d94c1a0e5/docs/source/_static/showcase.gif
--------------------------------------------------------------------------------
/docs/source/_static/stylefile.css:
--------------------------------------------------------------------------------
1 | /* Stylefile shamelessly copied (with permission) from ProbNum (Jonathan Wenger et al.) */
2 |
3 | /* Color variables */
4 | :root {
5 | --pn-green: #107d79;
6 | --pn-lightgreen: #e9f4f4;
7 | --pn-brightgreen: #16b3ad;
8 | --pn-orange: #ff9933;
9 | --pn-darkred: #820000;
10 | --pn-darkgray: #333333;
11 | --pn-lightgray: #f2f2f2;
12 | --background-gray: #fcfcfc;
13 | }
14 |
15 | /* Font */
16 | body {
17 | font-family: Roboto light, Helvetica, sans-serif;
18 | }
19 |
20 | .rst-content .toctree-wrapper > p.caption, h1, h2, h3, h4, h5, h6, legend {
21 | margin-top: 0;
22 | font-weight: 700;
23 | font-family: Open sans light, Helvetica, sans-serif;
24 | }
25 |
26 | /* Content */
27 | .wy-nav-content {
28 | max-width: 1000px;
29 | }
30 | .wy-nav-content a:visited {
31 | color: var(--pn-green);
32 | }
33 |
34 | /* Sidebar */
35 |
36 | /* Logo area */
37 | .wy-nav-top {
38 | background: var(--pn-green);
39 | }
40 | .wy-nav-side {
41 | background: var(--pn-darkgray);
42 | }
43 | /* Home button */
44 | .wy-side-nav-search .wy-dropdown > a, .wy-side-nav-search > a {
45 | color: var(--pn-darkgray);
46 | }
47 | /* Version number */
48 | .wy-side-nav-search > div.version {
49 | color: var(--pn-darkgray);
50 | }
51 | /* Sidebar headers */
52 | .wy-menu-vertical header, .wy-menu-vertical p.caption {
53 | color: var(--pn-brightgreen);
54 | }
55 | /* Sidebar links */
56 | .wy-menu-vertical a:active {
57 | background-color: var(--pn-green);
58 | }
59 | .wy-menu-vertical a:hover {
60 | color: var(--pn-brightgreen);
61 | }
62 |
63 | /* Hyperlinks */
64 | a {
65 | color: var(--pn-green);
66 | text-decoration: none;
67 | }
68 |
69 | a:hover {
70 | color: var(--pn-green);
71 | text-decoration: underline;
72 | }
73 |
74 | /* API Documentation */
75 | .rst-content code, .rst-content tt {
76 | color: var(--pn-darkgray);
77 | }
78 | .rst-content .viewcode-link {
79 | color: -var(--pn-green);
80 | }
81 |
82 | /* Code block */
83 | .highlight {
84 | background: var(--pn-lightgray);
85 | }
86 |
87 | /* Code inline */
88 | .rst-content code.literal, .rst-content tt.literal{
89 | color: var(--pn-green);
90 | padding: 1px 2px;
91 | background-color: var(--pn-lightgreen);
92 | border-radius: 4px;
93 | }
94 |
95 | /* Alert boxes (e.g. documentation "see also") */
96 | html.writer-html4 .rst-content dl:not(.docutils) dl:not(.field-list) > dt, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) dl:not(.field-list) > dt {
97 | border-left: 3px solid grey;
98 | background: #fff0;
99 | color: #555;
100 | }
101 |
102 | /* Footnotes */
103 | html.writer-html5 .rst-content dl.footnote > dd {
104 | margin: 0 0 0;
105 | }
106 |
107 | /* Block quotes */
108 | .rst-content blockquote {
109 | margin-left: 0px;
110 | line-height: 24px;
111 | margin-bottom: 0px;
112 | border-left: 5px solid #f2f2f2;
113 | padding-left: 18px
114 | }
115 |
116 | /* Sphinx gallery */
117 | div.sphx-glr-thumbcontainer {
118 | background: var(--pn-lightgray);
119 | border: solid var(--pn-lightgray) 1px;
120 | border-radius: 10px;
121 | margin-bottom: 15px;
122 | }
123 | div.sphx-glr-thumbcontainer:hover {
124 | box-shadow: 0 0 10px #107d7930;
125 | border: solid var(--pn-green) 1px;
126 | }
127 |
128 | /* Plot animations */
129 | .anim-state label {
130 | margin-right: 8px;
131 | display: inline;
132 | }
133 |
134 | /* Author attribution */
135 | .rst-content .section .authorlist ul li{
136 | float: left;
137 | list-style: none;
138 | margin-left: 16px;
139 | margin-bottom: 0px;
140 | }
141 |
142 | .avatar {
143 | vertical-align: middle;
144 | border-radius: 10%;
145 | }
146 |
147 | /* ReadTheDocs */
148 | .fa {
149 | font: Open Sans light, Helvetica, sans-serif;
150 | }
151 |
152 | .rst-versions {
153 | font-family: Open sans light, Helvetica, sans-serif;
154 | }
155 |
156 | .rst-versions .rst-current-version {
157 | color: var(--pn-brightgreen);
158 | }
159 |
160 | .keep-us-sustainable {
161 | border: 1px dotted var(--pn-lightgray);
162 | }
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.alpha_gauge.rst:
--------------------------------------------------------------------------------
1 | alpha_gauge
2 | ===========
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: alpha_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.cabs_gauge.rst:
--------------------------------------------------------------------------------
1 | cabs_gauge
2 | ==========
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: cabs_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.distance_gauge.rst:
--------------------------------------------------------------------------------
1 | distance_gauge
2 | ==============
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: distance_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.early_stopping_gauge.rst:
--------------------------------------------------------------------------------
1 | early_stopping_gauge
2 | ====================
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: early_stopping_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.grad_norm_gauge.rst:
--------------------------------------------------------------------------------
1 | grad_norm_gauge
2 | ===============
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: grad_norm_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.gradient_tests_gauge.rst:
--------------------------------------------------------------------------------
1 | gradient_tests_gauge
2 | ====================
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: gradient_tests_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.histogram_1d_gauge.rst:
--------------------------------------------------------------------------------
1 | histogram_1d_gauge
2 | ==================
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: histogram_1d_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.histogram_2d_gauge.rst:
--------------------------------------------------------------------------------
1 | histogram_2d_gauge
2 | ==================
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: histogram_2d_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.hyperparameter_gauge.rst:
--------------------------------------------------------------------------------
1 | hyperparameter_gauge
2 | ====================
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: hyperparameter_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.max_ev_gauge.rst:
--------------------------------------------------------------------------------
1 | max_ev_gauge
2 | ============
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: max_ev_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.mean_gsnr_gauge.rst:
--------------------------------------------------------------------------------
1 | mean_gsnr_gauge
2 | ===============
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: mean_gsnr_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.performance_gauge.rst:
--------------------------------------------------------------------------------
1 | performance_gauge
2 | =================
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: performance_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.tic_gauge.rst:
--------------------------------------------------------------------------------
1 | tic_gauge
2 | =========
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: tic_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.instruments.trace_gauge.rst:
--------------------------------------------------------------------------------
1 | trace_gauge
2 | ===========
3 |
4 | .. currentmodule:: cockpit.instruments
5 |
6 | .. autoclass:: trace_gauge
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.Alpha.rst:
--------------------------------------------------------------------------------
1 | Alpha
2 | =====
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: Alpha
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.CABS.rst:
--------------------------------------------------------------------------------
1 | CABS
2 | ====
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: CABS
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.Distance.rst:
--------------------------------------------------------------------------------
1 | Distance
2 | ========
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: Distance
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.EarlyStopping.rst:
--------------------------------------------------------------------------------
1 | EarlyStopping
2 | =============
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: EarlyStopping
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.GradHist1d.rst:
--------------------------------------------------------------------------------
1 | GradHist1d
2 | ==========
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: GradHist1d
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.GradHist2d.rst:
--------------------------------------------------------------------------------
1 | GradHist2d
2 | ==========
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: GradHist2d
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.GradNorm.rst:
--------------------------------------------------------------------------------
1 | GradNorm
2 | ========
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: GradNorm
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.HessMaxEV.rst:
--------------------------------------------------------------------------------
1 | HessMaxEV
2 | =========
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: HessMaxEV
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.HessTrace.rst:
--------------------------------------------------------------------------------
1 | HessTrace
2 | =========
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: HessTrace
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.InnerTest.rst:
--------------------------------------------------------------------------------
1 | InnerTest
2 | =========
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: InnerTest
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.Loss.rst:
--------------------------------------------------------------------------------
1 | Loss
2 | ====
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: Loss
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.MeanGSNR.rst:
--------------------------------------------------------------------------------
1 | MeanGSNR
2 | ========
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: MeanGSNR
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.NormTest.rst:
--------------------------------------------------------------------------------
1 | NormTest
2 | ========
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: NormTest
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.OrthoTest.rst:
--------------------------------------------------------------------------------
1 | OrthoTest
2 | =========
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: OrthoTest
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.Parameters.rst:
--------------------------------------------------------------------------------
1 | Parameters
2 | ==========
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: Parameters
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.TICDiag.rst:
--------------------------------------------------------------------------------
1 | TICDiag
2 | =======
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: TICDiag
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.TICTrace.rst:
--------------------------------------------------------------------------------
1 | TICTrace
2 | ========
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: TICTrace
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.Time.rst:
--------------------------------------------------------------------------------
1 | Time
2 | ====
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: Time
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.quantities.UpdateSize.rst:
--------------------------------------------------------------------------------
1 | UpdateSize
2 | ==========
3 |
4 | .. currentmodule:: cockpit.quantities
5 |
6 | .. autoclass:: UpdateSize
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.utils.configuration.configuration.rst:
--------------------------------------------------------------------------------
1 | configuration
2 | =============
3 |
4 | .. currentmodule:: cockpit.utils.configuration
5 |
6 | .. autofunction:: configuration
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.utils.configuration.quantities_cls_for_configuration.rst:
--------------------------------------------------------------------------------
1 | quantities_cls_for_configuration
2 | ================================
3 |
4 | .. currentmodule:: cockpit.utils.configuration
5 |
6 | .. autofunction:: quantities_cls_for_configuration
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.utils.schedules.linear.rst:
--------------------------------------------------------------------------------
1 | linear
2 | ======
3 |
4 | .. currentmodule:: cockpit.utils.schedules
5 |
6 | .. autofunction:: linear
7 |
--------------------------------------------------------------------------------
/docs/source/api/automod/cockpit.utils.schedules.logarithmic.rst:
--------------------------------------------------------------------------------
1 | logarithmic
2 | ===========
3 |
4 | .. currentmodule:: cockpit.utils.schedules
5 |
6 | .. autofunction:: logarithmic
7 |
--------------------------------------------------------------------------------
/docs/source/api/cockpit.rst:
--------------------------------------------------------------------------------
1 | =======
2 | Cockpit
3 | =======
4 |
5 | .. autoclass:: cockpit.Cockpit
6 | :members:
--------------------------------------------------------------------------------
/docs/source/api/instruments.rst:
--------------------------------------------------------------------------------
1 | .. _instruments:
2 |
3 | ===========
4 | Instruments
5 | ===========
6 |
7 | **Cockpit** offers a large set of so called *instruments* that takes tracked
8 | *quantities* and visualizes them.
9 |
10 | .. automodsumm:: cockpit.instruments
11 |
12 | .. toctree::
13 | :glob:
14 | :hidden:
15 |
16 | automod/cockpit.instruments.*
--------------------------------------------------------------------------------
/docs/source/api/plotter.rst:
--------------------------------------------------------------------------------
1 | ===============
2 | Cockpit Plotter
3 | ===============
4 |
5 | .. autoclass:: cockpit.CockpitPlotter
6 | :members:
--------------------------------------------------------------------------------
/docs/source/api/quantities.rst:
--------------------------------------------------------------------------------
1 | .. _quantities:
2 |
3 | ==========
4 | Quantities
5 | ==========
6 |
7 | **Cockpit** offers a large set of so called *quantities* that can be efficiently
8 | tracked during the training process.
9 |
10 | .. automodsumm:: cockpit.quantities
11 |
12 | .. toctree::
13 | :glob:
14 | :hidden:
15 |
16 | automod/cockpit.quantities.*
--------------------------------------------------------------------------------
/docs/source/api/utils.rst:
--------------------------------------------------------------------------------
1 | =====
2 | Utils
3 | =====
4 |
5 |
6 | Configuration
7 | =============
8 |
9 | .. automodapi:: cockpit.utils.configuration
10 | :no-heading:
11 |
12 |
13 | Schedules
14 | =============
15 |
16 | .. automodapi:: cockpit.utils.schedules
17 | :no-heading:
18 |
--------------------------------------------------------------------------------
/docs/source/examples/03_deepobs.rst:
--------------------------------------------------------------------------------
1 | ===============
2 | DeepOBS Example
3 | ===============
4 |
5 | **Cockpit** easily integrates with and can be used together with
6 | `DeepOBS `_.
7 | This will directly give you access to dozens of deep learning problems
8 | that you can explore with **Cockpit**.
9 |
10 | .. Note::
11 | This example requires a `DeepOBS `__
12 | and a `BackOBS `_ installation.
13 | You can install them by running
14 |
15 | .. code:: bash
16 |
17 | pip install 'git+https://github.com/fsschneider/DeepOBS.git@develop#egg=deepobs'
18 |
19 | and
20 |
21 | .. code:: bash
22 |
23 | pip install 'git+https://github.com/f-dangel/backobs.git@master#egg=backobs'
24 |
25 | Note, that currently, only the 1.2.0 beta version of DeepOBS supports PyTorch
26 | which will be installed by the above command.
27 |
28 | .. note::
29 |
30 | In the following example, we will use an additional :download:`utility file
31 | <../../../examples/_utils_deepobs.py>` which automatically incorporates **Cockpit**
32 | with the DeepOBS training loop.
33 |
34 | Having the two `utility files from our repository
35 | `_ we can run
36 |
37 | .. code:: bash
38 |
39 | python 03_deepobs.py
40 |
41 | which exectues the following :download:`example script <../../../examples/03_deepobs.py>`:
42 |
43 | .. literalinclude:: ../../../examples/03_deepobs.py
44 | :language: python
45 | :linenos:
46 |
47 | Just like before, we can define a list of quantities (here we use the
48 | :mod:`~cockpit.utils.configuration` ``"full"``) that we this time pass to the
49 | ``DeepOBSRunner``. It will automatically pass it on to the :class:`~cockpit.Cockpit`.
50 |
51 | With the arguments of the ``runner.run()`` function, we can define whether we want
52 | the :class:`~cockpit.CockpitPlotter` plots to show and/or be stored.
53 |
54 | The **Cockpit** will show a status screen every few epochs, as well as writing to
55 | a logfile and saving its final plot, after training completed.
56 |
57 | .. code-block:: console
58 |
59 | $ python 03_deepobs.py
60 |
61 | ********************************
62 | Evaluating after 0 of 15 epochs...
63 | TRAIN: loss 7.0372
64 | VALID: loss 7.07626
65 | TEST: loss 7.06894
66 | ********************************
67 | [cockpit|plot] Showing current Cockpit.
68 | ********************************
69 | Evaluating after 1 of 15 epochs...
70 | TRAIN: loss 7.00634
71 | VALID: loss 7.01242
72 | TEST: loss 7.00535
73 | ********************************
74 | ********************************
75 | Evaluating after 2 of 15 epochs...
76 | TRAIN: loss 6.98335
77 | VALID: loss 6.94937
78 | TEST: loss 6.94255
79 | ********************************
80 | [cockpit|plot] Showing current Cockpit.
81 |
82 | [...]
83 |
84 | The fifty epochs on the `deep quadratic
85 | `_
86 | problem will result in a **Cockpit** plot similar to this:
87 |
88 | .. image:: ../_static/03_deepobs.png
89 | :alt: Preview Cockpit DeepOBS Example
90 |
--------------------------------------------------------------------------------
/docs/source/extract_instrument_previews.py:
--------------------------------------------------------------------------------
1 | """Quick file to extract previews of instruments."""
2 |
3 |
4 | import os
5 |
6 | import matplotlib.pyplot as plt
7 | from matplotlib.transforms import Bbox
8 |
9 | import cockpit
10 |
11 | HERE_DIR = os.path.dirname(os.path.abspath(__file__))
12 | FILE_PATH = "_static"
13 | FILE_NAME = "instrument_preview_run"
14 | SECONDARY_FILE_NAME = FILE_NAME + "_secondary"
15 | SAVE_PATH = "instrument_previews"
16 |
17 | full_path = os.path.join(HERE_DIR, FILE_PATH, FILE_NAME)
18 | cp = cockpit.CockpitPlotter(secondary_screen=False)
19 | cp.plot(full_path, show_plot=False, show_log_iter=True, debug=True)
20 |
21 | # Store preview images
22 | preview_dict = {
23 | "Hyperparameters": [[3, 1], [11.2, 4.8]],
24 | "Performance": [[11.2, 1], [27.8, 4.8]],
25 | "GradientNorm": [[4, 5], [10.5, 7.75]],
26 | "Distances": [[4, 7.75], [10.6, 10.25]],
27 | "Alpha": [[4, 10.15], [10.7, 12.75]],
28 | "GradientTests": [[11.75, 10.15], [19, 12.75]],
29 | "HessMaxEV": [[20, 10.15], [26.5, 12.75]],
30 | "HessTrace": [[20, 7.75], [26.5, 10.25]],
31 | "TIC": [[20, 5.1], [26.5, 7.75]],
32 | "Hist1d": [[11.75, 7.75], [19, 10.25]],
33 | "Hist2d": [[11.75, 5.1], [18.5, 7.75]],
34 | }
35 |
36 | for instrument in preview_dict:
37 | plt.savefig(
38 | os.path.join(HERE_DIR, FILE_PATH, SAVE_PATH, instrument),
39 | bbox_inches=Bbox(preview_dict[instrument]),
40 | )
41 |
42 | # plt.savefig(
43 | # os.path.join(HERE_DIR, FILE_PATH, SAVE_PATH, "cockpit.png"),
44 | # format="png",
45 | # bbox_inches=Bbox([[11.75, 5.1], [18.5, 7.75]]),
46 | # )
47 |
48 | # Secondary Screen
49 | plt.close("all")
50 | cp = cockpit.CockpitPlotter(secondary_screen=True)
51 | secondary_full_path = os.path.join(HERE_DIR, FILE_PATH, SECONDARY_FILE_NAME)
52 | cp.plot(secondary_full_path, show_plot=False, show_log_iter=False)
53 |
54 | # Store preview images
55 | preview_dict = {
56 | "MeanGSNR": [[3.8, 10.45], [11, 13.15]],
57 | "CABS": [[11.5, 10.45], [19, 13.15]],
58 | "EarlyStopping": [[19, 10.45], [26.5, 13.15]],
59 | }
60 |
61 | for instrument in preview_dict:
62 | plt.savefig(
63 | os.path.join(HERE_DIR, FILE_PATH, SAVE_PATH, instrument),
64 | bbox_inches=Bbox(preview_dict[instrument]),
65 | )
66 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | =======
2 | Cockpit
3 | =======
4 |
5 | |CI Status| |Lint Status| |Doc Status| |Coverage| |License| |Code Style| |arXiv|
6 |
7 | ----
8 |
9 | .. code:: bash
10 |
11 | pip install cockpit-for-pytorch
12 |
13 | ----
14 |
15 | **Cockpit is a visual and statistical debugger specifically designed for deep
16 | learning.** Training a deep neural network is often a pain! Successfully training
17 | such a network usually requires either years of intuition or expensive parameter
18 | searches involving lots of trial and error. Traditional debuggers provide only
19 | limited help: They can find *syntactical errors* but not *training bugs* such as
20 | ill-chosen learning rates. **Cockpit** offers a closer, more meaningful look
21 | into the training process with multiple well-chosen *instruments*.
22 |
23 | ----
24 |
25 | .. image:: _static/showcase.gif
26 |
27 |
28 | To install **Cockpit** simply run
29 |
30 | .. code:: bash
31 |
32 | pip install cockpit-for-pytorch
33 |
34 |
35 | .. toctree::
36 | :maxdepth: 1
37 | :caption: Getting Started
38 |
39 | examples/01_basic_fmnist
40 | examples/02_advanced_fmnist
41 | examples/03_deepobs
42 | introduction/good_to_know
43 |
44 | .. toctree::
45 | :maxdepth: 1
46 | :caption: API Documentation
47 |
48 | api/cockpit
49 | api/plotter
50 | api/quantities
51 | api/instruments
52 | api/utils
53 |
54 | .. toctree::
55 | :maxdepth: 1
56 | :caption: Other
57 |
58 | GitHub Repository
59 | other/contributors
60 | other/license
61 | other/changelog
62 |
63 |
64 | .. |CI Status| image:: https://github.com/f-dangel/cockpit/actions/workflows/CI.yml/badge.svg
65 | :target: https://github.com/f-dangel/cockpit/actions/workflows/CI.yml
66 | :alt: CI Status
67 |
68 | .. |Lint Status| image:: https://github.com/f-dangel/cockpit/actions/workflows/Lint.yml/badge.svg
69 | :target: https://github.com/f-dangel/cockpit/actions/workflows/Lint.yml
70 | :alt: Lint Status
71 |
72 | .. |Doc Status| image:: https://img.shields.io/readthedocs/cockpit/latest.svg?logo=read%20the%20docs&logoColor=white&label=Doc
73 | :target: https://cockpit.readthedocs.io
74 | :alt: Doc Status
75 |
76 | .. |Coverage| image:: https://coveralls.io/repos/github/f-dangel/cockpit/badge.svg?branch=main&t=piyZHm
77 | :target: https://coveralls.io/github/f-dangel/cockpit?branch=main
78 | :alt: CI Status
79 |
80 | .. |License| image:: https://img.shields.io/badge/License-MIT-green.svg
81 | :target: https://github.com/f-dangel/cockpit/blob/master/LICENSE
82 | :alt: License
83 |
84 | .. |Code Style| image:: https://img.shields.io/badge/code%20style-black-000000.svg
85 | :target: https://github.com/psf/black
86 | :alt: Code Style
87 |
88 | .. |arXiv| image:: https://img.shields.io/static/v1?logo=arxiv&logoColor=white&label=Preprint&message=2102.06604&color=B31B1B
89 | :target: https://arxiv.org/abs/2102.06604
90 | :alt: arXiv
91 |
--------------------------------------------------------------------------------
/docs/source/introduction/good_to_know.rst:
--------------------------------------------------------------------------------
1 | ============
2 | Good to Know
3 | ============
4 |
5 | We try to make Cockpit's usage as easy and convenient as possible. Still, there
6 | are limitations. Here are some common pitfalls and recommendations.
7 |
8 | BackPACK
9 | ########
10 |
11 | Most of Cockpit's quantities use BackPACK_ as the back-end for efficient
12 | computation. Please pay attention to the following points for smooth
13 | integration:
14 |
15 | - Don't forget to `extend the model and loss function
16 | `_
17 | yourself [1]_ to activate BackPACK_.
18 |
19 | - Verify that your model architecture is `supported by BackPACK
20 | `_.
21 |
22 | - Your loss function must use ``"mean"`` reduction, that is the loss is of the
23 | following structure
24 |
25 | .. math::
26 |
27 | \mathcal{L}(\mathbf{\theta}) = \frac{1}{N} \sum_{n=0}^{N}
28 | \ell(f(\mathbf{x}_n, \mathbf{\theta}), \mathbf{y}_n)\,.
29 |
30 | This avoids an ambiguous scale in individual gradients, which is documented in
31 | `BackPACK's individual gradient extension
32 | `_.
33 | Otherwise, Cockpit quantities will use incorrectly scaled individual gradients
34 | in their computation.
35 |
36 | It's also a good idea to read through BackPACK's `Good to know
37 | `_ section.
38 |
39 | Performance
40 | ###########
41 |
42 | Slow run time and memory errors are annoying. Here are some tweaks to reduce run
43 | time and memory consumption:
44 |
45 | - Use schedules to reduce the tracking frequency. You can specify custom
46 | schedules to literally select any iteration to be tracked, or rely on
47 | pre-defined :mod:`~cockpit.utils.schedules`.
48 |
49 | - Exclude :py:class:`GradHist2d ` from your quantities. The
50 | two-dimensional histogram implementation uses :py:func:`torch.scatter_add`,
51 | which can be slow on GPU due to atomic additions.
52 |
53 | - Exclude :py:class:`HessMaxEV ` from your quantities. It
54 | requires multiple Hessian-vector products, that are executed sequentially.
55 | Also, this requires the full computation be kept in memory.
56 |
57 | - Spot :ref:`quantities ` whose constructor contains a ``curvature``
58 | argument. It defaults to the most accurate, but also most expensive type. You
59 | may want to sacrifice accuracy for memory and run time performance by
60 | selecting a cheaper option.
61 |
62 |
63 | .. [1] Leaving this responsibility to users is a deliberate choice, as Cockpit
64 | does not always need the package. Specific configurations, that are very
65 | limited though, work without BackPACK_ as they rely only on functionality
66 | built into PyTorch_.
67 |
68 | .. _BackPACK: https://backpack.pt/
69 | .. _PyTorch: https://pytorch.org/
70 |
--------------------------------------------------------------------------------
/docs/source/other/changelog.rst:
--------------------------------------------------------------------------------
1 | .. mdinclude:: ../../../CHANGELOG.md
--------------------------------------------------------------------------------
/docs/source/other/contributors.rst:
--------------------------------------------------------------------------------
1 | .. mdinclude:: ../../../AUTHORS.md
--------------------------------------------------------------------------------
/docs/source/other/license.rst:
--------------------------------------------------------------------------------
1 | .. _license:
2 |
3 | =======
4 | License
5 | =======
6 |
7 | **Cockpit** is published under the MIT License.
8 |
9 | .. literalinclude:: ../../../LICENSE.txt
10 | :language: text
--------------------------------------------------------------------------------
/examples/01_basic_fmnist.py:
--------------------------------------------------------------------------------
1 | """A basic example of using Cockpit with PyTorch for Fashion-MNIST."""
2 |
3 | import torch
4 | from _utils_examples import fmnist_data
5 | from backpack import extend
6 |
7 | from cockpit import Cockpit, CockpitPlotter
8 | from cockpit.utils.configuration import configuration
9 |
10 | # Build Fashion-MNIST classifier
11 | fmnist_data = fmnist_data()
12 | model = extend(torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(784, 10)))
13 | loss_fn = extend(torch.nn.CrossEntropyLoss(reduction="mean"))
14 | individual_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
15 |
16 | # Create SGD Optimizer
17 | opt = torch.optim.SGD(model.parameters(), lr=1e-2)
18 |
19 | # Create Cockpit and a plotter
20 | cockpit = Cockpit(model.parameters(), quantities=configuration("full"))
21 | plotter = CockpitPlotter()
22 |
23 | # Main training loop
24 | max_steps, global_step = 5, 0
25 | for inputs, labels in iter(fmnist_data):
26 | opt.zero_grad()
27 |
28 | # forward pass
29 | outputs = model(inputs)
30 | loss = loss_fn(outputs, labels)
31 | losses = individual_loss_fn(outputs, labels)
32 |
33 | # backward pass
34 | with cockpit(
35 | global_step,
36 | info={
37 | "batch_size": inputs.shape[0],
38 | "individual_losses": losses,
39 | "loss": loss,
40 | "optimizer": opt,
41 | },
42 | ):
43 | loss.backward(create_graph=cockpit.create_graph(global_step))
44 |
45 | # optimizer step
46 | opt.step()
47 | global_step += 1
48 |
49 | print(f"Step: {global_step:5d} | Loss: {loss.item():.4f}")
50 |
51 | plotter.plot(cockpit)
52 |
53 | if global_step >= max_steps:
54 | break
55 |
56 | plotter.plot(cockpit, block=True)
57 |
--------------------------------------------------------------------------------
/examples/02_advanced_fmnist.py:
--------------------------------------------------------------------------------
1 | """A slightly advanced example of using Cockpit with PyTorch for Fashion-MNIST."""
2 |
3 | import torch
4 | from _utils_examples import cnn, fmnist_data, get_logpath
5 | from backpack import extend, extensions
6 |
7 | from cockpit import Cockpit, CockpitPlotter, quantities
8 | from cockpit.utils import schedules
9 |
10 | # Build Fashion-MNIST classifier
11 | fmnist_data = fmnist_data()
12 | model = extend(cnn()) # Use a basic convolutional network
13 | loss_fn = extend(torch.nn.CrossEntropyLoss(reduction="mean"))
14 | individual_loss_fn = extend(torch.nn.CrossEntropyLoss(reduction="none"))
15 |
16 | # Create SGD Optimizer
17 | opt = torch.optim.SGD(model.parameters(), lr=5e-1)
18 |
19 | # Create Cockpit and a plotter
20 | # Customize the tracked quantities and their tracking schedule
21 | quantities = [
22 | quantities.GradNorm(schedules.linear(interval=1)),
23 | quantities.Distance(schedules.linear(interval=1)),
24 | quantities.UpdateSize(schedules.linear(interval=1)),
25 | quantities.HessMaxEV(schedules.linear(interval=3)),
26 | quantities.GradHist1d(schedules.linear(interval=10), bins=10),
27 | ]
28 | cockpit = Cockpit(model.parameters(), quantities=quantities)
29 | plotter = CockpitPlotter()
30 |
31 | # Main training loop
32 | max_steps, global_step = 50, 0
33 | for inputs, labels in iter(fmnist_data):
34 | opt.zero_grad()
35 |
36 | # forward pass
37 | outputs = model(inputs)
38 | loss = loss_fn(outputs, labels)
39 | losses = individual_loss_fn(outputs, labels)
40 |
41 | # backward pass
42 | with cockpit(
43 | global_step,
44 | extensions.DiagHessian(), # Other BackPACK quantities can be computed as well
45 | info={
46 | "batch_size": inputs.shape[0],
47 | "individual_losses": losses,
48 | "loss": loss,
49 | "optimizer": opt,
50 | },
51 | ):
52 | loss.backward(create_graph=cockpit.create_graph(global_step))
53 |
54 | # optimizer step
55 | opt.step()
56 | global_step += 1
57 |
58 | print(f"Step: {global_step:5d} | Loss: {loss.item():.4f}")
59 |
60 | if global_step % 10 == 0:
61 | plotter.plot(
62 | cockpit,
63 | savedir=get_logpath(),
64 | show_plot=False,
65 | save_plot=True,
66 | savename_append=str(global_step),
67 | )
68 |
69 | if global_step >= max_steps:
70 | break
71 |
72 | # Write Cockpit to json file.
73 | cockpit.write(get_logpath())
74 |
75 | # Plot results from file
76 | plotter.plot(
77 | get_logpath(),
78 | savedir=get_logpath(),
79 | show_plot=False,
80 | save_plot=True,
81 | savename_append="_final",
82 | )
83 |
--------------------------------------------------------------------------------
/examples/03_deepobs.py:
--------------------------------------------------------------------------------
1 | """An example of using Cockpit with DeepOBS."""
2 |
3 | from _utils_deepobs import DeepOBSRunner
4 | from _utils_examples import get_logpath
5 | from torch.optim import SGD
6 |
7 | from cockpit.utils import configuration, schedules
8 |
9 | optimizer = SGD
10 | hyperparams = {"lr": {"type": float, "default": 0.001}}
11 |
12 | track_schedule = schedules.linear(10)
13 | plot_schedule = schedules.linear(20)
14 | quantities = configuration.configuration("full", track_schedule=track_schedule)
15 |
16 | runner = DeepOBSRunner(optimizer, hyperparams, quantities, plot_schedule=plot_schedule)
17 |
18 |
19 | def const_schedule(num_epochs):
20 | """Constant learning rate schedule."""
21 | return lambda epoch: 1.0
22 |
23 |
24 | runner.run(
25 | testproblem="quadratic_deep",
26 | output_dir=get_logpath(),
27 | l2_reg=0.0,
28 | num_epochs=50,
29 | show_plots=True,
30 | save_plots=False,
31 | save_final_plot=True,
32 | save_animation=False,
33 | lr_schedule=const_schedule,
34 | )
35 |
--------------------------------------------------------------------------------
/examples/_utils_examples.py:
--------------------------------------------------------------------------------
1 | """Utility functions for the basic PyTorch example."""
2 |
3 | import os
4 | import warnings
5 |
6 | import torch
7 | import torchvision
8 |
9 | HERE = os.path.abspath(__file__)
10 | HEREDIR = os.path.dirname(HERE)
11 | EXAMPLESDIR = os.path.dirname(HEREDIR)
12 |
13 | # Ignore the PyTorch warning that is irrelevant for us
14 | warnings.filterwarnings("ignore", message="Using a non-full backward hook ")
15 |
16 |
17 | def fmnist_data(batch_size=64, shuffle=True):
18 | """Returns a dataloader for Fashion-MNIST.
19 |
20 | Args:
21 | batch_size (int, optional): Batch size. Defaults to 64.
22 | shuffle (bool, optional): Whether the data should be shuffled. Defaults to True.
23 |
24 | Returns:
25 | torch.DataLoader: Dataloader for Fashion-MNIST data.
26 | """
27 | # Additionally set the random seed for reproducability
28 | torch.manual_seed(0)
29 |
30 | fmnist_dataset = torchvision.datasets.FashionMNIST(
31 | root=os.path.join(EXAMPLESDIR, "data"),
32 | train=True,
33 | transform=torchvision.transforms.Compose(
34 | [
35 | torchvision.transforms.ToTensor(),
36 | torchvision.transforms.Normalize((0.1307,), (0.3081,)),
37 | ]
38 | ),
39 | download=True,
40 | )
41 |
42 | return torch.utils.data.dataloader.DataLoader(
43 | fmnist_dataset,
44 | batch_size=batch_size,
45 | shuffle=shuffle,
46 | )
47 |
48 |
49 | def cnn():
50 | """Basic Conv-Net for (Fashion-)MNIST."""
51 | return torch.nn.Sequential(
52 | torch.nn.Conv2d(1, 11, 5),
53 | torch.nn.ReLU(),
54 | torch.nn.MaxPool2d(2, 2),
55 | torch.nn.Conv2d(11, 16, 5),
56 | torch.nn.ReLU(),
57 | torch.nn.MaxPool2d(2, 2),
58 | torch.nn.Flatten(),
59 | torch.nn.Linear(16 * 4 * 4, 80),
60 | torch.nn.ReLU(),
61 | torch.nn.Linear(80, 84),
62 | torch.nn.ReLU(),
63 | torch.nn.Linear(84, 10),
64 | )
65 |
66 |
67 | def get_logpath(suffix=""):
68 | """Create a logpath and return it.
69 |
70 | Args:
71 | suffix (str, optional): suffix to add to the output. Defaults to "".
72 |
73 | Returns:
74 | str: Path to the logfile (output of Cockpit).
75 | """
76 | save_dir = os.path.join(EXAMPLESDIR, "logfiles")
77 | os.makedirs(save_dir, exist_ok=True)
78 | log_path = os.path.join(save_dir, f"cockpit_output{suffix}")
79 | return log_path
80 |
--------------------------------------------------------------------------------
/makefile:
--------------------------------------------------------------------------------
1 | .PHONY: help
2 | .PHONY: black black-check flake8
3 | .PHONY: test
4 | .PHONY: conda-env
5 | .PHONY: black isort format
6 | .PHONY: black-check isort-check format-check, code-standard-check
7 | .PHONY: flake8
8 | .PHONY: pydocstyle-check
9 | .PHONY: darglint-check
10 | .PHONY: build-docs
11 | .PHONY: clean-all
12 |
13 | .DEFAULT: help
14 | help:
15 | @echo "test"
16 | @echo " Run pytest on the project and report coverage"
17 | @echo "black"
18 | @echo " Run black on the project"
19 | @echo "black-check"
20 | @echo " Check if black would change files"
21 | @echo "flake8"
22 | @echo " Run flake8 on the project"
23 | @echo "pydocstyle-check"
24 | @echo " Run pydocstyle on the project"
25 | @echo "darglint-check"
26 | @echo " Run darglint on the project"
27 | @echo "code-standard-check"
28 | @echo " Run all linters on the project to check quality standards."
29 | @echo "conda-env"
30 | @echo " Create conda environment 'cockpit' with dev setup"
31 | @echo "build-docs"
32 | @echo " Build the docs"
33 | @echo "clean-all"
34 | @echo " Removes all unnecessary files."
35 |
36 | ### TESTING ###
37 | # Run pytest with the Matplotlib backend agg to not show plots
38 | test:
39 | @MPLBACKEND=agg pytest -vx --cov=cockpit .
40 |
41 | ### LINTING & FORMATTING ###
42 |
43 | # Uses black.toml config instead of pyproject.toml to avoid pip issues. See
44 | # - https://github.com/psf/black/issues/683
45 | # - https://github.com/pypa/pip/pull/6370
46 | # - https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support
47 | black:
48 | @black . --config=black.toml
49 |
50 | black-check:
51 | @black . --config=black.toml --check
52 |
53 | flake8:
54 | @flake8 .
55 |
56 | pydocstyle-check:
57 | @pydocstyle --count .
58 |
59 | darglint-check:
60 | @darglint --verbosity 2 .
61 |
62 | isort:
63 | @isort .
64 |
65 | isort-check:
66 | @isort . --check
67 |
68 | format:
69 | @make black
70 | @make isort
71 | @make black-check
72 |
73 | format-check: black-check isort-check pydocstyle-check darglint-check
74 |
75 | code-standard-check:
76 | @make black
77 | @make isort
78 | @make black-check
79 | @make flake8
80 | @make pydocstyle-check
81 |
82 | ### CONDA ###
83 | conda-env:
84 | @conda env create --file .conda_env.yml
85 |
86 | ### DOCS ###
87 | build-docs:
88 | @find . -type d -name "automod" -exec rm -r {} +
89 | @cd docs && make clean && make html
90 |
91 | ### CLEAN ###
92 | clean-all:
93 | @find . -name '*.pyc' -delete
94 | @find . -name '*.pyo' -delete
95 | @find . -name '*~' -delete
96 | @find . -type d -name "__pycache__" -delete
97 | @rm -fr .pytest_cache/
98 | @rm -fr .benchmarks/
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = cockpit-for-pytorch
3 | url = https://github.com/f-dangel/cockpit
4 | description = A Practical Debugging Tool for Training Deep Neural Networks.
5 | long_description = file: README.md, CHANGELOG.md, LICENSE.txt
6 | long-description-content-type = text/markdown
7 | author = Frank Schneider and Felix Dangel
8 | author-email = f.schneider@uni-tuebingen.de
9 | license = MIT
10 | keywords = deep-learning, machine-learning, debugging
11 | platforms = any
12 | classifiers =
13 | Development Status :: 4 - Beta
14 | License :: OSI Approved :: MIT License
15 | Operating System :: OS Independent
16 | Programming Language :: Python :: 3.7
17 | Programming Language :: Python :: 3.8
18 | Programming Language :: Python :: 3.9
19 |
20 | [options]
21 | # Define which packages are required to run
22 | install_requires =
23 | json-tricks
24 | matplotlib>=3.4.0
25 | numpy
26 | pandas
27 | scipy
28 | seaborn
29 | torch
30 | backpack-for-pytorch>=1.3.0
31 | zip_safe = False
32 | packages = find:
33 | python_requires = >=3.7
34 | setup_requires =
35 | setuptools_scm
36 |
37 | [options.packages.find]
38 | exclude = test*
39 |
40 | [flake8]
41 | # Configure flake8 linting, minor differences to pytorch.
42 | select = B,C,E,F,P,W,B9
43 | max-line-length = 80
44 | max-complexity = 10
45 | ignore =
46 | # replaced by B950 (max-line-length + 10%)
47 | E501, # max-line-length
48 | # ignored because pytorch uses dict
49 | C408, # use {} instead of dict()
50 | # Not Black-compatible
51 | E203, # whitespace before :
52 | E231, # missing whitespace after ','
53 | W291, # trailing whitespace
54 | W503, # line break before binary operator
55 | W504, # line break after binary operator
56 | exclude = docs,docs_src,build,.git,src,tex
57 |
58 | [pydocstyle]
59 | convention = google
60 | # exclude directories, see
61 | # https://github.com/PyCQA/pydocstyle/issues/363#issuecomment-625563088
62 | match_dir = ^(?!(docs|docs_src|build|.git|src|exp)).*
63 | match = .*\.py
64 |
65 | [isort]
66 | profile=black
67 |
68 | [darglint]
69 | docstring_style = google
70 | # short, long, full
71 | strictness = short
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Setup file for python_package_template using setup.cfg for configuration."""
2 | import sys
3 |
4 | from pkg_resources import VersionConflict, require
5 | from setuptools import setup
6 |
7 | # Use setup.cfg if possible
8 | try:
9 | require("setuptools>=38.3")
10 | except VersionConflict:
11 | print("Error: version of setuptools is too old (<38.3)!")
12 | sys.exit(1)
13 |
14 |
15 | if __name__ == "__main__":
16 | setup(use_scm_version=True)
17 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for ``cockpit``."""
2 |
--------------------------------------------------------------------------------
/tests/settings.py:
--------------------------------------------------------------------------------
1 | """Problem settings shared across all submodules."""
2 |
3 |
4 | SETTINGS = []
5 |
--------------------------------------------------------------------------------
/tests/test_bugs/test_issue5.py:
--------------------------------------------------------------------------------
1 | """Reproduces the bug described in https://github.com/f-dangel/cockpit/issues/5."""
2 |
3 | from backpack import extend
4 | from torch import manual_seed, rand
5 | from torch.nn import Flatten, Linear, MSELoss, Sequential
6 | from torch.optim import Adam
7 |
8 | from cockpit import Cockpit
9 | from cockpit.quantities import Alpha, GradHist1d
10 | from cockpit.utils.schedules import linear
11 |
12 |
13 | def test_BatchGradTransformsHook_deletes_attribute_required_by_Alpha():
14 | """If the optimizer is not SGD, ``Alpha`` needs access to ``.grad_batch``.
15 |
16 | But if an extension that uses ``BatchGradTransformsHook`` is used at the same time,
17 | it will delete the ``grad_batch`` attribute during the backward pass. Consequently,
18 | ``Alpha`` cannot access the attribute anymore. This leads to the error.
19 | """
20 | manual_seed(0)
21 |
22 | N, D_in, D_out = 2, 3, 1
23 | model = extend(Sequential(Flatten(), Linear(D_in, D_out)))
24 |
25 | opt_not_sgd = Adam(model.parameters(), lr=1e-3)
26 | loss_fn = extend(MSELoss(reduction="mean"))
27 | individual_loss_fn = MSELoss(reduction="none")
28 |
29 | on_first = linear(1)
30 | alpha = Alpha(on_first)
31 | uses_BatchGradTransformsHook = GradHist1d(on_first)
32 |
33 | cockpit = Cockpit(
34 | model.parameters(), quantities=[alpha, uses_BatchGradTransformsHook]
35 | )
36 |
37 | global_step = 0
38 | inputs, labels = rand(N, D_in), rand(N, D_out)
39 |
40 | # forward pass
41 | outputs = model(inputs)
42 | loss = loss_fn(outputs, labels)
43 | losses = individual_loss_fn(outputs, labels)
44 |
45 | # backward pass
46 | with cockpit(
47 | global_step,
48 | info={
49 | "batch_size": N,
50 | "individual_losses": losses,
51 | "loss": loss,
52 | "optimizer": opt_not_sgd,
53 | },
54 | ):
55 | loss.backward(create_graph=cockpit.create_graph(global_step))
56 |
--------------------------------------------------------------------------------
/tests/test_bugs/test_issue6.py:
--------------------------------------------------------------------------------
1 | """Reproduces the bug described in https://github.com/f-dangel/cockpit/issues/6."""
2 |
3 | from backpack import extend
4 | from torch import Tensor, manual_seed, rand
5 | from torch.nn import Linear, Module, MSELoss, ReLU
6 | from torch.optim import SGD
7 |
8 | from cockpit import Cockpit
9 | from cockpit.quantities import GradHist1d
10 | from cockpit.utils.schedules import linear
11 |
12 |
13 | def test_extension_hook_executes_on_custom_module():
14 | """Cockpit's extension hook is only skipped for known containers like Sequential.
15 |
16 | It will thus execute on custom containers and lead to crashes whenever a quantity
17 | that uses extension hooks is used.
18 | """
19 | manual_seed(0)
20 | N, D_in, D_out = 2, 3, 1
21 |
22 | # NOTE Inheriting from Sequential passes
23 | class CustomModule(Module):
24 | """Custom container that is not skipped by the extension hook."""
25 |
26 | def __init__(self):
27 | super().__init__()
28 | self.linear = Linear(D_in, D_out)
29 | self.relu = ReLU()
30 |
31 | def forward(self, x: Tensor) -> Tensor:
32 | return self.relu(self.linear(x))
33 |
34 | uses_extension_hook = GradHist1d(linear(interval=1))
35 | config = [uses_extension_hook]
36 |
37 | model = extend(CustomModule())
38 | cockpit = Cockpit(model.parameters(), quantities=config)
39 |
40 | opt = SGD(model.parameters(), lr=0.1)
41 |
42 | loss_fn = extend(MSELoss(reduction="mean"))
43 | individual_loss_fn = MSELoss(reduction="none")
44 |
45 | global_step = 0
46 | inputs, labels = rand(N, D_in), rand(N, D_out)
47 |
48 | # forward pass
49 | outputs = model(inputs)
50 | loss = loss_fn(outputs, labels)
51 | losses = individual_loss_fn(outputs, labels)
52 |
53 | # backward pass
54 | with cockpit(
55 | global_step,
56 | info={
57 | "batch_size": N,
58 | "individual_losses": losses,
59 | "loss": loss,
60 | "optimizer": opt,
61 | },
62 | ):
63 | loss.backward(create_graph=cockpit.create_graph(global_step))
64 |
--------------------------------------------------------------------------------
/tests/test_cockpit/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for ``cockpit.Cockpit``."""
2 |
--------------------------------------------------------------------------------
/tests/test_cockpit/settings.py:
--------------------------------------------------------------------------------
1 | """Settings used by the tests in this submodule."""
2 |
3 | import torch
4 |
5 | from tests.settings import SETTINGS as GLOBAL_SETTINGS
6 | from tests.utils.data import load_toy_data
7 | from tests.utils.models import load_toy_model
8 | from tests.utils.problem import make_problems_with_ids
9 |
10 | LOCAL_SETTINGS = [
11 | {
12 | "data_fn": lambda: load_toy_data(batch_size=5),
13 | "model_fn": load_toy_model,
14 | "individual_loss_function_fn": lambda: torch.nn.CrossEntropyLoss(
15 | reduction="none"
16 | ),
17 | "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"),
18 | "iterations": 1,
19 | },
20 | ]
21 | SETTINGS = GLOBAL_SETTINGS + LOCAL_SETTINGS
22 |
23 | PROBLEMS, PROBLEMS_IDS = make_problems_with_ids(SETTINGS)
24 |
--------------------------------------------------------------------------------
/tests/test_cockpit/test_automatic_call_track.py:
--------------------------------------------------------------------------------
1 | """Check that ``track`` is called when leaving the ``cockpit`` context."""
2 |
3 | import pytest
4 |
5 | from cockpit import quantities
6 | from cockpit.utils.schedules import linear
7 | from tests.test_cockpit.settings import PROBLEMS, PROBLEMS_IDS
8 | from tests.utils.harness import SimpleTestHarness
9 | from tests.utils.problem import instantiate
10 |
11 |
12 | # TODO Reconsider purpose of this test
13 | class CustomTestHarness(SimpleTestHarness):
14 | """Custom Test Harness checking that track gets called when leaving the context."""
15 |
16 | def check_in_context(self):
17 | """Check that track has not been called yet."""
18 | assert self.problem.iterations == 1, "Test only checks the first step"
19 | global_step = 0
20 | assert global_step not in self.cockpit.quantities[0].output.keys()
21 |
22 | def check_after_context(self):
23 | """Verify that track has been called after the context is left."""
24 | assert self.problem.iterations == 1, "Test only checks the first step"
25 | global_step = 0
26 | assert global_step in self.cockpit.quantities[0].output.keys()
27 |
28 |
29 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS)
30 | def test_track_writes_output(problem):
31 | """Test that a ``cockpit``'s ``track`` function writes to the output."""
32 | quantity = quantities.Time(track_schedule=linear(1))
33 |
34 | with instantiate(problem):
35 | testing_harness = CustomTestHarness(problem)
36 | cockpit_kwargs = {"quantities": [quantity]}
37 | testing_harness.test(cockpit_kwargs)
38 |
--------------------------------------------------------------------------------
/tests/test_cockpit/test_backpack_extensions.py:
--------------------------------------------------------------------------------
1 | """Test if BackPACK quantities other than that required by cockpit can be computed."""
2 |
3 | import pytest
4 | from backpack.extensions import DiagHessian
5 |
6 | from cockpit import quantities
7 | from cockpit.utils.schedules import linear
8 | from tests.test_cockpit.settings import PROBLEMS, PROBLEMS_IDS
9 | from tests.utils.harness import SimpleTestHarness
10 | from tests.utils.problem import instantiate
11 |
12 |
13 | class CustomTestHarness(SimpleTestHarness):
14 | """Create a Custom Test Harness that checks whether the BackPACK buffers exist."""
15 |
16 | def check_in_context(self):
17 | """Check that the BackPACK buffers exists in the context."""
18 | for param in self.problem.model.parameters():
19 | # required by TICDiag and user
20 | assert hasattr(param, "diag_h")
21 | # required by TICDiag only
22 | assert hasattr(param, "grad_batch_transforms")
23 | assert "sum_grad_squared" in param.grad_batch_transforms
24 |
25 | def check_after_context(self):
26 | """Check that the buffers are not deleted when specified by the user."""
27 | for param in self.problem.model.parameters():
28 | assert hasattr(param, "diag_h")
29 | # not protected by user
30 | assert not hasattr(param, "grad_batch_transforms")
31 |
32 |
33 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS)
34 | def test_backpack_extensions(problem):
35 | """Check if backpack quantities can be computed inside cockpit."""
36 | quantity = quantities.TICDiag(track_schedule=linear(1))
37 |
38 | with instantiate(problem):
39 | testing_harness = CustomTestHarness(problem)
40 | cockpit_kwargs = {"quantities": [quantity]}
41 | testing_harness.test(cockpit_kwargs, DiagHessian())
42 |
--------------------------------------------------------------------------------
/tests/test_cockpit/test_multiple_batch_grad_transforms.py:
--------------------------------------------------------------------------------
1 | """Tests for using multiple batch grad transforms in Cockpit."""
2 |
3 | import pytest
4 |
5 | from cockpit.cockpit import Cockpit
6 | from cockpit.quantities.utils_transforms import BatchGradTransformsHook
7 |
8 |
9 | def test_merge_batch_grad_transforms():
10 | """Test merging of multiple ``BatchGradTransforms``."""
11 | bgt1 = BatchGradTransformsHook({"x": lambda t: t, "y": lambda t: t})
12 | bgt2 = BatchGradTransformsHook({"v": lambda t: t, "w": lambda t: t})
13 |
14 | merged_bgt = Cockpit._merge_batch_grad_transform_hooks([bgt1, bgt2])
15 | assert isinstance(merged_bgt, BatchGradTransformsHook)
16 |
17 | merged_keys = ["x", "y", "v", "w"]
18 | assert len(merged_bgt._transforms.keys()) == len(merged_keys)
19 |
20 | for key in merged_keys:
21 | assert key in merged_bgt._transforms.keys()
22 |
23 | assert id(bgt1._transforms["x"]) == id(merged_bgt._transforms["x"])
24 | assert id(bgt2._transforms["w"]) == id(merged_bgt._transforms["w"])
25 |
26 |
27 | def test_merge_batch_grad_transforms_same_key_different_trafo():
28 | """Merging ``BatchGradTransforms`` with same key but different trafo should fail."""
29 | bgt1 = BatchGradTransformsHook({"x": lambda t: t, "y": lambda t: t})
30 | bgt2 = BatchGradTransformsHook({"x": lambda t: t, "w": lambda t: t})
31 |
32 | with pytest.raises(ValueError):
33 | _ = Cockpit._merge_batch_grad_transform_hooks([bgt1, bgt2])
34 |
35 |
36 | def test_merge_batch_grad_transforms_same_key_same_trafo():
37 | """Test merging multiple ``BatchGradTransforms`` with same key and same trafo."""
38 |
39 | def func(t):
40 | return t
41 |
42 | bgt1 = BatchGradTransformsHook({"x": func})
43 | bgt2 = BatchGradTransformsHook({"x": func})
44 |
45 | merged = Cockpit._merge_batch_grad_transform_hooks([bgt1, bgt2])
46 |
47 | assert len(merged._transforms.keys()) == 1
48 | assert id(merged._transforms["x"]) == id(func)
49 |
--------------------------------------------------------------------------------
/tests/test_examples/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for ``examples``."""
2 |
--------------------------------------------------------------------------------
/tests/test_examples/test_examples.py:
--------------------------------------------------------------------------------
1 | """Test whether the example scripts run."""
2 |
3 | import os
4 | import pathlib
5 | import runpy
6 | import sys
7 |
8 | import pytest
9 |
10 | SCRIPTS = sorted(pathlib.Path(__file__, "../../..", "examples").resolve().glob("*.py"))
11 | SCRIPTS_STR, SCRIPTS_ID = [], []
12 | for s in SCRIPTS:
13 | if not str(s).split("/")[-1].startswith("_"):
14 | SCRIPTS_STR.append(str(s))
15 | SCRIPTS_ID.append(str(s).split("/")[-1].split(".")[0])
16 |
17 |
18 | @pytest.mark.parametrize("script", SCRIPTS_STR, ids=SCRIPTS_ID)
19 | def test_example_scripts(script):
20 | """Run a single example script.
21 |
22 | Args:
23 | script (str): Script that should be run.
24 | """
25 | sys.path.append(os.path.dirname(script))
26 | del sys.argv[1:] # Clear CLI arguments from pytest
27 | runpy.run_path(str(script))
28 |
--------------------------------------------------------------------------------
/tests/test_quantities/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for ``cockpit.quantities``."""
2 |
--------------------------------------------------------------------------------
/tests/test_quantities/adam_settings.py:
--------------------------------------------------------------------------------
1 | """Problem settings using Adam as optimizer.
2 |
3 | Some quantities are only defined for zero-momentum SGD (``CABS``, ``EarlyStopping``),
4 | or use a different computation strategy (``Alpha``). This behavior needs to be tested.
5 | """
6 |
7 | import torch
8 |
9 | from tests.utils.data import load_toy_data
10 | from tests.utils.models import load_toy_model
11 | from tests.utils.problem import make_problems_with_ids
12 |
13 | ADAM_SETTINGS = [
14 | {
15 | "data_fn": lambda: load_toy_data(batch_size=4),
16 | "model_fn": load_toy_model,
17 | "individual_loss_function_fn": lambda: torch.nn.CrossEntropyLoss(
18 | reduction="none"
19 | ),
20 | "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"),
21 | "iterations": 5,
22 | "optimizer_fn": lambda parameters: torch.optim.Adam(parameters, lr=0.01),
23 | },
24 | ]
25 |
26 | ADAM_PROBLEMS, ADAM_IDS = make_problems_with_ids(ADAM_SETTINGS)
27 |
--------------------------------------------------------------------------------
/tests/test_quantities/settings.py:
--------------------------------------------------------------------------------
1 | """Settings used by the tests in this submodule."""
2 |
3 | import torch
4 |
5 | from cockpit.utils.schedules import linear, logarithmic
6 | from tests.settings import SETTINGS as GLOBAL_SETTINGS
7 | from tests.utils.data import load_toy_data
8 | from tests.utils.models import load_toy_model
9 | from tests.utils.problem import make_problems_with_ids
10 |
11 | LOCAL_SETTINGS = [
12 | {
13 | "data_fn": lambda: load_toy_data(batch_size=5),
14 | "model_fn": load_toy_model,
15 | "individual_loss_function_fn": lambda: torch.nn.CrossEntropyLoss(
16 | reduction="none"
17 | ),
18 | "loss_function_fn": lambda: torch.nn.CrossEntropyLoss(reduction="mean"),
19 | "iterations": 5,
20 | },
21 | ]
22 |
23 | SETTINGS = GLOBAL_SETTINGS + LOCAL_SETTINGS
24 |
25 | PROBLEMS, PROBLEMS_IDS = make_problems_with_ids(SETTINGS)
26 |
27 | INDEPENDENT_RUNS = [True, False]
28 | INDEPENDENT_RUNS_IDS = [f"independent_runs={run}" for run in INDEPENDENT_RUNS]
29 |
30 | CPU_PROBLEMS = []
31 | CPU_PROBLEMS_ID = []
32 | for problem, problem_id in zip(PROBLEMS, PROBLEMS_IDS):
33 | if "cpu" in str(problem.device):
34 | CPU_PROBLEMS.append(problem)
35 | CPU_PROBLEMS_ID.append(problem_id)
36 |
37 | QUANTITY_KWARGS = [
38 | {
39 | "track_schedule": linear(interval=1, offset=2), # [1, 3, 5, ...]
40 | "verbose": True,
41 | },
42 | {
43 | "track_schedule": logarithmic(
44 | start=0, end=1, steps=4, init=False
45 | ), # [1, 2, 4, 10]
46 | "verbose": True,
47 | },
48 | ]
49 | QUANTITY_KWARGS_IDS = [f"q_kwargs={q_kwargs}" for q_kwargs in QUANTITY_KWARGS]
50 |
--------------------------------------------------------------------------------
/tests/test_quantities/test_cabs.py:
--------------------------------------------------------------------------------
1 | """Compare ``CABS`` quantity with ``torch.autograd``."""
2 |
3 | import pytest
4 |
5 | from cockpit.context import get_individual_losses, get_optimizer
6 | from cockpit.quantities import CABS
7 | from cockpit.utils.optim import ComputeStep
8 | from tests.test_quantities.adam_settings import ADAM_IDS, ADAM_PROBLEMS
9 | from tests.test_quantities.settings import (
10 | INDEPENDENT_RUNS,
11 | INDEPENDENT_RUNS_IDS,
12 | PROBLEMS,
13 | PROBLEMS_IDS,
14 | QUANTITY_KWARGS,
15 | QUANTITY_KWARGS_IDS,
16 | )
17 | from tests.test_quantities.utils import autograd_diagonal_variance, get_compare_fn
18 |
19 |
20 | class AutogradCABS(CABS):
21 | """``torch.autograd`` implementation of ``CABS``."""
22 |
23 | def extensions(self, global_step):
24 | """Return list of BackPACK extensions required for the computation.
25 |
26 | Args:
27 | global_step (int): The current iteration number.
28 |
29 | Returns:
30 | list: (Potentially empty) list with required BackPACK quantities.
31 | """
32 | return []
33 |
34 | def extension_hooks(self, global_step):
35 | """Return list of BackPACK extension hooks required for the computation.
36 |
37 | Args:
38 | global_step (int): The current iteration number.
39 |
40 | Returns:
41 | [callable]: List of required BackPACK extension hooks for the current
42 | iteration.
43 | """
44 | return []
45 |
46 | def create_graph(self, global_step):
47 | """Return whether access to the forward pass computation graph is needed.
48 |
49 | Args:
50 | global_step (int): The current iteration number.
51 |
52 | Returns:
53 | bool: ``True`` if the computation graph shall not be deleted,
54 | else ``False``.
55 | """
56 | return self.should_compute(global_step)
57 |
58 | def _compute(self, global_step, params, batch_loss):
59 | """Evaluate the CABS criterion.
60 |
61 | Args:
62 | global_step (int): The current iteration number.
63 | params ([torch.Tensor]): List of torch.Tensors holding the network's
64 | parameters.
65 | batch_loss (torch.Tensor): Mini-batch loss from current step.
66 |
67 | Returns:
68 | float: Evaluated CABS criterion.
69 |
70 | Raises:
71 | ValueError: If the optimizer differs from SGD with default arguments.
72 | """
73 | optimizer = get_optimizer(global_step)
74 | if not ComputeStep.is_sgd_default_kwargs(optimizer):
75 | raise ValueError("This criterion only supports zero-momentum SGD.")
76 |
77 | losses = get_individual_losses(global_step)
78 | batch_axis = 0
79 | trace_variance = autograd_diagonal_variance(
80 | losses, params, concat=True, unbiased=False
81 | ).sum(batch_axis)
82 | lr = self.get_lr(optimizer)
83 |
84 | return (lr * trace_variance / batch_loss).item()
85 |
86 |
87 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS)
88 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS)
89 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS)
90 | def test_cabs(problem, independent_runs, q_kwargs):
91 | """Compare BackPACK and ``torch.autograd`` implementation of CABS.
92 |
93 | Args:
94 | problem (tests.utils.Problem): Settings for train loop.
95 | independent_runs (bool): Whether to use to separate runs to compute the
96 | output of every quantity.
97 | q_kwargs (dict): Keyword arguments handed over to both quantities.
98 | """
99 | compare_fn = get_compare_fn(independent_runs)
100 | compare_fn(problem, (CABS, AutogradCABS), q_kwargs)
101 |
102 |
103 | @pytest.mark.parametrize("problem", ADAM_PROBLEMS, ids=ADAM_IDS)
104 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS)
105 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS)
106 | def test_cabs_no_adam(problem, independent_runs, q_kwargs):
107 | """Verify Adam is not supported by CABS criterion.
108 |
109 | Args:
110 | problem (tests.utils.Problem): Settings for train loop.
111 | independent_runs (bool): Whether to use to separate runs to compute the
112 | output of every quantity.
113 | q_kwargs (dict): Keyword arguments handed over to both quantities.
114 | """
115 | with pytest.raises(ValueError):
116 | test_cabs(problem, independent_runs, q_kwargs)
117 |
--------------------------------------------------------------------------------
/tests/test_quantities/test_early_stopping.py:
--------------------------------------------------------------------------------
1 | """Compare ``EarlyStopping`` quantity with ``torch.autograd``."""
2 |
3 | import pytest
4 | import torch
5 |
6 | from cockpit.context import get_batch_size, get_individual_losses
7 | from cockpit.quantities import EarlyStopping
8 | from tests.test_quantities.adam_settings import ADAM_IDS, ADAM_PROBLEMS
9 | from tests.test_quantities.settings import (
10 | INDEPENDENT_RUNS,
11 | INDEPENDENT_RUNS_IDS,
12 | PROBLEMS,
13 | PROBLEMS_IDS,
14 | QUANTITY_KWARGS,
15 | QUANTITY_KWARGS_IDS,
16 | )
17 | from tests.test_quantities.utils import autograd_diagonal_variance, get_compare_fn
18 |
19 |
20 | class AutogradEarlyStopping(EarlyStopping):
21 | """``torch.autograd`` implementation of ``EarlyStopping``."""
22 |
23 | def extensions(self, global_step):
24 | """Return list of BackPACK extensions required for the computation.
25 |
26 | Args:
27 | global_step (int): The current iteration number.
28 |
29 | Returns:
30 | list: (Potentially empty) list with required BackPACK quantities.
31 | """
32 | return []
33 |
34 | def extension_hooks(self, global_step):
35 | """Return list of BackPACK extension hooks required for the computation.
36 |
37 | Args:
38 | global_step (int): The current iteration number.
39 |
40 | Returns:
41 | [callable]: List of required BackPACK extension hooks for the current
42 | iteration.
43 | """
44 | return []
45 |
46 | def create_graph(self, global_step):
47 | """Return whether access to the forward pass computation graph is needed.
48 |
49 | Args:
50 | global_step (int): The current iteration number.
51 |
52 | Returns:
53 | bool: ``True`` if the computation graph shall not be deleted,
54 | else ``False``.
55 | """
56 | return self.should_compute(global_step)
57 |
58 | def _compute(self, global_step, params, batch_loss):
59 | """Evaluate the early stopping criterion.
60 |
61 | Args:
62 | global_step (int): The current iteration number.
63 | params ([torch.Tensor]): List of torch.Tensors holding the network's
64 | parameters.
65 | batch_loss (torch.Tensor): Mini-batch loss from current step.
66 |
67 | Returns:
68 | float: Early stopping criterion.
69 | """
70 | grad_squared = torch.cat([p.grad.flatten() for p in params]) ** 2
71 |
72 | losses = get_individual_losses(global_step)
73 | diag_variance = autograd_diagonal_variance(losses, params, concat=True)
74 |
75 | B = get_batch_size(global_step)
76 |
77 | return 1 - B * (grad_squared / (diag_variance + self._epsilon)).mean().item()
78 |
79 |
80 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS)
81 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS)
82 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS)
83 | def test_early_stopping(problem, independent_runs, q_kwargs):
84 | """Compare BackPACK and ``torch.autograd`` implementation of EarlyStopping.
85 |
86 | Args:
87 | problem (tests.utils.Problem): Settings for train loop.
88 | independent_runs (bool): Whether to use to separate runs to compute the
89 | output of every quantity.
90 | q_kwargs (dict): Keyword arguments handed over to both quantities.
91 | """
92 | rtol, atol = 5e-3, 1e-5
93 |
94 | compare_fn = get_compare_fn(independent_runs)
95 | compare_fn(
96 | problem, (EarlyStopping, AutogradEarlyStopping), q_kwargs, rtol=rtol, atol=atol
97 | )
98 |
99 |
100 | @pytest.mark.parametrize("problem", ADAM_PROBLEMS, ids=ADAM_IDS)
101 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS)
102 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS)
103 | def test_early_stopping_no_adam(problem, independent_runs, q_kwargs):
104 | """Verify Adam is not supported by EarlyStopping criterion.
105 |
106 | Args:
107 | problem (tests.utils.Problem): Settings for train loop.
108 | independent_runs (bool): Whether to use to separate runs to compute the
109 | output of every quantity.
110 | q_kwargs (dict): Keyword arguments handed over to both quantities.
111 | """
112 | with pytest.raises(ValueError):
113 | test_early_stopping(problem, independent_runs, q_kwargs)
114 |
--------------------------------------------------------------------------------
/tests/test_quantities/test_hess_max_ev.py:
--------------------------------------------------------------------------------
1 | """Compare ``HessMaxEV`` quantity with ``torch.autograd``."""
2 |
3 | import warnings
4 |
5 | import pytest
6 |
7 | from cockpit.quantities import HessMaxEV
8 | from tests.test_quantities.settings import (
9 | INDEPENDENT_RUNS,
10 | INDEPENDENT_RUNS_IDS,
11 | PROBLEMS,
12 | PROBLEMS_IDS,
13 | QUANTITY_KWARGS,
14 | QUANTITY_KWARGS_IDS,
15 | )
16 | from tests.test_quantities.utils import (
17 | autograd_hessian_maximum_eigenvalue,
18 | get_compare_fn,
19 | )
20 |
21 |
22 | class AutogradHessMaxEV(HessMaxEV):
23 | """``torch.autograd`` implementation of ``HessMaxEV``.
24 |
25 | Requires storing the full Hessian in memory and can hence only be applied to small
26 | networks.
27 | """
28 |
29 | def _compute(self, global_step, params, batch_loss):
30 | """Evaluate the maximum mini-batch loss Hessian eigenvalue.
31 |
32 | Args:
33 | global_step (int): The current iteration number.
34 | params ([torch.Tensor]): List of torch.Tensors holding the network's
35 | parameters.
36 | batch_loss (torch.Tensor): Mini-batch loss from current step.
37 |
38 | Returns:
39 | float: Maximum Hessian eigenvalue (of the mini-batch loss).
40 | """
41 | self._maybe_warn_dimension(sum(p.numel() for p in params))
42 |
43 | return autograd_hessian_maximum_eigenvalue(batch_loss, params).item()
44 |
45 | @staticmethod
46 | def _maybe_warn_dimension(dim):
47 | """Warn user if the Hessian is large."""
48 | MAX_DIM = 1000
49 |
50 | if dim >= MAX_DIM:
51 | warnings.warn(f"Computing Hessians of size ({dim}, {dim}) is expensive")
52 |
53 |
54 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS)
55 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS)
56 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS)
57 | def test_hess_max_ev(problem, independent_runs, q_kwargs):
58 | """Compare BackPACK and ``torch.autograd`` implementation of Hessian max eigenvalue.
59 |
60 | Args:
61 | problem (tests.utils.Problem): Settings for train loop.
62 | independent_runs (bool): Whether to use to separate runs to compute the
63 | output of every quantity.
64 | q_kwargs (dict): Keyword arguments handed over to both quantities.
65 | """
66 | atol, rtol = 1e-4, 1e-2
67 |
68 | compare_fn = get_compare_fn(independent_runs)
69 | compare_fn(problem, (HessMaxEV, AutogradHessMaxEV), q_kwargs, rtol=rtol, atol=atol)
70 |
--------------------------------------------------------------------------------
/tests/test_quantities/test_hess_trace.py:
--------------------------------------------------------------------------------
1 | """Compare ``HessTrace`` quantity with ``torch.autograd``."""
2 |
3 | import pytest
4 |
5 | from cockpit.quantities import HessTrace
6 | from tests.test_quantities.settings import (
7 | INDEPENDENT_RUNS,
8 | INDEPENDENT_RUNS_IDS,
9 | PROBLEMS,
10 | PROBLEMS_IDS,
11 | QUANTITY_KWARGS,
12 | QUANTITY_KWARGS_IDS,
13 | )
14 | from tests.test_quantities.utils import autograd_diag_hessian, get_compare_fn
15 |
16 |
17 | class AutogradHessTrace(HessTrace):
18 | """``torch.autograd`` implementation of ``HessTrace``."""
19 |
20 | def extensions(self, global_step):
21 | """Return list of BackPACK extensions required for the computation.
22 |
23 | Args:
24 | global_step (int): The current iteration number.
25 |
26 | Returns:
27 | list: (Potentially empty) list with required BackPACK quantities.
28 | """
29 | return []
30 |
31 | def create_graph(self, global_step):
32 | """Return whether access to the forward pass computation graph is needed.
33 |
34 | Args:
35 | global_step (int): The current iteration number.
36 |
37 | Returns:
38 | bool: ``True`` if the computation graph shall not be deleted,
39 | else ``False``.
40 | """
41 | return self.should_compute(global_step)
42 |
43 | def _compute(self, global_step, params, batch_loss):
44 | """Evaluate the trace of the Hessian at the current point.
45 |
46 | Args:
47 | global_step (int): The current iteration number.
48 | params ([torch.Tensor]): List of torch.Tensors holding the network's
49 | parameters.
50 | batch_loss (torch.Tensor): Mini-batch loss from current step.
51 |
52 | Returns:
53 | list: Traces of the Hessian by layer.
54 | """
55 | return [
56 | diag_h.sum().item() for diag_h in autograd_diag_hessian(batch_loss, params)
57 | ]
58 |
59 |
60 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS)
61 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS)
62 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS)
63 | def test_hess_trace(problem, independent_runs, q_kwargs):
64 | """Compare BackPACK and ``torch.autograd`` implementation of Hessian trace.
65 |
66 | Args:
67 | problem (tests.utils.Problem): Settings for train loop.
68 | independent_runs (bool): Whether to use to separate runs to compute the
69 | output of every quantity.
70 | q_kwargs (dict): Keyword arguments handed over to both quantities.
71 | """
72 | compare_fn = get_compare_fn(independent_runs)
73 | compare_fn(problem, (HessTrace, AutogradHessTrace), q_kwargs)
74 |
--------------------------------------------------------------------------------
/tests/test_quantities/test_inner_test.py:
--------------------------------------------------------------------------------
1 | """Compare ``InnerTest`` quantity with ``torch.autograd``."""
2 |
3 | import pytest
4 | import torch
5 |
6 | from cockpit.context import get_batch_size, get_individual_losses
7 | from cockpit.quantities import InnerTest
8 | from tests.test_quantities.settings import (
9 | INDEPENDENT_RUNS,
10 | INDEPENDENT_RUNS_IDS,
11 | PROBLEMS,
12 | PROBLEMS_IDS,
13 | QUANTITY_KWARGS,
14 | QUANTITY_KWARGS_IDS,
15 | )
16 | from tests.test_quantities.utils import autograd_individual_gradients, get_compare_fn
17 |
18 |
19 | class AutogradInnerTest(InnerTest):
20 | """``torch.autograd`` implementation of ``InnerTest``."""
21 |
22 | def extensions(self, global_step):
23 | """Return list of BackPACK extensions required for the computation.
24 |
25 | Args:
26 | global_step (int): The current iteration number.
27 |
28 | Returns:
29 | list: (Potentially empty) list with required BackPACK quantities.
30 | """
31 | return []
32 |
33 | def extension_hooks(self, global_step):
34 | """Return list of BackPACK extension hooks required for the computation.
35 |
36 | Args:
37 | global_step (int): The current iteration number.
38 |
39 | Returns:
40 | [callable]: List of required BackPACK extension hooks for the current
41 | iteration.
42 | """
43 | return []
44 |
45 | def create_graph(self, global_step):
46 | """Return whether access to the forward pass computation graph is needed.
47 |
48 | Args:
49 | global_step (int): The current iteration number.
50 |
51 | Returns:
52 | bool: ``True`` if the computation graph shall not be deleted,
53 | else ``False``.
54 | """
55 | return self.should_compute(global_step)
56 |
57 | def _compute(self, global_step, params, batch_loss):
58 | """Evaluate the inner-product test.
59 |
60 | Args:
61 | global_step (int): The current iteration number.
62 | params ([torch.Tensor]): List of torch.Tensors holding the network's
63 | parameters.
64 | batch_loss (torch.Tensor): Mini-batch loss from current step.
65 |
66 | Returns:
67 | float: Resut of the inner-product test.
68 | """
69 | losses = get_individual_losses(global_step)
70 | individual_gradients_flat = autograd_individual_gradients(
71 | losses, params, concat=True
72 | )
73 | grad = torch.cat([p.grad.flatten() for p in params])
74 |
75 | projections = torch.einsum("ni,i->n", individual_gradients_flat, grad)
76 | grad_norm = grad.norm()
77 |
78 | N_axis = 0
79 | batch_size = get_batch_size(global_step)
80 |
81 | return (
82 | (
83 | 1
84 | / (batch_size * (batch_size - 1))
85 | * ((projections ** 2).sum(N_axis) / grad_norm ** 4 - batch_size)
86 | )
87 | .sqrt()
88 | .item()
89 | )
90 |
91 |
92 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS)
93 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS)
94 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS)
95 | def test_inner_test(problem, independent_runs, q_kwargs):
96 | """Compare BackPACK and ``torch.autograd`` implementation of InnerTest.
97 |
98 | Args:
99 | problem (tests.utils.Problem): Settings for train loop.
100 | independent_runs (bool): Whether to use to separate runs to compute the
101 | output of every quantity.
102 | q_kwargs (dict): Keyword arguments handed over to both quantities.
103 | """
104 | compare_fn = get_compare_fn(independent_runs)
105 | compare_fn(problem, (InnerTest, AutogradInnerTest), q_kwargs)
106 |
--------------------------------------------------------------------------------
/tests/test_quantities/test_mean_gsnr.py:
--------------------------------------------------------------------------------
1 | """Compare ``MeanGSNR`` quantity with ``torch.autograd``."""
2 |
3 | import pytest
4 | import torch
5 |
6 | from cockpit.context import get_individual_losses
7 | from cockpit.quantities import MeanGSNR
8 | from tests.test_quantities.settings import (
9 | INDEPENDENT_RUNS,
10 | INDEPENDENT_RUNS_IDS,
11 | PROBLEMS,
12 | PROBLEMS_IDS,
13 | QUANTITY_KWARGS,
14 | QUANTITY_KWARGS_IDS,
15 | )
16 | from tests.test_quantities.utils import autograd_individual_gradients, get_compare_fn
17 |
18 |
19 | class AutogradMeanGSNR(MeanGSNR):
20 | """``torch.autograd`` implementation of ``MeanGSNR``."""
21 |
22 | def extensions(self, global_step):
23 | """Return list of BackPACK extensions required for the computation.
24 |
25 | Args:
26 | global_step (int): The current iteration number.
27 |
28 | Returns:
29 | list: (Potentially empty) list with required BackPACK quantities.
30 | """
31 | return []
32 |
33 | def extension_hooks(self, global_step):
34 | """Return list of BackPACK extension hooks required for the computation.
35 |
36 | Args:
37 | global_step (int): The current iteration number.
38 |
39 | Returns:
40 | [callable]: List of required BackPACK extension hooks for the current
41 | iteration.
42 | """
43 | return []
44 |
45 | def create_graph(self, global_step):
46 | """Return whether access to the forward pass computation graph is needed.
47 |
48 | Args:
49 | global_step (int): The current iteration number.
50 |
51 | Returns:
52 | bool: ``True`` if the computation graph shall not be deleted,
53 | else ``False``.
54 | """
55 | return self.should_compute(global_step)
56 |
57 | def _compute(self, global_step, params, batch_loss):
58 | """Evaluate the MeanGSNR.
59 |
60 | Args:
61 | global_step (int): The current iteration number.
62 | params ([torch.Tensor]): List of torch.Tensors holding the network's
63 | parameters.
64 | batch_loss (torch.Tensor): Mini-batch loss from current step.
65 |
66 | Returns:
67 | float: Mean GSNR of the current iteration.
68 | """
69 | losses = get_individual_losses(global_step)
70 | individual_gradients_flat = autograd_individual_gradients(
71 | losses, params, concat=True
72 | )
73 |
74 | grad_squared = torch.cat([p.grad.flatten() for p in params]) ** 2
75 |
76 | N_axis = 0
77 | second_moment_flat = (individual_gradients_flat ** 2).mean(N_axis)
78 |
79 | gsnr = grad_squared / (second_moment_flat - grad_squared + self._epsilon)
80 |
81 | return gsnr.mean().item()
82 |
83 |
84 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS)
85 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS)
86 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS)
87 | def test_mean_gsnr(problem, independent_runs, q_kwargs):
88 | """Compare BackPACK and ``torch.autograd`` implementation of MeanGSNR.
89 |
90 | Args:
91 | problem (tests.utils.Problem): Settings for train loop.
92 | independent_runs (bool): Whether to use to separate runs to compute the
93 | output of every quantity.
94 | q_kwargs (dict): Keyword arguments handed over to both quantities.
95 | """
96 | rtol, atol = 5e-3, 1e-5
97 |
98 | compare_fn = get_compare_fn(independent_runs)
99 | compare_fn(problem, (MeanGSNR, AutogradMeanGSNR), q_kwargs, rtol=rtol, atol=atol)
100 |
--------------------------------------------------------------------------------
/tests/test_quantities/test_norm_test.py:
--------------------------------------------------------------------------------
1 | """Compare ``NormTest`` quantity with ``torch.autograd``."""
2 |
3 | import pytest
4 | import torch
5 |
6 | from cockpit.context import get_batch_size, get_individual_losses
7 | from cockpit.quantities import NormTest
8 | from tests.test_quantities.settings import (
9 | INDEPENDENT_RUNS,
10 | INDEPENDENT_RUNS_IDS,
11 | PROBLEMS,
12 | PROBLEMS_IDS,
13 | QUANTITY_KWARGS,
14 | QUANTITY_KWARGS_IDS,
15 | )
16 | from tests.test_quantities.utils import autograd_individual_gradients, get_compare_fn
17 |
18 |
19 | class AutogradNormTest(NormTest):
20 | """``torch.autograd`` implementation of ``NormTest``."""
21 |
22 | def extensions(self, global_step):
23 | """Return list of BackPACK extensions required for the computation.
24 |
25 | Args:
26 | global_step (int): The current iteration number.
27 |
28 | Returns:
29 | list: (Potentially empty) list with required BackPACK quantities.
30 | """
31 | return []
32 |
33 | def extension_hooks(self, global_step):
34 | """Return list of BackPACK extension hooks required for the computation.
35 |
36 | Args:
37 | global_step (int): The current iteration number.
38 |
39 | Returns:
40 | [callable]: List of required BackPACK extension hooks for the current
41 | iteration.
42 | """
43 | return []
44 |
45 | def create_graph(self, global_step):
46 | """Return whether access to the forward pass computation graph is needed.
47 |
48 | Args:
49 | global_step (int): The current iteration number.
50 |
51 | Returns:
52 | bool: ``True`` if the computation graph shall not be deleted,
53 | else ``False``.
54 | """
55 | return self.should_compute(global_step)
56 |
57 | def _compute(self, global_step, params, batch_loss):
58 | """Evaluate the norm test.
59 |
60 | Args:
61 | global_step (int): The current iteration number.
62 | params ([torch.Tensor]): List of torch.Tensors holding the network's
63 | parameters.
64 | batch_loss (torch.Tensor): Mini-batch loss from current step.
65 |
66 | Returns:
67 | float: Result of the norm test.
68 | """
69 | losses = get_individual_losses(global_step)
70 | individual_gradients_flat = autograd_individual_gradients(
71 | losses, params, concat=True
72 | )
73 | sum_of_squares = (individual_gradients_flat ** 2).sum()
74 |
75 | grad_norm = torch.cat([p.grad.flatten() for p in params]).norm()
76 |
77 | batch_size = get_batch_size(global_step)
78 |
79 | return (
80 | (
81 | 1
82 | / (batch_size * (batch_size - 1))
83 | * (sum_of_squares / grad_norm ** 2 - batch_size)
84 | )
85 | .sqrt()
86 | .item()
87 | )
88 |
89 |
90 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS)
91 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS)
92 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS)
93 | def test_norm_test(problem, independent_runs, q_kwargs):
94 | """Compare BackPACK and ``torch.autograd`` implementation of NormTest.
95 |
96 | Args:
97 | problem (tests.utils.Problem): Settings for train loop.
98 | independent_runs (bool): Whether to use to separate runs to compute the
99 | output of every quantity.
100 | q_kwargs (dict): Keyword arguments handed over to both quantities.
101 | """
102 | compare_fn = get_compare_fn(independent_runs)
103 | compare_fn(problem, (NormTest, AutogradNormTest), q_kwargs)
104 |
--------------------------------------------------------------------------------
/tests/test_quantities/test_ortho_test.py:
--------------------------------------------------------------------------------
1 | """Compare ``OrthoTest`` quantity with ``torch.autograd``."""
2 |
3 | import pytest
4 | import torch
5 |
6 | from cockpit.context import get_batch_size, get_individual_losses
7 | from cockpit.quantities import OrthoTest
8 | from tests.test_quantities.settings import (
9 | INDEPENDENT_RUNS,
10 | INDEPENDENT_RUNS_IDS,
11 | PROBLEMS,
12 | PROBLEMS_IDS,
13 | QUANTITY_KWARGS,
14 | QUANTITY_KWARGS_IDS,
15 | )
16 | from tests.test_quantities.utils import autograd_individual_gradients, get_compare_fn
17 |
18 |
19 | class AutogradOrthoTest(OrthoTest):
20 | """``torch.autograd`` implementation of ``OrthoTest``."""
21 |
22 | def extensions(self, global_step):
23 | """Return list of BackPACK extensions required for the computation.
24 |
25 | Args:
26 | global_step (int): The current iteration number.
27 |
28 | Returns:
29 | list: (Potentially empty) list with required BackPACK quantities.
30 | """
31 | return []
32 |
33 | def extension_hooks(self, global_step):
34 | """Return list of BackPACK extension hooks required for the computation.
35 |
36 | Args:
37 | global_step (int): The current iteration number.
38 |
39 | Returns:
40 | [callable]: List of required BackPACK extension hooks for the current
41 | iteration.
42 | """
43 | return []
44 |
45 | def create_graph(self, global_step):
46 | """Return whether access to the forward pass computation graph is needed.
47 |
48 | Args:
49 | global_step (int): The current iteration number.
50 |
51 | Returns:
52 | bool: ``True`` if the computation graph shall not be deleted,
53 | else ``False``.
54 | """
55 | return self.should_compute(global_step)
56 |
57 | def _compute(self, global_step, params, batch_loss):
58 | """Evaluate the norm test.
59 |
60 | Args:
61 | global_step (int): The current iteration number.
62 | params ([torch.Tensor]): List of torch.Tensors holding the network's
63 | parameters.
64 | batch_loss (torch.Tensor): Mini-batch loss from current step.
65 |
66 | Returns:
67 | foat: Result of the norm test.
68 | """
69 | losses = get_individual_losses(global_step)
70 | individual_gradients_flat = autograd_individual_gradients(
71 | losses, params, concat=True
72 | )
73 | D_axis = 1
74 | individual_l2_norms_squared = (individual_gradients_flat ** 2).sum(D_axis)
75 |
76 | grad = torch.cat([p.grad.flatten() for p in params])
77 | grad_norm = grad.norm()
78 |
79 | projections = torch.einsum("ni,i->n", individual_gradients_flat, grad)
80 |
81 | batch_size = get_batch_size(global_step)
82 |
83 | return (
84 | (
85 | 1
86 | / (batch_size * (batch_size - 1))
87 | * (
88 | individual_l2_norms_squared / grad_norm ** 2
89 | - (projections ** 2) / grad_norm ** 4
90 | ).sum()
91 | )
92 | .sqrt()
93 | .item()
94 | )
95 |
96 |
97 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS)
98 | @pytest.mark.parametrize("independent_runs", INDEPENDENT_RUNS, ids=INDEPENDENT_RUNS_IDS)
99 | @pytest.mark.parametrize("q_kwargs", QUANTITY_KWARGS, ids=QUANTITY_KWARGS_IDS)
100 | def test_ortho_test(problem, independent_runs, q_kwargs):
101 | """Compare BackPACK and ``torch.autograd`` implementation of OrthoTest.
102 |
103 | Args:
104 | problem (tests.utils.Problem): Settings for train loop.
105 | independent_runs (bool): Whether to use to separate runs to compute the
106 | output of every quantity.
107 | q_kwargs (dict): Keyword arguments handed over to both quantities.
108 | """
109 | compare_fn = get_compare_fn(independent_runs)
110 | compare_fn(problem, (OrthoTest, AutogradOrthoTest), q_kwargs)
111 |
--------------------------------------------------------------------------------
/tests/test_quantities/test_quantity_integration.py:
--------------------------------------------------------------------------------
1 | """Intergation tests for ``cockpit.quantities``."""
2 |
3 | import pytest
4 |
5 | from cockpit import quantities
6 | from cockpit.quantities import __all__
7 | from cockpit.quantities.quantity import SingleStepQuantity, TwoStepQuantity
8 | from cockpit.utils.schedules import linear
9 | from tests.test_quantities.settings import PROBLEMS, PROBLEMS_IDS
10 | from tests.utils.harness import SimpleTestHarness
11 | from tests.utils.problem import instantiate
12 |
13 | QUANTITIES = [
14 | getattr(quantities, q)
15 | for q in __all__
16 | if q != "Quantity"
17 | and q != "SingleStepQuantity"
18 | and q != "TwoStepQuantity"
19 | and q != "ByproductQuantity"
20 | ]
21 | IDS = [q_cls.__name__ for q_cls in QUANTITIES]
22 |
23 |
24 | @pytest.mark.parametrize("problem", PROBLEMS, ids=PROBLEMS_IDS)
25 | @pytest.mark.parametrize("quantity_cls", QUANTITIES, ids=IDS)
26 | def test_quantity_integration_and_track_events(problem, quantity_cls):
27 | """Check if ``Cockpit`` with a single quantity works.
28 |
29 | Args:
30 | problem (tests.utils.Problem): Settings for train loop.
31 | quantity_cls (Class): Quantity class that should be tested.
32 | """
33 | interval, offset = 1, 2
34 | schedule = linear(interval, offset=offset)
35 | quantity = quantity_cls(track_schedule=schedule, verbose=True)
36 |
37 | with instantiate(problem):
38 | iterations = problem.iterations
39 | testing_harness = SimpleTestHarness(problem)
40 | cockpit_kwargs = {"quantities": [quantity]}
41 | testing_harness.test(cockpit_kwargs)
42 |
43 | def is_track_event(iteration):
44 | if isinstance(quantity, SingleStepQuantity):
45 | return schedule(iteration)
46 | elif isinstance(quantity, TwoStepQuantity):
47 | end_iter = quantity.SAVE_SHIFT + iteration
48 | return quantity.is_end(end_iter) and end_iter < iterations
49 | else:
50 | raise ValueError(f"Unknown quantity: {quantity}")
51 |
52 | track_events = sorted(i for i in range(iterations) if is_track_event(i))
53 | output_events = sorted(quantity.get_output().keys())
54 |
55 | assert output_events == track_events
56 |
--------------------------------------------------------------------------------
/tests/test_utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for ``cockpit.utils``."""
2 |
--------------------------------------------------------------------------------
/tests/test_utils/test_configurations.py:
--------------------------------------------------------------------------------
1 | """Tests for ``cockpit.utils.configuration``."""
2 |
3 | import pytest
4 |
5 | from cockpit import quantities
6 | from cockpit.utils.configuration import quantities_cls_for_configuration
7 |
8 |
9 | @pytest.mark.parametrize("label", ["full", "business", "economy"])
10 | def test_quantities_cls_for_configuration(label):
11 | """Check cockpit configurations contain the correct quantities."""
12 | economy = [
13 | quantities.Alpha,
14 | quantities.GradNorm,
15 | quantities.UpdateSize,
16 | quantities.Distance,
17 | quantities.InnerTest,
18 | quantities.OrthoTest,
19 | quantities.NormTest,
20 | quantities.GradHist1d,
21 | quantities.Loss,
22 | ]
23 | business = economy + [quantities.TICDiag, quantities.HessTrace]
24 | full = business + [quantities.HessMaxEV, quantities.GradHist2d]
25 |
26 | configs = {
27 | "full": set(full),
28 | "business": set(business),
29 | "economy": set(economy),
30 | }
31 |
32 | quants = set(quantities_cls_for_configuration(label))
33 |
34 | assert quants == configs[label]
35 |
--------------------------------------------------------------------------------
/tests/test_utils/test_schedules.py:
--------------------------------------------------------------------------------
1 | """Test for ``cockpit.utils.schedules``."""
2 |
3 | import pytest
4 |
5 | from cockpit.utils import schedules
6 |
7 | MAX_STEP = 100
8 |
9 |
10 | @pytest.mark.parametrize("interval", [1, 2, 3])
11 | @pytest.mark.parametrize("offset", [0, 1, 2, -1])
12 | def test_linear_schedule(interval, offset):
13 | """Check linear schedule.
14 |
15 | Args:
16 | interval (int): The regular tracking interval.
17 | offset (int, optional): Offset of tracking. Defaults to 0.
18 | """
19 | schedule = schedules.linear(interval, offset)
20 | tracking = []
21 | for i in range(MAX_STEP):
22 | tracking.append(schedule(i))
23 |
24 | # If offset is negativ, start from first true value
25 | if offset < 0:
26 | offset = interval + offset
27 |
28 | # check that all steps that should be tracked are true
29 | assert all(tracking[offset::interval])
30 | # Check that everything else is false
31 | assert sum(tracking[offset::interval]) == sum(tracking)
32 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Utility functions for ``Cockpit`` tests."""
2 |
--------------------------------------------------------------------------------
/tests/utils/check.py:
--------------------------------------------------------------------------------
1 | """Utility functions to compare results of two implementations."""
2 |
3 | import numpy
4 |
5 |
6 | def compare_outputs(output1, output2, rtol=1e-5, atol=1e-7):
7 | """Compare outputs of two quantities."""
8 | assert len(list(output1.keys())) == len(
9 | list(output2.keys())
10 | ), "Different number of entries"
11 |
12 | for key in output1.keys():
13 | if isinstance(output1[key], dict):
14 | compare_outputs(output1[key], output2[key])
15 | else:
16 | val1, val2 = output1[key], output2[key]
17 |
18 | compare_fn = get_compare_function(val1, val2)
19 |
20 | compare_fn(val1, val2, atol=atol, rtol=rtol)
21 |
22 |
23 | def get_compare_function(value1, value2):
24 | """Return the function used to compare ``value1`` with ``value2``."""
25 | if isinstance(value1, float) and isinstance(value2, float):
26 | compare_fn = compare_floats
27 | elif isinstance(value1, int) and isinstance(value2, int):
28 | compare_fn = compare_ints
29 | elif isinstance(value1, numpy.ndarray) and isinstance(value2, numpy.ndarray):
30 | compare_fn = compare_arrays
31 | elif isinstance(value1, list) and isinstance(value2, list):
32 | compare_fn = compare_lists
33 | elif isinstance(value1, tuple) and isinstance(value2, tuple):
34 | compare_fn = compare_tuples
35 | else:
36 | raise NotImplementedError(
37 | "No comparison available for these data types: "
38 | + f"{type(value1)}, {type(value2)}."
39 | )
40 |
41 | return compare_fn
42 |
43 |
44 | def compare_tuples(tuple1, tuple2, rtol=1e-5, atol=1e-7):
45 | """Compare two tuples."""
46 | assert len(tuple1) == len(tuple2), "Different number of entries"
47 |
48 | for value1, value2 in zip(tuple1, tuple2):
49 | compare_fn = get_compare_function(value1, value2)
50 | compare_fn(value1, value2, rtol=rtol, atol=atol)
51 |
52 |
53 | def compare_arrays(array1, array2, rtol=1e-5, atol=1e-7):
54 | """Compare two ``numpy`` arrays."""
55 | assert numpy.allclose(array1, array2, rtol=rtol, atol=atol)
56 |
57 |
58 | def compare_floats(float1, float2, rtol=1e-5, atol=1e-7):
59 | """Compare two floats."""
60 | assert numpy.isclose(float1, float2, atol=atol, rtol=rtol), f"{float1} ≠ {float2}"
61 |
62 |
63 | def compare_ints(int1, int2, rtol=None, atol=None):
64 | """Compare two integers.
65 |
66 | ``rtol`` and ``atol`` are ignored in the comparison, but required to keep the
67 | interface identical among comparison functions.
68 |
69 | Args:
70 | int1 (int): First integer.
71 | int2 (int): Another integer.
72 | rtol (any): Ignored, see above.
73 | atol (any): Ignored, see above.
74 | """
75 | assert int1 == int2
76 |
77 |
78 | def compare_lists(list1, list2, rtol=1e-5, atol=1e-7):
79 | """Compare two lists containing floats."""
80 | assert len(list1) == len(
81 | list2
82 | ), f"Lists don't match in size: {len(list1)} ≠ {len(list2)}"
83 |
84 | for val1, val2 in zip(list1, list2):
85 | compare_fn = get_compare_function(val1, val2)
86 | compare_fn(val1, val2, rtol=rtol, atol=atol)
87 |
--------------------------------------------------------------------------------
/tests/utils/data.py:
--------------------------------------------------------------------------------
1 | """Utility functions to create toy input for Cockpit's tests."""
2 |
3 | import torch
4 | from torch.utils.data.dataloader import DataLoader
5 | from torch.utils.data.dataset import Dataset
6 |
7 |
8 | def load_toy_data(batch_size):
9 | """Build a ``DataLoader`` with specified batch size from the toy data."""
10 | return DataLoader(ToyData(), batch_size=batch_size)
11 |
12 |
13 | class ToyData(Dataset):
14 | """Toy data set used for testing. Consists of small random "images" and labels."""
15 |
16 | def __init__(self, center=0.1):
17 | """Init the toy data set.
18 |
19 | Args:
20 | center (float): Center around which the data is randomly distributed.
21 | """
22 | super(ToyData, self).__init__()
23 | self._center = center
24 |
25 | def __getitem__(self, index):
26 | """Return item with index `index` of data set.
27 |
28 | Args:
29 | index (int): Index of sample to access. Ignored for now.
30 |
31 | Returns:
32 | [tuple]: Tuple of (random) input and (random) label.
33 | """
34 | item_input = torch.rand(1, 5, 5) + self._center
35 | item_label = torch.randint(size=(), low=0, high=3)
36 | return (item_input, item_label)
37 |
38 | def __len__(self):
39 | """Length of dataset. Arbitrarily set to 10 000."""
40 | return 10000 # of how many examples(images?) you have
41 |
--------------------------------------------------------------------------------
/tests/utils/harness.py:
--------------------------------------------------------------------------------
1 | """Base class for executing and hooking into a training loop to execute checks."""
2 |
3 |
4 | from backpack import extend
5 |
6 | from cockpit import Cockpit
7 | from tests.utils.rand import restore_rng_state
8 |
9 |
10 | class SimpleTestHarness:
11 | """Class for running a simple test loop with the Cockpit.
12 |
13 | Args:
14 | problem (string): The (instantiated) problem to test on.
15 | """
16 |
17 | def __init__(self, problem):
18 | """Store the instantiated problem."""
19 | self.problem = problem
20 |
21 | def test(self, cockpit_kwargs, *backpack_exts):
22 | """Run the test loop.
23 |
24 | Args:
25 | cockpit_kwargs (dict): Arguments for the cockpit.
26 | *backpack_exts (list): List of user-defined BackPACK extensions.
27 | """
28 | problem = self.problem
29 |
30 | data = problem.data
31 | device = problem.device
32 | iterations = problem.iterations
33 |
34 | # Extend
35 | model = extend(problem.model)
36 | loss_fn = extend(problem.loss_function)
37 | individual_loss_fn = extend(problem.individual_loss_function)
38 |
39 | # Create Optimizer
40 | optimizer = problem.optimizer
41 |
42 | # Initialize Cockpit
43 | self.cockpit = Cockpit(model.parameters(), **cockpit_kwargs)
44 |
45 | # print(cockpit_exts)
46 |
47 | # Main training loop
48 | global_step = 0
49 | for inputs, labels in iter(data):
50 | inputs, labels = inputs.to(device), labels.to(device)
51 | optimizer.zero_grad()
52 |
53 | # forward pass
54 | outputs = model(inputs)
55 | loss = loss_fn(outputs, labels)
56 | losses = individual_loss_fn(outputs, labels)
57 |
58 | # code inside this block does not alter random number generation
59 | with restore_rng_state():
60 | # backward pass
61 | with self.cockpit(
62 | global_step,
63 | *backpack_exts,
64 | info={
65 | "batch_size": inputs.shape[0],
66 | "individual_losses": losses,
67 | "loss": loss,
68 | "optimizer": optimizer,
69 | },
70 | ):
71 | loss.backward(create_graph=self.cockpit.create_graph(global_step))
72 | self.check_in_context()
73 |
74 | self.check_after_context()
75 |
76 | # optimizer step
77 | optimizer.step()
78 | global_step += 1
79 |
80 | if global_step >= iterations:
81 | break
82 |
83 | def check_in_context(self):
84 | """Check that will be executed within the cockpit context."""
85 | pass
86 |
87 | def check_after_context(self):
88 | """Check that will be executed directly after the cockpit context."""
89 | pass
90 |
--------------------------------------------------------------------------------
/tests/utils/models.py:
--------------------------------------------------------------------------------
1 | """Utility functions to create toy models for Cockpit's tests."""
2 |
3 | import torch
4 |
5 |
6 | def load_toy_model():
7 | """Build a tor model that can be used in conjunction with ``ToyData``."""
8 | return torch.nn.Sequential(
9 | torch.nn.Flatten(),
10 | torch.nn.Linear(25, 4),
11 | torch.nn.ReLU(),
12 | torch.nn.Linear(4, 4),
13 | torch.nn.ReLU(),
14 | torch.nn.Linear(4, 3),
15 | )
16 |
--------------------------------------------------------------------------------
/tests/utils/rand.py:
--------------------------------------------------------------------------------
1 | """Utility functions for random number generation."""
2 |
3 | import torch
4 |
5 |
6 | class restore_rng_state:
7 | """Restores PyTorch seed to the value of initialization.
8 |
9 | This has the effect that code inside this context does not influence the outer
10 | loop's random generator state.
11 | """
12 |
13 | def __init__(self):
14 | """Store the current PyTorch seed."""
15 | self._rng_state = torch.get_rng_state()
16 |
17 | def __enter__(self):
18 | """Do nothing."""
19 | pass
20 |
21 | def __exit__(self, exc_type, exc_value, traceback):
22 | """Restore the random generator state at initialization."""
23 | torch.set_rng_state(self._rng_state)
24 |
--------------------------------------------------------------------------------