├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── benchmarks ├── benchmarking_Center_Crop.png ├── benchmarking_Color_brightness_only.png ├── benchmarking_Color_constrast_and_brightness.png ├── benchmarking_Color_contrast_only.png ├── benchmarking_Color_hue_only.png ├── benchmarking_Color_saturation_only.png ├── benchmarking_Five_crop.png ├── benchmarking_Random_affine.png ├── benchmarking_Random_crop_quarter_size.png ├── benchmarking_Random_grayscale.png ├── benchmarking_Random_horizontal_flip.png ├── benchmarking_Random_resized_crop_for_Inception.png ├── benchmarking_Random_rotation_10_degrees.png ├── benchmarking_Random_vertical_flip.png ├── benchmarking_Resize.png ├── benchmarking_Resize_flip_brightness_contrast_rotate.png ├── benchmarking_Ten_crop.png └── benchmarking_Zero_padding_50x25.png ├── opencv_transforms ├── __init__.py ├── functional.py └── transforms.py ├── setup.py └── tests ├── compare_to_pil_for_testing.ipynb ├── setup_testing_directory.py ├── test_color.py ├── test_spatial.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints 3 | .vscode 4 | /dist/ 5 | /*.egg-info 6 | *.pyc 7 | /build/ 8 | 9 | tests/testing_directory.txt 10 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Jim Bohnslav 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include benchmarks/*.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # opencv_transforms 2 | 3 | This repository is intended as a faster drop-in replacement for [Pytorch's Torchvision augmentations](https://github.com/pytorch/vision/). This repo uses OpenCV for fast image augmentation for PyTorch computer vision pipelines. I wrote this code because the Pillow-based Torchvision transforms was starving my GPU due to slow image augmentation. 4 | 5 | ## Requirements 6 | * A working installation of OpenCV. **Tested with OpenCV version 3.4.1, 4.1.0** 7 | * Tested on Windows 10 and Ubuntu 18.04. There is evidence that OpenCV doesn't work well with multithreading on Linux / MacOS, for example `num_workers >0` in a pytorch `DataLoader`. I haven't run into this issue yet. 8 | 9 | ## Installation 10 | opencv_transforms is now a pip package! Simply use 11 | * `pip install opencv_transforms` 12 | 13 | ## Usage 14 | **Breaking change! Please note the import syntax!** 15 | * `from opencv_transforms import transforms` 16 | * From here, almost everything should work exactly as the original `transforms`. 17 | #### Example: Image resizing 18 | ```python 19 | import numpy as np 20 | image = np.random.randint(low=0, high=255, size=(1024, 2048, 3)) 21 | resize = transforms.Resize(size=(256,256)) 22 | image = resize(image) 23 | ``` 24 | Should be 1.5 to 10 times faster than PIL. See benchmarks 25 | 26 | ## Performance 27 | * Most transformations are between 1.5X and ~4X faster in OpenCV. Large image resizes are up to 10 times faster in OpenCV. 28 | * To reproduce the following benchmarks, download the [Cityscapes dataset](https://www.cityscapes-dataset.com/). 29 | * An example benchmarking file can be found in the notebook **bencharming_v2.ipynb** I wrapped the Cityscapes default directories with a HDF5 file for even faster reading. 30 | 31 | ![resize](benchmarks/benchmarking_Resize.png) 32 | ![random crop](benchmarks/benchmarking_Random_crop_quarter_size.png) 33 | ![change brightness](benchmarks/benchmarking_Color_brightness_only.png) 34 | ![change brightness and contrast](benchmarks/benchmarking_Color_constrast_and_brightness.png) 35 | ![change contrast only](benchmarks/benchmarking_Color_contrast_only.png) 36 | ![random horizontal flips](benchmarks/benchmarking_Random_horizontal_flip.png) 37 | 38 | The changes start to add up when you compose multiple transformations together. 39 | ![composed transformations](benchmarks/benchmarking_Resize_flip_brightness_contrast_rotate.png) 40 | 41 | ## TODO 42 | - [x] Initial commit with all currently implemented torchvision transforms 43 | - [x] Cityscapes benchmarks 44 | - [ ] Make the `resample` flag on `RandomRotation`, `RandomAffine` actually do something 45 | - [ ] Speed up augmentation in saturation and hue. Currently, fastest way is to convert to a PIL image, perform same augmentation as Torchvision, then convert back to np.ndarray 46 | -------------------------------------------------------------------------------- /benchmarks/benchmarking_Center_Crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Center_Crop.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Color_brightness_only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Color_brightness_only.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Color_constrast_and_brightness.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Color_constrast_and_brightness.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Color_contrast_only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Color_contrast_only.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Color_hue_only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Color_hue_only.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Color_saturation_only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Color_saturation_only.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Five_crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Five_crop.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Random_affine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Random_affine.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Random_crop_quarter_size.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Random_crop_quarter_size.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Random_grayscale.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Random_grayscale.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Random_horizontal_flip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Random_horizontal_flip.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Random_resized_crop_for_Inception.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Random_resized_crop_for_Inception.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Random_rotation_10_degrees.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Random_rotation_10_degrees.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Random_vertical_flip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Random_vertical_flip.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Resize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Resize.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Resize_flip_brightness_contrast_rotate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Resize_flip_brightness_contrast_rotate.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Ten_crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Ten_crop.png -------------------------------------------------------------------------------- /benchmarks/benchmarking_Zero_padding_50x25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/benchmarks/benchmarking_Zero_padding_50x25.png -------------------------------------------------------------------------------- /opencv_transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbohnslav/opencv_transforms/fd91e4987a6929be9334b40f0f809d7a2709383f/opencv_transforms/__init__.py -------------------------------------------------------------------------------- /opencv_transforms/functional.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import torch 5 | from PIL import Image, ImageEnhance, ImageOps 6 | 7 | try: 8 | import accimage 9 | except ImportError: 10 | accimage = None 11 | import collections 12 | import numbers 13 | import types 14 | import warnings 15 | 16 | import cv2 17 | import numpy as np 18 | from PIL import Image 19 | 20 | _cv2_pad_to_str = { 21 | 'constant': cv2.BORDER_CONSTANT, 22 | 'edge': cv2.BORDER_REPLICATE, 23 | 'reflect': cv2.BORDER_REFLECT_101, 24 | 'symmetric': cv2.BORDER_REFLECT 25 | } 26 | _cv2_interpolation_to_str = { 27 | 'nearest': cv2.INTER_NEAREST, 28 | 'bilinear': cv2.INTER_LINEAR, 29 | 'area': cv2.INTER_AREA, 30 | 'bicubic': cv2.INTER_CUBIC, 31 | 'lanczos': cv2.INTER_LANCZOS4 32 | } 33 | _cv2_interpolation_from_str = {v: k for k, v in _cv2_interpolation_to_str.items()} 34 | 35 | 36 | def _is_pil_image(img): 37 | if accimage is not None: 38 | return isinstance(img, (Image.Image, accimage.Image)) 39 | else: 40 | return isinstance(img, Image.Image) 41 | 42 | 43 | def _is_tensor_image(img): 44 | return torch.is_tensor(img) and img.ndimension() == 3 45 | 46 | 47 | def _is_numpy_image(img): 48 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 49 | 50 | 51 | def to_tensor(pic): 52 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 53 | See ``ToTensor`` for more details. 54 | Args: 55 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 56 | Returns: 57 | Tensor: Converted image. 58 | """ 59 | if not (_is_numpy_image(pic)): 60 | raise TypeError('pic should be ndarray. Got {}'.format(type(pic))) 61 | 62 | # handle numpy array 63 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 64 | # backward compatibility 65 | if isinstance(img, torch.ByteTensor) or img.dtype == torch.uint8: 66 | return img.float().div(255) 67 | else: 68 | return img 69 | 70 | 71 | def normalize(tensor, mean, std): 72 | """Normalize a tensor image with mean and standard deviation. 73 | .. note:: 74 | This transform acts in-place, i.e., it mutates the input tensor. 75 | See :class:`~torchvision.transforms.Normalize` for more details. 76 | Args: 77 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 78 | mean (sequence): Sequence of means for each channel. 79 | std (sequence): Sequence of standard deviations for each channely. 80 | Returns: 81 | Tensor: Normalized Tensor image. 82 | """ 83 | if not _is_tensor_image(tensor): 84 | raise TypeError('tensor is not a torch image.') 85 | 86 | # This is faster than using broadcasting, don't change without benchmarking 87 | for t, m, s in zip(tensor, mean, std): 88 | t.sub_(m).div_(s) 89 | return tensor 90 | 91 | 92 | def resize(img, size, interpolation=cv2.INTER_LINEAR): 93 | r"""Resize the input numpy ndarray to the given size. 94 | Args: 95 | img (numpy ndarray): Image to be resized. 96 | size (sequence or int): Desired output size. If size is a sequence like 97 | (h, w), the output size will be matched to this. If size is an int, 98 | the smaller edge of the image will be matched to this number maintaing 99 | the aspect ratio. i.e, if height > width, then image will be rescaled to 100 | :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)` 101 | interpolation (int, optional): Desired interpolation. Default is 102 | ``cv2.INTER_LINEAR`` 103 | Returns: 104 | PIL Image: Resized image. 105 | """ 106 | if not _is_numpy_image(img): 107 | raise TypeError('img should be numpy image. Got {}'.format(type(img))) 108 | if not (isinstance(size, int) or (isinstance(size, collections.abc.Iterable) and len(size) == 2)): 109 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 110 | h, w = img.shape[0], img.shape[1] 111 | 112 | if isinstance(size, int): 113 | if (w <= h and w == size) or (h <= w and h == size): 114 | return img 115 | if w < h: 116 | ow = size 117 | oh = int(size * h / w) 118 | else: 119 | oh = size 120 | ow = int(size * w / h) 121 | else: 122 | ow, oh = size[1], size[0] 123 | output = cv2.resize(img, dsize=(ow, oh), interpolation=interpolation) 124 | if img.shape[2] == 1: 125 | return output[:, :, np.newaxis] 126 | else: 127 | return output 128 | 129 | 130 | def scale(*args, **kwargs): 131 | warnings.warn("The use of the transforms.Scale transform is deprecated, " + "please use transforms.Resize instead.") 132 | return resize(*args, **kwargs) 133 | 134 | 135 | def pad(img, padding, fill=0, padding_mode='constant'): 136 | r"""Pad the given numpy ndarray on all sides with specified padding mode and fill value. 137 | Args: 138 | img (numpy ndarray): image to be padded. 139 | padding (int or tuple): Padding on each border. If a single int is provided this 140 | is used to pad all borders. If tuple of length 2 is provided this is the padding 141 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 142 | this is the padding for the left, top, right and bottom borders 143 | respectively. 144 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 145 | length 3, it is used to fill R, G, B channels respectively. 146 | This value is only used when the padding_mode is constant 147 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 148 | - constant: pads with a constant value, this value is specified with fill 149 | - edge: pads with the last value on the edge of the image 150 | - reflect: pads with reflection of image (without repeating the last value on the edge) 151 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 152 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 153 | - symmetric: pads with reflection of image (repeating the last value on the edge) 154 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 155 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 156 | Returns: 157 | Numpy image: padded image. 158 | """ 159 | if not _is_numpy_image(img): 160 | raise TypeError('img should be numpy ndarray. Got {}'.format(type(img))) 161 | if not isinstance(padding, (numbers.Number, tuple, list)): 162 | raise TypeError('Got inappropriate padding arg') 163 | if not isinstance(fill, (numbers.Number, str, tuple)): 164 | raise TypeError('Got inappropriate fill arg') 165 | if not isinstance(padding_mode, str): 166 | raise TypeError('Got inappropriate padding_mode arg') 167 | if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: 168 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 169 | "{} element tuple".format(len(padding))) 170 | 171 | assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \ 172 | 'Padding mode should be either constant, edge, reflect or symmetric' 173 | 174 | if isinstance(padding, int): 175 | pad_left = pad_right = pad_top = pad_bottom = padding 176 | if isinstance(padding, collections.Sequence) and len(padding) == 2: 177 | pad_left = pad_right = padding[0] 178 | pad_top = pad_bottom = padding[1] 179 | if isinstance(padding, collections.Sequence) and len(padding) == 4: 180 | pad_left = padding[0] 181 | pad_top = padding[1] 182 | pad_right = padding[2] 183 | pad_bottom = padding[3] 184 | if img.shape[2] == 1: 185 | return cv2.copyMakeBorder(img, 186 | top=pad_top, 187 | bottom=pad_bottom, 188 | left=pad_left, 189 | right=pad_right, 190 | borderType=_cv2_pad_to_str[padding_mode], 191 | value=fill)[:, :, np.newaxis] 192 | else: 193 | return cv2.copyMakeBorder(img, 194 | top=pad_top, 195 | bottom=pad_bottom, 196 | left=pad_left, 197 | right=pad_right, 198 | borderType=_cv2_pad_to_str[padding_mode], 199 | value=fill) 200 | 201 | 202 | def crop(img, i, j, h, w): 203 | """Crop the given PIL Image. 204 | Args: 205 | img (numpy ndarray): Image to be cropped. 206 | i: Upper pixel coordinate. 207 | j: Left pixel coordinate. 208 | h: Height of the cropped image. 209 | w: Width of the cropped image. 210 | Returns: 211 | numpy ndarray: Cropped image. 212 | """ 213 | if not _is_numpy_image(img): 214 | raise TypeError('img should be numpy image. Got {}'.format(type(img))) 215 | 216 | return img[i:i + h, j:j + w, :] 217 | 218 | 219 | def center_crop(img, output_size): 220 | if isinstance(output_size, numbers.Number): 221 | output_size = (int(output_size), int(output_size)) 222 | h, w = img.shape[0:2] 223 | th, tw = output_size 224 | i = int(round((h - th) / 2.)) 225 | j = int(round((w - tw) / 2.)) 226 | return crop(img, i, j, th, tw) 227 | 228 | 229 | def resized_crop(img, i, j, h, w, size, interpolation=cv2.INTER_LINEAR): 230 | """Crop the given numpy ndarray and resize it to desired size. 231 | Notably used in :class:`~torchvision.transforms.RandomResizedCrop`. 232 | Args: 233 | img (numpy ndarray): Image to be cropped. 234 | i: Upper pixel coordinate. 235 | j: Left pixel coordinate. 236 | h: Height of the cropped image. 237 | w: Width of the cropped image. 238 | size (sequence or int): Desired output size. Same semantics as ``scale``. 239 | interpolation (int, optional): Desired interpolation. Default is 240 | ``cv2.INTER_CUBIC``. 241 | Returns: 242 | PIL Image: Cropped image. 243 | """ 244 | assert _is_numpy_image(img), 'img should be numpy image' 245 | img = crop(img, i, j, h, w) 246 | img = resize(img, size, interpolation=interpolation) 247 | return img 248 | 249 | 250 | def hflip(img): 251 | """Horizontally flip the given numpy ndarray. 252 | Args: 253 | img (numpy ndarray): image to be flipped. 254 | Returns: 255 | numpy ndarray: Horizontally flipped image. 256 | """ 257 | if not _is_numpy_image(img): 258 | raise TypeError('img should be numpy image. Got {}'.format(type(img))) 259 | # img[:,::-1] is much faster, but doesn't work with torch.from_numpy()! 260 | if img.shape[2] == 1: 261 | return cv2.flip(img, 1)[:, :, np.newaxis] 262 | else: 263 | return cv2.flip(img, 1) 264 | 265 | 266 | def vflip(img): 267 | """Vertically flip the given numpy ndarray. 268 | Args: 269 | img (numpy ndarray): Image to be flipped. 270 | Returns: 271 | numpy ndarray: Vertically flipped image. 272 | """ 273 | if not _is_numpy_image(img): 274 | raise TypeError('img should be numpy Image. Got {}'.format(type(img))) 275 | if img.shape[2] == 1: 276 | return cv2.flip(img, 0)[:, :, np.newaxis] 277 | else: 278 | return cv2.flip(img, 0) 279 | # img[::-1] is much faster, but doesn't work with torch.from_numpy()! 280 | 281 | 282 | def five_crop(img, size): 283 | """Crop the given numpy ndarray into four corners and the central crop. 284 | .. Note:: 285 | This transform returns a tuple of images and there may be a 286 | mismatch in the number of inputs and targets your ``Dataset`` returns. 287 | Args: 288 | size (sequence or int): Desired output size of the crop. If size is an 289 | int instead of sequence like (h, w), a square crop (size, size) is 290 | made. 291 | Returns: 292 | tuple: tuple (tl, tr, bl, br, center) 293 | Corresponding top left, top right, bottom left, bottom right and center crop. 294 | """ 295 | if isinstance(size, numbers.Number): 296 | size = (int(size), int(size)) 297 | else: 298 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 299 | 300 | h, w = img.shape[0:2] 301 | crop_h, crop_w = size 302 | if crop_w > w or crop_h > h: 303 | raise ValueError("Requested crop size {} is bigger than input size {}".format(size, (h, w))) 304 | tl = crop(img, 0, 0, crop_h, crop_w) 305 | tr = crop(img, 0, w - crop_w, crop_h, crop_w) 306 | bl = crop(img, h - crop_h, 0, crop_h, crop_w) 307 | br = crop(img, h - crop_h, w - crop_w, crop_h, crop_w) 308 | center = center_crop(img, (crop_h, crop_w)) 309 | return tl, tr, bl, br, center 310 | 311 | 312 | def ten_crop(img, size, vertical_flip=False): 313 | r"""Crop the given numpy ndarray into four corners and the central crop plus the 314 | flipped version of these (horizontal flipping is used by default). 315 | .. Note:: 316 | This transform returns a tuple of images and there may be a 317 | mismatch in the number of inputs and targets your ``Dataset`` returns. 318 | Args: 319 | size (sequence or int): Desired output size of the crop. If size is an 320 | int instead of sequence like (h, w), a square crop (size, size) is 321 | made. 322 | vertical_flip (bool): Use vertical flipping instead of horizontal 323 | Returns: 324 | tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) 325 | Corresponding top left, top right, bottom left, bottom right and center crop 326 | and same for the flipped image. 327 | """ 328 | if isinstance(size, numbers.Number): 329 | size = (int(size), int(size)) 330 | else: 331 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 332 | 333 | first_five = five_crop(img, size) 334 | 335 | if vertical_flip: 336 | img = vflip(img) 337 | else: 338 | img = hflip(img) 339 | 340 | second_five = five_crop(img, size) 341 | return first_five + second_five 342 | 343 | 344 | def adjust_brightness(img, brightness_factor): 345 | """Adjust brightness of an Image. 346 | Args: 347 | img (numpy ndarray): numpy ndarray to be adjusted. 348 | brightness_factor (float): How much to adjust the brightness. Can be 349 | any non negative number. 0 gives a black image, 1 gives the 350 | original image while 2 increases the brightness by a factor of 2. 351 | Returns: 352 | numpy ndarray: Brightness adjusted image. 353 | """ 354 | if not _is_numpy_image(img): 355 | raise TypeError('img should be numpy Image. Got {}'.format(type(img))) 356 | table = np.array([i * brightness_factor for i in range(0, 256)]).clip(0, 255).astype('uint8') 357 | # same thing but a bit slower 358 | # cv2.convertScaleAbs(img, alpha=brightness_factor, beta=0) 359 | if img.shape[2] == 1: 360 | return cv2.LUT(img, table)[:, :, np.newaxis] 361 | else: 362 | return cv2.LUT(img, table) 363 | 364 | 365 | def adjust_contrast(img, contrast_factor): 366 | """Adjust contrast of an mage. 367 | Args: 368 | img (numpy ndarray): numpy ndarray to be adjusted. 369 | contrast_factor (float): How much to adjust the contrast. Can be any 370 | non negative number. 0 gives a solid gray image, 1 gives the 371 | original image while 2 increases the contrast by a factor of 2. 372 | Returns: 373 | numpy ndarray: Contrast adjusted image. 374 | """ 375 | # much faster to use the LUT construction than anything else I've tried 376 | # it's because you have to change dtypes multiple times 377 | if not _is_numpy_image(img): 378 | raise TypeError('img should be numpy Image. Got {}'.format(type(img))) 379 | 380 | # input is RGB 381 | if img.ndim > 2 and img.shape[2] == 3: 382 | mean_value = round(cv2.mean(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY))[0]) 383 | elif img.ndim == 2: 384 | # grayscale input 385 | mean_value = round(cv2.mean(img)[0]) 386 | else: 387 | # multichannel input 388 | mean_value = round(np.mean(img)) 389 | 390 | table = np.array([(i - mean_value) * contrast_factor + mean_value for i in range(0, 256)]).clip(0, 391 | 255).astype('uint8') 392 | # enhancer = ImageEnhance.Contrast(img) 393 | # img = enhancer.enhance(contrast_factor) 394 | if img.ndim == 2 or img.shape[2] == 1: 395 | return cv2.LUT(img, table)[:, :, np.newaxis] 396 | else: 397 | return cv2.LUT(img, table) 398 | 399 | 400 | def adjust_saturation(img, saturation_factor): 401 | """Adjust color saturation of an image. 402 | Args: 403 | img (numpy ndarray): numpy ndarray to be adjusted. 404 | saturation_factor (float): How much to adjust the saturation. 0 will 405 | give a black and white image, 1 will give the original image while 406 | 2 will enhance the saturation by a factor of 2. 407 | Returns: 408 | numpy ndarray: Saturation adjusted image. 409 | """ 410 | # ~10ms slower than PIL! 411 | if not _is_numpy_image(img): 412 | raise TypeError('img should be numpy Image. Got {}'.format(type(img))) 413 | img = Image.fromarray(img) 414 | enhancer = ImageEnhance.Color(img) 415 | img = enhancer.enhance(saturation_factor) 416 | return np.array(img) 417 | 418 | 419 | def adjust_hue(img, hue_factor): 420 | """Adjust hue of an image. 421 | The image hue is adjusted by converting the image to HSV and 422 | cyclically shifting the intensities in the hue channel (H). 423 | The image is then converted back to original image mode. 424 | `hue_factor` is the amount of shift in H channel and must be in the 425 | interval `[-0.5, 0.5]`. 426 | See `Hue`_ for more details. 427 | .. _Hue: https://en.wikipedia.org/wiki/Hue 428 | Args: 429 | img (numpy ndarray): numpy ndarray to be adjusted. 430 | hue_factor (float): How much to shift the hue channel. Should be in 431 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 432 | HSV space in positive and negative direction respectively. 433 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 434 | with complementary colors while 0 gives the original image. 435 | Returns: 436 | numpy ndarray: Hue adjusted image. 437 | """ 438 | # After testing, found that OpenCV calculates the Hue in a call to 439 | # cv2.cvtColor(..., cv2.COLOR_BGR2HSV) differently from PIL 440 | 441 | # This function takes 160ms! should be avoided 442 | if not (-0.5 <= hue_factor <= 0.5): 443 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 444 | if not _is_numpy_image(img): 445 | raise TypeError('img should be numpy Image. Got {}'.format(type(img))) 446 | img = Image.fromarray(img) 447 | input_mode = img.mode 448 | if input_mode in {'L', '1', 'I', 'F'}: 449 | return np.array(img) 450 | 451 | h, s, v = img.convert('HSV').split() 452 | 453 | np_h = np.array(h, dtype=np.uint8) 454 | # uint8 addition take cares of rotation across boundaries 455 | with np.errstate(over='ignore'): 456 | np_h += np.uint8(hue_factor * 255) 457 | h = Image.fromarray(np_h, 'L') 458 | 459 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 460 | return np.array(img) 461 | 462 | 463 | def adjust_gamma(img, gamma, gain=1): 464 | r"""Perform gamma correction on an image. 465 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 466 | based on the following equation: 467 | .. math:: 468 | I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma} 469 | See `Gamma Correction`_ for more details. 470 | .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction 471 | Args: 472 | img (numpy ndarray): numpy ndarray to be adjusted. 473 | gamma (float): Non negative real number, same as :math:`\gamma` in the equation. 474 | gamma larger than 1 make the shadows darker, 475 | while gamma smaller than 1 make dark regions lighter. 476 | gain (float): The constant multiplier. 477 | """ 478 | if not _is_numpy_image(img): 479 | raise TypeError('img should be numpy Image. Got {}'.format(type(img))) 480 | 481 | if gamma < 0: 482 | raise ValueError('Gamma should be a non-negative real number') 483 | # from here 484 | # https://stackoverflow.com/questions/33322488/how-to-change-image-illumination-in-opencv-python/41061351 485 | table = np.array([((i / 255.0)**gamma) * 255 * gain for i in np.arange(0, 256)]).astype('uint8') 486 | if img.shape[2] == 1: 487 | return cv2.LUT(img, table)[:, :, np.newaxis] 488 | else: 489 | return cv2.LUT(img, table) 490 | 491 | 492 | def rotate(img, angle, resample=False, expand=False, center=None): 493 | """Rotate the image by angle. 494 | Args: 495 | img (numpy ndarray): numpy ndarray to be rotated. 496 | angle (float or int): In degrees degrees counter clockwise order. 497 | resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): 498 | An optional resampling filter. See `filters`_ for more information. 499 | If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. 500 | expand (bool, optional): Optional expansion flag. 501 | If true, expands the output image to make it large enough to hold the entire rotated image. 502 | If false or omitted, make the output image the same size as the input image. 503 | Note that the expand flag assumes rotation around the center and no translation. 504 | center (2-tuple, optional): Optional center of rotation. 505 | Origin is the upper left corner. 506 | Default is the center of the image. 507 | .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters 508 | """ 509 | if not _is_numpy_image(img): 510 | raise TypeError('img should be numpy Image. Got {}'.format(type(img))) 511 | rows, cols = img.shape[0:2] 512 | if center is None: 513 | center = (cols / 2, rows / 2) 514 | M = cv2.getRotationMatrix2D(center, angle, 1) 515 | if img.shape[2] == 1: 516 | return cv2.warpAffine(img, M, (cols, rows))[:, :, np.newaxis] 517 | else: 518 | return cv2.warpAffine(img, M, (cols, rows)) 519 | 520 | 521 | def _get_affine_matrix(center, angle, translate, scale, shear): 522 | # Helper method to compute matrix for affine transformation 523 | # We need compute affine transformation matrix: M = T * C * RSS * C^-1 524 | # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] 525 | # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] 526 | # RSS is rotation with scale and shear matrix 527 | # RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0] 528 | # [ sin(a)*scale cos(a + shear)*scale 0] 529 | # [ 0 0 1] 530 | 531 | angle = math.radians(angle) 532 | shear = math.radians(shear) 533 | # scale = 1.0 / scale 534 | 535 | T = np.array([[1, 0, translate[0]], [0, 1, translate[1]], [0, 0, 1]]) 536 | C = np.array([[1, 0, center[0]], [0, 1, center[1]], [0, 0, 1]]) 537 | RSS = np.array([[math.cos(angle) * scale, -math.sin(angle + shear) * scale, 0], 538 | [math.sin(angle) * scale, math.cos(angle + shear) * scale, 0], [0, 0, 1]]) 539 | matrix = T @ C @ RSS @ np.linalg.inv(C) 540 | 541 | return matrix[:2, :] 542 | 543 | 544 | def affine(img, angle, translate, scale, shear, interpolation=cv2.INTER_LINEAR, mode=cv2.BORDER_CONSTANT, fillcolor=0): 545 | """Apply affine transformation on the image keeping image center invariant 546 | Args: 547 | img (numpy ndarray): numpy ndarray to be transformed. 548 | angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction. 549 | translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) 550 | scale (float): overall scale 551 | shear (float): shear angle value in degrees between -180 to 180, clockwise direction. 552 | interpolation (``cv2.INTER_NEAREST` or ``cv2.INTER_LINEAR`` or ``cv2.INTER_AREA``, ``cv2.INTER_CUBIC``): 553 | An optional resampling filter. 554 | See `filters`_ for more information. 555 | If omitted, it is set to ``cv2.INTER_CUBIC``, for bicubic interpolation. 556 | mode (``cv2.BORDER_CONSTANT`` or ``cv2.BORDER_REPLICATE`` or ``cv2.BORDER_REFLECT`` or ``cv2.BORDER_REFLECT_101``) 557 | Method for filling in border regions. 558 | Defaults to cv2.BORDER_CONSTANT, meaning areas outside the image are filled with a value (val, default 0) 559 | val (int): Optional fill color for the area outside the transform in the output image. Default: 0 560 | """ 561 | if not _is_numpy_image(img): 562 | raise TypeError('img should be numpy Image. Got {}'.format(type(img))) 563 | 564 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 565 | "Argument translate should be a list or tuple of length 2" 566 | 567 | assert scale > 0.0, "Argument scale should be positive" 568 | 569 | output_size = img.shape[0:2] 570 | center = (img.shape[1] * 0.5 + 0.5, img.shape[0] * 0.5 + 0.5) 571 | matrix = _get_affine_matrix(center, angle, translate, scale, shear) 572 | 573 | if img.shape[2] == 1: 574 | return cv2.warpAffine(img, matrix, output_size[::-1], interpolation, borderMode=mode, 575 | borderValue=fillcolor)[:, :, np.newaxis] 576 | else: 577 | return cv2.warpAffine(img, matrix, output_size[::-1], interpolation, borderMode=mode, borderValue=fillcolor) 578 | 579 | 580 | def to_grayscale(img, num_output_channels: int = 1): 581 | """Convert image to grayscale version of image. 582 | Args: 583 | img (numpy ndarray): Image to be converted to grayscale. 584 | num_output_channels: int 585 | if 1 : returned image is single channel 586 | if 3 : returned image is 3 channel with r = g = b 587 | Returns: 588 | numpy ndarray: Grayscale version of the image. 589 | """ 590 | if not _is_numpy_image(img): 591 | raise TypeError('img should be numpy ndarray. Got {}'.format(type(img))) 592 | 593 | if num_output_channels == 1: 594 | img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis] 595 | elif num_output_channels == 3: 596 | # much faster than doing cvtColor to go back to gray 597 | img = np.broadcast_to(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis], img.shape) 598 | return img 599 | -------------------------------------------------------------------------------- /opencv_transforms/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import collections 4 | import math 5 | import numbers 6 | import random 7 | import types 8 | import warnings 9 | 10 | # from PIL import Image, ImageOps, ImageEnhance 11 | try: 12 | import accimage 13 | except ImportError: 14 | accimage = None 15 | import cv2 16 | import numpy as np 17 | import torch 18 | 19 | from . import functional as F 20 | 21 | __all__ = [ 22 | "Compose", "ToTensor", "Normalize", "Resize", "Scale", 23 | "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", 24 | "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", 25 | "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", 26 | "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", 27 | "Grayscale", "RandomGrayscale" 28 | ] 29 | 30 | _cv2_pad_to_str = { 31 | 'constant': cv2.BORDER_CONSTANT, 32 | 'edge': cv2.BORDER_REPLICATE, 33 | 'reflect': cv2.BORDER_REFLECT_101, 34 | 'symmetric': cv2.BORDER_REFLECT 35 | } 36 | _cv2_interpolation_to_str = { 37 | 'nearest': cv2.INTER_NEAREST, 38 | 'bilinear': cv2.INTER_LINEAR, 39 | 'area': cv2.INTER_AREA, 40 | 'bicubic': cv2.INTER_CUBIC, 41 | 'lanczos': cv2.INTER_LANCZOS4 42 | } 43 | _cv2_interpolation_from_str = { 44 | v: k 45 | for k, v in _cv2_interpolation_to_str.items() 46 | } 47 | 48 | 49 | class Compose(object): 50 | """Composes several transforms together. 51 | Args: 52 | transforms (list of ``Transform`` objects): list of transforms to compose. 53 | Example: 54 | >>> transforms.Compose([ 55 | >>> transforms.CenterCrop(10), 56 | >>> transforms.ToTensor(), 57 | >>> ]) 58 | """ 59 | def __init__(self, transforms): 60 | self.transforms = transforms 61 | 62 | def __call__(self, img): 63 | for t in self.transforms: 64 | img = t(img) 65 | return img 66 | 67 | def __repr__(self): 68 | format_string = self.__class__.__name__ + '(' 69 | for t in self.transforms: 70 | format_string += '\n' 71 | format_string += ' {0}'.format(t) 72 | format_string += '\n)' 73 | return format_string 74 | 75 | 76 | class ToTensor(object): 77 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 78 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 79 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 80 | """ 81 | def __call__(self, pic): 82 | """ 83 | Args: 84 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 85 | Returns: 86 | Tensor: Converted image. 87 | """ 88 | return F.to_tensor(pic) 89 | 90 | def __repr__(self): 91 | return self.__class__.__name__ + '()' 92 | 93 | 94 | class Normalize(object): 95 | """Normalize a tensor image with mean and standard deviation. 96 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 97 | will normalize each channel of the input ``torch.*Tensor`` i.e. 98 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 99 | .. note:: 100 | This transform acts in-place, i.e., it mutates the input tensor. 101 | Args: 102 | mean (sequence): Sequence of means for each channel. 103 | std (sequence): Sequence of standard deviations for each channel. 104 | """ 105 | def __init__(self, mean, std): 106 | self.mean = mean 107 | self.std = std 108 | 109 | def __call__(self, tensor): 110 | """ 111 | Args: 112 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 113 | Returns: 114 | Tensor: Normalized Tensor image. 115 | """ 116 | return F.normalize(tensor, self.mean, self.std) 117 | 118 | def __repr__(self): 119 | return self.__class__.__name__ + '(mean={0}, std={1})'.format( 120 | self.mean, self.std) 121 | 122 | 123 | class Resize(object): 124 | """Resize the input numpy ndarray to the given size. 125 | Args: 126 | size (sequence or int): Desired output size. If size is a sequence like 127 | (h, w), output size will be matched to this. If size is an int, 128 | smaller edge of the image will be matched to this number. 129 | i.e, if height > width, then image will be rescaled to 130 | (size * height / width, size) 131 | interpolation (int, optional): Desired interpolation. Default is 132 | ``cv2.INTER_CUBIC``, bicubic interpolation 133 | """ 134 | 135 | def __init__(self, size, interpolation=cv2.INTER_LINEAR): 136 | # assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 137 | if isinstance(size, int): 138 | self.size = size 139 | elif isinstance(size, collections.abc.Iterable) and len(size) == 2: 140 | if type(size) == list: 141 | size = tuple(size) 142 | self.size = size 143 | else: 144 | raise ValueError('Unknown inputs for size: {}'.format(size)) 145 | self.interpolation = interpolation 146 | 147 | def __call__(self, img): 148 | """ 149 | Args: 150 | img (numpy ndarray): Image to be scaled. 151 | Returns: 152 | numpy ndarray: Rescaled image. 153 | """ 154 | return F.resize(img, self.size, self.interpolation) 155 | 156 | def __repr__(self): 157 | interpolate_str = _cv2_interpolation_from_str[self.interpolation] 158 | return self.__class__.__name__ + '(size={0}, interpolation={1})'.format( 159 | self.size, interpolate_str) 160 | 161 | 162 | class Scale(Resize): 163 | """ 164 | Note: This transform is deprecated in favor of Resize. 165 | """ 166 | def __init__(self, *args, **kwargs): 167 | warnings.warn( 168 | "The use of the transforms.Scale transform is deprecated, " + 169 | "please use transforms.Resize instead.") 170 | super(Scale, self).__init__(*args, **kwargs) 171 | 172 | 173 | class CenterCrop(object): 174 | """Crops the given numpy ndarray at the center. 175 | Args: 176 | size (sequence or int): Desired output size of the crop. If size is an 177 | int instead of sequence like (h, w), a square crop (size, size) is 178 | made. 179 | """ 180 | def __init__(self, size): 181 | if isinstance(size, numbers.Number): 182 | self.size = (int(size), int(size)) 183 | else: 184 | self.size = size 185 | 186 | def __call__(self, img): 187 | """ 188 | Args: 189 | img (numpy ndarray): Image to be cropped. 190 | Returns: 191 | numpy ndarray: Cropped image. 192 | """ 193 | return F.center_crop(img, self.size) 194 | 195 | def __repr__(self): 196 | return self.__class__.__name__ + '(size={0})'.format(self.size) 197 | 198 | 199 | class Pad(object): 200 | """Pad the given numpy ndarray on all sides with the given "pad" value. 201 | Args: 202 | padding (int or tuple): Padding on each border. If a single int is provided this 203 | is used to pad all borders. If tuple of length 2 is provided this is the padding 204 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 205 | this is the padding for the left, top, right and bottom borders 206 | respectively. 207 | fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of 208 | length 3, it is used to fill R, G, B channels respectively. 209 | This value is only used when the padding_mode is constant 210 | padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. 211 | Default is constant. 212 | - constant: pads with a constant value, this value is specified with fill 213 | - edge: pads with the last value at the edge of the image 214 | - reflect: pads with reflection of image without repeating the last value on the edge 215 | For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 216 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 217 | - symmetric: pads with reflection of image repeating the last value on the edge 218 | For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 219 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 220 | """ 221 | def __init__(self, padding, fill=0, padding_mode='constant'): 222 | assert isinstance(padding, (numbers.Number, tuple, list)) 223 | assert isinstance(fill, (numbers.Number, str, tuple)) 224 | assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] 225 | if isinstance(padding, 226 | collections.Sequence) and len(padding) not in [2, 4]: 227 | raise ValueError( 228 | "Padding must be an int or a 2, or 4 element tuple, not a " + 229 | "{} element tuple".format(len(padding))) 230 | 231 | self.padding = padding 232 | self.fill = fill 233 | self.padding_mode = padding_mode 234 | 235 | def __call__(self, img): 236 | """ 237 | Args: 238 | img (numpy ndarray): Image to be padded. 239 | Returns: 240 | numpy ndarray: Padded image. 241 | """ 242 | return F.pad(img, self.padding, self.fill, self.padding_mode) 243 | 244 | def __repr__(self): 245 | return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ 246 | format(self.padding, self.fill, self.padding_mode) 247 | 248 | 249 | class Lambda(object): 250 | """Apply a user-defined lambda as a transform. 251 | Args: 252 | lambd (function): Lambda/function to be used for transform. 253 | """ 254 | def __init__(self, lambd): 255 | assert isinstance(lambd, types.LambdaType) 256 | self.lambd = lambd 257 | 258 | def __call__(self, img): 259 | return self.lambd(img) 260 | 261 | def __repr__(self): 262 | return self.__class__.__name__ + '()' 263 | 264 | 265 | class RandomTransforms(object): 266 | """Base class for a list of transformations with randomness 267 | Args: 268 | transforms (list or tuple): list of transformations 269 | """ 270 | def __init__(self, transforms): 271 | assert isinstance(transforms, (list, tuple)) 272 | self.transforms = transforms 273 | 274 | def __call__(self, *args, **kwargs): 275 | raise NotImplementedError() 276 | 277 | def __repr__(self): 278 | format_string = self.__class__.__name__ + '(' 279 | for t in self.transforms: 280 | format_string += '\n' 281 | format_string += ' {0}'.format(t) 282 | format_string += '\n)' 283 | return format_string 284 | 285 | 286 | class RandomApply(RandomTransforms): 287 | """Apply randomly a list of transformations with a given probability 288 | Args: 289 | transforms (list or tuple): list of transformations 290 | p (float): probability 291 | """ 292 | def __init__(self, transforms, p=0.5): 293 | super(RandomApply, self).__init__(transforms) 294 | self.p = p 295 | 296 | def __call__(self, img): 297 | if self.p < random.random(): 298 | return img 299 | for t in self.transforms: 300 | img = t(img) 301 | return img 302 | 303 | def __repr__(self): 304 | format_string = self.__class__.__name__ + '(' 305 | format_string += '\n p={}'.format(self.p) 306 | for t in self.transforms: 307 | format_string += '\n' 308 | format_string += ' {0}'.format(t) 309 | format_string += '\n)' 310 | return format_string 311 | 312 | 313 | class RandomOrder(RandomTransforms): 314 | """Apply a list of transformations in a random order 315 | """ 316 | def __call__(self, img): 317 | order = list(range(len(self.transforms))) 318 | random.shuffle(order) 319 | for i in order: 320 | img = self.transforms[i](img) 321 | return img 322 | 323 | 324 | class RandomChoice(RandomTransforms): 325 | """Apply single transformation randomly picked from a list 326 | """ 327 | def __call__(self, img): 328 | t = random.choice(self.transforms) 329 | return t(img) 330 | 331 | 332 | class RandomCrop(object): 333 | """Crop the given numpy ndarray at a random location. 334 | Args: 335 | size (sequence or int): Desired output size of the crop. If size is an 336 | int instead of sequence like (h, w), a square crop (size, size) is 337 | made. 338 | padding (int or sequence, optional): Optional padding on each border 339 | of the image. Default is None, i.e no padding. If a sequence of length 340 | 4 is provided, it is used to pad left, top, right, bottom borders 341 | respectively. If a sequence of length 2 is provided, it is used to 342 | pad left/right, top/bottom borders, respectively. 343 | pad_if_needed (boolean): It will pad the image if smaller than the 344 | desired size to avoid raising an exception. 345 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 346 | length 3, it is used to fill R, G, B channels respectively. 347 | This value is only used when the padding_mode is constant 348 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 349 | - constant: pads with a constant value, this value is specified with fill 350 | - edge: pads with the last value on the edge of the image 351 | - reflect: pads with reflection of image (without repeating the last value on the edge) 352 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 353 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 354 | - symmetric: pads with reflection of image (repeating the last value on the edge) 355 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 356 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 357 | """ 358 | def __init__(self, 359 | size, 360 | padding=None, 361 | pad_if_needed=False, 362 | fill=0, 363 | padding_mode='constant'): 364 | if isinstance(size, numbers.Number): 365 | self.size = (int(size), int(size)) 366 | else: 367 | self.size = size 368 | self.padding = padding 369 | self.pad_if_needed = pad_if_needed 370 | self.fill = fill 371 | self.padding_mode = padding_mode 372 | 373 | @staticmethod 374 | def get_params(img, output_size): 375 | """Get parameters for ``crop`` for a random crop. 376 | Args: 377 | img (numpy ndarray): Image to be cropped. 378 | output_size (tuple): Expected output size of the crop. 379 | Returns: 380 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 381 | """ 382 | h, w = img.shape[0:2] 383 | th, tw = output_size 384 | if w == tw and h == th: 385 | return 0, 0, h, w 386 | 387 | i = random.randint(0, h - th) 388 | j = random.randint(0, w - tw) 389 | return i, j, th, tw 390 | 391 | def __call__(self, img): 392 | """ 393 | Args: 394 | img (numpy ndarray): Image to be cropped. 395 | Returns: 396 | numpy ndarray: Cropped image. 397 | """ 398 | if self.padding is not None: 399 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 400 | 401 | # pad the width if needed 402 | if self.pad_if_needed and img.shape[1] < self.size[1]: 403 | img = F.pad(img, (self.size[1] - img.shape[1], 0), self.fill, 404 | self.padding_mode) 405 | # pad the height if needed 406 | if self.pad_if_needed and img.shape[0] < self.size[0]: 407 | img = F.pad(img, (0, self.size[0] - img.shape[0]), self.fill, 408 | self.padding_mode) 409 | 410 | i, j, h, w = self.get_params(img, self.size) 411 | 412 | return F.crop(img, i, j, h, w) 413 | 414 | def __repr__(self): 415 | return self.__class__.__name__ + '(size={0}, padding={1})'.format( 416 | self.size, self.padding) 417 | 418 | 419 | class RandomHorizontalFlip(object): 420 | """Horizontally flip the given PIL Image randomly with a given probability. 421 | Args: 422 | p (float): probability of the image being flipped. Default value is 0.5 423 | """ 424 | def __init__(self, p=0.5): 425 | self.p = p 426 | 427 | def __call__(self, img): 428 | """random 429 | Args: 430 | img (numpy ndarray): Image to be flipped. 431 | Returns: 432 | numpy ndarray: Randomly flipped image. 433 | """ 434 | # if random.random() < self.p: 435 | # print('flip') 436 | # return F.hflip(img) 437 | if random.random() < self.p: 438 | return F.hflip(img) 439 | return img 440 | 441 | def __repr__(self): 442 | return self.__class__.__name__ + '(p={})'.format(self.p) 443 | 444 | 445 | class RandomVerticalFlip(object): 446 | """Vertically flip the given PIL Image randomly with a given probability. 447 | Args: 448 | p (float): probability of the image being flipped. Default value is 0.5 449 | """ 450 | def __init__(self, p=0.5): 451 | self.p = p 452 | 453 | def __call__(self, img): 454 | """ 455 | Args: 456 | img (numpy ndarray): Image to be flipped. 457 | Returns: 458 | numpy ndarray: Randomly flipped image. 459 | """ 460 | if random.random() < self.p: 461 | return F.vflip(img) 462 | return img 463 | 464 | def __repr__(self): 465 | return self.__class__.__name__ + '(p={})'.format(self.p) 466 | 467 | 468 | class RandomResizedCrop(object): 469 | """Crop the given numpy ndarray to random size and aspect ratio. 470 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 471 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 472 | is finally resized to given size. 473 | This is popularly used to train the Inception networks. 474 | Args: 475 | size: expected output size of each edge 476 | scale: range of size of the origin size cropped 477 | ratio: range of aspect ratio of the origin aspect ratio cropped 478 | interpolation: Default: cv2.INTER_CUBIC 479 | """ 480 | def __init__(self, 481 | size, 482 | scale=(0.08, 1.0), 483 | ratio=(3. / 4., 4. / 3.), 484 | interpolation=cv2.INTER_LINEAR): 485 | self.size = (size, size) 486 | self.interpolation = interpolation 487 | self.scale = scale 488 | self.ratio = ratio 489 | 490 | @staticmethod 491 | def get_params(img, scale, ratio): 492 | """Get parameters for ``crop`` for a random sized crop. 493 | Args: 494 | img (numpy ndarray): Image to be cropped. 495 | scale (tuple): range of size of the origin size cropped 496 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 497 | Returns: 498 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 499 | sized crop. 500 | """ 501 | for attempt in range(10): 502 | area = img.shape[0] * img.shape[1] 503 | target_area = random.uniform(*scale) * area 504 | aspect_ratio = random.uniform(*ratio) 505 | 506 | w = int(round(math.sqrt(target_area * aspect_ratio))) 507 | h = int(round(math.sqrt(target_area / aspect_ratio))) 508 | 509 | if random.random() < 0.5: 510 | w, h = h, w 511 | 512 | if w <= img.shape[1] and h <= img.shape[0]: 513 | i = random.randint(0, img.shape[0] - h) 514 | j = random.randint(0, img.shape[1] - w) 515 | return i, j, h, w 516 | 517 | # Fallback 518 | w = min(img.shape[0], img.shape[1]) 519 | i = (img.shape[0] - w) // 2 520 | j = (img.shape[1] - w) // 2 521 | return i, j, w, w 522 | 523 | def __call__(self, img): 524 | """ 525 | Args: 526 | img (numpy ndarray): Image to be cropped and resized. 527 | Returns: 528 | numpy ndarray: Randomly cropped and resized image. 529 | """ 530 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 531 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) 532 | 533 | def __repr__(self): 534 | interpolate_str = _cv2_interpolation_from_str[self.interpolation] 535 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 536 | format_string += ', scale={0}'.format( 537 | tuple(round(s, 4) for s in self.scale)) 538 | format_string += ', ratio={0}'.format( 539 | tuple(round(r, 4) for r in self.ratio)) 540 | format_string += ', interpolation={0})'.format(interpolate_str) 541 | return format_string 542 | 543 | 544 | class RandomSizedCrop(RandomResizedCrop): 545 | """ 546 | Note: This transform is deprecated in favor of RandomResizedCrop. 547 | """ 548 | def __init__(self, *args, **kwargs): 549 | warnings.warn( 550 | "The use of the transforms.RandomSizedCrop transform is deprecated, " 551 | + "please use transforms.RandomResizedCrop instead.") 552 | super(RandomSizedCrop, self).__init__(*args, **kwargs) 553 | 554 | 555 | class FiveCrop(object): 556 | """Crop the given numpy ndarray into four corners and the central crop 557 | .. Note:: 558 | This transform returns a tuple of images and there may be a mismatch in the number of 559 | inputs and targets your Dataset returns. See below for an example of how to deal with 560 | this. 561 | Args: 562 | size (sequence or int): Desired output size of the crop. If size is an ``int`` 563 | instead of sequence like (h, w), a square crop of size (size, size) is made. 564 | Example: 565 | >>> transform = Compose([ 566 | >>> FiveCrop(size), # this is a list of numpy ndarrays 567 | >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor 568 | >>> ]) 569 | >>> #In your test loop you can do the following: 570 | >>> input, target = batch # input is a 5d tensor, target is 2d 571 | >>> bs, ncrops, c, h, w = input.size() 572 | >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops 573 | >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops 574 | """ 575 | def __init__(self, size): 576 | self.size = size 577 | if isinstance(size, numbers.Number): 578 | self.size = (int(size), int(size)) 579 | else: 580 | assert len( 581 | size 582 | ) == 2, "Please provide only two dimensions (h, w) for size." 583 | self.size = size 584 | 585 | def __call__(self, img): 586 | return F.five_crop(img, self.size) 587 | 588 | def __repr__(self): 589 | return self.__class__.__name__ + '(size={0})'.format(self.size) 590 | 591 | 592 | class TenCrop(object): 593 | """Crop the given numpy ndarray into four corners and the central crop plus the flipped version of 594 | these (horizontal flipping is used by default) 595 | .. Note:: 596 | This transform returns a tuple of images and there may be a mismatch in the number of 597 | inputs and targets your Dataset returns. See below for an example of how to deal with 598 | this. 599 | Args: 600 | size (sequence or int): Desired output size of the crop. If size is an 601 | int instead of sequence like (h, w), a square crop (size, size) is 602 | made. 603 | vertical_flip(bool): Use vertical flipping instead of horizontal 604 | Example: 605 | >>> transform = Compose([ 606 | >>> TenCrop(size), # this is a list of PIL Images 607 | >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor 608 | >>> ]) 609 | >>> #In your test loop you can do the following: 610 | >>> input, target = batch # input is a 5d tensor, target is 2d 611 | >>> bs, ncrops, c, h, w = input.size() 612 | >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops 613 | >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops 614 | """ 615 | def __init__(self, size, vertical_flip=False): 616 | self.size = size 617 | if isinstance(size, numbers.Number): 618 | self.size = (int(size), int(size)) 619 | else: 620 | assert len( 621 | size 622 | ) == 2, "Please provide only two dimensions (h, w) for size." 623 | self.size = size 624 | self.vertical_flip = vertical_flip 625 | 626 | def __call__(self, img): 627 | return F.ten_crop(img, self.size, self.vertical_flip) 628 | 629 | def __repr__(self): 630 | return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format( 631 | self.size, self.vertical_flip) 632 | 633 | 634 | class LinearTransformation(object): 635 | """Transform a tensor image with a square transformation matrix computed 636 | offline. 637 | Given transformation_matrix, will flatten the torch.*Tensor, compute the dot 638 | product with the transformation matrix and reshape the tensor to its 639 | original shape. 640 | Applications: 641 | - whitening: zero-center the data, compute the data covariance matrix 642 | [D x D] with np.dot(X.T, X), perform SVD on this matrix and 643 | pass it as transformation_matrix. 644 | Args: 645 | transformation_matrix (Tensor): tensor [D x D], D = C x H x W 646 | """ 647 | def __init__(self, transformation_matrix): 648 | if transformation_matrix.size(0) != transformation_matrix.size(1): 649 | raise ValueError("transformation_matrix should be square. Got " + 650 | "[{} x {}] rectangular matrix.".format( 651 | *transformation_matrix.size())) 652 | self.transformation_matrix = transformation_matrix 653 | 654 | def __call__(self, tensor): 655 | """ 656 | Args: 657 | tensor (Tensor): Tensor image of size (C, H, W) to be whitened. 658 | Returns: 659 | Tensor: Transformed image. 660 | """ 661 | if tensor.size(0) * tensor.size(1) * tensor.size( 662 | 2) != self.transformation_matrix.size(0): 663 | raise ValueError( 664 | "tensor and transformation matrix have incompatible shape." + 665 | "[{} x {} x {}] != ".format(*tensor.size()) + 666 | "{}".format(self.transformation_matrix.size(0))) 667 | flat_tensor = tensor.view(1, -1) 668 | transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) 669 | tensor = transformed_tensor.view(tensor.size()) 670 | return tensor 671 | 672 | def __repr__(self): 673 | format_string = self.__class__.__name__ + '(' 674 | format_string += (str(self.transformation_matrix.numpy().tolist()) + 675 | ')') 676 | return format_string 677 | 678 | 679 | class ColorJitter(object): 680 | """Randomly change the brightness, contrast and saturation of an image. 681 | Args: 682 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 683 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 684 | or the given [min, max]. Should be non negative numbers. 685 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 686 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 687 | or the given [min, max]. Should be non negative numbers. 688 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 689 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 690 | or the given [min, max]. Should be non negative numbers. 691 | hue (float or tuple of float (min, max)): How much to jitter hue. 692 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 693 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 694 | """ 695 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 696 | self.brightness = self._check_input(brightness, 'brightness') 697 | self.contrast = self._check_input(contrast, 'contrast') 698 | self.saturation = self._check_input(saturation, 'saturation') 699 | self.hue = self._check_input(hue, 700 | 'hue', 701 | center=0, 702 | bound=(-0.5, 0.5), 703 | clip_first_on_zero=False) 704 | if self.saturation is not None: 705 | warnings.warn( 706 | 'Saturation jitter enabled. Will slow down loading immensely.') 707 | if self.hue is not None: 708 | warnings.warn( 709 | 'Hue jitter enabled. Will slow down loading immensely.') 710 | 711 | def _check_input(self, 712 | value, 713 | name, 714 | center=1, 715 | bound=(0, float('inf')), 716 | clip_first_on_zero=True): 717 | if isinstance(value, numbers.Number): 718 | if value < 0: 719 | raise ValueError( 720 | "If {} is a single number, it must be non negative.". 721 | format(name)) 722 | value = [center - value, center + value] 723 | if clip_first_on_zero: 724 | value[0] = max(value[0], 0) 725 | elif isinstance(value, (tuple, list)) and len(value) == 2: 726 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 727 | raise ValueError("{} values should be between {}".format( 728 | name, bound)) 729 | else: 730 | raise TypeError( 731 | "{} should be a single number or a list/tuple with length 2.". 732 | format(name)) 733 | 734 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 735 | # or (0., 0.) for hue, do nothing 736 | if value[0] == value[1] == center: 737 | value = None 738 | return value 739 | 740 | @staticmethod 741 | def get_params(brightness, contrast, saturation, hue): 742 | """Get a randomized transform to be applied on image. 743 | Arguments are same as that of __init__. 744 | Returns: 745 | Transform which randomly adjusts brightness, contrast and 746 | saturation in a random order. 747 | """ 748 | transforms = [] 749 | 750 | if brightness is not None: 751 | brightness_factor = random.uniform(brightness[0], brightness[1]) 752 | transforms.append( 753 | Lambda( 754 | lambda img: F.adjust_brightness(img, brightness_factor))) 755 | 756 | if contrast is not None: 757 | contrast_factor = random.uniform(contrast[0], contrast[1]) 758 | transforms.append( 759 | Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 760 | 761 | if saturation is not None: 762 | saturation_factor = random.uniform(saturation[0], saturation[1]) 763 | transforms.append( 764 | Lambda( 765 | lambda img: F.adjust_saturation(img, saturation_factor))) 766 | 767 | if hue is not None: 768 | hue_factor = random.uniform(hue[0], hue[1]) 769 | transforms.append( 770 | Lambda(lambda img: F.adjust_hue(img, hue_factor))) 771 | 772 | random.shuffle(transforms) 773 | transform = Compose(transforms) 774 | 775 | return transform 776 | 777 | def __call__(self, img): 778 | """ 779 | Args: 780 | img (numpy ndarray): Input image. 781 | Returns: 782 | numpy ndarray: Color jittered image. 783 | """ 784 | transform = self.get_params(self.brightness, self.contrast, 785 | self.saturation, self.hue) 786 | return transform(img) 787 | 788 | def __repr__(self): 789 | format_string = self.__class__.__name__ + '(' 790 | format_string += 'brightness={0}'.format(self.brightness) 791 | format_string += ', contrast={0}'.format(self.contrast) 792 | format_string += ', saturation={0}'.format(self.saturation) 793 | format_string += ', hue={0})'.format(self.hue) 794 | return format_string 795 | 796 | 797 | class RandomRotation(object): 798 | """Rotate the image by angle. 799 | Args: 800 | degrees (sequence or float or int): Range of degrees to select from. 801 | If degrees is a number instead of sequence like (min, max), the range of degrees 802 | will be (-degrees, +degrees). 803 | resample ({cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4}, optional): 804 | An optional resampling filter. See `filters`_ for more information. 805 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 806 | expand (bool, optional): Optional expansion flag. 807 | If true, expands the output to make it large enough to hold the entire rotated image. 808 | If false or omitted, make the output image the same size as the input image. 809 | Note that the expand flag assumes rotation around the center and no translation. 810 | center (2-tuple, optional): Optional center of rotation. 811 | Origin is the upper left corner. 812 | Default is the center of the image. 813 | """ 814 | def __init__(self, degrees, resample=False, expand=False, center=None): 815 | if isinstance(degrees, numbers.Number): 816 | if degrees < 0: 817 | raise ValueError( 818 | "If degrees is a single number, it must be positive.") 819 | self.degrees = (-degrees, degrees) 820 | else: 821 | if len(degrees) != 2: 822 | raise ValueError( 823 | "If degrees is a sequence, it must be of len 2.") 824 | self.degrees = degrees 825 | 826 | self.resample = resample 827 | self.expand = expand 828 | self.center = center 829 | 830 | @staticmethod 831 | def get_params(degrees): 832 | """Get parameters for ``rotate`` for a random rotation. 833 | Returns: 834 | sequence: params to be passed to ``rotate`` for random rotation. 835 | """ 836 | angle = random.uniform(degrees[0], degrees[1]) 837 | 838 | return angle 839 | 840 | def __call__(self, img): 841 | """ 842 | img (numpy ndarray): Image to be rotated. 843 | Returns: 844 | numpy ndarray: Rotated image. 845 | """ 846 | 847 | angle = self.get_params(self.degrees) 848 | 849 | return F.rotate(img, angle, self.resample, self.expand, self.center) 850 | 851 | def __repr__(self): 852 | format_string = self.__class__.__name__ + '(degrees={0}'.format( 853 | self.degrees) 854 | format_string += ', resample={0}'.format(self.resample) 855 | format_string += ', expand={0}'.format(self.expand) 856 | if self.center is not None: 857 | format_string += ', center={0}'.format(self.center) 858 | format_string += ')' 859 | return format_string 860 | 861 | 862 | class RandomAffine(object): 863 | """Random affine transformation of the image keeping center invariant 864 | Args: 865 | degrees (sequence or float or int): Range of degrees to select from. 866 | If degrees is a number instead of sequence like (min, max), the range of degrees 867 | will be (-degrees, +degrees). Set to 0 to deactivate rotations. 868 | translate (tuple, optional): tuple of maximum absolute fraction for horizontal 869 | and vertical translations. For example translate=(a, b), then horizontal shift 870 | is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is 871 | randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. 872 | scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is 873 | randomly sampled from the range a <= scale <= b. Will keep original scale by default. 874 | shear (sequence or float or int, optional): Range of degrees to select from. 875 | If degrees is a number instead of sequence like (min, max), the range of degrees 876 | will be (-degrees, +degrees). Will not apply shear by default 877 | resample ({cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4}, optional): 878 | An optional resampling filter. See `filters`_ for more information. 879 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 880 | fillcolor (int): Optional fill color for the area outside the transform in the output image. 881 | """ 882 | def __init__(self, 883 | degrees, 884 | translate=None, 885 | scale=None, 886 | shear=None, 887 | interpolation=cv2.INTER_LINEAR, 888 | fillcolor=0): 889 | if isinstance(degrees, numbers.Number): 890 | if degrees < 0: 891 | raise ValueError( 892 | "If degrees is a single number, it must be positive.") 893 | self.degrees = (-degrees, degrees) 894 | else: 895 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ 896 | "degrees should be a list or tuple and it must be of length 2." 897 | self.degrees = degrees 898 | 899 | if translate is not None: 900 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 901 | "translate should be a list or tuple and it must be of length 2." 902 | for t in translate: 903 | if not (0.0 <= t <= 1.0): 904 | raise ValueError( 905 | "translation values should be between 0 and 1") 906 | self.translate = translate 907 | 908 | if scale is not None: 909 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 910 | "scale should be a list or tuple and it must be of length 2." 911 | for s in scale: 912 | if s <= 0: 913 | raise ValueError("scale values should be positive") 914 | self.scale = scale 915 | 916 | if shear is not None: 917 | if isinstance(shear, numbers.Number): 918 | if shear < 0: 919 | raise ValueError( 920 | "If shear is a single number, it must be positive.") 921 | self.shear = (-shear, shear) 922 | else: 923 | assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ 924 | "shear should be a list or tuple and it must be of length 2." 925 | self.shear = shear 926 | else: 927 | self.shear = shear 928 | 929 | # self.resample = resample 930 | self.interpolation = interpolation 931 | self.fillcolor = fillcolor 932 | 933 | @staticmethod 934 | def get_params(degrees, translate, scale_ranges, shears, img_size): 935 | """Get parameters for affine transformation 936 | Returns: 937 | sequence: params to be passed to the affine transformation 938 | """ 939 | angle = random.uniform(degrees[0], degrees[1]) 940 | if translate is not None: 941 | max_dx = translate[0] * img_size[0] 942 | max_dy = translate[1] * img_size[1] 943 | translations = (np.round(random.uniform(-max_dx, max_dx)), 944 | np.round(random.uniform(-max_dy, max_dy))) 945 | else: 946 | translations = (0, 0) 947 | 948 | if scale_ranges is not None: 949 | scale = random.uniform(scale_ranges[0], scale_ranges[1]) 950 | else: 951 | scale = 1.0 952 | 953 | if shears is not None: 954 | shear = random.uniform(shears[0], shears[1]) 955 | else: 956 | shear = 0.0 957 | 958 | return angle, translations, scale, shear 959 | 960 | def __call__(self, img): 961 | """ 962 | img (numpy ndarray): Image to be transformed. 963 | Returns: 964 | numpy ndarray: Affine transformed image. 965 | """ 966 | ret = self.get_params(self.degrees, self.translate, self.scale, 967 | self.shear, (img.shape[1], img.shape[0])) 968 | return F.affine(img, 969 | *ret, 970 | interpolation=self.interpolation, 971 | fillcolor=self.fillcolor) 972 | 973 | def __repr__(self): 974 | s = '{name}(degrees={degrees}' 975 | if self.translate is not None: 976 | s += ', translate={translate}' 977 | if self.scale is not None: 978 | s += ', scale={scale}' 979 | if self.shear is not None: 980 | s += ', shear={shear}' 981 | if self.resample > 0: 982 | s += ', resample={resample}' 983 | if self.fillcolor != 0: 984 | s += ', fillcolor={fillcolor}' 985 | s += ')' 986 | d = dict(self.__dict__) 987 | d['resample'] = _cv2_interpolation_to_str[d['resample']] 988 | return s.format(name=self.__class__.__name__, **d) 989 | 990 | 991 | class Grayscale(object): 992 | """Convert image to grayscale. 993 | Args: 994 | num_output_channels (int): (1 or 3) number of channels desired for output image 995 | Returns: 996 | numpy ndarray: Grayscale version of the input. 997 | - If num_output_channels == 1 : returned image is single channel 998 | - If num_output_channels == 3 : returned image is 3 channel with r == g == b 999 | """ 1000 | def __init__(self, num_output_channels=1): 1001 | self.num_output_channels = num_output_channels 1002 | 1003 | def __call__(self, img): 1004 | """ 1005 | Args: 1006 | img (numpy ndarray): Image to be converted to grayscale. 1007 | Returns: 1008 | numpy ndarray: Randomly grayscaled image. 1009 | """ 1010 | return F.to_grayscale(img, 1011 | num_output_channels=self.num_output_channels) 1012 | 1013 | def __repr__(self): 1014 | return self.__class__.__name__ + '(num_output_channels={0})'.format( 1015 | self.num_output_channels) 1016 | 1017 | 1018 | class RandomGrayscale(object): 1019 | """Randomly convert image to grayscale with a probability of p (default 0.1). 1020 | Args: 1021 | p (float): probability that image should be converted to grayscale. 1022 | Returns: 1023 | numpy ndarray: Grayscale version of the input image with probability p and unchanged 1024 | with probability (1-p). 1025 | - If input image is 1 channel: grayscale version is 1 channel 1026 | - If input image is 3 channel: grayscale version is 3 channel with r == g == b 1027 | """ 1028 | def __init__(self, p=0.1): 1029 | self.p = p 1030 | 1031 | def __call__(self, img): 1032 | """ 1033 | Args: 1034 | img (numpy ndarray): Image to be converted to grayscale. 1035 | Returns: 1036 | numpy ndarray: Randomly grayscaled image. 1037 | """ 1038 | num_output_channels = 3 1039 | if random.random() < self.p: 1040 | return F.to_grayscale(img, num_output_channels=num_output_channels) 1041 | return img 1042 | 1043 | def __repr__(self): 1044 | return self.__class__.__name__ + '(p={0})'.format(self.p) 1045 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open('README.md', 'r') as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name='opencv_transforms', 8 | version='0.0.6', 9 | author='Jim Bohnslav', 10 | author_email='JBohnslav@gmail.com', 11 | description='A drop-in replacement for Torchvision Transforms using OpenCV', 12 | keywords='pytorch image augmentations', 13 | long_description=long_description, 14 | long_description_content_type='text/markdown', 15 | url='https://github.com/jbohnslav/opencv_transforms', 16 | packages=setuptools.find_packages(), 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ], 22 | python_requires='>=3.6', 23 | ) -------------------------------------------------------------------------------- /tests/compare_to_pil_for_testing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import glob\n", 10 | "import numpy as np\n", 11 | "import random\n", 12 | "\n", 13 | "import cv2\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "from PIL import Image\n", 16 | "\n", 17 | "from torchvision import transforms as pil_transforms\n", 18 | "from torchvision.transforms import functional as F_pil\n", 19 | "\n", 20 | "import sys\n", 21 | "sys.path.insert(0, '..')\n", 22 | "from opencv_transforms import transforms\n", 23 | "from opencv_transforms import functional as F\n", 24 | "\n", 25 | "from setup_testing_directory import get_testing_directory" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "datadir = get_testing_directory()\n", 35 | "print(datadir)" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "train_images = glob.glob(datadir + '/**/*.JPEG', recursive=True)\n", 45 | "train_images.sort()\n", 46 | "print('Number of training images: {:,}'.format(len(train_images)))" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "random.seed(1)\n", 56 | "imfile = random.choice(train_images)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "def plot_pil_and_opencv(pil_image, opencv_image, orientation='row'):\n", 66 | " if orientation == 'row':\n", 67 | " rows, cols = 1,3\n", 68 | " size = (8, 4)\n", 69 | " else: \n", 70 | " rows, cols = 3,1\n", 71 | " size = (12, 6)\n", 72 | " fig, axes = plt.subplots(rows, cols,figsize=size)\n", 73 | " ax = axes[0]\n", 74 | " ax.imshow(pil_image)\n", 75 | " ax.set_title('PIL')\n", 76 | "\n", 77 | " ax = axes[1]\n", 78 | " ax.imshow(opencv_image)\n", 79 | " ax.set_title('opencv')\n", 80 | "\n", 81 | " ax = axes[2]\n", 82 | " l1 = np.abs(pil_image - opencv_image).mean(axis=2)\n", 83 | " ax.imshow(l1)\n", 84 | " ax.set_title('| PIL - opencv|\\nMAE:{:.4f}'.format(l1.mean()))\n", 85 | " plt.tight_layout()\n", 86 | " plt.show()" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "pil_image = Image.open(imfile)\n", 96 | "image = cv2.cvtColor(cv2.imread(imfile, 1), cv2.COLOR_BGR2RGB)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "plot_pil_and_opencv(pil_image, image)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "pil_resized = pil_transforms.Resize((224, 224))(pil_image)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "resized = transforms.Resize(224)(image)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "plot_pil_and_opencv(pil_resized, resized)\n", 133 | "plt.show()" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "def L1(pil: Image, image: np.ndarray) -> float:\n", 143 | " return np.mean(np.abs(np.asarray(pil) - image))" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "TOL = 1e-4\n", 153 | "\n", 154 | "l1 = L1(pil_resized, resized)\n", 155 | "assert l1 - 88.9559 < TOL" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "random.seed(1)\n", 165 | "pil = pil_transforms.RandomRotation(10)(pil_image)\n", 166 | "random.seed(1)\n", 167 | "np_img = transforms.RandomRotation(10)(image)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "plot_pil_and_opencv(pil, np_img)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "pil = pil_transforms.FiveCrop((224, 224))(pil_image)\n", 186 | "cv = transforms.FiveCrop((224,224))(image)" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "pil_stacked = np.hstack([np.asarray(i) for i in pil])\n", 196 | "cv_stacked = np.hstack(cv)\n", 197 | "\n", 198 | "plot_pil_and_opencv(pil_stacked, cv_stacked, orientation='col')" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "pil_stacked.shape" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "l1" 217 | ] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "opencv_transforms", 223 | "language": "python", 224 | "name": "opencv_transforms" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.7.9" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 4 241 | } 242 | -------------------------------------------------------------------------------- /tests/setup_testing_directory.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Union 4 | import warnings 5 | 6 | 7 | def get_testing_directory() -> str: 8 | directory_file = 'testing_directory.txt' 9 | directory_files = [directory_file, os.path.join('tests', directory_file)] 10 | 11 | for directory_file in directory_files: 12 | if os.path.isfile(directory_file): 13 | with open(directory_file, 'r') as f: 14 | testing_directory = f.read() 15 | return testing_directory 16 | raise ValueError('please run setup_testing_directory.py before attempting to run unit tests') 17 | 18 | 19 | def setup_testing_directory(datadir: Union[str, os.PathLike], overwrite: bool = False) -> str: 20 | testing_path_file = 'testing_directory.txt' 21 | 22 | should_setup = True 23 | if os.path.isfile(testing_path_file): 24 | with open(testing_path_file, 'r') as f: 25 | testing_directory = f.read() 26 | if not os.path.isfile(testing_directory): 27 | raise ValueError('saved testing directory {} does not exist, re-run ') 28 | warnings.warn( 29 | 'Saved testing directory {} does not exist, downloading Thumos14...'.format(testing_directory)) 30 | else: 31 | should_setup = False 32 | if not should_setup: 33 | return testing_directory 34 | 35 | testing_directory = datadir 36 | assert os.path.isdir(testing_directory) 37 | assert os.path.isdir(os.path.join(testing_directory, 'train')) 38 | assert os.path.isdir(os.path.join(testing_directory, 'val')) 39 | with open('testing_directory.txt', 'w') as f: 40 | f.write(testing_directory) 41 | return testing_directory 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser('Setting up image directory for opencv transforms testing') 46 | parser.add_argument('-d', '--datadir', default=os.getcwd(), help='Imagenet directory') 47 | 48 | args = parser.parse_args() 49 | 50 | setup_testing_directory(args.datadir) -------------------------------------------------------------------------------- /tests/test_color.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import random 4 | from typing import Union 5 | 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | from PIL import Image 9 | from PIL.Image import Image as PIL_image # for typing 10 | import pytest 11 | from torchvision import transforms as pil_transforms 12 | from torchvision.transforms import functional as F_pil 13 | 14 | from opencv_transforms import transforms 15 | from opencv_transforms import functional as F 16 | from setup_testing_directory import get_testing_directory 17 | 18 | TOL = 1e-4 19 | 20 | datadir = get_testing_directory() 21 | train_images = glob.glob(datadir + '/**/*.JPEG', recursive=True) 22 | train_images.sort() 23 | print('Number of training images: {:,}'.format(len(train_images))) 24 | 25 | random.seed(1) 26 | imfile = random.choice(train_images) 27 | pil_image = Image.open(imfile) 28 | image = cv2.cvtColor(cv2.imread(imfile, 1), cv2.COLOR_BGR2RGB) 29 | 30 | 31 | class TestContrast: 32 | @pytest.mark.parametrize('random_seed', [1, 2, 3, 4]) 33 | @pytest.mark.parametrize('contrast_factor', [0.0, 0.5, 1.0, 2.0]) 34 | def test_contrast(self, contrast_factor, random_seed): 35 | random.seed(random_seed) 36 | imfile = random.choice(train_images) 37 | pil_image = Image.open(imfile) 38 | image = np.array(pil_image).copy() 39 | 40 | pil_enhanced = F_pil.adjust_contrast(pil_image, contrast_factor) 41 | np_enhanced = F.adjust_contrast(image, contrast_factor) 42 | assert np.array_equal(np.array(pil_enhanced), np_enhanced.squeeze()) 43 | 44 | @pytest.mark.parametrize('n_images', [1, 11]) 45 | def test_multichannel_contrast(self, n_images, contrast_factor=0.1): 46 | imfile = random.choice(train_images) 47 | 48 | pil_image = Image.open(imfile) 49 | image = np.array(pil_image).copy() 50 | 51 | multichannel_image = np.concatenate([image for _ in range(n_images)], axis=-1) 52 | # this will raise an exception in version 0.0.5 53 | np_enchanced = F.adjust_contrast(multichannel_image, contrast_factor) 54 | 55 | @pytest.mark.parametrize('contrast_factor', [0, 0.5, 1.0]) 56 | def test_grayscale_contrast(self, contrast_factor): 57 | imfile = random.choice(train_images) 58 | 59 | pil_image = Image.open(imfile) 60 | image = np.array(pil_image).copy() 61 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 62 | 63 | # make sure grayscale images work 64 | pil_image = pil_image.convert('L') 65 | 66 | pil_enhanced = F_pil.adjust_contrast(pil_image, contrast_factor) 67 | np_enhanced = F.adjust_contrast(image, contrast_factor) 68 | assert np.array_equal(np.array(pil_enhanced), np_enhanced.squeeze()) 69 | -------------------------------------------------------------------------------- /tests/test_spatial.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import random 4 | from typing import Union 5 | 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | from PIL import Image 9 | from PIL.Image import Image as PIL_image # for typing 10 | 11 | from torchvision import transforms as pil_transforms 12 | from torchvision.transforms import functional as F_pil 13 | from opencv_transforms import transforms 14 | from opencv_transforms import functional as F 15 | 16 | from setup_testing_directory import get_testing_directory 17 | from utils import L1 18 | 19 | TOL = 1e-4 20 | 21 | datadir = get_testing_directory() 22 | train_images = glob.glob(datadir + '/**/*.JPEG', recursive=True) 23 | train_images.sort() 24 | print('Number of training images: {:,}'.format(len(train_images))) 25 | 26 | random.seed(1) 27 | imfile = random.choice(train_images) 28 | pil_image = Image.open(imfile) 29 | image = cv2.cvtColor(cv2.imread(imfile, 1), cv2.COLOR_BGR2RGB) 30 | 31 | 32 | def test_resize(): 33 | pil_resized = pil_transforms.Resize((224, 224))(pil_image) 34 | resized = transforms.Resize((224, 224))(image) 35 | l1 = L1(pil_resized, resized) 36 | assert l1 - 88.9559 < TOL 37 | 38 | def test_rotation(): 39 | random.seed(1) 40 | pil = pil_transforms.RandomRotation(10)(pil_image) 41 | random.seed(1) 42 | np_img = transforms.RandomRotation(10)(image) 43 | l1 = L1(pil, np_img) 44 | assert l1 - 86.7955 < TOL 45 | 46 | def test_five_crop(): 47 | pil = pil_transforms.FiveCrop((224, 224))(pil_image) 48 | cv = transforms.FiveCrop((224, 224))(image) 49 | pil_stacked = np.hstack([np.asarray(i) for i in pil]) 50 | cv_stacked = np.hstack(cv) 51 | l1 = L1(pil_stacked, cv_stacked) 52 | assert l1 - 22.0444 < TOL -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | from PIL.Image import Image as PIL_image # for typing 5 | 6 | 7 | def L1(pil: Union[PIL_image, np.ndarray], np_image: np.ndarray) -> float: 8 | return np.abs(np.asarray(pil) - np_image).mean() --------------------------------------------------------------------------------