├── LICENSE ├── README.md ├── imgs └── README.md ├── pascal_voc.py ├── test_bounding_boxes.py ├── test_keypoints.py ├── test_segmentation_mask.py └── transforms.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 uoip 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # transforms 2 | 3 | This project is used for image augmentation, featured in simultaneous transformation of image, keypoints, bounding boxes, and segmentation mask. It's extended from [torchvision](https://github.com/pytorch/vision) and project [imageUtils](https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0). Currently the implemented transformations include ColorJitter, RandomErasing, Expand, Scale, Resize, Crop, ElasticTransform, Rotate, Shift, and Flip. If you need more image augmentation types, you can take a look at [imgaug](https://github.com/aleju/imgaug), it's a very comprehensive library. 4 | 5 | Image transformations can be divided into two categories: 6 | * geometric transformations 7 | * photometric transformations 8 | 9 | Geometric transformations alter the geometry of the image with the aim of making algorithm invariant to change in position/orientation, and to image deformation. Photometric transformations amend the color channels with the objective of making algorithm invariant to change in lighting and color. 10 | 11 | For computer vision problems other than image classification, transforming image alone is often not enough. Say, for object detection, we should transform image and bounding boxes simultaneously, and for image segmentation, we should also transform the segmentation mask (mask should not be seen same as image, because for image geometric transformations, we often use interpolation to make transformed images visually pleased, 12 | but interpolation for mask is meaningless). This project concentrates on these problems. 13 | 14 | 15 | 16 | ### Keypoints 17 | ```python 18 | PRNG = RandomState() 19 | transform = Compose([ 20 | [ColorJitter(), None], # or write [ColorJitter()] 21 | Expand((0.8, 1.5)), 22 | RandomCompose([ 23 | RandomRotate(360), 24 | RandomShift(0.2)]), 25 | Scale(512), 26 | RandomCrop(512), 27 | HorizontalFlip(), 28 | ], 29 | PRNG, 30 | border='constant', 31 | fillval=0, 32 | outside_points='inf') 33 | 34 | # image: np.ndarray of shape (h, w, 3), RGB format 35 | # pts: np.ndarray of shape (N, 2), e.g. [[x1, y1], [x2, y2], ...] 36 | transformed_image, transformed_pts = transfrom(image, pts) 37 | ``` 38 | ![](https://i.loli.net/2018/01/06/5a5005a552e3b.gif) 39 | 40 | ### Bounding Boxes 41 | bounding boxes -> vertices coordinates -> transformed coordinates -> transformed bounding boxes. 42 | Below is the agumentation used by [SSD](https://arxiv.org/abs/1512.02325). 43 | ```python 44 | PRNG = RandomState() 45 | transform = Compose([ 46 | [ColorJitter(prob=0.5)], 47 | BoxesToCoords(relative=False), 48 | HorizontalFlip(), 49 | Expand((1, 4), prob=0.5), 50 | ObjectRandomCrop(), 51 | Resize(300), 52 | CoordsToBoxes(relative=False), 53 | ], 54 | PRNG, 55 | mode='linear', 56 | border='constant', 57 | fillval=0, 58 | outside_points='clamp') 59 | 60 | # image: np.ndarray of shape (h, w, 3), RGB format 61 | # bboxes: np.ndarray of shape (N, 4), e.g. [[xmin, ymin, xmax, ymax], ...] 62 | # note that the bboxes can be normalized to [0, 1] (yout should set 63 | # relative=True accordingly) or use pixel value directly (yout should set 64 | # relative=False). 65 | transformed_image, transformed_bboxes = transfrom(image, bboxes) 66 | ``` 67 | ![](https://i.loli.net/2018/01/06/5a5006787251c.gif) 68 | 69 | ### Segmentation Mask 70 | ```python 71 | transform = Compose([ 72 | [ColorJitter(), None], 73 | Merge(), 74 | Expand((0.7, 1.4)), 75 | RandomCompose([ 76 | RandomResize(1, 1.5), 77 | RandomRotate(5), 78 | RandomShift(0.1)]), 79 | Scale(512), 80 | ElasticTransform(150), 81 | RandomCrop(512), 82 | HorizontalFlip(), 83 | Split([0, 3], [3, 6]), 84 | ], 85 | PRNG, 86 | border='constant', 87 | fillval=0, 88 | anchor_index=3) 89 | # image: np.ndarray of shape (h, w, 3), RGB format 90 | # target: np.ndarray of shape (h, w, c) 91 | transformed_image, transformed_target = transfrom(image, target) 92 | ``` 93 | ![](https://i.loli.net/2018/01/06/5a5006c0d99b1.gif) 94 | Note that the example augmentations above are just for demonstration, there is no warranty that they are useful. 95 | 96 | ### More 97 | * For randomness part, we pass RandomState as class/function argument, instead of using global seeds. It's thread-safe; 98 | * This project has few photometric transformations, but provides a `Lambda` class, you can inserted 3rd party photometric transformation functions into our pipeline by using `Lambda`. 99 | * 100 | 101 | ### TODO 102 | * More transformations 103 | * Docstring or documentation 104 | 105 | ### Contact 106 | If you have problems related to this project, you can report isseus, or email me (qihang@outlook.com). 107 | -------------------------------------------------------------------------------- /imgs/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pascal_voc.py: -------------------------------------------------------------------------------- 1 | # https://github.com/amdegroot/ssd.pytorch/blob/master/data/voc0712.py 2 | # https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py 3 | 4 | import os 5 | import cv2 6 | import numpy as np 7 | # import torch.utils.data 8 | 9 | import sys 10 | import xml.etree.ElementTree as ET 11 | 12 | 13 | 14 | # VOC dataset file system: 15 | # VOCdevkit 16 | # -| VOC2007 17 | # -| Annotations 18 | # -| ImageSets 19 | # -| JPEGImages 20 | # -| SegmentationClass 21 | # -| SegmentationObject 22 | # -| VOC2012 23 | # -| Annotations 24 | # -| ImageSets 25 | # -| JPEGImages 26 | # -| SegmentationClass 27 | # -| SegmentationObject 28 | 29 | 30 | class VOC(object): 31 | # ROOT = 'path/to/your/VOCdevkit' 32 | N_CLASSES = 20 33 | CLASSES = ( 34 | 'aeroplane', 'bicycle', 'bird', 'boat', 35 | 'bottle', 'bus', 'car', 'cat', 'chair', 36 | 'cow', 'diningtable', 'dog', 'horse', 37 | 'motorbike', 'person', 'pottedplant', 38 | 'sheep', 'sofa', 'train', 'tvmonitor', 39 | ) 40 | 41 | MEAN = [123.68, 116.779, 103.939] # R, G, B 42 | 43 | label_to_id = dict(map(reversed, enumerate(CLASSES))) 44 | id_to_label = dict(enumerate(CLASSES)) 45 | 46 | 47 | 48 | 49 | class Viz(object): 50 | def __init__(self): 51 | voc = VOC() 52 | classes = voc.CLASSES 53 | 54 | self.id_to_label = voc.id_to_label 55 | self.label_to_id = voc.label_to_id 56 | 57 | colors = {} 58 | for label in classes: 59 | id = self.label_to_id[label] 60 | color = self._to_color(id, len(classes)) 61 | colors[id] = color 62 | colors[label] = color 63 | self.colors =colors 64 | 65 | def _to_color(self, indx, n_classes): 66 | base = int(np.ceil(pow(n_classes, 1./3))) 67 | base2 = base * base 68 | b = 2 - indx / base2 69 | r = 2 - (indx % base2) / base 70 | g = 2 - (indx % base2) % base 71 | return (r * 127, g * 127, b * 127) 72 | 73 | def draw_bbox(self, img, bboxes, labels, relative=False): 74 | if len(labels) == 0: 75 | return img 76 | img = img.copy() 77 | h, w = img.shape[:2] 78 | 79 | if relative: 80 | bboxes = bboxes * [w, h, w, h] 81 | 82 | bboxes = bboxes.astype(np.int) 83 | labels = labels.astype(np.int) 84 | 85 | for bbox, label in zip(bboxes, labels): 86 | left, top, right, bot = bbox 87 | color = self.colors[label] 88 | label = self.id_to_label[label] 89 | cv2.rectangle(img, (left, top), (right, bot), color, 2) 90 | cv2.putText(img, label, (left+1, top-5), cv2.FONT_HERSHEY_DUPLEX, 91 | 0.4, color, 1, cv2.LINE_AA) 92 | 93 | return img 94 | 95 | def blend_segmentation(self, img, target): 96 | mask = (target.max(axis=2) > 0)[..., np.newaxis] * 1. 97 | blend = img * 0.3 + target * 0.7 98 | 99 | img = (1 - mask) * img + mask * blend 100 | return img.astype('uint8') 101 | 102 | 103 | 104 | class ParseAnnotation(object): 105 | def __init__(self, keep_difficult=True): 106 | self.keep_difficult = keep_difficult 107 | 108 | voc = VOC() 109 | self.label_to_id = voc.label_to_id 110 | self.classes = voc.CLASSES 111 | 112 | def __call__(self, target): 113 | tree = ET.parse(target).getroot() 114 | 115 | bboxes = [] 116 | labels = [] 117 | for obj in tree.iter('object'): 118 | difficult = int(obj.find('difficult').text) == 1 119 | if not self.keep_difficult and difficult: 120 | continue 121 | 122 | label = obj.find('name').text.lower().strip() 123 | if label not in self.classes: 124 | continue 125 | label = self.label_to_id[label] 126 | 127 | bndbox = obj.find('bndbox') 128 | bbox = [int(bndbox.find(_).text) - 1 for _ in ['xmin', 'ymin', 'xmax', 'ymax']] 129 | 130 | bboxes.append(bbox) 131 | labels.append(label) 132 | 133 | return np.array(bboxes), np.array(labels) 134 | 135 | 136 | 137 | class VOCDetection(object): # torch.utils.data.Dataset 138 | def __init__(self, root, image_set, keep_difficult=False, transform=None, 139 | target_transform=None): 140 | self.root = root 141 | self.image_set = image_set 142 | self.transform = transform 143 | self.target_transform = target_transform 144 | 145 | self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg') 146 | self._annopath = os.path.join('%s', 'Annotations', '%s.xml') 147 | 148 | self.parse_annotation = ParseAnnotation(keep_difficult=keep_difficult) 149 | 150 | self.ids = [] 151 | for year, split in image_set: 152 | basepath = os.path.join(self.root, 'VOC' + str(year)) 153 | path = os.path.join(basepath, 'ImageSets', 'Main') 154 | for file in os.listdir(path): 155 | if not file.endswith('_' + split + '.txt'): 156 | continue 157 | with open(os.path.join(path, file)) as f: 158 | for line in f: 159 | self.ids.append((basepath, line.strip()[:-3])) 160 | 161 | self.ids = sorted(list(set(self.ids)), key=lambda _:_[0]+_[1]) # deterministic 162 | 163 | def __getitem__(self, index): 164 | img_id = self.ids[index] 165 | 166 | img = cv2.imread(self._imgpath % img_id)[:, :, ::-1] 167 | bboxes, labels = self.parse_annotation(self._annopath % img_id) 168 | 169 | if self.transform is not None: 170 | img, bboxes = self.transform(img, bboxes) 171 | 172 | # bboxes, labels = self.filter(img, bboxes, labels) 173 | if self.target_transform is not None: 174 | bboxes, labels = self.target_transform(bboxes, labels) 175 | return img, bboxes, labels 176 | 177 | 178 | def __len__(self): 179 | return len(self.ids) 180 | 181 | def filter(self, img, boxes, labels): 182 | shape = img.shape 183 | if len(shape) == 2: 184 | h, w = shape 185 | else: # !! 186 | if shape[0] > shape[2]: # HWC 187 | h, w = img.shape[:2] 188 | else: # CHW 189 | h, w = img.shape[1:] 190 | 191 | boxes_ = [] 192 | labels_ = [] 193 | for box, label in zip(boxes, labels): 194 | if min(box[2] - box[0], box[3] - box[1]) <= 0: 195 | continue 196 | if np.max(boxes) < 1 and np.sqrt((box[2] - box[0]) * w * (box[3] - box[1]) * h) < 8: 197 | #if np.max(boxes) < 1 and min((box[2] - box[0]) * w, (box[3] - box[1]) * h) < 5: 198 | continue 199 | boxes_.append(box) 200 | labels_.append(label) 201 | return np.array(boxes_), np.array(labels_) 202 | 203 | 204 | 205 | class VOCSegmentation(object): # torch.utils.data.Dataset 206 | def __init__(self, root, image_set, instance=False, transform=None): 207 | self.root = root 208 | self.image_set = image_set 209 | self.instance = instance 210 | self.transform = transform 211 | 212 | self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg') 213 | 214 | if self.instance: # instance segmentation 215 | self._segpath = os.path.join('%s', 'SegmentationObject', '%s.png') 216 | else: # semantic segmentation 217 | self._segpath = os.path.join('%s', 'SegmentationClass', '%s.png') 218 | 219 | self.ids = [] 220 | for year, split in image_set: 221 | basepath = os.path.join(root, 'VOC' + str(year)) 222 | path = os.path.join(basepath, 'ImageSets', 'Segmentation') 223 | for file in os.listdir(path): 224 | if (split + '.txt') != file: 225 | continue 226 | with open(os.path.join(path, file)) as f: 227 | for line in f: 228 | self.ids.append((basepath, line.strip())) 229 | 230 | 231 | def __getitem__(self, index): 232 | img_id = self.ids[index] 233 | 234 | img = cv2.imread(self._imgpath % img_id)[:,:,::-1] 235 | target = cv2.imread(self._segpath % img_id)[:,:,::-1] 236 | 237 | if self.transform is not None: 238 | img, target = self.transform(img, target) 239 | 240 | return img, target 241 | 242 | def __len__(self): 243 | return len(self.ids) -------------------------------------------------------------------------------- /test_bounding_boxes.py: -------------------------------------------------------------------------------- 1 | from numpy.random import RandomState 2 | # import imageio 3 | 4 | def test_bboxes(): 5 | PRNG = RandomState() 6 | PRNG2 = RandomState() 7 | if args.seed > 0: 8 | PRNG.seed(args.seed) 9 | PRNG2.seed(args.seed) 10 | 11 | transform = Compose([ 12 | [ColorJitter(prob=0.5)], # or write [ColorJitter(), None] 13 | BoxesToCoords(), 14 | HorizontalFlip(), 15 | Expand((1, 4), prob=0.5), 16 | ObjectRandomCrop(), 17 | Resize(300), 18 | CoordsToBoxes(), 19 | #[SubtractMean(mean=VOC.MEAN)], 20 | ], 21 | PRNG, 22 | mode=None, 23 | fillval=VOC.MEAN, 24 | outside_points='clamp') 25 | 26 | viz = Viz() 27 | voc_dataset = VOCDetection( 28 | root=args.root, 29 | image_set=[('2007', 'trainval')], 30 | keep_difficult=True, 31 | transform=transform) 32 | 33 | results = [] 34 | count = 0 35 | i = PRNG2.choice(len(voc_dataset)) 36 | for _ in range(100): 37 | img, boxes, labels = voc_dataset[i] 38 | if len(labels) == 0: 39 | continue 40 | 41 | img = viz.draw_bbox(img, boxes, labels, True) 42 | results.append(img) 43 | cv2.imshow('0', img[:, :, ::-1]) 44 | c = cv2.waitKey(500) 45 | if c == 27 or c == ord('q'): # ESC / 'q' 46 | break 47 | elif c == ord('c') or count >= 5: 48 | count = 0 49 | i = PRNG2.choice(len(voc_dataset)) 50 | count += 1 51 | 52 | # imageio.mimsave('bboxes.gif', results, duration=0.5) 53 | 54 | 55 | if __name__ == '__main__': 56 | from transforms import * 57 | from pascal_voc import VOC, VOCDetection, Viz 58 | 59 | import argparse 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('--root', type=str, help='voc dataset root path', default='path/to/your/VOCdevkit') 62 | parser.add_argument('--seed', type=int, help='random seed', default=0) 63 | args = parser.parse_args() 64 | 65 | test_bboxes() -------------------------------------------------------------------------------- /test_keypoints.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | from scipy import misc 5 | from numpy.random import RandomState 6 | 7 | from transforms import * 8 | # import imageio 9 | 10 | def test_keypoints(): 11 | feature_detector = cv2.ORB_create( 12 | nfeatures=500, scaleFactor=1.2, nlevels=1, edgeThreshold=31) 13 | 14 | image = misc.face() # RGB 15 | image = cv2.resize(image, None, fx=0.5, fy=0.5) 16 | print('image shape', image.shape) 17 | 18 | keypoints = feature_detector.detect(image[..., ::-1]) 19 | points = [kp.pt for kp in keypoints] 20 | print('num of keypoints', len(keypoints)) 21 | 22 | 23 | PRNG = RandomState() 24 | 25 | transform = Compose([ 26 | [ColorJitter(prob=0.75), None], 27 | Expand((0.8, 1.5)), 28 | RandomCompose([ 29 | RandomRotate(360), 30 | RandomShift(0.2)]), 31 | Scale(512), 32 | # ElasticTransform(300), 33 | RandomCrop(512), 34 | HorizontalFlip(), 35 | ], 36 | PRNG, 37 | border='constant', 38 | fillval=0, 39 | outside_points='inf') 40 | 41 | results = [] 42 | 43 | for _ in range(100): 44 | img, pts = transform(image, points) 45 | 46 | filtered = [] 47 | for pt in pts: 48 | x = [abs(pt[0]), abs(pt[1])] 49 | if np.inf not in x and np.nan not in x: 50 | filtered.append(pt) 51 | 52 | kps = [cv2.KeyPoint(*pt, 1) for pt in filtered] 53 | print('num of keypoints', len(kps)) 54 | 55 | img = cv2.drawKeypoints(img[..., ::-1], kps, None, flags=0) 56 | results.append(img[..., ::-1]) 57 | cv2.imshow('keypoints', img) 58 | c = cv2.waitKey(600) 59 | if c == 27 or c == ord('q'): # ESC / 'q' 60 | break 61 | 62 | # imageio.mimsave('keypoints.gif', results, duration=0.5) 63 | 64 | 65 | 66 | 67 | if __name__ == '__main__': 68 | test_keypoints() -------------------------------------------------------------------------------- /test_segmentation_mask.py: -------------------------------------------------------------------------------- 1 | from numpy.random import RandomState 2 | # import imageio 3 | 4 | def test_segmentation(): 5 | PRNG = RandomState() 6 | PRNG2 = RandomState() 7 | if args.seed > 0: 8 | PRNG.seed(args.seed) 9 | PRNG2.seed(args.seed) 10 | 11 | transform = Compose([ 12 | [ColorJitter(prob=0.75), None], 13 | Merge(), 14 | Expand((0.8, 1.5)), 15 | RandomCompose([ 16 | # RandomResize(1, 1.5), 17 | RandomRotate(10), 18 | RandomShift(0.1)]), 19 | Scale(300), 20 | # ElasticTransform(100), 21 | RandomCrop(300), 22 | HorizontalFlip(), 23 | Split([0, 3], [3, 6]), 24 | #[SubtractMean(mean=VOC.MEAN), None], 25 | ], 26 | PRNG, 27 | border='constant', 28 | fillval=VOC.MEAN, 29 | anchor_index=3) 30 | 31 | voc_dataset = VOCSegmentation( 32 | root=args.root, 33 | image_set=[('2007', 'trainval')], 34 | transform=transform, 35 | instance=False) 36 | viz = Viz() 37 | 38 | results = [] 39 | count = 0 40 | i = PRNG2.choice(len(voc_dataset)) 41 | for _ in range(1000): 42 | img, target = voc_dataset[i] 43 | img2 = viz.blend_segmentation(img, target) 44 | 45 | con = np.hstack([img, target, img2]) 46 | results.append(con) 47 | cv2.imshow('result', con[..., ::-1]) 48 | c = cv2.waitKey(500) 49 | 50 | if c == 27 or c == ord('q'): # ESC / 'q' 51 | break 52 | elif c == ord('c') or count >= 3: 53 | count = 0 54 | i = PRNG2.choice(len(voc_dataset)) 55 | count += 1 56 | 57 | # imageio.mimsave('mask.gif', results, duration=0.5) 58 | 59 | 60 | if __name__ == '__main__': 61 | from transforms import * 62 | from pascal_voc import VOC, VOCSegmentation, Viz 63 | 64 | import argparse 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--root', type=str, help='voc dataset root path', default='path/to/your/VOCdevkit') 67 | parser.add_argument('--seed', type=int, help='random seed', default=0) 68 | args = parser.parse_args() 69 | 70 | test_segmentation() -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 # much faster than scipy.ndimage/skimage 3 | 4 | import collections 5 | import numbers 6 | import types 7 | 8 | 9 | 10 | InterpolationFlags = { 11 | 'nearest':cv2.INTER_NEAREST, 'linear':cv2.INTER_LINEAR, 12 | 'cubic':cv2.INTER_CUBIC, 'area':cv2.INTER_AREA, 13 | 'lanczos':cv2.INTER_LANCZOS4} 14 | 15 | BorderTypes = { 16 | 'constant':cv2.BORDER_CONSTANT, 17 | 'replicate':cv2.BORDER_REPLICATE, 'nearest':cv2.BORDER_REPLICATE, 18 | 'reflect':cv2.BORDER_REFLECT, 'mirror': cv2.BORDER_REFLECT, 19 | 'wrap':cv2.BORDER_WRAP, 'reflect_101':cv2.BORDER_REFLECT_101} 20 | 21 | 22 | 23 | def _loguniform(interval, random_state=np.random): 24 | low, high = interval 25 | return np.exp(random_state.uniform(np.log(low), np.log(high))) 26 | 27 | 28 | def _clamp(img, low=None, high=None, dtype='uint8'): 29 | if low is None and high is None: 30 | if dtype == 'uint8': 31 | low, high = 0, 255 32 | elif dtype == 'uint16': 33 | low, high = 0, 65535 34 | else: 35 | low, high = -np.inf, np.inf 36 | img = np.clip(img, low, high) 37 | return img.astype(dtype) 38 | 39 | 40 | def _jaccard(boxes, rect): 41 | def _intersect(boxes, rect): 42 | lt = np.maximum(boxes[:, :2], rect[:2]) 43 | rb = np.minimum(boxes[:, 2:], rect[2:]) 44 | inter = np.clip(rb - lt, 0, None) 45 | return inter[:, 0] * inter[:, 1] 46 | inter = _intersect(boxes, rect) 47 | area1 = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 48 | area2 = (rect[2] - rect[0]) * (rect[3] - rect[1]) 49 | union = area1 + area2 - inter 50 | 51 | jaccard = inter / np.clip(union, 1e-10, None) 52 | coverage = inter / np.clip(area1, 1e-10, None) 53 | return jaccard, coverage, inter 54 | 55 | 56 | def _coords_clamp(cds, shape, outside=None): # outside: keep, clamp, inf 57 | w, h = shape[1] - 1, shape[0] - 1 58 | if outside == 'keep': 59 | return np.array(cds, dtype=np.float32) 60 | elif outside == 'inf': 61 | cds_ = [] 62 | for x, y in cds: 63 | x_ = x if 0 <= x <= w else np.sign(x) * np.inf 64 | y_ = y if 0 <= y <= h else np.sign(y) * np.inf 65 | cds_.append([x_, y_]) 66 | return np.array(cds_, dtype=np.float32) 67 | elif outside == 'clamp': # default 68 | return np.array( 69 | [[np.clip(cd[0], 0, w), np.clip(cd[1], 0, h)] for cd in cds], 70 | dtype=np.float32) 71 | else: 72 | raise NotImplementedError 73 | 74 | 75 | def _to_bboxes(cds, img_shape=None): 76 | assert len(cds) % 4 == 0 77 | 78 | h, w = img_shape if img_shape is not None else (np.inf, np.inf) 79 | boxes = [] 80 | cds = np.array(cds) 81 | for i in range(0, len(cds), 4): 82 | xmin = np.clip(cds[i:i+4, 0].min(), 0, w - 1) 83 | xmax = np.clip(cds[i:i+4, 0].max(), 0, w - 1) 84 | ymin = np.clip(cds[i:i+4, 1].min(), 0, h - 1) 85 | ymax = np.clip(cds[i:i+4, 1].max(), 0, h - 1) 86 | boxes.append([xmin, ymin, xmax, ymax]) 87 | return np.array(boxes) 88 | 89 | 90 | def _to_coords(boxes): 91 | cds = [] 92 | for box in boxes: 93 | xmin, ymin, xmax, ymax = box 94 | cds += [ 95 | [xmin, ymin], 96 | [xmax, ymin], 97 | [xmax, ymax], 98 | [xmin, ymax], 99 | ] 100 | return np.array(cds) 101 | 102 | 103 | # recursively reset transform's state 104 | def transform_state(t, **kwargs): 105 | if callable(t): 106 | t_vars = vars(t) 107 | 108 | if 'random_state' in kwargs and 'random' in t_vars: 109 | t.__dict__['random'] = kwargs['random_state'] 110 | 111 | support = ['fillval', 'anchor_index', 'prob', 'mean', 'std', 'outside_points'] 112 | for arg in kwargs: 113 | if arg in t_vars and arg in support: 114 | t.__dict__[arg] = kwargs[arg] 115 | 116 | if 'mode' in kwargs and 'mode' in t_vars: 117 | t.__dict__['mode'] = kwargs['mode'] 118 | if 'border' in kwargs and 'border' in t_vars: 119 | t.__dict__['border'] = BorderTypes.get(kwargs['border'], cv2.BORDER_REPLICATE) 120 | 121 | if 'transforms' in t_vars: 122 | t.__dict__['transforms'] = transforms_state(t.transforms, **kwargs) 123 | return t 124 | 125 | 126 | def transforms_state(ts, **kwargs): 127 | assert isinstance(ts, collections.Sequence) 128 | 129 | transforms = [] 130 | for t in ts: 131 | if isinstance(t, collections.Sequence): 132 | transforms.append(transforms_state(t, **kwargs)) 133 | else: 134 | transforms.append(transform_state(t, **kwargs)) 135 | return transforms 136 | 137 | 138 | 139 | # Operators 140 | ''' 141 | class Clamp(object): 142 | def __init__(self, min=0, max=255, soft=True, dtype='uint8'): 143 | self.min, self.max = min, max 144 | self.dtype = dtype 145 | self.soft = soft 146 | self.thresh = 147 | 148 | def __call__(self, img): 149 | if self.soft is None: 150 | return _clamp(img, low=self.min, high=self.max, dtype=self.dtype) 151 | else: 152 | ''' 153 | 154 | 155 | class Unsqueeze(object): 156 | def __call__(self, img): 157 | if img.ndim == 2: 158 | return img[..., np.newaxis] 159 | elif img.ndim == 3: 160 | return img 161 | else: 162 | raise ValueError('input muse be image') 163 | 164 | 165 | 166 | class Normalize(object): 167 | def __init__(self, mean, std): 168 | self.mean = mean 169 | self.std = std 170 | 171 | def __call__(self, img): 172 | if isinstance(img, np.ndarray): 173 | return (img - self.mean) / self.std 174 | # elif isinstance(img, torch.FloatTensor): 175 | # tensor = img 176 | # for t, m, s in zip(tensor, self.mean, self.std): 177 | # t.sub_(m).div_(s) 178 | # return tensor 179 | else: 180 | raise Exception('invalid input type') 181 | 182 | 183 | class SubtractMean(object): 184 | def __init__(self, mean): 185 | self.mean = mean 186 | 187 | def __call__(self, img): 188 | return img.astype(np.float32) - self.mean 189 | 190 | class DivideBy(object): 191 | def __init__(self, divisor): 192 | self.divisor = divisor 193 | 194 | def __call__(self, img): 195 | return img.astype(np.float32) / self.divisor 196 | 197 | 198 | def HalfBlood(img, anchor_index, f1, f2): 199 | # assert isinstance(f1, types.LambdaType) and isinstance(f2, types.LambdaType) 200 | 201 | if isinstance(anchor_index, numbers.Number): 202 | anchor_index = int(np.ceil(anchor_index)) 203 | 204 | if isinstance(anchor_index, int) and img.ndim == 3 and 0 < anchor_index < img.shape[2]: 205 | img1, img2 = img[:,:,:anchor_index], img[:,:,anchor_index:] 206 | 207 | if img1.shape[2] == 1: 208 | img1 = img1[:, :, 0] 209 | if img2.shape[2] == 1: 210 | img2 = img2[:, :, 0] 211 | 212 | img1 = f1(img1) 213 | img2 = f2(img2) 214 | 215 | if img1.ndim == 2: 216 | img1 = img1[..., np.newaxis] 217 | if img2.ndim == 2: 218 | img2 = img2[..., np.newaxis] 219 | return np.concatenate((img1, img2), axis=2) 220 | elif anchor_index == 0: 221 | img = f2(img) 222 | if img.ndim == 2: 223 | img = img[..., np.newaxis] 224 | return img 225 | else: 226 | img = f1(img) 227 | if img.ndim == 2: 228 | img = img[..., np.newaxis] 229 | return img 230 | 231 | 232 | 233 | 234 | 235 | # Photometric Transform 236 | 237 | class RGB2BGR(object): 238 | def __call__(self, img): 239 | assert img.ndim == 3 and img.shape[2] == 3 240 | return img[:, :, ::-1] 241 | 242 | class BGR2RGB(object): 243 | def __call__(self, img): 244 | assert img.ndim == 3 and img.shape[2] == 3 245 | return img[:, :, ::-1] 246 | 247 | 248 | class GrayScale(object): 249 | # RGB to Gray 250 | def __call__(self, img): 251 | if img.ndim == 3 and img.shape[2] == 1: 252 | return img 253 | assert img.ndim == 3 and img.shape[2] == 3 254 | dtype = img.dtype 255 | 256 | #5x slower than cv2.cvtColor 257 | gray = np.sum(img * [0.299, 0.587, 0.114], axis=2).astype(dtype) 258 | #gray = cv2.cvtColor(img.astype('uint8'), cv2.COLOR_RGB2GRAY) 259 | return gray[..., np.newaxis] 260 | 261 | 262 | class Hue(object): 263 | # skimage.color.rgb2hsv/hsv2rgb is almost 100x slower than cv2.cvtColor 264 | def __init__(self, var=0.05, prob=0.5, random_state=np.random): 265 | self.var = var 266 | self.prob = prob 267 | self.random = random_state 268 | 269 | def __call__(self, img): 270 | assert img.ndim == 3 and img.shape[2] == 3 271 | 272 | if self.random.random_sample() >= self.prob: 273 | return img 274 | 275 | var = self.random.uniform(-self.var, self.var) 276 | 277 | to_HSV, from_HSV = [ 278 | (cv2.COLOR_RGB2HSV, cv2.COLOR_HSV2RGB), 279 | (cv2.COLOR_BGR2HSV, cv2.COLOR_HSV2BGR)][self.random.randint(2)] 280 | 281 | hsv = cv2.cvtColor(img, to_HSV).astype(np.float32) 282 | 283 | hue = hsv[:, :, 0] / 179. + var 284 | hue = hue - np.floor(hue) 285 | hsv[:, :, 0] = hue * 179. 286 | 287 | img = cv2.cvtColor(hsv.astype('uint8'), from_HSV) 288 | return img 289 | 290 | 291 | class Saturation(object): 292 | def __init__(self, var=0.3, prob=0.5, random_state=np.random): 293 | self.var = var 294 | self.prob = prob 295 | self.random = random_state 296 | 297 | self.grayscale = GrayScale() 298 | 299 | def __call__(self, img): 300 | if self.random.random_sample() >= self.prob: 301 | return img 302 | 303 | dtype = img.dtype 304 | gs = self.grayscale(img) 305 | 306 | alpha = 1.0 + self.random.uniform(-self.var, self.var) 307 | img = alpha * img.astype(np.float32) + (1 - alpha) * gs.astype(np.float32) 308 | return _clamp(img, dtype=dtype) 309 | 310 | 311 | 312 | class Brightness(object): 313 | def __init__(self, delta=32, prob=0.5, random_state=np.random): 314 | self.delta = delta 315 | self.prob = prob 316 | self.random = random_state 317 | 318 | def __call__(self, img): 319 | if self.random.random_sample() >= self.prob: 320 | return img 321 | 322 | dtype = img.dtype 323 | #alpha = 1.0 + self.random.uniform(-self.var, self.var) 324 | #img = alpha * img.astype(np.float32) 325 | img = img.astype(np.float32) + self.random.uniform(-self.delta, self.delta) 326 | return _clamp(img, dtype=dtype) 327 | 328 | 329 | 330 | class Contrast(object): 331 | def __init__(self, var=0.3, prob=0.5, random_state=np.random): 332 | self.var = var 333 | self.prob = prob 334 | self.random = random_state 335 | 336 | self.grayscale = GrayScale() 337 | 338 | def __call__(self, img): 339 | if self.random.random_sample() >= self.prob: 340 | return img 341 | 342 | dtype = img.dtype 343 | gs = self.grayscale(img).mean() 344 | 345 | alpha = 1.0 + self.random.uniform(-self.var, self.var) 346 | img = alpha * img.astype(np.float32) + (1 - alpha) * gs 347 | return _clamp(img, dtype=dtype) 348 | 349 | 350 | class RandomOrder(object): 351 | def __init__(self, transforms, random_state=None): #, **kwargs): 352 | if random_state is None: 353 | self.random = np.random 354 | else: 355 | self.random = random_state 356 | #kwargs['random_state'] = random_state 357 | 358 | self.transforms = transforms_state(transforms, random=random_state) 359 | 360 | def __call__(self, img): 361 | if self.transforms is None: 362 | return img 363 | order = self.random.permutation(len(self.transforms)) 364 | for i in order: 365 | img = self.transforms[i](img) 366 | return img 367 | 368 | 369 | class ColorJitter(RandomOrder): 370 | def __init__(self, brightness=32, contrast=0.5, saturation=0.5, hue=0.1, 371 | prob=0.5, random_state=np.random): 372 | self.transforms = [] 373 | self.random = random_state 374 | 375 | if brightness != 0: 376 | self.transforms.append( 377 | Brightness(brightness, prob=prob, random_state=random_state)) 378 | if contrast != 0: 379 | self.transforms.append( 380 | Contrast(contrast, prob=prob, random_state=random_state)) 381 | if saturation != 0: 382 | self.transforms.append( 383 | Saturation(saturation, prob=prob, random_state=random_state)) 384 | if hue != 0: 385 | self.transforms.append( 386 | Hue(hue, prob=prob, random_state=random_state)) 387 | 388 | 389 | 390 | # "ImageNet Classification with Deep Convolutional Neural Networks" 391 | # looks inferior to ColorJitter 392 | class FancyPCA(object): 393 | def __init__(self, var=0.2, random_state=np.random): 394 | self.var = var 395 | self.random = random_state 396 | 397 | self.pca = None # shape (channels, channels) 398 | 399 | def __call__(self, img): 400 | dtype = img.dtype 401 | channels = img.shape[2] 402 | alpha = self.random.randn(channels) * self.var 403 | 404 | if self.pca is None: 405 | pca = self._pca(img) 406 | else: 407 | pca = self.pca 408 | 409 | img = img + (pca * alpha).sum(axis=1) 410 | return _clamp(img, dtype=dtype) 411 | 412 | def _pca(self, img): # single image (hwc), or a batch (nhwc) 413 | assert img.ndim >= 3 414 | channels = img.shape[-1] 415 | X = img.reshape(-1, channels) 416 | 417 | cov = np.cov(X.T) 418 | evals, evecs = np.linalg.eigh(cov) 419 | pca = np.sqrt(evals) * evecs 420 | return pca 421 | 422 | def fit(self, imgs): # training 423 | self.pca = self._pca(imgs) 424 | print(self.pca) 425 | 426 | 427 | class ShuffleChannels(object): 428 | def __init__(self, prob=1., random_state=np.random): 429 | self.prob = prob 430 | self.random = random_state 431 | 432 | def __call__(self, img): 433 | if self.prob < 1 and self.random.random_sample() >= self.prob: 434 | return img 435 | 436 | assert img.ndim == 3 437 | permut = self.random.permutation(img.shape[2]) 438 | img = img[:, :, permut] 439 | 440 | return img 441 | 442 | 443 | # "Improved Regularization of Convolutional Neural Networks with Cutout". 444 | # (arXiv:1708.04552) 445 | # fill with 0(if image is normalized) or dataset's per-channel mean. 446 | class Cutout(object): 447 | def __init__(self, size, fillval=0, prob=0.5, random_state=np.random): 448 | if isinstance(size, numbers.Number): 449 | size = (int(size), int(size)) 450 | self.size = size 451 | 452 | self.fillval = fillval 453 | self.prob = prob 454 | self.random = random_state 455 | 456 | def __call__(self, img): 457 | if self.random.random_sample() >= self.prob: 458 | return img 459 | 460 | h, w = img.shape[:2] 461 | tw, th = self.size 462 | 463 | cx = self.random.randint(0, w) 464 | cy = self.random.randint(0, h) 465 | 466 | x1 = int(np.clip(cx - tw / 2, 0, w - 1)) 467 | x2 = int(np.clip(cx + (tw + 1) / 2, 0, w )) 468 | y1 = int(np.clip(cy - th / 2, 0, h - 1)) 469 | y2 = int(np.clip(cy + (th + 1) / 2, 0, h )) 470 | 471 | img[y1:y2, x1:x2] = self.fillval 472 | 473 | return img 474 | 475 | 476 | # "Random Erasing Data Augmentation". (arXiv:1708.04896). 477 | # fill with random value 478 | class RandomErasing(object): 479 | def __init__(self, area_range=(0.02, 0.2), ratio_range=[0.3, 1/0.3], fillval=None, 480 | prob=0.5, num=1, anchor_index=None, random_state=np.random): 481 | self.area_range = area_range 482 | self.ratio_range = ratio_range 483 | self.fillval = fillval 484 | self.prob = prob 485 | self.num = num 486 | self.anchor_index = anchor_index 487 | self.random = random_state 488 | 489 | def __call__(self, img): 490 | if self.random.random_sample() >= self.prob: 491 | return img 492 | 493 | h, w = img.shape[:2] 494 | 495 | num = self.random.randint(self.num) + 1 496 | count = 0 497 | for _ in range(10): 498 | area = h * w 499 | target_area = _loguniform(self.area_range, self.random) * area 500 | aspect_ratio = _loguniform(self.ratio_range, self.random) 501 | 502 | tw = int(round(np.sqrt(target_area * aspect_ratio))) 503 | th = int(round(np.sqrt(target_area / aspect_ratio))) 504 | 505 | if tw <= w and th <= h: 506 | 507 | x1 = self.random.randint(0, w - tw + 1) 508 | y1 = self.random.randint(0, h - th + 1) 509 | 510 | fillval = self.random.randint(0, 256) if self.fillval is None else self.fillval 511 | 512 | erase = lambda im: self._fill(im, (x1, y1, x1+tw, y1+th), fillval) 513 | cut = lambda im: self._fill(im, (x1, y1, x1+tw, y1+th), 0) 514 | img = HalfBlood(img, self.anchor_index, erase, cut) 515 | 516 | count += 1 517 | if count >= num: 518 | return img 519 | 520 | # Fallback 521 | return img 522 | 523 | def _fill(self, img, rect, val): 524 | l, t, r, b = rect 525 | img[t:b, l:r] = val 526 | return img 527 | 528 | 529 | #GaussianBlur 530 | #MotionBlue 531 | #RadialBlur 532 | #ResizeBlur 533 | #Sharpen 534 | 535 | 536 | 537 | 538 | # Geometric Transform 539 | 540 | def _expand(img, size, lt, val): 541 | h, w = img.shape[:2] 542 | nw, nh = size 543 | x1, y1 = lt 544 | expand = np.zeros([nh, nw] + list(img.shape[2:]), dtype=img.dtype) 545 | expand[...] = val 546 | expand[y1: h + y1, x1: w + x1] = img 547 | #expand = cv2.copyMakeBorder(img, y1, nh-h-y1, x1, nw-w-x1, 548 | # cv2.BORDER_CONSTANT, value=val) # slightly faster 549 | return expand 550 | 551 | 552 | class Pad(object): 553 | def __init__(self, padding, fillval=0, anchor_index=None): 554 | if isinstance(padding, numbers.Number): 555 | padding = (padding, padding) 556 | assert len(padding) == 2 557 | 558 | self.padding = [int(np.clip(_), 0, None) for _ in padding] 559 | self.fillval = fillval 560 | self.anchor_index = anchor_index 561 | 562 | def __call__(self, img, cds=None): 563 | if max(self.padding) == 0: 564 | return img if cds is None else (img, cds) 565 | 566 | h, w = img.shape[:2] 567 | pw, ph = self.padding 568 | 569 | pad = lambda im: _expand(im, (w + pw*2, h + ph*2), (pw, ph), self.fillval) 570 | rigid = lambda im: _expand(im, (w + pw*2, h + ph*2), (pw, ph), 0) 571 | img = HalfBlood(img, self.anchor_index, pad, rigid) 572 | 573 | if cds is not None: 574 | return img, np.array([[x + pw, y + ph] for x, y in cds]) 575 | else: 576 | return img 577 | 578 | 579 | # "SSD: Single Shot MultiBox Detector". generate multi-resolution image/ multi-scale objects 580 | class Expand(object): 581 | def __init__(self, scale_range=(1, 4), fillval=0, prob=1.0, 582 | anchor_index=None, outside_points='clamp', random_state=np.random): 583 | if isinstance(scale_range, numbers.Number): 584 | scale_range = (1, scale_range) 585 | assert max(scale_range) <= 5 586 | 587 | self.scale_range = scale_range 588 | self.fillval = fillval 589 | self.prob = prob 590 | self.anchor_index = anchor_index 591 | self.outside_points = outside_points 592 | self.random = random_state 593 | 594 | def __call__(self, img, cds=None): 595 | if self.prob < 1 and self.random.random_sample() >= self.prob: 596 | return img if cds is None else (img, cds) 597 | 598 | #multiple = _loguniform(self.scale_range, self.random) 599 | multiple = self.random.uniform(*self.scale_range) 600 | 601 | h, w = img.shape[:2] 602 | nh, nw = int(multiple * h), int(multiple * w) 603 | 604 | if multiple < 1: 605 | return RandomCrop( 606 | size=(nw, nh), fillval=self.fillval, 607 | outside_points=self.outside_points, 608 | random_state=self.random)(img, cds) 609 | 610 | y1 = self.random.randint(0, nh - h + 1) 611 | x1 = self.random.randint(0, nw - w + 1) 612 | 613 | expand = lambda im: _expand(im, (nw, nh), (x1, y1), self.fillval) 614 | rigid = lambda im: _expand(im, (nw, nh), (x1, y1), 0) 615 | img = HalfBlood(img, self.anchor_index, expand, rigid) 616 | 617 | if cds is not None: 618 | return img, np.array([[x + x1, y + y1] for x, y in cds]) 619 | else: 620 | return img 621 | 622 | 623 | # scales the smaller edge to given size 624 | class Scale(object): 625 | def __init__(self, size, mode='linear', lazy=False, anchor_index=None, 626 | random_state=np.random): 627 | assert isinstance(size, int) 628 | 629 | self.size = int(size) 630 | self.mode = mode 631 | self.lazy = lazy 632 | self.anchor_index = anchor_index 633 | self.random = random_state 634 | 635 | def __call__(self, img, cds=None): 636 | interp_mode = ( 637 | self.random.choice(list(InterpolationFlags.values())) if self.mode 638 | is None else InterpolationFlags.get(self.mode, cv2.INTER_LINEAR)) 639 | 640 | h, w = img.shape[:2] 641 | 642 | if self.lazy and min(h, w) >= self.size: 643 | return img if cds is None else (img, cds) 644 | 645 | if h < w: 646 | tw, th = int(self.size / float(h) * w), self.size 647 | else: 648 | th, tw = int(self.size / float(w) * h), self.size 649 | 650 | # skimage.transform.resize 10x slower than cv2.resize 651 | resize = lambda im: cv2.resize(im, (tw, th), interpolation=interp_mode) 652 | rigid = lambda im: cv2.resize(im, (tw, th), interpolation=cv2.INTER_NEAREST) 653 | img = HalfBlood(img, self.anchor_index, resize, rigid) 654 | 655 | if cds is not None: 656 | s_x, s_y = tw / float(w), th / float(h) 657 | return img, np.array([[x * s_x, y * s_y] for x, y in cds]) 658 | else: 659 | return img 660 | 661 | 662 | class RandomScale(object): 663 | def __init__(self, size_range, mode='linear', anchor_index=None, random_state=np.random): 664 | assert isinstance(size_range, collections.Sequence) and len(size_range) == 2 665 | 666 | self.size_range = size_range 667 | self.mode = mode 668 | self.anchor_index = anchor_index 669 | self.random = random_state 670 | 671 | def __call__(self, img, cds=None): 672 | interp_mode = ( 673 | self.random.choice(list(InterpolationFlags.values())) if self.mode 674 | is None else InterpolationFlags.get(self.mode, cv2.INTER_LINEAR)) 675 | 676 | h, w = img.shape[:2] 677 | size = int(self.random.uniform(*self.size_range)) 678 | 679 | if h < w: 680 | tw, th = int(size / float(h) * w), size 681 | else: 682 | th, tw = int(size / float(w) * h), size 683 | 684 | resize = lambda im: cv2.resize(im, (tw, th), interpolation=interp_mode) 685 | rigid = lambda im: cv2.resize(im, (tw, th), interpolation=cv2.INTER_NEAREST) 686 | img = HalfBlood(img, self.anchor_index, resize, rigid) 687 | 688 | if cds is not None: 689 | s_x, s_y = tw / float(w), th / float(h) 690 | return img, np.array([[x * s_x, y * s_y] for x, y in cds]) 691 | else: 692 | return img 693 | 694 | 695 | class CenterCrop(object): 696 | def __init__(self, size, outside_points='clamp'): 697 | if isinstance(size, numbers.Number): 698 | size = (int(size), int(size)) 699 | self.size = size 700 | self.outside_points = outside_points 701 | 702 | def __call__(self, img, cds=None): 703 | h, w = img.shape[:2] 704 | tw, th = self.size 705 | 706 | if h == th and w == tw: 707 | return img if cds is None else (img, cds) 708 | elif h < th or w < tw: 709 | raise Exception('invalid crop size') 710 | 711 | x1 = int(round((w - tw) / 2.)) 712 | y1 = int(round((h - th) / 2.)) 713 | img = img[y1:y1 + th, x1:x1 + tw] 714 | 715 | if cds is not None: 716 | return img, _coords_clamp( 717 | [[x - x1, y - y1] for x, y in cds], img.shape, self.outside_points) 718 | else: 719 | return img 720 | 721 | 722 | class RandomCrop(object): 723 | def __init__(self, size, fillval=0, outside_points='clamp', random_state=np.random): 724 | if isinstance(size, numbers.Number): 725 | size = (int(size), int(size)) 726 | self.size = size 727 | self.outside_points = outside_points 728 | self.random = random_state 729 | 730 | def __call__(self, img, cds=None): 731 | h, w = img.shape[:2] 732 | tw, th = self.size 733 | 734 | assert h >= th and w >= tw 735 | 736 | x1 = self.random.randint(0, w - tw + 1) 737 | y1 = self.random.randint(0, h - th + 1) 738 | img = img[y1:y1 + th, x1:x1 + tw] 739 | 740 | if cds is not None: 741 | return img, _coords_clamp( 742 | [[x - x1, y - y1] for x, y in cds], img.shape, self.outside_points) 743 | else: 744 | return img 745 | 746 | 747 | # "SSD: Single Shot MultiBox Detector". 748 | # object-aware RandomCrop, crop multi-scale objects 749 | class ObjectRandomCrop(object): 750 | def __init__(self, prob=1., random_state=np.random): 751 | self.prob = prob 752 | self.random = random_state 753 | 754 | self.options = [ 755 | #(0, None), 756 | (0.1, None), 757 | (0.3, None), 758 | (0.5, None), 759 | (0.7, None), 760 | (0.9, None), 761 | (None, 1), ] 762 | 763 | 764 | def __call__(self, img, cbs): 765 | h, w = img.shape[:2] 766 | 767 | if len(cbs) == 0: 768 | return img, cbs 769 | 770 | if len(cbs[0]) == 4: # boxes 771 | boxes = cbs 772 | elif len(cbs[0]) == 2: # points 773 | boxes = _to_bboxes(cbs, img.shape[:2]) 774 | else: 775 | raise Exception('invalid input') 776 | 777 | params = [(np.array([0, 0, w, h]), None)] 778 | 779 | for min_iou, max_iou in self.options: 780 | if min_iou is None: 781 | min_iou = 0 782 | if max_iou is None: 783 | max_iou = 1 784 | 785 | for _ in range(50): 786 | scale = self.random.uniform(0.3, 1) 787 | aspect_ratio = self.random.uniform( 788 | max(1 / 2., scale * scale), 789 | min(2., 1 / (scale * scale))) 790 | th = int(h * scale / np.sqrt(aspect_ratio)) 791 | tw = int(w * scale * np.sqrt(aspect_ratio)) 792 | 793 | x1 = self.random.randint(0, w - tw + 1) 794 | y1 = self.random.randint(0, h - th + 1) 795 | rect = np.array([x1, y1, x1 + tw, y1 + th]) 796 | 797 | iou, coverage, _ = _jaccard(boxes, rect) 798 | #m1 = coverage > 0.1 799 | #m2 = coverage < 0.45 800 | #if (m1 * m2).any(): 801 | # continue 802 | 803 | center = (boxes[:, :2] + boxes[:, 2:]) / 2 804 | mask = np.logical_and(rect[:2] <= center, center < rect[2:]).all(axis=1) 805 | #mask = coverage >= 0.45 806 | if not mask.any(): 807 | continue 808 | 809 | if min_iou <= iou.max() and iou.min() <= max_iou: 810 | params.append((rect, mask)) 811 | break 812 | rect, mask = params[self.random.randint(len(params))] 813 | 814 | img = img[rect[1]:rect[3], rect[0]:rect[2]] 815 | boxes[:, :2] = np.clip(boxes[:, :2], rect[:2], rect[2:]) 816 | boxes[:, :2] = boxes[:, :2] - rect[:2] 817 | boxes[:, 2:] = np.clip(boxes[:, 2:], rect[:2], rect[2:]) 818 | boxes[:, 2:] = boxes[:, 2:] - rect[:2] 819 | if mask is not None: 820 | boxes[np.logical_not(mask), :] = 0 821 | 822 | if len(cbs[0]) == 4: 823 | return img, boxes 824 | else: 825 | return img, _to_coords(boxes) 826 | 827 | 828 | 829 | 830 | 831 | # Random crop with size 8%-100% and aspect ratio 3/4 - 4/3. (Inception-style) 832 | class RandomSizedCrop(object): 833 | def __init__(self, size, mode='linear', anchor_index=None, 834 | outside_points='clamp' , random_state=np.random): 835 | self.size = size 836 | self.mode = mode 837 | self.anchor_index = anchor_index 838 | self.outside_points = outside_points 839 | self.random = random_state 840 | 841 | self.scale = Scale(size, mode=mode, anchor_index=anchor_index) 842 | self.crop = CenterCrop(size) 843 | 844 | def __call__(self, img, cds=None): 845 | interp_mode = ( 846 | self.random.choice(list(InterpolationFlags.values())) if self.mode 847 | is None else InterpolationFlags.get(self.mode, cv2.INTER_LINEAR)) 848 | 849 | h, w = img.shape[:2] 850 | 851 | for _ in range(10): 852 | area = h * w 853 | target_area = self.random.uniform(0.16, 1.0) * area # 0.08~1.0 854 | aspect_ratio = self.random.uniform(3. / 4, 4. / 3) 855 | 856 | tw = int(round(np.sqrt(target_area * aspect_ratio))) 857 | th = int(round(np.sqrt(target_area / aspect_ratio))) 858 | 859 | if self.random.random_sample() < 0.5: 860 | tw, th = th, tw 861 | 862 | if tw <= w and th <= h: 863 | x1 = self.random.randint(0, w - tw + 1) 864 | y1 = self.random.randint(0, h - th + 1) 865 | 866 | img = img[y1:y1 + th, x1:x1 + tw] 867 | 868 | resize = lambda im: cv2.resize(im, (self.size, self.size), 869 | interpolation=interp_mode) 870 | rigid = lambda im: cv2.resize(im, (self.size, self.size), 871 | interpolation=cv2.INTER_NEAREST) 872 | img = HalfBlood(img, self.anchor_index, resize, rigid) 873 | 874 | if cds is not None: 875 | scale_x = self.size / float(tw) 876 | scale_y = self.size / float(th) 877 | 878 | return img, _coords_clamp( 879 | [[scale_x*(x-x1), scale_y*(y-y1)] for x, y in cds], 880 | img.shape, self.outside_points) 881 | else: 882 | return img 883 | 884 | # Fallback 885 | return self.crop(self.scale(img, cds=cds), cds=cds) 886 | 887 | 888 | class GridCrop(object): 889 | def __init__(self, size, grid=5, outside_points='clamp', random_state=np.random): 890 | # 4 grids, 5 grids or 9 grids 891 | if isinstance(size, numbers.Number): 892 | size = (int(size), int(size)) 893 | self.size = size 894 | 895 | self.grid = grid 896 | self.outside_points = outside_points 897 | self.random = random_state 898 | self.lookup = { 899 | 0: lambda w, h, tw, th: ( 0, 0), 900 | 1: lambda w, h, tw, th: ( w - tw, 0), 901 | 2: lambda w, h, tw, th: ( w - tw, h - th), 902 | 3: lambda w, h, tw, th: ( 0, h - th), 903 | 4: lambda w, h, tw, th: ((w - tw) // 2, (h - th) // 2), 904 | 5: lambda w, h, tw, th: ((w - tw) // 2, 0), 905 | 6: lambda w, h, tw, th: ( w - tw, (h - th) // 2), 906 | 7: lambda w, h, tw, th: ((w - tw) // 2, h - th), 907 | 8: lambda w, h, tw, th: ( 0, (h - th) // 2), 908 | } 909 | 910 | def __call__(self, img, cds=None, index=None): 911 | h, w = img.shape[:2] 912 | tw, th = self.size 913 | if index is None: 914 | index = self.random.randint(0, self.grid) 915 | if index not in self.lookup: 916 | raise Exception('invalid index') 917 | 918 | x1, y1 = self.lookup[index](w, h, tw, th) 919 | img = img[y1:y1 + th, x1:x1 + tw] 920 | 921 | if cds is not None: 922 | return img, _coords_clamp( 923 | [[x - x1, y - y1] for x, y in cds], img.shape, self.outside_points) 924 | else: 925 | return img 926 | 927 | 928 | 929 | class Resize(object): 930 | def __init__(self, size, mode='linear', anchor_index=None, random_state=np.random): 931 | if isinstance(size, numbers.Number): 932 | size = (int(size), int(size)) 933 | self.size = size 934 | 935 | self.mode = mode 936 | self.anchor_index = anchor_index 937 | self.random = random_state 938 | 939 | def __call__(self, img, cds=None): 940 | interp_mode = ( 941 | self.random.choice(list(InterpolationFlags.values())) if self.mode 942 | is None else InterpolationFlags.get(self.mode, cv2.INTER_LINEAR)) 943 | 944 | h, w = img.shape[:2] 945 | tw, th = self.size 946 | 947 | resize = lambda im: cv2.resize(im, (tw, th), interpolation=interp_mode) 948 | rigid = lambda im: cv2.resize(im, (tw, th), interpolation=cv2.INTER_NEAREST) 949 | img = HalfBlood(img, self.anchor_index, resize, rigid) 950 | 951 | if cds is not None: 952 | s_x = tw / float(w) 953 | s_y = th / float(h) 954 | return img, np.array([[s_x * x, s_y * y] for x, y in cds]) 955 | else: 956 | return img 957 | 958 | 959 | class RandomResize(object): 960 | def __init__(self, scale_range=(0.8, 1.2), ratio_range=1., mode='linear', 961 | anchor_index=None, random_state=np.random): 962 | 963 | sr = scale_range 964 | if isinstance(sr, numbers.Number): 965 | sr = (min(sr, 1. / sr), max(sr, 1. / sr)) 966 | assert max(sr) <= 5 967 | self.sr = sr 968 | 969 | rr = ratio_range 970 | if isinstance(rr, numbers.Number): 971 | rr = (min(rr, 1. / rr), max(rr, 1. / rr)) 972 | assert max(rr) <= 5 973 | self.rr = rr 974 | 975 | self.mode = mode 976 | self.anchor_index = anchor_index 977 | self.random = random_state 978 | 979 | def __call__(self, img, cds=None): 980 | interp_mode = ( 981 | self.random.choice(list(InterpolationFlags.values())) if self.mode 982 | is None else InterpolationFlags.get(self.mode, cv2.INTER_LINEAR)) 983 | 984 | h, w = img.shape[:2] 985 | 986 | scale_factor = _loguniform(self.sr, self.random) 987 | ratio_factor = _loguniform(self.rr, self.random) 988 | 989 | th = int(h * scale_factor) 990 | tw = int(w * scale_factor * ratio_factor) 991 | 992 | resize = lambda im: cv2.resize(im, (tw, th), interpolation=interp_mode) 993 | rigid = lambda im: cv2.resize(im, (tw, th), interpolation=cv2.INTER_NEAREST) 994 | img = HalfBlood(img, self.anchor_index, resize, rigid) 995 | 996 | if cds is not None: 997 | s_x = tw / float(w) 998 | s_y = th / float(h) 999 | return img, np.array([[s_x * x, s_y * y] for x, y in cds]) 1000 | else: 1001 | return img 1002 | 1003 | 1004 | class ElasticTransform(object): 1005 | def __init__(self, alpha=1000, sigma=40, mode='linear', border='constant', fillval=0, 1006 | outside_points='clamp', anchor_index=None, random_state=np.random): 1007 | 1008 | if isinstance(fillval, numbers.Number): 1009 | fillval = [fillval] * 3 1010 | 1011 | self.alpha, self.sigma = alpha, sigma 1012 | self.mode = mode 1013 | self.border = BorderTypes.get(border, cv2.BORDER_REPLICATE) 1014 | self.fillval = fillval 1015 | self.anchor_index = anchor_index 1016 | self.outside_points = outside_points 1017 | self.random = random_state 1018 | 1019 | 1020 | def __call__(self, img, cds=None): 1021 | interp_mode = ( 1022 | self.random.choice(list(InterpolationFlags.values())) if self.mode 1023 | is None else InterpolationFlags.get(self.mode, cv2.INTER_LINEAR)) 1024 | 1025 | shape = img.shape[:2] 1026 | 1027 | ksize = self.sigma * 4 + 1 1028 | dx = cv2.GaussianBlur(( 1029 | self.random.rand(*img.shape[:2]) * 2 - 1).astype(np.float32), 1030 | (ksize, ksize), 0) * self.alpha 1031 | dy = cv2.GaussianBlur(( 1032 | self.random.rand(*img.shape[:2]) * 2 - 1).astype(np.float32), 1033 | (ksize, ksize), 0) * self.alpha 1034 | 1035 | y, x = np.meshgrid( 1036 | np.arange(img.shape[0]), np.arange(img.shape[1]), indexing='ij') 1037 | mapy, mapx = (y + dy).astype(np.float32), (x + dx).astype(np.float32) 1038 | 1039 | elastic = lambda im: cv2.remap(im, mapx, mapy, interpolation=interp_mode, 1040 | borderMode=self.border, borderValue=self.fillval) 1041 | rigid = lambda im: cv2.remap(im, mapx, mapy, interpolation=cv2.INTER_NEAREST, 1042 | borderMode=cv2.BORDER_CONSTANT) 1043 | img = HalfBlood(img, self.anchor_index, elastic, rigid) 1044 | 1045 | if cds is None: 1046 | return img 1047 | else: 1048 | cds_from = np.hstack([mapx.reshape(-1, 1), mapy.reshape(-1, 1)]) 1049 | cds_to = np.hstack([x.reshape(-1, 1), y.reshape(-1, 1)]) 1050 | cds_ = [] 1051 | for coord in cds: 1052 | ind = np.argmin(np.sum((coord - cds_from)**2, axis=1)) 1053 | cds_.append(cds_to[ind]) 1054 | return img, _coords_clamp(cds_, img.shape, self.outside_points) 1055 | 1056 | 1057 | class RandomRotate(object): 1058 | def __init__(self, angle_range=(-30.0, 30.0), mode='linear', 1059 | border='constant', fillval=0, outside_points='clamp', 1060 | anchor_index=None, random_state=np.random): 1061 | 1062 | if isinstance(angle_range, numbers.Number): 1063 | angle_range = (-angle_range, angle_range) 1064 | self.angle_range = angle_range 1065 | 1066 | if isinstance(fillval, numbers.Number): 1067 | fillval = [fillval] * 3 1068 | 1069 | self.mode = mode 1070 | self.border = BorderTypes.get(border, cv2.BORDER_REPLICATE) 1071 | self.fillval = fillval 1072 | self.anchor_index = anchor_index 1073 | self.outside_points = outside_points 1074 | self.random = random_state 1075 | 1076 | def __call__(self, img, cds=None): 1077 | interp_mode = ( 1078 | self.random.choice(list(InterpolationFlags.values())) if self.mode 1079 | is None else InterpolationFlags.get(self.mode, cv2.INTER_LINEAR)) 1080 | 1081 | h, w = img.shape[:2] 1082 | angle = self.random.uniform(*self.angle_range) 1083 | 1084 | M = cv2.getRotationMatrix2D((w/2., h/2.), angle, 1) 1085 | 1086 | rotate = lambda im: cv2.warpAffine(im, M, dsize=(w, h), 1087 | flags=interp_mode, borderMode=self.border, borderValue=self.fillval) 1088 | rigid = lambda im: cv2.warpAffine(im, M, dsize=(w, h), 1089 | flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT) 1090 | img = HalfBlood(img, self.anchor_index, rotate, rigid) 1091 | 1092 | if cds is not None: 1093 | cos = np.cos(angle * np.pi / 180.) 1094 | sin = np.sin(angle * np.pi / 180.) 1095 | cds_ = [] 1096 | for x, y in cds: 1097 | x, y = x - w/2., -(y - h/2.) 1098 | x, y = cos*x - sin*y, sin*x + cos*y 1099 | x, y = x + w/2., -y + h/2. 1100 | cds_.append([x, y]) 1101 | return img, _coords_clamp(cds_, img.shape, self.outside_points) 1102 | else: 1103 | return img 1104 | 1105 | 1106 | class Rotate90(object): 1107 | def __init__(self, random_state=np.random): 1108 | # 4 directions 1109 | self.random = random_state 1110 | 1111 | self.lookup = { 1112 | 0: lambda x, y, w, h: ( x, y), 1113 | 1: lambda x, y, w, h: ( y, w-1-x), 1114 | 2: lambda x, y, w, h: (w-1-x, h-1-y), 1115 | 3: lambda x, y, w, h: (h-1-y, x), 1116 | } 1117 | 1118 | def __call__(self, img, cds=None, index=None): 1119 | h, w = img.shape[:2] 1120 | if index is None: 1121 | index = self.random.randint(0, 4) 1122 | if index not in self.lookup: 1123 | raise Exception('invalid index') 1124 | 1125 | img = np.rot90(img, index) 1126 | 1127 | if cds is not None: 1128 | return img, np.array([self.lookup[index](x, y, w, h) for x, y in cds]) 1129 | else: 1130 | return img 1131 | 1132 | 1133 | class RandomShift(object): 1134 | def __init__(self, tx=(-0.1, 0.1), ty=None, border='constant', fillval=0, 1135 | outside_points='clamp', anchor_index=None, random_state=np.random): 1136 | 1137 | if isinstance(tx, numbers.Number): 1138 | tx = (-abs(tx), abs(tx)) 1139 | assert isinstance(tx, tuple) and np.abs(tx).max() < 1 1140 | if ty is None: 1141 | ty = tx 1142 | elif isinstance(ty, numbers.Number): 1143 | ty = (-abs(ty), abs(ty)) 1144 | assert isinstance(ty, tuple) and np.abs(ty).max() < 1 1145 | self.tx, self.ty = tx, ty 1146 | 1147 | if isinstance(fillval, numbers.Number): 1148 | fillval = [fillval] * 3 1149 | 1150 | self.border = BorderTypes.get(border, cv2.BORDER_REPLICATE) 1151 | self.fillval = fillval 1152 | self.anchor_index = anchor_index 1153 | self.outside_points = outside_points 1154 | self.random = random_state 1155 | 1156 | def __call__(self, img, cds=None): 1157 | h, w = img.shape[:2] 1158 | tx = self.random.uniform(*self.tx) * w 1159 | ty = self.random.uniform(*self.ty) * h 1160 | 1161 | M = np.float32([[1,0,tx],[0,1,ty]]) 1162 | 1163 | shift = lambda im: cv2.warpAffine(im, M, dsize=(w, h), 1164 | flags=cv2.INTER_NEAREST, borderMode=self.border, 1165 | borderValue=self.fillval) 1166 | rigid = lambda im: cv2.warpAffine(im, M, dsize=(w, h), 1167 | flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT) 1168 | img = HalfBlood(img, self.anchor_index, shift, rigid) 1169 | 1170 | if cds is not None: 1171 | return img, _coords_clamp( 1172 | [[x + tx, y + ty] for x, y in cds], img.shape, self.outside_points) 1173 | else: 1174 | return img 1175 | 1176 | 1177 | class HorizontalFlip(object): 1178 | def __init__(self, prob=0.5, random_state=np.random): 1179 | self.prob = prob 1180 | self.random = random_state 1181 | 1182 | def __call__(self, img, cds=None, flip=None): 1183 | if flip is None: 1184 | flip = self.random.random_sample() < self.prob 1185 | 1186 | if flip: 1187 | img = img[:, ::-1] 1188 | 1189 | if cds is not None: 1190 | h, w = img.shape[:2] 1191 | t = lambda x, y: [w-1-x, y] if flip else [x, y] 1192 | return img, np.array([t(x, y) for x, y in cds]) 1193 | else: 1194 | return img 1195 | 1196 | 1197 | class VerticalFlip(object): 1198 | def __init__(self, prob=0.5, random_state=np.random): 1199 | self.prob = prob 1200 | self.random = random_state 1201 | 1202 | def __call__(self, img, cds=None, flip=None): 1203 | if flip is None: 1204 | flip = self.random.random_sample() < self.prob 1205 | 1206 | if flip: 1207 | img = img[::-1, :] 1208 | 1209 | if cds is not None: 1210 | h, w = img.shape[:2] 1211 | t = lambda x, y: [x, h-1-y] if flip else [x, y] 1212 | return img, np.array([t(x, y) for x, y in cds]) 1213 | else: 1214 | return img 1215 | 1216 | # TODO: more homography transformations 1217 | 1218 | 1219 | # Pipeline 1220 | 1221 | class Lambda(object): 1222 | def __init__(self, lambd): 1223 | assert isinstance(lambd, types.LambdaType) 1224 | self.lambd = lambd 1225 | 1226 | def __call__(self, *args): 1227 | return self.lambd(*args) 1228 | 1229 | 1230 | class Merge(object): 1231 | def __init__(self, axis=-1): 1232 | self.axis = axis 1233 | 1234 | def __call__(self, *imgs): 1235 | # ad-hoc 1236 | if len(imgs) > 1 and not isinstance(imgs[0], collections.Sequence): 1237 | pass 1238 | elif len(imgs) == 1 and isinstance(imgs[0], collections.Sequence): # unreliable 1239 | imgs = imgs[0] 1240 | elif len(imgs) == 1: 1241 | return imgs[0] 1242 | else: 1243 | raise Exception('input must be a sequence (list, tuple, etc.)') 1244 | 1245 | assert len(imgs) > 0 and all([isinstance(_, np.ndarray) 1246 | for _ in imgs]), 'only support numpy array' 1247 | 1248 | shapes = [] 1249 | imgs_ = [] 1250 | for i, img in enumerate(imgs): 1251 | if img.ndim == 2: 1252 | img = np.expand_dims(img, axis=self.axis) 1253 | imgs_.append(img) 1254 | shape = list(img.shape) 1255 | shape[self.axis] = None 1256 | shapes.append(shape) 1257 | assert all([_ == shapes[0] for _ in shapes]), 'shapes must match' 1258 | return np.concatenate(imgs_, axis=self.axis) 1259 | 1260 | 1261 | class Split(object): 1262 | def __init__(self, *slices, **kwargs): 1263 | slices_ = [] 1264 | for s in slices: 1265 | if isinstance(s, collections.Sequence): 1266 | slices_.append(slice(*s)) 1267 | else: 1268 | slices_.append(s) 1269 | assert all([isinstance(s, slice) for s in slices_]), ( 1270 | 'slices must consist of slice instances') 1271 | 1272 | self.slices = slices_ 1273 | self.axis = kwargs.get('axis', -1) 1274 | 1275 | def __call__(self, img): 1276 | if isinstance(img, np.ndarray): 1277 | result = [] 1278 | for s in self.slices: 1279 | sl = [slice(None)] * img.ndim 1280 | sl[self.axis] = s 1281 | result.append(img[sl]) 1282 | return result 1283 | else: 1284 | raise Exception('object must be a numpy array') 1285 | 1286 | 1287 | class Branching(object): 1288 | # TODO 1289 | pass 1290 | 1291 | class Bracket(object): 1292 | # TODO 1293 | pass 1294 | 1295 | class Flatten(object): 1296 | # TODO 1297 | pass 1298 | 1299 | class Permute(object): 1300 | # TODO 1301 | pass 1302 | 1303 | 1304 | class Compose(object): 1305 | def __init__(self, transforms, random_state=None, **kwargs): 1306 | if random_state is not None: 1307 | kwargs['random_state'] = random_state 1308 | self.transforms = transforms_state(transforms, **kwargs) 1309 | 1310 | def __call__(self, *data): 1311 | # ad-hoc 1312 | if len(data) >= 1 and not isinstance(data[0], collections.Sequence): 1313 | pass 1314 | elif len(data) == 1 and isinstance(data[0], collections.Sequence) and len(data[0]) > 0: # unreliable 1315 | data = list(data[0]) 1316 | else: 1317 | raise Exception('invalid input') 1318 | 1319 | for t in self.transforms: 1320 | if not isinstance(data, collections.Sequence): # unreliable 1321 | data = [data] 1322 | 1323 | if isinstance(t, collections.Sequence): 1324 | if len(t) > 1: 1325 | assert isinstance(data, collections.Sequence) and len(data) == len(t) 1326 | ds = [] 1327 | for i, d in enumerate(data): 1328 | if callable(t[i]): 1329 | ds.append(t[i](d)) 1330 | else: 1331 | ds.append(d) 1332 | data = ds 1333 | elif len(t) == 1: 1334 | if callable(t[0]): 1335 | data = [t[0](data[0])] + list(data)[1:] 1336 | elif callable(t): 1337 | data = t(*data) 1338 | elif t is not None: 1339 | raise Exception('invalid transform type') 1340 | 1341 | if isinstance(data, collections.Sequence) and len(data) == 1: # unreliable 1342 | return data[0] 1343 | else: 1344 | return data 1345 | 1346 | def set_random_state(self, random_state): 1347 | self.transforms = transforms_state(self.transforms, random=random_state) 1348 | 1349 | 1350 | class RandomCompose(Compose): 1351 | def __init__(self, transforms, random_state=None, **kwargs): 1352 | if random_state is None: 1353 | random_state = np.random 1354 | else: 1355 | kwargs['random_state'] = random_state 1356 | 1357 | self.transforms = transforms_state(transforms, **kwargs) 1358 | self.random = random_state 1359 | 1360 | def __call__(self, *data): 1361 | self.random.shuffle(self.transforms) 1362 | return super(RandomCompose, self).__call__(*data) 1363 | 1364 | 1365 | class BoxesToCoords(object): 1366 | def __init__(self, relative=False): 1367 | self.relative = relative 1368 | 1369 | def bbox2coords(self, bbox): 1370 | xmin, ymin, xmax, ymax = bbox 1371 | return np.array([ 1372 | [xmin, ymin], 1373 | [xmax, ymin], 1374 | [xmax, ymax], 1375 | [xmin, ymax], 1376 | ]) 1377 | 1378 | def __call__(self, img, boxes): 1379 | if len(boxes) == 0: 1380 | return img, np.array([]) 1381 | 1382 | h, w = img.shape[:2] 1383 | if self.relative: 1384 | boxes[:, 0] *= w 1385 | boxes[:, 2] *= w 1386 | boxes[:, 1] *= h 1387 | boxes[:, 3] *= h 1388 | return img, np.vstack([self.bbox2coords(_) for _ in boxes]) 1389 | 1390 | 1391 | class CoordsToBoxes(object): 1392 | def __init__(self, relative=True): 1393 | self.relative = relative 1394 | 1395 | def coords2bbox(self, cds, w, h): 1396 | xmin = np.clip(cds[:, 0].min(), 0, w - 1) 1397 | xmax = np.clip(cds[:, 0].max(), 0, w - 1) 1398 | ymin = np.clip(cds[:, 1].min(), 0, h - 1) 1399 | ymax = np.clip(cds[:, 1].max(), 0, h - 1) 1400 | return np.array([xmin, ymin, xmax, ymax]) 1401 | 1402 | def __call__(self, img, cds): 1403 | if len(cds) == 0: 1404 | return img, np.array([]) 1405 | 1406 | assert len(cds) % 4 == 0 1407 | num = len(cds) // 4 1408 | 1409 | h, w = img.shape[:2] 1410 | boxcds = np.split(np.array(cds), np.arange(1, num) * 4) 1411 | boxes = np.array( 1412 | [self.coords2bbox(_, w, h) for _ in boxcds], dtype=float) 1413 | 1414 | if self.relative: 1415 | boxes[:, 0] /= float(w) 1416 | boxes[:, 2] /= float(w) 1417 | boxes[:, 1] /= float(h) 1418 | boxes[:, 3] /= float(h) 1419 | 1420 | return img, boxes 1421 | 1422 | 1423 | class OneHotMask(object): 1424 | def __init__(self, n_classes): 1425 | self.n_classes = n_classes 1426 | 1427 | def __call__(self, mask): 1428 | if mask.ndim == 3 and mask.shape[2] == 1: 1429 | mask = mask[:, :, 0] 1430 | assert mask.ndim == 2 and mask.max() < self.n_classes 1431 | 1432 | onehot_mask = np.zeros( 1433 | (mask.shape[0], mask.shape[1], self.n_classes), dtype=np.uint8) 1434 | for i in range(self.n_classes): 1435 | onehot_mask[:, :, i] = mask == i 1436 | return onehot_mask 1437 | 1438 | 1439 | 1440 | 1441 | 1442 | # if __name__ == '__main__': 1443 | # from numpy.random import RandomState 1444 | 1445 | # PRNG = RandomState() 1446 | 1447 | # transform = Compose([ 1448 | # [ColorJitter(), None], 1449 | # Merge(), 1450 | # HorizontalFlip(), 1451 | # RandomResize(1, 1.2), 1452 | # Expand((1., 3), prob=0.8), 1453 | # # ObjectRandomCrop(), 1454 | # Resize(300), 1455 | # Split([0, 3], [3, 6]), 1456 | # #[RandomErasing(), None] 1457 | # #[Cutout(size=150), None] 1458 | # #[None, Squeeze()] 1459 | # ], PRNG) 1460 | 1461 | 1462 | 1463 | # # image = skimage.data.astronaut() 1464 | # image = cv2.imread('/home/qihang/datasets/VOCdevkit/VOC2007/JPEGImages/000030.jpg')[:,:,::-1] 1465 | 1466 | # for _ in range(8): 1467 | # # PRNG.seed(90) 1468 | 1469 | # img, original = transform([image, image]) 1470 | 1471 | # img = np.hstack([img, original]) 1472 | 1473 | # cv2.imshow('compare', img[:,:,::-1]) 1474 | # cv2.waitKey(1000) 1475 | 1476 | --------------------------------------------------------------------------------