├── .flake8 ├── .gitignore ├── .travis.yml ├── LICENSE.md ├── Makefile ├── README.md ├── poetry.lock ├── pyproject.toml ├── tests ├── conftest.py ├── test_disable_enable.py ├── test_output_check.py ├── test_param_and_output_check.py ├── test_param_check.py └── test_verbose.py └── torcheck ├── __init__.py ├── output_spec.py ├── param_spec.py ├── registry.py └── utils ├── __init__.py └── message_utils.py /.flake8: -------------------------------------------------------------------------------- 1 | # For code-style check, currently the following exceptions are allowed: 2 | # - F401: imported but unused 3 | [flake8] 4 | count = True 5 | max-line-length = 88 6 | statistics = True 7 | ignore = F401 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: xenial 2 | os: linux 3 | env: 4 | global: 5 | - secure: "qsSXaKcSgMx8buByRF8yHZpnt0ONzkjbqk/QWezjCLZSagX7+59K0/swNmsmwV+d5uMYUpE9SVCMvFG8zZ4M3BqEZ/6BH8abwnO47+0Xd4mp8oH5MgtdeZK+qIgvn1AvkpxVM3+87mf506pLsk9ADhkHqg7uTS045bLvMy3F9fTvxEMIFOWiRJbu+CKcxOFxrlHFrSdxqY2/q+Gbzjb7eJVj4SDrSXBv+Nu514pv4mlOmoHNaqlhM4eFHTnC7cKjaianOHhJhLmMptWhMxwhRns/8SdtmicpAUNBW4w3ZObg9CFfyZhlBzab2T7YG8w/fU445rof9S3wZhfDwsGLFnfAJ0gz+PMUnsdAXjnJRjbi68F2b/CdvBg1YlWWyM30//yQ4/vIfdWUNrqT0YME3VEXUpkDCutvWe3OfW38VOAhFxVHx+jRZeDWuJ8Ax88TL+AmIPUoIAQ5KHcTpRa7U2SQHoqkd+VK1rX9ri8nrkrIpYiJSL1X33mgIb+Xkc6CjqEvKNYTxKE0XOMSx8wIb/3QkIsTZB0nU6jASpF1yN8TwmBW0Yfnu1kZNUqPO369/jMi5tXhwgdBDMO8Sh1KvCZCqKCgT5YC/+5fhyYtAiEJ6wpfOMJgoMkjEkjb3UlzI0xGi6ih0A6WzyTkjQriDUbFulVN8wnkoYRfsNrgPhk=" 6 | language: python 7 | python: 8 | - 3.8 9 | - 3.9 10 | before_install: 11 | - pip install poetry 12 | install: 13 | - poetry install 14 | script: 15 | - poetry run black --check torcheck/ tests/ 16 | - poetry run flake8 torcheck/ tests/ 17 | - poetry run pytest --cov=torcheck/ tests/ 18 | after_success: 19 | - bash <(curl -s https://codecov.io/bash) 20 | before_deploy: 21 | - poetry config pypi-token.pypi $PYPI_TOKEN 22 | - poetry build 23 | deploy: 24 | provider: script 25 | script: poetry publish 26 | skip_cleanup: true 27 | on: 28 | tags: true 29 | python: 3.8 30 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 pengyan510 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: 2 | poetry run pytest tests/ 3 | 4 | test-cov: 5 | poetry run pytest --cov=torcheck/ tests/ 6 | 7 | format: 8 | poetry run black torcheck/ tests/ 9 | 10 | lint: 11 | poetry run black --check torcheck/ tests/ 12 | poetry run flake8 torcheck/ tests/ 13 | 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torcheck 2 | [![Build Status](https://travis-ci.com/pengyan510/torcheck.svg?branch=master)](https://travis-ci.com/pengyan510/torcheck) 3 | [![License](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 4 | [![codecov](https://codecov.io/gh/pengyan510/torcheck/branch/master/graph/badge.svg?token=Q8ADT16N8A)](https://codecov.io/gh/pengyan510/torcheck) 5 | [![PyPI version](https://badge.fury.io/py/torcheck.svg)](https://badge.fury.io/py/torcheck) 6 | 7 | Torcheck is a machine learning sanity check toolkit for PyTorch. 8 | 9 | For a general introduction, please check this out: [Testing Your PyTorch Models with Torcheck](https://towardsdatascience.com/testing-your-pytorch-models-with-torcheck-cb689ecbc08c) 10 | 11 | ## About 12 | The creation of torcheck is inspired by Chase Roberts' [Medium post](https://thenerdstation.medium.com/mltest-automatically-test-neural-network-models-in-one-function-call-eb6f1fa5019d). The innovation and major benefit is that you no longer 13 | need to write additional testing code for your model training. Just add a few 14 | lines of code specifying the checks before training, torcheck will then take over and 15 | perform the checks simultaneouly while the training happens. 16 | 17 | Another benefit is that torcheck allows you to check your model on different levels. 18 | Instead of checking the whole model, you can specify checks for a submodule, a linear 19 | layer, or even the weight tensor! This enables more customization around the sanity 20 | checks. 21 | 22 | ## Installation 23 | ``` 24 | pip install torcheck 25 | ``` 26 | 27 | ## Torcheck in 5 minutes 28 | OK, suppose you have coded up a standard PyTorch training routine like this: 29 | ``` 30 | model = Model() 31 | optimizer = torch.optim.Adam( 32 | model.parameters(), 33 | lr=0.001, 34 | ) 35 | 36 | # torcheck code goes here 37 | 38 | for epoch in range(num_epochs): 39 | for x, y in dataloader: 40 | # calculate loss and backward propagation 41 | ``` 42 | 43 | By simply adding a few lines of code right before the for loop, you can be more confident 44 | about whether your model is training as expected! 45 | 46 | ### Step 1: Registering your optimizer(s) 47 | First, register the optimizer(s) with torcheck: 48 | ``` 49 | torcheck.register(optimizer) 50 | ``` 51 | 52 | ### Step 2: Adding sanity checks 53 | Torcheck enables you to perform a wide range of checks, on both module level and tensor 54 | level. 55 | 56 | A rule of thumb is that use APIs with `add_module` prefix when checking something that 57 | subclasses from `nn.Module`, use APIs with `add_tensor` prefix when checking tensors. 58 | 59 | #### Parameters change/not change 60 | You can check whether model parameters actually get updated during the training. 61 | Or you can check whether they remain constant if you want them to be frozen. 62 | 63 | For our example, some of the possible checks are: 64 | 65 | ``` 66 | # check all the model parameters will change 67 | # module_name is optional, but it makes error messages more informative when checks fail 68 | torcheck.add_module_changing_check(model, module_name="my_model") 69 | ``` 70 | 71 | ``` 72 | # check the linear layer's parameters won't change 73 | torcheck.add_module_unchanging_check(model.linear_0, module_name="linear_layer_0") 74 | ``` 75 | 76 | ``` 77 | # check the linear layer's weight parameters will change 78 | torcheck.add_tensor_changing_check( 79 | model.linear_0.weight, tensor_name="linear_0.weight", module_name="my_model" 80 | ) 81 | ``` 82 | 83 | ``` 84 | # check the linear layer's bias parameters won't change 85 | torcheck.add_tensor_unchanging_check( 86 | model.linear_0.bias, tensor_name="linear_0.bias", module_name="my_model" 87 | ) 88 | ``` 89 | 90 | #### Output range check 91 | The basic use case is that you can check whether model outputs are all within a range, 92 | say (-1, 1). 93 | 94 | You can also check that model outputs are not all within a range. This is useful when 95 | you want softmax to behave correctly. It enables you to check model ouputs are not all 96 | within (0, 1). 97 | 98 | You can check the final model output or intermediate output of a submodule. 99 | ``` 100 | # check model outputs are within (-1, 1) 101 | torcheck.add_module_output_range_check( 102 | model, output_range=(-1, 1), module_name="my_model" 103 | ) 104 | ``` 105 | 106 | ``` 107 | # check outputs from the linear layer are within (-5, 5) 108 | torcheck.add_module_output_range_check( 109 | model.linear_0, output_range=(-5, 5), module_name="linear_layer_0" 110 | ) 111 | 112 | ``` 113 | 114 | ``` 115 | # check model outputs are not all within (0, 1) 116 | # aka softmax hasn't been applied before loss calculation 117 | torcheck.add_module_output_range_check( 118 | model, 119 | output_range=(0, 1), 120 | negate_range=True, 121 | module_name="my_model", 122 | ) 123 | ``` 124 | 125 | #### NaN check 126 | Check whether parameters become NaN during training, or model outputs contain NaN. 127 | 128 | ``` 129 | # check whether model parameters become NaN or outputs contain NaN 130 | torcheck.add_module_nan_check(model, module_name="my_model") 131 | ``` 132 | 133 | ``` 134 | # check whether linear layer's weight parameters become NaN 135 | torcheck.add_tensor_nan_check( 136 | model.linear_0.weight, tensor_name="linear_0.weight", module_name="my_model" 137 | ) 138 | ``` 139 | 140 | #### Inf check 141 | Check whether parameters become infinite (positive or negative infinity) during training, 142 | or model outputs contain infinite value. 143 | 144 | ``` 145 | # check whether model parameters become infinite or outputs contain infinite value 146 | torcheck.add_module_inf_check(model, module_name="my_model") 147 | ``` 148 | 149 | ``` 150 | # check whether linear layer's weight parameters become infinite 151 | torcheck.add_tensor_inf_check( 152 | model.linear_0.weight, tensor_name="linear_0.weight", module_name="my_model" 153 | ) 154 | ``` 155 | 156 | #### Adding multiple checks in one call 157 | You can add all checks for a module/tensor in one call: 158 | ``` 159 | # add all checks for model together 160 | torcheck.add_module( 161 | model, 162 | module_name="my_model", 163 | changing=True, 164 | output_range=(-1, 1), 165 | check_nan=True, 166 | check_inf=True, 167 | ) 168 | ``` 169 | 170 | ``` 171 | # add all checks for linear layer's weight together 172 | torcheck.add_tensor( 173 | model.linear_0.weight, 174 | tensor_name="linear_0.weight", 175 | module_name="my_model", 176 | changing=True, 177 | check_nan=True, 178 | check_inf=True, 179 | ) 180 | ``` 181 | 182 | ### Step 3: Training and fixing 183 | After adding all the checks, run the training as usual and fix errors if any. 184 | 185 | By default torcheck's error messages don't include tensor value information. If you 186 | think it would be helpful, you can add the following line inside your torcheck code: 187 | ``` 188 | torcheck.verbose_on() 189 | ``` 190 | 191 | You can turn it off again by calling 192 | ``` 193 | torcheck.verbose_off() 194 | ``` 195 | 196 | ### (Optional) Step 4: Turning off checks 197 | When your model has passed all the checks, you can easily turn them off by calling 198 | ``` 199 | torcheck.disable() 200 | ``` 201 | This is useful when you want to run your model on a validation set, or you just want to 202 | remove the checking overhead from training. 203 | 204 | If you want to turn on the checks again, just call 205 | ``` 206 | torcheck.enable() 207 | ``` 208 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "appdirs" 3 | version = "1.4.4" 4 | description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." 5 | category = "dev" 6 | optional = false 7 | python-versions = "*" 8 | 9 | [[package]] 10 | name = "atomicwrites" 11 | version = "1.4.0" 12 | description = "Atomic file writes." 13 | category = "dev" 14 | optional = false 15 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 16 | 17 | [[package]] 18 | name = "attrs" 19 | version = "21.2.0" 20 | description = "Classes Without Boilerplate" 21 | category = "dev" 22 | optional = false 23 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 24 | 25 | [package.extras] 26 | dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit"] 27 | docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] 28 | tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface"] 29 | tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins"] 30 | 31 | [[package]] 32 | name = "black" 33 | version = "21.5b2" 34 | description = "The uncompromising code formatter." 35 | category = "dev" 36 | optional = false 37 | python-versions = ">=3.6.2" 38 | 39 | [package.dependencies] 40 | appdirs = "*" 41 | click = ">=7.1.2" 42 | mypy-extensions = ">=0.4.3" 43 | pathspec = ">=0.8.1,<1" 44 | regex = ">=2020.1.8" 45 | toml = ">=0.10.1" 46 | 47 | [package.extras] 48 | colorama = ["colorama (>=0.4.3)"] 49 | d = ["aiohttp (>=3.6.0)", "aiohttp-cors (>=0.4.0)"] 50 | python2 = ["typed-ast (>=1.4.2)"] 51 | uvloop = ["uvloop (>=0.15.2)"] 52 | 53 | [[package]] 54 | name = "click" 55 | version = "8.0.1" 56 | description = "Composable command line interface toolkit" 57 | category = "dev" 58 | optional = false 59 | python-versions = ">=3.6" 60 | 61 | [package.dependencies] 62 | colorama = {version = "*", markers = "platform_system == \"Windows\""} 63 | 64 | [[package]] 65 | name = "colorama" 66 | version = "0.4.4" 67 | description = "Cross-platform colored terminal text." 68 | category = "dev" 69 | optional = false 70 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 71 | 72 | [[package]] 73 | name = "coverage" 74 | version = "5.5" 75 | description = "Code coverage measurement for Python" 76 | category = "dev" 77 | optional = false 78 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" 79 | 80 | [package.extras] 81 | toml = ["toml"] 82 | 83 | [[package]] 84 | name = "flake8" 85 | version = "3.9.2" 86 | description = "the modular source code checker: pep8 pyflakes and co" 87 | category = "dev" 88 | optional = false 89 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" 90 | 91 | [package.dependencies] 92 | mccabe = ">=0.6.0,<0.7.0" 93 | pycodestyle = ">=2.7.0,<2.8.0" 94 | pyflakes = ">=2.3.0,<2.4.0" 95 | 96 | [[package]] 97 | name = "iniconfig" 98 | version = "1.1.1" 99 | description = "iniconfig: brain-dead simple config-ini parsing" 100 | category = "dev" 101 | optional = false 102 | python-versions = "*" 103 | 104 | [[package]] 105 | name = "mccabe" 106 | version = "0.6.1" 107 | description = "McCabe checker, plugin for flake8" 108 | category = "dev" 109 | optional = false 110 | python-versions = "*" 111 | 112 | [[package]] 113 | name = "mypy-extensions" 114 | version = "0.4.3" 115 | description = "Experimental type system extensions for programs checked with the mypy typechecker." 116 | category = "dev" 117 | optional = false 118 | python-versions = "*" 119 | 120 | [[package]] 121 | name = "numpy" 122 | version = "1.20.3" 123 | description = "NumPy is the fundamental package for array computing with Python." 124 | category = "main" 125 | optional = false 126 | python-versions = ">=3.7" 127 | 128 | [[package]] 129 | name = "packaging" 130 | version = "20.9" 131 | description = "Core utilities for Python packages" 132 | category = "dev" 133 | optional = false 134 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 135 | 136 | [package.dependencies] 137 | pyparsing = ">=2.0.2" 138 | 139 | [[package]] 140 | name = "pathspec" 141 | version = "0.8.1" 142 | description = "Utility library for gitignore style pattern matching of file paths." 143 | category = "dev" 144 | optional = false 145 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 146 | 147 | [[package]] 148 | name = "pluggy" 149 | version = "0.13.1" 150 | description = "plugin and hook calling mechanisms for python" 151 | category = "dev" 152 | optional = false 153 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 154 | 155 | [package.extras] 156 | dev = ["pre-commit", "tox"] 157 | 158 | [[package]] 159 | name = "py" 160 | version = "1.10.0" 161 | description = "library with cross-python path, ini-parsing, io, code, log facilities" 162 | category = "dev" 163 | optional = false 164 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 165 | 166 | [[package]] 167 | name = "pycodestyle" 168 | version = "2.7.0" 169 | description = "Python style guide checker" 170 | category = "dev" 171 | optional = false 172 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 173 | 174 | [[package]] 175 | name = "pyflakes" 176 | version = "2.3.1" 177 | description = "passive checker of Python programs" 178 | category = "dev" 179 | optional = false 180 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 181 | 182 | [[package]] 183 | name = "pyparsing" 184 | version = "2.4.7" 185 | description = "Python parsing module" 186 | category = "dev" 187 | optional = false 188 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 189 | 190 | [[package]] 191 | name = "pytest" 192 | version = "6.2.4" 193 | description = "pytest: simple powerful testing with Python" 194 | category = "dev" 195 | optional = false 196 | python-versions = ">=3.6" 197 | 198 | [package.dependencies] 199 | atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} 200 | attrs = ">=19.2.0" 201 | colorama = {version = "*", markers = "sys_platform == \"win32\""} 202 | iniconfig = "*" 203 | packaging = "*" 204 | pluggy = ">=0.12,<1.0.0a1" 205 | py = ">=1.8.2" 206 | toml = "*" 207 | 208 | [package.extras] 209 | testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] 210 | 211 | [[package]] 212 | name = "pytest-cov" 213 | version = "2.12.1" 214 | description = "Pytest plugin for measuring coverage." 215 | category = "dev" 216 | optional = false 217 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 218 | 219 | [package.dependencies] 220 | coverage = ">=5.2.1" 221 | pytest = ">=4.6" 222 | toml = "*" 223 | 224 | [package.extras] 225 | testing = ["fields", "hunter", "process-tests", "six", "pytest-xdist", "virtualenv"] 226 | 227 | [[package]] 228 | name = "regex" 229 | version = "2021.4.4" 230 | description = "Alternative regular expression module, to replace re." 231 | category = "dev" 232 | optional = false 233 | python-versions = "*" 234 | 235 | [[package]] 236 | name = "toml" 237 | version = "0.10.2" 238 | description = "Python Library for Tom's Obvious, Minimal Language" 239 | category = "dev" 240 | optional = false 241 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 242 | 243 | [[package]] 244 | name = "torch" 245 | version = "1.8.1" 246 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" 247 | category = "main" 248 | optional = false 249 | python-versions = ">=3.6.2" 250 | 251 | [package.dependencies] 252 | numpy = "*" 253 | typing-extensions = "*" 254 | 255 | [[package]] 256 | name = "typing-extensions" 257 | version = "3.10.0.0" 258 | description = "Backported and Experimental Type Hints for Python 3.5+" 259 | category = "main" 260 | optional = false 261 | python-versions = "*" 262 | 263 | [metadata] 264 | lock-version = "1.1" 265 | python-versions = "^3.8" 266 | content-hash = "d7d1d0ff09fc34f5e3e4ae34ad1281f5301ad309c5190a638d59c89743384d57" 267 | 268 | [metadata.files] 269 | appdirs = [ 270 | {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"}, 271 | {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"}, 272 | ] 273 | atomicwrites = [ 274 | {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, 275 | {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, 276 | ] 277 | attrs = [ 278 | {file = "attrs-21.2.0-py2.py3-none-any.whl", hash = "sha256:149e90d6d8ac20db7a955ad60cf0e6881a3f20d37096140088356da6c716b0b1"}, 279 | {file = "attrs-21.2.0.tar.gz", hash = "sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb"}, 280 | ] 281 | black = [ 282 | {file = "black-21.5b2-py3-none-any.whl", hash = "sha256:e5cf21ebdffc7a9b29d73912b6a6a9a4df4ce70220d523c21647da2eae0751ef"}, 283 | {file = "black-21.5b2.tar.gz", hash = "sha256:1fc0e0a2c8ae7d269dfcf0c60a89afa299664f3e811395d40b1922dff8f854b5"}, 284 | ] 285 | click = [ 286 | {file = "click-8.0.1-py3-none-any.whl", hash = "sha256:fba402a4a47334742d782209a7c79bc448911afe1149d07bdabdf480b3e2f4b6"}, 287 | {file = "click-8.0.1.tar.gz", hash = "sha256:8c04c11192119b1ef78ea049e0a6f0463e4c48ef00a30160c704337586f3ad7a"}, 288 | ] 289 | colorama = [ 290 | {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, 291 | {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, 292 | ] 293 | coverage = [ 294 | {file = "coverage-5.5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:b6d534e4b2ab35c9f93f46229363e17f63c53ad01330df9f2d6bd1187e5eaacf"}, 295 | {file = "coverage-5.5-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:b7895207b4c843c76a25ab8c1e866261bcfe27bfaa20c192de5190121770672b"}, 296 | {file = "coverage-5.5-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:c2723d347ab06e7ddad1a58b2a821218239249a9e4365eaff6649d31180c1669"}, 297 | {file = "coverage-5.5-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:900fbf7759501bc7807fd6638c947d7a831fc9fdf742dc10f02956ff7220fa90"}, 298 | {file = "coverage-5.5-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:004d1880bed2d97151facef49f08e255a20ceb6f9432df75f4eef018fdd5a78c"}, 299 | {file = "coverage-5.5-cp27-cp27m-win32.whl", hash = "sha256:06191eb60f8d8a5bc046f3799f8a07a2d7aefb9504b0209aff0b47298333302a"}, 300 | {file = "coverage-5.5-cp27-cp27m-win_amd64.whl", hash = "sha256:7501140f755b725495941b43347ba8a2777407fc7f250d4f5a7d2a1050ba8e82"}, 301 | {file = "coverage-5.5-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:372da284cfd642d8e08ef606917846fa2ee350f64994bebfbd3afb0040436905"}, 302 | {file = "coverage-5.5-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:8963a499849a1fc54b35b1c9f162f4108017b2e6db2c46c1bed93a72262ed083"}, 303 | {file = "coverage-5.5-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:869a64f53488f40fa5b5b9dcb9e9b2962a66a87dab37790f3fcfb5144b996ef5"}, 304 | {file = "coverage-5.5-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:4a7697d8cb0f27399b0e393c0b90f0f1e40c82023ea4d45d22bce7032a5d7b81"}, 305 | {file = "coverage-5.5-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:8d0a0725ad7c1a0bcd8d1b437e191107d457e2ec1084b9f190630a4fb1af78e6"}, 306 | {file = "coverage-5.5-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:51cb9476a3987c8967ebab3f0fe144819781fca264f57f89760037a2ea191cb0"}, 307 | {file = "coverage-5.5-cp310-cp310-win_amd64.whl", hash = "sha256:c0891a6a97b09c1f3e073a890514d5012eb256845c451bd48f7968ef939bf4ae"}, 308 | {file = "coverage-5.5-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:3487286bc29a5aa4b93a072e9592f22254291ce96a9fbc5251f566b6b7343cdb"}, 309 | {file = "coverage-5.5-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:deee1077aae10d8fa88cb02c845cfba9b62c55e1183f52f6ae6a2df6a2187160"}, 310 | {file = "coverage-5.5-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:f11642dddbb0253cc8853254301b51390ba0081750a8ac03f20ea8103f0c56b6"}, 311 | {file = "coverage-5.5-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:6c90e11318f0d3c436a42409f2749ee1a115cd8b067d7f14c148f1ce5574d701"}, 312 | {file = "coverage-5.5-cp35-cp35m-manylinux2010_x86_64.whl", hash = "sha256:30c77c1dc9f253283e34c27935fded5015f7d1abe83bc7821680ac444eaf7793"}, 313 | {file = "coverage-5.5-cp35-cp35m-win32.whl", hash = "sha256:9a1ef3b66e38ef8618ce5fdc7bea3d9f45f3624e2a66295eea5e57966c85909e"}, 314 | {file = "coverage-5.5-cp35-cp35m-win_amd64.whl", hash = "sha256:972c85d205b51e30e59525694670de6a8a89691186012535f9d7dbaa230e42c3"}, 315 | {file = "coverage-5.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:af0e781009aaf59e25c5a678122391cb0f345ac0ec272c7961dc5455e1c40066"}, 316 | {file = "coverage-5.5-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:74d881fc777ebb11c63736622b60cb9e4aee5cace591ce274fb69e582a12a61a"}, 317 | {file = "coverage-5.5-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:92b017ce34b68a7d67bd6d117e6d443a9bf63a2ecf8567bb3d8c6c7bc5014465"}, 318 | {file = "coverage-5.5-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:d636598c8305e1f90b439dbf4f66437de4a5e3c31fdf47ad29542478c8508bbb"}, 319 | {file = "coverage-5.5-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:41179b8a845742d1eb60449bdb2992196e211341818565abded11cfa90efb821"}, 320 | {file = "coverage-5.5-cp36-cp36m-win32.whl", hash = "sha256:040af6c32813fa3eae5305d53f18875bedd079960822ef8ec067a66dd8afcd45"}, 321 | {file = "coverage-5.5-cp36-cp36m-win_amd64.whl", hash = "sha256:5fec2d43a2cc6965edc0bb9e83e1e4b557f76f843a77a2496cbe719583ce8184"}, 322 | {file = "coverage-5.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:18ba8bbede96a2c3dde7b868de9dcbd55670690af0988713f0603f037848418a"}, 323 | {file = "coverage-5.5-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:2910f4d36a6a9b4214bb7038d537f015346f413a975d57ca6b43bf23d6563b53"}, 324 | {file = "coverage-5.5-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:f0b278ce10936db1a37e6954e15a3730bea96a0997c26d7fee88e6c396c2086d"}, 325 | {file = "coverage-5.5-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:796c9c3c79747146ebd278dbe1e5c5c05dd6b10cc3bcb8389dfdf844f3ead638"}, 326 | {file = "coverage-5.5-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:53194af30d5bad77fcba80e23a1441c71abfb3e01192034f8246e0d8f99528f3"}, 327 | {file = "coverage-5.5-cp37-cp37m-win32.whl", hash = "sha256:184a47bbe0aa6400ed2d41d8e9ed868b8205046518c52464fde713ea06e3a74a"}, 328 | {file = "coverage-5.5-cp37-cp37m-win_amd64.whl", hash = "sha256:2949cad1c5208b8298d5686d5a85b66aae46d73eec2c3e08c817dd3513e5848a"}, 329 | {file = "coverage-5.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:217658ec7187497e3f3ebd901afdca1af062b42cfe3e0dafea4cced3983739f6"}, 330 | {file = "coverage-5.5-cp38-cp38-manylinux1_i686.whl", hash = "sha256:1aa846f56c3d49205c952d8318e76ccc2ae23303351d9270ab220004c580cfe2"}, 331 | {file = "coverage-5.5-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:24d4a7de75446be83244eabbff746d66b9240ae020ced65d060815fac3423759"}, 332 | {file = "coverage-5.5-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:d1f8bf7b90ba55699b3a5e44930e93ff0189aa27186e96071fac7dd0d06a1873"}, 333 | {file = "coverage-5.5-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:970284a88b99673ccb2e4e334cfb38a10aab7cd44f7457564d11898a74b62d0a"}, 334 | {file = "coverage-5.5-cp38-cp38-win32.whl", hash = "sha256:01d84219b5cdbfc8122223b39a954820929497a1cb1422824bb86b07b74594b6"}, 335 | {file = "coverage-5.5-cp38-cp38-win_amd64.whl", hash = "sha256:2e0d881ad471768bf6e6c2bf905d183543f10098e3b3640fc029509530091502"}, 336 | {file = "coverage-5.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d1f9ce122f83b2305592c11d64f181b87153fc2c2bbd3bb4a3dde8303cfb1a6b"}, 337 | {file = "coverage-5.5-cp39-cp39-manylinux1_i686.whl", hash = "sha256:13c4ee887eca0f4c5a247b75398d4114c37882658300e153113dafb1d76de529"}, 338 | {file = "coverage-5.5-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:52596d3d0e8bdf3af43db3e9ba8dcdaac724ba7b5ca3f6358529d56f7a166f8b"}, 339 | {file = "coverage-5.5-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:2cafbbb3af0733db200c9b5f798d18953b1a304d3f86a938367de1567f4b5bff"}, 340 | {file = "coverage-5.5-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:44d654437b8ddd9eee7d1eaee28b7219bec228520ff809af170488fd2fed3e2b"}, 341 | {file = "coverage-5.5-cp39-cp39-win32.whl", hash = "sha256:d314ed732c25d29775e84a960c3c60808b682c08d86602ec2c3008e1202e3bb6"}, 342 | {file = "coverage-5.5-cp39-cp39-win_amd64.whl", hash = "sha256:13034c4409db851670bc9acd836243aeee299949bd5673e11844befcb0149f03"}, 343 | {file = "coverage-5.5-pp36-none-any.whl", hash = "sha256:f030f8873312a16414c0d8e1a1ddff2d3235655a2174e3648b4fa66b3f2f1079"}, 344 | {file = "coverage-5.5-pp37-none-any.whl", hash = "sha256:2a3859cb82dcbda1cfd3e6f71c27081d18aa251d20a17d87d26d4cd216fb0af4"}, 345 | {file = "coverage-5.5.tar.gz", hash = "sha256:ebe78fe9a0e874362175b02371bdfbee64d8edc42a044253ddf4ee7d3c15212c"}, 346 | ] 347 | flake8 = [ 348 | {file = "flake8-3.9.2-py2.py3-none-any.whl", hash = "sha256:bf8fd333346d844f616e8d47905ef3a3384edae6b4e9beb0c5101e25e3110907"}, 349 | {file = "flake8-3.9.2.tar.gz", hash = "sha256:07528381786f2a6237b061f6e96610a4167b226cb926e2aa2b6b1d78057c576b"}, 350 | ] 351 | iniconfig = [ 352 | {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, 353 | {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, 354 | ] 355 | mccabe = [ 356 | {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"}, 357 | {file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"}, 358 | ] 359 | mypy-extensions = [ 360 | {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, 361 | {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, 362 | ] 363 | numpy = [ 364 | {file = "numpy-1.20.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:70eb5808127284c4e5c9e836208e09d685a7978b6a216db85960b1a112eeace8"}, 365 | {file = "numpy-1.20.3-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6ca2b85a5997dabc38301a22ee43c82adcb53ff660b89ee88dded6b33687e1d8"}, 366 | {file = "numpy-1.20.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c5bf0e132acf7557fc9bb8ded8b53bbbbea8892f3c9a1738205878ca9434206a"}, 367 | {file = "numpy-1.20.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db250fd3e90117e0312b611574cd1b3f78bec046783195075cbd7ba9c3d73f16"}, 368 | {file = "numpy-1.20.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:637d827248f447e63585ca3f4a7d2dfaa882e094df6cfa177cc9cf9cd6cdf6d2"}, 369 | {file = "numpy-1.20.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:8b7bb4b9280da3b2856cb1fc425932f46fba609819ee1c62256f61799e6a51d2"}, 370 | {file = "numpy-1.20.3-cp37-cp37m-win32.whl", hash = "sha256:67d44acb72c31a97a3d5d33d103ab06d8ac20770e1c5ad81bdb3f0c086a56cf6"}, 371 | {file = "numpy-1.20.3-cp37-cp37m-win_amd64.whl", hash = "sha256:43909c8bb289c382170e0282158a38cf306a8ad2ff6dfadc447e90f9961bef43"}, 372 | {file = "numpy-1.20.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f1452578d0516283c87608a5a5548b0cdde15b99650efdfd85182102ef7a7c17"}, 373 | {file = "numpy-1.20.3-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6e51534e78d14b4a009a062641f465cfaba4fdcb046c3ac0b1f61dd97c861b1b"}, 374 | {file = "numpy-1.20.3-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e515c9a93aebe27166ec9593411c58494fa98e5fcc219e47260d9ab8a1cc7f9f"}, 375 | {file = "numpy-1.20.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1c09247ccea742525bdb5f4b5ceeacb34f95731647fe55774aa36557dbb5fa4"}, 376 | {file = "numpy-1.20.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:66fbc6fed94a13b9801fb70b96ff30605ab0a123e775a5e7a26938b717c5d71a"}, 377 | {file = "numpy-1.20.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:ea9cff01e75a956dbee133fa8e5b68f2f92175233de2f88de3a682dd94deda65"}, 378 | {file = "numpy-1.20.3-cp38-cp38-win32.whl", hash = "sha256:f39a995e47cb8649673cfa0579fbdd1cdd33ea497d1728a6cb194d6252268e48"}, 379 | {file = "numpy-1.20.3-cp38-cp38-win_amd64.whl", hash = "sha256:1676b0a292dd3c99e49305a16d7a9f42a4ab60ec522eac0d3dd20cdf362ac010"}, 380 | {file = "numpy-1.20.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:830b044f4e64a76ba71448fce6e604c0fc47a0e54d8f6467be23749ac2cbd2fb"}, 381 | {file = "numpy-1.20.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:55b745fca0a5ab738647d0e4db099bd0a23279c32b31a783ad2ccea729e632df"}, 382 | {file = "numpy-1.20.3-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5d050e1e4bc9ddb8656d7b4f414557720ddcca23a5b88dd7cff65e847864c400"}, 383 | {file = "numpy-1.20.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9c65473ebc342715cb2d7926ff1e202c26376c0dcaaee85a1fd4b8d8c1d3b2f"}, 384 | {file = "numpy-1.20.3-cp39-cp39-win32.whl", hash = "sha256:16f221035e8bd19b9dc9a57159e38d2dd060b48e93e1d843c49cb370b0f415fd"}, 385 | {file = "numpy-1.20.3-cp39-cp39-win_amd64.whl", hash = "sha256:6690080810f77485667bfbff4f69d717c3be25e5b11bb2073e76bb3f578d99b4"}, 386 | {file = "numpy-1.20.3-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4e465afc3b96dbc80cf4a5273e5e2b1e3451286361b4af70ce1adb2984d392f9"}, 387 | {file = "numpy-1.20.3.zip", hash = "sha256:e55185e51b18d788e49fe8305fd73ef4470596b33fc2c1ceb304566b99c71a69"}, 388 | ] 389 | packaging = [ 390 | {file = "packaging-20.9-py2.py3-none-any.whl", hash = "sha256:67714da7f7bc052e064859c05c595155bd1ee9f69f76557e21f051443c20947a"}, 391 | {file = "packaging-20.9.tar.gz", hash = "sha256:5b327ac1320dc863dca72f4514ecc086f31186744b84a230374cc1fd776feae5"}, 392 | ] 393 | pathspec = [ 394 | {file = "pathspec-0.8.1-py2.py3-none-any.whl", hash = "sha256:aa0cb481c4041bf52ffa7b0d8fa6cd3e88a2ca4879c533c9153882ee2556790d"}, 395 | {file = "pathspec-0.8.1.tar.gz", hash = "sha256:86379d6b86d75816baba717e64b1a3a3469deb93bb76d613c9ce79edc5cb68fd"}, 396 | ] 397 | pluggy = [ 398 | {file = "pluggy-0.13.1-py2.py3-none-any.whl", hash = "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"}, 399 | {file = "pluggy-0.13.1.tar.gz", hash = "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0"}, 400 | ] 401 | py = [ 402 | {file = "py-1.10.0-py2.py3-none-any.whl", hash = "sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a"}, 403 | {file = "py-1.10.0.tar.gz", hash = "sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3"}, 404 | ] 405 | pycodestyle = [ 406 | {file = "pycodestyle-2.7.0-py2.py3-none-any.whl", hash = "sha256:514f76d918fcc0b55c6680472f0a37970994e07bbb80725808c17089be302068"}, 407 | {file = "pycodestyle-2.7.0.tar.gz", hash = "sha256:c389c1d06bf7904078ca03399a4816f974a1d590090fecea0c63ec26ebaf1cef"}, 408 | ] 409 | pyflakes = [ 410 | {file = "pyflakes-2.3.1-py2.py3-none-any.whl", hash = "sha256:7893783d01b8a89811dd72d7dfd4d84ff098e5eed95cfa8905b22bbffe52efc3"}, 411 | {file = "pyflakes-2.3.1.tar.gz", hash = "sha256:f5bc8ecabc05bb9d291eb5203d6810b49040f6ff446a756326104746cc00c1db"}, 412 | ] 413 | pyparsing = [ 414 | {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, 415 | {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, 416 | ] 417 | pytest = [ 418 | {file = "pytest-6.2.4-py3-none-any.whl", hash = "sha256:91ef2131a9bd6be8f76f1f08eac5c5317221d6ad1e143ae03894b862e8976890"}, 419 | {file = "pytest-6.2.4.tar.gz", hash = "sha256:50bcad0a0b9c5a72c8e4e7c9855a3ad496ca6a881a3641b4260605450772c54b"}, 420 | ] 421 | pytest-cov = [ 422 | {file = "pytest-cov-2.12.1.tar.gz", hash = "sha256:261ceeb8c227b726249b376b8526b600f38667ee314f910353fa318caa01f4d7"}, 423 | {file = "pytest_cov-2.12.1-py2.py3-none-any.whl", hash = "sha256:261bb9e47e65bd099c89c3edf92972865210c36813f80ede5277dceb77a4a62a"}, 424 | ] 425 | regex = [ 426 | {file = "regex-2021.4.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:619d71c59a78b84d7f18891fe914446d07edd48dc8328c8e149cbe0929b4e000"}, 427 | {file = "regex-2021.4.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:47bf5bf60cf04d72bf6055ae5927a0bd9016096bf3d742fa50d9bf9f45aa0711"}, 428 | {file = "regex-2021.4.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:281d2fd05555079448537fe108d79eb031b403dac622621c78944c235f3fcf11"}, 429 | {file = "regex-2021.4.4-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:bd28bc2e3a772acbb07787c6308e00d9626ff89e3bfcdebe87fa5afbfdedf968"}, 430 | {file = "regex-2021.4.4-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:7c2a1af393fcc09e898beba5dd59196edaa3116191cc7257f9224beaed3e1aa0"}, 431 | {file = "regex-2021.4.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:c38c71df845e2aabb7fb0b920d11a1b5ac8526005e533a8920aea97efb8ec6a4"}, 432 | {file = "regex-2021.4.4-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:96fcd1888ab4d03adfc9303a7b3c0bd78c5412b2bfbe76db5b56d9eae004907a"}, 433 | {file = "regex-2021.4.4-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:ade17eb5d643b7fead300a1641e9f45401c98eee23763e9ed66a43f92f20b4a7"}, 434 | {file = "regex-2021.4.4-cp36-cp36m-win32.whl", hash = "sha256:e8e5b509d5c2ff12f8418006d5a90e9436766133b564db0abaec92fd27fcee29"}, 435 | {file = "regex-2021.4.4-cp36-cp36m-win_amd64.whl", hash = "sha256:11d773d75fa650cd36f68d7ca936e3c7afaae41b863b8c387a22aaa78d3c5c79"}, 436 | {file = "regex-2021.4.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:d3029c340cfbb3ac0a71798100ccc13b97dddf373a4ae56b6a72cf70dfd53bc8"}, 437 | {file = "regex-2021.4.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:18c071c3eb09c30a264879f0d310d37fe5d3a3111662438889ae2eb6fc570c31"}, 438 | {file = "regex-2021.4.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:4c557a7b470908b1712fe27fb1ef20772b78079808c87d20a90d051660b1d69a"}, 439 | {file = "regex-2021.4.4-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:01afaf2ec48e196ba91b37451aa353cb7eda77efe518e481707e0515025f0cd5"}, 440 | {file = "regex-2021.4.4-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:3a9cd17e6e5c7eb328517969e0cb0c3d31fd329298dd0c04af99ebf42e904f82"}, 441 | {file = "regex-2021.4.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:90f11ff637fe8798933fb29f5ae1148c978cccb0452005bf4c69e13db951e765"}, 442 | {file = "regex-2021.4.4-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:919859aa909429fb5aa9cf8807f6045592c85ef56fdd30a9a3747e513db2536e"}, 443 | {file = "regex-2021.4.4-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:339456e7d8c06dd36a22e451d58ef72cef293112b559010db3d054d5560ef439"}, 444 | {file = "regex-2021.4.4-cp37-cp37m-win32.whl", hash = "sha256:67bdb9702427ceddc6ef3dc382455e90f785af4c13d495f9626861763ee13f9d"}, 445 | {file = "regex-2021.4.4-cp37-cp37m-win_amd64.whl", hash = "sha256:32e65442138b7b76dd8173ffa2cf67356b7bc1768851dded39a7a13bf9223da3"}, 446 | {file = "regex-2021.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1e1c20e29358165242928c2de1482fb2cf4ea54a6a6dea2bd7a0e0d8ee321500"}, 447 | {file = "regex-2021.4.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:314d66636c494ed9c148a42731b3834496cc9a2c4251b1661e40936814542b14"}, 448 | {file = "regex-2021.4.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:6d1b01031dedf2503631d0903cb563743f397ccaf6607a5e3b19a3d76fc10480"}, 449 | {file = "regex-2021.4.4-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:741a9647fcf2e45f3a1cf0e24f5e17febf3efe8d4ba1281dcc3aa0459ef424dc"}, 450 | {file = "regex-2021.4.4-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:4c46e22a0933dd783467cf32b3516299fb98cfebd895817d685130cc50cd1093"}, 451 | {file = "regex-2021.4.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e512d8ef5ad7b898cdb2d8ee1cb09a8339e4f8be706d27eaa180c2f177248a10"}, 452 | {file = "regex-2021.4.4-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:980d7be47c84979d9136328d882f67ec5e50008681d94ecc8afa8a65ed1f4a6f"}, 453 | {file = "regex-2021.4.4-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:ce15b6d103daff8e9fee13cf7f0add05245a05d866e73926c358e871221eae87"}, 454 | {file = "regex-2021.4.4-cp38-cp38-win32.whl", hash = "sha256:a91aa8619b23b79bcbeb37abe286f2f408d2f2d6f29a17237afda55bb54e7aac"}, 455 | {file = "regex-2021.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:c0502c0fadef0d23b128605d69b58edb2c681c25d44574fc673b0e52dce71ee2"}, 456 | {file = "regex-2021.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:598585c9f0af8374c28edd609eb291b5726d7cbce16be6a8b95aa074d252ee17"}, 457 | {file = "regex-2021.4.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:ee54ff27bf0afaf4c3b3a62bcd016c12c3fdb4ec4f413391a90bd38bc3624605"}, 458 | {file = "regex-2021.4.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:7d9884d86dd4dd489e981d94a65cd30d6f07203d90e98f6f657f05170f6324c9"}, 459 | {file = "regex-2021.4.4-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:bf5824bfac591ddb2c1f0a5f4ab72da28994548c708d2191e3b87dd207eb3ad7"}, 460 | {file = "regex-2021.4.4-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:563085e55b0d4fb8f746f6a335893bda5c2cef43b2f0258fe1020ab1dd874df8"}, 461 | {file = "regex-2021.4.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b9c3db21af35e3b3c05764461b262d6f05bbca08a71a7849fd79d47ba7bc33ed"}, 462 | {file = "regex-2021.4.4-cp39-cp39-manylinux2014_i686.whl", hash = "sha256:3916d08be28a1149fb97f7728fca1f7c15d309a9f9682d89d79db75d5e52091c"}, 463 | {file = "regex-2021.4.4-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:fd45ff9293d9274c5008a2054ecef86a9bfe819a67c7be1afb65e69b405b3042"}, 464 | {file = "regex-2021.4.4-cp39-cp39-win32.whl", hash = "sha256:fa4537fb4a98fe8fde99626e4681cc644bdcf2a795038533f9f711513a862ae6"}, 465 | {file = "regex-2021.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:97f29f57d5b84e73fbaf99ab3e26134e6687348e95ef6b48cfd2c06807005a07"}, 466 | {file = "regex-2021.4.4.tar.gz", hash = "sha256:52ba3d3f9b942c49d7e4bc105bb28551c44065f139a65062ab7912bef10c9afb"}, 467 | ] 468 | toml = [ 469 | {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, 470 | {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, 471 | ] 472 | torch = [ 473 | {file = "torch-1.8.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:f23eeb1a48cc39209d986c418ad7e02227eee973da45c0c42d36b1aec72f4940"}, 474 | {file = "torch-1.8.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:4ace9c5bb94d5a7b9582cd089993201658466e9c59ff88bd4e9e08f6f072d1cf"}, 475 | {file = "torch-1.8.1-cp36-cp36m-win_amd64.whl", hash = "sha256:6ffa1e7ae079c7cb828712cb0cdaae5cc4fb87c16a607e6d14526b62c20bcc17"}, 476 | {file = "torch-1.8.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:16f2630d9604c4ee28ea7d6e388e2264cd7bc6031c6ecd796bae3f56b5efa9a3"}, 477 | {file = "torch-1.8.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:95b7bbbacc3f28fe438f418392ceeae146a01adc03b29d44917d55214ac234c9"}, 478 | {file = "torch-1.8.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:55137feb2f5a0dc7aced5bba690dcdb7652054ad3452b09a2bbb59f02a11e9ff"}, 479 | {file = "torch-1.8.1-cp37-cp37m-win_amd64.whl", hash = "sha256:8ad2252bf09833dcf46a536a78544e349b8256a370e03a98627ebfb118d9555b"}, 480 | {file = "torch-1.8.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:1388b30fbd262c1a053d6c9ace73bb0bd8f5871b4892b6f3e02d1d7bc9768563"}, 481 | {file = "torch-1.8.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:e7ad1649adb7dc2a450e70a3e51240b84fa4746c69c8f98989ce0c254f9fba3a"}, 482 | {file = "torch-1.8.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:3e4190c04dfd89c59bad06d5fe451446643a65e6d2607cc989eb1001ee76e12f"}, 483 | {file = "torch-1.8.1-cp38-cp38-win_amd64.whl", hash = "sha256:5c2e9a33d44cdb93ebd739b127ffd7da786bf5f740539539195195b186a05f6c"}, 484 | {file = "torch-1.8.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:c6ede2ae4dcd8214b63e047efabafa92493605205a947574cf358216ca4e440a"}, 485 | {file = "torch-1.8.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ce7d435426f3dd14f95710d779aa46e9cd5e077d512488e813f7589fdc024f78"}, 486 | {file = "torch-1.8.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:a50ea8ed900927fb30cadb63aa7a32fdd59c7d7abe5012348dfbe35a8355c083"}, 487 | {file = "torch-1.8.1-cp39-cp39-win_amd64.whl", hash = "sha256:dac4d10494e74f7e553c92d7263e19ea501742c4825ddd26c4decfa27be95981"}, 488 | {file = "torch-1.8.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:225ee4238c019b28369c71977327deeeb2bd1c6b8557e6fcf631b8866bdc5447"}, 489 | ] 490 | typing-extensions = [ 491 | {file = "typing_extensions-3.10.0.0-py2-none-any.whl", hash = "sha256:0ac0f89795dd19de6b97debb0c6af1c70987fd80a2d62d1958f7e56fcc31b497"}, 492 | {file = "typing_extensions-3.10.0.0-py3-none-any.whl", hash = "sha256:779383f6086d90c99ae41cf0ff39aac8a7937a9283ce0a414e5dd782f4c94a84"}, 493 | {file = "typing_extensions-3.10.0.0.tar.gz", hash = "sha256:50b6f157849174217d0656f99dc82fe932884fb250826c18350e159ec6cdf342"}, 494 | ] 495 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "torcheck" 3 | version = "1.0.1" 4 | description = "A machine learning sanity check toolkit for PyTorch" 5 | authors = ["Peng Yan "] 6 | license = "MIT" 7 | readme = "README.md" 8 | homepage = "https://github.com/pengyan510/torcheck" 9 | repository = "https://github.com/pengyan510/torcheck" 10 | 11 | [tool.poetry.dependencies] 12 | python = "^3.8" 13 | torch = "^1.8" 14 | 15 | [tool.poetry.dev-dependencies] 16 | black = "^21.4b0" 17 | flake8 = "^3.9" 18 | pytest-cov = "^2.11" 19 | 20 | [build-system] 21 | requires = ["poetry-core>=1.0.0"] 22 | build-backend = "poetry.core.masonry.api" 23 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.optim import Adam 6 | from torch.utils.data import TensorDataset, DataLoader 7 | 8 | 9 | class NetBase(nn.Module): 10 | def __init__(self): 11 | super(NetBase, self).__init__() 12 | self.fc1 = nn.Linear(5, 5) 13 | self.fc2 = nn.Linear(5, 2) 14 | self.relu = nn.ReLU() 15 | 16 | 17 | class ChangingNet(NetBase): 18 | def forward(self, x): 19 | output = self.relu(self.fc1(x)) 20 | output = self.fc2(output) 21 | return output 22 | 23 | 24 | class UnchangingNet(NetBase): 25 | def forward(self, x): 26 | output = self.relu(self.fc1(x)) 27 | output = self.fc2(x) 28 | return output 29 | 30 | 31 | class NaNNet(NetBase): 32 | def forward(self, x): 33 | output = self.relu(self.fc1(x)) 34 | output = self.fc2(output) 35 | output = output - torch.max(output) 36 | output = torch.sqrt(output) 37 | return output 38 | 39 | 40 | class InfNet(NetBase): 41 | def forward(self, x): 42 | output = self.relu(self.fc1(x)) 43 | output = self.fc2(output) 44 | output = output / 0 45 | return output 46 | 47 | 48 | class BoundedNet(NetBase): 49 | def forward(self, x): 50 | output = self.relu(self.fc1(x)) 51 | output = self.fc2(output) 52 | output = F.softmax(output, dim=1) 53 | return output 54 | 55 | 56 | @pytest.fixture(scope="function") 57 | def changing_model(): 58 | return ChangingNet() 59 | 60 | 61 | @pytest.fixture(scope="function") 62 | def unchanging_model(): 63 | return UnchangingNet() 64 | 65 | 66 | @pytest.fixture(scope="function") 67 | def nan_model(): 68 | return NaNNet() 69 | 70 | 71 | @pytest.fixture(scope="function") 72 | def inf_model(): 73 | return InfNet() 74 | 75 | 76 | @pytest.fixture(scope="function") 77 | def bounded_model(): 78 | return BoundedNet() 79 | 80 | 81 | @pytest.fixture(scope="function") 82 | def changing_model_optimizer(changing_model): 83 | return Adam(changing_model.parameters(), lr=0.001) 84 | 85 | 86 | @pytest.fixture(scope="function") 87 | def unchanging_model_optimizer(unchanging_model): 88 | return Adam(unchanging_model.parameters(), lr=0.001) 89 | 90 | 91 | @pytest.fixture(scope="function") 92 | def nan_model_optimizer(nan_model): 93 | return Adam(nan_model.parameters(), lr=0.001) 94 | 95 | 96 | @pytest.fixture(scope="function") 97 | def inf_model_optimizer(inf_model): 98 | return Adam(inf_model.parameters(), lr=0.001) 99 | 100 | 101 | @pytest.fixture(scope="function") 102 | def bounded_model_optimizer(bounded_model): 103 | return Adam(bounded_model.parameters(), lr=0.001) 104 | 105 | 106 | @pytest.fixture(scope="function") 107 | def correct_model_optimizer(correct_model): 108 | return Adam(correct_model.parameters(), lr=0.001) 109 | 110 | 111 | @pytest.fixture(scope="function") 112 | def nonan_model_optimizer(nonan_model): 113 | return Adam(nonan_model.parameters(), lr=0.001) 114 | 115 | 116 | @pytest.fixture(scope="function") 117 | def noinf_model_optimizer(noinf_model): 118 | return Adam(noinf_model.parameters(), lr=0.001) 119 | 120 | 121 | @pytest.fixture(scope="function") 122 | def unbounded_model_optimizer(unbounded_model): 123 | return Adam(unbounded_model.parameters(), lr=0.001) 124 | 125 | 126 | @pytest.fixture(scope="function") 127 | def dataloader(): 128 | torch.manual_seed(42) 129 | x_data = torch.randn(8, 5) 130 | y_data = torch.randint(low=0, high=2, size=(8,)) 131 | dataset = TensorDataset(x_data, y_data) 132 | return DataLoader(dataset, batch_size=4) 133 | 134 | 135 | @pytest.fixture(scope="function") 136 | def run_training(): 137 | def func(model, dataloader, optimizer): 138 | for x_from_data, y_from_data in dataloader: 139 | y_from_model = model(x_from_data) 140 | loss = F.cross_entropy(y_from_model, y_from_data) 141 | loss.backward() 142 | optimizer.step() 143 | optimizer.zero_grad() 144 | 145 | return func 146 | 147 | 148 | correct_model = changing_model 149 | nonan_model = changing_model 150 | noinf_model = changing_model 151 | unbounded_model = changing_model 152 | -------------------------------------------------------------------------------- /tests/test_disable_enable.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torcheck 3 | 4 | 5 | def test_disable( 6 | unchanging_model_optimizer, unchanging_model, dataloader, run_training 7 | ): 8 | torcheck.register(unchanging_model_optimizer) 9 | torcheck.add_module_changing_check(unchanging_model, module_name="NeuralNet") 10 | torcheck.disable() 11 | run_training(unchanging_model, dataloader, unchanging_model_optimizer) 12 | 13 | 14 | def test_disable_enable( 15 | unchanging_model_optimizer, unchanging_model, dataloader, run_training 16 | ): 17 | torcheck.register(unchanging_model_optimizer) 18 | torcheck.add_module_changing_check(unchanging_model, module_name="NeuralNet") 19 | torcheck.disable() 20 | run_training(unchanging_model, dataloader, unchanging_model_optimizer) 21 | torcheck.enable() 22 | with pytest.raises( 23 | RuntimeError, 24 | match=( 25 | r"Module NeuralNet's fc1\.weight should change\.\n" 26 | r".*fc1.bias should change" 27 | ), 28 | ): 29 | run_training(unchanging_model, dataloader, unchanging_model_optimizer) 30 | -------------------------------------------------------------------------------- /tests/test_output_check.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torcheck 3 | 4 | 5 | def test_module_output_range_check_with_bounded_model( 6 | bounded_model_optimizer, bounded_model, dataloader, run_training 7 | ): 8 | torcheck.add_module_output_range_check( 9 | bounded_model, output_range=(0, 1), module_name="NeuralNet" 10 | ) 11 | run_training(bounded_model, dataloader, bounded_model_optimizer) 12 | 13 | 14 | def test_module_output_range_check_with_unbounded_model( 15 | unbounded_model_optimizer, unbounded_model, dataloader, run_training 16 | ): 17 | torcheck.add_module_output_range_check( 18 | unbounded_model, output_range=(0, 1), module_name="NeuralNet" 19 | ) 20 | with pytest.raises( 21 | RuntimeError, match=r"Module NeuralNet's output should all > 0 and < 1" 22 | ): 23 | run_training(unbounded_model, dataloader, unbounded_model_optimizer) 24 | 25 | 26 | def test_module_output_negate_range_check_with_bounded_model( 27 | bounded_model_optimizer, bounded_model, dataloader, run_training 28 | ): 29 | torcheck.add_module_output_range_check( 30 | bounded_model, 31 | output_range=(0, 1), 32 | negate_range=True, 33 | module_name="NeuralNet", 34 | ) 35 | with pytest.raises( 36 | RuntimeError, 37 | match=r"Module NeuralNet's output shouldn't all > 0 and < 1", 38 | ): 39 | run_training(bounded_model, dataloader, bounded_model_optimizer) 40 | 41 | 42 | def test_module_output_negate_range_check_with_unbounded_model( 43 | unbounded_model_optimizer, unbounded_model, dataloader, run_training 44 | ): 45 | torcheck.add_module_output_range_check( 46 | unbounded_model, 47 | output_range=(0, 1), 48 | negate_range=True, 49 | module_name="NeuralNet", 50 | ) 51 | run_training(unbounded_model, dataloader, unbounded_model_optimizer) 52 | -------------------------------------------------------------------------------- /tests/test_param_and_output_check.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torcheck 3 | 4 | 5 | def test_module_nan_check_with_nan_model( 6 | nan_model_optimizer, nan_model, dataloader, run_training 7 | ): 8 | torcheck.register(nan_model_optimizer) 9 | torcheck.add_module_nan_check(nan_model, module_name="NeuralNet") 10 | with pytest.raises(RuntimeError, match=r"Module NeuralNet's output contains NaN"): 11 | run_training(nan_model, dataloader, nan_model_optimizer) 12 | 13 | 14 | def test_module_nan_check_with_nonan_model( 15 | nonan_model_optimizer, nonan_model, dataloader, run_training 16 | ): 17 | torcheck.register(nonan_model_optimizer) 18 | torcheck.add_module_nan_check(nonan_model, module_name="NeuralNet") 19 | run_training(nonan_model, dataloader, nonan_model_optimizer) 20 | 21 | 22 | def test_module_inf_check_with_inf_model( 23 | inf_model_optimizer, inf_model, dataloader, run_training 24 | ): 25 | torcheck.register(inf_model_optimizer) 26 | torcheck.add_module_inf_check(inf_model, module_name="NeuralNet") 27 | with pytest.raises(RuntimeError, match=r"Module NeuralNet's output contains inf"): 28 | run_training(inf_model, dataloader, inf_model_optimizer) 29 | 30 | 31 | def test_module_inf_check_with_noinf_model( 32 | noinf_model_optimizer, noinf_model, dataloader, run_training 33 | ): 34 | torcheck.register(noinf_model_optimizer) 35 | torcheck.add_module_inf_check(noinf_model, module_name="NeuralNet") 36 | run_training(noinf_model, dataloader, noinf_model_optimizer) 37 | 38 | 39 | def test_module_multiple_check_with_correct_model( 40 | correct_model_optimizer, correct_model, dataloader, run_training 41 | ): 42 | torcheck.register(correct_model_optimizer) 43 | torcheck.add_module( 44 | correct_model, 45 | module_name="NeuralNet", 46 | changing=True, 47 | output_range=(0, 1), 48 | negate_range=True, 49 | check_nan=True, 50 | check_inf=True, 51 | ) 52 | run_training(correct_model, dataloader, correct_model_optimizer) 53 | -------------------------------------------------------------------------------- /tests/test_param_check.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torcheck 3 | 4 | 5 | def test_module_changing_check_with_changing_model( 6 | changing_model_optimizer, changing_model, dataloader, run_training 7 | ): 8 | torcheck.register(changing_model_optimizer) 9 | torcheck.add_module_changing_check(changing_model, module_name="NeuralNet") 10 | run_training(changing_model, dataloader, changing_model_optimizer) 11 | 12 | 13 | def test_module_changing_check_with_unchanging_model( 14 | unchanging_model_optimizer, unchanging_model, dataloader, run_training 15 | ): 16 | torcheck.register(unchanging_model_optimizer) 17 | torcheck.add_module_changing_check(unchanging_model, module_name="NeuralNet") 18 | with pytest.raises( 19 | RuntimeError, 20 | match=( 21 | r"Module NeuralNet's fc1\.weight should change\.\n" 22 | r".*fc1.bias should change" 23 | ), 24 | ): 25 | run_training(unchanging_model, dataloader, unchanging_model_optimizer) 26 | 27 | 28 | def test_module_unchanging_check_with_changing_model( 29 | changing_model_optimizer, changing_model, dataloader, run_training 30 | ): 31 | torcheck.register(changing_model_optimizer) 32 | torcheck.add_module_unchanging_check(changing_model, module_name="NeuralNet") 33 | with pytest.raises( 34 | RuntimeError, 35 | match=( 36 | r"Module NeuralNet's fc1\.weight should not change\." 37 | r"(.|\n)*fc2\.weight should not change" 38 | ), 39 | ): 40 | run_training(changing_model, dataloader, changing_model_optimizer) 41 | 42 | 43 | def test_module_unchanging_check_with_unchanging_model( 44 | unchanging_model_optimizer, unchanging_model, dataloader, run_training 45 | ): 46 | torcheck.register(unchanging_model_optimizer) 47 | torcheck.add_module_unchanging_check( 48 | unchanging_model.fc1, module_name="First Layer" 49 | ) 50 | run_training(unchanging_model, dataloader, unchanging_model_optimizer) 51 | 52 | 53 | def test_tensor_changing_check_with_changing_model( 54 | changing_model_optimizer, changing_model, dataloader, run_training 55 | ): 56 | torcheck.register(changing_model_optimizer) 57 | torcheck.add_tensor_changing_check( 58 | changing_model.fc1.weight, tensor_name="fc1.weight", module_name="NeuralNet" 59 | ) 60 | torcheck.add_tensor_changing_check( 61 | changing_model.fc1.bias, tensor_name="fc1.bias", module_name="NeuralNet" 62 | ) 63 | torcheck.add_tensor_changing_check( 64 | changing_model.fc2.weight, tensor_name="fc2.weight", module_name="NeuralNet" 65 | ) 66 | torcheck.add_tensor_changing_check( 67 | changing_model.fc2.bias, tensor_name="fc2.bias", module_name="NeuralNet" 68 | ) 69 | run_training(changing_model, dataloader, changing_model_optimizer) 70 | 71 | 72 | def test_tensor_changing_check_with_unchanging_model( 73 | unchanging_model_optimizer, unchanging_model, dataloader, run_training 74 | ): 75 | torcheck.register(unchanging_model_optimizer) 76 | torcheck.add_tensor_changing_check( 77 | unchanging_model.fc1.weight, tensor_name="fc1.weight", module_name="NeuralNet" 78 | ) 79 | torcheck.add_tensor_changing_check( 80 | unchanging_model.fc1.bias, tensor_name="fc1.bias", module_name="NeuralNet" 81 | ) 82 | torcheck.add_tensor_changing_check( 83 | unchanging_model.fc2.weight, tensor_name="fc2.weight", module_name="NeuralNet" 84 | ) 85 | torcheck.add_tensor_changing_check( 86 | unchanging_model.fc2.bias, tensor_name="fc2.bias", module_name="NeuralNet" 87 | ) 88 | with pytest.raises( 89 | RuntimeError, 90 | match=( 91 | r"Module NeuralNet's fc1\.weight should change\.\n" 92 | r".*fc1.bias should change" 93 | ), 94 | ): 95 | run_training(unchanging_model, dataloader, unchanging_model_optimizer) 96 | 97 | 98 | def test_tensor_unchanging_check_with_changing_model( 99 | changing_model_optimizer, changing_model, dataloader, run_training 100 | ): 101 | torcheck.register(changing_model_optimizer) 102 | torcheck.add_tensor_unchanging_check( 103 | changing_model.fc1.weight, tensor_name="fc1.weight", module_name="NeuralNet" 104 | ) 105 | torcheck.add_tensor_unchanging_check( 106 | changing_model.fc1.bias, tensor_name="fc1.bias", module_name="NeuralNet" 107 | ) 108 | with pytest.raises( 109 | RuntimeError, 110 | match=( 111 | r"Module NeuralNet's fc1\.weight should not change\.\n" 112 | r".*fc1.bias should not change" 113 | ), 114 | ): 115 | run_training(changing_model, dataloader, changing_model_optimizer) 116 | 117 | 118 | def test_tensor_unchanging_check_with_unchanging_model( 119 | unchanging_model_optimizer, unchanging_model, dataloader, run_training 120 | ): 121 | torcheck.register(unchanging_model_optimizer) 122 | torcheck.add_tensor_unchanging_check( 123 | unchanging_model.fc1.weight, tensor_name="fc1.weight", module_name="NeuralNet" 124 | ) 125 | torcheck.add_tensor_unchanging_check( 126 | unchanging_model.fc1.bias, tensor_name="fc1.bias", module_name="NeuralNet" 127 | ) 128 | run_training(unchanging_model, dataloader, unchanging_model_optimizer) 129 | 130 | 131 | def test_tensor_nan_check_with_nan_model( 132 | nan_model_optimizer, nan_model, dataloader, run_training 133 | ): 134 | torcheck.register(nan_model_optimizer) 135 | torcheck.add_tensor_nan_check( 136 | nan_model.fc1.weight, tensor_name="fc1.weight", module_name="NeuralNet" 137 | ) 138 | torcheck.add_tensor_nan_check( 139 | nan_model.fc1.bias, tensor_name="fc1.bias", module_name="NeuralNet" 140 | ) 141 | torcheck.add_tensor_nan_check( 142 | nan_model.fc2.weight, tensor_name="fc2.weight", module_name="NeuralNet" 143 | ) 144 | torcheck.add_tensor_nan_check( 145 | nan_model.fc2.bias, tensor_name="fc2.bias", module_name="NeuralNet" 146 | ) 147 | with pytest.raises( 148 | RuntimeError, 149 | match=( 150 | r"Module NeuralNet's fc1\.weight contains NaN\.\n" 151 | r".*fc1.bias contains NaN\.\n.*fc2.weight contains NaN\.\n" 152 | r".*fc2.bias contains NaN" 153 | ), 154 | ): 155 | run_training(nan_model, dataloader, nan_model_optimizer) 156 | 157 | 158 | def test_tensor_nan_check_with_nonan_model( 159 | nonan_model_optimizer, nonan_model, dataloader, run_training 160 | ): 161 | torcheck.register(nonan_model_optimizer) 162 | torcheck.add_tensor_nan_check( 163 | nonan_model.fc1.weight, tensor_name="fc1.weight", module_name="NeuralNet" 164 | ) 165 | torcheck.add_tensor_nan_check( 166 | nonan_model.fc1.bias, tensor_name="fc1.bias", module_name="NeuralNet" 167 | ) 168 | torcheck.add_tensor_nan_check( 169 | nonan_model.fc2.weight, tensor_name="fc2.weight", module_name="NeuralNet" 170 | ) 171 | torcheck.add_tensor_nan_check( 172 | nonan_model.fc2.bias, tensor_name="fc2.bias", module_name="NeuralNet" 173 | ) 174 | run_training(nonan_model, dataloader, nonan_model_optimizer) 175 | 176 | 177 | def _test_tensor_inf_check_with_inf_model( 178 | inf_model_optimizer, inf_model, dataloader, run_training 179 | ): 180 | """TODO: design a test case with inf gradient values""" 181 | torcheck.register(inf_model_optimizer) 182 | torcheck.add_tensor_inf_check( 183 | inf_model.fc1.weight, tensor_name="fc1.weight", module_name="NeuralNet" 184 | ) 185 | torcheck.add_tensor_inf_check( 186 | inf_model.fc1.bias, tensor_name="fc1.bias", module_name="NeuralNet" 187 | ) 188 | torcheck.add_tensor_inf_check( 189 | inf_model.fc2.weight, tensor_name="fc2.weight", module_name="NeuralNet" 190 | ) 191 | torcheck.add_tensor_inf_check( 192 | inf_model.fc2.bias, tensor_name="fc2.bias", module_name="NeuralNet" 193 | ) 194 | with pytest.raises( 195 | RuntimeError, 196 | match=( 197 | r"Module NeuralNet's fc1\.weight contains inf\.\n" 198 | r".*fc1.bias contains inf\.\n.*fc2.weight contains inf\.\n" 199 | r".*fc2.bias contains inf" 200 | ), 201 | ): 202 | run_training(inf_model, dataloader, inf_model_optimizer) 203 | 204 | 205 | def test_tensor_inf_check_with_noinf_model( 206 | noinf_model_optimizer, noinf_model, dataloader, run_training 207 | ): 208 | torcheck.register(noinf_model_optimizer) 209 | torcheck.add_tensor_inf_check( 210 | noinf_model.fc1.weight, tensor_name="fc1.weight", module_name="NeuralNet" 211 | ) 212 | torcheck.add_tensor_inf_check( 213 | noinf_model.fc1.bias, tensor_name="fc1.bias", module_name="NeuralNet" 214 | ) 215 | torcheck.add_tensor_inf_check( 216 | noinf_model.fc2.weight, tensor_name="fc2.weight", module_name="NeuralNet" 217 | ) 218 | torcheck.add_tensor_inf_check( 219 | noinf_model.fc2.bias, tensor_name="fc2.bias", module_name="NeuralNet" 220 | ) 221 | run_training(noinf_model, dataloader, noinf_model_optimizer) 222 | 223 | 224 | def test_tensor_multiple_check_with_correct_model( 225 | correct_model_optimizer, correct_model, dataloader, run_training 226 | ): 227 | torcheck.register(correct_model_optimizer) 228 | torcheck.add_tensor( 229 | correct_model.fc1.weight, 230 | tensor_name="fc1.weight", 231 | module_name="NeuralNet", 232 | changing=True, 233 | check_nan=True, 234 | check_inf=True, 235 | ) 236 | torcheck.add_tensor( 237 | correct_model.fc1.bias, 238 | tensor_name="fc1.bias", 239 | module_name="NeuralNet", 240 | changing=True, 241 | check_nan=True, 242 | check_inf=True, 243 | ) 244 | run_training(correct_model, dataloader, correct_model_optimizer) 245 | -------------------------------------------------------------------------------- /tests/test_verbose.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torcheck 3 | 4 | 5 | def test_verbose_on( 6 | unchanging_model_optimizer, unchanging_model, dataloader, run_training 7 | ): 8 | torcheck.verbose_on() 9 | torcheck.register(unchanging_model_optimizer) 10 | torcheck.add_module_changing_check(unchanging_model, module_name="NeuralNet") 11 | with pytest.raises( 12 | RuntimeError, 13 | match=( 14 | r"Module NeuralNet's fc1\.weight should change\.\n" 15 | r"The tensor is:(.|\n)*" 16 | r"fc1\.bias should change\.\n" 17 | r"The tensor is:(.|\n)*" 18 | ), 19 | ): 20 | run_training(unchanging_model, dataloader, unchanging_model_optimizer) 21 | 22 | 23 | def test_verbose_off( 24 | unchanging_model_optimizer, unchanging_model, dataloader, run_training 25 | ): 26 | torcheck.register(unchanging_model_optimizer) 27 | torcheck.add_module_changing_check(unchanging_model, module_name="NeuralNet") 28 | torcheck.verbose_on() 29 | with pytest.raises(RuntimeError): 30 | run_training(unchanging_model, dataloader, unchanging_model_optimizer) 31 | torcheck.verbose_off() 32 | with pytest.raises( 33 | RuntimeError, 34 | match=( 35 | r"Module NeuralNet's fc1\.weight should change\.\n" 36 | r".*fc1.bias should change" 37 | ), 38 | ): 39 | run_training(unchanging_model, dataloader, unchanging_model_optimizer) 40 | -------------------------------------------------------------------------------- /torcheck/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | 3 | from .registry import Registry 4 | from .utils import verbose_on, verbose_off, is_verbose 5 | 6 | __version__ = importlib.metadata.version(__name__) 7 | 8 | registry = Registry() 9 | 10 | register = registry.register 11 | add_tensor = registry.add_tensor 12 | add_tensor_changing_check = registry.add_tensor_changing_check 13 | add_tensor_unchanging_check = registry.add_tensor_unchanging_check 14 | add_tensor_nan_check = registry.add_tensor_nan_check 15 | add_tensor_inf_check = registry.add_tensor_inf_check 16 | add_module = registry.add_module 17 | add_module_changing_check = registry.add_module_changing_check 18 | add_module_unchanging_check = registry.add_module_unchanging_check 19 | add_module_output_range_check = registry.add_module_output_range_check 20 | add_module_nan_check = registry.add_module_nan_check 21 | add_module_inf_check = registry.add_module_inf_check 22 | 23 | disable_optimizers = registry.disable_optimizers 24 | disable_modules = registry.disable_modules 25 | disable = registry.disable 26 | enable_optimizers = registry.enable_optimizers 27 | enable_modules = registry.enable_modules 28 | enable = registry.enable 29 | -------------------------------------------------------------------------------- /torcheck/output_spec.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Union 3 | import warnings 4 | 5 | import torch 6 | 7 | from .utils import message_utils 8 | 9 | 10 | @dataclass 11 | class OutputSpec: 12 | module_name: str = None 13 | range: Union[list, tuple] = None 14 | negate: bool = False 15 | check_nan: bool = False 16 | check_inf: bool = False 17 | 18 | @property 19 | def name(self): 20 | if self.module_name is None: 21 | return "Module's output" 22 | else: 23 | return f"Module {self.module_name}'s output" 24 | 25 | @property 26 | def condition(self): 27 | low, high = self.range 28 | if low is None: 29 | return f"< {high}" 30 | elif high is None: 31 | return f"> {low}" 32 | else: 33 | return f"> {low} and < {high}" 34 | 35 | def update( 36 | self, 37 | module_name=None, 38 | range=None, 39 | negate=False, 40 | check_nan=False, 41 | check_inf=False, 42 | ): 43 | if module_name is not None and module_name != self.module_name: 44 | old_name = self.name 45 | self.module_name = module_name 46 | warnings.warn(f"{old_name} is renamed as {self.name}.") 47 | if range is not None: 48 | self.range = range 49 | self.negate = negate 50 | if check_nan: 51 | self.check_nan = True 52 | if check_inf: 53 | self.check_inf = True 54 | 55 | def validate(self, output): 56 | error_items = [] 57 | if self.range is not None: 58 | error_items.append(self.validate_range(output)) 59 | if self.check_nan: 60 | error_items.append(self.validate_nan(output)) 61 | if self.check_inf: 62 | error_items.append(self.validate_inf(output)) 63 | 64 | error_items = [_ for _ in error_items if _ is not None] 65 | if len(error_items): 66 | raise RuntimeError(message_utils.make_message(error_items, output)) 67 | 68 | def validate_range(self, output): 69 | low, high = self.range 70 | status = torch.ones_like(output, dtype=torch.bool) 71 | if low is not None: 72 | status = output >= low 73 | if high is not None: 74 | status = status & (output <= high) 75 | 76 | if not self.negate: 77 | if not torch.all(status).item(): 78 | return ( 79 | f"{self.name} should all {self.condition}. " "Some are out of range" 80 | ) 81 | else: 82 | if torch.all(status).item(): 83 | return f"{self.name} shouldn't all {self.condition}" 84 | 85 | def validate_nan(self, output): 86 | if torch.any(torch.isnan(output)).item(): 87 | return f"{self.name} contains NaN." 88 | 89 | def validate_inf(self, output): 90 | if torch.any(torch.isinf(output)).item(): 91 | return f"{self.name} contains inf." 92 | -------------------------------------------------------------------------------- /torcheck/param_spec.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import warnings 3 | 4 | import torch 5 | 6 | from .utils import message_utils 7 | 8 | 9 | @dataclass 10 | class SpecItem: 11 | tensor: torch.Tensor 12 | tensor_name: str 13 | module_name: str = None 14 | changing: bool = None 15 | check_nan: bool = False 16 | check_inf: bool = False 17 | _old_copy: torch.Tensor = field(init=False, default=None) 18 | 19 | def __post_init__(self): 20 | if self.changing is not None: 21 | self._old_copy = self.tensor.detach().clone() 22 | 23 | @property 24 | def name(self): 25 | if self.module_name is None: 26 | return self.tensor_name 27 | else: 28 | return f"Module {self.module_name}'s {self.tensor_name}" 29 | 30 | def update( 31 | self, 32 | tensor_name, 33 | module_name=None, 34 | changing=None, 35 | check_nan=False, 36 | check_inf=False, 37 | ): 38 | if (tensor_name != self.tensor_name) or ( 39 | module_name is not None and module_name != self.module_name 40 | ): 41 | old_name = self.name 42 | self.tensor_name = tensor_name 43 | if module_name is not None: 44 | self.module_name = module_name 45 | warnings.warn(f"{old_name} is renamed as {self.name}") 46 | if changing is not None: 47 | self.changing = changing 48 | if check_nan: 49 | self.check_nan = True 50 | if check_inf: 51 | self.check_inf = True 52 | 53 | def validate(self): 54 | error_items = [] 55 | if self.changing is not None: 56 | error_items.append(self.validate_changing()) 57 | if self.check_nan: 58 | error_items.append(self.validate_nan()) 59 | if self.check_inf: 60 | error_items.append(self.validate_inf()) 61 | 62 | error_items = [_ for _ in error_items if _ is not None] 63 | return message_utils.make_message(error_items, self.tensor) 64 | 65 | def validate_changing(self): 66 | if self.changing: 67 | if torch.equal(self.tensor, self._old_copy): 68 | return f"{self.name} should change." 69 | else: 70 | if not torch.equal(self.tensor, self._old_copy): 71 | return f"{self.name} should not change." 72 | 73 | self._old_copy = self.tensor.detach().clone() 74 | 75 | def validate_nan(self): 76 | if torch.any(torch.isnan(self.tensor)).item(): 77 | return f"{self.name} contains NaN." 78 | 79 | def validate_inf(self): 80 | if torch.any(torch.isinf(self.tensor)).item(): 81 | return f"{self.name} contains inf." 82 | 83 | 84 | @dataclass 85 | class ParamSpec: 86 | specs: dict = field(default_factory=dict) 87 | 88 | def add( 89 | self, 90 | tensor, 91 | tensor_name, 92 | module_name=None, 93 | changing=None, 94 | check_nan=False, 95 | check_inf=False, 96 | ): 97 | if tensor in self.specs: 98 | self.specs[tensor].update( 99 | tensor_name=tensor_name, 100 | module_name=module_name, 101 | changing=changing, 102 | check_nan=check_nan, 103 | check_inf=check_inf, 104 | ) 105 | else: 106 | self.specs[tensor] = SpecItem( 107 | tensor=tensor, 108 | tensor_name=tensor_name, 109 | module_name=module_name, 110 | changing=changing, 111 | check_nan=check_nan, 112 | check_inf=check_inf, 113 | ) 114 | 115 | def validate(self): 116 | error_strings = [] 117 | for spec in self.specs.values(): 118 | error_string = spec.validate() 119 | if len(error_string) > 0: 120 | error_strings.append(error_string) 121 | if len(error_strings) > 0: 122 | error_msg = "\n".join(error_strings) 123 | raise RuntimeError( 124 | f"The following errors are detected while training:\n{error_msg}" 125 | ) 126 | -------------------------------------------------------------------------------- /torcheck/registry.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from functools import singledispatchmethod 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .param_spec import ParamSpec 8 | from .output_spec import OutputSpec 9 | 10 | 11 | @dataclass 12 | class Registry: 13 | optimizer_to_spec: dict = field(default_factory=dict, init=False) 14 | tensor_to_optimizer: dict = field(default_factory=dict, init=False) 15 | active_optimizers: set = field(default_factory=set, init=False) 16 | module_to_spec: dict = field(default_factory=dict, init=False) 17 | active_modules: set = field(default_factory=set, init=False) 18 | 19 | @singledispatchmethod 20 | def _run_check(self, component): 21 | pass 22 | 23 | @_run_check.register 24 | def _(self, optimizer: torch.optim.Optimizer): 25 | def decorator(func): 26 | def inner(*args, **kwargs): 27 | output = func(*args, **kwargs) 28 | if optimizer in self.active_optimizers: 29 | self.optimizer_to_spec[optimizer].validate() 30 | return output 31 | 32 | return inner 33 | 34 | return decorator 35 | 36 | @_run_check.register 37 | def _(self, module: nn.Module): 38 | def decorator(func): 39 | def inner(*args, **kwargs): 40 | output = func(*args, **kwargs) 41 | if module in self.active_modules: 42 | self.module_to_spec[module].validate(output) 43 | return output 44 | 45 | return inner 46 | 47 | return decorator 48 | 49 | def register(self, optimizer): 50 | if optimizer in self.optimizer_to_spec: 51 | raise RuntimeError("The optimizer has already been registered.") 52 | self.optimizer_to_spec[optimizer] = ParamSpec() 53 | optimizer.step = self._run_check(optimizer)(optimizer.step) 54 | for param_group in optimizer.param_groups: 55 | for param in param_group["params"]: 56 | self.tensor_to_optimizer[param] = optimizer 57 | self.active_optimizers.add(optimizer) 58 | 59 | def add_tensor( 60 | self, 61 | tensor, 62 | tensor_name, 63 | module_name=None, 64 | changing=None, 65 | check_nan=False, 66 | check_inf=False, 67 | ): 68 | optimizer = self.tensor_to_optimizer.get(tensor, None) 69 | if optimizer is None: 70 | raise RuntimeError( 71 | "The tensor doesn't belong to any optimizer. " 72 | "Please register its optimizer first." 73 | ) 74 | self.optimizer_to_spec[optimizer].add( 75 | tensor=tensor, 76 | tensor_name=tensor_name, 77 | module_name=module_name, 78 | changing=changing, 79 | check_nan=check_nan, 80 | check_inf=check_inf, 81 | ) 82 | 83 | def _add_param_check( 84 | self, module, module_name=None, changing=None, check_nan=False, check_inf=False 85 | ): 86 | if not isinstance(module, nn.Module): 87 | raise RuntimeError( 88 | f"Module should be nn.Module type, but is {type(module)}." 89 | ) 90 | 91 | for name, param in module.named_parameters(): 92 | self.add_tensor( 93 | tensor=param, 94 | tensor_name=name, 95 | module_name=module_name, 96 | changing=changing, 97 | check_nan=check_nan, 98 | check_inf=check_inf, 99 | ) 100 | 101 | def _add_output_check( 102 | self, 103 | module, 104 | module_name=None, 105 | output_range=None, 106 | negate_range=False, 107 | check_nan=False, 108 | check_inf=False, 109 | ): 110 | if not isinstance(module, nn.Module): 111 | raise RuntimeError( 112 | f"Module should be nn.Module type, but is {type(module)}." 113 | ) 114 | 115 | if module in self.module_to_spec: 116 | self.module_to_spec[module].update( 117 | module_name=module_name, 118 | range=output_range, 119 | negate=negate_range, 120 | check_nan=check_nan, 121 | check_inf=check_inf, 122 | ) 123 | else: 124 | self.module_to_spec[module] = OutputSpec( 125 | module_name=module_name, 126 | range=output_range, 127 | negate=negate_range, 128 | check_nan=check_nan, 129 | check_inf=check_inf, 130 | ) 131 | self.active_modules.add(module) 132 | module.forward = self._run_check(module)(module.forward) 133 | 134 | def add_module( 135 | self, 136 | module, 137 | module_name=None, 138 | changing=None, 139 | output_range=None, 140 | negate_range=False, 141 | check_nan=False, 142 | check_inf=False, 143 | ): 144 | if (changing is not None) or check_nan or check_inf: 145 | self._add_param_check( 146 | module=module, 147 | module_name=module_name, 148 | changing=changing, 149 | check_nan=check_nan, 150 | check_inf=check_inf, 151 | ) 152 | if (output_range is not None) or check_nan or check_inf: 153 | self._add_output_check( 154 | module=module, 155 | module_name=module_name, 156 | output_range=output_range, 157 | negate_range=negate_range, 158 | check_nan=check_nan, 159 | check_inf=check_inf, 160 | ) 161 | 162 | def add_tensor_changing_check( 163 | self, 164 | tensor, 165 | tensor_name, 166 | module_name=None, 167 | ): 168 | self.add_tensor( 169 | tensor=tensor, 170 | tensor_name=tensor_name, 171 | module_name=module_name, 172 | changing=True, 173 | ) 174 | 175 | def add_tensor_unchanging_check( 176 | self, 177 | tensor, 178 | tensor_name, 179 | module_name=None, 180 | ): 181 | self.add_tensor( 182 | tensor=tensor, 183 | tensor_name=tensor_name, 184 | module_name=module_name, 185 | changing=False, 186 | ) 187 | 188 | def add_tensor_nan_check( 189 | self, 190 | tensor, 191 | tensor_name, 192 | module_name=None, 193 | ): 194 | self.add_tensor( 195 | tensor=tensor, 196 | tensor_name=tensor_name, 197 | module_name=module_name, 198 | check_nan=True, 199 | ) 200 | 201 | def add_tensor_inf_check( 202 | self, 203 | tensor, 204 | tensor_name, 205 | module_name=None, 206 | ): 207 | self.add_tensor( 208 | tensor=tensor, 209 | tensor_name=tensor_name, 210 | module_name=module_name, 211 | check_inf=True, 212 | ) 213 | 214 | def add_module_changing_check( 215 | self, 216 | module, 217 | module_name=None, 218 | ): 219 | self._add_param_check( 220 | module, 221 | module_name=module_name, 222 | changing=True, 223 | ) 224 | 225 | def add_module_unchanging_check( 226 | self, 227 | module, 228 | module_name=None, 229 | ): 230 | self._add_param_check( 231 | module, 232 | module_name=module_name, 233 | changing=False, 234 | ) 235 | 236 | def add_module_output_range_check( 237 | self, 238 | module, 239 | output_range, 240 | negate_range=False, 241 | module_name=None, 242 | ): 243 | self._add_output_check( 244 | module, 245 | output_range=output_range, 246 | negate_range=negate_range, 247 | module_name=module_name, 248 | ) 249 | 250 | def add_module_nan_check( 251 | self, 252 | module, 253 | module_name=None, 254 | ): 255 | self.add_module(module, module_name=module_name, check_nan=True) 256 | 257 | def add_module_inf_check( 258 | self, 259 | module, 260 | module_name=None, 261 | ): 262 | self.add_module(module, module_name=module_name, check_inf=True) 263 | 264 | def disable_optimizers(self, *optimizers): 265 | for optimizer in optimizers: 266 | self.active_optimizers.remove(optimizer) 267 | 268 | def disable_modules(self, *modules): 269 | for module in modules: 270 | self.active_modules.remove(module) 271 | 272 | def disable(self, optimizers=None, modules=None): 273 | if optimizers is None: 274 | optimizers = self.active_optimizers 275 | self.disable_optimizers(*optimizers) 276 | if modules is None: 277 | modules = self.active_modules 278 | self.disable_modules(*modules) 279 | 280 | def enable_optimizers(self, *optimizers): 281 | for optimizer in optimizers: 282 | self.active_optimizers.add(optimizer) 283 | 284 | def enable_modules(self, *modules): 285 | for module in modules: 286 | self.active_modules.add(module) 287 | 288 | def enable(self, optimizers=None, modules=None): 289 | if optimizers is None: 290 | optimizers = self.optimizer_to_spec.keys() 291 | self.enable_optimizers(*optimizers) 292 | if modules is None: 293 | modules = self.module_to_spec.keys() 294 | self.enable_modules(*modules) 295 | -------------------------------------------------------------------------------- /torcheck/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .message_utils import verbose_on, verbose_off, is_verbose 2 | -------------------------------------------------------------------------------- /torcheck/utils/message_utils.py: -------------------------------------------------------------------------------- 1 | _is_verbose = False 2 | 3 | 4 | def verbose_on(): 5 | global _is_verbose 6 | _is_verbose = True 7 | 8 | 9 | def verbose_off(): 10 | global _is_verbose 11 | _is_verbose = False 12 | 13 | 14 | def is_verbose(): 15 | return _is_verbose 16 | 17 | 18 | def make_message(error_items, tensor): 19 | if not len(error_items): 20 | return "" 21 | 22 | message = " ".join(error_items) 23 | if is_verbose(): 24 | message += f"\nThe tensor is:\n{tensor}\n" 25 | return message 26 | --------------------------------------------------------------------------------