├── .gitignore ├── MIT_LICENSE.txt ├── README.md ├── example.py ├── pyproject.toml ├── setup.cfg └── tensorcheck ├── __init__.py ├── tensorcheck.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | **/*.swp 3 | **/.DS_Store 4 | 5 | # Distribution / packaging 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | share/python-wheels/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | MANIFEST 24 | -------------------------------------------------------------------------------- /MIT_LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorCheck 2 | 3 | `pip install tensorcheck` 4 | 5 | Run-time validation of tensors for machine-learning systems. 6 | 7 | This is a naive way of validating tensor inputs to functions by using a 8 | function decorator as a hook for introspection. This is not designed to be a 9 | composable type system, but to simply allow for concise description of what 10 | tensors are expected to be when designing new architectures or writing 11 | visualization code. 12 | 13 | ### Supports: 14 | 1. Tensor dtypes validation for `np.ndarray` and `torch.Tensor` 15 | 2. Shape validation, including generic shape variables 16 | 3. Range of input validation 17 | 18 | ## Example Usage 19 | 20 | ```python 21 | import numpy as np 22 | import torch 23 | from tensorcheck import tensorcheck 24 | 25 | @tensorcheck({ 26 | "img": { 27 | "dtype": np.uint8, 28 | "shape": [1, 3, "H", "W"], 29 | "range": [0, 255] 30 | }, 31 | "mask": { 32 | "dtype": torch.float32, 33 | "shape": [1, 1, "H", "W"], 34 | "range": [0, 1] 35 | }, 36 | "return": { 37 | "dtype": np.float32, 38 | "shape": [1, 3, "H", "W"], 39 | "range": [0, 255] 40 | }, 41 | }) 42 | def apply_mask(img, mask): 43 | # ...do compute 44 | return img * mask.numpy() 45 | 46 | 47 | x = np.random.uniform(0, 255, size=[1, 3, 10, 8]).astype(np.uint8) 48 | y = torch.rand(1, 1, 10, 7) 49 | apply_mask(x, y) 50 | # > tensorcheck.ShapeException: /mask/ dim 3 of torch.Size([1, 1, 10, 7]) is not W=8 51 | 52 | x = np.random.uniform(0, 255, size=[1, 3, 10, 8]).astype(np.uint8) 53 | y = 2 * torch.rand(1, 1, 10, 8) 54 | apply_mask(x, y) 55 | # > tensorcheck.UpperBoundException: /mask/ max value 1.9982... is greater than 1 56 | 57 | x = np.random.uniform(0, 255, size=[1, 3, 10, 8]).astype(np.float) 58 | y = torch.rand(1, 1, 10, 8) 59 | apply_mask(x, y) 60 | # > tensorcheck.DataTypeException: /img/ dtype float64 is not 61 | 62 | x = np.random.uniform(0, 255, size=[1, 3, 10, 8]).astype(np.uint8) 63 | y = torch.rand(1, 1, 10, 8).int() 64 | apply_mask(x, y) 65 | # > tensorcheck.DataTypeException: /mask/ dtype torch.int32 is not torch.float32 66 | 67 | x = np.random.uniform(0, 255, size=[1, 3, 10, 8]).astype(np.uint8) 68 | y = torch.rand(1, 1, 10, 8) 69 | apply_mask(x, y) 70 | # > Success 71 | ``` 72 | 73 | ## Future Work 74 | - [x] Add support for return types 75 | - [ ] Validate all annotations (e.g. shape must be list of string/int/float) 76 | - [ ] Work on multiple return values 77 | - [ ] Guarantee whole range is used, e.g. [0, 1] should fail on [0, 255] 78 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tensorcheck import tensorcheck 4 | 5 | @tensorcheck({ 6 | "img": { 7 | "dtype": np.uint8, 8 | "shape": [1, 3, "H", "W"], 9 | "range": [0, 255] 10 | }, 11 | "mask": { 12 | "dtype": torch.float32, 13 | "shape": [1, 1, "H", "W"], 14 | "range": [0, 1] 15 | }, 16 | "return": { 17 | "dtype": np.float32, 18 | "shape": [1, 3, "H", "W"], 19 | "range": [0, 255] 20 | }, 21 | }) 22 | def apply_mask(img, mask): 23 | # ...do compute 24 | return img * mask.numpy() 25 | 26 | 27 | x = np.random.uniform(0, 255, size=[1, 3, 10, 8]).astype(np.uint8) 28 | y = torch.rand(1, 1, 10, 7) 29 | apply_mask(x, y) 30 | # > tensorcheck.ShapeException: /mask/ dim 3 of torch.Size([1, 1, 10, 7]) is not W=8 31 | 32 | x = np.random.uniform(0, 255, size=[1, 3, 10, 8]).astype(np.uint8) 33 | y = 2 * torch.rand(1, 1, 10, 8) 34 | apply_mask(x, y) 35 | # > tensorcheck.UpperBoundException: /mask/ max value 1.9982... is greater than 1 36 | 37 | x = np.random.uniform(0, 255, size=[1, 3, 10, 8]).astype(np.float) 38 | y = torch.rand(1, 1, 10, 8) 39 | apply_mask(x, y) 40 | # > tensorcheck.DataTypeException: /img/ dtype float64 is not 41 | 42 | x = np.random.uniform(0, 255, size=[1, 3, 10, 8]).astype(np.uint8) 43 | y = torch.rand(1, 1, 10, 8).int() 44 | apply_mask(x, y) 45 | # > tensorcheck.DataTypeException: /mask/ dtype torch.int32 is not torch.float32 46 | 47 | x = np.random.uniform(0, 255, size=[1, 3, 10, 8]).astype(np.uint8) 48 | y = torch.rand(1, 1, 10, 8) 49 | apply_mask(x, y) 50 | # > Success 51 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = tensorcheck 3 | version = 0.0.2 4 | url = https://github.com/tobyshooters/tensorcheck 5 | author = Cristóbal Sciutto 6 | author_email = cristobal.sciutto@gmail.com 7 | description = Run-time validation of tensors for machine-learning systems. 8 | license_files = MIT_LICENSE.txt 9 | classifiers = 10 | Programming Language :: Python :: 3 11 | License :: OSI Approved :: MIT License 12 | Operating System :: OS Independent 13 | 14 | [options] 15 | packages = tensorcheck 16 | install_requires = 17 | torch 18 | numpy 19 | -------------------------------------------------------------------------------- /tensorcheck/__init__.py: -------------------------------------------------------------------------------- 1 | from .tensorcheck import * 2 | -------------------------------------------------------------------------------- /tensorcheck/tensorcheck.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from inspect import getcallargs 5 | from functools import wraps 6 | 7 | class AnnotationException(Exception): pass 8 | class TypeException(Exception): pass 9 | class DataTypeException(Exception): pass 10 | class ShapeException(Exception): pass 11 | class LowerBoundException(Exception): pass 12 | class UpperBoundException(Exception): pass 13 | 14 | 15 | class TypeAsserter: 16 | 17 | def __init__(self, annotations): 18 | self.argument_annotations = {k:v for k,v in annotations.items() if k != "return"} 19 | self.return_annotation = annotations.get("return", None) 20 | # Keeps map from generic shape variables to first seen sizes 21 | self.generic_shapes = {} 22 | 23 | 24 | def assert_arguments(self, named_args): 25 | for name, annotation in self.argument_annotations.items(): 26 | if name not in named_args: 27 | raise AnnotationException(f'{name} is not a parameter of the function') 28 | self.assert_type(name, named_args[name], annotation) 29 | 30 | 31 | def assert_return(self, output): 32 | if self.return_annotation is not None: 33 | self.assert_type("return", output, self.return_annotation) 34 | 35 | 36 | def assert_type(self, name, arg, annotation): 37 | if type(arg) not in [np.ndarray, torch.Tensor]: 38 | raise TypeException(f'Annotation must be np.ndarray or torch.Tensor, not {type(arg)}') 39 | 40 | for check, wish in annotation.items(): 41 | 42 | if check == "dtype": 43 | # Reference: stackoverflow.com/questions/12569452/how-to-identify-numpy-types-in-python 44 | if isinstance(arg, np.ndarray) and not isinstance(arg.flat[0], wish): 45 | raise DataTypeException(f'/{name}/ dtype {arg.flat[0].dtype} is not {wish}') 46 | 47 | elif isinstance(arg, torch.Tensor) and not arg.dtype == wish: 48 | raise DataTypeException(f'/{name}/ dtype {arg.dtype} is not {wish}') 49 | 50 | elif check == "shape": 51 | if not isinstance(wish, list): 52 | raise AnnotationException(f'{wish} is not a valid shape annotation.') 53 | 54 | # Build up shape from cache, while updating unseen variables 55 | concrete_wish = [] 56 | for wish_dim, wish_size in enumerate(wish): 57 | if isinstance(wish_size, str): 58 | if not wish_size in self.generic_shapes: 59 | self.generic_shapes[wish_size] = arg.shape[wish_dim] 60 | concrete_wish.append(self.generic_shapes[wish_size]) 61 | elif isinstance(wish_size, (int, float)): 62 | concrete_wish.append(wish_size) 63 | else: 64 | raise AnnotationException(f'{wish_size} in shape annotation is not an int, float, or string.') 65 | 66 | if len(arg.shape) != len(concrete_wish): 67 | err = f'/{name}/ {arg.shape} is not of same length as desired shape {concrete_wish}' 68 | raise ShapeException(err) 69 | 70 | for wish_dim, wish_size in enumerate(concrete_wish): 71 | if arg.shape[wish_dim] != wish_size: 72 | errs = f'/{name}/ dim {wish_dim} of {arg.shape} is not {wish[wish_dim]}={concrete_wish[wish_dim]}' 73 | raise ShapeException(errs) 74 | 75 | elif check == "range": 76 | rmin, rmax = wish 77 | 78 | if not arg.min() >= float(rmin): 79 | errs = f'/{name}/ min value {arg.min()} is less than {rmin}' 80 | raise LowerBoundException(errs) 81 | 82 | if not arg.max() <= float(rmax): 83 | errs = f'/{name}/ max value {arg.max()} is greater than {rmax}' 84 | raise UpperBoundException(errs) 85 | 86 | 87 | 88 | def tensorcheck(annotations): 89 | """ 90 | The type checking decorator for Numpy arrays. 91 | Reference: https://realpython.com/primer-on-python-decorators 92 | """ 93 | def decorator(func): 94 | @wraps(func) 95 | def func_with_asserts(*args, **kwargs): 96 | asserter = TypeAsserter(annotations) 97 | named_args = getcallargs(func, *args, **kwargs) 98 | asserter.assert_arguments(named_args) 99 | output = func(**named_args) 100 | asserter.assert_return(output) 101 | return output 102 | return func_with_asserts 103 | return decorator 104 | -------------------------------------------------------------------------------- /tensorcheck/test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from tensorcheck import * 3 | 4 | import numpy as np 5 | 6 | class TestTensorChecker(unittest.TestCase): 7 | 8 | def test_supports_args_and_kwargs(self): 9 | try: 10 | @tensorcheck({ 11 | "a": {}, 12 | "b": {}, 13 | }) 14 | def inference(a, b): 15 | return 16 | 17 | x = np.random.randn(1, 1, 3, 2) 18 | y = np.random.randn(1, 1, 2, 2) 19 | inference(x, b=y) 20 | 21 | except AnnotationException: 22 | self.fail() 23 | 24 | 25 | def test_annotated_key_must_be_parameter(self): 26 | with self.assertRaises(AnnotationException): 27 | @tensorcheck({ 28 | "b": {}, 29 | }) 30 | def inference(a): 31 | return 32 | 33 | x = np.random.randn(1, 1, 3, 2) 34 | y = np.random.randn(1, 1, 2, 2) 35 | inference(x) 36 | 37 | 38 | def test_dtype_success(self): 39 | try: 40 | @tensorcheck({ 41 | "a": { "dtype": np.float, "dtype": np.float64 }, 42 | "b": { "dtype": np.uint8 }, 43 | "c": { "dtype": torch.float }, 44 | }) 45 | def inference(a, b, c): 46 | return 47 | 48 | x = np.random.randn(1, 1, 3, 2) 49 | img = np.random.uniform(0, 255, size=[1, 3, 5, 5]).astype(np.uint8) 50 | t = torch.randn(1, 1, 3, 3) 51 | 52 | inference(x, img, t) 53 | 54 | except DataTypeException: 55 | self.fail() 56 | 57 | 58 | def test_dtype_fail_numpy(self): 59 | with self.assertRaises(DataTypeException): 60 | @tensorcheck({ 61 | "a": { "dtype": np.int8 } 62 | }) 63 | def inference(a): 64 | return 65 | 66 | x = np.random.randn(1, 1, 3, 2) 67 | inference(x) 68 | 69 | def test_dtype_fail_torch(self): 70 | with self.assertRaises(DataTypeException): 71 | @tensorcheck({ 72 | "a": { "dtype": torch.int } 73 | }) 74 | def inference(a): 75 | return 76 | 77 | x = torch.randn(1, 1, 3, 2) 78 | inference(x) 79 | 80 | 81 | def test_shape_success(self): 82 | try: 83 | @tensorcheck({ 84 | "a": { "shape": [1, 1, 3, 2] }, 85 | "b": { "shape": [1, 1, 3, 2] }, 86 | }) 87 | def inference(a, b): 88 | return 89 | 90 | x = np.random.randn(1, 1, 3, 2) 91 | t = torch.randn(1, 1, 3, 2) 92 | inference(x, t) 93 | 94 | except ShapeException: 95 | self.fail() 96 | 97 | 98 | def test_shape_fail_numpy(self): 99 | with self.assertRaises(ShapeException): 100 | @tensorcheck({ 101 | "a": { "shape": [1, 1, 2, 2] }, 102 | }) 103 | def inference(a): 104 | return 105 | 106 | x = np.random.randn(1, 1, 3, 2) 107 | inference(x) 108 | 109 | def test_shape_fail_torch(self): 110 | with self.assertRaises(ShapeException): 111 | @tensorcheck({ 112 | "a": { "shape": [1, 1, 2, 2] }, 113 | }) 114 | def inference(a): 115 | return 116 | 117 | x = torch.randn(1, 1, 3, 2) 118 | inference(x) 119 | 120 | 121 | def test_generic_shape_success(self): 122 | try: 123 | @tensorcheck({ 124 | "a": { "shape": [1, 1, 3, "W"] }, 125 | "b": { "shape": [1, 1, 2, "W"] }, 126 | }) 127 | def inference(a, b): 128 | return 129 | 130 | x = np.random.randn(1, 1, 3, 2) 131 | y = np.random.randn(1, 1, 2, 2) 132 | inference(x, y) 133 | 134 | except: 135 | self.fail() 136 | 137 | 138 | def test_shape_length_not_equal_fails(self): 139 | with self.assertRaises(ShapeException): 140 | @tensorcheck({ 141 | "a": { "shape": [1, 1, 3] }, 142 | }) 143 | def inference(a): 144 | return 145 | 146 | x = np.random.randn(1, 1, 3, 1) 147 | inference(x) 148 | 149 | 150 | def test_generic_shape_success_across_torch_and_numpy(self): 151 | try: 152 | @tensorcheck({ 153 | "a": { "dtype": np.float, "shape": [1, 1, 2, "W"] }, 154 | "b": { "shape": torch.float, "shape": [1, 1, 3, "W"] }, 155 | }) 156 | def inference(a, b): 157 | return 158 | 159 | x = np.random.randn(1, 1, 2, 2) 160 | y = torch.randn(1, 1, 3, 2) 161 | inference(x, y) 162 | 163 | except: 164 | self.fail() 165 | 166 | 167 | def test_generic_shape_fail(self): 168 | with self.assertRaises(ShapeException): 169 | @tensorcheck({ 170 | "a": { "shape": [1, 1, "H", 2] }, 171 | "b": { "shape": [1, 1, "H", 2] }, 172 | }) 173 | def inference(a, b): 174 | return 175 | 176 | x = np.random.randn(1, 1, 3, 2) 177 | y = np.random.randn(1, 1, 2, 2) 178 | inference(x, y) 179 | 180 | 181 | def test_range_success(self): 182 | try: 183 | @tensorcheck({ 184 | "a": { "dtype": np.float, "range": [0, 1] }, 185 | "b": { "dtype": torch.float, "range": [-10, 10] }, 186 | }) 187 | def inference(a, b): 188 | return 189 | 190 | x = np.random.uniform(size=[1, 1, 2, 2]) 191 | y = torch.randn(1, 1, 3, 2) 192 | inference(x, y) 193 | 194 | except: 195 | self.fail() 196 | 197 | 198 | def test_upperbound_fail_numpy(self): 199 | with self.assertRaises(UpperBoundException): 200 | @tensorcheck({ 201 | "a": { "dtype": np.float, "range": [0, 1] }, 202 | }) 203 | def inference(a): 204 | return 205 | 206 | x = 2 * np.ones([1, 1, 2, 2]) 207 | inference(x) 208 | 209 | 210 | def test_lowerbound_fail_numpy(self): 211 | with self.assertRaises(LowerBoundException): 212 | @tensorcheck({ 213 | "a": { "dtype": np.float, "range": [0, 1] }, 214 | }) 215 | def inference(a): 216 | return 217 | 218 | x = -1 * np.ones([1, 1, 2, 2]) 219 | inference(x) 220 | 221 | 222 | def test_upperbound_fail_torch(self): 223 | with self.assertRaises(UpperBoundException): 224 | @tensorcheck({ 225 | "a": { "dtype": torch.float, "range": [0, 1] }, 226 | }) 227 | def inference(a): 228 | return 229 | 230 | x = 2 * torch.ones(1, 1, 2, 2) 231 | inference(x) 232 | 233 | 234 | def test_lowerbound_fail_torch(self): 235 | with self.assertRaises(LowerBoundException): 236 | @tensorcheck({ 237 | "a": { "dtype": torch.float, "range": [0, 1] }, 238 | }) 239 | def inference(a): 240 | return 241 | 242 | x = -1 * torch.ones(1, 1, 2, 2) 243 | inference(x) 244 | 245 | 246 | def test_rgb_image_and_alpha_mask(self): 247 | try: 248 | @tensorcheck({ 249 | "img": { "dtype": np.uint8, "shape": [1, 3, "H", "W"], "range": [0, 255] }, 250 | "mask": { "dtype": np.float, "shape": [1, 1, "H", "W"], "range": [0, 1] }, 251 | }) 252 | def inference(img, mask): 253 | return 254 | 255 | img = np.random.uniform(0, 255, size=[1, 3, 5, 5]).astype(np.uint8) 256 | mask = np.random.uniform(size=[1, 1, 5, 5]) 257 | inference(img, mask) 258 | 259 | except: 260 | self.fail() 261 | 262 | 263 | def test_return_with_outer_product_shape(self): 264 | try: 265 | @tensorcheck({ 266 | "a": { "shape": [2] }, 267 | "b": { "shape": [3] }, 268 | "return": { "shape": [2, 3] }, 269 | }) 270 | def inference(a, b): 271 | return np.outer(a, b) 272 | 273 | inference(np.array([-3, 2]), np.array([2, 5, -1])) 274 | 275 | except: 276 | self.fail() 277 | 278 | 279 | def test_return_with_type_cast(self): 280 | try: 281 | @tensorcheck({ 282 | "a": { "dtype": np.float }, 283 | "return": { "dtype": np.uint8 }, 284 | }) 285 | def cast_to_int(a): 286 | return a.astype(np.uint8) 287 | 288 | img = np.random.uniform(0, 255, size=[1, 3, 3]) 289 | cast_to_int(img) 290 | 291 | except: 292 | self.fail() 293 | 294 | 295 | def test_return_fail_with_type(self): 296 | with self.assertRaises(DataTypeException): 297 | @tensorcheck({ 298 | "a": { "dtype": np.float }, 299 | "return": { "dtype": np.uint8 }, 300 | }) 301 | def identity(a): 302 | return a 303 | 304 | img = np.random.uniform(0, 255, size=[1, 3, 3]) 305 | identity(img) 306 | 307 | 308 | 309 | if __name__ == '__main__': 310 | unittest.main(verbosity=2) 311 | --------------------------------------------------------------------------------