├── snoop.png ├── tests ├── __init__.py ├── test_snoop.py ├── mini_toolbox │ ├── __init__.py │ ├── contextlib.py │ └── pathlib.py ├── utils.py └── test_torchsnooper.py ├── .gitignore ├── setup.cfg ├── .github ├── dependabot.yml └── workflows │ ├── deploy-pypi.yml │ ├── deploy-test-pypi.yml │ └── tests.yml ├── setup.py ├── LICENSE ├── torchsnooper └── __init__.py └── README.md /snoop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zasdfgbnm/TorchSnooper/HEAD/snoop.png -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | pytest.register_assert_rewrite('tests.utils') 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .eggs 3 | *.egg-info 4 | .pytest_cache 5 | build 6 | dist 7 | /test.py 8 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501 3 | exclude = 4 | .git, 5 | __pycache__, 6 | build, 7 | .eggs, 8 | tests/utils.py -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | time: "13:00" 8 | open-pull-requests-limit: 10 9 | -------------------------------------------------------------------------------- /.github/workflows/deploy-pypi.yml: -------------------------------------------------------------------------------- 1 | name: deploy-pypi 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build: 9 | 10 | runs-on: ubuntu-latest 11 | strategy: 12 | max-parallel: 4 13 | matrix: 14 | python-version: [3.8] 15 | 16 | steps: 17 | - uses: actions/checkout@v1 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v1 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Fail build on non-release commits 23 | run: git describe --exact-match --tags HEAD 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install twine wheel 28 | - name: Deploy 29 | run: | 30 | rm -rf dist/* 31 | python setup.py sdist bdist_wheel 32 | twine upload -u zasdfgbnm-bot -p ${{secrets.zasdfgbnm_bot_pypi_password}} dist/* 33 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Xiang Gao and collaborators. 2 | # This program is distributed under the MIT license. 3 | import setuptools 4 | 5 | 6 | with open("README.md", "r") as fh: 7 | long_description = fh.read() 8 | 9 | 10 | setuptools.setup( 11 | name='TorchSnooper', 12 | author='Xiang Gao', 13 | author_email='qasdfgtyuiop@gmail.com', 14 | description="Debug PyTorch code using PySnooper.", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | url='https://github.com/zasdfgbnm/TorchSnooper', 18 | packages=setuptools.find_packages(exclude=['tests']), 19 | use_scm_version=True, 20 | setup_requires=['setuptools_scm'], 21 | install_requires=[ 22 | 'pysnooper>=0.1.0', 23 | 'numpy', 24 | ], 25 | tests_require=[ 26 | 'pytest', 27 | 'torch', 28 | 'python-toolbox', 29 | 'coverage', 30 | 'snoop', 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /.github/workflows/deploy-test-pypi.yml: -------------------------------------------------------------------------------- 1 | name: deploy-test-pypi 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | schedule: 9 | - cron: '0 0 * * *' 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | strategy: 16 | max-parallel: 4 17 | matrix: 18 | python-version: [3.8] 19 | 20 | steps: 21 | - uses: actions/checkout@v1 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v1 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install twine wheel 30 | - name: Deploy 31 | run: | 32 | rm -rf dist/* 33 | git tag $(date +'v%Y.%m.%d.%H.%M.%S') 34 | python setup.py sdist bdist_wheel 35 | twine upload --repository-url https://test.pypi.org/legacy/ -u zasdfgbnm-bot -p ${{secrets.zasdfgbnm_bot_test_pypi_password}} dist/* 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018- Xiang Gao and other contributors 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | schedule: 9 | - cron: '0 0 * * *' 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | strategy: 16 | max-parallel: 4 17 | matrix: 18 | python-version: [3.6, 3.7, 3.8] 19 | 20 | steps: 21 | - uses: actions/checkout@v1 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v1 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | pip install --upgrade pip 29 | pip install --upgrade numpy setuptools wheel six 30 | pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html 31 | pip install . 32 | - name: Lint with flake8 33 | run: | 34 | pip install flake8 35 | flake8 . --count --show-source --statistics 36 | - name: Test with pytest 37 | run: | 38 | python setup.py test 39 | -------------------------------------------------------------------------------- /tests/test_snoop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy 3 | import math 4 | import sys 5 | import torchsnooper 6 | from python_toolbox import sys_tools 7 | import re 8 | import snoop 9 | import copy 10 | 11 | 12 | ansi_escape = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]') 13 | default_config = copy.copy(snoop.config) 14 | 15 | 16 | def func(): 17 | x = torch.tensor(math.inf) 18 | x = torch.tensor(math.nan) 19 | x = torch.tensor(1.0, requires_grad=True) 20 | x = torch.tensor([1.0, math.nan, math.inf]) 21 | x = numpy.zeros((2, 2)) 22 | x = (x, x) 23 | 24 | 25 | verbose_expect = ''' 26 | 01:24:31.56 >>> Call to func in File "test_snoop.py", line 16 27 | 01:24:31.56 16 | def func(): 28 | 01:24:31.56 17 | x = torch.tensor(math.inf) 29 | 01:24:31.56 .......... x = tensor<(), float32, cpu, has_inf> 30 | 01:24:31.56 .......... x.data = tensor(inf) 31 | 01:24:31.56 18 | x = torch.tensor(math.nan) 32 | 01:24:31.56 .......... x = tensor<(), float32, cpu, has_nan> 33 | 01:24:31.56 .......... x.data = tensor(nan) 34 | 01:24:31.56 19 | x = torch.tensor(1.0, requires_grad=True) 35 | 01:24:31.56 .......... x = tensor<(), float32, cpu, grad> 36 | 01:24:31.56 .......... x.data = tensor(1.) 37 | 01:24:31.56 20 | x = torch.tensor([1.0, math.nan, math.inf]) 38 | 01:24:31.56 .......... x = tensor<(3,), float32, cpu, has_nan, has_inf> 39 | 01:24:31.56 .......... x.data = tensor([1., nan, inf]) 40 | 01:24:31.56 21 | x = numpy.zeros((2, 2)) 41 | 01:24:31.56 .......... x = ndarray<(2, 2), float64> 42 | 01:24:31.56 .......... x.data = 43 | 01:24:31.56 22 | x = (x, x) 44 | 01:24:31.56 .......... x = (ndarray<(2, 2), float64>, ndarray<(2, 2), float64>) 45 | 01:24:31.56 <<< Return value from func: None 46 | '''.strip() 47 | 48 | terse_expect = ''' 49 | 21:44:09.63 >>> Call to func in File "test_snoop.py", line 16 50 | 21:44:09.63 16 | def func(): 51 | 21:44:09.63 17 | x = torch.tensor(math.inf) 52 | 21:44:09.63 .......... x = tensor<(), float32, cpu, has_inf> 53 | 21:44:09.63 18 | x = torch.tensor(math.nan) 54 | 21:44:09.63 .......... x = tensor<(), float32, cpu, has_nan> 55 | 21:44:09.63 19 | x = torch.tensor(1.0, requires_grad=True) 56 | 21:44:09.63 .......... x = tensor<(), float32, cpu, grad> 57 | 21:44:09.63 20 | x = torch.tensor([1.0, math.nan, math.inf]) 58 | 21:44:09.63 .......... x = tensor<(3,), float32, cpu, has_nan, has_inf> 59 | 21:44:09.63 21 | x = numpy.zeros((2, 2)) 60 | 21:44:09.63 .......... x = ndarray<(2, 2), float64> 61 | 21:44:09.63 22 | x = (x, x) 62 | 21:44:09.63 .......... x = (ndarray<(2, 2), float64>, ndarray<(2, 2), float64>) 63 | 21:44:09.63 <<< Return value from func: None 64 | '''.strip() 65 | 66 | 67 | def clean_output(input_): 68 | lines = input_.splitlines()[1:] 69 | lines = [x[len('21:14:00.89 '):] for x in lines] 70 | return '\n'.join(lines) 71 | 72 | 73 | def assert_output(verbose, expect): 74 | torchsnooper.register_snoop(verbose=verbose) 75 | with sys_tools.OutputCapturer(stdout=False, stderr=True) as output_capturer: 76 | assert sys.gettrace() is None 77 | snoop(func)() 78 | assert sys.gettrace() is None 79 | output = output_capturer.string_io.getvalue() 80 | output = ansi_escape.sub('', output) 81 | assert clean_output(output) == clean_output(expect) 82 | snoop.config = default_config 83 | 84 | 85 | def test_verbose(): 86 | assert_output(True, verbose_expect) 87 | 88 | 89 | def test_terse(): 90 | assert_output(False, terse_expect) 91 | -------------------------------------------------------------------------------- /torchsnooper/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pysnooper 3 | import pysnooper.utils 4 | import warnings 5 | import numpy 6 | from pkg_resources import get_distribution, DistributionNotFound 7 | 8 | 9 | FLOATING_POINTS = set() 10 | for i in ['float', 'double', 'half', 'bfloat16', 'complex128', 'complex32', 'complex64']: 11 | if hasattr(torch, i): # older version of PyTorch do not have complex dtypes 12 | FLOATING_POINTS.add(getattr(torch, i)) 13 | 14 | 15 | try: 16 | __version__ = get_distribution(__name__).version 17 | except DistributionNotFound: 18 | # package is not installed 19 | pass 20 | 21 | 22 | class TensorFormat: 23 | 24 | def __init__(self, property_name=False, properties=('shape', 'dtype', 'device', 'requires_grad', 'has_nan', 'has_inf', 'memory_format')): 25 | self.properties = properties 26 | self.properties_name = property_name 27 | 28 | def repr_shape(self, tensor): 29 | ret = '' 30 | if not hasattr(tensor, 'names') or not tensor.has_names(): 31 | ret += str(tuple(tensor.shape)) 32 | else: 33 | ret += '(' 34 | for n, v in zip(tensor.names, tensor.shape): 35 | if n is not None: 36 | ret += '{}={}, '.format(n, v) 37 | else: 38 | ret += '{}, '.format(v) 39 | ret = ret[:-2] + ')' 40 | return ret 41 | 42 | def repr_dtype(self, tensor): 43 | dtype_str = str(tensor.dtype) 44 | dtype_str = dtype_str[len('torch.'):] 45 | return dtype_str 46 | 47 | def repr_device(self, tensor): 48 | return str(tensor.device) 49 | 50 | def repr_requires_grad(self, tensor): 51 | if self.properties_name: 52 | return str(tensor.requires_grad) 53 | if tensor.requires_grad: 54 | return 'grad' 55 | return '' 56 | 57 | def repr_has_nan(self, tensor): 58 | result = tensor.dtype in FLOATING_POINTS and bool(torch.isnan(tensor).any()) 59 | if self.properties_name: 60 | return str(result) 61 | if result: 62 | return 'has_nan' 63 | return '' 64 | 65 | def repr_has_inf(self, tensor): 66 | result = tensor.dtype in FLOATING_POINTS and bool(torch.isinf(tensor).any()) 67 | if self.properties_name: 68 | return str(result) 69 | if result: 70 | return 'has_inf' 71 | return '' 72 | 73 | def repr_memory_format(self, tensor): 74 | ctg = tensor.is_contiguous(memory_format=torch.contiguous_format) 75 | cl = tensor.is_contiguous(memory_format=torch.channels_last) 76 | cl3d = tensor.is_contiguous(memory_format=torch.channels_last_3d) 77 | if ctg: 78 | mf = 'contiguous' 79 | elif cl: 80 | mf = 'channels_last' 81 | elif cl3d: 82 | mf = 'channels_last_3d' 83 | else: 84 | mf = 'discontiguous' 85 | if self.properties_name: 86 | if ctg or cl or cl3d: 87 | mf = 'torch.' + mf 88 | return mf 89 | if ctg: 90 | return '' 91 | return mf 92 | 93 | def __call__(self, tensor): 94 | prefix = 'tensor<' 95 | suffix = '>' 96 | properties_str = '' 97 | for p in self.properties: 98 | new = '' 99 | if self.properties_name: 100 | new += p + '=' 101 | if hasattr(self, 'repr_' + p): 102 | new += getattr(self, 'repr_' + p)(tensor) 103 | else: 104 | raise ValueError('Unknown tensor property') 105 | 106 | if properties_str != '' and len(new) > 0: 107 | properties_str += ', ' 108 | properties_str += new 109 | 110 | return prefix + properties_str + suffix 111 | 112 | 113 | default_format = TensorFormat() 114 | 115 | 116 | class NumpyFormat: 117 | 118 | def __call__(self, x): 119 | return 'ndarray<{}, {}>'.format(x.shape, x.dtype.name) 120 | 121 | 122 | default_numpy_format = NumpyFormat() 123 | 124 | 125 | class TorchSnooper(pysnooper.tracer.Tracer): 126 | 127 | def __init__(self, *args, tensor_format=default_format, numpy_format=default_numpy_format, **kwargs): 128 | self.orig_custom_repr = kwargs['custom_repr'] if 'custom_repr' in kwargs else () 129 | custom_repr = (lambda x: True, self.compute_repr) 130 | kwargs['custom_repr'] = (custom_repr,) 131 | super(TorchSnooper, self).__init__(*args, **kwargs) 132 | self.tensor_format = tensor_format 133 | self.numpy_format = numpy_format 134 | 135 | @staticmethod 136 | def is_return_types(x): 137 | return type(x).__module__ == 'torch.return_types' 138 | 139 | def return_types_repr(self, x): 140 | if type(x).__name__ in {'max', 'min', 'median', 'mode', 'sort', 'topk', 'kthvalue'}: 141 | return type(x).__name__ + '(values=' + self.tensor_format(x.values) + ', indices=' + self.tensor_format(x.indices) + ')' 142 | if type(x).__name__ == 'svd': 143 | return 'svd(U=' + self.tensor_format(x.U) + ', S=' + self.tensor_format(x.S) + ', V=' + self.tensor_format(x.V) + ')' 144 | if type(x).__name__ == 'slogdet': 145 | return 'slogdet(sign=' + self.tensor_format(x.sign) + ', logabsdet=' + self.tensor_format(x.logabsdet) + ')' 146 | if type(x).__name__ == 'qr': 147 | return 'qr(Q=' + self.tensor_format(x.Q) + ', R=' + self.tensor_format(x.R) + ')' 148 | if type(x).__name__ == 'solve': 149 | return 'solve(solution=' + self.tensor_format(x.solution) + ', LU=' + self.tensor_format(x.LU) + ')' 150 | if type(x).__name__ == 'geqrf': 151 | return 'geqrf(a=' + self.tensor_format(x.a) + ', tau=' + self.tensor_format(x.tau) + ')' 152 | if type(x).__name__ in {'symeig', 'eig'}: 153 | return type(x).__name__ + '(eigenvalues=' + self.tensor_format(x.eigenvalues) + ', eigenvectors=' + self.tensor_format(x.eigenvectors) + ')' 154 | if type(x).__name__ == 'triangular_solve': 155 | return 'triangular_solve(solution=' + self.tensor_format(x.solution) + ', cloned_coefficient=' + self.tensor_format(x.cloned_coefficient) + ')' 156 | if type(x).__name__ == 'gels': 157 | return 'gels(solution=' + self.tensor_format(x.solution) + ', QR=' + self.tensor_format(x.QR) + ')' 158 | warnings.warn('Unknown return_types encountered, open a bug report!') 159 | 160 | def compute_repr(self, x): 161 | orig_repr_func = pysnooper.utils.get_repr_function(x, self.orig_custom_repr) 162 | if torch.is_tensor(x): 163 | return self.tensor_format(x) 164 | if isinstance(x, numpy.ndarray): 165 | return self.numpy_format(x) 166 | if self.is_return_types(x): 167 | return self.return_types_repr(x) 168 | if orig_repr_func is not repr: 169 | return orig_repr_func(x) 170 | if isinstance(x, (list, tuple)): 171 | content = '' 172 | for i in x: 173 | if content != '': 174 | content += ', ' 175 | content += self.compute_repr(i) 176 | if isinstance(x, tuple) and len(x) == 1: 177 | content += ',' 178 | if isinstance(x, tuple): 179 | return '(' + content + ')' 180 | return '[' + content + ']' 181 | if isinstance(x, dict): 182 | content = '' 183 | for k, v in x.items(): 184 | if content != '': 185 | content += ', ' 186 | content += self.compute_repr(k) + ': ' + self.compute_repr(v) 187 | return '{' + content + '}' 188 | return repr(x) 189 | 190 | 191 | snoop = TorchSnooper 192 | 193 | 194 | def register_snoop(verbose=False, tensor_format=default_format, numpy_format=default_numpy_format): 195 | import snoop 196 | import cheap_repr 197 | import snoop.configuration 198 | cheap_repr.register_repr(torch.Tensor)(lambda x, _: tensor_format(x)) 199 | cheap_repr.register_repr(numpy.ndarray)(lambda x, _: numpy_format(x)) 200 | cheap_repr.cheap_repr(torch.zeros(6)) 201 | unwanted = { 202 | snoop.configuration.len_shape_watch, 203 | snoop.configuration.dtype_watch, 204 | } 205 | snoop.config.watch_extras = tuple(x for x in snoop.config.watch_extras if x not in unwanted) 206 | if verbose: 207 | 208 | class TensorWrap: 209 | 210 | def __init__(self, tensor): 211 | self.tensor = tensor 212 | 213 | def __repr__(self): 214 | return self.tensor.__repr__() 215 | 216 | snoop.config.watch_extras += ( 217 | lambda source, value: ('{}.data'.format(source), TensorWrap(value.data)), 218 | ) 219 | -------------------------------------------------------------------------------- /tests/mini_toolbox/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2019 Ram Rachum and collaborators. 4 | # This program is distributed under the MIT license. 5 | 6 | import tempfile 7 | import shutil 8 | import io 9 | import sys 10 | from . import pathlib 11 | from . import contextlib 12 | 13 | 14 | 15 | @contextlib.contextmanager 16 | def BlankContextManager(): 17 | yield 18 | 19 | @contextlib.contextmanager 20 | def create_temp_folder(prefix=tempfile.template, suffix='', 21 | parent_folder=None, chmod=None): 22 | ''' 23 | Context manager that creates a temporary folder and deletes it after usage. 24 | 25 | After the suite finishes, the temporary folder and all its files and 26 | subfolders will be deleted. 27 | 28 | Example: 29 | 30 | with create_temp_folder() as temp_folder: 31 | 32 | # We have a temporary folder! 33 | assert temp_folder.is_dir() 34 | 35 | # We can create files in it: 36 | (temp_folder / 'my_file').open('w') 37 | 38 | # The suite is finished, now it's all cleaned: 39 | assert not temp_folder.exists() 40 | 41 | Use the `prefix` and `suffix` string arguments to dictate a prefix and/or a 42 | suffix to the temporary folder's name in the filesystem. 43 | 44 | If you'd like to set the permissions of the temporary folder, pass them to 45 | the optional `chmod` argument, like this: 46 | 47 | create_temp_folder(chmod=0o550) 48 | 49 | ''' 50 | temp_folder = pathlib.Path(tempfile.mkdtemp(prefix=prefix, suffix=suffix, 51 | dir=parent_folder)) 52 | try: 53 | if chmod is not None: 54 | temp_folder.chmod(chmod) 55 | yield temp_folder 56 | finally: 57 | shutil.rmtree(str(temp_folder)) 58 | 59 | 60 | class NotInDict: 61 | '''Object signifying that the key was not found in the dict.''' 62 | 63 | 64 | class TempValueSetter(object): 65 | ''' 66 | Context manager for temporarily setting a value to a variable. 67 | 68 | The value is set to the variable before the suite starts, and gets reset 69 | back to the old value after the suite finishes. 70 | ''' 71 | 72 | def __init__(self, variable, value, assert_no_fiddling=True): 73 | ''' 74 | Construct the `TempValueSetter`. 75 | 76 | `variable` may be either an `(object, attribute_string)`, a `(dict, 77 | key)` pair, or a `(getter, setter)` pair. 78 | 79 | `value` is the temporary value to set to the variable. 80 | ''' 81 | 82 | self.assert_no_fiddling = assert_no_fiddling 83 | 84 | 85 | ####################################################################### 86 | # We let the user input either an `(object, attribute_string)`, a 87 | # `(dict, key)` pair, or a `(getter, setter)` pair. So now it's our job 88 | # to inspect `variable` and figure out which one of these options the 89 | # user chose, and then obtain from that a `(getter, setter)` pair that 90 | # we could use. 91 | 92 | bad_input_exception = Exception( 93 | '`variable` must be either an `(object, attribute_string)` pair, ' 94 | 'a `(dict, key)` pair, or a `(getter, setter)` pair.' 95 | ) 96 | 97 | try: 98 | first, second = variable 99 | except Exception: 100 | raise bad_input_exception 101 | if hasattr(first, '__getitem__') and hasattr(first, 'get') and \ 102 | hasattr(first, '__setitem__') and hasattr(first, '__delitem__'): 103 | # `first` is a dictoid; so we were probably handed a `(dict, key)` 104 | # pair. 105 | self.getter = lambda: first.get(second, NotInDict) 106 | self.setter = lambda value: (first.__setitem__(second, value) if 107 | value is not NotInDict else 108 | first.__delitem__(second)) 109 | ### Finished handling the `(dict, key)` case. ### 110 | 111 | elif callable(second): 112 | # `second` is a callable; so we were probably handed a `(getter, 113 | # setter)` pair. 114 | if not callable(first): 115 | raise bad_input_exception 116 | self.getter, self.setter = first, second 117 | ### Finished handling the `(getter, setter)` case. ### 118 | else: 119 | # All that's left is the `(object, attribute_string)` case. 120 | if not isinstance(second, str): 121 | raise bad_input_exception 122 | 123 | parent, attribute_name = first, second 124 | self.getter = lambda: getattr(parent, attribute_name) 125 | self.setter = lambda value: setattr(parent, attribute_name, value) 126 | ### Finished handling the `(object, attribute_string)` case. ### 127 | 128 | # 129 | # 130 | ### Finished obtaining a `(getter, setter)` pair from `variable`. ##### 131 | 132 | 133 | self.getter = self.getter 134 | '''Getter for getting the current value of the variable.''' 135 | 136 | self.setter = self.setter 137 | '''Setter for Setting the the variable's value.''' 138 | 139 | self.value = value 140 | '''The value to temporarily set to the variable.''' 141 | 142 | self.active = False 143 | 144 | 145 | def __enter__(self): 146 | 147 | self.active = True 148 | 149 | self.old_value = self.getter() 150 | '''The old value of the variable, before entering the suite.''' 151 | 152 | self.setter(self.value) 153 | 154 | # In `__exit__` we'll want to check if anyone changed the value of the 155 | # variable in the suite, which is unallowed. But we can't compare to 156 | # `.value`, because sometimes when you set a value to a variable, some 157 | # mechanism modifies that value for various reasons, resulting in a 158 | # supposedly equivalent, but not identical, value. For example this 159 | # happens when you set the current working directory on Mac OS. 160 | # 161 | # So here we record the value right after setting, and after any 162 | # possible processing the system did to it: 163 | self._value_right_after_setting = self.getter() 164 | 165 | return self 166 | 167 | 168 | def __exit__(self, exc_type, exc_value, exc_traceback): 169 | 170 | if self.assert_no_fiddling: 171 | # Asserting no-one inside the suite changed our variable: 172 | assert self.getter() == self._value_right_after_setting 173 | 174 | self.setter(self.old_value) 175 | 176 | self.active = False 177 | 178 | class OutputCapturer(object): 179 | ''' 180 | Context manager for catching all system output generated during suite. 181 | 182 | Example: 183 | 184 | with OutputCapturer() as output_capturer: 185 | print('woo!') 186 | 187 | assert output_capturer.output == 'woo!\n' 188 | 189 | The boolean arguments `stdout` and `stderr` determine, respectively, 190 | whether the standard-output and the standard-error streams will be 191 | captured. 192 | ''' 193 | def __init__(self, stdout=True, stderr=True): 194 | self.string_io = io.StringIO() 195 | 196 | if stdout: 197 | self._stdout_temp_setter = \ 198 | TempValueSetter((sys, 'stdout'), self.string_io) 199 | else: # not stdout 200 | self._stdout_temp_setter = BlankContextManager() 201 | 202 | if stderr: 203 | self._stderr_temp_setter = \ 204 | TempValueSetter((sys, 'stderr'), self.string_io) 205 | else: # not stderr 206 | self._stderr_temp_setter = BlankContextManager() 207 | 208 | def __enter__(self): 209 | '''Manage the `OutputCapturer`'s context.''' 210 | self._stdout_temp_setter.__enter__() 211 | self._stderr_temp_setter.__enter__() 212 | return self 213 | 214 | def __exit__(self, exc_type, exc_value, exc_traceback): 215 | # Not doing exception swallowing anywhere here. 216 | self._stderr_temp_setter.__exit__(exc_type, exc_value, exc_traceback) 217 | self._stdout_temp_setter.__exit__(exc_type, exc_value, exc_traceback) 218 | 219 | output = property(lambda self: self.string_io.getvalue(), 220 | doc='''The string of output that was captured.''') 221 | 222 | 223 | class TempSysPathAdder(object): 224 | ''' 225 | Context manager for temporarily adding paths to `sys.path`. 226 | 227 | Removes the path(s) after suite. 228 | 229 | Example: 230 | 231 | with TempSysPathAdder('path/to/fubar/package'): 232 | import fubar 233 | fubar.do_stuff() 234 | 235 | ''' 236 | def __init__(self, addition): 237 | self.addition = [str(addition)] 238 | 239 | 240 | def __enter__(self): 241 | self.entries_not_in_sys_path = [entry for entry in self.addition if 242 | entry not in sys.path] 243 | sys.path += self.entries_not_in_sys_path 244 | return self 245 | 246 | 247 | def __exit__(self, *args, **kwargs): 248 | 249 | for entry in self.entries_not_in_sys_path: 250 | 251 | # We don't allow anyone to remove it except for us: 252 | assert entry in sys.path 253 | 254 | sys.path.remove(entry) 255 | 256 | 257 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchSnooper 2 | 3 | Status: 4 | 5 | ![PyPI](https://img.shields.io/pypi/v/TorchSnooper.svg) 6 | ![PyPI - Downloads](https://img.shields.io/pypi/dm/TorchSnooper.svg) 7 | [![Actions Status](https://github.com/zasdfgbnm/TorchSnooper/workflows/tests/badge.svg)](https://github.com/zasdfgbnm/TorchSnooper/actions) 8 | [![Actions Status](https://github.com/zasdfgbnm/TorchSnooper/workflows/deploy-test-pypi/badge.svg)](https://github.com/zasdfgbnm/TorchSnooper/actions) 9 | 10 | Deploy (only run on release): 11 | 12 | [![Actions Status](https://github.com/zasdfgbnm/TorchSnooper/workflows/deploy-pypi/badge.svg)](https://github.com/zasdfgbnm/TorchSnooper/actions) 13 | 14 | Do you want to look at the shape/dtype/etc. of every step of you model, but tired of manually writing prints? 15 | 16 | Are you bothered by errors like `RuntimeError: Expected object of scalar type Double but got scalar type Float`, and want to quickly figure out the problem? 17 | 18 | TorchSnooper is a [PySnooper](https://github.com/cool-RR/PySnooper) extension that helps you debugging these errors. 19 | 20 | To use TorchSnooper, you just use it like using PySnooper. Remember to replace the `pysnooper.snoop` with `torchsnooper.snoop` in your code. 21 | 22 | To install: 23 | 24 | ``` 25 | pip install torchsnooper 26 | ``` 27 | 28 | TorchSnooper also support [snoop](https://github.com/alexmojaki/snoop). To use TorchSnooper with snoop, simply execute: 29 | ```python 30 | torchsnooper.register_snoop() 31 | ``` 32 | or 33 | ```python 34 | torchsnooper.register_snoop(verbose=True) 35 | ``` 36 | at the beginning, and use snoop normally. 37 | 38 | # Example 1: Monitoring device and dtype 39 | 40 | We're writing a simple function: 41 | 42 | ```python 43 | def myfunc(mask, x): 44 | y = torch.zeros(6) 45 | y.masked_scatter_(mask, x) 46 | return y 47 | ``` 48 | 49 | and use it like below 50 | 51 | ```python 52 | mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda') 53 | source = torch.tensor([1.0, 2.0, 3.0], device='cuda') 54 | y = myfunc(mask, source) 55 | ``` 56 | 57 | The above code seems to be correct, but unfortunately, we are getting the following error: 58 | 59 | ``` 60 | RuntimeError: Expected object of backend CPU but got backend CUDA for argument #2 'mask' 61 | ``` 62 | 63 | What is the problem? Let's snoop it! Decorate our function with `torchsnooper.snoop()`: 64 | 65 | ```python 66 | import torch 67 | import torchsnooper 68 | 69 | @torchsnooper.snoop() 70 | def myfunc(mask, x): 71 | y = torch.zeros(6) 72 | y.masked_scatter_(mask, x) 73 | return y 74 | 75 | mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda') 76 | source = torch.tensor([1.0, 2.0, 3.0], device='cuda') 77 | y = myfunc(mask, source) 78 | ``` 79 | 80 | Run our script, and we will see: 81 | 82 | ``` 83 | Starting var:.. mask = tensor<(6,), int64, cuda:0> 84 | Starting var:.. x = tensor<(3,), float32, cuda:0> 85 | 21:41:42.941668 call 5 def myfunc(mask, x): 86 | 21:41:42.941834 line 6 y = torch.zeros(6) 87 | New var:....... y = tensor<(6,), float32, cpu> 88 | 21:41:42.943443 line 7 y.masked_scatter_(mask, x) 89 | 21:41:42.944404 exception 7 y.masked_scatter_(mask, x) 90 | ``` 91 | 92 | Now pay attention to the devices of tensors, we notice 93 | ``` 94 | New var:....... y = tensor<(6,), float32, cpu> 95 | ``` 96 | 97 | Now, it's clear that, the problem is because `y` is a tensor on CPU, that is, 98 | we forget to specify the device on `y = torch.zeros(6)`. Changing it to 99 | `y = torch.zeros(6, device='cuda')`, this problem is solved. 100 | 101 | But when running the script again we are getting another error: 102 | 103 | ``` 104 | RuntimeError: Expected object of scalar type Byte but got scalar type Long for argument #2 'mask' 105 | ``` 106 | 107 | Look at the trace above again, pay attention to the dtype of variables, we notice 108 | 109 | ``` 110 | Starting var:.. mask = tensor<(6,), int64, cuda:0> 111 | ``` 112 | 113 | OK, the problem is that, we didn't make the `mask` in the input a byte tensor. Changing the line into 114 | ``` 115 | mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda', dtype=torch.uint8) 116 | ``` 117 | Problem solved. 118 | 119 | # Example 1.5: Using Snoop instead of PySnooper 120 | 121 | We could also choose to use [snoop](https://github.com/alexmojaki/snoop) instead of [PySnooper](https://github.com/cool-RR/PySnooper). 122 | 123 | Remember to install `snoop` manually since it is not a dependency of TorchSnooper: 124 | 125 | ``` 126 | pip install snoop 127 | ``` 128 | 129 | The code in example 1 using snoop looks like: 130 | 131 | ```python 132 | import torch 133 | import torchsnooper 134 | import snoop 135 | 136 | torchsnooper.register_snoop() 137 | 138 | @snoop 139 | def myfunc(mask, x): 140 | y = torch.zeros(6) 141 | y.masked_scatter_(mask, x) 142 | return y 143 | 144 | mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda') 145 | source = torch.tensor([1.0, 2.0, 3.0], device='cuda') 146 | y = myfunc(mask, source) 147 | ``` 148 | 149 | and the screenshot looks like: 150 | 151 | ![snoop](snoop.png) 152 | 153 | # Example 2: Monitoring shape 154 | 155 | We are building a linear model 156 | 157 | ```python 158 | class Model(torch.nn.Module): 159 | 160 | def __init__(self): 161 | super().__init__() 162 | self.layer = torch.nn.Linear(2, 1) 163 | 164 | def forward(self, x): 165 | return self.layer(x) 166 | ``` 167 | 168 | and we want to fit `y = x1 + 2 * x2 + 3`, so we create a dataset: 169 | 170 | ```python 171 | x = torch.tensor([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]) 172 | y = torch.tensor([3.0, 5.0, 4.0, 6.0]) 173 | ``` 174 | 175 | We train our model on this dataset using SGD optimizer: 176 | 177 | ```python 178 | model = Model() 179 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 180 | for _ in range(10): 181 | optimizer.zero_grad() 182 | pred = model(x) 183 | squared_diff = (y - pred) ** 2 184 | loss = squared_diff.mean() 185 | print(loss.item()) 186 | loss.backward() 187 | optimizer.step() 188 | ``` 189 | 190 | But unfortunately, the loss does not go down to a low enough number. 191 | 192 | What's wrong? Let's snoop it! Putting the training loop inside snoop: 193 | 194 | ```python 195 | with torchsnooper.snoop(): 196 | for _ in range(100): 197 | optimizer.zero_grad() 198 | pred = model(x) 199 | squared_diff = (y - pred) ** 2 200 | loss = squared_diff.mean() 201 | print(loss.item()) 202 | loss.backward() 203 | optimizer.step() 204 | ``` 205 | 206 | Part of the trace looks like: 207 | 208 | ``` 209 | New var:....... x = tensor<(4, 2), float32, cpu> 210 | New var:....... y = tensor<(4,), float32, cpu> 211 | New var:....... model = Model( (layer): Linear(in_features=2, out_features=1, bias=True)) 212 | New var:....... optimizer = SGD (Parameter Group 0 dampening: 0 lr: 0....omentum: 0 nesterov: False weight_decay: 0) 213 | 22:27:01.024233 line 21 for _ in range(100): 214 | New var:....... _ = 0 215 | 22:27:01.024439 line 22 optimizer.zero_grad() 216 | 22:27:01.024574 line 23 pred = model(x) 217 | New var:....... pred = tensor<(4, 1), float32, cpu, grad> 218 | 22:27:01.026442 line 24 squared_diff = (y - pred) ** 2 219 | New var:....... squared_diff = tensor<(4, 4), float32, cpu, grad> 220 | 22:27:01.027369 line 25 loss = squared_diff.mean() 221 | New var:....... loss = tensor<(), float32, cpu, grad> 222 | 22:27:01.027616 line 26 print(loss.item()) 223 | 22:27:01.027793 line 27 loss.backward() 224 | 22:27:01.050189 line 28 optimizer.step() 225 | ``` 226 | 227 | We notice that, `y` has shape `(4,)`, but `pred` has shape `(4, 1)`. As a result, `squared_diff` has shape `(4, 4)` due to broadcasting! 228 | 229 | This is not the expected behavior, let's fix it: `pred = model(x).squeeze()`, now everything looks good: 230 | 231 | ``` 232 | New var:....... x = tensor<(4, 2), float32, cpu> 233 | New var:....... y = tensor<(4,), float32, cpu> 234 | New var:....... model = Model( (layer): Linear(in_features=2, out_features=1, bias=True)) 235 | New var:....... optimizer = SGD (Parameter Group 0 dampening: 0 lr: 0....omentum: 0 nesterov: False weight_decay: 0) 236 | 22:28:19.778089 line 21 for _ in range(100): 237 | New var:....... _ = 0 238 | 22:28:19.778293 line 22 optimizer.zero_grad() 239 | 22:28:19.778436 line 23 pred = model(x).squeeze() 240 | New var:....... pred = tensor<(4,), float32, cpu, grad> 241 | 22:28:19.780250 line 24 squared_diff = (y - pred) ** 2 242 | New var:....... squared_diff = tensor<(4,), float32, cpu, grad> 243 | 22:28:19.781099 line 25 loss = squared_diff.mean() 244 | New var:....... loss = tensor<(), float32, cpu, grad> 245 | 22:28:19.781361 line 26 print(loss.item()) 246 | 22:28:19.781537 line 27 loss.backward() 247 | 22:28:19.798983 line 28 optimizer.step() 248 | ``` 249 | 250 | And the final model converge to the desired values. 251 | 252 | # Example 3: Monitoring nan and inf 253 | 254 | Let's say we have a model that output the likelihood of something. For this example, we will just use a mock: 255 | 256 | ```python 257 | class MockModel(torch.nn.Module): 258 | 259 | def __init__(self): 260 | super(MockModel, self).__init__() 261 | self.unused = torch.nn.Linear(6, 4) 262 | 263 | def forward(self, x): 264 | return torch.tensor([0.0, 0.25, 0.9, 0.75]) + self.unused(x) * 0.0 265 | 266 | model = MockModel() 267 | ``` 268 | 269 | During training, we want to minimize the negative log likelihood, we have code: 270 | 271 | ```python 272 | for epoch in range(100): 273 | batch_input = torch.randn(6, 6) 274 | likelihood = model(batch_input) 275 | log_likelihood = likelihood.log() 276 | target = -log_likelihood.mean() 277 | print(target.item()) 278 | 279 | optimizer.zero_grad() 280 | target.backward() 281 | optimizer.step() 282 | ``` 283 | 284 | Unfortunately, we first get `inf` then `nan` for our target during training. What's wrong? Let's snoop it: 285 | 286 | ```python 287 | with torchsnooper.snoop(): 288 | for epoch in range(100): 289 | batch_input = torch.randn(6, 6) 290 | likelihood = model(batch_input) 291 | log_likelihood = likelihood.log() 292 | target = -log_likelihood.mean() 293 | print(target.item()) 294 | 295 | optimizer.zero_grad() 296 | target.backward() 297 | optimizer.step() 298 | ``` 299 | 300 | We will see the part of the output of the snoop looks like: 301 | 302 | ``` 303 | 19:30:20.928316 line 18 for epoch in range(100): 304 | New var:....... epoch = 0 305 | 19:30:20.928575 line 19 batch_input = torch.randn(6, 6) 306 | New var:....... batch_input = tensor<(6, 6), float32, cpu> 307 | 19:30:20.929671 line 20 likelihood = model(batch_input) 308 | New var:....... likelihood = tensor<(6, 4), float32, cpu, grad> 309 | 19:30:20.930284 line 21 log_likelihood = likelihood.log() 310 | New var:....... log_likelihood = tensor<(6, 4), float32, cpu, grad, has_inf> 311 | 19:30:20.930672 line 22 target = -log_likelihood.mean() 312 | New var:....... target = tensor<(), float32, cpu, grad, has_inf> 313 | 19:30:20.931136 line 23 print(target.item()) 314 | 19:30:20.931508 line 25 optimizer.zero_grad() 315 | 19:30:20.931871 line 26 target.backward() 316 | inf 317 | 19:30:20.960028 line 27 optimizer.step() 318 | 19:30:20.960673 line 18 for epoch in range(100): 319 | Modified var:.. epoch = 1 320 | 19:30:20.961043 line 19 batch_input = torch.randn(6, 6) 321 | 19:30:20.961423 line 20 likelihood = model(batch_input) 322 | Modified var:.. likelihood = tensor<(6, 4), float32, cpu, grad, has_nan> 323 | 19:30:20.961910 line 21 log_likelihood = likelihood.log() 324 | Modified var:.. log_likelihood = tensor<(6, 4), float32, cpu, grad, has_nan> 325 | 19:30:20.962302 line 22 target = -log_likelihood.mean() 326 | Modified var:.. target = tensor<(), float32, cpu, grad, has_nan> 327 | 19:30:20.962715 line 23 print(target.item()) 328 | 19:30:20.963089 line 25 optimizer.zero_grad() 329 | 19:30:20.963464 line 26 target.backward() 330 | 19:30:20.964051 line 27 optimizer.step() 331 | ``` 332 | 333 | Reading the output, we find that, at the first epoch (`epoch = 0`), the `log_likelihood` has `has_inf` flag. 334 | The `has_inf` flag means, your tensor contains `inf` in its value. The same flag appears for `target`. 335 | And at the second epoch, starting from `likelihood`, tensors all have a `has_nan` flag. 336 | 337 | From our experience with deep learning, we would guess this is because the first epoch has `inf`, which causes 338 | the gradient to be `nan`, and when parameters are updated, these `nan` propagate to parameters and causing all 339 | future steps to have `nan` result. 340 | 341 | Taking a deeper look, we figure out that the `likelihood` contains 0 in it, which leads to `log(0) = -inf`. Changing 342 | the line 343 | ```python 344 | log_likelihood = likelihood.log() 345 | ``` 346 | into 347 | ```python 348 | log_likelihood = likelihood.clamp(min=1e-8).log() 349 | ``` 350 | Problem solved. 351 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | # Copyright 2019 Ram Rachum and collaborators. 4 | # This program is distributed under the MIT license. 5 | import os 6 | import re 7 | import abc 8 | import inspect 9 | 10 | from pysnooper.utils import DEFAULT_REPR_RE 11 | 12 | try: 13 | from itertools import zip_longest 14 | except ImportError: 15 | from itertools import izip_longest as zip_longest 16 | 17 | from . import mini_toolbox 18 | 19 | import pysnooper.pycompat 20 | 21 | 22 | def get_function_arguments(function, exclude=()): 23 | try: 24 | getfullargspec = inspect.getfullargspec 25 | except AttributeError: 26 | result = inspect.getargspec(function).args 27 | else: 28 | result = getfullargspec(function).args 29 | for exclude_item in exclude: 30 | result.remove(exclude_item) 31 | return result 32 | 33 | 34 | class _BaseEntry(pysnooper.pycompat.ABC): 35 | def __init__(self, prefix=''): 36 | self.prefix = prefix 37 | 38 | @abc.abstractmethod 39 | def check(self, s): 40 | pass 41 | 42 | def __repr__(self): 43 | init_arguments = get_function_arguments(self.__init__, 44 | exclude=('self',)) 45 | attributes = { 46 | key: repr(getattr(self, key)) for key in init_arguments 47 | if getattr(self, key) is not None 48 | } 49 | return '%s(%s)' % ( 50 | type(self).__name__, 51 | ', '.join('{key}={value}'.format(**locals()) for key, value 52 | in attributes.items()) 53 | ) 54 | 55 | 56 | 57 | class _BaseValueEntry(_BaseEntry): 58 | def __init__(self, prefix=''): 59 | _BaseEntry.__init__(self, prefix=prefix) 60 | self.line_pattern = re.compile( 61 | r"""^%s(?P(?: {4})*)(?P[^:]*):""" 62 | r"""\.{2,7} (?P.*)$""" % (re.escape(self.prefix),) 63 | ) 64 | 65 | @abc.abstractmethod 66 | def _check_preamble(self, preamble): 67 | pass 68 | 69 | @abc.abstractmethod 70 | def _check_content(self, preamble): 71 | pass 72 | 73 | def check(self, s): 74 | match = self.line_pattern.match(s) 75 | if not match: 76 | return False 77 | _, preamble, content = match.groups() 78 | return (self._check_preamble(preamble) and 79 | self._check_content(content)) 80 | 81 | 82 | class ElapsedTimeEntry(_BaseEntry): 83 | def __init__(self, elapsed_time_value=None, tolerance=0.2, prefix=''): 84 | _BaseEntry.__init__(self, prefix=prefix) 85 | self.line_pattern = re.compile( 86 | r"""^%s(?P(?: {4})*)Elapsed time: (?P