├── imgs ├── video.gif ├── flow_ts.gif ├── overlay_1.png ├── overlay_2.png ├── rich_info.gif ├── torchshow.png ├── axes_title.jpg ├── batch_imgs.gif ├── cat_mask_ts.gif ├── featuremap.gif ├── video_grid.gif ├── RGB_image_plt.gif ├── RGB_image_ts.gif ├── cat_mask_plt.gif └── custom_layout.gif ├── torchshow ├── __init__.py ├── config.py ├── flow.py ├── utils.py ├── torchshow.py └── visualization.py ├── .gitignore ├── setup.py ├── LICENSE ├── changelogs.md ├── API.md ├── README.md └── tests └── tests.py /imgs/video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/video.gif -------------------------------------------------------------------------------- /imgs/flow_ts.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/flow_ts.gif -------------------------------------------------------------------------------- /imgs/overlay_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/overlay_1.png -------------------------------------------------------------------------------- /imgs/overlay_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/overlay_2.png -------------------------------------------------------------------------------- /imgs/rich_info.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/rich_info.gif -------------------------------------------------------------------------------- /imgs/torchshow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/torchshow.png -------------------------------------------------------------------------------- /imgs/axes_title.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/axes_title.jpg -------------------------------------------------------------------------------- /imgs/batch_imgs.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/batch_imgs.gif -------------------------------------------------------------------------------- /imgs/cat_mask_ts.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/cat_mask_ts.gif -------------------------------------------------------------------------------- /imgs/featuremap.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/featuremap.gif -------------------------------------------------------------------------------- /imgs/video_grid.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/video_grid.gif -------------------------------------------------------------------------------- /imgs/RGB_image_plt.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/RGB_image_plt.gif -------------------------------------------------------------------------------- /imgs/RGB_image_ts.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/RGB_image_ts.gif -------------------------------------------------------------------------------- /imgs/cat_mask_plt.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/cat_mask_plt.gif -------------------------------------------------------------------------------- /imgs/custom_layout.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwying/torchshow/HEAD/imgs/custom_layout.gif -------------------------------------------------------------------------------- /torchshow/__init__.py: -------------------------------------------------------------------------------- 1 | from .torchshow import show, show_video, save, overlay 2 | from .config import set_image_std, set_image_mean, set_color_mode, show_rich_info -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | _torchshow/ 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.ipynb 6 | *.flo 7 | test_data/ 8 | ._* 9 | .DS_Store 10 | 11 | # Distribution / packaging 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | eggs/ 16 | lib/ 17 | *.egg-info/ 18 | .installed.cfg 19 | *.egg 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="torchshow", # Replace with your own username 8 | version="0.5.2", 9 | author="Xiaowen Ying", 10 | author_email="shawnying.inbox@gmail.com", 11 | description="Visualizing PyTorch tensors with a single line of code.", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/xwying/torchshow", 15 | packages=setuptools.find_packages(), 16 | install_requires=[ 17 | 'numpy', 18 | 'matplotlib'], 19 | classifiers=[ 20 | "Programming Language :: Python :: 3", 21 | "License :: OSI Approved :: MIT License", 22 | "Operating System :: OS Independent", 23 | ], 24 | python_requires='>=3.6', 25 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2021] [Xiaowen Ying] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /torchshow/config.py: -------------------------------------------------------------------------------- 1 | class Config(): 2 | def __init__(self): 3 | self.config_dict = {'backend': 'matplotlib', # May support more backend in the future 4 | 'image_mean': None, 5 | 'image_std': None, 6 | 'color_mode': 'rgb', 7 | 'show_rich_info': True, 8 | } 9 | 10 | def set(self, key, value): 11 | if key in self.config_dict.keys(): 12 | self.config_dict[key] = value 13 | 14 | def get(self, key): 15 | return self.config_dict.get(key, None) 16 | 17 | config = Config() 18 | 19 | def set_image_mean(mean: list): 20 | assert len(mean)==3 21 | config.set('image_mean', mean) 22 | 23 | def set_image_std(std: list): 24 | assert len(std)==3 25 | config.set('image_std', std) 26 | 27 | def set_color_mode(value='rgb'): 28 | assert value in ['rgb', 'bgr'] 29 | config.set('color_mode', value) 30 | 31 | def show_rich_info(flag: bool): 32 | if flag: 33 | config.set('show_rich_info', True) 34 | else: 35 | config.set('show_rich_info', False) -------------------------------------------------------------------------------- /changelogs.md: -------------------------------------------------------------------------------- 1 | # Changelogs 2 | 3 | ## [2025-03-19] v0.5.2 4 | - Modify set_window_title function to avoid version comparison bugs ([#24](https://github.com/xwying/torchshow/issues/24)). 5 | - Support more data dtype for pytorch tensors. 6 | 7 | ## [2023-07-02] v0.5.1 8 | - Fix `np.int` depreciation issues ([#13](https://github.com/xwying/torchshow/pull/13)). 9 | - Allow specifying `nrows` and `ncols` when visualizing a list of tensors ([#17](https://github.com/xwying/torchshow/pull/17)). 10 | - Fix unexpected white spaces when saving figures ([#19](https://github.com/xwying/torchshow/pull/19)). 11 | 12 | ## [2022-11-07] v0.5.0 13 | - Support specifying the color map for grayscale image. 14 | - Support PIL Image. 15 | - Support filenames. 16 | - Addinng `ts.overlay()` API which can be used to blend multiple plots. 17 | - Fix bugs when unnormalize with customize mean and std 18 | 19 | ## [2022-06-30] v0.4.2 20 | - You can specify the `figsize`, `dpi`, and `subtitle` parameter when calling ts.show(). 21 | - Add some missing APIs to `ts.save()`. 22 | - Revisit the option to add axes titles. 23 | - Add tight_layout option to `ts.show_video` (enabled by default). 24 | - Fix some bugs. 25 | - Create API Reference Page. 26 | 27 | ## [2022-05-21] v0.4.1 28 | - Now you can get richer information from a pixel (e.g. raw pixel value) by hovering the mouse over the pixels. 29 | - Fix the unexpected colors around the contour while visualizing categorical masks. 30 | 31 | ## [2022-05-19] v0.4.0 32 | - TorchShow will now automatically check if running in an ipython environment (e.g. jupyter notebook). Remove `ts.inline()` since it is no longer needed. 33 | - Fix a bug where binary mask will be inferred as categorical mask. 34 | - Optimize the logic to handle a few corner cases. 35 | 36 | 37 | ## [2021-08-23] v0.3.2 38 | - Adding `ts.save(tensor)` API for saving figs instead of showing them. This is more convenient compared to the headless mode. - Remove surrounding white spaces of the saved figures. 39 | - ts.headless() has been removed. Use ts.save() instead. 40 | 41 | ## [2021-06-14] v0.3.1 42 | - Fixes some bugs. 43 | - Now support headless mode useful for running on server without display. After setting `ts.headless(True)`, calling `ts.show(tensor)` will save the figure under `./_torchshow/`. 44 | 45 | ## [2021-04-25] v0.3.0 46 | - Adding optical flow support. -------------------------------------------------------------------------------- /torchshow/flow.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-08-03 14 | 15 | import numpy as np 16 | 17 | def make_colorwheel(): 18 | """ 19 | Generates a color wheel for optical flow visualization as presented in: 20 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 21 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 22 | Code follows the original C++ source code of Daniel Scharstein. 23 | Code follows the the Matlab source code of Deqing Sun. 24 | Returns: 25 | np.ndarray: Color wheel 26 | """ 27 | 28 | RY = 15 29 | YG = 6 30 | GC = 4 31 | CB = 11 32 | BM = 13 33 | MR = 6 34 | 35 | ncols = RY + YG + GC + CB + BM + MR 36 | colorwheel = np.zeros((ncols, 3)) 37 | col = 0 38 | 39 | # RY 40 | colorwheel[0:RY, 0] = 255 41 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 42 | col = col+RY 43 | # YG 44 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 45 | colorwheel[col:col+YG, 1] = 255 46 | col = col+YG 47 | # GC 48 | colorwheel[col:col+GC, 1] = 255 49 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 50 | col = col+GC 51 | # CB 52 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 53 | colorwheel[col:col+CB, 2] = 255 54 | col = col+CB 55 | # BM 56 | colorwheel[col:col+BM, 2] = 255 57 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 58 | col = col+BM 59 | # MR 60 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 61 | colorwheel[col:col+MR, 0] = 255 62 | return colorwheel 63 | 64 | 65 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 66 | """ 67 | Applies the flow color wheel to (possibly clipped) flow components u and v. 68 | According to the C++ source code of Daniel Scharstein 69 | According to the Matlab source code of Deqing Sun 70 | Args: 71 | u (np.ndarray): Input horizontal flow of shape [H,W] 72 | v (np.ndarray): Input vertical flow of shape [H,W] 73 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 74 | Returns: 75 | np.ndarray: Flow visualization image of shape [H,W,3] 76 | """ 77 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 78 | colorwheel = make_colorwheel() # shape [55x3] 79 | ncols = colorwheel.shape[0] 80 | rad = np.sqrt(np.square(u) + np.square(v)) 81 | a = np.arctan2(-v, -u)/np.pi 82 | fk = (a+1) / 2*(ncols-1) 83 | k0 = np.floor(fk).astype(np.int32) 84 | k1 = k0 + 1 85 | k1[k1 == ncols] = 0 86 | f = fk - k0 87 | for i in range(colorwheel.shape[1]): 88 | tmp = colorwheel[:,i] 89 | col0 = tmp[k0] / 255.0 90 | col1 = tmp[k1] / 255.0 91 | col = (1-f)*col0 + f*col1 92 | idx = (rad <= 1) 93 | col[idx] = 1 - rad[idx] * (1-col[idx]) 94 | col[~idx] = col[~idx] * 0.75 # out of range 95 | # Note the 2-i => BGR instead of RGB 96 | ch_idx = 2-i if convert_to_bgr else i 97 | flow_image[:,:,ch_idx] = np.floor(255 * col) 98 | return flow_image 99 | 100 | 101 | def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): 102 | """ 103 | Expects a two dimensional flow image of shape. 104 | Args: 105 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 106 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 107 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 108 | Returns: 109 | np.ndarray: Flow visualization image of shape [H,W,3] 110 | """ 111 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 112 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 113 | if clip_flow is not None: 114 | flow_uv = np.clip(flow_uv, 0, clip_flow) 115 | u = flow_uv[:,:,0] 116 | v = flow_uv[:,:,1] 117 | rad = np.sqrt(np.square(u) + np.square(v)) 118 | rad_max = np.max(rad) 119 | epsilon = 1e-5 120 | u = u / (rad_max + epsilon) 121 | v = v / (rad_max + epsilon) 122 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /torchshow/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import importlib 5 | import warnings 6 | import numbers 7 | import logging 8 | 9 | _EXIF_ORIENT = 274 10 | 11 | logger = logging.getLogger('TorchShow') 12 | 13 | def isnumber(x): 14 | return isinstance(x, numbers.Number) 15 | 16 | 17 | def isinteger(x): 18 | """ 19 | Function to check if np array has integer. Work regardless of type. 20 | e.g. isinteger(np.array([0., 1.5. 1.])) >>> array([ True, False, True], dtype=bool) 21 | 22 | """ 23 | return np.equal(np.mod(x, 1), 0) 24 | 25 | 26 | def within_0_1(x): 27 | return (x.min() >= 0) and (x.max() <= 1) 28 | 29 | 30 | def within_0_255(x): 31 | return (x.min() >= 0) and (x.max() <= 255) 32 | 33 | 34 | def calculate_grid_layout(N, img_H, img_W, nrow=None, ncol=None): 35 | """ 36 | Function to calculate grid_layout 37 | """ 38 | if (nrow != None): # `nrow` has higher priority than `ncol`. 39 | ncol = int(np.ceil(N / nrow)) 40 | elif (ncol != None): 41 | nrow = int(np.ceil(N / ncol)) 42 | else: # If both nrow and ncols are not set, perform automatic calculation. 43 | N_sqrt = np.sqrt(N) 44 | if img_H >= img_W: 45 | nrow = int(np.floor(N_sqrt)) 46 | ncol = int(np.ceil(N/nrow)) 47 | else: 48 | ncol = int(np.floor(N_sqrt)) 49 | nrow = int(np.ceil(N/ncol)) 50 | return nrow, ncol 51 | 52 | 53 | def tensor_to_array(x): 54 | # Recursively perform tensor to array conversion 55 | # ====== PyTorch Tensor ======= 56 | try: 57 | torch = import_if_not_yet("torch") 58 | except ImportError: 59 | pass 60 | else: 61 | if isinstance(x, torch.Tensor): 62 | x = x.detach().clone().cpu() 63 | if x.dtype in [torch.bfloat16, torch.float16, torch.bool]: 64 | logger.warning("The tensor with type \"{}\" was automatically converted to float32 for visualization.".format(x.dtype)) 65 | x = x.float() 66 | return x.numpy() 67 | 68 | # ====== PIL Image ======= 69 | try: 70 | Image = import_if_not_yet(module_name="Image", package_name="PIL") 71 | except ImportError: 72 | pass 73 | else: 74 | if isinstance(x, Image.Image): 75 | return np.asarray(x) 76 | 77 | # ====== Numpy Array ======= 78 | if isinstance(x, np.ndarray): 79 | return x.copy() 80 | 81 | # ====== Filename ======= 82 | elif isinstance(x, str): # Handling filename 83 | if os.path.exists(x): 84 | if x.split('.')[-1].lower() == 'flo': # Handling optical flow files. 85 | return read_flow(x) 86 | else: 87 | image = read_image_PIL(x) 88 | return np.asarray(image) 89 | else: 90 | raise FileNotFoundError(f"{x} is not a file.") 91 | # ====== Recursively processing list of inputs 92 | elif isinstance(x, list): 93 | return [tensor_to_array(e) for e in x] 94 | else: 95 | raise TypeError('Found unsupported type ', type(x)) 96 | 97 | 98 | def isnotebook(): 99 | # Check if running in ipython/jupyter notebook environment 100 | try: 101 | shell = get_ipython().__class__.__name__ 102 | if shell == 'ZMQInteractiveShell': 103 | return True # Jupyter notebook or qtconsole 104 | elif shell == 'TerminalInteractiveShell': 105 | return False # Terminal running IPython 106 | else: 107 | return False # Other type (?) 108 | except NameError: 109 | return False # Probably standard Python interpreter 110 | 111 | def import_if_not_yet(module_name, package_name=""): 112 | module_key = f"{package_name}.{module_name}".lstrip('.') 113 | if module_key not in sys.modules: 114 | return importlib.import_module(module_name, package_name) 115 | else: 116 | return sys.modules[module_key] 117 | 118 | def read_image_PIL(filename): 119 | try: 120 | Image = import_if_not_yet(module_name="Image", package_name="PIL") 121 | except ImportError: 122 | raise ImportError("TorchShow opens image files using PIL which is not yet installed.") 123 | else: 124 | image = Image.open(filename) 125 | mode = image.mode 126 | if mode == 'RGBA': # Convert RGBA to RGB since torchshow will handle 4 channel images differently. 127 | mode = 'RGB' 128 | 129 | """ 130 | The following code is for correctly applying the exif orientation. 131 | This code is modified based on https://github.com/facebookresearch/detectron2/blob/2b98c273b240b54d2d0ee6853dc331c4f2ca87b9/detectron2/data/detection_utils.py#L119 132 | """ 133 | 134 | if not hasattr(image, "getexif"): 135 | return image 136 | 137 | try: 138 | exif = image.getexif() 139 | except Exception: # https://github.com/facebookresearch/detectron2/issues/1885 140 | exif = None 141 | 142 | if exif is None: 143 | return image 144 | 145 | orientation = exif.get(_EXIF_ORIENT) 146 | method = { 147 | 2: Image.FLIP_LEFT_RIGHT, 148 | 3: Image.ROTATE_180, 149 | 4: Image.FLIP_TOP_BOTTOM, 150 | 5: Image.TRANSPOSE, 151 | 6: Image.ROTATE_270, 152 | 7: Image.TRANSVERSE, 153 | 8: Image.ROTATE_90, 154 | }.get(orientation) 155 | 156 | if method is not None: 157 | warnings.warn(f"TorchShow has detected orientation information in the EXIF of file {filename} and the corresponding transformed has been applied. Please be aware of this issue if you want to load this file with PIL in your code.") 158 | image = image.transpose(method) 159 | 160 | return image.convert(mode) 161 | 162 | 163 | def read_flow(filename): 164 | with open(filename, 'rb') as f: 165 | magic = np.fromfile(f, np.float32, count=1) 166 | if 202021.25 != magic: 167 | print('Magic number incorrect. Invalid .flo file') 168 | else: 169 | w = np.fromfile(f, np.int32, count=1)[0] 170 | h = np.fromfile(f, np.int32, count=1)[0] 171 | print('Reading %d x %d flo file' % (w, h)) 172 | data = np.fromfile(f, np.float32, count=2*w*h) 173 | # Reshape data into 3D array (columns, rows, bands) 174 | data2D = np.resize(data, (h, w, 2)) 175 | return data2D 176 | 177 | 178 | if __name__ == '__main__': 179 | for N in range(1,100): 180 | row, col = calculate_grid_layout(N, 10, 10) 181 | print(N, row, col, row*col) -------------------------------------------------------------------------------- /API.md: -------------------------------------------------------------------------------- 1 | # TorchShow API References 2 | 3 | ## torchshow.show 4 | 5 | ```python 6 | torchshow.show(x, 7 | mode='auto', 8 | auto_permute=True, 9 | display=True, 10 | nrows=None, 11 | ncols=None, 12 | channel_mode='auto', 13 | show_axis=False, 14 | tight_layout=True, 15 | suptitle=None, 16 | axes_title=None, 17 | figsize=None, 18 | dpi=None, 19 | cmap='gray') 20 | ``` 21 | 22 | ### Parameters: 23 | 24 | * **x**: *tensor-like (support both `torch.Tensor`, `np.ndarray` and `PIL Image`) or List of tensor-like. * The tensor data that we want to visualize. Filename and list of filenames are also supported, for example: `ts.show("my_image.jpg")`. 25 | 26 | * **mode**: *str*. The visualize mode. The default value is `"auto"` where TorchShow will automatically infer the mode. Available options are: `"image"`, `"flow"`, `"grayscale"`, `"categorical_mask"`. 27 | 28 | * **auto_permute**: *bool*. If enable, TorchShow will automatically convert `CHW` to `HWC` format. 29 | 30 | * **display**: *bool*. If set to false, TorchShow will not display the data but return the list of processed data. Use it when you want to visualize them using other libraries such as OpenCV. 31 | 32 | * **nrows**: *Int*. The number of rows to plot in a grid layout. If not specified it will be automatically inferred by TorchShow. 33 | 34 | * **ncols**: *Int*. The number of columns to plot in a grid layout. If not specified it will be automatically inferred by TorchShow. 35 | 36 | * **channel_mode**: *Str*. The channel mode of your input data. Available options are `"auto"`, `"channel_last"` and `"channel_fist"`. The default value is `"auto"` and it will be automatically inferred by TorchShow. 37 | 38 | * **show_axis**: *Bool*. Whether to show the axis in the plot. 39 | 40 | * **tight_layout**: *Bool*. Routines to adjust subplot params so that subplots are nicely fit in the figure. Corresponding to `fig.tight_layout()` in matplotlib. 41 | 42 | * **suptitle**: *Str*. Add a centered suptitle to the figure. 43 | 44 | * **axes_title**: *Str*. Add titles to each of the axes. It can be used with predefined placeholders. Available placeholders are: `{img_id}`, `{img_id_from_1}`, `{row}`, `{column}`. 45 | 46 | Below is an example that shows the image id on top of each image: 47 | 48 | ```python 49 | batch = torch.rand(8, 3, 100, 100) 50 | ts.show(batch, axes_title="Image ID: {img_id_from_1}") 51 | ``` 52 | 53 | ![](./imgs/axes_title.jpg) 54 | 55 | * **figsize**: *2-tuple of floats*. Figure dimension `(width, height)` in inches. 56 | 57 | * **dpi**: *float*. Dots per inch. 58 | 59 | * **cmap**: *str*. Specifying the [color map](https://matplotlib.org/stable/tutorials/colors/colormaps.html) for grayscale image. 60 | 61 | --- 62 | 63 | ## torchshow.save 64 | 65 | ```python 66 | torchshow.save(x, 67 | path=None, 68 | **kwargs) 69 | ``` 70 | 71 | ### Parameters: 72 | 73 | * **x**: *tensor-like (support both `torch.Tensor` and `np.ndarray`) or List of tensor-like.* The tensor data that we want to visualize. 74 | * **path**: *str*. The path to save the figure. 75 | * **kwargs**: You can pass in any other parameters available in `torchshow.show().` 76 | 77 | --- 78 | 79 | ## torchshow.overlay 80 | 81 | ```python 82 | torchshow.overlay(x, 83 | alpha=None, 84 | extent=None, 85 | save_as=None, 86 | **kwargs) 87 | ``` 88 | 89 | A function use to overlay multiple visualization. 90 | 91 | ### Parameters 92 | 93 | * **x**: *list of tensor-like.* A list of tensor data that we want to overlay their visualization. Filenames are also supported. 94 | * **alpha**: *list of (number or array-like)*. (Optional) The list of alpha values for blending, each alpha value is between 0 (transparent) and 1 (opaque). If alpha is an array-like, the alpha blending values are applied pixel by pixel, and alpha must have the same shape as X. 95 | * **extent**: *tuple*. (Optional) Format: `(x_min, x_max, y_min, y_max)`. The extent defines the size of the rendering area which will be used to render all plots. If unspecified TorchShow will use the extent of the first visualization. 96 | * **save_as**: *srt*. (Optional) A filepath to save the plot. If specified TorchShow will save the result to this file. 97 | * **kwargs**: You can pass in any other parameters available in `torchshow.show().` 98 | 99 | ### Examples: 100 | 101 | ```python 102 | ts.overlay([tensor1, tensor2, tensor3], alpha=[0.5, 0.5]) 103 | ts.overlay(["example_rgb.jpg", "example_category_mask.png"], alpha=[1, 0.5]) 104 | ``` 105 | 106 | --- 107 | 108 | ## torchshow.show_video 109 | 110 | ```python 111 | torchshow.show_video(x, 112 | display=True, 113 | show_axis=False, 114 | tight_layout=False, 115 | suptitle=None, 116 | figsize=None, 117 | dpi=None) 118 | ``` 119 | 120 | * **x**: *tensor-like (Support both `torch.Tensor` and `np.ndarray`) or List of tensor-like.* The tensor data that we want to visualize. 121 | 122 | * **display**: *bool*. If set to false, TorchShow will not display the data but return the list of processed data. Use it when you want to visualize them using other libraries such as OpenCV. 123 | 124 | * **show_axis**: *Bool*. Whether to show the axis in the plot. 125 | 126 | * **tight_layout**: *Bool*. Routines to adjust subplot params so that subplots are nicely fit in the figure. Corresponding to `fig.tight_layout()` in matplotlib. 127 | 128 | * **suptitle**: *Str*. Add a centered suptitle to the figure. 129 | 130 | * **figsize**: *2-tuple of floats*. Figure dimension `(width, height)` in inches. 131 | 132 | * **dpi**: *float*. Dots per inch. 133 | 134 | --- 135 | 136 | 137 | ## torchshow.set_color_mode 138 | ```python 139 | torchshow.set_color_mode(mode) 140 | ``` 141 | 142 | * **mode**: *str*. `"rgb"` or `"bgr"`. Set channel mode of the color image. The default config is `"rgb"`. 143 | 144 | --- 145 | 146 | ## torchshow.set_image_mean 147 | ```python 148 | torchshow.set_image_mean(mean) 149 | ``` 150 | 151 | * **mean**: *list of number*: Set the channel-wise mean when unnormalize the image. The default mean is `[0., 0., 0.]`. 152 | 153 | --- 154 | 155 | ## torchshow.set_image_std 156 | ```python 157 | torchshow.set_image_std(std) 158 | ``` 159 | 160 | * **std**: *list of number*: Set the channel-wise std when unnormalize the image. The default std is `[1., 1., 1.]`. 161 | 162 | --- 163 | 164 | ## torchshow.show_rich_info 165 | ```python 166 | torchshow.show_rich_info(flag) 167 | ``` 168 | 169 | * **flag**: *bool*: Whether to show rich info in the interactive window. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ![TorchShow Logo](https://raw.githubusercontent.com/xwying/torchshow/master/imgs/torchshow.png) 4 | 5 | [![PyPI version](https://badge.fury.io/py/torchshow.svg)](https://badge.fury.io/py/torchshow) 6 | [![Downloads](https://static.pepy.tech/personalized-badge/torchshow?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=Downloads)](https://pepy.tech/project/torchshow) 7 | ![License](https://img.shields.io/github/license/xwying/torchshow?color=brightgreen) 8 | 9 |
10 | 11 | ---- 12 | 13 | Torchshow visualizes your data in one line of code. It is designed to help debugging Computer Vision project. 14 | 15 | Torchshow automatically infers the type of a tensor such as RGB images, grayscale images, binary masks, categorical masks (automatically apply color palette), etc. and perform necessary unnormalization if needed. 16 | 17 | **Supported Type:** 18 | 19 | - [x] RGB Images 20 | - [x] Grayscale Images 21 | - [x] Binary Mask 22 | - [x] Categorical Mask (Integer Labels) 23 | - [x] Multiple Images 24 | - [x] Videos 25 | - [x] Multiple Videos 26 | - [x] Optical Flows (powered by [flow_vis](https://github.com/tomrunia/OpticalFlow_Visualization)) 27 | 28 | 29 | 30 | ## What's New in v0.5.2 31 | - Fixed the version comparison bugs in `set_window_title` function ([#24](https://github.com/xwying/torchshow/issues/24)). 32 | - Support more data dtype for pytorch tensors. 33 | 34 | See the complete [changelogs](changelogs.md). 35 | 36 | 37 | ## Installation 38 | Install from [PyPI](https://pypi.org/project/torchshow/): 39 | 40 | ```bash 41 | pip install torchshow 42 | ``` 43 | 44 | Alternatively, you can install directly from this repo to test the latest features. 45 | 46 | ```bash 47 | pip install git+https://github.com/xwying/torchshow.git@master 48 | ``` 49 | 50 | 51 | ## Basic Usage 52 | 53 | The usage of TorchShow is extremely simple. Simply import the package and visualize your data in one line: 54 | 55 | ```python 56 | import torchshow as ts 57 | ts.show(tensor) 58 | ``` 59 | 60 | If you work on a headless server without display. You can use `ts.save(tensor)` command (since version 0.3.2). 61 | 62 | ```python 63 | import torchshow as ts 64 | ts.save(tensor) # Figure will be saved under ./_torchshow/***.png 65 | ts.save(tensor, './vis/test.jpg') # You can specify the save path. 66 | ``` 67 | 68 | ## API References 69 | 70 | Please check [this page](./API.md) for detailed API references. 71 | 72 | 73 | ## Examples 74 | 75 | ### Table of Contents 76 | - [Visualizing Image Tensor](#1-visualizing-image-tensor) 77 | - [Visualizing Mask Tensors](#2-visualizing-mask-tensors) 78 | - [Visualizing Batch of Tensors](#3-visualizing-batch-of-tensors) 79 | - [Visualizing Channels in Feature Maps](#4-visualizing-feature-maps) 80 | - [Visualizing Multiple Tensors with Custom Layout.](#5-visualizing-multiple-tensors-with-custom-layout) 81 | - [Examine the pixel with rich information.](#6-examine-the-pixel-with-richer-information) 82 | - [Visualizing Tensors as Video Clip](#7-visualizing-tensors-as-video-clip) 83 | - [Display Video Animation in Jupyter Notebook](#8-display-video-animation-in-jupyter-notebook) 84 | - [Visualizing Optical Flows](#9-visualizing-optical-flows) 85 | - [Change Channel Order (RGB/BGR)](#10-change-channel-order-rgbbgr) 86 | - [Change Unnormalization Presets](#11-change-unnormalization-presets) 87 | - [Overlay Visualizations](#12-overlay-visualizations) 88 | 89 | ### 1. Visualizing Image Tensor 90 | Visualizing an image-like tensor is not difficult but could be very cumbersome. You usually need to convert the tensor to numpy array with proper shapes. In many cases images were normalized during dataloader, which means that you have to unnormalize it so it can be displayed correctly. 91 | 92 | If you need to frequently verify what your tensors look like, TorchShow is a very helpful tool. 93 | 94 | Using Matplotlib | Using TorchShow 95 | :-------------------------:|:-------------------------: 96 | ![](./imgs/RGB_image_plt.gif) | ![](./imgs/RGB_image_ts.gif) 97 | |The image tensor has been normalized so Matplotlib cannot display it correctly. | TorchShow does the conversion automatically.| 98 | 99 | ### 2. Visualizing Mask Tensors 100 | For projects related to Semantic Segmentation or Instance Segmentation, we often need to visualize mask tensors -- either ground truth annotations or model's prediction. This can be easily done using TorchShow. 101 | 102 | Using Matplotlib | Using TorchShow 103 | :-------------------------:|:-------------------------: 104 | ![](./imgs/cat_mask_plt.gif) | ![](./imgs/cat_mask_ts.gif) 105 | | Different instances have same colors. Some categories are missing. | TorchShow automatically apply color palletes during visualization.| 106 | 107 | ### 3. Visualizing Batch of Tensors 108 | When the tensor is a batch of images, TorchShow will automatically create grid layout to visualize them. It is also possible to manually control the number of rows and columns. 109 | 110 | ![](./imgs/batch_imgs.gif) 111 | 112 | ### 4. Visualizing Feature Maps 113 | If the input tensor has more than 3 channels, TorchShow will visualize each of the channel similar to batch visualization. This is useful to visualize a feature map. 114 | 115 | ![](./imgs/featuremap.gif) 116 | 117 | ### 5. Visualizing Multiple Tensors with Custom Layout. 118 | TorchShow has more flexibility to visualize multiple tensor using a custom layout. 119 | 120 | To control the layout, put the tensors in list of list as an 2D array. The following example will create a 2 x 3 grid layout. 121 | 122 | ``` 123 | ts.show([[tensor1, tensor2, tensor3], 124 | [tensor4, tensor5, tensor6]]) 125 | ``` 126 | 127 | It is worth mentioning that there is no need to fill up all the places in the grid. The following example visualizes 5 tensors in a 2 x 3 grid layout. 128 | 129 | ``` 130 | ts.show([[tensor1, tensor2], 131 | [tensor3, tensor4, tensor5]]) 132 | ``` 133 | 134 | ![](./imgs/custom_layout.gif) 135 | 136 | 137 | ### 6. Examine the pixel with richer information. 138 | Since `v0.4.1`, TorchShow allows you to get richer information from a pixel you are interested by simply hovering your mouse over that pixel. This is very helpful for some types of tensors such as Categorical Mask and Optical Flows. 139 | 140 | Currently, Torchshow displays the following information: 141 | 142 | - `Mode`: Visualization Mode. 143 | - `Shape`: Shape of the tensor. 144 | - `X`, `Y`: The pixel location of the mouse cursor. 145 | - `Raw`: The raw tensor value at (X, Y). 146 | - `Disp`: The display value at (X, Y). 147 | 148 | ![](./imgs/rich_info.gif) 149 | 150 | **Note: if the information is not showing on the status bar, try to resize the window and make it wider.** 151 | 152 | This feature can be turned off by `ts.show_rich_info(False)`. 153 | 154 | 155 | ### 7. Visualizing Tensors as Video Clip 156 | Tensors can be visualized as video clips, which very helpful if the tensor is a sequence of frames. This can be done using `show_video` function. 157 | 158 | ```python 159 | ts.show_video(video_tensor) 160 | ``` 161 | 162 | ![](./imgs/video.gif) 163 | 164 | It is also possible to visualize multiple videos in a custom grid layout. 165 | 166 | ![](./imgs/video_grid.gif) 167 | 168 | ### 8. Display Video Animation in Jupyter Notebook 169 | TorchShow visualizes video clips as an `matplotlib.func_animation` object and may not display in a notebook by default. The following example shows a simple trick to display it. 170 | 171 | ```python 172 | import torchshow as ts 173 | from IPython.display import HTML 174 | 175 | ani = ts.show_video(video_tensor) 176 | HTML(ani.to_jshtml()) 177 | ``` 178 | 179 | ### 9. Visualizing Optical Flows 180 | TorchShow support visualizing optical flow (powered by [flow_vis](https://github.com/tomrunia/OpticalFlow_Visualization)). Below is a demostration using a VSCode debugger remotely attached to a SSH server (with X-server configured). Running in a Jupyter Notebook is also supported. 181 | 182 | ![](./imgs/flow_ts.gif) 183 | 184 | ### 10. Change Channel Order (RGB/BGR) 185 | By default tensorflow visualize image tensor in the RGB mode, you can switch the setting to BGR in case you are using opencv to load the image. 186 | ```python 187 | ts.set_color_mode('bgr') 188 | ``` 189 | 190 | ### 11. Change Unnormalization Presets 191 | The image tensor may have been preprocessed with a normalization function. If not specified, torchshow will automatically rescale it to 0-1. 192 | 193 | 194 | To change the preset to imagenet normalization. Use the following code. 195 | ```python 196 | ts.show(tensor, unnormalize='imagenet') 197 | ``` 198 | 199 | To use a customize mean and std value, use the following command. 200 | ```python 201 | ts.set_image_mean([0., 0., 0.]) 202 | ts.set_image_std([1., 1., 1.]) 203 | ``` 204 | Note that once this is set, torchshow will use this value for the following visualization. This is useful because usually only a single normalization preset will be used for the entire project. 205 | 206 | 207 | ### 12. Overlay Visualizations 208 | In Computer Vision project there are many times we will be dealing with different representations of the scene, including but not limited to RGB image, depth image, infrared image, semantic mask, instance mask, etc. Sometimes it will be very helpful to overlay these different data for visualization. Since `v0.5.0`, TorchShow provides a very useful API `ts.overlay()` for this purpose. 209 | 210 | In the below example we have an RGB image and its corresponding semantic mask. Let's first check what they look like using TorchShow. 211 | 212 | ```python 213 | import torchshow as ts 214 | ts.show(["example_rgb.jpg", "example_category_mask.png"]) 215 | ``` 216 | 217 | ![](./imgs/overlay_1.png) 218 | 219 | Now I would like to overlay the mask on top of the RGB image to gain more insights, with TorchShow this can be easily done with one line of code. 220 | 221 | ```python 222 | import torchshow as ts 223 | ts.overlay(["example_rgb.jpg", "example_category_mask.png"], alpha=[1, 0.6]) 224 | ``` 225 | 226 | ![](./imgs/overlay_2.png) 227 | -------------------------------------------------------------------------------- /torchshow/torchshow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import logging 4 | import warnings 5 | 6 | from .visualization import vis_image, vis_grayscale, vis_categorical_mask, vis_flow, display_plt, animate_plt, overlay_plt 7 | from .utils import isinteger, calculate_grid_layout, tensor_to_array 8 | 9 | logging.basicConfig(level=logging.INFO) 10 | 11 | logger = logging.getLogger('TorchShow') 12 | 13 | vis_func_dict = dict(image=vis_image, 14 | grayscale=vis_grayscale, 15 | categorical_mask=vis_categorical_mask, 16 | flow=vis_flow) 17 | 18 | def save(x, path=None, **kwargs): 19 | show(x, save=True, file_path=path, **kwargs) 20 | 21 | 22 | def show(x, display=True, **kwargs): 23 | vis_list = None 24 | 25 | x = tensor_to_array(x) 26 | 27 | if isinstance(x, (np.ndarray)): 28 | x = x.copy() 29 | nrows = kwargs.get('nrows', None) 30 | ncols = kwargs.get('ncols', None) 31 | channel_mode = kwargs.get('channel_mode', 'auto') 32 | if channel_mode == 'auto': 33 | if x.shape[-1] in [1,2,3]: 34 | channel_mode = 'channel_last' 35 | else: 36 | channel_mode = 'channel_first' 37 | 38 | if x.ndim == 4: # (N, C, H, W) like array 39 | if channel_mode == 'channel_first': 40 | N, _, H, W = x.shape 41 | elif channel_mode == 'channel_last': 42 | N, H, W, _ = x.shape 43 | 44 | nrows, ncols = calculate_grid_layout(N, H, W, nrows, ncols) 45 | assert (nrows * ncols >= N) 46 | vis_list = [list(x[i:i + ncols]) for i in range(0, N, ncols)] # vis_list is now an list of list 47 | 48 | 49 | elif x.ndim == 3: # (C, H, W) like array 50 | if channel_mode == 'channel_first': # C, H, W 51 | C, H, W = x.shape 52 | elif channel_mode == 'channel_last': 53 | H, W, C = x.shape 54 | 55 | if C <=3: 56 | vis_list = [[x]] # if C is in [1,2,3], visualize it as single image 57 | else: # when C is greater than 3 (e.g. feature maps), visualize it in grid layout 58 | if channel_mode == 'channel_last': 59 | x = np.transpose(x, (2,0,1)) # Transpose to C, H, W because we will visualize each individual channel 60 | nrows, ncols = calculate_grid_layout(C, H, W, nrows, ncols) 61 | assert (nrows * ncols >= C) 62 | vis_list = [list(x[i:i + ncols]) for i in range(0, C, ncols)] 63 | 64 | elif x.ndim == 2: # (H, W) 65 | vis_list = [[x]] 66 | 67 | else: 68 | raise TypeError("Unsupported shape of numpy array {} .".format(x.shape)) 69 | 70 | elif isinstance(x, list): 71 | if isinstance(x[0], np.ndarray): # if the input is list of images [img1, img2, ....] 72 | nrows = kwargs.get('nrows', None) 73 | ncols = kwargs.get('ncols', None) 74 | if (nrows is not None) or (ncols is not None): # If user specified the grid layout 75 | N = len(x) 76 | # Here we assume either `nrows` or `ncols` must given so we do not have to specify H, W 77 | nrows, ncols = calculate_grid_layout(N, None, None, nrows, ncols) 78 | assert (nrows * ncols >= N) 79 | vis_list = [list(x[i:i + ncols]) for i in range(0, N, ncols)] # vis_list is now a list of list 80 | else: 81 | vis_list = [x] # Make it [[img1, img2, ...]] if `nrows` or `ncols` are not specified 82 | else: 83 | vis_list = x 84 | 85 | else: 86 | raise NotImplementedError("Does not support input type \"{}\"".format(type(x))) 87 | 88 | 89 | # vis_list: list of list. Outer list is for rows and inner list is the images in each row. 90 | # e.g.[[img1, img2], 91 | # [img3, img4]] 92 | 93 | assert isinstance(vis_list, list) 94 | assert np.array([isinstance(l, list) for l in vis_list]).all() # Now the input should be list of list 95 | 96 | plot_list = [] 97 | 98 | for row in vis_list: 99 | list_per_row = [] 100 | for img in row: 101 | vis = visualize(img, **kwargs) 102 | list_per_row.append(vis) 103 | plot_list.append(list_per_row) 104 | 105 | if display: 106 | display_plt(plot_list, **kwargs) 107 | else: 108 | return plot_list 109 | 110 | 111 | def show_video(x, display=True, **kwargs): 112 | video_list = None 113 | 114 | x = tensor_to_array(x) 115 | 116 | if isinstance(x, (np.ndarray)): 117 | x = x.copy() 118 | assert x.ndim in [3,4], "only support 3D array (N, H, W) or 4D array (N, C, H, W) in video mode" 119 | print(x.shape) 120 | video_list = [[x]] # for a single video, make it [[vid]] 121 | 122 | elif isinstance(x, list): 123 | if isinstance(x[0], np.ndarray): # if the input is list of array [vid1, vid2, ...] 124 | nrows = kwargs.get('nrows', None) 125 | ncols = kwargs.get('ncols', None) 126 | if (nrows is not None) or (ncols is not None): # If user specified the grid layout 127 | N = len(x) 128 | # Here we assume either `nrows` or `ncols` must given so we do not have to specify H, W. 129 | nrows, ncols = calculate_grid_layout(N, None, None, nrows, ncols) 130 | assert (nrows * ncols >= N) 131 | video_list = [list(x[i:i + ncols]) for i in range(0, N, ncols)] # video_list is now a list of list 132 | else: 133 | video_list = [x] # Make it [[vid1, vid2, ...]] if `nrows` or `ncols` are not specified 134 | else: 135 | video_list = x 136 | 137 | else: 138 | raise NotImplementedError("Does not support input type \"{}\"".format(type(x))) 139 | 140 | 141 | # video_list: list of list. Outer list is for rows and inner list is the images in each row. 142 | # e.g.[[img1, img2], 143 | # [img3, img4]] 144 | 145 | assert isinstance(video_list, list) 146 | assert np.array([isinstance(l, list) for l in video_list]).all() # Now the input should be list of list 147 | 148 | video_length = max([len(vid) for l in video_list for vid in l]) 149 | 150 | video_vis_list = [] # Reorganize frames into [t, row, col, img] 151 | 152 | for t in range(video_length): 153 | frames_at_t = [] # [[frame_t_video1, frame_t_video2], 154 | # [frame_t_video3, frame_t_video4]] 155 | for row in video_list: 156 | frames_at_t_per_row = [] # [frame_t_video1, frame_t_video2] 157 | for video in row: 158 | if t < len(video): 159 | img = video[t] 160 | vis = visualize(img, **kwargs) 161 | else: 162 | vis = None 163 | 164 | frames_at_t_per_row.append(vis) # 165 | frames_at_t.append(frames_at_t_per_row) # 166 | video_vis_list.append(frames_at_t) 167 | 168 | if display: 169 | return animate_plt(video_vis_list, **kwargs) 170 | else: 171 | return video_vis_list 172 | 173 | def overlay(x, alpha=None, extent=None, save_as=None, **kwargs): 174 | """ Overlay a list of inputs. Useful for comparing two images or putting masks over images. 175 | 176 | Args: 177 | x (list): a list of tensor-like. 178 | alpha (list) (Optional) : list of alpha for the overlay factor. Each alpha value could be a float or array-like. 179 | save_as (str) (Optional) : A filepath. If specified, save the result to this file. 180 | """ 181 | assert isinstance(x, list) and len(x)>1, "ts.overlay() expect a list with at least two tensor-like as inputs." 182 | if alpha is None: 183 | alpha = [1] + [0.7] * (len(x)-1) # Default blending factor. 184 | if not isinstance(alpha, list): 185 | alpha = [alpha] 186 | 187 | if len(alpha) < len(x): 188 | alpha += [0.7] * (len(x) - len(alpha)) # fill the list of alpha so it has the same length as the list of tensor. 189 | 190 | vis_list = show(x, display=False, **kwargs)[0] 191 | 192 | if extent is None: 193 | h, w = vis_list[0]["disp"].shape[:2] 194 | extent = 0, w, 0, h 195 | 196 | return overlay_plt(vis_list, alpha=alpha, save_as=save_as, extent=extent, **kwargs) 197 | 198 | 199 | def visualize(x, 200 | mode='auto', 201 | auto_permute=True, 202 | **kwargs): 203 | 204 | assert isinstance(x, np.ndarray) 205 | 206 | shape = x.shape 207 | ndim = len(shape) 208 | assert ndim <= 3 209 | 210 | if auto_permute: 211 | if (ndim == 3) and (shape[0] in [1, 2, 3]): # For C, H, W kind of array. 212 | logger.debug('Detected input shape {} is in CHW format, TorchShow will automatically convert it to HWC format'.format(shape)) 213 | x = np.transpose(x, (1,2,0)) 214 | 215 | if ndim == 2: 216 | x = np.expand_dims(x, axis=-1) 217 | 218 | if mode=='auto': 219 | mode = infer_mode(x) 220 | 221 | vis_func = vis_func_dict.get(mode, None) 222 | 223 | if vis_func == None: 224 | raise ValueError("mode {} is not supported.".format(mode)) 225 | 226 | return vis_func(x, **kwargs) 227 | 228 | 229 | def infer_mode(x): 230 | shape = x.shape 231 | ndim = len(shape) 232 | if shape[-1] == 3: 233 | mode = 'image' 234 | elif shape[-1] == 2: 235 | mode = 'flow' 236 | elif shape[-1] == 1: 237 | if (x.min() >= 0) and (x.max() <= 1): 238 | mode = 'grayscale' 239 | elif isinteger(np.unique(x)).all(): # If values are all integer 240 | mode = 'categorical_mask' 241 | else: 242 | mode = 'grayscale' 243 | else: 244 | raise NotImplementedError("Does support auto infer for shape {} .".format(shape)) 245 | logger.debug("Auto Infer: {}".format(mode)) 246 | return mode 247 | 248 | -------------------------------------------------------------------------------- /tests/tests.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torchshow as ts 3 | import torch 4 | import torchvision 5 | import numpy as np 6 | import logging 7 | from PIL import Image 8 | 9 | def read_flow(filename): 10 | with open(filename, 'rb') as f: 11 | magic = np.fromfile(f, np.float32, count=1) 12 | if 202021.25 != magic: 13 | print('Magic number incorrect. Invalid .flo file') 14 | else: 15 | w = np.fromfile(f, np.int32, count=1)[0] 16 | h = np.fromfile(f, np.int32, count=1)[0] 17 | print('Reading %d x %d flo file' % (w, h)) 18 | data = np.fromfile(f, np.float32, count=2*w*h) 19 | # Reshape data into 3D array (columns, rows, bands) 20 | data2D = np.resize(data, (h, w, 2)) 21 | return data2D 22 | 23 | 24 | def test(section): 25 | rgb_img = torch.rand((3, 100, 100)) 26 | if section <= 1: 27 | print("1.1 Single 3-channel image between 0-1") 28 | rgb_img = torch.rand((3, 100, 100)) 29 | print(rgb_img.min(), rgb_img.max()) 30 | ts.show(rgb_img) 31 | ts.save(rgb_img) 32 | 33 | print("1.2 Single 3-channel image between 0-255") 34 | rgb_img_2 = rgb_img * 100 35 | print(rgb_img_2.min(), rgb_img_2.max()) 36 | ts.show(rgb_img_2) 37 | ts.save(rgb_img_2) 38 | 39 | print("1.3 Single 3-channel image with value larger than 255") 40 | rgb_img_3 = rgb_img * 300 41 | print(rgb_img_3.min(), rgb_img_3.max()) 42 | ts.show(rgb_img_3) 43 | ts.save(rgb_img_3) 44 | 45 | print("1.4 Single 3-channel image with value smaller than 0") 46 | rgb_img_4 = rgb_img - 0.5 47 | print(rgb_img_4.min(), rgb_img_4.max()) 48 | ts.show(rgb_img_4) 49 | ts.save(rgb_img_4) 50 | 51 | print("1.5 Single 3-channel image with value smaller than 0 and larger than 255") 52 | rgb_img_5 = rgb_img * 500 - 100 53 | print(rgb_img_5.min(), rgb_img_5.max()) 54 | ts.show(rgb_img_5) 55 | ts.save(rgb_img_5) 56 | 57 | gray_img = torch.rand((1, 100, 100)) 58 | category_mask = np.array(Image.open('test_data/example_category_mask.png')) 59 | depth_image = np.array(Image.open('test_data/example_depth.png')) / 1000.0 60 | if section <=2: 61 | print("2.1 Single 1-Channel image between 0-1") 62 | print(gray_img.min(), gray_img.max()) 63 | ts.show(gray_img) 64 | ts.save(gray_img) 65 | 66 | print("2.2 Single 1-Channel image between 0-255") 67 | gray_img_2 = gray_img * 100 68 | print(gray_img_2.min(), gray_img_2.max()) 69 | ts.show(gray_img_2) 70 | ts.save(gray_img_2) 71 | 72 | print("2.3 Single 1-Channel image with value smaller than 0 and larger than 255") 73 | gray_img_3 = gray_img * 500 - 100 74 | print(gray_img_3.min(), gray_img_3.max()) 75 | ts.show(gray_img_3) 76 | ts.save(gray_img_3) 77 | 78 | print("2.4 Single 1-Channel image with binary value") 79 | gray_img_4 = torch.eye(100).unsqueeze(0) 80 | gray_img_4[:, 20:40, 20:40] = 1 81 | gray_img_4[:, 60:80, 60:80] = 1 82 | print(gray_img_4.unique()) 83 | ts.show(gray_img_4) 84 | ts.save(gray_img_4) 85 | 86 | print("2.5 Single 1-Channel image with integer value >= 0") 87 | gray_img_5 = torch.randint(0, 100, (1, 100, 100)) 88 | print(gray_img_5.unique()) 89 | ts.show(gray_img_5) 90 | ts.save(gray_img_5) 91 | 92 | print("2.6 Single 1-Channel image with both positive and negative integer value") 93 | gray_img_6 = torch.randint(-50, 100, (1, 100, 100)) 94 | print(gray_img_6.unique()) 95 | ts.show(gray_img_6) 96 | ts.save(gray_img_6) 97 | 98 | print("2.7 Single 1-Channel normal categorical mask") 99 | print(np.unique(category_mask)) 100 | ts.show(category_mask) 101 | ts.save(category_mask) 102 | 103 | print("2.8 Single channel image with custom cmap") 104 | ts.show(depth_image, cmap="magma") 105 | ts.show(depth_image, cmap="inferno") 106 | ts.show(depth_image, cmap="viridis") 107 | ts.save(depth_image, cmap="magma") 108 | ts.save(depth_image, cmap="inferno") 109 | ts.save(depth_image, cmap="viridis") 110 | 111 | flow = read_flow("./test_data/example_flow.flo") 112 | if section <= 3: 113 | print("3.1 Single 2-Channel normal optical flow") 114 | print(flow.shape, flow.min(), flow.max()) 115 | ts.show(flow) 116 | ts.save(flow) 117 | 118 | print("3.2 Single 2-Channel with random values") 119 | flow_2 = torch.rand((2, 100, 100)) * 200 - 100 120 | print(flow_2.min(), flow_2.max()) 121 | ts.show(flow_2) 122 | ts.save(flow_2) 123 | 124 | img_n = torch.rand(16, 100, 100) 125 | if section <=4: 126 | print("4.1 Single n-channel tensor (n>3") 127 | print(img_n.min(), img_n.max()) 128 | ts.show(img_n) 129 | ts.save(img_n) 130 | ts.show(img_n, ncols=3) 131 | ts.save(img_n, ncols=3) 132 | ts.show(img_n, nrows=5) 133 | ts.save(img_n, nrows=5) 134 | 135 | print("4.2 4D tensors") 136 | batch = torch.rand(8, 3, 100, 100) * 500 - 250 137 | print(batch.min(), batch.max()) 138 | ts.show(batch) 139 | ts.save(batch) 140 | ts.show(batch, ncols=3) 141 | ts.save(batch, ncols=3) 142 | ts.show(batch, nrows=4) 143 | ts.save(batch, nrows=4) 144 | ts.show(batch, axes_title="Image ID: {img_id_from_1}") 145 | ts.show(batch, nrows=4, axes_title="{img_id}-{img_id_from_1}-{row}-{column}") 146 | ts.show(batch, nrows=4, axes_title="{img_id}-{img_id_from_1}-{row}-{column}", suptitle="Figure 1") 147 | 148 | print("4.3 List of tensors") 149 | list_tensors = list(batch) 150 | ts.show(list_tensors) 151 | ts.save(list_tensors) 152 | ts.show(list_tensors, ncols=3) 153 | ts.save(list_tensors, ncols=3) 154 | ts.show(list_tensors, nrows=4) 155 | ts.save(list_tensors, nrows=4) 156 | ts.show(list_tensors, nrows=3, ncols=5) 157 | ts.save(list_tensors, nrows=3, ncols=5) 158 | ts.show(list_tensors, axes_title="Image ID: {img_id_from_1}") 159 | ts.show(list_tensors, nrows=4, axes_title="{img_id}-{img_id_from_1}-{row}-{column}") 160 | ts.show(list_tensors, nrows=4, axes_title="{img_id}-{img_id_from_1}-{row}-{column}", suptitle="List of Tensors") 161 | 162 | if section <=5: 163 | print("5.1 Custom Layout") 164 | grid = [[rgb_img, gray_img, flow]] 165 | ts.show(grid) 166 | ts.save(grid) 167 | grid2 = [[rgb_img, gray_img], 168 | [flow]] 169 | ts.show(grid2) 170 | ts.save(grid2) 171 | 172 | if section <=6: 173 | print("6.1 Video Clip") 174 | video = torch.rand(16, 3, 100, 100) 175 | print(video.min(), video.max()) 176 | ts.show_video(video) 177 | video2 = torch.rand(8, 1, 100, 100) 178 | video3 = torch.rand(13, 2, 100, 100) 179 | ts.show_video([[video, video2], 180 | [video3]]) 181 | video4 = torch.rand(13,5, 100, 100) 182 | list_of_videos = [video, video2, video3, video, video2, video3] 183 | ts.show_video(list_of_videos) 184 | ts.show_video(list_of_videos, ncols=3) 185 | ts.show_video(list_of_videos, nrows=3) 186 | ts.show_video(list_of_videos, nrows=2, ncols=8) 187 | print("6.2 ts.show_video with image") 188 | # This test produces unwanted results. Ignore it at this moment unless requested. 189 | ts.show_video(rgb_img) 190 | ts.show_video(flow) 191 | ts.show_video([[video, video2], 192 | [rgb_img, flow]], suptitle="Video Example") 193 | vis_list = ts.show_video([[video, video2], 194 | [rgb_img, flow]], display=False) 195 | print(len(vis_list), len(vis_list[0]), len(vis_list[0][0])) 196 | print("6.3 Video Clip Edge Cases") 197 | try: 198 | ts.show_video(video4) 199 | except Exception as e: # This should raise an error 200 | print(e) 201 | try: 202 | ts.show_video([[video, video2], 203 | [video3, video4]]) 204 | except Exception as e: # This should raise an error 205 | print(e) 206 | 207 | if section <=7: 208 | print("7 Return vis_list if display=False") 209 | vis_list = ts.show(img_n, display=False) 210 | print(len(vis_list), len(vis_list[0])) 211 | vis_list = ts.show(img_n, display=False, nrows=3) 212 | print(len(vis_list), len(vis_list[0])) 213 | 214 | if section <=8: 215 | print("8 Change Unnormalization Presets") 216 | rgb_img_numpy = rgb_img.permute(1,2,0).numpy() 217 | def test_normalize(MEAN, STD): 218 | transform = torchvision.transforms.Normalize(MEAN, STD) 219 | rgb_img_0 = transform(rgb_img) 220 | ts.set_image_mean(MEAN) 221 | ts.set_image_std(STD) 222 | rgb_img_1 = ts.show(rgb_img_0, display=False)[0][0]['disp'] 223 | assert np.allclose(rgb_img_1, rgb_img_numpy, atol=1e-7) # will be False if atol=1e-8 224 | 225 | test_normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 226 | test_normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 227 | test_normalize([0.4851231531212364564, 0.4561231523135436, 0.406123412312452343], [0.2293453455673435, 0.22434531445342, 0.22534534472423]) 228 | 229 | ts.set_image_mean([0,0,0]) 230 | ts.set_image_std([1,1,1]) 231 | 232 | if section <=9: 233 | print("9 Testing PIL Image") 234 | pil_rgb = Image.open("test_data/example_category_mask.png") 235 | pil_depth = Image.open("test_data/example_depth.png") 236 | ts.show([pil_rgb, pil_depth]) 237 | ts.save([pil_rgb, pil_depth]) 238 | 239 | if section <=10: 240 | print("10 Testing filename as input") 241 | ts.show(["test_data/example_image.jpeg", "test_data/example_depth.png", "test_data/example_category_mask.png"]) 242 | try: 243 | ts.show("test_data/a_file_that_does_not_exists.jpg") 244 | except Exception as e : 245 | print(e) 246 | pass 247 | ts.show("test_data/example_image_rotated_by_exif.jpeg") 248 | ts.show("test_data/example_flow.flo") 249 | 250 | if section <= 11: 251 | print("11 Testing Overlay FUnction") 252 | ts.overlay(["test_data/example_rgb.jpg", "test_data/example_category_mask.png"], alpha=[0.8]) 253 | ts.overlay(["test_data/example_rgb.jpg", "test_data/example_category_mask.png"], alpha=[0.5, 0.5]) 254 | ts.overlay(["test_data/example_rgb.jpg", "test_data/example_category_mask.png", "test_data/example_flow.flo"], alpha=[0.5, 0.5]) 255 | ts.overlay(["test_data/example_rgb.jpg", "test_data/example_category_mask.png"], save_as="_torchshow/test_overlay_save.png") 256 | # ts.overlay(["test_data/example_rgb.jpg", "test_data/example_category_mask.png"], alpha=[0.5, 1.5]) 257 | 258 | if __name__ == "__main__": 259 | if len(sys.argv) > 1: 260 | section = sys.argv[1] 261 | else: 262 | section = 0 263 | test(int(section)) -------------------------------------------------------------------------------- /torchshow/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | from matplotlib import colors, animation, pyplot as plt, rcParams 4 | # from PIL import Image 5 | from .utils import isinteger, within_0_1, within_0_255, isnotebook 6 | from .config import config 7 | from .flow import flow_to_color 8 | import warnings 9 | import logging 10 | from datetime import datetime 11 | import os 12 | import copy 13 | from numbers import Number 14 | 15 | 16 | logger = logging.getLogger('TorchShow') 17 | 18 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 19 | IMAGENET_STD = [0.229, 0.224, 0.225] 20 | 21 | 22 | def create_color_map(N=256, normalized=False): 23 | 24 | def bitget(byteval, idx): 25 | return (byteval & (1 << idx)) != 0 26 | 27 | dtype = 'float32' if normalized else 'uint8' 28 | cmap = np.zeros((N, 3), dtype=dtype) 29 | for i in range(N): 30 | r = g = b = 0 31 | c = i 32 | for j in range(8): 33 | r = r | (bitget(c, 0) << 7-j) 34 | g = g | (bitget(c, 1) << 7-j) 35 | b = b | (bitget(c, 2) << 7-j) 36 | c = c >> 3 37 | 38 | cmap[i] = np.array([r, g, b]) 39 | 40 | cmap = cmap/255 if normalized else cmap 41 | return cmap 42 | 43 | def set_window_title(fig, title): 44 | """ 45 | Set the title of the figure window (effective when using a interactive backend.) 46 | """ 47 | try: 48 | fig.canvas.set_window_title(title) 49 | except: 50 | try: 51 | fig.canvas.manager.set_window_title(title) 52 | except: 53 | logger.info("Seting window title failed.") 54 | 55 | 56 | def imshow(ax, vis, alpha=None, extent=None, show_rich_info=True): 57 | max_rows, max_cols = vis['raw'].shape[:2] 58 | 59 | def format_coord(x, y): 60 | """ 61 | We display x-y coordinate as integer. 62 | """ 63 | col = int(x + 0.5) 64 | row = int(y + 0.5) 65 | if not (0<=col 1 or ncols > 1 121 | if has_multiple_axes: 122 | axes = fig.subplots(nrows=nrows, ncols=ncols, squeeze=False) 123 | else: 124 | axes = np.array([[plt.Axes(fig, [0., 0., 1., 1.])]]) 125 | fig.add_axes(axes[0, 0]) 126 | show_rich_info = config.get('show_rich_info') 127 | # fig, axes = plt.subplots(nrows=nrows, ncols=ncols, squeeze=False) 128 | set_window_title(fig, 'TorchShow') 129 | if suptitle: 130 | fig.suptitle(suptitle) 131 | 132 | for i, plots_per_row in enumerate(vis_list): 133 | for j, vis in enumerate(plots_per_row): 134 | # axes[i, j].imshow(vis, **plot_cfg) 135 | imshow(axes[i,j], vis, show_rich_info=show_rich_info) 136 | title_namespace["img_id"] = i*ncols+j 137 | title_namespace["img_id_from_1"] = title_namespace["img_id"] + 1 138 | title_namespace["row"] = i 139 | title_namespace["column"] = j 140 | if axes_title is not None: 141 | axes[i, j].set_title(axes_title.format(**title_namespace)) 142 | 143 | # Delete empty axes 144 | for ax in axes.ravel(): 145 | if not ax.has_data(): 146 | fig.delaxes(ax) 147 | 148 | if not show_axis: 149 | for ax in axes.ravel(): 150 | ax.axis('off') 151 | 152 | if has_multiple_axes and tight_layout: 153 | fig.tight_layout() 154 | 155 | if kwargs.get('save', False): 156 | file_path = kwargs.get('file_path', None) 157 | if file_path is None: 158 | os.makedirs('_torchshow', exist_ok=True) 159 | cur_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f") 160 | file_path = '_torchshow/'+cur_time+'.png' 161 | dirname = os.path.dirname(file_path) 162 | if dirname != '': 163 | os.makedirs(dirname, exist_ok=True) 164 | fig.savefig(file_path, bbox_inches = 'tight', pad_inches=0) 165 | plt.close(fig) 166 | else: # If the image is saved by ts.save() it will not call plt.show() 167 | if not isnotebook(): 168 | plt.show() 169 | 170 | 171 | def overlay_plt(vis_list, alpha, save_as, extent, **kwargs): 172 | show_axis = kwargs.get('show_axis', False) 173 | tight_layout = kwargs.get('tight_layout', True) 174 | suptitle = kwargs.get('suptitle', None) 175 | axes_title = kwargs.get('axes_title', None) 176 | 177 | figsize = kwargs.get('figsize', None) 178 | dpi = kwargs.get('dpi', None) 179 | fig = plt.figure(figsize=figsize, dpi=dpi) 180 | ax = fig.add_subplot() 181 | set_window_title(fig, 'TorchShow') 182 | if suptitle: 183 | fig.suptitle(suptitle) 184 | 185 | assert len(vis_list) == len(alpha) 186 | 187 | for vis, a in zip(vis_list, alpha): 188 | imshow(ax, vis, alpha=a, extent=extent, show_rich_info=False) 189 | 190 | if not show_axis: 191 | ax.axis('off') 192 | 193 | if tight_layout: 194 | fig.tight_layout() 195 | 196 | if save_as != None: 197 | assert isinstance(save_as, str) 198 | dirname = os.path.dirname(save_as) 199 | if dirname!='': 200 | os.makedirs(dirname, exist_ok=True) 201 | fig.savefig(save_as, bbox_inches = 'tight', pad_inches=0) 202 | plt.close(fig) 203 | else: 204 | if not isnotebook(): 205 | plt.show() 206 | 207 | 208 | 209 | 210 | def animate_plt(video_list, **kwargs): 211 | """ 212 | : video_list: [t, row, col, image] 213 | """ 214 | nrows = len(video_list[0]) 215 | ncols = max([len(l) for l in video_list[0]]) 216 | 217 | show_axis = kwargs.get('show_axis', False) 218 | tight_layout = kwargs.get('tight_layout', True) 219 | suptitle = kwargs.get('suptitle', None) 220 | # show_title = kwargs.get('show_title', False) 221 | # title_pattern = kwargs.get('title_pattern', "{img_id}") 222 | # title_namespace = {} 223 | figsize = kwargs.get('figsize', None) 224 | dpi = kwargs.get('dpi', None) 225 | 226 | fig = plt.figure(figsize=figsize, dpi=dpi) 227 | axes = fig.subplots(nrows=nrows, ncols=ncols, squeeze=False) 228 | # fig, axes = plt.subplots(nrows=nrows, ncols=ncols, squeeze=False) 229 | set_window_title(fig, 'TorchShow') 230 | if suptitle: 231 | fig.suptitle(suptitle) 232 | 233 | plots = [] 234 | 235 | # Initialization 236 | for i, plots_per_row in enumerate(video_list[0]): 237 | for j, vis in enumerate(plots_per_row): 238 | if vis is not None: 239 | plot = axes[i, j].imshow(vis['disp'], **vis['plot_cfg']) 240 | plots.append(plot) 241 | else: 242 | plots.append(None) 243 | 244 | # Delete empty axes 245 | for ax in axes.ravel(): 246 | if not ax.has_data(): 247 | fig.delaxes(ax) 248 | 249 | if not show_axis: 250 | for ax in axes.ravel(): 251 | ax.axis('off') 252 | 253 | if tight_layout: 254 | fig.tight_layout() 255 | 256 | def run(frames_at_t): 257 | for i, plots_per_row in enumerate(frames_at_t): 258 | for j, vis in enumerate(plots_per_row): 259 | # axes[i, j].imshow(vis, **plot_cfg) 260 | if vis is not None: 261 | # axes[i,j].figure.canvas.draw() 262 | plots[i*ncols+j].set_data(vis['disp']) 263 | fig.canvas.draw() 264 | return plots 265 | 266 | # if tight_layout: 267 | # fig.tight_layout() 268 | 269 | ani = animation.FuncAnimation(fig, run, video_list, blit=True, interval=5, repeat=True) 270 | 271 | if not config.get('inline'): 272 | plt.show() 273 | return ani 274 | 275 | 276 | def auto_unnormalize_image(x): 277 | all_int = isinteger(np.unique(x)).all() 278 | range_0_1 = within_0_1(x) 279 | range_0_255 = within_0_255(x) 280 | has_negative = (x.min() < 0) 281 | 282 | if has_negative: 283 | logger.debug('Detects input has negative values, auto rescaling input to 0-1.') 284 | return rescale_0_1(x) 285 | if range_0_255 and all_int and (not range_0_1): # if image is all integer and between 0 - 255. Normalize it to 0-1. 286 | logger.debug('Detects input are all integers within range 0-255. Divided all values by 255.') 287 | return x / 255. 288 | if range_0_1: 289 | logger.debug('Inputs already within 0-1, no unnormalization is performed.') 290 | return x 291 | logger.debug('Auto rescaling input to 0-1.') 292 | return rescale_0_1(x) 293 | 294 | 295 | def rescale_0_1(x): 296 | """ 297 | Rescaling tensor to 0-1 using min-max normalization 298 | """ 299 | return (x - x.min()) / (x.max() - x.min()) 300 | 301 | 302 | def unnormalize_with_mean_and_std(x, mean, std): 303 | """ 304 | General Channel-wise mean std unnormalization. Expect input to be (H, W, C) 305 | """ 306 | x = x.copy() 307 | assert (len(x.shape) == 3), "Unnormalization only support (H, W, C) format input, got {}".format(x.shape) 308 | C = x.shape[-1] 309 | assert (len(mean) == C) and (len(std) == C), "Number of mean and std values must equals to number of channels." 310 | 311 | for i in range(C): 312 | x[:,:,i] = x[:,:,i] * std[i] + mean[i] 313 | 314 | return x 315 | 316 | 317 | def vis_image(x, unnormalize='auto', **kwargs): 318 | """ 319 | : x: ndarray (H, W, 3). 320 | : unnormalize: 'auto', 'imagenet', 'imagenet_scaled' 321 | : mean: image mean for unnormalization. 322 | : std: image std for unnormalization 323 | : display: whether to display image using matplotlib 324 | """ 325 | vis = dict() 326 | vis['raw'] = copy.deepcopy(x) 327 | vis['shape'] = str(x.shape) 328 | shape = x.shape 329 | ndim = len(shape) 330 | assert ndim == 3, "vis_image only support 3D array in (H, W, C) format." 331 | 332 | user_mean = config.get('image_mean') 333 | user_std = config.get('image_std') 334 | 335 | if (user_mean is not None) or (user_std is not None): 336 | if user_mean == None: 337 | user_mean = [0] * x.shape[-1] # Initialize mean to 0 if not specified. 338 | if user_std == None: 339 | user_std = [1.] * x.shape[-1] # Initialize std to 1 if not specified. 340 | x = unnormalize_with_mean_and_std(x, user_mean, user_std) 341 | 342 | elif unnormalize=='auto': 343 | x = auto_unnormalize_image(x) 344 | 345 | elif unnormalize=='imagenet': 346 | if (x.max() > 2.66) or (x.min() < -2.17): 347 | # A quick validation to check if the image was normalized to 0-1 348 | # before substracting imagenet mean and std 349 | x = x / 255. 350 | x = unnormalize_with_mean_and_std(x, IMAGENET_MEAN, IMAGENET_STD) 351 | 352 | else: 353 | raise NotImplementedError("Unsupported unnormalization profile \"{}\"".format(unnormalize)) 354 | 355 | assert x is not None 356 | 357 | if config.get('color_mode') == 'bgr': 358 | x = x[:,:,::-1] 359 | vis['mode'] = 'Image(BGR)' 360 | else: 361 | vis['mode'] = 'Image(RGB)' 362 | 363 | plot_cfg = dict() 364 | 365 | vis['disp'] = x 366 | vis['plot_cfg'] = plot_cfg 367 | return vis 368 | 369 | 370 | def vis_flow(x, **kwargs): 371 | vis = dict() 372 | vis['raw'] = copy.deepcopy(x) 373 | vis['shape'] = str(x.shape) 374 | x = flow_to_color(x) 375 | plot_cfg = dict() 376 | vis['disp'] = x 377 | vis['plot_cfg'] = plot_cfg 378 | vis['mode'] = 'Flow' 379 | return vis 380 | 381 | 382 | def vis_grayscale(x, **kwargs): 383 | vis = dict() 384 | assert (len(x.shape) == 3) and (x.shape[-1] == 1) 385 | x = np.squeeze(x, -1) 386 | vis['raw'] = copy.deepcopy(x) 387 | vis['shape'] = str(x.shape) 388 | # rescale to [0-1] 389 | if not within_0_1(x): 390 | warnings.warn('Original input range is not 0-1 when using grayscale mode. Auto-rescaling it to 0-1 by default.') 391 | x = rescale_0_1(x) 392 | vis['disp'] = x 393 | cmap = kwargs.get("cmap", "gray") 394 | plot_cfg = dict(cmap=cmap) 395 | 396 | plot_cfg['vmin'] = kwargs.get("vmin", None) 397 | plot_cfg['vmax'] = kwargs.get("vmax", None) 398 | 399 | if isinteger(np.unique(x)).all(): 400 | plot_cfg['interpolation'] = 'nearest' 401 | vis['mode'] = 'Binary' 402 | else: 403 | vis['mode'] = 'Gray' 404 | vis['plot_cfg'] = plot_cfg 405 | 406 | return vis 407 | 408 | 409 | def vis_categorical_mask(x, max_N=256, **kwargs): 410 | assert (len(x.shape) == 3) and (x.shape[-1] == 1) 411 | assert isinteger(np.unique(x)).all(), "Input has to contain only integers in categorical mask mode." 412 | vis = dict() 413 | 414 | x = np.squeeze(x, -1) 415 | vis['raw'] = copy.deepcopy(x) 416 | vis['shape'] = str(x.shape) 417 | N = int(x.max()) + 1 418 | 419 | if x.max() > max_N: 420 | warnings.warn('The maximum value in input is {} which is greater than the default max_N ({}). Automatically adjust max_N to {}.'.format(x.max(), max_N, x.max())) 421 | max_N = x.max() + 1 422 | 423 | color_list = create_color_map(N=max_N, normalized=True) 424 | color_list = np.concatenate([color_list, np.ones((max_N, 1)).astype(np.float32)], axis=1) 425 | 426 | if x.min() < 0: 427 | warnings.warn('Input has negative value when trying to visualize as categorical mask, which will all be converted to -1 and displayed in white.') 428 | x[x<0] = -1 # Map all negative value to -1 429 | color_list = np.concatenate([np.ones((1,4)).astype(np.float32), color_list], axis=0) # appending an extra value. 430 | N = N + 1 431 | 432 | cmap = colors.ListedColormap(color_list, N=N) 433 | 434 | x = cmap(x.astype(int), alpha=None, bytes=True)[:,:,:3] 435 | # print(x.shape) 436 | plot_cfg = dict( interpolation="nearest") 437 | 438 | vis['disp'] = x 439 | vis['plot_cfg'] = plot_cfg 440 | vis['mode'] = 'Categorical' 441 | return vis 442 | 443 | 444 | if __name__ == "__main__": 445 | print(create_color_map()) 446 | --------------------------------------------------------------------------------