├── tensorhue ├── py.typed ├── _print_opts.py ├── eastereggs.py ├── __init__.py ├── colors.py ├── converters.py └── viz.py ├── requirements.txt ├── setup.py ├── .github ├── images.png ├── tensorhue.png ├── tensor_types.png ├── confusion_matrix.png ├── PULL_REQUEST_TEMPLATE.md ├── ISSUE_TEMPLATE │ ├── help-wanted.md │ ├── feature_request.md │ └── bug-report.md └── workflows │ ├── pylint.yml │ ├── tests.yml │ ├── update-coverage-badge.yml │ └── pypi-release.yml ├── tests ├── test_resources │ ├── test_image_rgb.jpg │ ├── test_image_rgba.png │ ├── test_image_indexed.gif │ ├── test_image_greyscale.gif │ ├── test_image_greyscale.jpg │ └── test_image_greyscale.png ├── test_eastereggs.py ├── test__print_opts.py ├── test_converter_pillow.py ├── test_converter_tensorflow.py ├── test_converter_jax.py ├── test_converter_torch.py ├── test_colors.py └── test_viz.py ├── .flake8 ├── requirements-dev.txt ├── reports └── junit │ └── junit.xml ├── .pre-commit-config.yaml ├── pyproject.toml ├── coverage-badge.svg ├── LICENSE ├── setup.cfg ├── .gitignore ├── README.md └── .pylintrc /tensorhue/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | rich 3 | matplotlib 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /.github/images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epistoteles/TensorHue/HEAD/.github/images.png -------------------------------------------------------------------------------- /.github/tensorhue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epistoteles/TensorHue/HEAD/.github/tensorhue.png -------------------------------------------------------------------------------- /.github/tensor_types.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epistoteles/TensorHue/HEAD/.github/tensor_types.png -------------------------------------------------------------------------------- /.github/confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epistoteles/TensorHue/HEAD/.github/confusion_matrix.png -------------------------------------------------------------------------------- /tests/test_resources/test_image_rgb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epistoteles/TensorHue/HEAD/tests/test_resources/test_image_rgb.jpg -------------------------------------------------------------------------------- /tests/test_resources/test_image_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epistoteles/TensorHue/HEAD/tests/test_resources/test_image_rgba.png -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503, F403, F401 3 | max-line-length = 120 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 -------------------------------------------------------------------------------- /tests/test_resources/test_image_indexed.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epistoteles/TensorHue/HEAD/tests/test_resources/test_image_indexed.gif -------------------------------------------------------------------------------- /tests/test_resources/test_image_greyscale.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epistoteles/TensorHue/HEAD/tests/test_resources/test_image_greyscale.gif -------------------------------------------------------------------------------- /tests/test_resources/test_image_greyscale.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epistoteles/TensorHue/HEAD/tests/test_resources/test_image_greyscale.jpg -------------------------------------------------------------------------------- /tests/test_resources/test_image_greyscale.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epistoteles/TensorHue/HEAD/tests/test_resources/test_image_greyscale.png -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | pre-commit 3 | pylint 4 | torch 5 | tensorflow 6 | jax 7 | pillow 8 | tox 9 | pytest 10 | pytest-cov 11 | mypy 12 | flake8 13 | genbadge[all] 14 | -------------------------------------------------------------------------------- /tests/test_eastereggs.py: -------------------------------------------------------------------------------- 1 | from tensorhue.eastereggs import pride 2 | 3 | 4 | def test_pride_output(capsys): 5 | pride() 6 | captured = capsys.readouterr() 7 | out = captured.out.rstrip("\n") 8 | assert len(out.split("\n")) == 3 9 | assert out.count("▀") == 30 10 | -------------------------------------------------------------------------------- /reports/junit/junit.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | exclude: \.svg$ 8 | - id: trailing-whitespace 9 | exclude: (\.md|\.svg$) 10 | - repo: https://github.com/psf/black 11 | rev: 22.10.0 12 | hooks: 13 | - id: black 14 | - repo: https://github.com/pre-commit/pre-commit-hooks 15 | rev: v1.2.3 16 | hooks: 17 | - id: flake8 18 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. 4 | 5 | This PR fixes #000000. 6 | 7 | ## Checklist: 8 | 9 | - [ ] PR is related to an existing issue - of none exist, open a new issue yourself 10 | - [ ] PR includes the issue number that it fixes, e.g. "This PR fixes #000000." 11 | - [ ] PR passes all automated checks (linting, tests) with a green checkmark 12 | -------------------------------------------------------------------------------- /tests/test__print_opts.py: -------------------------------------------------------------------------------- 1 | from tensorhue._print_opts import set_printoptions, PRINT_OPTS 2 | from tensorhue.colors import ColorScheme 3 | 4 | 5 | def test_default_print_opts(): 6 | assert PRINT_OPTS.edgeitems == 3 7 | assert isinstance(PRINT_OPTS.colorscheme, ColorScheme) 8 | 9 | 10 | def test_set_printopts(): 11 | set_printoptions(edgeitems=42) 12 | assert PRINT_OPTS.edgeitems == 42 13 | cs = ColorScheme(true_color=(0, 0, 0)) 14 | set_printoptions(colorscheme=cs) 15 | assert PRINT_OPTS.colorscheme == cs 16 | assert PRINT_OPTS.edgeitems == 42 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/help-wanted.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Help wanted 3 | about: You need help using TensorHue 4 | title: "[HELP]: " 5 | labels: help-wanted 6 | assignees: epistoteles 7 | 8 | --- 9 | 10 | **What you want to achieve** 11 | A clear and concise description of what you are trying to achieve. 12 | 13 | **What you have tried so far** 14 | Steps you have tried so far - ideally including some code. 15 | 16 | ``` 17 | short code snippets here 18 | ``` 19 | 20 | If you want to add more than a few lines of code, please link a [GitHub Gist](https://gist.github.com/). 21 | 22 | **Additional context** 23 | Add any other context or screenshots about the request here. 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ["py39", "py310", "py311"] 4 | extend-exclude = "(\\.md|\\.svg$)" 5 | 6 | [build-system] 7 | requires = ["setuptools>=42.0", "wheel"] 8 | build-backend = "setuptools.build_meta" 9 | 10 | [tool.pytest.ini_options] 11 | addopts = "--cov=tensorhue" 12 | testpaths = [ 13 | "tests", 14 | ] 15 | 16 | [tool.mypy] 17 | mypy_path = "tensorhue" 18 | check_untyped_defs = true 19 | disallow_any_generics = true 20 | ignore_missing_imports = true 21 | no_implicit_optional = true 22 | show_error_codes = true 23 | strict_equality = true 24 | warn_redundant_casts = true 25 | warn_return_any = true 26 | warn_unreachable = true 27 | warn_unused_configs = true 28 | no_implicit_reexport = true 29 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] 11 | steps: 12 | - name: Checkout repository 13 | uses: actions/checkout@v4 14 | 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install pylint 24 | 25 | - name: Lint code with pylint 26 | run: | 27 | pylint $(git ls-files '*.py') 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[FEATURE]: " 5 | labels: enhancement 6 | assignees: epistoteles 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I want to use x but it is does not work with y because z happens. 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG]: " 5 | labels: bug 6 | assignees: epistoteles 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 16 | ``` 17 | any code here 18 | ``` 19 | 20 | **Current behavior** 21 | The buggy behavior you are currently experiencing. 22 | 23 | **Expected behavior** 24 | A description of what you expected to happen. 25 | 26 | **Screenshots** 27 | If applicable, add screenshots to help explain your problem. 28 | 29 | **Additional context** 30 | - Python version: e.g. 3.10 31 | - TensorHue version: e.g. 0.0.11 32 | - OS: e.g. Ubuntu 22.04 33 | - Shell: e.g. bash, zsh, ... 34 | -------------------------------------------------------------------------------- /tests/test_converter_pillow.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import pytest 4 | from tensorhue.converters import _tensor_to_numpy_pillow 5 | 6 | 7 | @pytest.mark.parametrize("thumbnail", [True, False]) 8 | def test_image_modes(thumbnail): 9 | "Tests .jpg, .png and .gif images in different color modes (RGB, L, RGBA, etc.)" 10 | image_dir = "./tests/test_resources/" 11 | for file in os.listdir(image_dir): 12 | img = Image.open(image_dir + file) 13 | array = _tensor_to_numpy_pillow(img, thumbnail=thumbnail, max_size=(100, 138)) 14 | assert array.shape == ((100, 100, 3) if thumbnail else (600, 600, 3)) 15 | if img.getbands()[-1] == "A": 16 | assert array[0, 0, 0] == 0 # top left pixel is black (PIL default: transparent -> black) 17 | else: 18 | assert array[0, 0, 0] == 255 # top left pixel is white 19 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | paths: 6 | - 'tensorhub/**' 7 | - 'tests/**' 8 | - '*.toml' 9 | - 'setup.*' 10 | - 'requirements*' 11 | pull_request: 12 | 13 | jobs: 14 | test: 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | matrix: 18 | os: [ubuntu-latest, windows-latest] 19 | python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] 20 | 21 | steps: 22 | - name: Checkout repository 23 | uses: actions/checkout@v4 24 | 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install tox tox-gh-actions 34 | pip install -e . 35 | 36 | - name: Test with tox 37 | run: tox 38 | -------------------------------------------------------------------------------- /tensorhue/_print_opts.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from tensorhue.colors import ColorScheme 3 | 4 | 5 | @dataclass 6 | class __PrinterOptions: 7 | colorscheme: ColorScheme = ColorScheme() 8 | edgeitems: int = 3 9 | 10 | 11 | PRINT_OPTS = __PrinterOptions() 12 | 13 | 14 | # We could use **kwargs, but this will give better docs 15 | def set_printoptions( 16 | edgeitems: int = None, 17 | colorscheme: ColorScheme = None, 18 | ): 19 | """Set options for printing. Items shamelessly taken from NumPy 20 | 21 | Args: 22 | colorscheme: The color scheme to use. 23 | edgeitems: Number of array items in summary at beginning and end of 24 | each dimension (default = 3). 25 | """ 26 | if edgeitems is not None: 27 | assert isinstance(edgeitems, int) 28 | PRINT_OPTS.edgeitems = edgeitems 29 | if colorscheme is not None: 30 | assert isinstance(colorscheme, ColorScheme) 31 | PRINT_OPTS.colorscheme = colorscheme 32 | -------------------------------------------------------------------------------- /coverage-badge.svg: -------------------------------------------------------------------------------- 1 | coverage: 87.32%coverage87.32% -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Korbinian Koch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tensorhue/eastereggs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib.colors import LinearSegmentedColormap 3 | from rich.color_triplet import ColorTriplet 4 | from tensorhue.colors import ColorScheme 5 | from tensorhue.viz import viz, get_terminal_size 6 | 7 | 8 | def pride(width: int = None): 9 | """ 10 | Prints a pride flag in the terminal 11 | 12 | Args: 13 | width (int, optional): The width of the pride flag. If none is specified, 14 | the full width of the terminal is used. 15 | """ 16 | if width is None: 17 | width = get_terminal_size(default_width=10).columns 18 | pride_colors = [ 19 | ColorTriplet(228, 3, 3), 20 | ColorTriplet(255, 140, 0), 21 | ColorTriplet(255, 237, 0), 22 | ColorTriplet(0, 128, 38), 23 | ColorTriplet(0, 76, 255), 24 | ColorTriplet(115, 41, 130), 25 | ] 26 | pride_cm = LinearSegmentedColormap.from_list(colors=[c.normalized for c in pride_colors], name="pride") 27 | pride_cs = ColorScheme(colormap=pride_cm) 28 | arr = np.repeat(np.linspace(0, 1, 6).reshape(-1, 1), width, axis=1) 29 | viz(arr, colorscheme=pride_cs, legend=False) 30 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = tensorhue 3 | version = attr: tensorhue.__version__ 4 | author = epistoteles 5 | author_email = 6 | description = TensorHue lets you visualize tensors in your console. 7 | long_description_content_type = text/markdown 8 | long_description = 9 | TensorHue is a Python library that allows you to visualize tensors right in your console, making understanding and debugging tensor contents easier. 10 | 11 | Learn more at: https://github.com/epistoteles/tensorhue 12 | keywords = 13 | classifiers = 14 | Development Status :: 3 - Alpha 15 | Intended Audience :: Developers 16 | Programming Language :: Python :: 3 17 | Operating System :: Unix 18 | Operating System :: MacOS :: MacOS X 19 | Operating System :: Microsoft :: Windows 20 | 21 | [options] 22 | packages = find: 23 | python_requires = >=3.7 24 | zip_safe = no 25 | install_requires = 26 | numpy 27 | rich 28 | matplotlib 29 | 30 | [options.packages.find] 31 | include = tensorhue* 32 | 33 | [project] 34 | name = "tensorhue" 35 | dynamic = ["version"] 36 | 37 | [tool.setuptools.dynamic] 38 | version = {attr = "tensorhue.__version__"} 39 | 40 | [options.extras_require] 41 | testing = 42 | pre-commit 43 | pylint 44 | torch 45 | tox 46 | pytest 47 | pytest-cov 48 | mypy 49 | flake8 50 | 51 | [options.package_data] 52 | tensorhue = py.typed 53 | -------------------------------------------------------------------------------- /tests/test_converter_tensorflow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import pytest 4 | from tensorhue.converters import _tensor_to_numpy_tensorflow 5 | 6 | 7 | class NonConvertibleTensor: 8 | def numpy(self): 9 | raise RuntimeError("This tensor cannot be converted to numpy") 10 | 11 | 12 | def test_tensor_dtypes(): 13 | dtypes = { 14 | tf.float32: "float32", 15 | tf.double: "float64", 16 | tf.int32: "int32", 17 | tf.int64: "int64", 18 | tf.bool: "bool", 19 | tf.complex128: "complex128", 20 | } 21 | tf_tensor = tf.constant([0.0, 1.0, 2.0, float("nan"), float("inf")]) 22 | for dtype_tf, dtype_np in dtypes.items(): 23 | tensor_casted = tf.cast(tf_tensor, dtype_tf) 24 | converted = _tensor_to_numpy_tensorflow(tensor_casted) 25 | assert np.array_equal( 26 | converted.dtype, dtype_np 27 | ), f"dtype mismatch in torch to numpy conversion: expected {dtype_np}, got {converted.dtype}" 28 | 29 | 30 | def test_runtime_error_for_non_convertible_tensor(): 31 | non_convertible = NonConvertibleTensor() 32 | with pytest.raises(NotImplementedError) as exc_info: 33 | _tensor_to_numpy_tensorflow(non_convertible) 34 | assert "This tensor cannot be converted to numpy" in str(exc_info.value) 35 | 36 | 37 | def test_unexpected_exception_for_other_errors(): 38 | class UnexpectedErrorTensor: 39 | def numpy(self): 40 | raise ValueError("Unexpected error") 41 | 42 | with pytest.raises(RuntimeError) as exc_info: 43 | _tensor_to_numpy_tensorflow(UnexpectedErrorTensor()) 44 | assert "Unexpected error" in str(exc_info.value) 45 | -------------------------------------------------------------------------------- /tests/test_converter_jax.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import jax.numpy as jnp 3 | from jax import core 4 | import numpy as np 5 | from tensorhue.converters import _tensor_to_numpy_jax 6 | 7 | 8 | class NonConvertibleTensor: 9 | pass 10 | 11 | 12 | def test_jax_device_array(): 13 | data = [[1, 2], [3, 4]] 14 | device_array = jnp.array(data) 15 | assert np.array_equal(_tensor_to_numpy_jax(device_array), np.array(data)) 16 | 17 | 18 | def test_tensor_dtypes(): 19 | dtypes = { 20 | jnp.float32: "float32", 21 | jnp.bfloat16: "bfloat16", 22 | jnp.int32: "int32", 23 | jnp.uint8: "uint8", 24 | bool: "bool", 25 | jnp.complex64: "complex64", 26 | } 27 | jnp_array = jnp.array([0.0, 1.0, 2.0, jnp.nan, jnp.inf]) 28 | for dtype_jnp, dtype_np in dtypes.items(): 29 | jnp_casted = jnp_array.astype(dtype_jnp) 30 | converted = _tensor_to_numpy_jax(jnp_casted) 31 | assert np.array_equal( 32 | converted.dtype, dtype_np 33 | ), f"dtype mismatch in jax.numpy to numpy conversion: expected {dtype_np}, got {converted.dtype}" 34 | 35 | 36 | def test_jax_incompatible_arrays(): 37 | shape = (2, 2) 38 | dtype = jnp.float32 39 | 40 | shaped_array = core.ShapedArray(shape, dtype) 41 | with pytest.raises(NotImplementedError) as exc_info: 42 | _tensor_to_numpy_jax(shaped_array) 43 | assert "cannot be visualized" in str(exc_info.value) 44 | 45 | 46 | def test_runtime_error_for_non_convertible_tensor(): 47 | non_convertible = NonConvertibleTensor() 48 | with pytest.raises(NotImplementedError) as exc_info: 49 | _tensor_to_numpy_jax(non_convertible) 50 | assert "Got non-visualizable dtype 'object'." in str(exc_info.value) 51 | -------------------------------------------------------------------------------- /tensorhue/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | from tensorhue.colors import COLORS, ColorScheme 4 | from tensorhue._print_opts import PRINT_OPTS, set_printoptions 5 | from tensorhue.eastereggs import pride 6 | from tensorhue.viz import viz, _viz, _viz_image 7 | 8 | 9 | __version__ = "0.1.0" # single source of version truth 10 | 11 | __all__ = ["set_printoptions", "viz", "pride"] 12 | 13 | 14 | # show deprecation warning for t.viz() usage 15 | # delete everything below this line after version 0.2.0 16 | 17 | 18 | def _viz_is_deprecated(self): 19 | raise DeprecationWarning("The tensor.viz() function has been deprecated. Please use tensorhue.viz(tensor) instead.") 20 | 21 | 22 | if "torch" in sys.modules: 23 | torch = sys.modules["torch"] 24 | setattr(torch.Tensor, "viz", _viz_is_deprecated) 25 | if "jax" in sys.modules: 26 | jax = sys.modules["jax"] 27 | setattr(jax.Array, "viz", _viz_is_deprecated) 28 | jaxlib = sys.modules["jaxlib"] 29 | if "DeviceArrayBase" in {x[0] for x in inspect.getmembers(jaxlib.xla_extension)}: # jax < 0.4.X 30 | setattr(jaxlib.xla_extension.DeviceArrayBase, "viz", _viz_is_deprecated) 31 | if "ArrayImpl" in { 32 | x[0] for x in inspect.getmembers(jaxlib.xla_extension) 33 | }: # jax >= 0.4.X (not sure about the exact version this changed) 34 | setattr(jaxlib.xla_extension.ArrayImpl, "viz", _viz_is_deprecated) 35 | if "tensorflow" in sys.modules: 36 | tensorflow = sys.modules["tensorflow"] 37 | setattr(tensorflow.Tensor, "viz", _viz_is_deprecated) 38 | composite_tensor = sys.modules["tensorflow.python.framework.composite_tensor"] 39 | setattr(composite_tensor.CompositeTensor, "viz", _viz_is_deprecated) 40 | if "PIL" in sys.modules: 41 | PIL = sys.modules["PIL"] 42 | setattr(PIL.Image.Image, "viz", _viz_is_deprecated) 43 | -------------------------------------------------------------------------------- /tests/test_converter_torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | from tensorhue.converters import _tensor_to_numpy_torch 5 | 6 | 7 | class NonConvertibleTensor(torch.Tensor): 8 | def numpy(self): 9 | raise RuntimeError("This tensor cannot be converted to numpy") 10 | 11 | 12 | @pytest.mark.filterwarnings("ignore::UserWarning:torch") 13 | def test_masked_tensor(): 14 | ones = torch.ones(5, 5) 15 | mask = torch.eye(5).bool() 16 | masked_torch = torch.masked.MaskedTensor(ones, mask) 17 | masked_numpy = np.ma.masked_array(ones.numpy(), ~mask.numpy()) 18 | converted = _tensor_to_numpy_torch(masked_torch) 19 | assert np.array_equal( 20 | converted, masked_numpy 21 | ), "Converting masked tensor to masked array failed. Please check if the torch.masked API changed and raise an issue." 22 | 23 | 24 | def test_tensor_dtypes(): 25 | dtypes = { 26 | torch.FloatTensor: "float32", 27 | torch.DoubleTensor: "float64", 28 | torch.IntTensor: "int32", 29 | torch.LongTensor: "int64", 30 | torch.bool: "bool", 31 | torch.complex128: "complex128", 32 | } 33 | torch_tensor = torch.Tensor([0.0, 1.0, 2.0, torch.nan, torch.inf]) 34 | for dtype_torch, dtype_np in dtypes.items(): 35 | torch_casted = torch_tensor.type(dtype_torch) 36 | converted = _tensor_to_numpy_torch(torch_casted) 37 | assert np.array_equal( 38 | converted.dtype, dtype_np 39 | ), f"dtype mismatch in torch to numpy conversion: expected {dtype_np}, got {converted.dtype}" 40 | 41 | 42 | def test_runtime_error_for_non_convertible_tensor(): 43 | non_convertible = NonConvertibleTensor() 44 | with pytest.raises(NotImplementedError) as exc_info: 45 | _tensor_to_numpy_torch(non_convertible) 46 | assert "This tensor cannot be converted to numpy" in str(exc_info.value) 47 | 48 | 49 | def test_unexpected_exception_for_other_errors(): 50 | class UnexpectedErrorTensor: 51 | def numpy(self): 52 | raise ValueError("Unexpected error") 53 | 54 | with pytest.raises(RuntimeError) as exc_info: 55 | _tensor_to_numpy_torch(UnexpectedErrorTensor()) 56 | assert "Unexpected error" in str(exc_info.value) 57 | -------------------------------------------------------------------------------- /.github/workflows/update-coverage-badge.yml: -------------------------------------------------------------------------------- 1 | name: Update code coverage badge 2 | 3 | on: 4 | push: 5 | branches: 6 | main 7 | paths: 8 | - 'tensorhub/**' 9 | - 'tests/**' 10 | pull_request: 11 | branches: 12 | main 13 | paths: 14 | - 'tensorhub/**' 15 | - 'tests/**' 16 | 17 | jobs: 18 | update_coverage: 19 | name: "Update coverage badge" 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: "Checkout repository" 23 | uses: actions/checkout@v4 24 | 25 | - name: "Set up Python" 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: '3.x' 29 | 30 | - name: "Install dependencies" 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install -r requirements-dev.txt 34 | pip install -e . 35 | 36 | - name: "Run coverage and generate badge" 37 | run: | 38 | pytest . 39 | coverage report 40 | coverage xml 41 | coverage html 42 | genbadge coverage --input-file coverage.xml 43 | 44 | - name: "Check if coverage-badge.svg Changed" 45 | id: check_svg_change 46 | run: | 47 | if git diff --quiet -- exit-code -- coverage-badge.svg; then 48 | echo "No changes in coverage-badge.svg" 49 | echo "svg_changed=false" >> $GITHUB_ENV 50 | else 51 | echo "Changes detected in coverage-badge.svg" 52 | echo "svg_changed=true" >> $GITHUB_ENV 53 | fi 54 | 55 | - name: "Commit and push changes" 56 | if: env.svg_changed == 'true' 57 | run: | 58 | git config user.email "${{ github.run_id }}+github-actions[bot]@users.noreply.github.com" 59 | git config user.name "github-actions[bot]" 60 | 61 | if [[ "${{ github.event_name }}" == 'push' ]]; then 62 | target_branch=$(echo "${{ github.ref }}" | awk -F'/' '{print $3}') 63 | else 64 | target_branch="${{ github.event.pull_request.head.ref }}" 65 | fi 66 | 67 | git fetch origin "${target_branch}:${target_branch}" 68 | git checkout "${target_branch}" || git checkout -b "${target_branch}" 69 | git push --set-upstream origin "${target_branch}" 70 | 71 | git add coverage-badge.svg 72 | git commit -m "gh-actions[bot]: update code coverage badge" 73 | git push 74 | env: 75 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 76 | -------------------------------------------------------------------------------- /.github/workflows/pypi-release.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Release python package on PyPi 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | pypi-publish: 20 | 21 | runs-on: ubuntu-latest 22 | environment: release 23 | permissions: 24 | id-token: write 25 | 26 | steps: 27 | - name: Checkout repository 28 | uses: actions/checkout@v4 29 | 30 | - name: Set up Python 3.x 31 | uses: actions/setup-python@v5 32 | with: 33 | python-version: '3.x' 34 | 35 | - name: Install dependencies 36 | run: | 37 | python -m pip install --upgrade pip 38 | pip install build 39 | 40 | - name: Check version consistency 41 | run: | 42 | import re 43 | import sys 44 | import os 45 | 46 | # Read the version from tensorhue/__init__.py 47 | with open('tensorhue/__init__.py') as f: 48 | content = f.read() 49 | match = re.search(r'__version__\s*=\s*["\']([^"\']+)["\']', content) 50 | if match: 51 | version = match.group(1) 52 | else: 53 | print("Could not find __version__ in tensorhue/__init__.py") 54 | sys.exit(1) 55 | 56 | # Extract the version from the GitHub release tag 57 | github_ref = os.environ.get('GITHUB_REF') 58 | if github_ref.startswith('refs/tags/'): 59 | release_version = github_ref.split('/')[-1] 60 | if release_version.startswith('v'): 61 | release_version = release_version[1:] # Remove 'v' prefix 62 | 63 | if release_version != version: 64 | print(f"Version mismatch: GitHub release {release_version} != Package version {version}") 65 | sys.exit(1) 66 | print(f"Version {version} matches GitHub release tag (without 'v' prefix).") 67 | else: 68 | print("Not a release tag, skipping version check.") 69 | shell: python 70 | 71 | - name: Build package 72 | run: python -m build 73 | 74 | - name: Publish package on PyPi 75 | uses: pypa/gh-action-pypi-publish@release/v1 76 | -------------------------------------------------------------------------------- /tests/test_colors.py: -------------------------------------------------------------------------------- 1 | from rich.color_triplet import ColorTriplet 2 | import numpy as np 3 | from matplotlib.colors import Colormap, CenteredNorm 4 | from matplotlib import colormaps 5 | from tensorhue.colors import ColorScheme, COLORS 6 | 7 | 8 | def test_COLORS(): 9 | for key, value in COLORS.items(): 10 | assert isinstance(key, str) 11 | assert isinstance(value, ColorTriplet) 12 | 13 | 14 | def test_ColorScheme(): 15 | cs = ColorScheme() 16 | 17 | assert isinstance(cs.colormap, Colormap) 18 | assert cs.masked_color == COLORS["masked"] 19 | assert cs.true_color == COLORS["true"] 20 | assert np.array_equal(cs.colormap.get_bad()[:3], np.array(cs.masked_color.normalized)) 21 | 22 | values1 = np.array([-0.5, 0.0, 0.5, 0.75]) 23 | result1 = cs(values1) 24 | values2 = values1 * 2 25 | result2 = cs(values2) 26 | assert np.array_equal(result1, result2) 27 | 28 | cs.colormap = colormaps["cividis"] 29 | assert np.array_equal(cs.colormap.get_bad()[:3], np.array(cs.masked_color.normalized)) 30 | 31 | cs.masked_color = COLORS["black"] 32 | assert np.array_equal(cs.colormap.get_bad()[:3], np.array([0, 0, 0])) 33 | 34 | cs.inf_color = COLORS["black"] 35 | assert np.array_equal(cs.colormap.get_over()[:3], np.array([0, 0, 0])) 36 | 37 | cs.ninf_color = COLORS["black"] 38 | assert np.array_equal(cs.colormap.get_under()[:3], np.array([0, 0, 0])) 39 | 40 | bool_array = np.array([True, False]) 41 | assert np.array_equal(cs(bool_array), np.array([cs.true_color, cs.false_color])) 42 | 43 | 44 | def test_vmin_vmax(): 45 | cs = ColorScheme(colormap=colormaps["magma"]) 46 | 47 | values1 = np.array([-0.5, 0.0, 0.5, 0.75]) 48 | 49 | result1 = cs(values1) 50 | assert np.array_equal( 51 | result1, 52 | np.array([[0, 0, 3, 255], [140, 41, 128, 255], [253, 159, 108, 255], [251, 252, 191, 255]], dtype=np.uint8), 53 | ) 54 | 55 | result2 = cs(values1, vmin=-0.5) 56 | assert np.array_equal(result1, result2) 57 | 58 | result3 = cs(values1, vmax=0.75) 59 | assert np.array_equal(result1, result3) 60 | 61 | result4 = cs(values1, vmin=-0.5, vmax=0.75) 62 | assert np.array_equal(result1, result4) 63 | 64 | result5 = cs(values1, vmin=-1) 65 | assert np.array_equal( 66 | result5, 67 | np.array([[94, 23, 127, 255], [211, 66, 109, 255], [254, 187, 128, 255], [251, 252, 191, 255]], dtype=np.uint8), 68 | ) 69 | 70 | result6 = cs(values1, vmax=0.4) 71 | assert np.array_equal( 72 | result6, 73 | np.array([[0, 0, 3, 255], [205, 63, 112, 255], [255, 255, 255, 255], [255, 255, 255, 255]], dtype=np.uint8), 74 | ) 75 | 76 | cs = ColorScheme(colormap=colormaps["bwr"], normalize=CenteredNorm()) 77 | 78 | result7 = cs(values1) 79 | assert np.array_equal( 80 | result7, 81 | np.array([[84, 84, 255, 255], [255, 254, 254, 255], [255, 84, 84, 255], [255, 0, 0, 255]], dtype=np.uint8), 82 | ) 83 | 84 | result8 = cs(values1, vmin=-1) 85 | assert np.array_equal( 86 | result8, 87 | np.array( 88 | [[128, 128, 255, 255], [255, 254, 254, 255], [255, 126, 126, 255], [255, 62, 62, 255]], dtype=np.uint8 89 | ), 90 | ) 91 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | venv* 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | TensorHue 3 |
4 |
5 |
6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | > [!IMPORTANT] 14 | > t.viz() has been deprecated. Please use tensorhue.viz(t) instead. 15 | 16 | > [!NOTE] 17 | > TensorHue is currently in alpha. We appreciate any feedback! 18 | 19 | # TensorHue - tensors, visualized 20 | 21 | TensorHue is a Python library that allows you to visualize tensors right in your console, making understanding and debugging tensor contents easier. 22 | 23 | You can use it with your favorite tensor processing libraries, such as PyTorch, JAX, and TensorFlow, and a large set of related libraries, including Numpy, Pillow, torchvision, and more. 24 | 25 | TensorHue automagically detects which kind of tensor you are visualizing and adjusts accordingly: 26 | 27 |
28 | tensor types 29 |
30 | 31 | ## Getting started 32 | 33 | Install TensorHue with pip: 34 | 35 | ```bash 36 | pip install tensorhue 37 | ``` 38 | 39 | Using TensorHue is easy, simply import TensorHue together with the library of your choice: 40 | 41 | ```python 42 | import torch 43 | import tensorhue 44 | ``` 45 | 46 | Or, alternatively: 47 | 48 | ```python 49 | from tensorhue import viz 50 | ``` 51 | 52 | That's it! You can now visualize any tensor by calling .viz() on it in your Python console: 53 | 54 | ```python 55 | t = torch.rand(20,20) 56 | tensorhue.viz(t) ✅ 57 | ``` 58 | 59 | ## Images 60 | 61 | Pillow images can be visualized in RGB and other color modes: 62 | 63 | ```python 64 | from torchvision.datasets import CIFAR10 65 | dataset = CIFAR10('.', dowload=True) 66 | img = dataset[0][0] 67 | tensorhue.viz(img) ✅ 68 | ``` 69 | 70 |
71 | image visualization 72 |
73 | 74 | By default, images get downscaled to the size of your terminal, but you can make them even smaller if you want: 75 | 76 | ```python 77 | tensorhue.viz(img, max_size=(40,40)) ✅ 78 | ``` 79 | 80 | ## Custom colors 81 | 82 | You can pass along your own ColorScheme when visualizing a specific tensor: 83 | 84 | ```python 85 | from tensorhue import ColorScheme 86 | from matplotlib import colormaps 87 | 88 | cs = ColorScheme(colormap=colormaps['inferno'], 89 | true_color=(255,255,255), 90 | false_color=(0,0,0)) 91 | tensorhue.viz(t, colorscheme=cs) ✅ 92 | ``` 93 | 94 | Alternatively, you can overwrite the default ColorScheme: 95 | 96 | 97 | ```python 98 | tensorhue.set_printoptions(colorscheme=cs) 99 | ``` 100 | 101 | ## Advanced colormaps and normalization 102 | 103 | By default, TensorHue normalizes numerical values between 0 and 1 and then applies the matplotlib colormap. If you want to use diverging colormaps such as `coolwarm` or `bwr` and the value 0 to be mapped to the middle of the colormap, you need to specify the normalizer, e.g. `matplotlib.colors.CenteredNorm`: 104 | 105 | ```python 106 | from matplotlib.colors import CenteredNorm 107 | cs = ColorScheme(colormap=colormaps['bwr'], 108 | normalize=CenteredNorm(vcenter=0)) 109 | tensorhue.viz(t, colorscheme=cs) ✅ 110 | ``` 111 | 112 | You can also specify the normalization range manually, for example when you want to visualize a confusion matrix where colors should be mapped to the range [0, 1], but the actual values of the tensor are in the range [0.12, 0.73]: 113 | 114 | ``` 115 | tensorhue.viz(conf_matrix, vmin=0, vmax=1, scale=3) 116 | ``` 117 | 118 |
119 | confusion matrix 120 |
121 | 122 | The `scale` parameter scales up the 'pixels' of the tensor so that small tensors are easier to view. 123 | -------------------------------------------------------------------------------- /tests/test_viz.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import torch 4 | import jax.numpy as jnp 5 | import tensorflow as tf 6 | import numpy as np 7 | from PIL import Image 8 | from tensorhue.viz import viz 9 | from tensorhue.converters import _tensor_to_numpy_torch, _tensor_to_numpy_jax, _tensor_to_numpy_tensorflow 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "tensor", 14 | [ 15 | np.ones(10), 16 | _tensor_to_numpy_torch(torch.ones(10)), 17 | _tensor_to_numpy_jax(jnp.ones(10)), 18 | _tensor_to_numpy_tensorflow(tf.ones(10)), 19 | ], 20 | ) 21 | def test_1d_tensor(tensor, capsys): 22 | viz(tensor) 23 | captured = capsys.readouterr() 24 | out = captured.out.rstrip("\n") 25 | assert len(out.split("\n")) == 2 26 | assert out.count("▀") == 10 27 | assert out.split("\n")[-1] == f"shape = {tensor.shape}" 28 | 29 | 30 | @pytest.mark.parametrize( 31 | "tensor", 32 | [ 33 | np.ones((10, 10)), 34 | _tensor_to_numpy_torch(torch.ones(10, 10)), 35 | _tensor_to_numpy_jax(jnp.ones((10, 10))), 36 | _tensor_to_numpy_tensorflow(tf.ones((10, 10))), 37 | ], 38 | ) 39 | def test_2d_tensor(tensor, capsys): 40 | viz(tensor) 41 | captured = capsys.readouterr() 42 | out = captured.out.rstrip("\n") 43 | assert len(out.split("\n")) == 6 44 | assert out.count("▀") == 100 / 2 45 | assert out.split("\n")[-1] == f"shape = {tensor.shape}" 46 | 47 | 48 | @pytest.mark.parametrize( 49 | "tensor", 50 | [ 51 | np.ones(200), 52 | _tensor_to_numpy_torch(torch.ones(200)), 53 | _tensor_to_numpy_jax(jnp.ones(200)), 54 | _tensor_to_numpy_tensorflow(tf.ones(200)), 55 | ], 56 | ) 57 | def test_1d_tensor_too_wide(tensor, capsys): 58 | viz(tensor) 59 | captured = capsys.readouterr() 60 | out = captured.out.rstrip("\n") 61 | assert out.count(" ··· ") == 1 62 | assert out.count("▀") == 95 63 | assert out.split("\n")[-1] == f"shape = {tensor.shape}" 64 | 65 | 66 | @pytest.mark.parametrize( 67 | "tensor", 68 | [ 69 | np.ones((10, 200)), 70 | _tensor_to_numpy_torch(torch.ones(10, 200)), 71 | _tensor_to_numpy_jax(jnp.ones((10, 200))), 72 | _tensor_to_numpy_tensorflow(tf.ones((10, 200))), 73 | ], 74 | ) 75 | def test_2d_tensor_too_wide(tensor, capsys): 76 | viz(tensor) 77 | captured = capsys.readouterr() 78 | out = captured.out.rstrip("\n") 79 | assert out.count(" ··· ") == 5 80 | assert out.count("▀") == 950 / 2 81 | assert out.split("\n")[-1] == f"shape = {tensor.shape}" 82 | 83 | 84 | @pytest.mark.parametrize( 85 | "tensor", 86 | [ 87 | np.ones(10), 88 | _tensor_to_numpy_torch(torch.ones(10)), 89 | _tensor_to_numpy_jax(jnp.ones(10)), 90 | _tensor_to_numpy_tensorflow(tf.ones(10)), 91 | ], 92 | ) 93 | def test_no_legend(tensor, capsys): 94 | viz(tensor, legend=False) 95 | captured = capsys.readouterr() 96 | out = captured.out.rstrip("\n") 97 | assert len(out.split("\n")) == 1 98 | assert out.count("▀") == 10 99 | 100 | 101 | @pytest.mark.parametrize("scale", [1, 2, 4, 8]) 102 | def test_scale(scale, capsys): 103 | tensor = np.ones((4, 4)) 104 | viz(tensor, scale=scale) 105 | captured = capsys.readouterr() 106 | out = captured.out.rstrip("\n") 107 | assert out.count("▀") == (8 * (scale**2)) 108 | 109 | 110 | @pytest.mark.parametrize("image_filename", os.listdir("./tests/test_resources/")) 111 | @pytest.mark.parametrize("thumbnail", [True, False]) 112 | def test_viz_image(image_filename, thumbnail, capsys): 113 | filepath = "./tests/test_resources/" + image_filename 114 | image = Image.open(filepath) 115 | viz(image, thumbnail=thumbnail) 116 | captured = capsys.readouterr() 117 | out = captured.out.rstrip("\n") 118 | assert out.count(" ··· ") == (0 if thumbnail else 300) 119 | assert len(out.split("\n")) == (100 if thumbnail else 600) 120 | assert out.count("▀") == (5000 if thumbnail else 28500) 121 | 122 | 123 | @pytest.mark.parametrize("thumbnail", [True, False]) 124 | def test_viz_image_legend(thumbnail, capsys): 125 | filepath = "./tests/test_resources/test_image_rgba.png" 126 | image = Image.open(filepath) 127 | viz(image, legend=True, thumbnail=thumbnail) 128 | captured = capsys.readouterr() 129 | out = captured.out.rstrip("\n") 130 | assert out.split("\n")[-1] == "size = (600, 600), mode = RGBA" 131 | -------------------------------------------------------------------------------- /tensorhue/colors.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import warnings 3 | 4 | from rich.color_triplet import ColorTriplet 5 | from matplotlib import colormaps 6 | from matplotlib.colors import Colormap, Normalize, CenteredNorm 7 | import numpy as np 8 | 9 | 10 | COLORS = { 11 | "masked": ColorTriplet(127, 127, 127), # medium grey 12 | "default_dark": ColorTriplet(64, 17, 159), # dark purple 13 | "default_medium": ColorTriplet(255, 55, 140), # pink 14 | "default_bright": ColorTriplet(255, 210, 240), # light rose 15 | "true": ColorTriplet(125, 215, 82), # green 16 | "false": ColorTriplet(255, 80, 80), # red 17 | "accessible_true": ColorTriplet(255, 255, 255), # TODO 18 | "accessible_false": ColorTriplet(0, 0, 0), # TODO 19 | "black": ColorTriplet(0, 0, 0), # black 20 | "white": ColorTriplet(255, 255, 255), # white 21 | } 22 | 23 | 24 | class ColorScheme: 25 | def __init__( 26 | self, 27 | colormap: Colormap = colormaps["magma"], 28 | normalize: Normalize = Normalize(), 29 | masked_color: ColorTriplet = COLORS["masked"], 30 | true_color: ColorTriplet = COLORS["true"], 31 | false_color: ColorTriplet = COLORS["false"], 32 | inf_color: ColorTriplet = COLORS["white"], 33 | ninf_color: ColorTriplet = COLORS["black"], 34 | ): 35 | self._colormap = colormap 36 | self.normalize = normalize 37 | self._masked_color = masked_color 38 | self.true_color = true_color 39 | self.false_color = false_color 40 | self._inf_color = inf_color 41 | self._ninf_color = ninf_color 42 | 43 | self.colormap.set_extremes( 44 | bad=self.masked_color.normalized, under=self.ninf_color.normalized, over=self.inf_color.normalized 45 | ) 46 | 47 | @property 48 | def colormap(self): 49 | return self._colormap 50 | 51 | @colormap.setter 52 | def colormap(self, value): 53 | self._colormap = value 54 | self._colormap.set_extremes( 55 | bad=self._masked_color.normalized, under=self._ninf_color.normalized, over=self._inf_color.normalized 56 | ) 57 | 58 | @property 59 | def masked_color(self): 60 | return self._masked_color 61 | 62 | @masked_color.setter 63 | def masked_color(self, value): 64 | self._masked_color = value 65 | self._colormap.set_bad(value.normalized) 66 | 67 | @property 68 | def inf_color(self): 69 | return self._inf_color 70 | 71 | @inf_color.setter 72 | def inf_color(self, value): 73 | self._inf_color = value 74 | self._colormap.set_over(value.normalized) 75 | 76 | @property 77 | def ninf_color(self): 78 | return self._ninf_color 79 | 80 | @ninf_color.setter 81 | def ninf_color(self, value): 82 | self._ninf_color = value 83 | self._colormap.set_under(value.normalized) 84 | 85 | def __call__(self, data: np.ndarray, **kwargs) -> np.ndarray: 86 | if data.dtype == "bool": 87 | true_values = np.array(self.true_color, dtype=np.uint8) 88 | false_values = np.array(self.false_color, dtype=np.uint8) 89 | return np.where(data[..., np.newaxis], true_values, false_values) 90 | data_noinf = np.where(np.isinf(data), np.nan, data) 91 | if "vmin" not in kwargs: 92 | vmin = np.nanmin(data_noinf) 93 | else: 94 | vmin = float(kwargs["vmin"]) 95 | if "vmax" not in kwargs: 96 | vmax = np.nanmax(data_noinf) 97 | else: 98 | vmax = float(kwargs["vmax"]) 99 | if isinstance(self.normalize, CenteredNorm): 100 | vcenter = self.normalize.vcenter 101 | diff_vmin = vmin - vcenter 102 | diff_vmax = vmax - vcenter 103 | max_abs_diff = max(abs(diff_vmin), abs(diff_vmax)) 104 | vmin = vcenter - max_abs_diff 105 | vmax = vcenter + max_abs_diff 106 | if "vmin" in kwargs and "vmax" in kwargs: 107 | warnings.warn( 108 | f"You shouldn't specify both 'vmin' and 'vmax' when using CenteredNorm. 'vmin' and 'vmax' must be symmetric around 'vcenter' and are thus inferred from a single value. Using: vmin={vmin}, vcenter={vcenter}, vmax={vmax}." 109 | ) 110 | self.normalize.vmin = vmin 111 | self.normalize.vmax = vmax 112 | return self.colormap(self.normalize(data), bytes=True) 113 | 114 | def __repr__(self): 115 | return ( 116 | f"ColorScheme(\n" 117 | f" colormap={self._colormap},\n" 118 | f" normalize={self.normalize},\n" 119 | f" masked_color={self._masked_color},\n" 120 | f" true_color={self.true_color},\n" 121 | f" false_color={self.false_color},\n" 122 | f" inf_color={self._inf_color},\n" 123 | f" ninf_color={self._ninf_color}\n" 124 | f")" 125 | ) 126 | -------------------------------------------------------------------------------- /tensorhue/converters.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import warnings 3 | import numpy as np 4 | 5 | 6 | def tensor_to_numpy(tensor, **kwargs) -> np.ndarray: 7 | """ 8 | Converts a tensor of unknown type to a numpy array. 9 | 10 | Args: 11 | tensor (Any): The tensor to be converted. 12 | **kwargs: Additional keyword arguments that are passed to the underlying converter functions. 13 | 14 | Returns: 15 | The converted numpy array. 16 | """ 17 | mro_strings = mro_to_strings(tensor.__class__.__mro__) 18 | 19 | if "numpy.ndarray" in mro_strings: 20 | return tensor 21 | if "torch.Tensor" in mro_strings: 22 | return _tensor_to_numpy_torch(tensor, **kwargs) 23 | if "tensorflow.python.types.core.Tensor" in mro_strings: 24 | return _tensor_to_numpy_tensorflow(tensor, **kwargs) 25 | if "jaxlib.xla_extension.DeviceArray" in mro_strings: 26 | return _tensor_to_numpy_jax(tensor, **kwargs) 27 | if "PIL.Image.Image" in mro_strings: 28 | return _tensor_to_numpy_pillow(tensor, **kwargs) 29 | raise NotImplementedError( 30 | f"Conversion of tensor of type {type(tensor)} is not supported. Please raise an issue of you think this is a bug or should be implemented." 31 | ) 32 | 33 | 34 | def mro_to_strings(mro) -> list[str]: 35 | """ 36 | Converts the __mro__ of a class to a list of module.class_name strings. 37 | 38 | Args: 39 | mro (tuple[type]): The __mro__ to be converted. 40 | 41 | Returns: 42 | The converted list of strings. 43 | """ 44 | return [f"{c.__module__}.{c.__name__}" for c in mro] 45 | 46 | 47 | def _tensor_to_numpy_torch(tensor) -> np.ndarray: 48 | if tensor.__class__.__name__ == "MaskedTensor": # hacky - but we shouldn't import torch here 49 | return np.ma.masked_array(tensor.get_data(), ~tensor.get_mask()) 50 | try: # pylint: disable=duplicate-code 51 | return tensor.numpy() 52 | except RuntimeError as e: 53 | raise NotImplementedError( 54 | f"{e}: It looks like tensors of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them." 55 | ) from e 56 | except Exception as e: 57 | raise RuntimeError( 58 | f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}" 59 | ) from e 60 | 61 | 62 | def _tensor_to_numpy_tensorflow(tensor) -> np.ndarray: 63 | if tensor.__class__.__name__ == "RaggedTensor": # hacky - but we shouldn't import torch here 64 | warnings.warn( 65 | "Tensorflow RaggedTensors are currently converted to dense tensors by filling with the value 0. Values that are actually 0 and filled-in values will appear indistinguishable. This behavior will change in the future." 66 | ) 67 | return _tensor_to_numpy_tensorflow(tensor.to_tensor()) 68 | if tensor.__class__.__name__ == "SparseTensor": 69 | raise ValueError("Tensorflow SparseTensors are not yet supported by TensorHue.") 70 | try: # pylint: disable=duplicate-code 71 | return tensor.numpy() 72 | except RuntimeError as e: 73 | raise NotImplementedError( 74 | f"{e}: It looks like tensors of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them." 75 | ) from e 76 | except Exception as e: 77 | raise RuntimeError( 78 | f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}" 79 | ) from e 80 | 81 | 82 | def _tensor_to_numpy_jax(tensor) -> np.ndarray: 83 | not_implemented = {"ShapedArray", "UnshapedArray", "AbstractArray"} 84 | if {c.__name__ for c in tensor.__class__.__mro__}.intersection( 85 | not_implemented 86 | ): # hacky - but we shouldn't import jax here 87 | raise NotImplementedError( 88 | f"Jax arrays of type {tensor.__class__.__name__} cannot be visualized. Raise an issue if you believe this is wrong." 89 | ) 90 | try: 91 | array = np.asarray(tensor) 92 | if array.dtype == "object": 93 | raise RuntimeError("Got non-visualizable dtype 'object'.") 94 | return array 95 | except RuntimeError as e: 96 | raise NotImplementedError( 97 | f"{e}: It looks like JAX arrays of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them." 98 | ) from e 99 | except Exception as e: 100 | raise RuntimeError( 101 | f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}" 102 | ) from e 103 | 104 | 105 | def _tensor_to_numpy_pillow(image, thumbnail, max_size) -> np.ndarray: 106 | try: 107 | image = image.convert("RGB") 108 | except Exception as e: 109 | raise ValueError("Could not convert image from mode '{mode}' to 'RGB'.") from e 110 | 111 | if thumbnail: 112 | image.thumbnail(max_size) 113 | 114 | array = np.array(image) 115 | assert array.dtype == "uint8" 116 | 117 | return array 118 | -------------------------------------------------------------------------------- /tensorhue/viz.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations # backwards type hint compatability for Python 3.7 and 3.8 2 | import os 3 | import sys 4 | import warnings 5 | from rich.console import Console 6 | import numpy as np 7 | from tensorhue.colors import ColorScheme 8 | from tensorhue._print_opts import PRINT_OPTS 9 | from tensorhue.converters import tensor_to_numpy, mro_to_strings 10 | 11 | 12 | def viz(tensor, **kwargs): 13 | try: 14 | mro_strings = mro_to_strings(tensor.__class__.__mro__) 15 | if "PIL.Image.Image" in mro_strings: 16 | _viz_image(tensor, **kwargs) 17 | else: 18 | _viz(tensor, **kwargs) 19 | except Exception as e: 20 | raise NotImplementedError( 21 | f"TensorHue currently does not support type {type(tensor)}. Please raise an issue if you want to visualize them.." 22 | ) from e 23 | 24 | 25 | def _viz(tensor, colorscheme: ColorScheme = None, legend: bool = True, scale: int = 1, **kwargs): 26 | """ 27 | Prints a tensor using colored Unicode art representation. 28 | 29 | Args: 30 | tensor (Any): The tensor to be visualized. 31 | colorscheme (ColorScheme, optional): The color scheme to use. 32 | Defaults to None, which means the global default color scheme is used. 33 | legend (bool, optional): Whether or not to include legend information (like the shape) 34 | scale (int, optional): Scales the size of the entire tensor up, making the unicode 'pixels' larger. 35 | **kwargs: Additional keyword arguments that are passed to the underlying viz function (vmin or vmax) 36 | """ 37 | if not isinstance(scale, int): 38 | raise ValueError("scale must be an integer.") 39 | 40 | if colorscheme is None: 41 | colorscheme = PRINT_OPTS.colorscheme 42 | 43 | np_array = tensor_to_numpy(tensor) 44 | shape = np_array.shape 45 | ndim = np_array.ndim 46 | 47 | if ndim == 1: 48 | np_array = np_array[np.newaxis, :] 49 | elif ndim > 2: 50 | raise NotImplementedError( 51 | "Visualization of tensors with more than 2 dimensions is under development. Please slice them for now." 52 | ) 53 | 54 | np_array = np.repeat(np.repeat(np_array, scale, axis=1), scale, axis=0) 55 | 56 | result_lines = _viz_2d(np_array, colorscheme, **kwargs) 57 | 58 | if legend: 59 | result_lines.append(f"[italic]shape = {shape}[/]") 60 | 61 | c = Console(log_path=False, record=False) 62 | c.print("\n".join(result_lines)) 63 | 64 | 65 | def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme = None, **kwargs) -> list[str]: 66 | """ 67 | Constructs a list of rich-compatible strings out of a 2D numpy array. 68 | 69 | Args: 70 | array_2d (np.ndarray): The 2-dimensional numpy array (or 3-dimensional if the values are already RGB). 71 | colorscheme (ColorScheme): The color scheme to use. If None, the array must be 3-dimensional (already RGB values). 72 | **kwargs: Additional keyword arguments that are passed to the underlying viz function (vmin or vmax) 73 | """ 74 | terminal_width = get_terminal_size().columns 75 | shape = array_2d.shape 76 | 77 | if shape[1] > terminal_width: 78 | slice_left = (terminal_width - 5) // 2 79 | slice_right = slice_left + (terminal_width - 5) % 2 80 | if colorscheme is not None: 81 | colors_right = colorscheme(array_2d[:, -slice_right:])[..., :3] 82 | else: 83 | assert ( 84 | array_2d.ndim == 3 and array_2d.shape[-1] == 3 85 | ), "Array shape must be 3-dimensional (*, *, 3) when colorscheme=None." 86 | colors_right = array_2d[:, -slice_right:, :] 87 | else: 88 | slice_left = shape[1] 89 | slice_right = colors_right = False 90 | 91 | if colorscheme is not None: 92 | colors_left = colorscheme(array_2d[:, :slice_left], **kwargs)[..., :3] 93 | else: 94 | assert ( 95 | array_2d.ndim == 3 and array_2d.shape[-1] == 3 96 | ), "Array shape must be 3-dimensional (*, *, 3) when colorscheme=None." 97 | colors_left = array_2d[:, :slice_left, :] 98 | 99 | result_lines = _construct_unicode_string(colors_left, colors_right) 100 | 101 | return result_lines 102 | 103 | 104 | def _construct_unicode_string(colors_left: np.ndarray, colors_right: np.ndarray) -> str: 105 | result_lines = [""] 106 | 107 | for y in range(0, colors_left.shape[0] - 1, 2): 108 | for x in range(colors_left.shape[1]): 109 | result_lines[ 110 | -1 111 | ] += f"[rgb({colors_left[y, x, 0]},{colors_left[y, x, 1]},{colors_left[y, x, 2]}) on rgb({colors_left[y+1, x, 0]},{colors_left[y+1, x, 1]},{colors_left[y+1, x, 2]})]▀[/]" 112 | if isinstance(colors_right, np.ndarray): 113 | result_lines[-1] += " ··· " 114 | for x in range(colors_right.shape[1]): 115 | result_lines[ 116 | -1 117 | ] += f"[rgb({colors_right[y, x, 0]},{colors_right[y, x, 1]},{colors_right[y, x, 2]}) on rgb({colors_right[y+1, x, 0]},{colors_right[y+1, x, 1]},{colors_right[y+1, x, 2]})]▀[/]" 118 | result_lines.append("") 119 | 120 | if colors_left.shape[0] % 2 == 1: 121 | for x in range(colors_left.shape[1]): 122 | result_lines[-1] += f"[rgb({colors_left[-1, x, 0]},{colors_left[-1, x, 1]},{colors_left[-1, x, 2]})]▀[/]" 123 | if isinstance(colors_right, np.ndarray): 124 | result_lines[-1] += " ··· " 125 | for x in range(colors_right.shape[1]): 126 | result_lines[ 127 | -1 128 | ] += f"[rgb({colors_right[-1, x, 0]},{colors_right[-1, x, 1]},{colors_right[-1, x, 2]})]▀[/]" 129 | else: 130 | result_lines = result_lines[:-1] 131 | 132 | return result_lines 133 | 134 | 135 | def _viz_image(image, legend: bool = False, thumbnail: bool = True, max_size: tuple[int, int] = None): 136 | """ 137 | A special case of _viz that does not use the ColorScheme but instead treats the tensor as RGB or greyscale values directly. 138 | 139 | Args: 140 | image (PIL.Image.Image): The image to visualize 141 | legend (bool, optional): Whether or not to include legend information (like the shape) 142 | thumbnail (bool, optional): Scales down the image size to a thumbnail that fits into the terminal window 143 | max_size (tuple[int, int], optional): The maximum size (width, height) to which the image gets downsized to. Only used if thumbnail=True. 144 | """ 145 | 146 | raise_max_size_warning = max_size and not thumbnail 147 | 148 | size = image.size 149 | mode = image.mode 150 | if max_size is None: 151 | terminal_size = get_terminal_size() 152 | else: 153 | terminal_size = os.terminal_size(max_size) 154 | max_size = (terminal_size.columns, (terminal_size.lines - 1) * 2) 155 | image = tensor_to_numpy(image, thumbnail=thumbnail, max_size=max_size) 156 | 157 | result_lines = _viz_2d(image) 158 | 159 | if legend: 160 | result_lines.append(f"[italic]size = {size}[/], [italic]mode = {mode}[/]") 161 | 162 | c = Console(log_path=False, record=False) 163 | c.print("\n".join(result_lines)) 164 | 165 | if raise_max_size_warning: 166 | warnings.warn( 167 | "You specified a max_size, but set thumbnail to False. Your max_size will be ignored unless thumbnail=True." 168 | ) 169 | 170 | 171 | def get_terminal_size(default_width: int = 100, default_height: int = 70) -> os.terminal_size: 172 | """ 173 | Returns the terminal size if the standard output is connected to a terminal. Otherwise, returns the defined default size. 174 | 175 | Args: 176 | default_width (int, optional): The default width to use if there is no terminal connected. 177 | default_height (int, optional): The default height to use if there is no terminal connected. 178 | """ 179 | if sys.stdout.isatty(): 180 | try: 181 | return os.get_terminal_size() 182 | except OSError: 183 | return os.terminal_size((default_width, default_height)) 184 | else: 185 | return os.terminal_size((default_width, default_height)) 186 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MAIN] 2 | 3 | # Specify a configuration file. 4 | #rcfile= 5 | 6 | # Python code to execute, usually for sys.path manipulation such as 7 | # pygtk.require(). 8 | #init-hook= 9 | 10 | # Files or directories to be skipped. They should be base names, not 11 | # paths. 12 | ignore=CVS 13 | 14 | # Add files or directories matching the regex patterns to the ignore-list. The 15 | # regex matches against paths and can be in Posix or Windows format. 16 | ignore-paths= 17 | 18 | # Files or directories matching the regex patterns are skipped. The regex 19 | # matches against base names, not paths. 20 | ignore-patterns=^\.# 21 | 22 | # Pickle collected data for later comparisons. 23 | persistent=yes 24 | 25 | # List of plugins (as comma separated values of python modules names) to load, 26 | # usually to register additional checkers. 27 | load-plugins= 28 | pylint.extensions.check_elif, 29 | pylint.extensions.bad_builtin, 30 | pylint.extensions.docparams, 31 | pylint.extensions.for_any_all, 32 | pylint.extensions.set_membership, 33 | pylint.extensions.code_style, 34 | pylint.extensions.overlapping_exceptions, 35 | pylint.extensions.typing, 36 | pylint.extensions.redefined_variable_type, 37 | pylint.extensions.comparison_placement, 38 | pylint.extensions.mccabe, 39 | 40 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the 41 | # number of processors available to use. 42 | jobs=0 43 | 44 | # When enabled, pylint would attempt to guess common misconfiguration and emit 45 | # user-friendly hints instead of false-positive error messages. 46 | suggestion-mode=yes 47 | 48 | # Allow loading of arbitrary C extensions. Extensions are imported into the 49 | # active Python interpreter and may run arbitrary code. 50 | unsafe-load-any-extension=no 51 | 52 | # A comma-separated list of package or module names from where C extensions may 53 | # be loaded. Extensions are loading into the active Python interpreter and may 54 | # run arbitrary code 55 | extension-pkg-allow-list= 56 | 57 | # Minimum supported python version 58 | py-version = 3.7.2 59 | 60 | # Control the amount of potential inferred values when inferring a single 61 | # object. This can help the performance when dealing with large functions or 62 | # complex, nested conditions. 63 | limit-inference-results=100 64 | 65 | # Specify a score threshold to be exceeded before program exits with error. 66 | fail-under=10.0 67 | 68 | # Return non-zero exit code if any of these messages/categories are detected, 69 | # even if score is above --fail-under value. Syntax same as enable. Messages 70 | # specified are enabled, while categories only check already-enabled messages. 71 | fail-on= 72 | 73 | 74 | [MESSAGES CONTROL] 75 | 76 | # Only show warnings with the listed confidence levels. Leave empty to show 77 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 78 | # confidence= 79 | 80 | # Enable the message, report, category or checker with the given id(s). You can 81 | # either give multiple identifier separated by comma (,) or put this option 82 | # multiple time (only on the command line, not in the configuration file where 83 | # it should appear only once). See also the "--disable" option for examples. 84 | enable= 85 | use-symbolic-message-instead, 86 | useless-suppression, 87 | 88 | # Disable the message, report, category or checker with the given id(s). You 89 | # can either give multiple identifiers separated by comma (,) or put this 90 | # option multiple times (only on the command line, not in the configuration 91 | # file where it should appear only once).You can also use "--disable=all" to 92 | # disable everything first and then re-enable specific checks. For example, if 93 | # you want to run only the similarities checker, you can use "--disable=all 94 | # --enable=similarities". If you want to run only the classes checker, but have 95 | # no Warning level messages displayed, use"--disable=all --enable=classes 96 | # --disable=W" 97 | 98 | disable= 99 | attribute-defined-outside-init, 100 | invalid-name, 101 | missing-docstring, 102 | protected-access, 103 | too-few-public-methods, 104 | # handled by black 105 | format, 106 | # We anticipate #3512 where it will become optional 107 | fixme, 108 | cyclic-import, 109 | import-error, 110 | 111 | 112 | [REPORTS] 113 | 114 | # Set the output format. Available formats are text, parseable, colorized, msvs 115 | # (visual studio) and html. You can also give a reporter class, eg 116 | # mypackage.mymodule.MyReporterClass. 117 | output-format=text 118 | 119 | # Tells whether to display a full report or only the messages 120 | reports=no 121 | 122 | # Python expression which should return a note less than 10 (10 is the highest 123 | # note). You have access to the variables 'fatal', 'error', 'warning', 'refactor', 'convention' 124 | # and 'info', which contain the number of messages in each category, as 125 | # well as 'statement', which is the total number of statements analyzed. This 126 | # score is used by the global evaluation report (RP0004). 127 | evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) 128 | 129 | # Template used to display messages. This is a python new-style format string 130 | # used to format the message information. See doc for all details 131 | #msg-template= 132 | 133 | # Activate the evaluation score. 134 | score=yes 135 | 136 | 137 | [LOGGING] 138 | 139 | # Logging modules to check that the string format arguments are in logging 140 | # function parameter format 141 | logging-modules=logging 142 | 143 | # The type of string formatting that logging methods do. `old` means using % 144 | # formatting, `new` is for `{}` formatting. 145 | logging-format-style=old 146 | 147 | 148 | [MISCELLANEOUS] 149 | 150 | # List of note tags to take in consideration, separated by a comma. 151 | notes=FIXME,XXX,TODO 152 | 153 | # Regular expression of note tags to take in consideration. 154 | #notes-rgx= 155 | 156 | 157 | [SIMILARITIES] 158 | 159 | # Minimum lines number of a similarity. 160 | min-similarity-lines=6 161 | 162 | # Ignore comments when computing similarities. 163 | ignore-comments=yes 164 | 165 | # Ignore docstrings when computing similarities. 166 | ignore-docstrings=yes 167 | 168 | # Ignore imports when computing similarities. 169 | ignore-imports=yes 170 | 171 | # Signatures are removed from the similarity computation 172 | ignore-signatures=yes 173 | 174 | 175 | [VARIABLES] 176 | 177 | # Tells whether we should check for unused import in __init__ files. 178 | init-import=no 179 | 180 | # A regular expression matching the name of dummy variables (i.e. expectedly 181 | # not used). 182 | dummy-variables-rgx=_$|dummy 183 | 184 | # List of additional names supposed to be defined in builtins. Remember that 185 | # you should avoid defining new builtins when possible. 186 | additional-builtins= 187 | 188 | # List of strings which can identify a callback function by name. A callback 189 | # name must start or end with one of those strings. 190 | callbacks=cb_,_cb 191 | 192 | # Tells whether unused global variables should be treated as a violation. 193 | allow-global-unused-variables=yes 194 | 195 | # List of names allowed to shadow builtins 196 | allowed-redefined-builtins= 197 | 198 | # Argument names that match this expression will be ignored. Default to name 199 | # with leading underscore. 200 | ignored-argument-names=_.* 201 | 202 | # List of qualified module names which can have objects that can redefine 203 | # builtins. 204 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io 205 | 206 | 207 | [FORMAT] 208 | 209 | # Maximum number of characters on a single line. 210 | max-line-length=120 211 | 212 | # Regexp for a line that is allowed to be longer than the limit. 213 | ignore-long-lines=^\s*(# )??$ 214 | 215 | # Allow the body of an if to be on the same line as the test if there is no 216 | # else. 217 | single-line-if-stmt=no 218 | 219 | # Allow the body of a class to be on the same line as the declaration if body 220 | # contains single statement. 221 | single-line-class-stmt=no 222 | 223 | # Maximum number of lines in a module 224 | max-module-lines=1000 225 | 226 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 227 | # tab). 228 | indent-string=' ' 229 | 230 | # Number of spaces of indent required inside a hanging or continued line. 231 | indent-after-paren=4 232 | 233 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 234 | expected-line-ending-format= 235 | 236 | 237 | [BASIC] 238 | 239 | # Good variable names which should always be accepted, separated by a comma 240 | good-names=i,j,k,ex,Run,_ 241 | 242 | # Good variable names regexes, separated by a comma. If names match any regex, 243 | # they will always be accepted 244 | good-names-rgxs= 245 | 246 | # Bad variable names which should always be refused, separated by a comma 247 | bad-names=foo,bar,baz,toto,tutu,tata 248 | 249 | # Bad variable names regexes, separated by a comma. If names match any regex, 250 | # they will always be refused 251 | bad-names-rgxs= 252 | 253 | # Colon-delimited sets of names that determine each other's naming style when 254 | # the name regexes allow several styles. 255 | name-group= 256 | 257 | # Include a hint for the correct naming format with invalid-name 258 | include-naming-hint=no 259 | 260 | # Naming style matching correct function names. 261 | function-naming-style=snake_case 262 | 263 | # Regular expression matching correct function names 264 | function-rgx=[a-z_][a-z0-9_]{2,30}$ 265 | 266 | # Naming style matching correct variable names. 267 | variable-naming-style=snake_case 268 | 269 | # Regular expression matching correct variable names 270 | variable-rgx=[a-z_][a-z0-9_]{2,30}$ 271 | 272 | # Naming style matching correct constant names. 273 | const-naming-style=UPPER_CASE 274 | 275 | # Regular expression matching correct constant names 276 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 277 | 278 | # Naming style matching correct attribute names. 279 | attr-naming-style=snake_case 280 | 281 | # Regular expression matching correct attribute names 282 | attr-rgx=[a-z_][a-z0-9_]{2,}$ 283 | 284 | # Naming style matching correct argument names. 285 | argument-naming-style=snake_case 286 | 287 | # Regular expression matching correct argument names 288 | argument-rgx=[a-z_][a-z0-9_]{2,30}$ 289 | 290 | # Naming style matching correct class attribute names. 291 | class-attribute-naming-style=any 292 | 293 | # Regular expression matching correct class attribute names 294 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ 295 | 296 | # Naming style matching correct class constant names. 297 | class-const-naming-style=UPPER_CASE 298 | 299 | # Regular expression matching correct class constant names. Overrides class- 300 | # const-naming-style. 301 | #class-const-rgx= 302 | 303 | # Naming style matching correct inline iteration names. 304 | inlinevar-naming-style=any 305 | 306 | # Regular expression matching correct inline iteration names 307 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ 308 | 309 | # Naming style matching correct class names. 310 | class-naming-style=PascalCase 311 | 312 | # Regular expression matching correct class names 313 | class-rgx=[A-Z_][a-zA-Z0-9]+$ 314 | 315 | 316 | # Naming style matching correct module names. 317 | module-naming-style=snake_case 318 | 319 | # Regular expression matching correct module names 320 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 321 | 322 | 323 | # Naming style matching correct method names. 324 | method-naming-style=snake_case 325 | 326 | # Regular expression matching correct method names 327 | method-rgx=[a-z_][a-z0-9_]{2,}$ 328 | 329 | # Regular expression which can overwrite the naming style set by typevar-naming-style. 330 | #typevar-rgx= 331 | 332 | # Regular expression which should only match function or class names that do 333 | # not require a docstring. Use ^(?!__init__$)_ to also check __init__. 334 | no-docstring-rgx=__.*__ 335 | 336 | # Minimum line length for functions/classes that require docstrings, shorter 337 | # ones are exempt. 338 | docstring-min-length=-1 339 | 340 | # List of decorators that define properties, such as abc.abstractproperty. 341 | property-classes=abc.abstractproperty 342 | 343 | 344 | [TYPECHECK] 345 | 346 | # Regex pattern to define which classes are considered mixins if ignore-mixin- 347 | # members is set to 'yes' 348 | mixin-class-rgx=.*MixIn 349 | 350 | # List of module names for which member attributes should not be checked 351 | # (useful for modules/projects where namespaces are manipulated during runtime 352 | # and thus existing member attributes cannot be deduced by static analysis). It 353 | # supports qualified module names, as well as Unix pattern matching. 354 | ignored-modules= 355 | 356 | # List of class names for which member attributes should not be checked (useful 357 | # for classes with dynamically set attributes). This supports the use of 358 | # qualified names. 359 | ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local 360 | 361 | # List of members which are set dynamically and missed by pylint inference 362 | # system, and so shouldn't trigger E1101 when accessed. Python regular 363 | # expressions are accepted. 364 | generated-members=REQUEST,acl_users,aq_parent,argparse.Namespace 365 | 366 | # List of decorators that create context managers from functions, such as 367 | # contextlib.contextmanager. 368 | contextmanager-decorators=contextlib.contextmanager 369 | 370 | # Tells whether to warn about missing members when the owner of the attribute 371 | # is inferred to be None. 372 | ignore-none=yes 373 | 374 | # This flag controls whether pylint should warn about no-member and similar 375 | # checks whenever an opaque object is returned when inferring. The inference 376 | # can return multiple potential results while evaluating a Python object, but 377 | # some branches might not be evaluated, which results in partial inference. In 378 | # that case, it might be useful to still emit no-member and other checks for 379 | # the rest of the inferred objects. 380 | ignore-on-opaque-inference=yes 381 | 382 | # Show a hint with possible names when a member name was not found. The aspect 383 | # of finding the hint is based on edit distance. 384 | missing-member-hint=yes 385 | 386 | # The minimum edit distance a name should have in order to be considered a 387 | # similar match for a missing member name. 388 | missing-member-hint-distance=1 389 | 390 | # The total number of similar names that should be taken in consideration when 391 | # showing a hint for a missing member. 392 | missing-member-max-choices=1 393 | 394 | [SPELLING] 395 | 396 | # Spelling dictionary name. Available dictionaries: none. To make it working 397 | # install python-enchant package. 398 | spelling-dict= 399 | 400 | # List of comma separated words that should not be checked. 401 | spelling-ignore-words= 402 | 403 | # List of comma separated words that should be considered directives if they 404 | # appear and the beginning of a comment and should not be checked. 405 | spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:,pragma:,# noinspection 406 | 407 | # A path to a file that contains private dictionary; one word per line. 408 | spelling-private-dict-file=.pyenchant_pylint_custom_dict.txt 409 | 410 | # Tells whether to store unknown words to indicated private dictionary in 411 | # --spelling-private-dict-file option instead of raising a message. 412 | spelling-store-unknown-words=no 413 | 414 | # Limits count of emitted suggestions for spelling mistakes. 415 | max-spelling-suggestions=2 416 | 417 | 418 | [DESIGN] 419 | 420 | # Maximum number of arguments for function / method 421 | max-args=10 422 | 423 | # Maximum number of locals for function / method body 424 | max-locals=25 425 | 426 | # Maximum number of return / yield for function / method body 427 | max-returns=11 428 | 429 | # Maximum number of branch for function / method body 430 | max-branches=27 431 | 432 | # Maximum number of statements in function / method body 433 | max-statements=100 434 | 435 | # Maximum number of parents for a class (see R0901). 436 | max-parents=7 437 | 438 | # List of qualified class names to ignore when counting class parents (see R0901). 439 | ignored-parents= 440 | 441 | # Maximum number of attributes for a class (see R0902). 442 | max-attributes=11 443 | 444 | # Minimum number of public methods for a class (see R0903). 445 | min-public-methods=2 446 | 447 | # Maximum number of public methods for a class (see R0904). 448 | max-public-methods=25 449 | 450 | # Maximum number of boolean expressions in an if statement (see R0916). 451 | max-bool-expr=5 452 | 453 | # List of regular expressions of class ancestor names to 454 | # ignore when counting public methods (see R0903). 455 | exclude-too-few-public-methods= 456 | 457 | max-complexity=10 458 | 459 | [CLASSES] 460 | 461 | # List of method names used to declare (i.e. assign) instance attributes. 462 | defining-attr-methods=__init__,__new__,setUp,__post_init__ 463 | 464 | # List of valid names for the first argument in a class method. 465 | valid-classmethod-first-arg=cls 466 | 467 | # List of valid names for the first argument in a metaclass class method. 468 | valid-metaclass-classmethod-first-arg=mcs 469 | 470 | # List of member names, which should be excluded from the protected access 471 | # warning. 472 | exclude-protected=_asdict,_fields,_replace,_source,_make 473 | 474 | # Warn about protected attribute access inside special methods 475 | check-protected-access-in-special-methods=no 476 | 477 | [IMPORTS] 478 | 479 | # List of modules that can be imported at any level, not just the top level 480 | # one. 481 | allow-any-import-level= 482 | 483 | # Allow wildcard imports from modules that define __all__. 484 | allow-wildcard-with-all=no 485 | 486 | # Analyse import fallback blocks. This can be used to support both Python 2 and 487 | # 3 compatible code, which means that the block might have code that exists 488 | # only in one or another interpreter, leading to false positives when analysed. 489 | analyse-fallback-blocks=no 490 | 491 | # Deprecated modules which should not be used, separated by a comma 492 | deprecated-modules=regsub,TERMIOS,Bastion,rexec 493 | 494 | # Create a graph of every (i.e. internal and external) dependencies in the 495 | # given file (report RP0402 must not be disabled) 496 | import-graph= 497 | 498 | # Create a graph of external dependencies in the given file (report RP0402 must 499 | # not be disabled) 500 | ext-import-graph= 501 | 502 | # Create a graph of internal dependencies in the given file (report RP0402 must 503 | # not be disabled) 504 | int-import-graph= 505 | 506 | # Force import order to recognize a module as part of the standard 507 | # compatibility libraries. 508 | known-standard-library= 509 | 510 | # Force import order to recognize a module as part of a third party library. 511 | known-third-party=enchant 512 | 513 | # Couples of modules and preferred modules, separated by a comma. 514 | preferred-modules= 515 | 516 | 517 | [EXCEPTIONS] 518 | 519 | # Exceptions that will emit a warning when being caught. Defaults to 520 | # "Exception" 521 | overgeneral-exceptions=builtins.Exception 522 | 523 | 524 | [TYPING] 525 | 526 | # Set to ``no`` if the app / library does **NOT** need to support runtime 527 | # introspection of type annotations. If you use type annotations 528 | # **exclusively** for type checking of an application, you're probably fine. 529 | # For libraries, evaluate if some users what to access the type hints at 530 | # runtime first, e.g., through ``typing.get_type_hints``. Applies to Python 531 | # versions 3.7 - 3.9 532 | runtime-typing = no 533 | 534 | 535 | [DEPRECATED_BUILTINS] 536 | 537 | # List of builtins function names that should not be used, separated by a comma 538 | bad-functions=map,input 539 | 540 | 541 | [REFACTORING] 542 | 543 | # Maximum number of nested blocks for function / method body 544 | max-nested-blocks=5 545 | 546 | # Complete name of functions that never returns. When checking for 547 | # inconsistent-return-statements if a never returning function is called then 548 | # it will be considered as an explicit return statement and no message will be 549 | # printed. 550 | never-returning-functions=sys.exit,argparse.parse_error 551 | 552 | 553 | [STRING] 554 | 555 | # This flag controls whether inconsistent-quotes generates a warning when the 556 | # character used as a quote delimiter is used inconsistently within a module. 557 | check-quote-consistency=no 558 | 559 | # This flag controls whether the implicit-str-concat should generate a warning 560 | # on implicit string concatenation in sequences defined over several lines. 561 | check-str-concat-over-line-jumps=no 562 | 563 | 564 | [CODE_STYLE] 565 | 566 | # Max line length for which to sill emit suggestions. Used to prevent optional 567 | # suggestions which would get split by a code formatter (e.g., black). Will 568 | # default to the setting for ``max-line-length``. 569 | #max-line-length-suggestions= 570 | --------------------------------------------------------------------------------