├── Figures ├── bbox_ga_0.png ├── bbox_ga_1.png ├── color_0.png ├── color_1.png ├── etc_0.png ├── etc_1.png ├── ga_0.png ├── ga_1.png ├── mag_0.png ├── mag_1.png ├── policy_0.png └── policy_1.png ├── README.md ├── Visualize - Bounding Box Geometric Augmentation.ipynb ├── Visualize - Color Augmentation.ipynb ├── Visualize - Geometric Augmentation.ipynb ├── Visualize - Magnitude Check.ipynb ├── Visualize - Other Augmentation.ipynb ├── Visualize - Policy.ipynb ├── augmentation.py ├── dataset.py ├── functional.py └── policy.py /Figures/bbox_ga_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jasonlee1995/AutoAugment_Detection/c09aceff631085d6992aff3297c0a803aa82b332/Figures/bbox_ga_0.png -------------------------------------------------------------------------------- /Figures/bbox_ga_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jasonlee1995/AutoAugment_Detection/c09aceff631085d6992aff3297c0a803aa82b332/Figures/bbox_ga_1.png -------------------------------------------------------------------------------- /Figures/color_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jasonlee1995/AutoAugment_Detection/c09aceff631085d6992aff3297c0a803aa82b332/Figures/color_0.png -------------------------------------------------------------------------------- /Figures/color_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jasonlee1995/AutoAugment_Detection/c09aceff631085d6992aff3297c0a803aa82b332/Figures/color_1.png -------------------------------------------------------------------------------- /Figures/etc_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jasonlee1995/AutoAugment_Detection/c09aceff631085d6992aff3297c0a803aa82b332/Figures/etc_0.png -------------------------------------------------------------------------------- /Figures/etc_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jasonlee1995/AutoAugment_Detection/c09aceff631085d6992aff3297c0a803aa82b332/Figures/etc_1.png -------------------------------------------------------------------------------- /Figures/ga_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jasonlee1995/AutoAugment_Detection/c09aceff631085d6992aff3297c0a803aa82b332/Figures/ga_0.png -------------------------------------------------------------------------------- /Figures/ga_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jasonlee1995/AutoAugment_Detection/c09aceff631085d6992aff3297c0a803aa82b332/Figures/ga_1.png -------------------------------------------------------------------------------- /Figures/mag_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jasonlee1995/AutoAugment_Detection/c09aceff631085d6992aff3297c0a803aa82b332/Figures/mag_0.png -------------------------------------------------------------------------------- /Figures/mag_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jasonlee1995/AutoAugment_Detection/c09aceff631085d6992aff3297c0a803aa82b332/Figures/mag_1.png -------------------------------------------------------------------------------- /Figures/policy_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jasonlee1995/AutoAugment_Detection/c09aceff631085d6992aff3297c0a803aa82b332/Figures/policy_0.png -------------------------------------------------------------------------------- /Figures/policy_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jasonlee1995/AutoAugment_Detection/c09aceff631085d6992aff3297c0a803aa82b332/Figures/policy_1.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoAugment for Detection Implementation with Pytorch 2 | - Unofficial implementation of the paper *Learning Data Augmentation Strategies for Object Detection* 3 | 4 | 5 | ## 0. Develop Environment 6 | ``` 7 | Docker Image 8 | - pytorch/pytorch:1.8.1-cuda11.1-cudnn8-devel 9 | ``` 10 | - Using Single GPU 11 | 12 | 13 | ## 1. Implementation Details 14 | - augmentation.py : augmentation class with probability included 15 | - dataset.py : COCO pytorch dataset 16 | - functional.py : augmentation functions for augmentation class 17 | - policy.py : augmentation policy v0, v1, v2, v3, vtest 18 | - Visualize - Bounding Box Geometric Augmentation.ipynb : experiments of bounding box geometric augmentation 19 | - Visualize - Color Augmentation.ipynb : experiments of color augmentation 20 | - Visualize - Geometric Augmentation.ipynb : experiments of geometric augmentation 21 | - Visualize - Magnitude Check.ipynb : experiments for checking Magnitude is right 22 | - Visualize - Other Augmentation.ipynb : experiments of left augmentation 23 | - Visualize - Policy.ipynb : experiments of policy 24 | - Details 25 | * range are different so just followed the official code not the paper 26 | * some of the range are fixed cause of mismatch with magnitude 27 | * range 0.1 ~ 1.9 for color operation (Color, Contrast, Brightness, Sharpness) 28 | * but 1 is the default (original image) 29 | * so in this repo, I code like below 30 | * instead using 0.1 ~ 1.9, use 0 ~ 0.9 with random change (0.5 probability) 31 | * e.g.) 0.9 was chosen randomly minus the value (0.9 or -0.9) and add with 1 (1.9 or 0.1) 32 | * do not use numpy nor opencv for speed and preventing version crashes 33 | * similar design pattern following torchvision transforms code 34 | * some of the codes can be improved but not considered in this repo (e.g. TranslateX_Only_BBoxes - translate considering bbox size not fixed pixel) 35 | 36 | 37 | ## 2. Results 38 | #### 2.1. Color Augmentation 39 | ![Color Augmentation](./Figures/color_0.png) 40 | ![Color Augmentation](./Figures/color_1.png) 41 | 42 | #### 2.2. Geometric Augmentation 43 | ![Geometric Augmentation](./Figures/ga_0.png) 44 | ![Geometric Augmentation](./Figures/ga_1.png) 45 | 46 | #### 2.3. Bounding Box Augmentation 47 | ![Bounding Box Augmentation](./Figures/bbox_ga_0.png) 48 | ![Bounding Box Augmentation](./Figures/bbox_ga_1.png) 49 | 50 | #### 2.4. Other Augmentation 51 | ![Other Augmentation](./Figures/etc_0.png) 52 | ![Other Augmentation](./Figures/etc_1.png) 53 | 54 | #### 2.5. Policy 55 | ![Policy](./Figures/policy_0.png) 56 | ![Policy](./Figures/policy_1.png) 57 | 58 | #### 2.6. Magnitude 59 | ![Magnitude](./Figures/mag_0.png) 60 | ![Magnitude](./Figures/mag_1.png) 61 | 62 | 63 | ## 3. Reference 64 | - Learning Data Augmentation Strategies for Object Detection [[paper]](https://arxiv.org/pdf/1906.11172.pdf) [[official code]](https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/autoaugment_utils.py) 65 | -------------------------------------------------------------------------------- /augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Coior Augmentation: Autocontrast, Brightness, Color, Contrast, Equalize, Posterize, Sharpness, Solarize, SolarizeAdd 3 | 4 | Geometric Augmentation: Rotate_BBox, ShearX_BBox, ShearY_BBox, TranslateX_BBox, TranslateY_BBox 5 | 6 | Mask Augmentation: Cutout 7 | 8 | Color Augmentation based on BBoxes: Equalize_Only_BBoxes, Solarize_Only_BBoxes 9 | 10 | Geometric Augmentation based on BBoxes: Rotate_Only_BBoxes, ShearX_Only_BBoxes, ShearY_Only_BBoxes, 11 | TranslateX_Only_BBoxes, TranslateY_Only_BBoxes, Flip_Only_BBoxes 12 | 13 | Mask Augmentation based on BBoxes: BBox_Cutout, Cutout_Only_BBoxes 14 | 15 | """ 16 | 17 | 18 | import torch, torchvision, functional 19 | import torchvision.transforms.functional as F 20 | 21 | from PIL import Image, ImageOps 22 | 23 | 24 | ### Basic Augmentation 25 | class Compose: 26 | """ 27 | Composes several transforms together. 28 | """ 29 | def __init__(self, transforms): 30 | self.transforms = transforms 31 | 32 | def __call__(self, image, bboxs): 33 | for t in self.transforms: 34 | image, bboxs = t(image, bboxs) 35 | return image, bboxs 36 | 37 | 38 | class ToTensor: 39 | """ 40 | Converts a PIL Image or numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 41 | Only applied to image, not bboxes. 42 | """ 43 | def __call__(self, image, bboxs): 44 | return F.to_tensor(image), bboxs 45 | 46 | 47 | class Normalize(torch.nn.Module): 48 | """ 49 | Normalize a tensor image with mean and standard deviation. 50 | Only applied to image, not bboxes. 51 | """ 52 | def __init__(self, mean, std, inplace=False): 53 | super().__init__() 54 | self.mean = mean 55 | self.std = std 56 | self.inplace = inplace 57 | 58 | def forward(self, image, bboxs): 59 | return F.normalize(image, self.mean, self.std, self.inplace), bboxs 60 | 61 | 62 | ### Coior Augmentation 63 | class AutoContrast(torch.nn.Module): 64 | """ 65 | Autocontrast the pixels of the given image. 66 | Only applied to image, not bboxes. 67 | """ 68 | def __init__(self, p): 69 | super().__init__() 70 | self.p = p 71 | 72 | def forward(self, image, bboxs): 73 | if torch.rand(1) < self.p: 74 | autocontrast_image = ImageOps.autocontrast(image) 75 | return autocontrast_image, bboxs 76 | else: 77 | return image, bboxs 78 | 79 | 80 | class Brightness(torch.nn.Module): 81 | """ 82 | Adjust image brightness using magnitude. 83 | Only applied to image, not bboxes. 84 | """ 85 | def __init__(self, p, magnitude, minus=True): 86 | super().__init__() 87 | self.p = p 88 | self.magnitude = magnitude 89 | self.minus = minus 90 | 91 | def forward(self, image, bboxs): 92 | if self.minus and (torch.rand(1) < 0.5): self.magnitude *= -1 93 | if torch.rand(1) < self.p: 94 | brightness_image = functional.brightness(image, 1+self.magnitude) 95 | return brightness_image, bboxs 96 | else: 97 | return image, bboxs 98 | 99 | 100 | class Color(torch.nn.Module): 101 | """ 102 | Adjust image color balance using magnitude. 103 | Only applied to image, not bboxes. 104 | """ 105 | def __init__(self, p, magnitude, minus=True): 106 | super().__init__() 107 | self.p = p 108 | self.magnitude = magnitude 109 | self.minus = minus 110 | 111 | def forward(self, image, bboxs): 112 | if self.minus and (torch.rand(1) < 0.5): self.magnitude *= -1 113 | if torch.rand(1) < self.p: 114 | color_image = functional.color(image, 1+self.magnitude) 115 | return color_image, bboxs 116 | else: 117 | return image, bboxs 118 | 119 | 120 | class Contrast(torch.nn.Module): 121 | """ 122 | Adjust image contrast using magnitude. 123 | Only applied to image, not bboxes. 124 | """ 125 | def __init__(self, p, magnitude, minus=True): 126 | super().__init__() 127 | self.p = p 128 | self.magnitude = magnitude 129 | self.minus = minus 130 | 131 | def forward(self, image, bboxs): 132 | if self.minus and (torch.rand(1) < 0.5): self.magnitude *= -1 133 | if torch.rand(1) < self.p: 134 | contrast_image = functional.contrast(image, 1+self.magnitude) 135 | return contrast_image, bboxs 136 | else: 137 | return image, bboxs 138 | 139 | 140 | class Equalize(torch.nn.Module): 141 | """ 142 | Equalize the histogram of the given image. 143 | Only applied to image, not bboxes. 144 | """ 145 | def __init__(self, p): 146 | super().__init__() 147 | self.p = p 148 | 149 | def forward(self, image, bboxs): 150 | if torch.rand(1) < self.p: 151 | equalize_image = ImageOps.equalize(image) 152 | return equalize_image, bboxs 153 | else: 154 | return image, bboxs 155 | 156 | 157 | class Posterize(torch.nn.Module): 158 | """ 159 | Posterize the image by reducing the number of bits for each color channel. 160 | Only applied to image, not bboxes. 161 | """ 162 | def __init__(self, p, bits): 163 | super().__init__() 164 | self.p = p 165 | self.bits = int(bits) 166 | 167 | def forward(self, image, bboxs): 168 | if torch.rand(1) < self.p: 169 | posterize_image = ImageOps.posterize(image, self.bits) 170 | return posterize_image, bboxs 171 | else: 172 | return image, bboxs 173 | 174 | 175 | class Sharpness(torch.nn.Module): 176 | """ 177 | Adjust image sharpness using magnitude. 178 | Only applied to image, not bboxes. 179 | """ 180 | def __init__(self, p, magnitude, minus=True): 181 | super().__init__() 182 | self.p = p 183 | self.magnitude = magnitude 184 | self.minus = minus 185 | 186 | def forward(self, image, bboxs): 187 | if self.minus and (torch.rand(1) < 0.5): self.magnitude *= -1 188 | if torch.rand(1) < self.p: 189 | sharpness_image = functional.sharpness(image, 1+self.magnitude) 190 | return sharpness_image, bboxs 191 | else: 192 | return image, bboxs 193 | 194 | 195 | class Solarize(torch.nn.Module): 196 | """ 197 | Solarize the image by inverting all pixel values above a threshold. 198 | Only applied to image, not bboxes. 199 | """ 200 | def __init__(self, p, threshold): 201 | super().__init__() 202 | self.p = p 203 | self.threshold = int(threshold) 204 | 205 | def forward(self, image, bboxs): 206 | if torch.rand(1) < self.p: 207 | solarize_image = ImageOps.solarize(image, self.threshold) 208 | return solarize_image, bboxs 209 | else: 210 | return image, bboxs 211 | 212 | 213 | class SolarizeAdd(torch.nn.Module): 214 | """ 215 | Solarize the image by added image below a threshold. 216 | Add addition amount to image and then clip the pixel value to 0~255 or 0~1. 217 | Parameter addition must be integer. 218 | Only applied to image, not bboxes. 219 | """ 220 | def __init__(self, p, addition, threshold=128, minus=True): 221 | super().__init__() 222 | self.p = p 223 | self.addition = int(addition) 224 | self.threshold = int(threshold) 225 | self.minus = minus 226 | 227 | def forward(self, image, bboxs): 228 | if self.minus and (torch.rand(1) < 0.5): self.addition *= -1 229 | if torch.rand(1) < self.p: 230 | solarize_add_image = functional.solarize_add(image, self.addition, self.threshold) 231 | return solarize_add_image, bboxs 232 | else: 233 | return image, bboxs 234 | 235 | 236 | ### Geometric Augmentation 237 | class Rotate_BBox(torch.nn.Module): 238 | """ 239 | Rotate image by degrees and change bboxes according to rotated image. 240 | The pixel values filled in will be of the value replace. 241 | Assume the coords are given min_x, min_y, max_x, max_y. 242 | Both applied to image and bboxes. 243 | """ 244 | def __init__(self, p, degrees, replace=128, minus=True): 245 | super().__init__() 246 | self.p = p 247 | self.degrees = degrees 248 | self.replace = replace 249 | self.minus = minus 250 | 251 | def forward(self, image, bboxs): 252 | if self.minus and (torch.rand(1) < 0.5): self.degrees *= -1 253 | if torch.rand(1) < self.p: 254 | rotate_image = image.rotate(self.degrees, fillcolor=(self.replace, self.replace, self.replace)) 255 | if bboxs == None: 256 | return rotate_image, bboxs 257 | else: 258 | rotate_bbox = functional._rotate_bbox(image, bboxs, self.degrees) 259 | return rotate_image, rotate_bbox 260 | else: 261 | return image, bboxs 262 | 263 | 264 | class ShearX_BBox(torch.nn.Module): 265 | """ 266 | Shear image and change bboxes on X-axis. 267 | The pixel values filled in will be of the value replace. 268 | Level is usually between -0.3~0.3. 269 | Assume the coords are given min_x, min_y, max_x, max_y. 270 | Both applied to image and bboxes. 271 | """ 272 | def __init__(self, p, level, replace=128, minus=True): 273 | super().__init__() 274 | self.p = p 275 | self.level = level 276 | self.replace = replace 277 | self.minus = minus 278 | 279 | def forward(self, image, bboxs): 280 | if self.minus and (torch.rand(1) < 0.5): self.level *= -1 281 | if torch.rand(1) < self.p: 282 | shear_image = image.transform(image.size, Image.AFFINE, (1, self.level, 0, 0, 1, 0), fillcolor=(self.replace, self.replace, self.replace)) 283 | if bboxs == None: 284 | return shear_image, bboxs 285 | else: 286 | shear_bbox = functional.shear_with_bboxes(image, bboxs, self.level, self.replace, shift_horizontal=True) 287 | return shear_image, shear_bbox 288 | else: 289 | return image, bboxs 290 | 291 | 292 | class ShearY_BBox(torch.nn.Module): 293 | """ 294 | Shear image and change bboxes on Y-axis. 295 | The pixel values filled in will be of the value replace. 296 | Level is usually between -0.3~0.3. 297 | Assume the coords are given min_x, min_y, max_x, max_y. 298 | Both applied to image and bboxes. 299 | """ 300 | def __init__(self, p, level, replace=128, minus=True): 301 | super().__init__() 302 | self.p = p 303 | self.level = level 304 | self.replace = replace 305 | self.minus = minus 306 | 307 | def forward(self, image, bboxs): 308 | if self.minus and (torch.rand(1) < 0.5): self.level *= -1 309 | if torch.rand(1) < self.p: 310 | shear_image = image.transform(image.size, Image.AFFINE, (1, 0, 0, self.level, 1, 0), fillcolor=(self.replace, self.replace, self.replace)) 311 | if bboxs == None: 312 | return shear_image, bboxs 313 | else: 314 | shear_bbox = functional.shear_with_bboxes(image, bboxs, self.level, self.replace, shift_horizontal=False) 315 | return shear_image, shear_bbox 316 | else: 317 | return image, bboxs 318 | 319 | 320 | class TranslateX_BBox(torch.nn.Module): 321 | """ 322 | Translate image and bboxes on X-axis. 323 | The pixel values filled in will be of the value replace. 324 | Assume the coords are given min_x, min_y, max_x, max_y. 325 | Both applied to image and bboxes. 326 | """ 327 | def __init__(self, p, pixels, replace=128, minus=True): 328 | super().__init__() 329 | self.p = p 330 | self.pixels = int(pixels) 331 | self.replace = replace 332 | self.minus = minus 333 | 334 | def forward(self, image, bboxs): 335 | if self.minus and (torch.rand(1) < 0.5): self.pixels *= -1 336 | if torch.rand(1) < self.p: 337 | translate_image = image.transform(image.size, Image.AFFINE, (1, 0, -self.pixels, 0, 1, 0), fillcolor=(self.replace, self.replace, self.replace)) 338 | if bboxs == None: 339 | return translate_image, bboxs 340 | else: 341 | translate_bbox = functional.translate_bbox(image, bboxs, self.pixels, self.replace, shift_horizontal=True) 342 | return translate_image, translate_bbox 343 | else: 344 | return image, bboxs 345 | 346 | 347 | class TranslateY_BBox(torch.nn.Module): 348 | """ 349 | Translate image and bboxes on Y-axis. 350 | The pixel values filled in will be of the value replace. 351 | Assume the coords are given min_x, min_y, max_x, max_y. 352 | Both applied to image and bboxes. 353 | """ 354 | def __init__(self, p, pixels, replace=128, minus=True): 355 | super().__init__() 356 | self.p = p 357 | self.pixels = int(pixels) 358 | self.replace = replace 359 | self.minus = minus 360 | 361 | def forward(self, image, bboxs): 362 | if self.minus and (torch.rand(1) < 0.5): self.pixels *= -1 363 | if torch.rand(1) < self.p: 364 | translate_image = image.transform(image.size, Image.AFFINE, (1, 0, 0, 0, 1, -self.pixels), fillcolor=(self.replace, self.replace, self.replace)) 365 | if bboxs == None: 366 | return translate_image, bboxs 367 | else: 368 | translate_bbox = functional.translate_bbox(image, bboxs, self.pixels, self.replace, shift_horizontal=False) 369 | return translate_image, translate_bbox 370 | else: 371 | return image, bboxs 372 | 373 | 374 | ### Mask Augmentation 375 | class Cutout(torch.nn.Module): 376 | """ 377 | Apply cutout (https://arxiv.org/abs/1708.04552) to the image. 378 | This operation applies a (2*pad_size, 2*pad_size) mask of zeros to a random location within image. 379 | The pixel values filled in will be of the value replace. 380 | Only applied to image, not bboxes. 381 | """ 382 | def __init__(self, p, pad_size, replace=128): 383 | super().__init__() 384 | self.p = p 385 | self.pad_size = int(pad_size) 386 | self.replace = replace 387 | 388 | def forward(self, image, bboxs): 389 | if torch.rand(1) < self.p: 390 | cutout_image = functional.cutout(image, self.pad_size, self.replace) 391 | return cutout_image, bboxs 392 | else: 393 | return image, bboxs 394 | 395 | 396 | ### Color Augmentation based on BBoxes 397 | class Equalize_Only_BBoxes(torch.nn.Module): 398 | """ 399 | Apply equalize to each bboxes in the image with probability. 400 | Assume the coords are given min_x, min_y, max_x, max_y. 401 | Only applied to image not bboxes. 402 | """ 403 | def __init__(self, p): 404 | super().__init__() 405 | self.p = p/3 406 | 407 | def forward(self, image, bboxs): 408 | if bboxs == None: 409 | return image, bboxs 410 | else: 411 | equalize_image = functional.equalize_only_bboxes(image, bboxs, self.p) 412 | return equalize_image, bboxs 413 | 414 | 415 | class Solarize_Only_BBoxes(torch.nn.Module): 416 | """ 417 | Apply solarize to each bboxes in the image with probability. 418 | Assume the coords are given min_x, min_y, max_x, max_y. 419 | Only applied to image not bboxes. 420 | """ 421 | def __init__(self, p, threshold): 422 | super().__init__() 423 | self.p = p/3 424 | self.threshold = int(threshold) 425 | 426 | def forward(self, image, bboxs): 427 | if bboxs == None: 428 | return image, bboxs 429 | else: 430 | solarize_image = functional.solarize_only_bboxes(image, bboxs, self.p, self.threshold) 431 | return solarize_image, bboxs 432 | 433 | 434 | ### Geometric Augmentation based on BBoxes 435 | class Rotate_Only_BBoxes(torch.nn.Module): 436 | """ 437 | Apply rotation to each bboxes in the image with probability. 438 | Assume the coords are given min_x, min_y, max_x, max_y. 439 | Only applied to image not bboxes. 440 | """ 441 | def __init__(self, p, degrees, replace=128, minus=True): 442 | super().__init__() 443 | self.p = p/3 444 | self.degrees = degrees 445 | self.replace = replace 446 | self.minus = minus 447 | 448 | def forward(self, image, bboxs): 449 | if self.minus and (torch.rand(1) < 0.5): self.degrees *= -1 450 | if bboxs == None: 451 | return image, bboxs 452 | else: 453 | rotate_image = functional.rotate_only_bboxes(image, bboxs, self.p, self.degrees, self.replace) 454 | return rotate_image, bboxs 455 | 456 | 457 | class ShearX_Only_BBoxes(torch.nn.Module): 458 | """ 459 | Apply shear to each bboxes in the image with probability only on X-axis. 460 | Assume the coords are given min_x, min_y, max_x, max_y. 461 | Only applied to image not bboxes. 462 | """ 463 | def __init__(self, p, level, replace=128, minus=True): 464 | super().__init__() 465 | self.p = p/3 466 | self.level = level 467 | self.replace = replace 468 | self.minus = minus 469 | 470 | def forward(self, image, bboxs): 471 | if self.minus and (torch.rand(1) < 0.5): self.level *= -1 472 | if bboxs == None: 473 | return image, bboxs 474 | else: 475 | shear_image = functional.shear_only_bboxes(image, bboxs, self.p, self.level, self.replace, shift_horizontal=True) 476 | return shear_image, bboxs 477 | 478 | 479 | class ShearY_Only_BBoxes(torch.nn.Module): 480 | """ 481 | Apply shear to each bboxes in the image with probability only on Y-axis. 482 | Assume the coords are given min_x, min_y, max_x, max_y. 483 | Only applied to image not bboxes. 484 | """ 485 | def __init__(self, p, level, replace=128, minus=True): 486 | super().__init__() 487 | self.p = p/3 488 | self.level = level 489 | self.replace = replace 490 | self.minus = minus 491 | 492 | def forward(self, image, bboxs): 493 | if self.minus and (torch.rand(1) < 0.5): self.level *= -1 494 | if bboxs == None: 495 | return image, bboxs 496 | else: 497 | shear_image = functional.shear_only_bboxes(image, bboxs, self.p, self.level, self.replace, shift_horizontal=False) 498 | return shear_image, bboxs 499 | 500 | 501 | class TranslateX_Only_BBoxes(torch.nn.Module): 502 | """ 503 | Apply translation to each bboxes in the image with probability only on X-axis. 504 | Assume the coords are given min_x, min_y, max_x, max_y. 505 | Only applied to image not bboxes. 506 | """ 507 | def __init__(self, p, pixels, replace=128, minus=True): 508 | super().__init__() 509 | self.p = p/3 510 | self.pixels = int(pixels) 511 | self.replace = replace 512 | self.minus = minus 513 | 514 | def forward(self, image, bboxs): 515 | if self.minus and (torch.rand(1) < 0.5): self.pixels *= -1 516 | if bboxs == None: 517 | return image, bboxs 518 | else: 519 | translate_image = functional.translate_only_bboxes(image, bboxs, self.p, self.pixels, self.replace, shift_horizontal=True) 520 | return translate_image, bboxs 521 | 522 | 523 | class TranslateY_Only_BBoxes(torch.nn.Module): 524 | """ 525 | Apply transloation to each bboxes in the image with probability only on Y-axis. 526 | Assume the coords are given min_x, min_y, max_x, max_y. 527 | Only applied to image not bboxes. 528 | """ 529 | def __init__(self, p, pixels, replace=128, minus=True): 530 | super().__init__() 531 | self.p = p/3 532 | self.pixels = int(pixels) 533 | self.replace = replace 534 | self.minus = minus 535 | 536 | def forward(self, image, bboxs): 537 | if self.minus and (torch.rand(1) < 0.5): self.pixels *= -1 538 | if bboxs == None: 539 | return image, bboxs 540 | else: 541 | translate_image = functional.translate_only_bboxes(image, bboxs, self.p, self.pixels, self.replace, shift_horizontal=False) 542 | return translate_image, bboxs 543 | 544 | 545 | class Flip_Only_BBoxes(torch.nn.Module): 546 | """ 547 | Apply horizontal flip to each bboxes in the image with probability. 548 | Assume the coords are given min_x, min_y, max_x, max_y. 549 | Only applied to image not bboxes. 550 | """ 551 | def __init__(self, p): 552 | super().__init__() 553 | self.p = p/3 554 | 555 | def forward(self, image, bboxs): 556 | if bboxs == None: 557 | return image, bboxs 558 | else: 559 | flip_image = functional.flip_only_bboxes(image, bboxs, self.p) 560 | return flip_image, bboxs 561 | 562 | 563 | ### Mask Augmentation based on BBoxes 564 | class BBox_Cutout(torch.nn.Module): 565 | """ 566 | Apply cutout to the image according to bbox information. 567 | Assume the coords are given min_x, min_y, max_x, max_y. 568 | Only applied to image, not bboxes. 569 | """ 570 | def __init__(self, p, pad_fraction, replace_with_mean=False): 571 | super().__init__() 572 | self.p = p 573 | self.pad_fraction = pad_fraction 574 | self.replace_with_mean = replace_with_mean 575 | 576 | def forward(self, image, bboxs): 577 | if (torch.rand(1) < self.p) and (bboxs != None): 578 | cutout_image = functional.bbox_cutout(image, bboxs, self.pad_fraction, self.replace_with_mean) 579 | return cutout_image, bboxs 580 | else: 581 | return image, bboxs 582 | 583 | 584 | class Cutout_Only_BBoxes(torch.nn.Module): 585 | """ 586 | Apply cutout to each bboxes in the image with probability. 587 | Assume the coords are given min_x, min_y, max_x, max_y. 588 | Only applied to image not bboxes. 589 | """ 590 | def __init__(self, p, pad_size, replace=128): 591 | super().__init__() 592 | self.p = p/3 593 | self.pad_size = int(pad_size) 594 | self.replace = replace 595 | 596 | def forward(self, image, bboxs): 597 | if bboxs == None: 598 | return image, bboxs 599 | else: 600 | cutout_image = functional.cutout_only_bboxes(image, bboxs, self.p, self.pad_size, self.replace) 601 | return cutout_image, bboxs 602 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os, torch, torchvision 2 | from PIL import Image 3 | from pycocotools.coco import COCO 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class COCO_detection(Dataset): 8 | def __init__(self, img_dir, ann, transforms=None): 9 | super(COCO_detection, self).__init__() 10 | self.img_dir = img_dir 11 | self.transforms = transforms 12 | self.coco = COCO(ann) 13 | self.ids = list(sorted(self.coco.imgs.keys())) 14 | self.label_map = {raw_label:i for i, raw_label in enumerate(self.coco.getCatIds())} 15 | 16 | def _load_image(self, id_): 17 | img = self.coco.loadImgs(id_)[0]['file_name'] 18 | return Image.open(os.path.join(self.img_dir, img)).convert('RGB') 19 | 20 | def _load_target(self, id_): 21 | if len(self.coco.loadAnns(self.coco.getAnnIds(id_))) == 0: return None, None 22 | bboxs, labels = [], [] 23 | for ann in self.coco.loadAnns(self.coco.getAnnIds(id_)): 24 | min_x, min_y, w, h = ann['bbox'] 25 | bboxs.append(torch.FloatTensor([min_x, min_y, min_x+w, min_y+h])) 26 | labels.append(self.label_map[ann['category_id']]) 27 | bboxs, labels = torch.stack(bboxs, 0), torch.LongTensor(labels) 28 | return bboxs, labels 29 | 30 | def __getitem__(self, index): 31 | id_ = self.ids[index] 32 | image, (bboxs, labels) = self._load_image(id_), self._load_target(id_) 33 | if self.transforms is not None: 34 | image, bboxs = self.transforms(image, bboxs) 35 | 36 | return image, bboxs, labels 37 | 38 | def __len__(self): 39 | return len(self.ids) 40 | 41 | 42 | class COCO_detection_raw(Dataset): 43 | def __init__(self, img_dir, ann, transforms=None): 44 | super(COCO_detection_visualize, self).__init__() 45 | self.img_dir = img_dir 46 | self.transforms = transforms 47 | self.coco = COCO(ann) 48 | self.ids = list(sorted(self.coco.imgs.keys())) 49 | 50 | def _load_image(self, id_): 51 | img = self.coco.loadImgs(id_)[0]['file_name'] 52 | return Image.open(os.path.join(self.img_dir, img)).convert('RGB') 53 | 54 | def _load_target(self, id_): 55 | return self.coco.loadAnns(self.coco.getAnnIds(id_)) 56 | 57 | def __getitem__(self, index): 58 | id_ = self.ids[index] 59 | image, target = self._load_image(id_), self._load_target(id_) 60 | if self.transforms is not None: image, target = self.transforms(image, target) 61 | 62 | return image, target 63 | 64 | def __len__(self): 65 | return len(self.ids) -------------------------------------------------------------------------------- /functional.py: -------------------------------------------------------------------------------- 1 | import math, torch, torchvision 2 | import torchvision.transforms.functional as F 3 | from PIL import Image, ImageEnhance, ImageOps 4 | 5 | 6 | def solarize_add(img, addition, threshold): 7 | img = F.pil_to_tensor(img) 8 | added_img = img + addition 9 | added_img = torch.clamp(added_img, 0, 255) 10 | return F.to_pil_image(torch.where(img < threshold, added_img, img)) 11 | 12 | 13 | def color(img, magnitude): 14 | return ImageEnhance.Color(img).enhance(magnitude) 15 | 16 | 17 | def contrast(img, magnitude): 18 | return ImageEnhance.Contrast(img).enhance(magnitude) 19 | 20 | 21 | def brightness(img, magnitude): 22 | return ImageEnhance.Brightness(img).enhance(magnitude) 23 | 24 | 25 | def sharpness(img, magnitude): 26 | return ImageEnhance.Sharpness(img).enhance(magnitude) 27 | 28 | 29 | def cutout(img, pad_size, replace): 30 | img = F.pil_to_tensor(img) 31 | _, h, w = img.shape 32 | center_h, center_w = torch.randint(high=h, size=(1,)), torch.randint(high=w, size=(1,)) 33 | low_h, high_h = torch.clamp(center_h-pad_size, 0, h).item(), torch.clamp(center_h+pad_size, 0, h).item() 34 | low_w, high_w = torch.clamp(center_w-pad_size, 0, w).item(), torch.clamp(center_w+pad_size, 0, w).item() 35 | cutout_img = img.clone() 36 | cutout_img[:, low_h:high_h, low_w:high_w] = replace 37 | return F.to_pil_image(cutout_img) 38 | 39 | 40 | def bbox_cutout(img, bboxs, pad_fraction, replace_with_mean): 41 | img = F.pil_to_tensor(img) 42 | _, h, w = img.shape 43 | random_index = torch.randint(bboxs.size(0), size=(1,)).item() 44 | chosen_bbox = bboxs[random_index] 45 | min_x, min_y, max_x, max_y = chosen_bbox 46 | min_x, min_y, max_x, max_y = int(min_x.item()), int(min_y.item()), int(max_x.item()), int(max_y.item()) 47 | 48 | if (min_x == max_x) or (min_y == max_y): return F.to_pil_image(img) 49 | 50 | mask_x, mask_y = torch.randint(low=min_x, high=max_x, size=(1,)), torch.randint(low=min_y, high=max_y, size=(1,)) 51 | mask_w, mask_h = pad_fraction * w / 2, pad_fraction * h / 2 52 | 53 | x_min, x_max = int(torch.clamp(mask_x-mask_w, 0, w).item()), int(torch.clamp(mask_x+mask_w, 0, w).item()) 54 | y_min, y_max = int(torch.clamp(mask_y-mask_h, 0, h).item()), int(torch.clamp(mask_y+mask_h, 0, h).item()) 55 | 56 | if replace_with_mean == True: replace = torch.mean(img[:, min_y:max_y, min_x:max_x]).item() 57 | else: replace = 128 58 | 59 | cutout_img = img.clone() 60 | cutout_img[:, y_min:y_max, x_min:x_max] = replace 61 | return F.to_pil_image(cutout_img) 62 | 63 | 64 | def _rotate_bbox(img, bboxs, degrees): 65 | img = F.pil_to_tensor(img) 66 | _, h, w = img.shape 67 | 68 | rotate_bboxs = [] 69 | rotate_matrix = torch.FloatTensor([[math.cos(degrees*math.pi/180), math.sin(degrees*math.pi/180)], 70 | [-math.sin(degrees*math.pi/180), math.cos(degrees*math.pi/180)]]) 71 | for bbox in bboxs: 72 | min_x, min_y, max_x, max_y = bbox 73 | rel_min_x, rel_max_x, rel_min_y, rel_max_y = min_x-w/2, max_x-w/2, min_y-h/2, max_y-h/2 74 | coords = torch.FloatTensor([[rel_min_x, rel_min_y], 75 | [rel_min_x, rel_max_y], 76 | [rel_max_x, rel_max_y], 77 | [rel_max_x, rel_min_y]]) 78 | rotate_coords = torch.matmul(rotate_matrix, coords.t()).t() 79 | x_min, y_min = torch.min(rotate_coords, dim=0)[0] 80 | x_max, y_max = torch.max(rotate_coords, dim=0)[0] 81 | 82 | rotate_min_x, rotate_max_x = torch.clamp(x_min+w/2, 0, w),torch.clamp(x_max+w/2, 0, w) 83 | rotate_min_y, rotate_max_y = torch.clamp(y_min+h/2, 0, h),torch.clamp(y_max+h/2, 0, h) 84 | rotate_bboxs.append(torch.FloatTensor([rotate_min_x, rotate_min_y, rotate_max_x, rotate_max_y])) 85 | return torch.stack(rotate_bboxs) 86 | 87 | 88 | def translate_bbox(img, bboxs, pixels, replace, shift_horizontal): 89 | img = F.pil_to_tensor(img) 90 | _, h, w = img.shape 91 | 92 | translate_bboxs = [] 93 | if shift_horizontal: 94 | for bbox in bboxs: 95 | min_x, min_y, max_x, max_y = bbox 96 | translate_min_x, translate_max_x = torch.clamp(min_x+pixels, 0, w), torch.clamp(max_x+pixels, 0, w) 97 | translate_min_x, translate_max_x = int(translate_min_x.item()), int(translate_max_x.item()) 98 | translate_bboxs.append(torch.FloatTensor([translate_min_x, min_y, translate_max_x, max_y])) 99 | else: 100 | for bbox in bboxs: 101 | min_x, min_y, max_x, max_y = bbox 102 | translate_min_y, translate_max_y = torch.clamp(min_y+pixels, 0, h), torch.clamp(max_y+pixels, 0, h) 103 | translate_min_y, translate_max_y = int(translate_min_y.item()), int(translate_max_y.item()) 104 | translate_bboxs.append(torch.FloatTensor([min_x, translate_min_y, max_x, translate_max_y])) 105 | return torch.stack(translate_bboxs) 106 | 107 | 108 | def shear_with_bboxes(img, bboxs, level, replace, shift_horizontal): 109 | img = F.pil_to_tensor(img) 110 | _, h, w = img.shape 111 | 112 | shear_bboxs = [] 113 | if shift_horizontal: 114 | shear_matrix = torch.FloatTensor([[1, -level], 115 | [0, 1]]) 116 | for bbox in bboxs: 117 | min_x, min_y, max_x, max_y = bbox 118 | coords = torch.FloatTensor([[min_x, min_y], 119 | [min_x, max_y], 120 | [max_x, max_y], 121 | [max_x, min_y]]) 122 | shear_coords = torch.matmul(shear_matrix, coords.t()).t() 123 | x_min, y_min = torch.min(shear_coords, dim=0)[0] 124 | x_max, y_max = torch.max(shear_coords, dim=0)[0] 125 | shear_min_x, shear_max_x = torch.clamp(x_min, 0, w), torch.clamp(x_max, 0, w) 126 | shear_min_y, shear_max_y = torch.clamp(y_min, 0, h), torch.clamp(y_max, 0, h) 127 | shear_bboxs.append(torch.FloatTensor([shear_min_x, shear_min_y, shear_max_x, shear_max_y])) 128 | else: 129 | shear_matrix = torch.FloatTensor([[1, 0], 130 | [-level, 1]]) 131 | for bbox in bboxs: 132 | min_x, min_y, max_x, max_y = bbox 133 | coords = torch.FloatTensor([[min_x, min_y], 134 | [min_x, max_y], 135 | [max_x, max_y], 136 | [max_x, min_y]]) 137 | shear_coords = torch.matmul(shear_matrix, coords.t()).t() 138 | x_min, y_min = torch.min(shear_coords, dim=0)[0] 139 | x_max, y_max = torch.max(shear_coords, dim=0)[0] 140 | shear_min_x, shear_max_x = torch.clamp(x_min, 0, w), torch.clamp(x_max, 0, w) 141 | shear_min_y, shear_max_y = torch.clamp(y_min, 0, h), torch.clamp(y_max, 0, h) 142 | shear_bboxs.append(torch.FloatTensor([shear_min_x, shear_min_y, shear_max_x, shear_max_y])) 143 | return torch.stack(shear_bboxs) 144 | 145 | 146 | def rotate_only_bboxes(img, bboxs, p, degrees, replace): 147 | img = F.pil_to_tensor(img) 148 | rotate_img = torch.zeros_like(img) 149 | 150 | for bbox in bboxs: 151 | if torch.rand(1) < p: 152 | min_x, min_y, max_x, max_y = bbox 153 | min_x, min_y, max_x, max_y = int(min_x.item()), int(min_y.item()), int(max_x.item()), int(max_y.item()) 154 | bbox_rotate_img = F.to_pil_image(img[:, min_y:max_y+1, min_x:max_x+1]).rotate(degrees, fillcolor=(replace,replace,replace)) 155 | rotate_img[:, min_y:max_y+1, min_x:max_x+1] = F.pil_to_tensor(bbox_rotate_img) 156 | return F.to_pil_image(torch.where(rotate_img != 0, rotate_img, img)) 157 | 158 | 159 | def shear_only_bboxes(img, bboxs, p, level, replace, shift_horizontal): 160 | img = F.pil_to_tensor(img) 161 | shear_img = torch.zeros_like(img) 162 | 163 | for bbox in bboxs: 164 | if torch.rand(1) < p: 165 | min_x, min_y, max_x, max_y = bbox 166 | min_x, min_y, max_x, max_y = int(min_x.item()), int(min_y.item()), int(max_x.item()), int(max_y.item()) 167 | 168 | bbox_shear_img = F.to_pil_image(img[:, min_y:max_y+1, min_x:max_x+1]) 169 | if shift_horizontal: 170 | bbox_shear_img = bbox_shear_img.transform(bbox_shear_img.size, Image.AFFINE, (1,level,0,0,1,0), fillcolor=(replace,replace,replace)) 171 | else: 172 | bbox_shear_img = bbox_shear_img.transform(bbox_shear_img.size, Image.AFFINE, (1,0,0,level,1,0), fillcolor=(replace,replace,replace)) 173 | shear_img[:, min_y:max_y+1, min_x:max_x+1] = F.pil_to_tensor(bbox_shear_img) 174 | 175 | return F.to_pil_image(torch.where(shear_img != 0, shear_img, img)) 176 | 177 | 178 | def translate_only_bboxes(img, bboxs, p, pixels, replace, shift_horizontal): 179 | img = F.pil_to_tensor(img) 180 | translate_img = torch.zeros_like(img) 181 | 182 | for bbox in bboxs: 183 | if torch.rand(1) < p: 184 | min_x, min_y, max_x, max_y = bbox 185 | min_x, min_y, max_x, max_y = int(min_x.item()), int(min_y.item()), int(max_x.item()), int(max_y.item()) 186 | 187 | bbox_tran_img = F.to_pil_image(img[:, min_y:max_y+1, min_x:max_x+1]) 188 | if shift_horizontal: 189 | bbox_tran_img = bbox_tran_img.transform(bbox_tran_img.size, Image.AFFINE, (1,0,-pixels,0,1,0), fillcolor=(replace,replace,replace)) 190 | else: 191 | bbox_tran_img = bbox_tran_img.transform(bbox_tran_img.size, Image.AFFINE, (1,0,0,0,1,-pixels), fillcolor=(replace,replace,replace)) 192 | translate_img[:, min_y:max_y+1, min_x:max_x+1] = F.pil_to_tensor(bbox_tran_img) 193 | 194 | return F.to_pil_image(torch.where(translate_img != 0, translate_img, img)) 195 | 196 | 197 | def flip_only_bboxes(img, bboxs, p): 198 | img = F.pil_to_tensor(img) 199 | flip_img = torch.zeros_like(img) 200 | 201 | for bbox in bboxs: 202 | if torch.rand(1) < p: 203 | min_x, min_y, max_x, max_y = bbox 204 | min_x, min_y, max_x, max_y = int(min_x.item()), int(min_y.item()), int(max_x.item()), int(max_y.item()) 205 | flip_img[:, min_y:max_y+1, min_x:max_x+1] = F.hflip(img[:, min_y:max_y+1, min_x:max_x+1]) 206 | 207 | return F.to_pil_image(torch.where(flip_img != 0, flip_img, img)) 208 | 209 | 210 | def solarize_only_bboxes(img, bboxs, p, threshold): 211 | img = F.pil_to_tensor(img) 212 | for bbox in bboxs: 213 | if torch.rand(1) < p: 214 | min_x, min_y, max_x, max_y = bbox 215 | min_x, min_y, max_x, max_y = int(min_x.item()), int(min_y.item()), int(max_x.item()), int(max_y.item()) 216 | solarize_img = img[:, min_y:max_y+1, min_x:max_x+1] 217 | solarize_img = F.to_pil_image(solarize_img) 218 | solarize_img = ImageOps.solarize(solarize_img, threshold=threshold) 219 | solarize_img = F.pil_to_tensor(solarize_img) 220 | img[:, min_y:max_y+1, min_x:max_x+1] = solarize_img 221 | return F.to_pil_image(img) 222 | 223 | 224 | def equalize_only_bboxes(img, bboxs, p): 225 | img = F.pil_to_tensor(img) 226 | for bbox in bboxs: 227 | if torch.rand(1) < p: 228 | min_x, min_y, max_x, max_y = bbox 229 | min_x, min_y, max_x, max_y = int(min_x.item()), int(min_y.item()), int(max_x.item()), int(max_y.item()) 230 | equalize_img = img[:, min_y:max_y+1, min_x:max_x+1] 231 | equalize_img = F.to_pil_image(equalize_img) 232 | equalize_img = ImageOps.equalize(equalize_img) 233 | equalize_img = F.pil_to_tensor(equalize_img) 234 | img[:, min_y:max_y+1, min_x:max_x+1] = equalize_img 235 | return F.to_pil_image(img) 236 | 237 | 238 | def cutout_only_bboxes(img, bboxs, p, pad_size, replace): 239 | img = F.pil_to_tensor(img) 240 | cutout_img = img.clone() 241 | 242 | for bbox in bboxs: 243 | if torch.rand(1) < p: 244 | min_x, min_y, max_x, max_y = bbox 245 | min_x, min_y, max_x, max_y = int(min_x.item()), int(min_y.item()), int(max_x.item()), int(max_y.item()) 246 | 247 | cutout_x, cutout_y = torch.randint(low=min_x, high=max_x, size=(1,)), torch.randint(low=min_y, high=max_y, size=(1,)) 248 | 249 | y_min, y_max = int(torch.clamp(cutout_y-pad_size, min_y, max_y).item()), int(torch.clamp(cutout_y+pad_size, min_y, max_y).item()) 250 | x_min, x_max = int(torch.clamp(cutout_x-pad_size, min_x, max_x).item()), int(torch.clamp(cutout_x+pad_size, min_x, max_x).item()) 251 | 252 | cutout_img[:, y_min:y_max, x_min:x_max] = replace 253 | 254 | return F.to_pil_image(cutout_img) -------------------------------------------------------------------------------- /policy.py: -------------------------------------------------------------------------------- 1 | import torch, random 2 | from augmentation import * 3 | 4 | 5 | M = 10 6 | 7 | color_range = torch.arange(0, 0.9+1e-8, (0.9-0)/M).tolist() 8 | rotate_range = torch.arange(0, 30+1e-8, (30-0)/M).tolist() 9 | shear_range = torch.arange(0, 0.3+1e-8, (0.3-0)/M).tolist() 10 | translate_range = torch.arange(0, 250+1e-8, (250-0)/M).tolist() 11 | translate_bbox_range = torch.arange(0, 120+1e-8, (120-0)/M).tolist() 12 | 13 | 14 | Mag = {'Brightness' : color_range, 'Color' : color_range, 'Contrast' : color_range, 15 | 'Posterize' : torch.arange(4, 8+1e-8, (8-4)/M).tolist()[::-1], 'Sharpness' : color_range, 16 | 'Solarize' : torch.arange(0, 256+1e-8, (256-0)/M).tolist()[::-1], 'SolarizeAdd' : torch.arange(0, 110+1e-8, (110-0)/M).tolist(), 17 | 18 | 'Cutout' : torch.arange(0, 100+1e-8, (100-0)/M).tolist(), 19 | 20 | 'Rotate_BBox' : rotate_range, 'ShearX_BBox' : shear_range, 'ShearY_BBox' : shear_range, 21 | 'TranslateX_BBox' : translate_range, 'TranslateY_BBox' : translate_range, 22 | 23 | 'Rotate_Only_BBoxes' : rotate_range, 'ShearX_Only_BBoxes' : shear_range, 'ShearY_Only_BBoxes' : shear_range, 24 | 'TranslateX_Only_BBoxes' : translate_bbox_range, 'TranslateY_Only_BBoxes' : translate_bbox_range, 25 | 26 | 'Solarize_Only_BBoxes' : torch.arange(0, 256+1e-8, (256-0)/M).tolist()[::-1], 27 | 28 | 'BBox_Cutout' : torch.arange(0, 0.75+1e-8, (0.75-0)/M).tolist(), 'Cutout_Only_BBoxes' : torch.arange(0, 50+1e-8, (50-0)/M).tolist() 29 | } 30 | 31 | 32 | Fun = {'AutoContrast' : AutoContrast, 'Brightness' : Brightness, 'Color' : Color, 'Contrast' : Contrast, 'Equalize' : Equalize, 33 | 'Posterize' : Posterize, 'Sharpness' : Sharpness, 'Solarize' : Solarize, 'SolarizeAdd' : SolarizeAdd, 34 | 35 | 'Cutout' : Cutout, 36 | 37 | 'Rotate_BBox' : Rotate_BBox, 'ShearX_BBox' : ShearX_BBox, 'ShearY_BBox' : ShearY_BBox, 38 | 'TranslateX_BBox' : TranslateX_BBox, 'TranslateY_BBox' : TranslateY_BBox, 39 | 40 | 'Rotate_Only_BBoxes' : Rotate_Only_BBoxes, 'ShearX_Only_BBoxes' : ShearX_Only_BBoxes, 'ShearY_Only_BBoxes' : ShearY_Only_BBoxes, 41 | 'TranslateX_Only_BBoxes' : TranslateX_Only_BBoxes, 'TranslateY_Only_BBoxes' : TranslateY_Only_BBoxes, 'Flip_Only_BBoxes' : Flip_Only_BBoxes, 42 | 43 | 'Equalize_Only_BBoxes' : Equalize_Only_BBoxes, 'Solarize_Only_BBoxes' : Solarize_Only_BBoxes, 44 | 45 | 'BBox_Cutout' : BBox_Cutout, 'Cutout_Only_BBoxes' : Cutout_Only_BBoxes 46 | } 47 | 48 | 49 | class Policy(torch.nn.Module): 50 | def __init__(self, policy, pre_transform, post_transform): 51 | super().__init__() 52 | self.pre_transform = pre_transform 53 | self.post_transform = post_transform 54 | 55 | if policy == 'policy_v0': self.policy = policy_v0() 56 | elif policy == 'policy_v1': self.policy = policy_v1() 57 | elif policy == 'policy_v2': self.policy = policy_v2() 58 | elif policy == 'policy_v3': self.policy = policy_v3() 59 | elif policy == 'policy_vtest': self.policy = policy_vtest() 60 | 61 | def forward(self, image, bboxs): 62 | policy_idx = random.randint(0, len(self.policy)-1) 63 | policy_transform = self.pre_transform + self.policy[policy_idx] + self.post_transform 64 | policy_transform = Compose(policy_transform) 65 | image, bboxs = policy_transform(image, bboxs) 66 | return image, bboxs 67 | 68 | 69 | def SubPolicy(f1, p1, m1, f2, p2, m2): 70 | subpolicy = [] 71 | if f1 in ['AutoContrast', 'Equalize', 'Equalize_Only_BBoxes', 'Flip_Only_BBoxes']: subpolicy.append(Fun[f1](p1)) 72 | else: subpolicy.append(Fun[f1](p1, Mag[f1][m1])) 73 | 74 | if f2 in ['AutoContrast', 'Equalize', 'Equalize_Only_BBoxes', 'Flip_Only_BBoxes']: subpolicy.append(Fun[f2](p2)) 75 | else: subpolicy.append(Fun[f2](p2, Mag[f2][m2])) 76 | 77 | return subpolicy 78 | 79 | 80 | def SubPolicy3(f1, p1, m1, f2, p2, m2, f3, p3, m3): 81 | subpolicy = [] 82 | if f1 in ['AutoContrast', 'Equalize', 'Equalize_Only_BBoxes', 'Flip_Only_BBoxes']: subpolicy.append(Fun[f1](p1)) 83 | else: subpolicy.append(Fun[f1](p1, Mag[f1][m1])) 84 | 85 | if f2 in ['AutoContrast', 'Equalize', 'Equalize_Only_BBoxes', 'Flip_Only_BBoxes']: subpolicy.append(Fun[f2](p2)) 86 | else: subpolicy.append(Fun[f2](p2, Mag[f2][m2])) 87 | 88 | if f3 in ['AutoContrast', 'Equalize', 'Equalize_Only_BBoxes', 'Flip_Only_BBoxes']: subpolicy.append(Fun[f3](p3)) 89 | else: subpolicy.append(Fun[f3](p3, Mag[f3][m3])) 90 | 91 | return subpolicy 92 | 93 | 94 | def policy_v0(): 95 | policy = [SubPolicy('TranslateX_BBox', 0.6, 4, 'Equalize', 0.8, None), 96 | SubPolicy('TranslateY_Only_BBoxes', 0.2, 2, 'Cutout', 0.8, 8), 97 | SubPolicy('Sharpness', 0.0, 8, 'ShearX_BBox', 0.4, 0), 98 | SubPolicy('ShearY_BBox', 1.0, 2, 'TranslateY_Only_BBoxes', 0.6, 6), 99 | SubPolicy('Rotate_BBox', 0.6, 10, 'Color', 1.0, 6)] 100 | return policy 101 | 102 | 103 | def policy_v1(): 104 | policy = [SubPolicy('TranslateX_BBox', 0.6, 4, 'Equalize', 0.8, None), 105 | SubPolicy('TranslateY_Only_BBoxes', 0.2, 2, 'Cutout', 0.8, 8), 106 | SubPolicy('Sharpness', 0, 8, 'ShearX_BBox', 0.4, 0), 107 | SubPolicy('ShearY_BBox', 1.0, 2, 'TranslateY_Only_BBoxes', 0.6, 6), 108 | SubPolicy('Rotate_BBox', 0.6, 10, 'Color', 1.0, 6), 109 | SubPolicy('Color', 0.0, 0, 'ShearX_Only_BBoxes', 0.8, 4), 110 | SubPolicy('ShearY_Only_BBoxes', 0.8, 2, 'Flip_Only_BBoxes', 0.0, None), 111 | SubPolicy('Equalize', 0.6, None, 'TranslateX_BBox', 0.2, 2), 112 | SubPolicy('Color', 1.0, 10, 'TranslateY_Only_BBoxes', 0.4, 6), 113 | SubPolicy('Rotate_BBox', 0.8, 10, 'Contrast', 0.0, 10), 114 | SubPolicy('Cutout', 0.2, 2, 'Brightness', 0.8, 10), 115 | SubPolicy('Color', 1.0, 6, 'Equalize', 1.0, None), 116 | SubPolicy('Cutout_Only_BBoxes', 0.4, 6, 'TranslateY_Only_BBoxes', 0.8, 2), 117 | SubPolicy('Color', 0.2, 8, 'Rotate_BBox', 0.8, 10), 118 | SubPolicy('Sharpness', 0.4, 4, 'TranslateY_Only_BBoxes', 0.0, 4), 119 | SubPolicy('Sharpness', 1.0, 4, 'SolarizeAdd', 0.4, 4), 120 | SubPolicy('Rotate_BBox', 1.0, 8, 'Sharpness', 0.2, 8), 121 | SubPolicy('ShearY_BBox', 0.6, 10, 'Equalize_Only_BBoxes', 0.6, None), 122 | SubPolicy('ShearX_BBox', 0.2, 6, 'TranslateY_Only_BBoxes', 0.2, 10), 123 | SubPolicy('SolarizeAdd', 0.6, 8, 'Brightness', 0.8, 10)] 124 | return policy 125 | 126 | 127 | def policy_vtest(): 128 | policy = [SubPolicy('TranslateX_BBox', 1.0, 4, 'Equalize', 1.0, None)] 129 | return policy 130 | 131 | 132 | def policy_v2(): 133 | policy = [SubPolicy3('Color', 0.0, 6, 'Cutout', 0.6, 8, 'Sharpness', 0.4, 8), 134 | SubPolicy3('Rotate_BBox', 0.4, 8, 'Sharpness', 0.4, 2, 'Rotate_BBox', 0.8, 10), 135 | SubPolicy('TranslateY_BBox', 1.0, 8, 'AutoContrast', 0.8, None), 136 | SubPolicy3('AutoContrast', 0.4, None, 'ShearX_BBox', 0.8, 8, 'Brightness', 0.0, 10), 137 | SubPolicy3('SolarizeAdd', 0.2, 6, 'Contrast', 0.0, 10, 'AutoContrast', 0.6, None), 138 | SubPolicy3('Cutout', 0.2, 0, 'Solarize', 0.8, 8, 'Color', 1.0, 4), 139 | SubPolicy3('TranslateY_BBox', 0.0, 4, 'Equalize', 0.6, None, 'Solarize', 0.0, 10), 140 | SubPolicy3('TranslateY_BBox', 0.2, 2, 'ShearY_BBox', 0.8, 8, 'Rotate_BBox', 0.8, 8), 141 | SubPolicy3('Cutout', 0.8, 8, 'Brightness', 0.8, 8, 'Cutout', 0.2, 2), 142 | SubPolicy3('Color', 0.8, 4, 'TranslateY_BBox', 1.0, 6, 'Rotate_BBox', 0.6, 6), 143 | SubPolicy3('Rotate_BBox', 0.6, 10, 'BBox_Cutout', 1.0, 4, 'Cutout', 0.2, 8), 144 | SubPolicy3('Rotate_BBox', 0.0, 0, 'Equalize', 0.6, None, 'ShearY_BBox', 0.6, 8), 145 | SubPolicy3('Brightness', 0.8, 8, 'AutoContrast', 0.4, None, 'Brightness', 0.2, 2), 146 | SubPolicy3('TranslateY_BBox', 0.4, 8, 'Solarize', 0.4, 6, 'SolarizeAdd', 0.2, 10), 147 | SubPolicy3('Contrast', 1.0, 10, 'SolarizeAdd', 0.2, 8, 'Equalize', 0.2, None)] 148 | return policy 149 | 150 | 151 | def policy_v3(): 152 | policy = [SubPolicy('Posterize', 0.8, 2, 'TranslateX_BBox', 1.0, 8), 153 | SubPolicy('BBox_Cutout', 0.2, 10, 'Sharpness', 1.0, 8), 154 | SubPolicy('Rotate_BBox', 0.6, 8, 'Rotate_BBox', 0.8, 10), 155 | SubPolicy('Equalize', 0.8, None, 'AutoContrast', 0.2, None), 156 | SubPolicy('SolarizeAdd', 0.2, 2, 'TranslateY_BBox', 0.2, 8), 157 | SubPolicy('Sharpness', 0.0, 2, 'Color', 0.4, 8), 158 | SubPolicy('Equalize', 1.0, None, 'TranslateY_BBox', 1.0, 8), 159 | SubPolicy('Posterize', 0.6, 2, 'Rotate_BBox', 0.0, 10), 160 | SubPolicy('AutoContrast', 0.6, None, 'Rotate_BBox', 1.0, 6), 161 | SubPolicy('Equalize', 0.0, None, 'Cutout', 0.8, 10), 162 | SubPolicy('Brightness', 1.0, 2, 'TranslateY_BBox', 1.0, 6), 163 | SubPolicy('Contrast', 0.0, 2, 'ShearY_BBox', 0.8, 0), 164 | SubPolicy('AutoContrast', 0.8, None, 'Contrast', 0.2, 10), 165 | SubPolicy('Rotate_BBox', 1.0, 10, 'Cutout', 1.0, 10), 166 | SubPolicy('SolarizeAdd', 0.8, 6, 'Equalize', 0.8, None)] 167 | return policy --------------------------------------------------------------------------------