├── README.md ├── auto_augment.py ├── cutmix.py ├── cutout.py ├── dataloader.py ├── evaluation.py ├── fixres ├── Res.py ├── pnasnet.py └── resnext_wsl.py ├── label_smooth.py ├── lr_scheduler.py ├── main.py ├── merge_result.py ├── model.py ├── multigrain ├── augmentations │ ├── __init__.py │ ├── autoaugment.py │ └── transforms.py ├── backbones │ ├── __init__.py │ ├── backbone.py │ ├── nasnet_mobile.py │ └── pnasnet.py ├── datasets │ ├── __init__.py │ ├── holidays-rotate.yaml │ ├── id_dataset.py │ ├── imagenet.py │ ├── list_dataset.py │ ├── loader.py │ └── retrieval.py ├── lib │ ├── __init__.py │ ├── multigrain.py │ ├── samplers.py │ └── whiten.py ├── modules │ ├── __init__.py │ ├── criterion.py │ ├── functional.py │ ├── layers.py │ ├── margin.py │ └── multioptim.py └── utils │ ├── __init__.py │ ├── arguments.py │ ├── checkpoint.py │ ├── logging.py │ ├── metrics.py │ ├── misc.py │ ├── plots.py │ ├── tictoc.py │ └── torch_utils.py ├── optimizer.py ├── opts.py ├── rand_augment.py ├── train.py ├── tta.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # WeatherClassification 2 | (中国气象局 天气识别比赛 B榜第八) 3 | 4 | ### 主要思路 5 | - 其实就和一般的分类比赛差不多,先使用各种预训练的sota模型去试,找出比较适合该比赛数据的,我这里最好的是resnext101和efficientNet。我先选择resnext50作为baseline, 6 | - 然后就是加一些常用的trick,比如label smooth, mixup(cutout, cutmix), rand_aug,以及选择合适的optimizer+scheduler等等,在baseline的基础上比较效果,排除掉不好用的。 7 | - 顺利的话,这时候其实效果已经可以差不多可以到前20了。要想更进一步,就针对一下特殊的数据集。在Data aug时,要注意天气图像最好不要上下翻转和旋转的,这样整个图像的特征就变了(敏感)。这个比赛里由于是天气分类,所以用mixup、cutout系列trick时注意和imagenet不同,要变一下,如全局多个小块的cutout。 8 | - 接下来我用到的比较重要的一个技巧来自这篇paper:“Fixing the train-test resolution discrepancy”。简单点说就是先用小的size对整个网络进行训练,再冻结前面的所有层,使用大的size训练最后几层(一般是fc层)以及BN层。我的设置一般是(320\384 + 224)。这样也加快了整个训练速度,而且二次训练时可以释放大量的GPU资源,只训练FC层也保证了可以使用大分辨率的输入。这个技巧在比赛后期几乎还可以提升1.5个百分点左右,很稳定。。。 9 | - 接着又试了teacher mode。简单点就是使用复杂的网络对测试集进行预测,再使用该测试集作为训练集去预训练小的网络,再使用原始训练集接着训练该小网络,最后小网络可以拟合复杂网络中的大多数据,还可以在训练集中学到自身特性,最终提升分类效果。(实际上对复杂网络的提升不大) 10 | - 最后就是集成了,我使用的大多是分权重投票。最后试了stack,但是训练太复杂而且太慢了,就只融合了2个模型。 11 | 12 | 最终得分:0.89347994 13 | 14 | 训练数据在 https://www.datafountain.cn/competitions/356 15 | -------------------------------------------------------------------------------- /auto_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import scipy 4 | from scipy import ndimage 5 | from PIL import Image, ImageEnhance, ImageOps 6 | 7 | 8 | class AutoAugment(object): 9 | def __init__(self): 10 | self.policies = [ 11 | ['TranslateX', 0.3, 9, 'Contrast', 0.2, 6], 12 | # ['Rotate', 0.7, 2, 'TranslateX', 0.3, 9], 13 | ['Sharpness', 0.8, 1, 'Sharpness', 0.9, 3], 14 | ['ShearY', 0.5, 8, 'TranslateY', 0.7, 9], 15 | ['AutoContrast', 0.5, 8, 'Equalize', 0.9, 2], 16 | ['ShearY', 0.2, 7, 'Posterize', 0.3, 7], 17 | ['Color', 0.4, 3, 'Brightness', 0.6, 7], 18 | ['Sharpness', 0.3, 9, 'Brightness', 0.7, 9], 19 | ['Equalize', 0.6, 5, 'Equalize', 0.5, 1], 20 | ['Contrast', 0.6, 7, 'Sharpness', 0.6, 5], 21 | ['Color', 0.7, 7, 'TranslateX', 0.5, 8], 22 | ['Equalize', 0.3, 7, 'AutoContrast', 0.4, 8], 23 | ['TranslateY', 0.4, 3, 'Sharpness', 0.2, 6], 24 | ['Brightness', 0.9, 6, 'Color', 0.2, 8], 25 | ['Solarize', 0.5, 2, 'TranslateX', 0.3, 7], 26 | ['Equalize', 0.2, 0, 'AutoContrast', 0.6, 0], 27 | ['Equalize', 0.2, 8, 'Equalize', 0.6, 4], 28 | ['Color', 0.9, 9, 'Equalize', 0.6, 6], 29 | ['AutoContrast', 0.8, 4, 'Solarize', 0.2, 8], 30 | ['Brightness', 0.1, 3, 'Color', 0.7, 0], 31 | ['Solarize', 0.4, 5, 'AutoContrast', 0.9, 3], 32 | ['TranslateY', 0.9, 9, 'TranslateY', 0.7, 9], 33 | ['AutoContrast', 0.9, 2, 'Solarize', 0.8, 3], 34 | ['Equalize', 0.8, 8, 'TranslateX', 0.3, 9], 35 | ['TranslateY', 0.7, 9, 'AutoContrast', 0.9, 1], 36 | ] 37 | 38 | def __call__(self, img): 39 | img = apply_policy(img, self.policies[random.randrange(len(self.policies))]) 40 | return img 41 | 42 | 43 | operations = { 44 | 'ShearX': lambda img, magnitude: shear_x(img, magnitude), 45 | 'ShearY': lambda img, magnitude: shear_y(img, magnitude), 46 | 'TranslateX': lambda img, magnitude: translate_x(img, magnitude), 47 | 'TranslateY': lambda img, magnitude: translate_y(img, magnitude), 48 | # 'Rotate': lambda img, magnitude: rotate(img, magnitude), 49 | 'AutoContrast': lambda img, magnitude: auto_contrast(img, magnitude), 50 | # 'Invert': lambda img, magnitude: invert(img, magnitude), 51 | 'Equalize': lambda img, magnitude: equalize(img, magnitude), 52 | 'Solarize': lambda img, magnitude: solarize(img, magnitude), 53 | 'Posterize': lambda img, magnitude: posterize(img, magnitude), 54 | 'Contrast': lambda img, magnitude: contrast(img, magnitude), 55 | 'Color': lambda img, magnitude: color(img, magnitude), 56 | 'Brightness': lambda img, magnitude: brightness(img, magnitude), 57 | 'Sharpness': lambda img, magnitude: sharpness(img, magnitude), 58 | 'Cutout': lambda img, magnitude: cutout(img, magnitude), 59 | } 60 | 61 | 62 | def apply_policy(img, policy): 63 | if random.random() < policy[1]: 64 | img = operations[policy[0]](img, policy[2]) 65 | if random.random() < policy[4]: 66 | img = operations[policy[3]](img, policy[5]) 67 | 68 | return img 69 | 70 | 71 | def transform_matrix_offset_center(matrix, x, y): 72 | o_x = float(x) / 2 + 0.5 73 | o_y = float(y) / 2 + 0.5 74 | offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) 75 | reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) 76 | transform_matrix = offset_matrix @ matrix @ reset_matrix 77 | return transform_matrix 78 | 79 | 80 | def shear_x(img, magnitude): 81 | img = np.array(img) 82 | magnitudes = np.linspace(-0.3, 0.3, 11) 83 | 84 | transform_matrix = np.array([[1, random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]), 0], 85 | [0, 1, 0], 86 | [0, 0, 1]]) 87 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 88 | affine_matrix = transform_matrix[:2, :2] 89 | offset = transform_matrix[:2, 2] 90 | img = np.stack([ndimage.interpolation.affine_transform( 91 | img[:, :, c], 92 | affine_matrix, 93 | offset) for c in range(img.shape[2])], axis=2) 94 | img = Image.fromarray(img) 95 | return img 96 | 97 | 98 | def shear_y(img, magnitude): 99 | img = np.array(img) 100 | magnitudes = np.linspace(-0.3, 0.3, 11) 101 | 102 | transform_matrix = np.array([[1, 0, 0], 103 | [random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]), 1, 0], 104 | [0, 0, 1]]) 105 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 106 | affine_matrix = transform_matrix[:2, :2] 107 | offset = transform_matrix[:2, 2] 108 | img = np.stack([ndimage.interpolation.affine_transform( 109 | img[:, :, c], 110 | affine_matrix, 111 | offset) for c in range(img.shape[2])], axis=2) 112 | img = Image.fromarray(img) 113 | return img 114 | 115 | 116 | def translate_x(img, magnitude): 117 | img = np.array(img) 118 | magnitudes = np.linspace(-150/331, 150/331, 11) 119 | 120 | transform_matrix = np.array([[1, 0, 0], 121 | [0, 1, img.shape[1]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])], 122 | [0, 0, 1]]) 123 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 124 | affine_matrix = transform_matrix[:2, :2] 125 | offset = transform_matrix[:2, 2] 126 | img = np.stack([ndimage.interpolation.affine_transform( 127 | img[:, :, c], 128 | affine_matrix, 129 | offset) for c in range(img.shape[2])], axis=2) 130 | img = Image.fromarray(img) 131 | return img 132 | 133 | 134 | def translate_y(img, magnitude): 135 | img = np.array(img) 136 | magnitudes = np.linspace(-150/331, 150/331, 11) 137 | 138 | transform_matrix = np.array([[1, 0, img.shape[0]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])], 139 | [0, 1, 0], 140 | [0, 0, 1]]) 141 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 142 | affine_matrix = transform_matrix[:2, :2] 143 | offset = transform_matrix[:2, 2] 144 | img = np.stack([ndimage.interpolation.affine_transform( 145 | img[:, :, c], 146 | affine_matrix, 147 | offset) for c in range(img.shape[2])], axis=2) 148 | img = Image.fromarray(img) 149 | return img 150 | 151 | 152 | def rotate(img, magnitude): 153 | img = np.array(img) 154 | magnitudes = np.linspace(-30, 30, 11) 155 | theta = np.deg2rad(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 156 | transform_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], 157 | [np.sin(theta), np.cos(theta), 0], 158 | [0, 0, 1]]) 159 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 160 | affine_matrix = transform_matrix[:2, :2] 161 | offset = transform_matrix[:2, 2] 162 | img = np.stack([ndimage.interpolation.affine_transform( 163 | img[:, :, c], 164 | affine_matrix, 165 | offset) for c in range(img.shape[2])], axis=2) 166 | img = Image.fromarray(img) 167 | return img 168 | 169 | 170 | def auto_contrast(img, magnitude): 171 | img = ImageOps.autocontrast(img) 172 | return img 173 | 174 | 175 | def invert(img, magnitude): 176 | img = ImageOps.invert(img) 177 | return img 178 | 179 | 180 | def equalize(img, magnitude): 181 | img = ImageOps.equalize(img) 182 | return img 183 | 184 | 185 | def solarize(img, magnitude): 186 | magnitudes = np.linspace(0, 256, 11) 187 | img = ImageOps.solarize(img, random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 188 | return img 189 | 190 | 191 | def posterize(img, magnitude): 192 | magnitudes = np.linspace(4, 8, 11) 193 | img = ImageOps.posterize(img, int(round(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])))) 194 | return img 195 | 196 | 197 | def contrast(img, magnitude): 198 | magnitudes = np.linspace(0.1, 1.9, 11) 199 | img = ImageEnhance.Contrast(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 200 | return img 201 | 202 | 203 | def color(img, magnitude): 204 | magnitudes = np.linspace(0.1, 1.9, 11) 205 | img = ImageEnhance.Color(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 206 | return img 207 | 208 | 209 | def brightness(img, magnitude): 210 | magnitudes = np.linspace(0.1, 1.9, 11) 211 | img = ImageEnhance.Brightness(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 212 | return img 213 | 214 | 215 | def sharpness(img, magnitude): 216 | magnitudes = np.linspace(0.1, 1.9, 11) 217 | img = ImageEnhance.Sharpness(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 218 | return img 219 | 220 | 221 | def cutout(org_img, magnitude=None): 222 | img = np.array(img) 223 | 224 | magnitudes = np.linspace(0, 60/331, 11) 225 | 226 | img = np.copy(org_img) 227 | mask_val = img.mean() 228 | 229 | if magnitude is None: 230 | mask_size = 16 231 | else: 232 | mask_size = int(round(img.shape[0]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]))) 233 | top = np.random.randint(0 - mask_size//2, img.shape[0] - mask_size) 234 | left = np.random.randint(0 - mask_size//2, img.shape[1] - mask_size) 235 | bottom = top + mask_size 236 | right = left + mask_size 237 | 238 | if top < 0: 239 | top = 0 240 | if left < 0: 241 | left = 0 242 | 243 | img[top:bottom, left:right, :].fill(mask_val) 244 | 245 | img = Image.fromarray(img) 246 | 247 | return img 248 | 249 | 250 | 251 | class Cutout(object): 252 | def __init__(self, length=16): 253 | self.length = length 254 | 255 | def __call__(self, img): 256 | img = np.array(img) 257 | 258 | mask_val = img.mean() 259 | 260 | top = np.random.randint(0 - self.length//2, img.shape[0] - self.length) 261 | left = np.random.randint(0 - self.length//2, img.shape[1] - self.length) 262 | bottom = top + self.length 263 | right = left + self.length 264 | 265 | top = 0 if top < 0 else top 266 | left = 0 if left < 0 else top 267 | 268 | img[top:bottom, left:right, :] = mask_val 269 | 270 | img = Image.fromarray(img) 271 | 272 | return img -------------------------------------------------------------------------------- /cutmix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def cutmix(batch, alpha): 7 | data, targets = batch 8 | 9 | indices = torch.randperm(data.size(0)) 10 | shuffled_data = data[indices] 11 | shuffled_targets = targets[indices] 12 | 13 | lam = np.random.beta(alpha, alpha) 14 | 15 | image_h, image_w = data.shape[2:] 16 | cx = np.random.uniform(0, image_w) 17 | cy = np.random.uniform(0, image_h) 18 | w = image_w * np.sqrt(1 - lam) 19 | h = image_h * np.sqrt(1 - lam) 20 | x0 = int(np.round(max(cx - w / 2, 0))) 21 | x1 = int(np.round(min(cx + w / 2, image_w))) 22 | y0 = int(np.round(max(cy - h / 2, 0))) 23 | y1 = int(np.round(min(cy + h / 2, image_h))) 24 | 25 | data[:, :, y0:y1, x0:x1] = shuffled_data[:, :, y0:y1, x0:x1] 26 | targets = (targets, shuffled_targets, lam) 27 | 28 | return data, targets 29 | 30 | 31 | class CutMixCollator: 32 | def __init__(self, alpha): 33 | self.alpha = alpha 34 | 35 | def __call__(self, batch): 36 | batch = torch.utils.data.dataloader.default_collate(batch) 37 | batch = cutmix(batch, self.alpha) 38 | return batch 39 | 40 | 41 | class CutMixCriterion: 42 | def __init__(self, criterion): 43 | self.criterion = criterion 44 | 45 | def __call__(self, preds, targets): 46 | targets1, targets2, lam = targets 47 | return lam * self.criterion( 48 | preds, targets1) + (1 - lam) * self.criterion(preds, targets2) 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /cutout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class Cutout(object): 6 | """Randomly mask out one or more patches from an image. 7 | 8 | Args: 9 | n_holes (int): Number of patches to cut out of each image. 10 | length (int): The length (in pixels) of each square patch. 11 | """ 12 | def __init__(self, n_holes=1, length=8): 13 | self.n_holes = n_holes 14 | self.length = length 15 | 16 | def __call__(self, img): 17 | """ 18 | Args: 19 | img (Tensor): Tensor image of size (C, H, W). 20 | Returns: 21 | Tensor: Image with n_holes of dimension length x length cut out of it. 22 | """ 23 | h = img.size(1) 24 | w = img.size(2) 25 | 26 | mask = np.ones((h, w), np.float32) 27 | 28 | for n in range(self.n_holes): 29 | y = np.random.randint(h) 30 | x = np.random.randint(w) 31 | 32 | y1 = np.clip(y - self.length // 2, 0, h) 33 | y2 = np.clip(y + self.length // 2, 0, h) 34 | x1 = np.clip(x - self.length // 2, 0, w) 35 | x2 = np.clip(x + self.length // 2, 0, w) 36 | 37 | mask[y1: y2, x1: x2] = 0. 38 | 39 | mask = torch.from_numpy(mask) 40 | mask = mask.expand_as(img) 41 | img = img * mask 42 | 43 | return img 44 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import glob 4 | import cv2 5 | import PIL 6 | import random 7 | 8 | from torch.utils.data import Dataset, DataLoader 9 | import torchvision as tv 10 | from torchvision import transforms as T 11 | 12 | import utils 13 | from cutout import Cutout 14 | from auto_augment import AutoAugment 15 | from rand_augment import Rand_Augment 16 | 17 | class WeatherDataset(Dataset): 18 | def __init__(self, images, labels, transforms, output_name=False): 19 | self.images = images 20 | self.labels = labels 21 | self.transforms = transforms 22 | self.output_name = output_name 23 | 24 | def __len__(self): 25 | return len(self.labels) 26 | 27 | def __getitem__(self, idx): 28 | image = utils.load_image(self.images[idx]) 29 | # image = self.images[idx] 30 | label = utils.to_tensor(self.labels[idx], torch.long) 31 | 32 | if self.transforms is not None: 33 | image = self.transforms(image) 34 | 35 | if self.output_name: 36 | return image, label, self.images[idx] 37 | 38 | return image, label 39 | 40 | class TestDataset(Dataset): 41 | def __init__(self, images, names, transforms): 42 | self.images = images 43 | self.names = names 44 | self.transforms = transforms 45 | 46 | def __len__(self): 47 | return len(self.images) 48 | 49 | def __getitem__(self, idx): 50 | # image = utils.load_image(self.images[idx]) 51 | image = self.images[idx] 52 | name = self.names[idx] 53 | if self.transforms is not None: 54 | image = self.transforms(image) 55 | return image, name 56 | 57 | class CamDataset(Dataset): 58 | def __init__(self, images, labels, transforms): 59 | self.images = images 60 | self.labels = labels 61 | self.transforms = transforms 62 | 63 | def __len__(self): 64 | return len(self.labels) 65 | 66 | def __getitem__(self, idx): 67 | # image = utils.load_image(self.images[idx]) 68 | image = self.images[idx] 69 | label = utils.to_tensor(self.labels[idx], torch.long) 70 | 71 | if self.transforms is not None: 72 | t_image = self.transforms(image) 73 | image = resize_transform(image) 74 | 75 | return t_image, label, np.array(image) 76 | 77 | 78 | class UnNormalize(object): 79 | def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 80 | self.mean = mean 81 | self.std = std 82 | 83 | def __call__(self, tensor): 84 | """ 85 | Args: 86 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 87 | Returns: 88 | Tensor: Normalized image. 89 | """ 90 | for t, m, s in zip(tensor, self.mean, self.std): 91 | t.mul_(s).add_(m) 92 | # The normalize code -> t.sub_(m).div_(s) 93 | return tensor 94 | 95 | 96 | def my_transform(train=True, resize=224, use_cutout=False, n_holes=1, length=8, auto_aug=False, rand_aug=False): 97 | transforms = [] 98 | interpolations = [PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.HAMMING, PIL.Image.BICUBIC, PIL.Image.LANCZOS] 99 | 100 | if train: 101 | # transforms.append(T.RandomRotation(90)) 102 | transforms.append(T.RandomResizedCrop(resize+5, 103 | scale=(0.2, 2.0), 104 | interpolation=PIL.Image.BICUBIC)) 105 | transforms.append(T.RandomHorizontalFlip()) 106 | # transforms.append(T.RandomVerticalFlip()) 107 | transforms.append(T.ColorJitter(0.2, 0.2, 0.3, 0.)) 108 | transforms.append(T.CenterCrop(resize)) 109 | if auto_aug: 110 | transforms.append(AutoAugment()) 111 | if rand_aug: 112 | transforms.append(Rand_Augment()) 113 | else: 114 | transforms.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC)) 115 | transforms.append(T.CenterCrop(resize)) 116 | 117 | transforms.append(T.ToTensor()) 118 | transforms.append( 119 | # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) 120 | # T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) 121 | T.Normalize(mean=[0.507, 0.522, 0.500], std=[0.213, 0.207, 0.212])) 122 | 123 | if train and use_cutout: 124 | transforms.append(Cutout()) 125 | 126 | return T.Compose(transforms) 127 | 128 | def test_transform(resize=224): 129 | transforms = [] 130 | transforms.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC)) 131 | transforms.append(T.CenterCrop(resize)) 132 | transforms.append(T.ToTensor()) 133 | 134 | return T.Compose(transforms) 135 | 136 | def resize_transform(images, resize=224): 137 | transforms = [] 138 | 139 | transforms.append(T.Resize(resize+20)) 140 | transforms.append(T.CenterCrop(resize)) 141 | 142 | # transforms.append(T.ToTensor()) 143 | 144 | return T.Compose(transforms)(images) 145 | 146 | 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | import os 7 | import csv 8 | import re 9 | import json 10 | 11 | import utils 12 | import opts 13 | from train import train_model, eval_model, eval_logits, eval_model_tta, eval_logits_tta 14 | from model import * 15 | from dataloader import TestDataset, my_transform, test_transform 16 | from sync_batchnorm import convert_model 17 | 18 | def main(opt): 19 | if torch.cuda.is_available(): 20 | device = torch.device('cuda') 21 | torch.cuda.set_device(opt.gpu_id) 22 | else: 23 | device = torch.device('cpu') 24 | 25 | if opt.cadene: 26 | model = cadene_model(opt.classes, model_name=opt.network) 27 | elif opt.network == 'resnet': 28 | model = resnet(opt.classes, opt.layers) 29 | elif opt.network == 'resnext': 30 | model = resnext(opt.classes, opt.layers) 31 | elif opt.network == 'resnext_wsl': 32 | # resnext_wsl must specify the opt.battleneck_width parameter 33 | opt.network = 'resnext_wsl_32x' + str(opt.battleneck_width) +'d' 34 | model = resnext_wsl(opt.classes, opt.battleneck_width) 35 | elif opt.network == 'resnext_swsl': 36 | model = resnext_swsl(opt.classes, opt.layers, opt.battleneck_width) 37 | elif opt.network == 'vgg': 38 | model = vgg_bn(opt.classes, opt.layers) 39 | elif opt.network == 'densenet': 40 | model = densenet(opt.classes, opt.layers) 41 | elif opt.network == 'inception_v3': 42 | model = inception_v3(opt.classes, opt.layers) 43 | elif opt.network == 'dpn': 44 | model = dpn(opt.classes, opt.layers) 45 | elif opt.network == 'effnet': 46 | model = effnet(opt.classes, opt.layers) 47 | elif opt.network == 'pnasnet_m': 48 | model = pnasnet_m(opt.classes, opt.layers, opt.pretrained) 49 | elif opt.network == 'senet_m': 50 | model = senet_m(opt.classes, opt.layers, opt.pretrained) 51 | 52 | 53 | # model = nn.DataParallel(model, device_ids=[0, 1, 2, 3]) 54 | model = nn.DataParallel(model, device_ids=[0, 1, 2, 3, 4, 5, 6, 7]) 55 | # model = nn.DataParallel(model, device_ids=[1, 2, 3, 4, 5, 6, 7, 0]) 56 | # model = nn.DataParallel(model, device_ids=[4, 5, 6, 7]) 57 | # model = convert_model(model) 58 | model = model.to(device) 59 | 60 | # for param in model.module.model.parameters(): 61 | for param in model.parameters(): 62 | param.requires_grad = False 63 | 64 | if opt.classes > 2: 65 | images, names = utils.read_test_data(os.path.join(opt.root_dir, opt.test_dir)) 66 | else: 67 | images, names = utils.read_test_ice_snow_data( 68 | os.path.join(opt.root_dir, opt.test_dir), 69 | os.path.join(opt.results_ts, opt.res8)) 70 | 71 | dict_= {} 72 | 73 | for crop_size in [opt.crop_size+256]: 74 | if opt.tta: 75 | transforms = test_transform(crop_size) 76 | else: 77 | transforms = my_transform(False, crop_size) 78 | 79 | dataset = TestDataset(images, names, transforms) 80 | loader = torch.utils.data.DataLoader(dataset, 81 | batch_size=opt.batch_size, 82 | shuffle=False, num_workers=4) 83 | state_dict = torch.load( 84 | opt.model_dir+'/'+opt.network+'-'+str(opt.layers)+'-'+str(crop_size)+'_model.ckpt') 85 | if opt.network == 'densenet': 86 | pattern = re.compile( 87 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 88 | for key in list(state_dict.keys()): 89 | res = pattern.match(key) 90 | if res: 91 | new_key = res.group(1) + res.group(2) 92 | state_dict[new_key] = state_dict[key] 93 | del state_dict[key] 94 | model.load_state_dict(state_dict) 95 | if opt.vote: 96 | if opt.tta: 97 | im_names, labels = eval_model_tta(loader, model, device=device) 98 | else: 99 | im_names, labels = eval_model(loader, model, device=device) 100 | else: 101 | if opt.tta: 102 | im_names, labels = eval_logits_tta(loader, model, device=device) 103 | else: 104 | im_names, labels = eval_logits(loader, model, device) 105 | im_labels = [] 106 | # print(im_names) 107 | for name, label in zip(im_names, labels): 108 | if name in dict_: 109 | dict_[name].append(label) 110 | else: 111 | dict_[name] = [label] 112 | 113 | 114 | header = ['filename', 'type'] 115 | utils.mkdir(opt.results_dir) 116 | utils.mkdir(opt.results_ts) 117 | result = opt.network + '-' +str(opt.layers) + '-'+str(crop_size)+ '_result.csv' 118 | if opt.classes == 9: 119 | filename = os.path.join(opt.results_dir, result) 120 | else: 121 | result = str(opt.classes) + '-' + result 122 | filename = os.path.join(opt.results_ts, result) 123 | with open(filename, 'w', encoding='utf-8') as f: 124 | f_csv = csv.writer(f) 125 | f_csv.writerow(header) 126 | for key in dict_.keys(): 127 | # val = np.max(np.sum(np.array(dict_[key]), axis=0)) 128 | # if val > 0.5: continue 129 | # v = np.argmax(np.sum(np.array(dict_[key]), axis=0)) + 1 130 | v = list(np.sum(np.array(dict_[key]), axis=0)) 131 | # f_csv.writerow([key, val]) 132 | f_csv.writerow([key, v]) 133 | 134 | opt = opts.parse_args() 135 | main(opt) 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /fixres/Res.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch.nn as nn 9 | try: 10 | from torch.hub import load_state_dict_from_url 11 | except ImportError: 12 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 13 | 14 | 15 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 16 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 17 | 18 | 19 | model_urls = { 20 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 21 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 22 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 23 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 24 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 25 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 26 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 31 | """3x3 convolution with padding""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=dilation, groups=groups, bias=False, dilation=dilation) 34 | 35 | 36 | def conv1x1(in_planes, out_planes, stride=1): 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 45 | base_width=64, dilation=1, norm_layer=None): 46 | super(BasicBlock, self).__init__() 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | if groups != 1 or base_width != 64: 50 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 51 | if dilation > 1: 52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 54 | self.conv1 = conv3x3(inplanes, planes, stride) 55 | self.bn1 = norm_layer(planes) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3(planes, planes) 58 | self.bn2 = norm_layer(planes) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 85 | base_width=64, dilation=1, norm_layer=None): 86 | super(Bottleneck, self).__init__() 87 | if norm_layer is None: 88 | norm_layer = nn.BatchNorm2d 89 | width = int(planes * (base_width / 64.)) * groups 90 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 91 | self.conv1 = conv1x1(inplanes, width) 92 | self.bn1 = norm_layer(width) 93 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 94 | self.bn2 = norm_layer(width) 95 | self.conv3 = conv1x1(width, planes * self.expansion) 96 | self.bn3 = norm_layer(planes * self.expansion) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.downsample = downsample 99 | self.stride = stride 100 | 101 | def forward(self, x): 102 | identity = x 103 | 104 | out = self.conv1(x) 105 | out = self.bn1(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv2(out) 109 | out = self.bn2(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv3(out) 113 | out = self.bn3(out) 114 | 115 | if self.downsample is not None: 116 | identity = self.downsample(x) 117 | 118 | out += identity 119 | out = self.relu(out) 120 | 121 | return out 122 | 123 | 124 | class ResNet(nn.Module): 125 | 126 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 127 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 128 | norm_layer=None): 129 | super(ResNet, self).__init__() 130 | if norm_layer is None: 131 | norm_layer = nn.BatchNorm2d 132 | self._norm_layer = norm_layer 133 | 134 | self.inplanes = 64 135 | self.dilation = 1 136 | if replace_stride_with_dilation is None: 137 | # each element in the tuple indicates if we should replace 138 | # the 2x2 stride with a dilated convolution instead 139 | replace_stride_with_dilation = [False, False, False] 140 | if len(replace_stride_with_dilation) != 3: 141 | raise ValueError("replace_stride_with_dilation should be None " 142 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 143 | self.groups = groups 144 | self.base_width = width_per_group 145 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 146 | bias=False) 147 | self.bn1 = norm_layer(self.inplanes) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 150 | self.layer1 = self._make_layer(block, 64, layers[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 152 | dilate=replace_stride_with_dilation[0]) 153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 154 | dilate=replace_stride_with_dilation[1]) 155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 156 | dilate=replace_stride_with_dilation[2]) 157 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 158 | self.fc = nn.Linear(512 * block.expansion, num_classes) 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | 167 | # Zero-initialize the last BN in each residual branch, 168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 170 | if zero_init_residual: 171 | for m in self.modules(): 172 | if isinstance(m, Bottleneck): 173 | nn.init.constant_(m.bn3.weight, 0) 174 | elif isinstance(m, BasicBlock): 175 | nn.init.constant_(m.bn2.weight, 0) 176 | 177 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 178 | norm_layer = self._norm_layer 179 | downsample = None 180 | previous_dilation = self.dilation 181 | if dilate: 182 | self.dilation *= stride 183 | stride = 1 184 | if stride != 1 or self.inplanes != planes * block.expansion: 185 | downsample = nn.Sequential( 186 | conv1x1(self.inplanes, planes * block.expansion, stride), 187 | norm_layer(planes * block.expansion), 188 | ) 189 | 190 | layers = [] 191 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 192 | self.base_width, previous_dilation, norm_layer)) 193 | self.inplanes = planes * block.expansion 194 | for _ in range(1, blocks): 195 | layers.append(block(self.inplanes, planes, groups=self.groups, 196 | base_width=self.base_width, dilation=self.dilation, 197 | norm_layer=norm_layer)) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def forward(self, x): 202 | x = self.conv1(x) 203 | x = self.bn1(x) 204 | x = self.relu(x) 205 | x = self.maxpool(x) 206 | 207 | x = self.layer1(x) 208 | x = self.layer2(x) 209 | x = self.layer3(x) 210 | x = self.layer4(x) 211 | 212 | x = self.avgpool(x) 213 | x = x.reshape(x.size(0), -1) 214 | x = self.fc(x) 215 | 216 | return x 217 | 218 | 219 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 220 | model = ResNet(block, layers, **kwargs) 221 | if pretrained: 222 | state_dict = load_state_dict_from_url(model_urls[arch], 223 | progress=progress) 224 | model.load_state_dict(state_dict) 225 | return model 226 | 227 | 228 | def resnet18(pretrained=False, progress=True, **kwargs): 229 | """Constructs a ResNet-18 model. 230 | Args: 231 | pretrained (bool): If True, returns a model pre-trained on ImageNet 232 | progress (bool): If True, displays a progress bar of the download to stderr 233 | """ 234 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 235 | **kwargs) 236 | 237 | 238 | def resnet34(pretrained=False, progress=True, **kwargs): 239 | """Constructs a ResNet-34 model. 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | progress (bool): If True, displays a progress bar of the download to stderr 243 | """ 244 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 245 | **kwargs) 246 | 247 | 248 | def resnet50(pretrained=False, progress=True, **kwargs): 249 | """Constructs a ResNet-50 model. 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 255 | **kwargs) 256 | 257 | 258 | def resnet101(pretrained=False, progress=True, **kwargs): 259 | """Constructs a ResNet-101 model. 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | progress (bool): If True, displays a progress bar of the download to stderr 263 | """ 264 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 265 | **kwargs) 266 | 267 | 268 | def resnet152(pretrained=False, progress=True, **kwargs): 269 | """Constructs a ResNet-152 model. 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | progress (bool): If True, displays a progress bar of the download to stderr 273 | """ 274 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 275 | **kwargs) 276 | 277 | 278 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 279 | """Constructs a ResNeXt-50 32x4d model. 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | progress (bool): If True, displays a progress bar of the download to stderr 283 | """ 284 | kwargs['groups'] = 32 285 | kwargs['width_per_group'] = 4 286 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 287 | pretrained, progress, **kwargs) 288 | 289 | 290 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 291 | """Constructs a ResNeXt-101 32x8d model. 292 | Args: 293 | pretrained (bool): If True, returns a model pre-trained on ImageNet 294 | progress (bool): If True, displays a progress bar of the download to stderr 295 | """ 296 | kwargs['groups'] = 32 297 | kwargs['width_per_group'] = 8 298 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 299 | pretrained, progress, **kwargs) 300 | -------------------------------------------------------------------------------- /fixres/resnext_wsl.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # Optional list of dependencies required by the package 9 | 10 | ''' 11 | Code From : https://github.com/facebookresearch/WSL-Images/blob/master/hubconf.py 12 | ''' 13 | dependencies = ['torch', 'torchvision'] 14 | 15 | try: 16 | from torch.hub import load_state_dict_from_url 17 | except ImportError: 18 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 19 | 20 | from .Res import ResNet, Bottleneck 21 | 22 | 23 | model_urls = { 24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth', 25 | 'resnext101_32x16d': 'https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth', 26 | 'resnext101_32x32d': 'https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth', 27 | 'resnext101_32x48d': 'https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth', 28 | } 29 | 30 | 31 | def _resnext(arch, block, layers, pretrained, progress, **kwargs): 32 | model = ResNet(block, layers, **kwargs) 33 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 34 | model.load_state_dict(state_dict) 35 | return model 36 | 37 | 38 | def resnext101_32x8d_wsl(progress=True, **kwargs): 39 | """Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data 40 | and finetuned on ImageNet from Figure 5 in 41 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 42 | Args: 43 | progress (bool): If True, displays a progress bar of the download to stderr. 44 | """ 45 | kwargs['groups'] = 32 46 | kwargs['width_per_group'] = 8 47 | return _resnext('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 48 | 49 | 50 | def resnext101_32x16d_wsl(progress=True, **kwargs): 51 | """Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data 52 | and finetuned on ImageNet from Figure 5 in 53 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 54 | Args: 55 | progress (bool): If True, displays a progress bar of the download to stderr. 56 | """ 57 | kwargs['groups'] = 32 58 | kwargs['width_per_group'] = 16 59 | return _resnext('resnext101_32x16d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 60 | 61 | 62 | def resnext101_32x32d_wsl(progress=True, **kwargs): 63 | """Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data 64 | and finetuned on ImageNet from Figure 5 in 65 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 66 | Args: 67 | progress (bool): If True, displays a progress bar of the download to stderr. 68 | """ 69 | kwargs['groups'] = 32 70 | kwargs['width_per_group'] = 32 71 | return _resnext('resnext101_32x32d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 72 | 73 | 74 | def resnext101_32x48d_wsl(progress=True, **kwargs): 75 | """Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data 76 | and finetuned on ImageNet from Figure 5 in 77 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 78 | Args: 79 | progress (bool): If True, displays a progress bar of the download to stderr. 80 | """ 81 | kwargs['groups'] = 32 82 | kwargs['width_per_group'] = 48 83 | return _resnext('resnext101_32x48d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 84 | -------------------------------------------------------------------------------- /label_smooth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from lovasz_losses import * 6 | 7 | class LabelSmoothSoftmaxCE(nn.Module): 8 | def __init__(self, 9 | lb_pos=0.9, 10 | lb_neg=0.005, 11 | reduction='mean', 12 | lb_ignore=255, 13 | weight=None, 14 | use_focal_loss=False 15 | ): 16 | super(LabelSmoothSoftmaxCE, self).__init__() 17 | self.lb_pos = lb_pos 18 | self.lb_neg = lb_neg 19 | self.reduction = reduction 20 | self.lb_ignore = lb_ignore 21 | self.weight = weight 22 | self.use_focal_loss = use_focal_loss 23 | if use_focal_loss: 24 | self.focal_loss = FocalLoss2(weight=weight, balance_param=0.5) 25 | # self.f1_loss = F1Loss() 26 | self.triplet_loss = TripletLoss(0.3) 27 | self.log_softmax = nn.LogSoftmax(1) 28 | 29 | def forward(self, logits, label): 30 | if self.use_focal_loss: 31 | floss = self.focal_loss(logits, label) 32 | # f1loss = self.f1_loss(logits, label) 33 | # lovasz_loss = lovasz_softmax(torch.softmax(logits, dim=1), label, classes='all') 34 | triplet_loss = self.triplet_loss(logits, label) 35 | 36 | logs = self.log_softmax(logits) 37 | ignore = label.data.cpu() == self.lb_ignore 38 | n_valid = (ignore == 0).sum() 39 | label = label.clone() 40 | label[ignore] = 0 41 | lb_one_hot = logits.data.clone().zero_().scatter_(1, label.unsqueeze(1), 1) 42 | label = self.lb_pos * lb_one_hot + self.lb_neg * (1-lb_one_hot) 43 | ignore = ignore.nonzero() 44 | _, M = ignore.size() 45 | a, *b = ignore.chunk(M, dim=1) 46 | label[[a, torch.arange(label.size(1)), *b]] = 0 47 | 48 | if self.weight is not None: 49 | sum_loss = -torch.sum(torch.sum((logs*label)*self.weight, dim=1)) 50 | else: 51 | sum_loss = -torch.sum(torch.sum(logs*label, dim=1)) 52 | 53 | if self.reduction == 'mean': 54 | loss = sum_loss / n_valid 55 | elif self.reduction == 'sum': 56 | loss = sum_loss 57 | if self.use_focal_loss: 58 | loss = 1.*loss + 0.4*floss 59 | # + 0.4*triplet_loss 60 | return loss 61 | 62 | class FocalLoss(nn.Module): 63 | def __init__(self, 64 | alpha=1., 65 | gamma=2, 66 | reduction='sum', 67 | ignore_lb=255): 68 | super(FocalLoss, self).__init__() 69 | self.alpha = alpha 70 | self.gamma = gamma 71 | self.reduction = reduction 72 | self.ignore_lb = ignore_lb 73 | 74 | def forward(self, logits, label): 75 | ''' 76 | args: logits: tensor of shape (N, C, H, W) 77 | args: label: tensor of shape(N, H, W) 78 | ''' 79 | # overcome ignored label 80 | ignore = label.data.cpu() == self.ignore_lb 81 | n_valid = (ignore == 0).sum() 82 | label[ignore] = 0 83 | 84 | ignore = ignore.nonzero() 85 | _, M = ignore.size() 86 | a, *b = ignore.chunk(M, dim=1) 87 | mask = torch.ones_like(logits) 88 | mask[[a, torch.arange(mask.size(1)), *b]] = 0 89 | 90 | # compute loss 91 | probs = torch.sigmoid(logits) 92 | lb_one_hot = logits.data.clone().zero_().scatter_(1, label.unsqueeze(1), 1) 93 | pt = torch.where(lb_one_hot == 1, probs, 1-probs) 94 | alpha = self.alpha*lb_one_hot + (1-self.alpha)*(1-lb_one_hot) 95 | loss = -alpha*((1-pt)**self.gamma)*torch.log(pt + 1e-12) 96 | loss[mask == 0] = 0 97 | if self.reduction == 'mean': 98 | loss = loss.sum(dim=1).sum()/n_valid 99 | elif self.reduction == 'sum': 100 | loss = loss.sum() 101 | return loss 102 | 103 | class FocalLoss2(nn.Module): 104 | 105 | def __init__(self, weight=None, focusing_param=2, balance_param=0.25): 106 | super(FocalLoss2, self).__init__() 107 | self.weight = weight 108 | self.focusing_param = focusing_param 109 | self.balance_param = balance_param 110 | 111 | def forward(self, output, target): 112 | 113 | cross_entropy = F.cross_entropy(output, target, weight=self.weight, reduction='sum') 114 | cross_entropy_log = torch.log(cross_entropy) 115 | logpt = - F.cross_entropy(output, target, weight=self.weight, reduction='sum') 116 | pt = torch.exp(logpt) 117 | 118 | focal_loss = -((1 - pt) ** self.focusing_param) * logpt 119 | 120 | balanced_focal_loss = self.balance_param * focal_loss 121 | 122 | return balanced_focal_loss 123 | 124 | 125 | class F1Loss(nn.Module): 126 | def __init__(self): 127 | super(F1Loss, self).__init__() 128 | 129 | def forward(self, predict, targets): 130 | return self.f1_loss(predict, targets) 131 | 132 | def f1_loss(self, predict, target): 133 | batch_size = predict.size(0) 134 | target = target.view(batch_size, 1) 135 | target = torch.zeros(batch_size, 9).cuda().scatter_(1, target, 1) 136 | predict = torch.sigmoid(predict) 137 | # print(predict.size(), target.size()) 138 | predict = torch.clamp(predict * (1-target), min=0.01) + predict * target 139 | tp = predict * target 140 | tp = tp.sum(dim=0) 141 | precision = tp / (predict.sum(dim=0) + 1e-8) 142 | recall = tp / (target.sum(dim=0) + 1e-8) 143 | f1 = 2 * (precision * recall / (precision + recall + 1e-8)) 144 | return 1 - f1.mean() 145 | 146 | class TripletLoss(nn.Module): 147 | """Triplet loss with hard positive/negative mining. 148 | Reference: 149 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 150 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 151 | Args: 152 | margin (float): margin for triplet. 153 | """ 154 | def __init__(self, margin=0.3, mutual_flag = False): 155 | super(TripletLoss, self).__init__() 156 | self.margin = margin 157 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 158 | self.mutual = mutual_flag 159 | 160 | def forward(self, inputs, targets): 161 | """ 162 | Args: 163 | inputs: feature matrix with shape (batch_size, feat_dim) 164 | targets: ground truth labels with shape (num_classes) 165 | """ 166 | n = inputs.size(0) 167 | # inputs = 1. * inputs / (torch.norm(inputs, 2, dim=-1, keepdim=True).expand_as(inputs) + 1e-12) 168 | # Compute pairwise distance, replace by the official when merged 169 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 170 | dist = dist + dist.t() 171 | dist.addmm_(1, -2, inputs, inputs.t()) 172 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 173 | # For each anchor, find the hardest positive and negative 174 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 175 | dist_ap, dist_an = [], [] 176 | for i in range(n): 177 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 178 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 179 | dist_ap = torch.cat(dist_ap) 180 | dist_an = torch.cat(dist_an) 181 | # Compute ranking hinge loss 182 | y = torch.ones_like(dist_an) 183 | loss = self.ranking_loss(dist_an, dist_ap, y) 184 | if self.mutual: 185 | return loss, dist 186 | return loss 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from bisect import bisect_right 3 | from collections import defaultdict 4 | from itertools import chain 5 | from torch.optim import Optimizer 6 | import torch 7 | import warnings 8 | import math 9 | 10 | 11 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 12 | # separating MultiStepLR with WarmupLR 13 | # but the current LRScheduler design doesn't allow it 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, 19 | gamma=0.1, 20 | warmup_factor=1.0 / 5, 21 | warmup_iters=50, # 500 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = float(self.last_epoch) / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | 58 | 59 | class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): 60 | """cosine annealing scheduler with warmup. 61 | Args: 62 | optimizer (Optimizer): Wrapped optimizer. 63 | T_max (int): Maximum number of iterations. 64 | eta_min (float): Minimum learning rate. Default: 0. 65 | last_epoch (int): The index of last epoch. Default: -1. 66 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 67 | https://arxiv.org/abs/1608.03983 68 | """ 69 | 70 | def __init__( 71 | self, 72 | optimizer, 73 | T_max, 74 | eta_min, 75 | warmup_factor=1.0 / 3, 76 | warmup_iters=500, 77 | warmup_method="linear", 78 | last_epoch=-1, 79 | ): 80 | if warmup_method not in ("constant", "linear"): 81 | raise ValueError( 82 | "Only 'constant' or 'linear' warmup_method accepted" 83 | "got {}".format(warmup_method) 84 | ) 85 | 86 | self.T_max = T_max 87 | self.eta_min = eta_min 88 | self.warmup_factor = warmup_factor 89 | self.warmup_iters = warmup_iters 90 | self.warmup_method = warmup_method 91 | super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 92 | 93 | def get_lr(self): 94 | if self.last_epoch < self.warmup_iters: 95 | return self.get_lr_warmup() 96 | else: 97 | return self.get_lr_cos_annealing() 98 | 99 | def get_lr_warmup(self): 100 | if self.warmup_method == "constant": 101 | warmup_factor = self.warmup_factor 102 | elif self.warmup_method == "linear": 103 | alpha = self.last_epoch / self.warmup_iters 104 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 105 | return [ 106 | base_lr * warmup_factor 107 | for base_lr in self.base_lrs 108 | ] 109 | 110 | def get_lr_cos_annealing(self): 111 | last_epoch = self.last_epoch - self.warmup_iters 112 | T_max = self.T_max - self.warmup_iters 113 | return [self.eta_min + (base_lr - self.eta_min) * 114 | (1 + math.cos(math.pi * last_epoch / T_max)) / 2 115 | for base_lr in self.base_lrs] 116 | 117 | 118 | class PiecewiseCyclicalLinearLR(torch.optim.lr_scheduler._LRScheduler): 119 | """Set the learning rate of each parameter group using piecewise 120 | cyclical linear schedule. 121 | When last_epoch=-1, sets initial lr as lr. 122 | 123 | Args: 124 | c: cycle length 125 | alpha1: lr upper bound of cycle 126 | alpha2: lr lower bound of cycle 127 | _Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs 128 | https://arxiv.org/pdf/1802.10026 129 | _Exploring loss function topology with cyclical learning rates 130 | https://arxiv.org/abs/1702.04283 131 | """ 132 | 133 | def __init__(self, optimizer, c, alpha1=1e-2, alpha2=5e-4, last_epoch=-1): 134 | 135 | self.c = c 136 | self.alpha1 = alpha1 137 | self.alpha2 = alpha2 138 | super(PiecewiseCyclicalLinearLR, self).__init__(optimizer, last_epoch) 139 | 140 | def get_lr(self): 141 | 142 | lrs = [] 143 | for _ in range(len(self.base_lrs)): 144 | ti = ((self.last_epoch - 1) % self.c + 1) / self.c 145 | if 0 <= ti <= 0.5: 146 | lr = (1 - 2 * ti) * self.alpha1 + 2 * ti * self.alpha2 147 | elif 0.5 < ti <= 1.0: 148 | lr = (2 - 2 * ti) * self.alpha2 + (2 * ti - 1) * self.alpha1 149 | else: 150 | raise ValueError('t(i) is out of range [0,1].') 151 | lrs.append(lr) 152 | 153 | return lrs 154 | 155 | 156 | class PolyLR(torch.optim.lr_scheduler._LRScheduler): 157 | 158 | def __init__(self, optimizer, power=0.9, max_epoch=4e4, last_epoch=-1): 159 | self.power = power 160 | self.max_epoch = max_epoch 161 | self.last_epoch = last_epoch 162 | super(PolyLR, self).__init__(optimizer, last_epoch) 163 | 164 | def get_lr(self): 165 | lrs = [] 166 | for base_lr in self.base_lrs: 167 | lr = base_lr * (1.0 - (self.last_epoch / self.max_epoch)) ** self.power 168 | lrs.append(lr) 169 | 170 | return lrs 171 | 172 | 173 | class Lookahead(Optimizer): 174 | def __init__(self, optimizer, k=5, alpha=0.5): 175 | self.optimizer = optimizer 176 | self.k = k 177 | self.alpha = alpha 178 | self.param_groups = self.optimizer.param_groups 179 | self.state = defaultdict(dict) 180 | self.fast_state = self.optimizer.state 181 | for group in self.param_groups: 182 | group["counter"] = 0 183 | 184 | def update(self, group): 185 | for fast in group["params"]: 186 | param_state = self.state[fast] 187 | if "slow_param" not in param_state: 188 | param_state["slow_param"] = torch.zeros_like(fast.data) 189 | param_state["slow_param"].copy_(fast.data) 190 | slow = param_state["slow_param"] 191 | slow += (fast.data - slow) * self.alpha 192 | fast.data.copy_(slow) 193 | 194 | def update_lookahead(self): 195 | for group in self.param_groups: 196 | self.update(group) 197 | 198 | def step(self, closure=None): 199 | loss = self.optimizer.step(closure) 200 | for group in self.param_groups: 201 | if group["counter"] == 0: 202 | self.update(group) 203 | group["counter"] += 1 204 | if group["counter"] >= self.k: 205 | group["counter"] = 0 206 | return loss 207 | 208 | def state_dict(self): 209 | fast_state_dict = self.optimizer.state_dict() 210 | slow_state = { 211 | (id(k) if isinstance(k, torch.Tensor) else k): v 212 | for k, v in self.state.items() 213 | } 214 | fast_state = fast_state_dict["state"] 215 | param_groups = fast_state_dict["param_groups"] 216 | return { 217 | "fast_state": fast_state, 218 | "slow_state": slow_state, 219 | "param_groups": param_groups, 220 | } 221 | 222 | def load_state_dict(self, state_dict): 223 | slow_state_dict = { 224 | "state": state_dict["slow_state"], 225 | "param_groups": state_dict["param_groups"], 226 | } 227 | fast_state_dict = { 228 | "state": state_dict["fast_state"], 229 | "param_groups": state_dict["param_groups"], 230 | } 231 | super(Lookahead, self).load_state_dict(slow_state_dict) 232 | self.optimizer.load_state_dict(fast_state_dict) 233 | self.fast_state = self.optimizer.state 234 | 235 | def add_param_group(self, param_group): 236 | param_group["counter"] = 0 237 | self.optimizer.add_param_group(param_group) 238 | 239 | class CosineAnnealingWithRestartsLR(torch.optim.lr_scheduler._LRScheduler): 240 | 241 | r"""Set the learning rate of each parameter group using a cosine annealing 242 | schedule, where :math:`\eta_{max}` is set to the initial lr and 243 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR: 244 | .. math:: 245 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 246 | \cos(\frac{T_{cur}}{T_{max}}\pi)) 247 | When last_epoch=-1, sets initial lr as lr. 248 | It has been proposed in 249 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. This implements 250 | the cosine annealing part of SGDR, the restarts and number of iterations multiplier. 251 | Args: 252 | optimizer (Optimizer): Wrapped optimizer. 253 | T_max (int): Maximum number of iterations. 254 | T_mult (float): Multiply T_max by this number after each restart. Default: 1. 255 | eta_min (float): Minimum learning rate. Default: 0. 256 | last_epoch (int): The index of last epoch. Default: -1. 257 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 258 | https://arxiv.org/abs/1608.03983 259 | """ 260 | 261 | def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, T_mult=1): 262 | self.T_max = T_max 263 | self.T_mult = T_mult 264 | self.restart_every = T_max 265 | self.eta_min = eta_min 266 | self.restarts = 0 267 | self.restarted_at = 0 268 | super().__init__(optimizer, last_epoch) 269 | 270 | def restart(self): 271 | self.restart_every *= self.T_mult 272 | self.restarted_at = self.last_epoch 273 | 274 | def cosine(self, base_lr): 275 | return self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.step_n / self.restart_every)) / 2 276 | 277 | @property 278 | def step_n(self): 279 | return self.last_epoch - self.restarted_at 280 | 281 | def get_lr(self): 282 | if self.step_n >= self.restart_every: 283 | self.restart() 284 | return [self.cosine(base_lr) for base_lr in self.base_lrs] 285 | 286 | -------------------------------------------------------------------------------- /merge_result.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv 3 | import glob 4 | import os 5 | from collections import Counter 6 | import json 7 | 8 | def merge(results, vote=True): 9 | r_dict = {} 10 | for r in results: 11 | with open(r, 'r') as f: 12 | f_csv = csv.reader(f) 13 | f_csv.__next__() 14 | for row in f_csv: 15 | if vote: 16 | if row[0] in r_dict: 17 | r_dict[row[0]].append(int(row[1])) 18 | else: 19 | r_dict[row[0]] = [int(row[1])] 20 | else: 21 | if row[0] in r_dict: 22 | r_dict[row[0]].append(json.loads(row[1])) 23 | else: 24 | r_dict[row[0]] = [] 25 | r_dict[row[0]].append(json.loads(row[1])) 26 | with open('results.csv', 'w', encoding='utf-8') as f: 27 | f_csv = csv.writer(f) 28 | header = ['filename', 'type'] 29 | f_csv.writerow(header) 30 | for key in r_dict.keys(): 31 | if vote: 32 | value = np.argmax(np.bincount(r_dict[key])) 33 | else: 34 | r_dict[key] = np.array(r_dict[key]) 35 | value = np.argmax(np.sum(r_dict[key], 0)) + 1 36 | f_csv.writerow([key, value]) 37 | 38 | print('merge finished') 39 | 40 | 41 | if __name__ == '__main__': 42 | results = [] 43 | res_dir = 'results' 44 | # res_dir = 'log' 45 | for res in glob.glob(res_dir+'/*.csv'): 46 | results.append(res) 47 | merge(results, False) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torchvision import models 6 | import pretrainedmodels as pmodels 7 | from efficientnet_pytorch import EfficientNet 8 | from multigrain.lib import get_multigrain 9 | from fixres.pnasnet import pnasnet5large 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 19 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 20 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 21 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth', 23 | 'resnext101_32x16d': 'https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth', 24 | 'resnext101_32x32d': 'https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth', 25 | 'resnext101_32x48d': 'https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth', 26 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 27 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 28 | 'densenet121':'https://download.pytorch.org/models/densenet121-a639ec97.pth', 29 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 30 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 31 | 'inception_v3': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 32 | 'fixpnas': '/home/lzw/.cache/torch/checkpoints/PNASNet.pth' 33 | } 34 | 35 | def resnet(num_classes=9, layers=101, state_dict=None): 36 | if layers == 18: 37 | model = models.resnet18() 38 | elif layers == 34: 39 | model = models.resnet34() 40 | elif layers == 50: 41 | model = models.resnet50() 42 | elif layers == 101: 43 | model = models.resnet101() 44 | elif layers == 152: 45 | model = models.resnet152() 46 | 47 | if state_dict is not None: 48 | print('load_state_dict') 49 | model.load_state_dict(state_dict) 50 | 51 | num_ftrs = model.fc.in_features 52 | model.fc = nn.Linear(num_ftrs, num_classes) 53 | 54 | return model 55 | 56 | def resnext(num_classes=9, layers=101, state_dict=None): 57 | if layers == 50: 58 | model = models.resnext50_32x4d() 59 | elif layers == 101: 60 | model = models.resnext101_32x8d() 61 | 62 | if state_dict is not None: 63 | model.load_state_dict(state_dict) 64 | 65 | num_ftrs = model.fc.in_features 66 | model.fc = nn.Linear(num_ftrs, num_classes) 67 | 68 | return model 69 | 70 | def resnext_wsl(num_classes=9, bottleneck_width=8): 71 | model = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x'+str(bottleneck_width)+'d_wsl') 72 | 73 | num_ftrs = model.fc.in_features 74 | model.fc = nn.Linear(num_ftrs, num_classes) 75 | return model 76 | 77 | def resnext_swsl(num_classes=9, layers=101, bottleneck_width=8): 78 | model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext'+str(layers)+'_32x'+str(bottleneck_width)+'d_swsl') 79 | 80 | num_ftrs = model.fc.in_features 81 | model.fc = nn.Linear(num_ftrs, num_classes) 82 | return model 83 | 84 | def vgg_bn(num_classes=9, layers=16, state_dict=None): 85 | if layers == 16: 86 | model = models.vgg16_bn() 87 | elif layers == 19: 88 | model = models.vgg19_bn() 89 | 90 | if state_dict is not None: 91 | model.load_state_dict(state_dict) 92 | 93 | model._modules['6'] = nn.Linear(4096, num_classes) 94 | return model 95 | 96 | def densenet(num_classes=9, layers=121, state_dict=None): 97 | ''' 98 | layers: 121, 201, 161 99 | ''' 100 | if layers == 121: 101 | model = models.densenet121() 102 | elif layers == 201: 103 | model = models.densenet201() 104 | elif layers == 161: 105 | model = models.densenet161() 106 | 107 | if state_dict is not None: 108 | model.load_state_dict(state_dict) 109 | 110 | num_ftrs = model.classifier.in_features 111 | model.classifier = nn.Linear(num_ftrs, num_classes) 112 | return model 113 | 114 | def inception_v3(num_classes=9, layers=101, state_dict=None): 115 | model = models.inception_v3() 116 | if state_dict is not None: 117 | model.load_state_dict(state_dict) 118 | 119 | aux_ftrs = model.AuxLogits.fc.in_features 120 | model.AuxLogits.fc = nn.Linear(aux_ftrs, num_classes) 121 | num_ftrs = model.fc.in_features 122 | model.fc = nn.Linear(num_ftrs, num_classes) 123 | return model 124 | 125 | def dpn(num_classes=9, layers=92, pretrained=True): 126 | model = torch.hub.load('rwightman/pytorch-dpn-pretrained', 'dpn'+str(layers), pretrained=pretrained) 127 | 128 | in_chs = model.classifier.in_channels 129 | model.classifier = nn.Conv2d(in_chs, num_classes, kernel_size=1, bias=True) 130 | return model 131 | 132 | class EffNet(nn.Module): 133 | def __init__(self, num_classes=9, layers=0, pretrained=False): 134 | super(EffNet, self).__init__() 135 | if pretrained: 136 | self.model = EfficientNet.from_pretrained('efficientnet-b'+str(layers)) 137 | else: 138 | self.model = EfficientNet.from_name('efficientnet-b'+str(layers)) 139 | 140 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 141 | self.maxpool = nn.AdaptiveMaxPool2d((1, 1)) 142 | 143 | num_ftrs = self.model._fc.in_features 144 | self._fc = nn.Sequential( 145 | nn.BatchNorm1d(num_ftrs*2), 146 | nn.Dropout(inplace=True), 147 | nn.Linear(num_ftrs*2, num_ftrs, bias=False), 148 | nn.ReLU(inplace=True), 149 | nn.BatchNorm1d(num_ftrs), 150 | nn.Dropout(inplace=True), 151 | nn.Linear(num_ftrs, num_classes, bias=False) 152 | ) 153 | 154 | def forward(self, x): 155 | x = self.model.extract_features(x) 156 | avgfeature = torch.flatten(self.avgpool(x), 1) 157 | maxfeature = torch.flatten(self.maxpool(x), 1) 158 | x = torch.cat([avgfeature, maxfeature], 1) 159 | return self._fc(x) 160 | 161 | def extract_features(self, x): 162 | x = self.model.extract_features(x) 163 | avgfeature = torch.flatten(self.avgpool(x), 1) 164 | maxfeature = torch.flatten(self.maxpool(x), 1) 165 | x = torch.cat([avgfeature, maxfeature], 1) 166 | return x 167 | 168 | def effnet(num_classes=9, layers=0, pretrained=False): 169 | return EffNet(num_classes, layers, pretrained) 170 | 171 | # def effnet(num_classes=9, layers=0, pretrained=False): 172 | # if pretrained: 173 | # model = EfficientNet.from_pretrained('efficientnet-b'+str(layers)) 174 | # else: 175 | # model = EfficientNet.from_name('efficientnet-b'+str(layers)) 176 | # num_ftrs = model._fc.in_features 177 | 178 | # model._fc = nn.Linear(num_ftrs, num_classes) 179 | # return model 180 | 181 | def pnasnet_m(num_classes=9, layers=5, pretrained=False): 182 | model = get_multigrain(backbone='pnasnet5large', include_sampling=False, learn_p=False, p=1.7) 183 | if pretrained: 184 | model.load_state_dict( 185 | torch.load('/home/lzw/.cache/torch/checkpoints/pnasnet5large-finetune500.pth')['model_state']) 186 | num_ftrs = model.classifier.in_features 187 | model.classifier = nn.Linear(num_ftrs, num_classes) 188 | return model 189 | 190 | def senet_m(num_classes=9, layers=154, pretrained=False): 191 | model = get_multigrain(backbone='senet154', include_sampling=False, learn_p=False, p=1.6) 192 | if pretrained: 193 | model.load_state_dict( 194 | torch.load('/home/lzw/.cache/torch/checkpoints/senet154-finetune400.pth')['model_state']) 195 | num_ftrs = model.classifier.in_features 196 | model.classifier = nn.Linear(num_ftrs, num_classes) 197 | return model 198 | 199 | def cadene_model(num_classes=9, model_name='inceptionresnetv2'): 200 | model = pmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet') 201 | 202 | if model_name == 'inceptionresnetv2': 203 | model.avgpool_1a = nn.AdaptiveAvgPool2d((1, 1)) 204 | elif model_name[:5] in ['resne', 'senet', 'pnasn', 'nasne', 'polyn', 'se_re']: 205 | model.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 206 | 207 | num_ftrs = model.last_linear.in_features 208 | model.last_linear = nn.Linear(num_ftrs, num_classes) 209 | return model 210 | 211 | def fixpnas(num_classes=9, pretrained=False): 212 | model = pnasnet5large(pretrained=None) 213 | if pretrained: 214 | pretrained_dict=torch.load(model_urls['fixpnas'],map_location='cpu')['model'] 215 | model_dict = model.state_dict() 216 | for k in model_dict.keys(): 217 | if(('module.'+k) in pretrained_dict.keys()): 218 | model_dict[k]=pretrained_dict.get(('module.'+k)) 219 | model.load_state_dict(model_dict) 220 | else: 221 | model = pnasnet5large(pretrained=None) 222 | num_ftrs = model.last_linear.in_features 223 | model.last_linear = nn.Linear(num_ftrs, num_classes) 224 | return model 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /multigrain/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import get_transforms, transforms_list -------------------------------------------------------------------------------- /multigrain/augmentations/autoaugment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | #!/usr/bin/env python3 8 | 9 | from enum import Enum, auto 10 | from typing import Tuple, Any 11 | from abc import ABC, abstractmethod 12 | 13 | from PIL import Image, ImageEnhance, ImageOps 14 | import numpy as np 15 | import random 16 | 17 | # from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py 18 | 19 | RGBColor = Tuple[int, int, int] 20 | MIDDLE_GRAY = (128, 128, 128) 21 | 22 | 23 | class ImageOp(Enum): 24 | SHEAR_X = auto() 25 | SHEAR_Y = auto() 26 | TRANSLATE_X = auto() 27 | TRANSLATE_Y = auto() 28 | ROTATE = auto() 29 | AUTO_CONTRAST = auto() 30 | INVERT = auto() 31 | EQUALIZE = auto() 32 | SOLARIZE = auto() 33 | POSTERIZE = auto() 34 | CONTRAST = auto() 35 | COLOR = auto() 36 | BRIGHTNESS = auto() 37 | SHARPNESS = auto() 38 | 39 | 40 | class AutoAugmentPolicy(ABC): 41 | 42 | def __init__(self) -> None: 43 | pass 44 | 45 | @abstractmethod 46 | def __call__(self, img: Any) -> Any: 47 | pass 48 | 49 | 50 | class ImageNetPolicy(AutoAugmentPolicy): 51 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 52 | 53 | Example: 54 | >>> policy = ImageNetPolicy() 55 | >>> transformed = policy(image) 56 | 57 | Example as a PyTorch Transform: 58 | >>> transform=transforms.Compose([ 59 | >>> transforms.Resize(256), 60 | >>> ImageNetPolicy(), 61 | >>> transforms.ToTensor()]) 62 | """ 63 | 64 | def __init__(self, fillcolor: RGBColor = MIDDLE_GRAY) -> None: 65 | self.policies = [ 66 | SubPolicy(ImageOp.POSTERIZE, 8, 0.4, ImageOp.ROTATE, 9, 0.6, fillcolor), 67 | SubPolicy( 68 | ImageOp.SOLARIZE, 5, 0.6, ImageOp.AUTO_CONTRAST, 5, 0.6, fillcolor 69 | ), 70 | SubPolicy(ImageOp.EQUALIZE, 8, 0.8, ImageOp.EQUALIZE, 3, 0.6, fillcolor), 71 | SubPolicy(ImageOp.POSTERIZE, 7, 0.6, ImageOp.POSTERIZE, 6, 0.6, fillcolor), 72 | SubPolicy(ImageOp.EQUALIZE, 7, 0.4, ImageOp.SOLARIZE, 4, 0.2, fillcolor), 73 | SubPolicy(ImageOp.EQUALIZE, 4, 0.4, ImageOp.ROTATE, 8, 0.8, fillcolor), 74 | SubPolicy(ImageOp.SOLARIZE, 3, 0.6, ImageOp.EQUALIZE, 7, 0.6, fillcolor), 75 | SubPolicy(ImageOp.POSTERIZE, 5, 0.8, ImageOp.EQUALIZE, 2, 1.0, fillcolor), 76 | SubPolicy(ImageOp.ROTATE, 3, 0.2, ImageOp.SOLARIZE, 8, 0.6, fillcolor), 77 | SubPolicy(ImageOp.EQUALIZE, 8, 0.6, ImageOp.POSTERIZE, 6, 0.4, fillcolor), 78 | SubPolicy(ImageOp.ROTATE, 8, 0.8, ImageOp.COLOR, 0, 0.4, fillcolor), 79 | SubPolicy(ImageOp.ROTATE, 9, 0.4, ImageOp.EQUALIZE, 2, 0.6, fillcolor), 80 | SubPolicy(ImageOp.EQUALIZE, 7, 0.0, ImageOp.EQUALIZE, 8, 0.8, fillcolor), 81 | SubPolicy(ImageOp.INVERT, 4, 0.6, ImageOp.EQUALIZE, 8, 1.0, fillcolor), 82 | SubPolicy(ImageOp.COLOR, 4, 0.6, ImageOp.CONTRAST, 8, 1.0, fillcolor), 83 | SubPolicy(ImageOp.ROTATE, 8, 0.8, ImageOp.COLOR, 2, 1.0, fillcolor), 84 | SubPolicy(ImageOp.COLOR, 8, 0.8, ImageOp.SOLARIZE, 7, 0.8, fillcolor), 85 | SubPolicy(ImageOp.SHARPNESS, 7, 0.4, ImageOp.INVERT, 8, 0.6, fillcolor), 86 | SubPolicy(ImageOp.SHEAR_X, 5, 0.6, ImageOp.EQUALIZE, 9, 1.0, fillcolor), 87 | SubPolicy(ImageOp.COLOR, 0, 0.4, ImageOp.EQUALIZE, 3, 0.6, fillcolor), 88 | SubPolicy(ImageOp.EQUALIZE, 7, 0.4, ImageOp.SOLARIZE, 4, 0.2, fillcolor), 89 | SubPolicy( 90 | ImageOp.SOLARIZE, 5, 0.6, ImageOp.AUTO_CONTRAST, 5, 0.6, fillcolor 91 | ), 92 | SubPolicy(ImageOp.INVERT, 4, 0.6, ImageOp.EQUALIZE, 8, 1.0, fillcolor), 93 | SubPolicy(ImageOp.COLOR, 4, 0.6, ImageOp.CONTRAST, 8, 1.0, fillcolor), 94 | ] 95 | 96 | def __call__(self, img: Any) -> Any: 97 | policy_idx = random.randint(0, len(self.policies) - 1) 98 | return self.policies[policy_idx](img) 99 | 100 | def __repr__(self) -> str: 101 | return "AutoAugment ImageNet Policy" 102 | 103 | 104 | class CIFAR10Policy(AutoAugmentPolicy): 105 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 106 | 107 | Example: 108 | >>> policy = CIFAR10Policy() 109 | >>> transformed = policy(image) 110 | 111 | Example as a PyTorch Transform: 112 | >>> transform=transforms.Compose([ 113 | >>> transforms.Resize(256), 114 | >>> CIFAR10Policy(), 115 | >>> transforms.ToTensor()]) 116 | """ 117 | 118 | def __init__(self, fillcolor: RGBColor = MIDDLE_GRAY) -> None: 119 | self.policies = [ 120 | SubPolicy(ImageOp.INVERT, 7, 0.1, ImageOp.CONTRAST, 6, 0.2, fillcolor), 121 | SubPolicy(ImageOp.ROTATE, 2, 0.7, ImageOp.TRANSLATE_X, 9, 0.3, fillcolor), 122 | SubPolicy(ImageOp.SHARPNESS, 1, 0.8, ImageOp.SHARPNESS, 3, 0.9, fillcolor), 123 | SubPolicy(ImageOp.SHEAR_Y, 8, 0.5, ImageOp.TRANSLATE_Y, 9, 0.7, fillcolor), 124 | SubPolicy( 125 | ImageOp.AUTO_CONTRAST, 8, 0.5, ImageOp.EQUALIZE, 2, 0.9, fillcolor 126 | ), 127 | SubPolicy(ImageOp.SHEAR_Y, 7, 0.2, ImageOp.POSTERIZE, 7, 0.3, fillcolor), 128 | SubPolicy(ImageOp.COLOR, 3, 0.4, ImageOp.BRIGHTNESS, 7, 0.6, fillcolor), 129 | SubPolicy(ImageOp.SHARPNESS, 9, 0.3, ImageOp.BRIGHTNESS, 9, 0.7, fillcolor), 130 | SubPolicy(ImageOp.EQUALIZE, 5, 0.6, ImageOp.EQUALIZE, 1, 0.5, fillcolor), 131 | SubPolicy(ImageOp.CONTRAST, 7, 0.6, ImageOp.SHARPNESS, 5, 0.6, fillcolor), 132 | SubPolicy(ImageOp.COLOR, 7, 0.7, ImageOp.TRANSLATE_X, 8, 0.5, fillcolor), 133 | SubPolicy( 134 | ImageOp.EQUALIZE, 7, 0.3, ImageOp.AUTO_CONTRAST, 8, 0.4, fillcolor 135 | ), 136 | SubPolicy( 137 | ImageOp.TRANSLATE_Y, 3, 0.4, ImageOp.SHARPNESS, 6, 0.2, fillcolor 138 | ), 139 | SubPolicy(ImageOp.BRIGHTNESS, 6, 0.9, ImageOp.COLOR, 8, 0.2, fillcolor), 140 | SubPolicy(ImageOp.SOLARIZE, 2, 0.5, ImageOp.INVERT, 3, 0.0, fillcolor), 141 | SubPolicy( 142 | ImageOp.EQUALIZE, 0, 0.2, ImageOp.AUTO_CONTRAST, 0, 0.6, fillcolor 143 | ), 144 | SubPolicy(ImageOp.EQUALIZE, 8, 0.2, ImageOp.EQUALIZE, 4, 0.8, fillcolor), 145 | SubPolicy(ImageOp.COLOR, 9, 0.9, ImageOp.EQUALIZE, 6, 0.6, fillcolor), 146 | SubPolicy( 147 | ImageOp.AUTO_CONTRAST, 4, 0.8, ImageOp.SOLARIZE, 8, 0.2, fillcolor 148 | ), 149 | SubPolicy(ImageOp.BRIGHTNESS, 3, 0.1, ImageOp.COLOR, 0, 0.7, fillcolor), 150 | SubPolicy( 151 | ImageOp.SOLARIZE, 5, 0.4, ImageOp.AUTO_CONTRAST, 3, 0.9, fillcolor 152 | ), 153 | SubPolicy( 154 | ImageOp.TRANSLATE_Y, 9, 0.9, ImageOp.TRANSLATE_Y, 9, 0.7, fillcolor 155 | ), 156 | SubPolicy( 157 | ImageOp.AUTO_CONTRAST, 2, 0.9, ImageOp.SOLARIZE, 3, 0.8, fillcolor 158 | ), 159 | SubPolicy(ImageOp.EQUALIZE, 8, 0.8, ImageOp.INVERT, 3, 0.1, fillcolor), 160 | SubPolicy( 161 | ImageOp.TRANSLATE_Y, 9, 0.7, ImageOp.AUTO_CONTRAST, 1, 0.9, fillcolor 162 | ), 163 | ] 164 | 165 | def __call__(self, img: Any) -> Any: 166 | policy_idx = random.randint(0, len(self.policies) - 1) 167 | return self.policies[policy_idx](img) 168 | 169 | def __repr__(self) -> str: 170 | return "AutoAugment CIFAR10 Policy" 171 | 172 | 173 | class SVHNPolicy(AutoAugmentPolicy): 174 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 175 | 176 | Example: 177 | >>> policy = SVHNPolicy() 178 | >>> transformed = policy(image) 179 | 180 | Example as a PyTorch Transform: 181 | >>> transform=transforms.Compose([ 182 | >>> transforms.Resize(256), 183 | >>> SVHNPolicy(), 184 | >>> transforms.ToTensor()]) 185 | """ 186 | 187 | def __init__(self, fillcolor: RGBColor = MIDDLE_GRAY) -> None: 188 | self.policies = [ 189 | SubPolicy(ImageOp.SHEAR_X, 4, 0.9, ImageOp.INVERT, 3, 0.2, fillcolor), 190 | SubPolicy(ImageOp.SHEAR_Y, 8, 0.9, ImageOp.INVERT, 5, 0.7, fillcolor), 191 | SubPolicy(ImageOp.EQUALIZE, 5, 0.6, ImageOp.SOLARIZE, 6, 0.6, fillcolor), 192 | SubPolicy(ImageOp.INVERT, 3, 0.9, ImageOp.EQUALIZE, 3, 0.6, fillcolor), 193 | SubPolicy(ImageOp.EQUALIZE, 1, 0.6, ImageOp.ROTATE, 3, 0.9, fillcolor), 194 | SubPolicy( 195 | ImageOp.SHEAR_X, 4, 0.9, ImageOp.AUTO_CONTRAST, 3, 0.8, fillcolor 196 | ), 197 | SubPolicy(ImageOp.SHEAR_Y, 8, 0.9, ImageOp.INVERT, 5, 0.4, fillcolor), 198 | SubPolicy(ImageOp.SHEAR_Y, 5, 0.9, ImageOp.SOLARIZE, 6, 0.2, fillcolor), 199 | SubPolicy(ImageOp.INVERT, 6, 0.9, ImageOp.AUTO_CONTRAST, 1, 0.8, fillcolor), 200 | SubPolicy(ImageOp.EQUALIZE, 3, 0.6, ImageOp.ROTATE, 3, 0.9, fillcolor), 201 | SubPolicy(ImageOp.SHEAR_X, 4, 0.9, ImageOp.SOLARIZE, 3, 0.3, fillcolor), 202 | SubPolicy(ImageOp.SHEAR_Y, 8, 0.8, ImageOp.INVERT, 4, 0.7, fillcolor), 203 | SubPolicy(ImageOp.EQUALIZE, 5, 0.9, ImageOp.TRANSLATE_Y, 6, 0.6, fillcolor), 204 | SubPolicy(ImageOp.INVERT, 4, 0.9, ImageOp.EQUALIZE, 7, 0.6, fillcolor), 205 | SubPolicy(ImageOp.CONTRAST, 3, 0.3, ImageOp.ROTATE, 4, 0.8, fillcolor), 206 | SubPolicy(ImageOp.INVERT, 5, 0.8, ImageOp.TRANSLATE_Y, 2, 0.0, fillcolor), 207 | SubPolicy(ImageOp.SHEAR_Y, 6, 0.7, ImageOp.SOLARIZE, 8, 0.4, fillcolor), 208 | SubPolicy(ImageOp.INVERT, 4, 0.6, ImageOp.ROTATE, 4, 0.8, fillcolor), 209 | SubPolicy(ImageOp.SHEAR_Y, 7, 0.3, ImageOp.TRANSLATE_X, 3, 0.9, fillcolor), 210 | SubPolicy(ImageOp.SHEAR_X, 6, 0.1, ImageOp.INVERT, 5, 0.6, fillcolor), 211 | SubPolicy(ImageOp.SOLARIZE, 2, 0.7, ImageOp.TRANSLATE_Y, 7, 0.6, fillcolor), 212 | SubPolicy(ImageOp.SHEAR_Y, 4, 0.8, ImageOp.INVERT, 8, 0.8, fillcolor), 213 | SubPolicy(ImageOp.SHEAR_X, 9, 0.7, ImageOp.TRANSLATE_Y, 3, 0.8, fillcolor), 214 | SubPolicy( 215 | ImageOp.SHEAR_Y, 5, 0.8, ImageOp.AUTO_CONTRAST, 3, 0.7, fillcolor 216 | ), 217 | SubPolicy(ImageOp.SHEAR_X, 2, 0.7, ImageOp.INVERT, 5, 0.1, fillcolor), 218 | ] 219 | 220 | def __call__(self, img: Any) -> Any: 221 | policy_idx = random.randint(0, len(self.policies) - 1) 222 | return self.policies[policy_idx](img) 223 | 224 | def __repr__(self) -> str: 225 | return "AutoAugment SVHN Policy" 226 | 227 | 228 | class SubPolicy(object): 229 | 230 | ranges = { 231 | ImageOp.SHEAR_X: np.linspace(0, 0.3, 10), 232 | ImageOp.SHEAR_Y: np.linspace(0, 0.3, 10), 233 | ImageOp.TRANSLATE_X: np.linspace(0, 150 / 331, 10), 234 | ImageOp.TRANSLATE_Y: np.linspace(0, 150 / 331, 10), 235 | ImageOp.ROTATE: np.linspace(0, 30, 10), 236 | ImageOp.COLOR: np.linspace(0.0, 0.9, 10), 237 | ImageOp.POSTERIZE: np.round(np.linspace(8, 4, 10), 0).astype(np.int), 238 | ImageOp.SOLARIZE: np.linspace(256, 0, 10), 239 | ImageOp.CONTRAST: np.linspace(0.0, 0.9, 10), 240 | ImageOp.SHARPNESS: np.linspace(0.0, 0.9, 10), 241 | ImageOp.BRIGHTNESS: np.linspace(0.0, 0.9, 10), 242 | ImageOp.AUTO_CONTRAST: [0] * 10, 243 | ImageOp.EQUALIZE: [0] * 10, 244 | ImageOp.INVERT: [0] * 10, 245 | } 246 | 247 | def __init__( 248 | self, 249 | operation1: ImageOp, 250 | magnitude_idx1: int, 251 | p1: float, 252 | operation2: ImageOp, 253 | magnitude_idx2: int, 254 | p2: float, 255 | fillcolor: RGBColor = MIDDLE_GRAY, 256 | ) -> None: 257 | 258 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 259 | def rotate_with_fill(img: Any, magnitude: int) -> Any: 260 | rot = img.convert("RGBA").rotate(magnitude) 261 | return Image.composite( 262 | rot, Image.new("RGBA", rot.size, (128,) * 4), rot 263 | ).convert(img.mode) 264 | 265 | func = { 266 | ImageOp.SHEAR_X: lambda img, magnitude: img.transform( 267 | img.size, 268 | Image.AFFINE, 269 | (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 270 | Image.BICUBIC, 271 | fillcolor=fillcolor, 272 | ), 273 | ImageOp.SHEAR_Y: lambda img, magnitude: img.transform( 274 | img.size, 275 | Image.AFFINE, 276 | (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 277 | Image.BICUBIC, 278 | fillcolor=fillcolor, 279 | ), 280 | ImageOp.TRANSLATE_X: lambda img, magnitude: img.transform( 281 | img.size, 282 | Image.AFFINE, 283 | (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 284 | fillcolor=fillcolor, 285 | ), 286 | ImageOp.TRANSLATE_Y: lambda img, magnitude: img.transform( 287 | img.size, 288 | Image.AFFINE, 289 | (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 290 | fillcolor=fillcolor, 291 | ), 292 | ImageOp.ROTATE: lambda img, magnitude: rotate_with_fill(img, magnitude), 293 | ImageOp.COLOR: lambda img, magnitude: ImageEnhance.Color(img).enhance( 294 | 1 + magnitude * random.choice([-1, 1]) 295 | ), 296 | ImageOp.POSTERIZE: lambda img, magnitude: ImageOps.posterize( 297 | img, magnitude 298 | ), 299 | ImageOp.SOLARIZE: lambda img, magnitude: ImageOps.solarize(img, magnitude), 300 | ImageOp.CONTRAST: lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 301 | 1 + magnitude * random.choice([-1, 1]) 302 | ), 303 | ImageOp.SHARPNESS: lambda img, magnitude: ImageEnhance.Sharpness( 304 | img 305 | ).enhance(1 + magnitude * random.choice([-1, 1])), 306 | ImageOp.BRIGHTNESS: lambda img, magnitude: ImageEnhance.Brightness( 307 | img 308 | ).enhance(1 + magnitude * random.choice([-1, 1])), 309 | ImageOp.AUTO_CONTRAST: lambda img, magnitude: ImageOps.autocontrast(img), 310 | ImageOp.EQUALIZE: lambda img, magnitude: ImageOps.equalize(img), 311 | ImageOp.INVERT: lambda img, magnitude: ImageOps.invert(img), 312 | } 313 | 314 | self.operation1 = func[operation1] 315 | self.magnitude1 = self.ranges[operation1][magnitude_idx1] 316 | self.p1 = p1 317 | self.operation2 = func[operation2] 318 | self.magnitude2 = self.ranges[operation2][magnitude_idx2] 319 | self.p2 = p2 320 | 321 | def __call__(self, img: Any) -> Any: 322 | if random.random() < self.p1: 323 | img = self.operation1(img, self.magnitude1) 324 | if random.random() < self.p2: 325 | img = self.operation2(img, self.magnitude2) 326 | return img 327 | 328 | -------------------------------------------------------------------------------- /multigrain/augmentations/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torchvision.transforms.functional as F 9 | from torchvision import transforms 10 | 11 | from multigrain.datasets import IN1K 12 | from .autoaugment import ImageNetPolicy 13 | 14 | 15 | class Resize(transforms.Resize): 16 | """ 17 | Resize with a ``largest=False'' argument 18 | allowing to resize to a common largest side without cropping 19 | """ 20 | 21 | 22 | def __init__(self, size, largest=False, **kwargs): 23 | super().__init__(size, **kwargs) 24 | self.largest = largest 25 | 26 | @staticmethod 27 | def target_size(w, h, size, largest=False): 28 | if h < w and largest: 29 | w, h = size, int(size * h / w) 30 | else: 31 | w, h = int(size * w / h), size 32 | size = (h, w) 33 | return size 34 | 35 | def __call__(self, img): 36 | size = self.size 37 | w, h = img.size 38 | target_size = self.target_size(w, h, size, self.largest) 39 | return F.resize(img, target_size, self.interpolation) 40 | 41 | def __repr__(self): 42 | r = super().__repr__() 43 | return r[:-1] + ', largest={})'.format(self.largest) 44 | 45 | 46 | class Lighting(object): 47 | """ 48 | PCA jitter transform on tensors 49 | """ 50 | def __init__(self, alpha_std, eig_val, eig_vec): 51 | self.alpha_std = alpha_std 52 | self.eig_val = torch.as_tensor(eig_val, dtype=torch.float).view(1, 3) 53 | self.eig_vec = torch.as_tensor(eig_vec, dtype=torch.float) 54 | 55 | def __call__(self, data): 56 | if self.alpha_std == 0: 57 | return data 58 | alpha = torch.empty(1, 3).normal_(0, self.alpha_std) 59 | rgb = ((self.eig_vec * alpha) * self.eig_val).sum(1) 60 | data += rgb.view(3, 1, 1) 61 | data /= 1. + self.alpha_std 62 | return data 63 | 64 | 65 | class Bound(object): 66 | def __init__(self, lower=0., upper=1.): 67 | self.lower = lower 68 | self.upper = upper 69 | 70 | def __call__(self, data): 71 | return data.clamp_(self.lower, self.upper) 72 | 73 | 74 | def get_transforms(Dataset=IN1K, input_size=224, kind='full', crop=True, need=('train', 'val'), backbone=None): 75 | mean, std = Dataset.MEAN, Dataset.STD 76 | if backbone is not None and backbone in ['pnasnet5large', 'nasnetamobile']: 77 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 78 | transformations = {} 79 | if 'train' in need: 80 | if kind == 'torch': 81 | transformations['train'] = transforms.Compose([ 82 | transforms.RandomResizedCrop(input_size), 83 | transforms.RandomHorizontalFlip(), 84 | transforms.ToTensor(), 85 | transforms.Normalize(mean, std), 86 | ]) 87 | elif kind == 'full': 88 | transformations['train'] = transforms.Compose([ 89 | transforms.RandomResizedCrop(input_size), 90 | transforms.RandomHorizontalFlip(), 91 | transforms.ColorJitter(0.3, 0.3, 0.3), 92 | transforms.ToTensor(), 93 | Lighting(0.1, Dataset.EIG_VALS, Dataset.EIG_VECS), 94 | Bound(0., 1.), 95 | transforms.Normalize(mean, std), 96 | ]) 97 | elif kind == 'senet': 98 | transformations['train'] = transforms.Compose([ 99 | transforms.RandomResizedCrop(input_size), 100 | transforms.RandomHorizontalFlip(), 101 | transforms.ColorJitter(0.2, 0.2, 0.2), 102 | transforms.ToTensor(), 103 | Lighting(0.1, Dataset.EIG_VALS, Dataset.EIG_VECS), 104 | Bound(0., 1.), 105 | transforms.Normalize(mean, std), 106 | ]) 107 | elif kind == 'AA': 108 | transformations['train'] = transforms.Compose([ 109 | transforms.RandomResizedCrop(input_size), 110 | transforms.RandomHorizontalFlip(), 111 | ImageNetPolicy(), 112 | transforms.ToTensor(), 113 | transforms.Normalize(mean, std), 114 | ]) 115 | else: 116 | raise ValueError('Transforms kind {} unknown'.format(kind)) 117 | if 'val' in need: 118 | if crop: 119 | transformations['val'] = transforms.Compose( 120 | [Resize(int((256 / 224) * input_size)), # to maintain same ratio w.r.t. 224 images 121 | transforms.CenterCrop(input_size), 122 | transforms.ToTensor(), 123 | transforms.Normalize(mean, std)]) 124 | else: 125 | transformations['val'] = transforms.Compose( 126 | [Resize(input_size, largest=True), # to maintain same ratio w.r.t. 224 images 127 | transforms.ToTensor(), 128 | transforms.Normalize(mean, std)]) 129 | return transformations 130 | 131 | transforms_list = ['torch', 'full', 'senet', 'AA'] -------------------------------------------------------------------------------- /multigrain/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone import BackBone, backbone_list 2 | 3 | -------------------------------------------------------------------------------- /multigrain/backbones/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | from torch import nn 9 | import torch.utils.checkpoint 10 | import multigrain 11 | from torchvision.models import resnet18, resnet50, resnet101, resnet152 12 | from pretrainedmodels.models import senet154 13 | from .pnasnet import pnasnet5large 14 | from .nasnet_mobile import nasnetamobile 15 | from collections import OrderedDict as OD 16 | from multigrain.modules.layers import Layer 17 | # torch.utils.checkpoint.preserve_rng_state=False 18 | 19 | 20 | backbone_list = ['resnet18', 'resnet50', 'resnet101', 'resnet152', 'senet154', 'pnasnet5large', 'nasnetamobile'] 21 | 22 | 23 | class Features(nn.Module): 24 | def __init__(self, net): 25 | super().__init__() 26 | self.base_net = net 27 | 28 | def forward(self, x): 29 | return self.base_net.features(x) 30 | 31 | 32 | class BackBone(nn.Module): 33 | """ 34 | Base networks with output dict and standarized structure 35 | Returns embedding, classifier_output 36 | """ 37 | def __init__(self, net, **kwargs): 38 | super().__init__() 39 | if isinstance(net, str): 40 | if net not in backbone_list: 41 | raise ValueError('Available backbones:', ', '.join(backbone_list)) 42 | net = multigrain.backbones.backbone.__dict__[net](**kwargs) 43 | children = list(net.named_children()) 44 | self.pre_classifier = None 45 | if type(net).__name__ == 'ResNet': 46 | self.features = nn.Sequential(OD(children[:-2])) 47 | self.pool = children[-2][1] 48 | self.classifier = children[-1][1] 49 | elif type(net).__name__ == 'SENet': 50 | self.features = nn.Sequential(OD(children[:-3])) 51 | self.pool = children[-3][1] 52 | self.pre_classifier = children[-2][1] 53 | self.classifier = children[-1][1] 54 | elif type(net).__name__ in ['PNASNet5Large', 'NASNetAMobile']: 55 | self.features = nn.Sequential(Features(net), nn.ReLU()) 56 | self.pool = children[-3][1] 57 | self.pre_classifier = children[-2][1] 58 | self.classifier = children[-1][1] 59 | else: 60 | raise NotImplementedError('Unknown base net', type(net).__name__) 61 | self.whitening = None 62 | 63 | def forward(self, input): 64 | output = {} 65 | if isinstance(input, list): 66 | # for lists of tensors of unequal input size 67 | features = map(self.features, [i.unsqueeze(0) for i in input]) 68 | embedding = torch.cat([self.pool(f) for f in features], 0) 69 | else: 70 | features = self.features(input) 71 | embedding = self.pool(features) 72 | if self.whitening is not None: 73 | embedding = self.whitening(embedding) 74 | 75 | classifier_input = embedding 76 | if self.pre_classifier is not None: 77 | classifier_input = self.pre_classifier(classifier_input) 78 | classifier_output = self.classifier(classifier_input) 79 | return embedding, classifier_output 80 | 81 | -------------------------------------------------------------------------------- /multigrain/backbones/nasnet_mobile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """ 8 | Monkey-patches https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/nasnet_mobile.py 9 | The source implementation cannot be applied to images of arbitrary shape 10 | Here we add cropping operations before additions and concatenations to address this. 11 | """ 12 | from .pnasnet import shrink_sum, shrink_cat 13 | import pretrainedmodels 14 | 15 | 16 | def CellStem0_forward(self, x): 17 | x1 = self.conv_1x1(x) 18 | 19 | x_comb_iter_0_left = self.comb_iter_0_left(x1) 20 | x_comb_iter_0_right = self.comb_iter_0_right(x) 21 | x_comb_iter_0 = shrink_sum(x_comb_iter_0_left, x_comb_iter_0_right) 22 | 23 | x_comb_iter_1_left = self.comb_iter_1_left(x1) 24 | x_comb_iter_1_right = self.comb_iter_1_right(x) 25 | x_comb_iter_1 = shrink_sum(x_comb_iter_1_left, x_comb_iter_1_right) 26 | 27 | x_comb_iter_2_left = self.comb_iter_2_left(x1) 28 | x_comb_iter_2_right = self.comb_iter_2_right(x) 29 | x_comb_iter_2 = shrink_sum(x_comb_iter_2_left, x_comb_iter_2_right) 30 | 31 | x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) 32 | x_comb_iter_3 = shrink_sum(x_comb_iter_3_right, x_comb_iter_1) 33 | 34 | x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) 35 | x_comb_iter_4_right = self.comb_iter_4_right(x1) 36 | x_comb_iter_4 = shrink_sum(x_comb_iter_4_left, x_comb_iter_4_right) 37 | 38 | x_out = shrink_cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) 39 | return x_out 40 | 41 | 42 | def CellStem1_forward(self, x_conv0, x_stem_0): 43 | x_left = self.conv_1x1(x_stem_0) 44 | 45 | x_relu = self.relu(x_conv0) 46 | # path 1 47 | x_path1 = self.path_1(x_relu) 48 | # path 2 49 | x_path2 = self.path_2.pad(x_relu) 50 | x_path2 = x_path2[:, :, 1:, 1:] 51 | x_path2 = self.path_2.avgpool(x_path2) 52 | x_path2 = self.path_2.conv(x_path2) 53 | # final path 54 | x_right = self.final_path_bn(shrink_cat([x_path1, x_path2], 1)) 55 | 56 | x_comb_iter_0_left = self.comb_iter_0_left(x_left) 57 | x_comb_iter_0_right = self.comb_iter_0_right(x_right) 58 | x_comb_iter_0 = shrink_sum(x_comb_iter_0_left, x_comb_iter_0_right) 59 | 60 | x_comb_iter_1_left = self.comb_iter_1_left(x_left) 61 | x_comb_iter_1_right = self.comb_iter_1_right(x_right) 62 | x_comb_iter_1 = shrink_sum(x_comb_iter_1_left, x_comb_iter_1_right) 63 | 64 | x_comb_iter_2_left = self.comb_iter_2_left(x_left) 65 | x_comb_iter_2_right = self.comb_iter_2_right(x_right) 66 | x_comb_iter_2 = shrink_sum(x_comb_iter_2_left, x_comb_iter_2_right) 67 | 68 | x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) 69 | x_comb_iter_3 = shrink_sum(x_comb_iter_3_right, x_comb_iter_1) 70 | 71 | x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) 72 | x_comb_iter_4_right = self.comb_iter_4_right(x_left) 73 | x_comb_iter_4 = shrink_sum(x_comb_iter_4_left, x_comb_iter_4_right) 74 | 75 | x_out = shrink_cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) 76 | return x_out 77 | 78 | 79 | def FirstCell_forward(self, x, x_prev): 80 | x_relu = self.relu(x_prev) 81 | # path 1 82 | x_path1 = self.path_1(x_relu) 83 | # path 2 84 | x_path2 = self.path_2.pad(x_relu) 85 | x_path2 = x_path2[:, :, 1:, 1:] 86 | x_path2 = self.path_2.avgpool(x_path2) 87 | x_path2 = self.path_2.conv(x_path2) 88 | # final path 89 | x_left = self.final_path_bn(shrink_cat([x_path1, x_path2], 1)) 90 | 91 | x_right = self.conv_1x1(x) 92 | 93 | x_comb_iter_0_left = self.comb_iter_0_left(x_right) 94 | x_comb_iter_0_right = self.comb_iter_0_right(x_left) 95 | x_comb_iter_0 = shrink_sum(x_comb_iter_0_left, x_comb_iter_0_right) 96 | 97 | x_comb_iter_1_left = self.comb_iter_1_left(x_left) 98 | x_comb_iter_1_right = self.comb_iter_1_right(x_left) 99 | x_comb_iter_1 = shrink_sum(x_comb_iter_1_left, x_comb_iter_1_right) 100 | 101 | x_comb_iter_2_left = self.comb_iter_2_left(x_right) 102 | x_comb_iter_2 = shrink_sum(x_comb_iter_2_left, x_left) 103 | 104 | x_comb_iter_3_left = self.comb_iter_3_left(x_left) 105 | x_comb_iter_3_right = self.comb_iter_3_right(x_left) 106 | x_comb_iter_3 = shrink_sum(x_comb_iter_3_left, x_comb_iter_3_right) 107 | 108 | x_comb_iter_4_left = self.comb_iter_4_left(x_right) 109 | x_comb_iter_4 = shrink_sum(x_comb_iter_4_left, x_right) 110 | 111 | x_out = shrink_cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) 112 | return x_out 113 | 114 | 115 | def NormalCell_forward(self, x, x_prev): 116 | x_left = self.conv_prev_1x1(x_prev) 117 | x_right = self.conv_1x1(x) 118 | 119 | x_comb_iter_0_left = self.comb_iter_0_left(x_right) 120 | x_comb_iter_0_right = self.comb_iter_0_right(x_left) 121 | x_comb_iter_0 = shrink_sum(x_comb_iter_0_left, x_comb_iter_0_right) 122 | 123 | x_comb_iter_1_left = self.comb_iter_1_left(x_left) 124 | x_comb_iter_1_right = self.comb_iter_1_right(x_left) 125 | x_comb_iter_1 = shrink_sum(x_comb_iter_1_left, x_comb_iter_1_right) 126 | 127 | x_comb_iter_2_left = self.comb_iter_2_left(x_right) 128 | x_comb_iter_2 = shrink_sum(x_comb_iter_2_left, x_left) 129 | 130 | x_comb_iter_3_left = self.comb_iter_3_left(x_left) 131 | x_comb_iter_3_right = self.comb_iter_3_right(x_left) 132 | x_comb_iter_3 = shrink_sum(x_comb_iter_3_left, x_comb_iter_3_right) 133 | 134 | x_comb_iter_4_left = self.comb_iter_4_left(x_right) 135 | x_comb_iter_4 = shrink_sum(x_comb_iter_4_left, x_right) 136 | 137 | x_out = shrink_cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) 138 | return x_out 139 | 140 | 141 | def ReductionCell0_forward(self, x, x_prev): 142 | x_left = self.conv_prev_1x1(x_prev) 143 | x_right = self.conv_1x1(x) 144 | 145 | x_comb_iter_0_left = self.comb_iter_0_left(x_right) 146 | x_comb_iter_0_right = self.comb_iter_0_right(x_left) 147 | x_comb_iter_0 = shrink_sum(x_comb_iter_0_left, x_comb_iter_0_right) 148 | 149 | x_comb_iter_1_left = self.comb_iter_1_left(x_right) 150 | x_comb_iter_1_right = self.comb_iter_1_right(x_left) 151 | x_comb_iter_1 = shrink_sum(x_comb_iter_1_left, x_comb_iter_1_right) 152 | 153 | x_comb_iter_2_left = self.comb_iter_2_left(x_right) 154 | x_comb_iter_2_right = self.comb_iter_2_right(x_left) 155 | x_comb_iter_2 = shrink_sum(x_comb_iter_2_left, x_comb_iter_2_right) 156 | 157 | x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) 158 | x_comb_iter_3 = shrink_sum(x_comb_iter_3_right, x_comb_iter_1) 159 | 160 | x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) 161 | x_comb_iter_4_right = self.comb_iter_4_right(x_right) 162 | x_comb_iter_4 = shrink_sum(x_comb_iter_4_left, x_comb_iter_4_right) 163 | 164 | x_out = shrink_cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) 165 | return x_out 166 | 167 | 168 | def ReductionCell1_forward(self, x, x_prev): 169 | x_left = self.conv_prev_1x1(x_prev) 170 | x_right = self.conv_1x1(x) 171 | 172 | x_comb_iter_0_left = self.comb_iter_0_left(x_right) 173 | x_comb_iter_0_right = self.comb_iter_0_right(x_left) 174 | x_comb_iter_0 = shrink_sum(x_comb_iter_0_left, x_comb_iter_0_right) 175 | 176 | x_comb_iter_1_left = self.comb_iter_1_left(x_right) 177 | x_comb_iter_1_right = self.comb_iter_1_right(x_left) 178 | x_comb_iter_1 = shrink_sum(x_comb_iter_1_left, x_comb_iter_1_right) 179 | 180 | x_comb_iter_2_left = self.comb_iter_2_left(x_right) 181 | x_comb_iter_2_right = self.comb_iter_2_right(x_left) 182 | x_comb_iter_2 = shrink_sum(x_comb_iter_2_left, x_comb_iter_2_right) 183 | 184 | x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) 185 | x_comb_iter_3 = shrink_sum(x_comb_iter_3_right, x_comb_iter_1) 186 | 187 | x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) 188 | x_comb_iter_4_right = self.comb_iter_4_right(x_right) 189 | x_comb_iter_4 = shrink_sum(x_comb_iter_4_left, x_comb_iter_4_right) 190 | 191 | x_out = shrink_cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) 192 | return x_out 193 | 194 | 195 | def nasnetamobile(*args, **kwargs): 196 | pretrainedmodels.models.nasnet_mobile.CellStem0.forward = CellStem0_forward 197 | pretrainedmodels.models.nasnet_mobile.CellStem1.forward = CellStem1_forward 198 | pretrainedmodels.models.nasnet_mobile.FirstCell.forward = FirstCell_forward 199 | pretrainedmodels.models.nasnet_mobile.NormalCell.forward = NormalCell_forward 200 | pretrainedmodels.models.nasnet_mobile.ReductionCell0.forward = ReductionCell0_forward 201 | pretrainedmodels.models.nasnet_mobile.ReductionCell1.forward = ReductionCell1_forward 202 | model = pretrainedmodels.models.nasnetamobile(*args, **kwargs) 203 | return model 204 | -------------------------------------------------------------------------------- /multigrain/backbones/pnasnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """ 8 | Monkey-patches https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py 9 | The source implementation cannot be applied to images of arbitrary shape 10 | Here we add cropping operations before additions and concatenations to address this. 11 | """ 12 | import pretrainedmodels 13 | 14 | import torch 15 | 16 | 17 | def equal_except(a, b, avoid=None): 18 | for i, (ai, bi) in enumerate(zip(a, b)): 19 | if ai != bi and (avoid is None or i != avoid): 20 | return False 21 | return True 22 | 23 | 24 | def shrink_common(*tensors, avoid=None): 25 | sizes = [tuple(t.size()) for t in tensors] 26 | st = tuple(min(*dims) for dims in zip(*sizes)) 27 | out_tensors = [] 28 | for t, s in zip(tensors, sizes): 29 | if not equal_except(s, st, avoid): 30 | dest_size = list(st) 31 | if avoid is not None: 32 | dest_size[avoid] = s[avoid] 33 | t = t.__getitem__(list(slice(si) for si in dest_size)) 34 | out_tensors.append(t) 35 | return out_tensors 36 | 37 | 38 | def shrink_sum(*tensors): 39 | tensors = shrink_common(*tensors) 40 | return sum(tensors) 41 | 42 | def shrink_cat(tensors, dim=1): 43 | tensors = shrink_common(*tensors, avoid=dim) 44 | return torch.cat(tensors, dim=1) 45 | 46 | 47 | def cell_forward(self, x_left, x_right): 48 | x_comb_iter_0_left = self.comb_iter_0_left(x_left) 49 | x_comb_iter_0_right = self.comb_iter_0_right(x_left) 50 | x_comb_iter_0 = shrink_sum(x_comb_iter_0_left, x_comb_iter_0_right) 51 | 52 | x_comb_iter_1_left = self.comb_iter_1_left(x_right) 53 | x_comb_iter_1_right = self.comb_iter_1_right(x_right) 54 | x_comb_iter_1 = shrink_sum(x_comb_iter_1_left, x_comb_iter_1_right) 55 | 56 | x_comb_iter_2_left = self.comb_iter_2_left(x_right) 57 | x_comb_iter_2_right = self.comb_iter_2_right(x_right) 58 | x_comb_iter_2 = shrink_sum(x_comb_iter_2_left, x_comb_iter_2_right) 59 | 60 | x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2) 61 | x_comb_iter_3_right = self.comb_iter_3_right(x_right) 62 | x_comb_iter_3 = shrink_sum(x_comb_iter_3_left, x_comb_iter_3_right) 63 | 64 | x_comb_iter_4_left = self.comb_iter_4_left(x_left) 65 | if self.comb_iter_4_right: 66 | x_comb_iter_4_right = self.comb_iter_4_right(x_right) 67 | else: 68 | x_comb_iter_4_right = x_right 69 | x_comb_iter_4 = shrink_sum(x_comb_iter_4_left, x_comb_iter_4_right) 70 | 71 | x_out = shrink_cat( 72 | [x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, 73 | x_comb_iter_4], 1) 74 | return x_out 75 | 76 | 77 | def pnasnet5large(*args, num_classes=1000, **kwargs): 78 | pretrainedmodels.models.pnasnet.CellBase.cell_forward = cell_forward 79 | model = pretrainedmodels.models.pnasnet5large(*args, num_classes=num_classes, **kwargs) 80 | return model 81 | -------------------------------------------------------------------------------- /multigrain/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .id_dataset import IdDataset 2 | from .imagenet import IN1K 3 | from .loader import loader as default_loader 4 | from .loader import preloader 5 | from .list_dataset import ListDataset 6 | 7 | -------------------------------------------------------------------------------- /multigrain/datasets/holidays-rotate.yaml: -------------------------------------------------------------------------------- 1 | { 2 | "100600.jpg": 90, 3 | "100700.jpg": 270, 4 | "100800.jpg": 270, 5 | "100801.jpg": 270, 6 | "100900.jpg": 270, 7 | "100902.jpg": 270, 8 | "101000.jpg": 90, 9 | "101001.jpg": 90, 10 | "101100.jpg": 90, 11 | "101200.jpg": 90, 12 | "101201.jpg": 90, 13 | "101301.jpg": 90, 14 | "102200.jpg": 90, 15 | "102201.jpg": 90, 16 | "103200.jpg": 90, 17 | "103201.jpg": 90, 18 | "103202.jpg": 90, 19 | "109301.jpg": 90, 20 | "110100.jpg": 90, 21 | "110101.jpg": 90, 22 | "110800.jpg": 90, 23 | "110802.jpg": 90, 24 | "111500.jpg": 90, 25 | "112100.jpg": 90, 26 | "112101.jpg": 90, 27 | "112200.jpg": 90, 28 | "112201.jpg": 90, 29 | "114701.jpg": 90, 30 | "114901.jpg": 90, 31 | "114902.jpg": 90, 32 | "115001.jpg": 90, 33 | "116300.jpg": 90, 34 | "116501.jpg": 90, 35 | "117700.jpg": 90, 36 | "118300.jpg": 90, 37 | "118302.jpg": 90, 38 | "118803.jpg": 90, 39 | "119601.jpg": 90, 40 | "119602.jpg": 90, 41 | "119604.jpg": 90, 42 | "119703.jpg": 90, 43 | "120000.jpg": 90, 44 | "120001.jpg": 90, 45 | "120100.jpg": 90, 46 | "120500.jpg": 90, 47 | "120501.jpg": 90, 48 | "120601.jpg": 90, 49 | "121000.jpg": 90, 50 | "121001.jpg": 90, 51 | "121200.jpg": 90, 52 | "121201.jpg": 90, 53 | "121300.jpg": 90, 54 | "121301.jpg": 90, 55 | "121400.jpg": 90, 56 | "121401.jpg": 90, 57 | "121402.jpg": 90, 58 | "121403.jpg": 90, 59 | "121500.jpg": 90, 60 | "121501.jpg": 90, 61 | "121700.jpg": 90, 62 | "121701.jpg": 90, 63 | "121800.jpg": 90, 64 | "121901.jpg": 90, 65 | "122600.jpg": 90, 66 | "122601.jpg": 90, 67 | "122602.jpg": 90, 68 | "122702.jpg": 90, 69 | "122707.jpg": 270, 70 | "122708.jpg": 90, 71 | "122800.jpg": 90, 72 | "122801.jpg": 90, 73 | "122900.jpg": 90, 74 | "122901.jpg": 90, 75 | "124701.jpg": 270, 76 | "124802.jpg": 90, 77 | "125400.jpg": 90, 78 | "126000.jpg": 90, 79 | "126701.jpg": 90, 80 | "126801.jpg": 90, 81 | "126803.jpg": 90, 82 | "126807.jpg": 270, 83 | "127700.jpg": 90, 84 | "127701.jpg": 90, 85 | "127800.jpg": 90, 86 | "127802.jpg": 90, 87 | "127900.jpg": 270, 88 | "128200.jpg": 90, 89 | "128500.jpg": 90, 90 | "128502.jpg": 90, 91 | "129700.jpg": 90, 92 | "132000.jpg": 90, 93 | "132001.jpg": 270, 94 | "132100.jpg": 90, 95 | "132101.jpg": 90, 96 | "132102.jpg": 90, 97 | "132200.jpg": 270, 98 | "132201.jpg": 270, 99 | "132301.jpg": 90, 100 | "132401.jpg": 90, 101 | "132700.jpg": 90, 102 | "132701.jpg": 90, 103 | "132800.jpg": 90, 104 | "132801.jpg": 90, 105 | "133000.jpg": 90, 106 | "133001.jpg": 90, 107 | "133002.jpg": 90, 108 | "133003.jpg": 90, 109 | "133100.jpg": 90, 110 | "133101.jpg": 90, 111 | "133200.jpg": 90, 112 | "133201.jpg": 90, 113 | "133202.jpg": 90, 114 | "133802.jpg": 90, 115 | "134001.jpg": 90, 116 | "134002.jpg": 90, 117 | "134003.jpg": 90, 118 | "134503.jpg": 90, 119 | "134504.jpg": 90, 120 | "135401.jpg": 90, 121 | "135503.jpg": 270, 122 | "135600.jpg": 90, 123 | "135601.jpg": 90, 124 | "135700.jpg": 90, 125 | "136300.jpg": 90, 126 | "136301.jpg": 90, 127 | "136503.jpg": 90, 128 | "136900.jpg": 90, 129 | "136901.jpg": 90, 130 | "137004.jpg": 90, 131 | "137005.jpg": 90, 132 | "137006.jpg": 90, 133 | "137100.jpg": 90, 134 | "137101.jpg": 90, 135 | "137102.jpg": 90, 136 | "137103.jpg": 90, 137 | "137302.jpg": 90, 138 | "138000.jpg": 90, 139 | "138001.jpg": 90, 140 | "138504.jpg": 270, 141 | "138507.jpg": 90, 142 | "138703.jpg": 90, 143 | "138704.jpg": 90, 144 | "138705.jpg": 90, 145 | "138902.jpg": 90, 146 | "138903.jpg": 90, 147 | "138904.jpg": 90, 148 | "139001.jpg": 270, 149 | "139002.jpg": 270, 150 | "139003.jpg": 90, 151 | "139004.jpg": 90, 152 | "139005.jpg": 90, 153 | "139102.jpg": 270, 154 | "139103.jpg": 270, 155 | "139105.jpg": 90, 156 | "139300.jpg": 270, 157 | "139301.jpg": 270, 158 | "139302.jpg": 180, 159 | "139303.jpg": 270, 160 | "139400.jpg": 270, 161 | "139404.jpg": 90, 162 | "139502.jpg": 270, 163 | "139600.jpg": 90, 164 | "139601.jpg": 90, 165 | "139602.jpg": 90, 166 | "139603.jpg": 90, 167 | "139701.jpg": 90, 168 | "140500.jpg": 90, 169 | "140503.jpg": 90, 170 | "140800.jpg": 90, 171 | "141500.jpg": 90, 172 | "141501.jpg": 90, 173 | "141601.jpg": 270, 174 | "141702.jpg": 270, 175 | "141704.jpg": 90, 176 | "141801.jpg": 270, 177 | "142501.jpg": 90, 178 | "142600.jpg": 90, 179 | "144300.jpg": 90, 180 | "144900.jpg": 90, 181 | "145202.jpg": 90, 182 | "145301.jpg": 90, 183 | "145700.jpg": 90, 184 | "145800.jpg": 90, 185 | "145801.jpg": 90, 186 | "145901.jpg": 90, 187 | "146001.jpg": 90, 188 | "146002.jpg": 90, 189 | "146003.jpg": 90, 190 | "146102.jpg": 90, 191 | "146202.jpg": 90, 192 | "146207.jpg": 90, 193 | "146301.jpg": 90, 194 | "146303.jpg": 90, 195 | "146400.jpg": 90, 196 | "146402.jpg": 90, 197 | "146403.jpg": 90, 198 | "146404.jpg": 90, 199 | "146600.jpg": 90, 200 | "146700.jpg": 90, 201 | "146800.jpg": 90, 202 | "146801.jpg": 90, 203 | "146901.jpg": 90, 204 | "147201.jpg": 90, 205 | "147302.jpg": 90, 206 | "147304.jpg": 90, 207 | "147600.jpg": 90, 208 | "147601.jpg": 90, 209 | "147602.jpg": 90, 210 | "147603.jpg": 90, 211 | "147900.jpg": 90, 212 | "147901.jpg": 90, 213 | "148003.jpg": 90, 214 | "148004.jpg": 90, 215 | "148401.jpg": 90, 216 | "148501.jpg": 90, 217 | "148800.jpg": 90, 218 | "148900.jpg": 90, 219 | "148901.jpg": 90, 220 | "149000.jpg": 90, 221 | "149001.jpg": 90, 222 | "149201.jpg": 90, 223 | "149301.jpg": 90, 224 | "149400.jpg": 90, 225 | "127200.jpg": 90, 226 | "127202.jpg": 90, 227 | "127201.jpg": 90, 228 | "129401.jpg": 270, 229 | "148902.jpg": 90, 230 | } -------------------------------------------------------------------------------- /multigrain/datasets/id_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch.utils.data as data 8 | 9 | 10 | class IdDataset(data.Dataset): 11 | """ 12 | Return image id with getitem in dataset 13 | """ 14 | 15 | def __init__(self, dataset): 16 | self.dataset = dataset 17 | 18 | def __getitem__(self, index): 19 | returns = self.dataset[index] 20 | return_dict = {} 21 | if not isinstance(returns, dict): 22 | return_dict['input'], return_dict['classifier_target'] = returns 23 | return_dict['instance_target'] = index 24 | return return_dict 25 | 26 | def __len__(self): 27 | return len(self.dataset) 28 | 29 | def __repr__(self): 30 | return "IdDataset(" + repr(self.dataset) + ")" -------------------------------------------------------------------------------- /multigrain/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """ 8 | Imagenet is either distributed along with a devkit to get the validation labels, 9 | or with the validation set reorganized into different subsets. 10 | Here we support both. 11 | Keeps an index of the images for fast initialization. 12 | """ 13 | 14 | import torch 15 | from torch.utils import data 16 | import os 17 | from os import path as osp 18 | import numpy as np 19 | from .loader import loader as default_loader 20 | from multigrain.utils import ifmakedirs 21 | 22 | 23 | class IN1K(data.Dataset): 24 | """ 25 | ImageNet 1K dataset 26 | Classes numbered from 0 to 999 inclusive 27 | Can deal both with ImageNet original structure and the common "reorganized" validation dataset 28 | Caches list of files for faster reloading. 29 | """ 30 | NUM_CLASSES = 1000 31 | MEAN = [0.485, 0.456, 0.406] 32 | STD = [0.229, 0.224, 0.225] 33 | EIG_VALS = [0.2175, 0.0188, 0.0045] 34 | EIG_VECS = np.array([ 35 | [-0.5675, 0.7192, 0.4009], 36 | [-0.5808, -0.0045, -0.8140], 37 | [-0.5836, -0.6948, 0.4203] 38 | ]) 39 | 40 | def __init__(self, root, split='train', transform=None, force_reindex=False, loader=default_loader): 41 | self.root = root 42 | self.transform = transform 43 | self.split = split 44 | cachefile = 'data/IN1K-' + split + '-cached-list.pth' 45 | self.classes, self.class_to_idx, self.imgs, self.labels, self.images_subdir = self.get_dataset(cachefile, force_reindex) 46 | self.loader = loader 47 | 48 | def get_dataset(self, cachefile=None, force_reindex=False): 49 | if osp.isfile(cachefile) and not force_reindex: 50 | print('Loaded IN1K {} dataset from cache: {}...'.format(self.split, cachefile)) 51 | return torch.load(cachefile) 52 | 53 | print('Indexing IN1K {} dataset...'.format(self.split), end=' ') 54 | for images_subdir in [self.split, 'ILSVRC2012_img_' + self.split]: 55 | if osp.isdir(osp.join(self.root, images_subdir)): 56 | break 57 | else: 58 | raise ValueError('Split {} not found'.format(self.split)) 59 | self.images_subdir = images_subdir 60 | subfiles = os.listdir(osp.join(self.root, images_subdir)) 61 | if osp.isdir(osp.join(self.root, images_subdir, subfiles[0])): # ImageFolder 62 | classes = [folder for folder in subfiles if folder.startswith('n')] 63 | classes.sort() 64 | class_to_idx = {c: i for (i, c) in enumerate(classes)} 65 | imgs = [] 66 | labels = [] 67 | for label in classes: 68 | label_images = os.listdir(osp.join(self.root, images_subdir, label)) 69 | label_images.sort() 70 | imgs.extend([osp.join(label, i) for i in label_images]) 71 | labels.extend([class_to_idx[label] for _ in label_images]) 72 | else: # DevKit 73 | try: 74 | import mat4py 75 | except ImportError: 76 | print('Install mat4py to discover classes from meta.mat') 77 | raise 78 | synsets = mat4py.loadmat(osp.join(self.root, 79 | 'ILSVRC2012_devkit_t12', 80 | 'data', 81 | 'meta.mat'))['synsets'] 82 | 83 | ilsvrc_label_to_wnid = {label: wn 84 | for (wn, label) in zip(synsets['WNID'], 85 | synsets['ILSVRC2012_ID']) 86 | if label <= self.NUM_CLASSES} 87 | classes = sorted(ilsvrc_label_to_wnid.values()) 88 | class_to_idx = {c: i for (i, c) in enumerate(classes) if i < self.NUM_CLASSES} 89 | imgs = sorted(subfiles) 90 | ilsvrc_labels = np.loadtxt(osp.join(self.root, 91 | 'ILSVRC2012_devkit_t12', 92 | 'data', 93 | 'ILSVRC2012_validation_ground_truth.txt' 94 | ), dtype=int) 95 | labels = [class_to_idx[ilsvrc_label_to_wnid[l]] for l in ilsvrc_labels] 96 | 97 | sort_by_label = sorted(zip(labels, imgs)) 98 | labels, imgs = list(zip(*sort_by_label)) 99 | print('OK!', end='') 100 | returns = (classes, class_to_idx, imgs, labels, images_subdir) 101 | if cachefile is not None: 102 | ifmakedirs(osp.dirname(cachefile)) 103 | torch.save(returns, cachefile) 104 | print(' cached to', cachefile) 105 | print() 106 | return returns 107 | 108 | def __getitem__(self, idx): 109 | image = self.loader(osp.join(self.root, self.images_subdir, self.imgs[idx])) 110 | if self.transform is not None: 111 | image = self.transform(image) 112 | return (image, self.labels[idx]) 113 | 114 | def __len__(self): 115 | return len(self.imgs) 116 | 117 | def __repr__(self): 118 | return "IN1K(root='{}', split='{}')".format(self.root, self.split) 119 | -------------------------------------------------------------------------------- /multigrain/datasets/list_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from torch.utils import data 8 | import os.path as osp 9 | from .loader import loader as default_loader 10 | 11 | 12 | class ListDataset(data.Dataset): 13 | """ 14 | Unlabelled images dataset from list of images and root 15 | """ 16 | 17 | def __init__(self, root, imagelist, transform=None, loader=default_loader): 18 | self.root = root 19 | self.transform = transform 20 | self.imgs = imagelist 21 | if isinstance(imagelist, str): 22 | self.imgs = [] 23 | if not osp.isfile(imagelist): 24 | raise FileNotFoundError('Image list not found at {}'.format(imagelist)) 25 | with open(imagelist) as f: 26 | for im in f: 27 | im = im.strip() 28 | if not im: continue 29 | self.imgs.append(im) 30 | self.loader = loader 31 | 32 | def __getitem__(self, idx): 33 | image = self.loader(osp.join(self.root, self.imgs[idx])) 34 | if self.transform is not None: 35 | image = self.transform(image) 36 | return image 37 | 38 | def __len__(self): 39 | return len(self.imgs) 40 | -------------------------------------------------------------------------------- /multigrain/datasets/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from shutil import copyfile 8 | import os.path as osp 9 | from PIL import Image 10 | from PIL import ImageFile 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | from multigrain.utils import ifmakedirs 13 | 14 | 15 | def loader(path): 16 | return Image.open(path).convert('RGB') 17 | 18 | 19 | def preloader(dataset_root, preload_dir): 20 | def this_loader(path): 21 | dest_path = osp.join(preload_dir, osp.relpath(path, dataset_root)) 22 | ifmakedirs(osp.dirname(dest_path)) 23 | 24 | if not osp.isfile(dest_path): 25 | copyfile(path, dest_path) 26 | 27 | image = loader(dest_path) 28 | return image 29 | return this_loader -------------------------------------------------------------------------------- /multigrain/datasets/retrieval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch.utils.data as data 8 | from glob import glob 9 | import os 10 | import os.path as osp 11 | from .loader import loader 12 | from collections import OrderedDict as OD 13 | from torchvision.datasets.utils import download_url 14 | from multigrain.utils.misc import ifmakedirs 15 | import zipfile 16 | import tarfile 17 | import yaml 18 | 19 | 20 | class DownloadableDataset(data.Dataset): 21 | URLS = [] 22 | NUM_FILES = [] 23 | 24 | def __init__(self, root): 25 | self.root = root 26 | 27 | def _check_exists(self): 28 | for pattern, number in self.NUM_FILES: 29 | if len(glob(osp.join(self.root, pattern))) != number: 30 | return False 31 | return True 32 | 33 | def download(self, remove_finished=True): 34 | if self._check_exists(): 35 | return 36 | 37 | ifmakedirs(self.root) 38 | 39 | for url in self.URLS: 40 | subdest = '' 41 | if isinstance(url, tuple): 42 | subdest, url = url 43 | # download file 44 | filename = url.rpartition('/')[2] 45 | file_path = osp.join(self.root, filename) 46 | download_url(url, root=self.root, filename=filename, md5=None) 47 | dest = osp.join(self.root, subdest) 48 | ifmakedirs(dest) 49 | if filename.endswith('.zip'): 50 | self.extract_zip(zip_path=file_path, dest=dest, remove_finished=remove_finished) 51 | elif filename.endswith('.tar') or filename.endswith('.tar.gz'): 52 | self.extract_tar(file_path, dest=dest, remove_finished=remove_finished) 53 | else: 54 | raise ValueError('File {}: has unknown extension'.format(filename)) 55 | print('Done!') 56 | 57 | @staticmethod 58 | def extract_zip(zip_path, dest, remove_finished=True): 59 | zip_ref = zipfile.ZipFile(zip_path, 'r') 60 | zip_ref.extractall(dest) 61 | zip_ref.close() 62 | if remove_finished: 63 | os.unlink(zip_ref) 64 | 65 | @staticmethod 66 | def extract_tar(fname, dest, remove_finished=True): 67 | if (fname.endswith("tar.gz")): 68 | tar = tarfile.open(fname, "r:gz") 69 | tar.extractall(path=dest) 70 | tar.close() 71 | elif (fname.endswith("tar")): 72 | tar = tarfile.open(fname, "r:") 73 | tar.extractall(path=dest) 74 | tar.close() 75 | if remove_finished: 76 | os.unlink(fname) 77 | 78 | 79 | class UKBench(DownloadableDataset): 80 | """UKBench dataset.""" 81 | 82 | URLS = ['https://archive.org/download/ukbench/ukbench.zip'] 83 | NUM_FILES = [('*.jpg', 10200)] 84 | 85 | def __init__(self, root, transform=None, download=False): 86 | self.root = root 87 | if download: 88 | self.download() 89 | images = glob(osp.join(self.root, '*.jpg')) 90 | grouped = [] 91 | for i in range(0, len(images), 4): 92 | grouped.append([images[i], images[i + 1], images[i + 2], images[i + 3]]) 93 | self.imgs = [] 94 | self.class_groups = {} 95 | for c, G in enumerate(grouped): 96 | for im in G: 97 | self.imgs.append((im, c)) 98 | self.class_groups.setdefault(c, []).append(len(self.imgs) - 1) 99 | self.loader = loader 100 | self.transform = transform 101 | 102 | def __len__(self): 103 | return len(self.imgs) 104 | 105 | def __getitem__(self, index): 106 | path, target = self.imgs[index] 107 | sample = self.loader(path) 108 | if self.transform is not None: 109 | sample = self.transform(sample) 110 | 111 | return sample, target 112 | 113 | 114 | class Holidays(DownloadableDataset): 115 | """Holidays dataset.""" 116 | URLS = ['ftp://ftp.inrialpes.fr/pub/lear/douze/data/jpg1.tar.gz', 117 | 'ftp://ftp.inrialpes.fr/pub/lear/douze/data/jpg2.tar.gz'] 118 | NUM_FILES = [(osp.join('jpg', '*.jpg'), 1491)] 119 | 120 | def __init__(self, root, transform=None, rotated=True, download=False): 121 | self.root = root 122 | if download: 123 | self.download() 124 | images = sorted(glob(osp.join(self.root, 'jpg', '*.jpg'))) 125 | cur_group = [images[0]] 126 | grouped = [] 127 | for i in images[1:]: 128 | if int(osp.basename(i[:-len('.jpg')])) % 100: 129 | cur_group.append(i) 130 | else: 131 | grouped.append(cur_group) 132 | cur_group = [i] 133 | self.imgs = [] 134 | self.class_groups = {} 135 | for c, G in enumerate(grouped): 136 | for im in G: 137 | self.imgs.append((im, c)) 138 | self.class_groups.setdefault(c, []).append(len(self.imgs) - 1) 139 | self.loader = loader 140 | self.transform = transform 141 | self.rotated = None 142 | if rotated: 143 | self.rotated = yaml.load(open(osp.join(osp.dirname(__file__), 'holidays-rotate.yaml'))) 144 | 145 | def __len__(self): 146 | return len(self.imgs) 147 | 148 | def __getitem__(self, index): 149 | path, target = self.imgs[index] 150 | sample = self.loader(path) 151 | if self.rotated is not None: 152 | rotation = self.rotated.get(osp.basename(path), 0) 153 | if rotation != 0: 154 | sample = sample.rotate(-rotation, expand=True) 155 | if self.transform is not None: 156 | sample = self.transform(sample) 157 | 158 | return sample, target 159 | 160 | 161 | class CopyDays(DownloadableDataset): 162 | """CopyDays dataset.""" 163 | URLS = ['http://pascal.inrialpes.fr/data/holidays/copydays_original.tar.gz', 164 | 'http://pascal.inrialpes.fr/data/holidays/copydays_crop.tar.gz', 165 | 'http://pascal.inrialpes.fr/data/holidays/copydays_jpeg.tar.gz', 166 | ('strong', 'http://pascal.inrialpes.fr/data/holidays/copydays_strong.tar.gz')] 167 | NUM_FILES = [('*.jpg', 157), ('*/*.jpg', 229), ('*/*/*.jpg', 2826)] 168 | 169 | def __init__(self, root, subset=None, transform=None): 170 | self.root = root 171 | self.download() 172 | 173 | self.distractors = distractors 174 | self.num_distractors = num_distractors 175 | self.subset = subset 176 | 177 | avail = OD() 178 | for x in os.walk(root): 179 | for filename in x[2]: 180 | id = int(filename.split('.')[0]) 181 | transf = osp.relpath(x[0], root) 182 | id = (id // 100) * 100 183 | if transf == '.': 184 | transf = '' 185 | avail.setdefault(id, OD())[transf] = osp.join(x[0], filename) 186 | 187 | if not avail: 188 | raise ValueError("Dataset not found in {}".format(root)) 189 | 190 | transfs = [] 191 | self.images = [] 192 | self.class_groups = [] 193 | for i, id in enumerate(avail): 194 | im = {'variant': -1, 195 | 'input': avail[id][''], 196 | 'target': i} 197 | self.images.append(im) 198 | cur_group = [len(self.images) - 1] 199 | for transf in avail[id]: 200 | if not transf: 201 | continue 202 | if self.subset is not None and transf not in self.subset: 203 | continue 204 | if transf not in transfs: 205 | transfs.append(transf) 206 | im = {'variant': transfs.index(transf), 207 | 'input': avail[id][transf], 208 | 'target': i} 209 | self.images.append(im) 210 | cur_group.append(len(self.images) - 1) 211 | self.class_groups.append(cur_group) 212 | 213 | self.gen_distractors() 214 | # if num_distractors is not None and len(self.distractors) > num_distractors: 215 | # self.distractors = self.distractors[:num_distractors] 216 | self.transfs = transfs 217 | self.loader = loader 218 | self.transform = transform 219 | 220 | def gen_distractors(self): 221 | self.distractor_list = [] # glob(distractors) 222 | for dirpath, subdirs, files in os.walk(self.distractors): 223 | for x in files: 224 | if x.endswith('.jpg'): 225 | self.distractor_list.append(osp.join(dirpath, x)) 226 | if len(self.distractor_list) >= self.num_distractors: 227 | return 228 | 229 | def __len__(self): 230 | return len(self.images) + len(self.distractor_list) 231 | 232 | def __getitem__(self, index): 233 | if index < len(self.images): 234 | return_dict = self.images[index].copy() 235 | else: 236 | index -= len(self.images) 237 | return_dict = {'transf': -1, 238 | 'input': self.distractor_list[index], 239 | 'target': -1} 240 | im = self.loader(return_dict['input']) 241 | if self.transform is not None: 242 | im = self.transform(im) 243 | return_dict['input'] = im 244 | return return_dict 245 | 246 | -------------------------------------------------------------------------------- /multigrain/lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .multigrain import get_multigrain, MultiGrain 2 | from .samplers import RASampler, list_collate -------------------------------------------------------------------------------- /multigrain/lib/multigrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | from torch import nn 9 | from multigrain.backbones import BackBone 10 | import torch.utils.model_zoo as model_zoo 11 | from multigrain.modules.layers import Layer, Select 12 | from multigrain.modules import DistanceWeightedSampling 13 | from collections import OrderedDict as OD 14 | 15 | 16 | __all__ = ['multigrain'] 17 | 18 | 19 | model_urls = { 20 | ('multigrain_resnet50'): '', 21 | ('multigrain_pnasnet5large'): 'https://dl.fbaipublicfiles.com/multigrain/finetuned_models/pnasnet5large-finetune500.pth', 22 | ('multigrain_senet154'): 'https://dl.fbaipublicfiles.com/multigrain/finetuned_models/senet154-finetune400.pth', 23 | } 24 | 25 | 26 | class MultiGrain(BackBone): 27 | """ 28 | Implement MultiGrain by changing the pooling layer of the backbone into GeM pooling with exponent p, 29 | and adding DistanceWeightedSampling for the margin loss. 30 | """ 31 | def __init__(self, backbone, p=3.0, include_sampling=True, learn_p=False, **kwargs): 32 | super().__init__(backbone, **kwargs) 33 | if not torch.is_tensor(p): 34 | p = torch.tensor(p) 35 | if learn_p: 36 | p.requires_grad = True 37 | self.pool = Layer('gem', p=p) 38 | self.normalize = Layer('l2n') 39 | if include_sampling: 40 | self.weighted_sampling = DistanceWeightedSampling() 41 | self.whitening = None 42 | 43 | def load_state_dict(self, D, *args, **kwargs): 44 | # adjust whitening and bias during load_state_dict 45 | for (k, v) in D.items(): 46 | parts = k.split('.') 47 | if parts[-1] in ('pca_P', 'pca_m'): 48 | if self.whitening is None: 49 | self.init_whitening(loading_checkpoint=True) 50 | getattr(self.whitening.pca, parts[-1]).resize_(v.size()) 51 | super().load_state_dict(D, *args, **kwargs) 52 | 53 | def init_whitening(self, loading_checkpoint=False): 54 | """ 55 | Initialize whitening operation (see scripts/whiten.py) 56 | """ 57 | self.whitening = nn.Sequential(OD(normalize=Layer('l2n'), 58 | pca=Layer('apply_pca', pca_P=torch.tensor([]), 59 | pca_m=torch.tensor([])))) 60 | # integrate bias in classifier to make it invariant to the input normalization 61 | self.pool.kwargs['add_bias'] = True 62 | if self.pre_classifier is not None: 63 | self.pre_classifier = Select(self.pre_classifier, -1) 64 | W, b = self.classifier.weight, self.classifier.bias 65 | W.data = torch.cat((W.data, b.data.view(-1, 1)), 1) 66 | if not loading_checkpoint: 67 | self.classifier.bias = None 68 | 69 | def integrate_whitening(self, m, P): 70 | """ 71 | Set whitening parameters and add their reverse in classifier (see scripts/whiten.py) 72 | """ 73 | Pinv = P.t().double().inverse() 74 | self.whitening.pca.pca_m.data.resize_(m.size()).copy_(m) 75 | self.whitening.pca.pca_P.data.resize_(P.size()).copy_(P) 76 | W = self.classifier.weight 77 | self.classifier.bias = nn.Parameter(m.to(W.device).matmul(W.data.t())) 78 | W.data = torch.matmul(W.data.double(), Pinv.to(W.device)).float() 79 | 80 | def forward(self, input, instance_target=None, **kwargs): 81 | if isinstance(instance_target, list): 82 | instance_target = torch.stack(instance_target) 83 | output_dict = {'instance_target': instance_target} 84 | output_dict['embedding'], output_dict['classifier_output'] = super().forward(input, **kwargs) 85 | output_dict['normalized_embedding'] = self.normalize(output_dict['embedding']) 86 | 87 | if hasattr(self, 'weighted_sampling') and instance_target is not None: 88 | sampled = self.weighted_sampling(output_dict['normalized_embedding'], instance_target) 89 | output_dict.update(sampled) 90 | 91 | return output_dict['classifier_output'] 92 | 93 | 94 | def get_multigrain(backbone='resnet50', pretrained=None, pretrained_backbone=None, **kwargs): 95 | kwargs['pretrained'] = pretrained_backbone 96 | model = MultiGrain(backbone, **kwargs) 97 | if pretrained: 98 | model.load_state_dict(model_zoo.load_url(model_urls['multigrain_' + backbone])) 99 | return model 100 | -------------------------------------------------------------------------------- /multigrain/lib/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from torch.utils.data.sampler import BatchSampler 8 | import torch 9 | import numpy as np 10 | from torch.utils.data.dataloader import default_collate 11 | from collections.abc import Mapping, Sequence 12 | 13 | 14 | class RASampler(torch.utils.data.Sampler): 15 | """ 16 | Batch Sampler with Repeated Augmentations (RA) 17 | - dataset_len: original length of the dataset 18 | - batch_size 19 | - repetitions: instances per image 20 | - len_factor: multiplicative factor for epoch size 21 | """ 22 | 23 | def __init__(self, dataset_len, batch_size, repetitions=1, len_factor=1.0, shuffle=False, drop_last=False): 24 | self.dataset_len = dataset_len 25 | self.batch_size = batch_size 26 | self.repetitions = repetitions 27 | self.len_images = int(dataset_len * len_factor) 28 | self.shuffle = shuffle 29 | self.drop_last = drop_last 30 | 31 | def shuffler(self): 32 | if self.shuffle: 33 | new_perm = lambda: iter(np.random.permutation(self.dataset_len)) 34 | else: 35 | new_perm = lambda: iter(np.arange(self.dataset_len)) 36 | shuffle = new_perm() 37 | while True: 38 | try: 39 | index = next(shuffle) 40 | except StopIteration: 41 | shuffle = new_perm() 42 | index = next(shuffle) 43 | for repetition in range(self.repetitions): 44 | yield index 45 | 46 | def __iter__(self): 47 | shuffle = iter(self.shuffler()) 48 | seen = 0 49 | batch = [] 50 | for _ in range(self.len_images): 51 | index = next(shuffle) 52 | batch.append(index) 53 | if len(batch) == self.batch_size: 54 | yield batch 55 | batch = [] 56 | if batch and not self.drop_last: 57 | yield batch 58 | 59 | def __len__(self): 60 | if self.drop_last: 61 | return self.len_images // self.batch_size 62 | else: 63 | return (self.len_images + self.batch_size - 1) // self.batch_size 64 | 65 | 66 | def list_collate(batch): 67 | """ 68 | Collate into a list instead of a tensor to deal with variable-sized inputs 69 | """ 70 | elem_type = type(batch[0]) 71 | if isinstance(batch[0], torch.Tensor): 72 | return batch 73 | elif elem_type.__module__ == 'numpy': 74 | if elem_type.__name__ == 'ndarray': 75 | return list_collate([torch.from_numpy(b) for b in batch]) 76 | elif isinstance(batch[0], Mapping): 77 | return {key: list_collate([d[key] for d in batch]) for key in batch[0]} 78 | elif isinstance(batch[0], Sequence): 79 | transposed = zip(*batch) 80 | return [list_collate(samples) for samples in transposed] 81 | return default_collate(batch) 82 | -------------------------------------------------------------------------------- /multigrain/lib/whiten.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from sklearn.decomposition import PCA 8 | import torch 9 | import numpy as np 10 | 11 | 12 | def get_whiten(X): 13 | pca = PCA(whiten=True, n_components=X.size(1)) 14 | pca.fit(X.detach().cpu().numpy()) 15 | m = torch.tensor(pca.mean_, dtype=torch.float) 16 | P = torch.tensor(pca.components_.T / np.sqrt(pca.explained_variance_), dtype=torch.float) 17 | return m, P 18 | -------------------------------------------------------------------------------- /multigrain/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .margin import DistanceWeightedSampling, MarginLoss, SampledMarginLoss 2 | from .criterion import MultiCriterion 3 | from .multioptim import MultiOptim 4 | -------------------------------------------------------------------------------- /multigrain/modules/criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from torch import nn 8 | from collections import OrderedDict as OD 9 | 10 | 11 | class MultiCriterion(nn.Module): 12 | """ 13 | Holds a dict of multiple losses with a weighting factor for each loss. 14 | - losses_dict: should be a dict with name as key and (loss, input_keys, weight) as values. 15 | - skip_zero: skip the computation of losses with 0 weight 16 | """ 17 | def __init__(self, losses_dict, skip_zeros=False): 18 | super().__init__() 19 | self.losses = OD() 20 | self.input_keys = OD() 21 | self.weights = OD() 22 | for name, (loss, input_keys, weight) in losses_dict.items(): 23 | self.losses[name] = loss 24 | self.input_keys[name] = input_keys 25 | self.weights[name] = weight 26 | self.losses = nn.ModuleDict(self.losses) 27 | self.skip_zeros = skip_zeros 28 | 29 | def forward(self, input_dict): 30 | return_dict = {} 31 | loss = 0.0 32 | for name, module in self.losses.items(): 33 | for k in self.input_keys[name]: 34 | if k not in input_dict: 35 | raise ValueError('Element {} not found in input.'.format(k)) 36 | if self.weights[name] == 0.0 and self.skip_zeros: 37 | continue 38 | this_loss = module(*[input_dict[k] for k in self.input_keys[name]]) 39 | return_dict[name] = this_loss 40 | loss = loss + self.weights[name] * this_loss 41 | return_dict['loss'] = loss 42 | return return_dict 43 | 44 | -------------------------------------------------------------------------------- /multigrain/modules/functional.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import math 9 | from torch.nn import functional as F 10 | 11 | 12 | def add_bias_channel(x, dim=1): 13 | one_size = list(x.size()) 14 | one_size[dim] = 1 15 | one = x.new_ones(one_size) 16 | return torch.cat((x, one), dim) 17 | 18 | 19 | def flatten(x, keepdims=False): 20 | """ 21 | Flattens B C H W input to B C*H*W output, optionally retains trailing dimensions. 22 | """ 23 | y = x.view(x.size(0), -1) 24 | if keepdims: 25 | for d in range(y.dim(), x.dim()): 26 | y = y.unsqueeze(-1) 27 | return y 28 | 29 | 30 | def gem(x, p=3, eps=1e-6, clamp=True, add_bias=False, keepdims=False): 31 | if p == math.inf or p is 'inf': 32 | x = F.max_pool2d(x, (x.size(-2), x.size(-1))) 33 | elif p == 1 and not (torch.is_tensor(p) and p.requires_grad): 34 | x = F.avg_pool2d(x, (x.size(-2), x.size(-1))) 35 | else: 36 | if clamp: 37 | x = x.clamp(min=eps) 38 | x = F.avg_pool2d(x.pow(p), (x.size(-2), x.size(-1))).pow(1.0 / p) 39 | if add_bias: 40 | x = add_bias_channel(x) 41 | if not keepdims: 42 | x = flatten(x) 43 | return x 44 | 45 | 46 | def apply_pca(vt, pca_P=None, pca_m=None): 47 | do_rotation = torch.is_tensor(pca_P) and pca_P.numel() > 0 48 | do_shift = torch.is_tensor(pca_P) and pca_P.numel() > 0 49 | 50 | if do_rotation or do_shift: 51 | if do_shift: 52 | vt = vt - pca_m 53 | if do_rotation: 54 | vt = torch.matmul(vt, pca_P) 55 | return vt 56 | 57 | 58 | def l2n(x, eps=1e-6, dim=1): 59 | x = x / (torch.norm(x, p=2, dim=dim, keepdim=True) + eps).expand_as(x) 60 | return x -------------------------------------------------------------------------------- /multigrain/modules/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from . import functional as LF 11 | 12 | 13 | class Select(nn.Module): 14 | """ 15 | Apply module on subset of the input channels. 16 | Optionally drops other channels 17 | """ 18 | 19 | def __init__(self, module, end, start=0, drop_other=False): 20 | super().__init__() 21 | self.module = module 22 | self.drop_other = drop_other 23 | self.end = end 24 | self.start = start 25 | 26 | def forward(self, x): 27 | bits = [] 28 | bits.append(x[:, :self.start, ...]) 29 | bits.append(x[:, self.start:self.end, ...]) 30 | bits.append(x[:, self.end:, ...]) 31 | bits[1] = self.module(bits[1]) 32 | if self.drop_other: 33 | return bits[1] 34 | return torch.cat([b for b in bits if b.numel()], 1) 35 | 36 | 37 | class Layer(nn.Module): 38 | """ 39 | General module wrapper for a functional layer. 40 | """ 41 | def __init__(self, name, **kwargs): 42 | super().__init__() 43 | self.name = name 44 | for n, v in kwargs.items(): 45 | if torch.is_tensor(v): 46 | if v.requires_grad: 47 | setattr(self, n, nn.Parameter(v)) 48 | else: 49 | self.register_buffer(n, v) 50 | kwargs[n] = 'self.' + n 51 | self.kwargs = kwargs 52 | 53 | def forward(self, input): 54 | kwargs = self.kwargs.copy() 55 | for (n, v) in kwargs.items(): 56 | if isinstance(v, str) and v.startswith('self.'): 57 | kwargs[n] = getattr(self, v[len('self.'):]) 58 | out = getattr(LF, self.name)(input, **kwargs) 59 | return out 60 | 61 | def __repr__(self): 62 | kwargs = [] 63 | for (left, right) in self.kwargs.items(): 64 | rt = repr(right) 65 | if isinstance(right, str) and right.startswith('self.'): 66 | vs = right[len('self.'):] 67 | v = getattr(self, vs) 68 | if vs in self._buffers and v.numel() <= 1: 69 | rt = v 70 | kwargs.append('{}={}'.format(left, rt)) 71 | kwargs = ', '.join(kwargs) 72 | if kwargs: 73 | kwargs = ', ' + kwargs 74 | return 'Layer(name=' + repr(self.name) + kwargs + ')' 75 | 76 | -------------------------------------------------------------------------------- /multigrain/modules/margin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from torch import nn 8 | import torch 9 | import numpy as np 10 | 11 | 12 | class DistanceWeightedSampling(nn.Module): 13 | r"""Distance weighted sampling. 14 | See "sampling matters in deep embedding learning" paper for details. 15 | Implementation similar to https://github.com/chaoyuaw/sampling_matters 16 | """ 17 | def __init__(self, cutoff=0.5, nonzero_loss_cutoff=1.4): 18 | super().__init__() 19 | self.cutoff = cutoff 20 | self.nonzero_loss_cutoff = nonzero_loss_cutoff 21 | 22 | @staticmethod 23 | def get_distance(x): 24 | """ 25 | Helper function for margin-based loss. Return a distance matrix given a matrix. 26 | Returns 1 on the diagonal (prevents numerical errors) 27 | """ 28 | n = x.size(0) 29 | square = torch.sum(x ** 2.0, dim=1, keepdim=True) 30 | distance_square = square + square.t() - (2.0 * torch.matmul(x, x.t())) 31 | return torch.sqrt(distance_square + torch.eye(n, dtype=x.dtype, device=x.device)) 32 | 33 | def forward(self, embedding, target): 34 | """ 35 | Inputs: 36 | - embedding: embeddings of images in batch 37 | - target: id of instance targets 38 | 39 | Outputs: 40 | - a dict with 41 | * 'anchor_embeddings' 42 | * 'negative_embeddings' 43 | * 'positive_embeddings' 44 | with sampled embeddings corresponding to anchors, negatives, positives 45 | """ 46 | 47 | B, C = embedding.size()[:2] 48 | embedding = embedding.view(B, C) 49 | 50 | distance = self.get_distance(embedding) 51 | distance = torch.clamp(distance, min=self.cutoff) 52 | 53 | # Subtract max(log(distance)) for stability. 54 | log_weights = ((2.0 - float(C)) * torch.log(distance) 55 | - (float(C - 3) / 2) * torch.log(1.0 - 0.25 * (distance ** 2.0))) 56 | weights = torch.exp(log_weights - log_weights.max()) 57 | 58 | unequal = target.view(-1, 1) 59 | unequal = (unequal != unequal.t()) 60 | 61 | weights = weights * (unequal & (distance < self.nonzero_loss_cutoff)).float() 62 | weights = weights / torch.sum(weights, dim=1, keepdim=True) 63 | 64 | a_indices = [] 65 | p_indices = [] 66 | n_indices = [] 67 | 68 | np_weights = weights.detach().cpu().numpy() 69 | unequal_np = unequal.cpu().numpy() 70 | 71 | for i in range(B): 72 | same = (1 - unequal_np[i]).nonzero()[0] 73 | 74 | if np.isnan(np_weights[i].sum()): # 0 samples within cutoff, sample uniformly 75 | np_weights_ = unequal_np[i].astype(float) 76 | np_weights_ /= np_weights_.sum() 77 | else: 78 | np_weights_ = np_weights[i] 79 | try: 80 | n_indices += np.random.choice(B, len(same) - 1, p=np_weights_, replace=False).tolist() 81 | except ValueError: # cannot always sample without replacement 82 | n_indices += np.random.choice(B, len(same) - 1, p=np_weights_).tolist() 83 | 84 | for j in same: 85 | if j != i: 86 | a_indices.append(i) 87 | p_indices.append(j) 88 | 89 | return {'anchor_embeddings': embedding[a_indices], 90 | 'negative_embeddings': embedding[n_indices], 91 | 'positive_embeddings': embedding[p_indices]} 92 | 93 | 94 | class MarginLoss(nn.Module): 95 | r"""Margin based loss. 96 | 97 | Parameters 98 | ---------- 99 | beta_init: float 100 | Initial beta 101 | margin : float 102 | Margin between positive and negative pairs. 103 | """ 104 | def __init__(self, beta_init=1.2, margin=0.2): 105 | super().__init__() 106 | self.beta = nn.Parameter(torch.tensor(beta_init)) 107 | self._margin = margin 108 | 109 | def forward(self, anchor_embeddings, negative_embeddings, positive_embeddings, eps=1e-8): 110 | """ 111 | 112 | Inputs: 113 | - input_dict: 'anchor_embeddings', 'negative_embeddings', 'positive_embeddings' 114 | 115 | Outputs: 116 | - Loss. 117 | """ 118 | 119 | d_ap = torch.sqrt(torch.sum((positive_embeddings - anchor_embeddings) ** 2, dim=1) + eps) 120 | d_an = torch.sqrt(torch.sum((negative_embeddings - anchor_embeddings) ** 2, dim=1) + eps) 121 | 122 | pos_loss = torch.clamp(d_ap - self.beta + self._margin, min=0.0) 123 | neg_loss = torch.clamp(self.beta - d_an + self._margin, min=0.0) 124 | 125 | pair_cnt = float(torch.sum((pos_loss > 0.0) + (neg_loss > 0.0)).item()) 126 | 127 | # Normalize based on the number of pairs 128 | loss = (torch.sum(pos_loss + neg_loss)) / max(pair_cnt, 1.0) 129 | 130 | return loss 131 | 132 | 133 | class SampledMarginLoss(nn.Module): 134 | """ 135 | Combines DistanceWeightedSampling + Margin Loss 136 | """ 137 | def __init__(self, sampling_args={}, margin_args={}): 138 | super().__init__() 139 | self.sampling = DistanceWeightedSampling(**sampling_args) 140 | self.margin = MarginLoss(**margin_args) 141 | 142 | def forward(self, embedding, target): 143 | sampled_dict = self.sampling(embedding, target) 144 | loss = self.margin(**sampled_dict) 145 | return loss 146 | 147 | -------------------------------------------------------------------------------- /multigrain/modules/multioptim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from collections import OrderedDict as OD 8 | 9 | 10 | class MultiOptim(OD): 11 | """ 12 | Holds dict of optimizers 13 | """ 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | self.base_lr = None 17 | 18 | def state_dict(self): 19 | D = OD() 20 | for k, v in self.items(): 21 | for u, w in v.state_dict().items(): 22 | D[k + '.' + u] = w 23 | return D 24 | 25 | def load_state_dict(self, D): 26 | for opt in self: 27 | local = {} 28 | for k, v in D.items(): 29 | u, k2 = k.split('.', 1) 30 | if u == opt: 31 | local[k2] = v 32 | self[opt].load_state_dict(local) 33 | return self 34 | 35 | def zero_grad(self): 36 | for opt in self.values(): 37 | opt.zero_grad() 38 | 39 | def parameters(self): 40 | P = [] 41 | for opt in self.values(): 42 | for G in opt.param_groups: 43 | for p in G["params"]: 44 | P.append(p) 45 | return P 46 | 47 | def step(self): 48 | for name, O in self.items(): 49 | O.step() 50 | 51 | def set_base_lr(self): 52 | """ 53 | Remember base learning rates to easily apply learning rate drops. 54 | """ 55 | self.base_lr = {} 56 | for name, O in self.items(): 57 | for i, G in enumerate(O.param_groups): 58 | self.base_lr[(name, i)] = G["lr"] 59 | 60 | def lr_multiply(self, multiplier): 61 | """ 62 | Change lr multiplicatively relative to base_lr captured with self.set_base_lr(). 63 | """ 64 | for name, O in self.items(): 65 | for i, G in enumerate(O.param_groups): 66 | G["lr"] = self.base_lr[(name, i)] * multiplier 67 | -------------------------------------------------------------------------------- /multigrain/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import ifmakedirs 2 | from multigrain.utils.logging import print_file 3 | from .torch_utils import cuda 4 | from .metrics import accuracy, AverageMeter, HistoryMeter 5 | from .checkpoint import CheckpointHandler 6 | from .plots import make_plots 7 | from .logging import num_fmt 8 | from .tictoc import Tictoc 9 | from . import arguments -------------------------------------------------------------------------------- /multigrain/utils/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import argparse 8 | 9 | 10 | def comma_separated(type, separator=','): 11 | def parser(inp): 12 | out = tuple(type(i) for i in inp.split(separator)) 13 | if out == ('',): 14 | out = () 15 | return out 16 | return parser 17 | 18 | 19 | def float_in_range(begin, end): 20 | def parser(inp): 21 | inp = float(inp) 22 | if not begin <= inp <= end: 23 | raise argparse.ArgumentTypeError('Argument should be between {} and {}'.format(begin, end)) 24 | return inp 25 | return parser 26 | 27 | 28 | def compare_dicts(dict1, dict2, verbose=True): 29 | removed = [] 30 | added = [] 31 | changed = [] 32 | for k in dict1: 33 | if k not in dict2: 34 | removed.append((k, dict1[k])) 35 | elif dict2[k] != dict1[k]: 36 | changed.append((k, dict1[k], dict2[k])) 37 | for k in dict2: 38 | if k not in dict2: 39 | added.append((k, dict2[k])) 40 | if verbose: 41 | if removed: 42 | print('removed keys:', ', '.join('{} ({})'.format(k, v) for (k, v) in removed)) 43 | if added: 44 | print('added keys:', ', '.join('{} ({})'.format(k, v) for (k, v) in added)) 45 | if changed: 46 | print('changed keys:', ', '.join('{} ({} -> {})'.format(k, v1, v2) for (k, v1, v2) in changed)) 47 | return removed, added, changed 48 | -------------------------------------------------------------------------------- /multigrain/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import division, print_function, absolute_import 8 | import torch 9 | from torch import nn 10 | from collections import OrderedDict as OD 11 | from glob import glob 12 | import os 13 | import os.path as osp 14 | from .logging import ordered_load, ordered_dump 15 | 16 | 17 | class CheckpointHandler(object): 18 | """ 19 | Save checkpoint and metric history in directory 20 | Remove old checkpoints, keep every save_every (none if 0) 21 | """ 22 | 23 | def __init__(self, expdir, save_every=0, verbose=True, prefix='checkpoint_', 24 | metrics_file='metrics'): 25 | self.expdir = expdir 26 | self.save_every = save_every 27 | self.verbose = verbose 28 | self.prefix = prefix 29 | self.metrics = metrics_file 30 | 31 | def available(self, dir=None): 32 | if dir is None: 33 | dir = self.expdir 34 | avail = OD() 35 | for checkpoint in glob(osp.join(dir, self.prefix + '*.pth')): 36 | epoch = int(osp.basename(checkpoint)[len(self.prefix):-len('.pth')]) 37 | avail[epoch] = checkpoint 38 | return avail 39 | 40 | def exists(self, resume, dir=None): 41 | if dir is None: 42 | dir = self.expdir 43 | if resume in (-1, 0): 44 | return True 45 | avail = self.available(dir) 46 | return (resume in avail) 47 | 48 | def delete_old_checkpoints(self, epoch): 49 | avail = self.available() 50 | for k in avail: 51 | if k != epoch and (self.save_every == 0 or (k % self.save_every) != 0): 52 | os.remove(avail[k]) 53 | 54 | def save_metrics(self, metrics_history): 55 | ordered_dump(metrics_history, osp.join(self.expdir, self.metrics + '.yaml')) 56 | 57 | def save(self, model, epoch, optimizer=None, metrics_history=None, extra=None): 58 | module = model.module if isinstance(model, nn.DataParallel) else model 59 | check = dict(model_state=module.state_dict()) 60 | if optimizer is not None: 61 | check['optimizer_state'] = optimizer.state_dict() 62 | if extra is not None: 63 | check['extra'] = extra 64 | torch.save(check, osp.join(self.expdir, self.prefix + '{:d}.pth'.format(epoch))) 65 | if metrics_history is not None: 66 | self.save_metrics(metrics_history) 67 | self.delete_old_checkpoints(epoch) 68 | if self.verbose: 69 | print('Saved checkpoint in', self.expdir) 70 | 71 | def load_state_dict(self, model, state_dict): 72 | module = model.module if isinstance(model, nn.DataParallel) else model 73 | module.load_state_dict(state_dict) 74 | 75 | def resume(self, model, optimizer=None, metrics_history={}, resume_epoch=-1, resume_from=None, return_extra=True): 76 | """ 77 | Restore model state dict and metrics. 78 | """ 79 | if not resume_from: 80 | resume_from = self.expdir 81 | 82 | if osp.isdir(resume_from): 83 | avail = self.available(resume_from) 84 | if resume_epoch == -1: 85 | avail_keys = sorted(avail.keys()) 86 | resume_epoch = avail_keys[-1] if avail_keys else 0 87 | 88 | if resume_epoch != 0: 89 | if resume_epoch not in avail: 90 | raise ValueError('Epoch {} not found in {}'.format(resume_epoch, resume_from)) 91 | resume_from = avail[resume_epoch] 92 | 93 | metrics_file = osp.join(self.expdir, self.metrics + '.yaml') 94 | 95 | if osp.isfile(metrics_file): 96 | metrics_history.clear() 97 | metrics_history.update(ordered_load(metrics_file)) 98 | else: 99 | print('Reinitializing metrics, metrics file: {}'.format(metrics_file)) 100 | 101 | if resume_epoch == 0: 102 | if self.verbose: 103 | print('Initialized model, optimizer') 104 | 105 | return 0, {} if return_extra else 0 106 | 107 | checkpoint = torch.load(resume_from, map_location=torch.device('cpu')) 108 | self.load_state_dict(model, checkpoint['model_state']) 109 | if self.verbose: 110 | print('Model state loaded from', resume_from) 111 | 112 | if optimizer is not None: 113 | if 'optimizer_state' in checkpoint: 114 | optimizer.load_state_dict(checkpoint['optimizer_state']) 115 | if self.verbose: 116 | print('Optimizer state loaded from', resume_from) 117 | elif self.verbose: 118 | print('No optimizer state found in', resume_from) 119 | 120 | if return_extra: 121 | extra = checkpoint.get('extra', {}) 122 | return resume_epoch, extra 123 | return resume_epoch 124 | -------------------------------------------------------------------------------- /multigrain/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import yaml 8 | from collections import OrderedDict as OD 9 | import os 10 | import os.path as osp 11 | 12 | 13 | def num_fmt(num, n=1): 14 | """format digits with n-significant digits""" 15 | if isinstance(num, int): 16 | return str(num) 17 | # round to n significant digits using scientific notation 18 | num = float(('{:.' + str(n - 1) + 'e}').format(num)) 19 | return str(int(num) if num.is_integer() else num) 20 | 21 | 22 | # https://stackoverflow.com/a/21912744/805502 23 | def ordered_load(stream, Loader=yaml.Loader, object_pairs_hook=OD): 24 | if not hasattr(stream, 'read'): # filename instead of stream 25 | stream = open(stream, 'r') 26 | class OrderedLoader(Loader): 27 | pass 28 | def construct_mapping(loader, node): 29 | loader.flatten_mapping(node) 30 | return object_pairs_hook(loader.construct_pairs(node)) 31 | OrderedLoader.add_constructor( 32 | yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, 33 | construct_mapping) 34 | loaded = yaml.load(stream, OrderedLoader) 35 | stream.close() 36 | return loaded 37 | 38 | 39 | def ordered_dump(data, stream=None, Dumper=yaml.Dumper, **kwds): 40 | class OrderedDumper(Dumper): 41 | pass 42 | def _dict_representer(dumper, data): 43 | return dumper.represent_mapping( 44 | yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, 45 | data.items()) 46 | OrderedDumper.add_representer(OD, _dict_representer) 47 | if hasattr(stream, 'write'): 48 | dumped = yaml.dump(data, stream, OrderedDumper, **kwds) 49 | else: # filename instead of file 50 | dumped = yaml.dump(data, None, OrderedDumper, **kwds) 51 | print_file(dumped, stream) 52 | return dumped 53 | 54 | 55 | def str_metrics(metrics, epoch, num_epochs, iter=None, num_iters=None): 56 | str = '[Ep {}/{}] '.format(epoch, num_epochs) 57 | if iter is not None: 58 | str += '({}/{}) '.format(iter, num_iters) 59 | metricstr = [] 60 | count = len(str) 61 | for name, value in metrics.items(): 62 | if iter is not None: 63 | new = "{} {} ({})".format(name, num_fmt(value.val, 3), num_fmt(value.avg, 3)) 64 | else: 65 | new = "{}_avg {}".format(name, num_fmt(value.avg, 3)) 66 | if count + len(new) > 90: 67 | count = len(new) 68 | new = '\n' + new 69 | else: 70 | count += len(new) 71 | metricstr.append(new) 72 | str += ', '.join(metricstr) 73 | return str 74 | 75 | 76 | def print_file(str, filename, safe_overwrite=True): 77 | """ 78 | Write a string to a file; 79 | if the file exists and safe_overwrite is true, do a safe overwriting. 80 | """ 81 | tmp = None 82 | if osp.isfile(filename) and safe_overwrite: 83 | tmp = osp.join(osp.dirname(filename), osp.basename(filename) + '.old') 84 | os.rename(filename, tmp) 85 | with open(filename, 'w') as f: 86 | f.write(str) 87 | if tmp is not None: 88 | os.remove(tmp) 89 | -------------------------------------------------------------------------------- /multigrain/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | from torch.nn import functional as F 9 | 10 | 11 | # Classification metrics 12 | 13 | def accuracy(output, target, topk=(1,)): 14 | """Computes the precision@k for the specified values of k""" 15 | maxk = max(topk) 16 | batch_size = target.size(0) 17 | 18 | _, pred = output.topk(maxk, 1, True, True) 19 | pred = pred.t() 20 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 21 | 22 | res = [] 23 | for k in topk: 24 | correct_k = correct[:k].view(-1).float().sum(0).item() 25 | res.append(correct_k * (100.0 / batch_size)) 26 | return res 27 | 28 | 29 | class AverageMeter(object): 30 | """Computes and stores the average and current value""" 31 | def __init__(self): 32 | self.reset() 33 | 34 | def reset(self): 35 | self.val = 0 36 | self.avg = 0 37 | self.sum = 0 38 | self.count = 0 39 | 40 | def update(self, val, n=1): 41 | self.val = val 42 | self.sum += val * n 43 | self.count += n 44 | self.avg = self.sum / self.count 45 | 46 | 47 | class HistoryMeter(object): 48 | """Remember all values""" 49 | def __init__(self): 50 | self.reset() 51 | 52 | def reset(self): 53 | self.hist = [] 54 | self.partials = [] 55 | self.count = 0 56 | self.val = 0 57 | 58 | def update(self, x, n=1): 59 | self.val = x 60 | self.hist.append(x) 61 | x = n * x 62 | self.count += n 63 | # full precision summation based on http://code.activestate.com/recipes/393090/ 64 | i = 0 65 | for y in self.partials: 66 | if abs(x) < abs(y): 67 | x, y = y, x 68 | hi = x + y 69 | lo = y - (hi - x) 70 | if lo: 71 | self.partials[i] = lo 72 | i += 1 73 | x = hi 74 | self.partials[i:] = [x] 75 | 76 | @property 77 | def avg(self): 78 | """ 79 | Alternative to AverageMeter without floating point errors 80 | """ 81 | return sum(self.partials, 0.0) / self.count if self.partials else 0 82 | 83 | 84 | # Retrieval metrics 85 | 86 | def score_ap(ranks, nres): 87 | """ 88 | Compute the average precision of one search. 89 | ranks = ordered list of ranks of true positives 90 | nres = total number of positives in dataset 91 | """ 92 | 93 | # accumulate trapezoids in PR-plot 94 | ap = 0.0 95 | 96 | # All have an x-size of: 97 | recall_step = 1.0 / nres 98 | 99 | for ntp, rank in enumerate(ranks): 100 | 101 | # y-size on left side of trapezoid: 102 | # ntp = nb of true positives so far 103 | # rank = nb of retrieved items so far 104 | if rank == 0: 105 | precision_0 = 1.0 106 | else: 107 | precision_0 = ntp / float(rank) 108 | 109 | # y-size on right side of trapezoid: 110 | # ntp and rank are increased by one 111 | precision_1 = (ntp + 1) / float(rank + 1) 112 | 113 | ap += (precision_1 + precision_0) * recall_step / 2.0 114 | 115 | return ap 116 | 117 | 118 | def get_distance_matrix(outputs): 119 | """Get distance matrix given all embeddings.""" 120 | square = torch.sum(outputs ** 2.0, dim=1, keepdim=True) 121 | distance_square = square + square.t() - (2.0 * torch.matmul(outputs, outputs.t())) 122 | return F.relu(distance_square) ** 0.5 123 | 124 | 125 | def retrieval_acc(output, target, instances=4): 126 | """ 127 | UKB-like accuracy criterion. 128 | Must be applied to the whole dataset. 129 | """ 130 | _, pred = output.topk(instances, 1, True, False) 131 | d_mat -= torch.eye(d_mat.size(0)) 132 | # d_mat_ic -= torch.eye(d_mat.size(0)) 133 | _, pred = torch.sort(d_mat, dim=1) 134 | 135 | d_mats = [get_distance_matrix(f) for f in features] 136 | preds = [torch.sort(d.cpu(), dim=1)[1] for d in d_mats] -------------------------------------------------------------------------------- /multigrain/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | 9 | 10 | def ifmakedirs(path): 11 | if not os.path.exists(path): 12 | os.makedirs(path) 13 | 14 | -------------------------------------------------------------------------------- /multigrain/utils/plots.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import matplotlib as mpl 8 | mpl.use('Agg') 9 | import matplotlib.pyplot as plt 10 | from collections import defaultdict 11 | import os.path as osp 12 | 13 | 14 | def make_plots(metrics_history, destdir): 15 | keys = set() 16 | for metrics in metrics_history.values(): 17 | keys.update(metrics.keys()) 18 | 19 | groups = defaultdict(list) 20 | for k in keys: 21 | split = k.split('_', 1) 22 | if len(split) == 1: 23 | split = [''] + split 24 | subk, g = split 25 | groups[g].append((subk, k)) 26 | 27 | for g in groups: 28 | plt.figure() 29 | plt.title(g) 30 | for k, kg in groups[g]: 31 | epochs = [] 32 | values = [] 33 | for epoch, metrics in metrics_history.items(): 34 | if kg in metrics: 35 | if isinstance(metrics[kg], list): 36 | for i, v in enumerate(metrics[kg]): 37 | epochs.append(epoch - 1 + (i + 1)/len(metrics[kg])) 38 | values.append(v) 39 | else: 40 | epochs.append(epoch) 41 | values.append(metrics[kg]) 42 | plt.plot(epochs, values, 'o-', label=k if k else None) 43 | if len(groups[g]) > 1: 44 | plt.legend() 45 | plt.xlabel("epochs") 46 | plt.tight_layout() 47 | plt.savefig(osp.join(destdir, g + '.pdf')) 48 | -------------------------------------------------------------------------------- /multigrain/utils/tictoc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from time import time 8 | 9 | 10 | def Tictoc(): 11 | start_stack = [] 12 | start_named = {} 13 | 14 | def tic(name=None): 15 | if name is None: 16 | start_stack.append(time()) 17 | else: 18 | start_named[name] = time() 19 | 20 | def toc(name=None): 21 | if name is None: 22 | start = start_stack.pop() 23 | else: 24 | start = start_named.pop(name) 25 | elapsed = time() - start 26 | return elapsed 27 | return tic, toc 28 | -------------------------------------------------------------------------------- /multigrain/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from collections.abc import Mapping, Sequence 8 | 9 | 10 | def cuda(o, id=0): 11 | """ 12 | Applies cuda recursively to modules and tensors. 13 | """ 14 | if isinstance(o, Mapping): 15 | return type(o)((k, cuda(v)) for (k, v) in o.items()) 16 | if isinstance(o, Sequence): 17 | return type(o)(cuda(v) for v in o) 18 | if hasattr(o, 'cuda'): 19 | return o.cuda(id) 20 | return o 21 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser() 5 | # Data input settings 6 | parser.add_argument('--root_dir', type=str, default='/home/lzw/datasets/air', 7 | help='Path to root directory of datasets') 8 | parser.add_argument('--train_dir', type=str, default='train', 9 | help='Path to training dataset') 10 | parser.add_argument('--test_dir', type=str, default='test', 11 | help='Path to test dataset') 12 | parser.add_argument('--train_label', type=str, default='Train_label.csv', 13 | help='Path to train labels') 14 | parser.add_argument('--val_label', type=str, default='val_label.csv', 15 | help='Path to val labels') 16 | 17 | parser.add_argument('--model_dir', type=str, default='./model') 18 | parser.add_argument('--results_dir', type=str, default='./results') 19 | parser.add_argument('--log_dir', type=str, default='./log') 20 | parser.add_argument('--results_ts', type=str, default='./results_ts/') 21 | parser.add_argument('--res8', type=str, default='8-resnet-50-480_result.csv') 22 | 23 | parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') 24 | parser.add_argument('--gpu_id', type=int, default=0) 25 | parser.add_argument('--seed', type=int, default=924) 26 | parser.add_argument('--crop_size', type=int, default=224, help='Training crop size') 27 | parser.add_argument('--batch_size', type=int, default=128) 28 | parser.add_argument('--k_folds', type=int, default=6) 29 | parser.add_argument('--num_val', type=int, default=400) 30 | parser.add_argument('--num_epochs', type=int, default=120) 31 | parser.add_argument('--fore', type=int, default=1) 32 | parser.add_argument('--train_less', type=int, default=0) 33 | parser.add_argument('--clean_data', type=int, default=0) 34 | parser.add_argument('--weight', type=int, default=1) 35 | 36 | parser.add_argument('--optimizer', type=str, default='sgd', 37 | help='sgd, adam, adamw, radam, novograd') 38 | parser.add_argument('--scheduler', type=str, default='multistep', 39 | help='multistep, cycle, plateau, warmup') 40 | parser.add_argument('--lookahead', type=int, default=0) 41 | 42 | parser.add_argument('--cadene', type=int, default=0) 43 | parser.add_argument('--classes', type=int, default=9) 44 | parser.add_argument('--layers', type=int, default=101, 45 | help='layer nums: 0-7, 18, 34, 50, 101, 152, 16, 19, 121, 161, 201, 48') 46 | parser.add_argument('--pretrained', type=int, default=1, help='pretrained 1=true, 0=false') 47 | parser.add_argument('--network', type=str, default='resnet', 48 | help='network: resnet, resnext, resnext_wsl(with battleneck_width arg), vgg, inception_v3') 49 | parser.add_argument('--battleneck_width', type=int, default=8, help='8, 16, 32, 48') 50 | parser.add_argument('--is_inception', type=int, default=0) 51 | parser.add_argument('--retrain', type=int, default=0) 52 | 53 | ########### mixup ################# 54 | parser.add_argument('--mixup', type=int, default=0, help='use mixup could set alpha, cutmix') 55 | parser.add_argument('--alpha', type=float, default=1.0, help='for mixup alpha') 56 | parser.add_argument('--cutmix', type=int, default=0) 57 | 58 | ########## cutout ################# 59 | parser.add_argument('--cutout', type=int, default=0, help='cutout need n_holes and length') 60 | parser.add_argument('--n_holes', type=int, default=1) 61 | parser.add_argument('--length', type=int, default=16) 62 | 63 | ########## auto_aug ############### 64 | parser.add_argument('--auto_aug', type=int, default=0) 65 | parser.add_argument('--rand_aug', type=int, default=0) 66 | 67 | ########## loss ################## 68 | parser.add_argument('--criterion', type=str, default='lsr', help='criterion: lsr(label smooth), focal, ce') 69 | parser.add_argument('--use_focal', type=int, default=0) 70 | ##### test ##### 71 | parser.add_argument('--vote', type=int, default=0) 72 | parser.add_argument('--tta', type=int, default=1) 73 | 74 | ########## teacher ######### 75 | parser.add_argument('--teacher_mode', type=int, default=0) 76 | parser.add_argument('--cu_mode', type=int, default=0) 77 | 78 | args = parser.parse_args() 79 | return args 80 | -------------------------------------------------------------------------------- /rand_augment.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from PIL import Image, ImageEnhance, ImageOps 5 | import numpy as np 6 | import random 7 | 8 | 9 | class Rand_Augment(): 10 | def __init__(self, Numbers=None, max_Magnitude=None): 11 | self.transforms = ['autocontrast', 'equalize', 'solarize', 'color', 'posterize', # 'rotate', 12 | 'contrast', 'brightness', 'sharpness', 'shearX', 'shearY', 'translateX', 'translateY'] 13 | if Numbers is None: 14 | self.Numbers = len(self.transforms) // 2 15 | else: 16 | self.Numbers = Numbers 17 | if max_Magnitude is None: 18 | self.max_Magnitude = 10 19 | else: 20 | self.max_Magnitude = max_Magnitude 21 | fillcolor = 128 22 | self.ranges = { 23 | # these Magnitude range , you must test it yourself , see what will happen after these operation , 24 | # it is no need to obey the value in autoaugment.py 25 | "shearX": np.linspace(0, 0.3, 10), 26 | "shearY": np.linspace(0, 0.3, 10), 27 | "translateX": np.linspace(0, 0.2, 10), 28 | "translateY": np.linspace(0, 0.2, 10), 29 | # "rotate": np.linspace(0, 360, 10), 30 | "color": np.linspace(0.0, 0.9, 10), 31 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 32 | "solarize": np.linspace(256, 231, 10), 33 | "contrast": np.linspace(0.0, 0.5, 10), 34 | "sharpness": np.linspace(0.0, 0.9, 10), 35 | "brightness": np.linspace(0.0, 0.3, 10), 36 | "autocontrast": [0] * 10, 37 | "equalize": [0] * 10 38 | # "invert": [0] * 10 39 | } 40 | self.func = { 41 | "shearX": lambda img, magnitude: img.transform( 42 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 43 | Image.BICUBIC, fill=fillcolor), 44 | "shearY": lambda img, magnitude: img.transform( 45 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 46 | Image.BICUBIC, fill=fillcolor), 47 | "translateX": lambda img, magnitude: img.transform( 48 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 49 | fill=fillcolor), 50 | "translateY": lambda img, magnitude: img.transform( 51 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 52 | fill=fillcolor), 53 | # "rotate": lambda img, magnitude: self.rotate_with_fill(img, magnitude), 54 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 55 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 56 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 57 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 58 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 59 | 1 + magnitude * random.choice([-1, 1])), 60 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 61 | 1 + magnitude * random.choice([-1, 1])), 62 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 63 | 1 + magnitude * random.choice([-1, 1])), 64 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 65 | "equalize": lambda img, magnitude: img 66 | # "invert": lambda img, magnitude: ImageOps.invert(img) 67 | } 68 | 69 | def rand_augment(self): 70 | """Generate a set of distortions. 71 | Args: 72 | N: Number of augmentation transformations to apply sequentially. N is len(transforms)/2 will be best 73 | M: Max_Magnitude for all the transformations. should be <= self.max_Magnitude """ 74 | 75 | M = np.random.randint(0, self.max_Magnitude, self.Numbers) 76 | 77 | sampled_ops = np.random.choice(self.transforms, self.Numbers) 78 | return [(op, Magnitude) for (op, Magnitude) in zip(sampled_ops, M)] 79 | 80 | def __call__(self, image): 81 | operations = self.rand_augment() 82 | for (op_name, M) in operations: 83 | operation = self.func[op_name] 84 | mag = self.ranges[op_name][M] 85 | image = operation(image, mag) 86 | return image 87 | 88 | def rotate_with_fill(self, img, magnitude): 89 | # I don't know why rotate must change to RGBA , it is copy from Autoaugment - pytorch 90 | rot = img.convert("RGBA").rotate(magnitude) 91 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 92 | 93 | def test_single_operation(self, image, op_name, M=-1): 94 | ''' 95 | :param image: image 96 | :param op_name: operation name in self.transforms 97 | :param M: -1 stands for the max Magnitude in there operation 98 | :return: 99 | ''' 100 | operation = self.func[op_name] 101 | mag = self.ranges[op_name][M] 102 | image = operation(image, mag) 103 | return image 104 | 105 | 106 | if __name__ == '__main__': 107 | # # this is for call the whole fun 108 | # img_augment = Rand_Augment() 109 | # img_origal = Image.open(r'0a38b552372d.png') 110 | # img_final = img_augment(img_origal) 111 | # plt.imshow(img_final) 112 | # plt.show() 113 | # print('how to call') 114 | 115 | # this is for a single fun you want to test 116 | img_augment = Rand_Augment() 117 | img_origal = Image.open(r'0bfdedaa60b54078ab0fc3bc6582aa90.jpg') 118 | for i in range(0, 10): 119 | img_final = img_augment.test_single_operation(img_origal, 'invert', M=i) 120 | plt.subplot(5, 2, i + 1) 121 | plt.imshow(img_final) 122 | plt.show() 123 | print('how to test') -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import time 6 | import copy 7 | from tqdm import tqdm 8 | import numpy as np 9 | import copy 10 | from sklearn.metrics import f1_score 11 | 12 | import utils 13 | from grad_cam import GradCam, show_cam_on_image 14 | from dataloader import UnNormalize 15 | from tta import * 16 | 17 | tta_list = [NoneAug(), 18 | Hflip(), 19 | # Vflip(), 20 | # Resize(256), 21 | # Resize(288), 22 | # Resize(320), 23 | # Resize(352), 24 | # Resize(388), 25 | Resize(400), 26 | # Resize(544), 27 | # Resize(608), 28 | # Adjustcontrast(0.2), 29 | # Resize(640), 30 | # Resize(663), 31 | # Resize(736), 32 | ] 33 | 34 | def cutmix_data(data, targets, alpha=1.0, device=None): 35 | indices = torch.randperm(data.size(0)).to(device) 36 | shuffled_data = data[indices] 37 | shuffled_targets = targets[indices] 38 | 39 | lam = np.random.beta(alpha, alpha) 40 | 41 | image_h, image_w = data.shape[2:] 42 | cx = np.random.uniform(0, image_w) 43 | cy = np.random.uniform(0, image_h) 44 | w = image_w * np.sqrt(1 - lam) 45 | h = image_h * np.sqrt(1 - lam) 46 | x0 = int(np.round(max(cx - w / 2, 0))) 47 | x1 = int(np.round(min(cx + w / 2, image_w))) 48 | y0 = int(np.round(max(cy - h / 2, 0))) 49 | y1 = int(np.round(min(cy + h / 2, image_h))) 50 | 51 | data[:, :, y0:y1, x0:x1] = shuffled_data[:, :, y0:y1, x0:x1] 52 | # targets = (targets, shuffled_targets, lam) 53 | 54 | return data, targets, shuffled_targets, lam 55 | 56 | def mixup_data(x, y, alpha=1.0, device=None): 57 | '''Returns mixed inputs, pairs of targets, and lambda''' 58 | if alpha > 0: 59 | lam = np.random.beta(alpha, alpha) 60 | else: 61 | lam = 1 62 | 63 | batch_size = x.size()[0] 64 | index = torch.randperm(batch_size).to(device) 65 | 66 | mixed_x = lam * x + (1 - lam) * x[index, :] 67 | y_a, y_b = y, y[index] 68 | return mixed_x, y_a, y_b, lam 69 | 70 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 71 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 72 | 73 | def train_model(dataloaders, model, criterion, optimizer, summary_writer, 74 | scheduler=None, scheduler_name='multistep', num_epochs=20, device=None, 75 | is_inception=False, mixup=False, cutmix=False, alpha=1.0, val_dis=[]): 76 | tic = time.time() 77 | 78 | acc_history = [] 79 | best_acc = 0 80 | best_model_wgt = None 81 | 82 | if len(dataloaders['val']) == 0: 83 | phases = ['train'] 84 | else: 85 | phases = ['train', 'val'] 86 | 87 | step_per_epoch = len(dataloaders['train'].dataset) / dataloaders['train'].batch_size 88 | 89 | for epoch in range(num_epochs): 90 | print('epoch {}/{}, lr: {}'.format(epoch + 1, num_epochs, scheduler.get_lr()[0])) 91 | 92 | for phase in phases: 93 | if phase == 'train': 94 | model.train() 95 | running_loss = 0.0 96 | else: 97 | model.eval() 98 | running_correct = 0.0 99 | err = [] 100 | p, l = [], [] 101 | 102 | for i, (inputs, labels) in enumerate(dataloaders[phase]): 103 | inputs = inputs.to(device) 104 | labels = labels.to(device) 105 | 106 | if phase == 'train' and mixup: 107 | if cutmix: 108 | inputs, targets_a, targets_b, lam = cutmix_data(inputs, labels, alpha, device) 109 | else: 110 | inputs, targets_a, targets_b, lam = mixup_data(inputs, labels, alpha, device) 111 | 112 | if phase == 'train' and is_inception: 113 | logits, aux_logits = model(inputs) 114 | if mixup: 115 | loss1 = mixup_criterion(criterion, logits, targets_a, targets_b, lam) 116 | loss2 = mixup_criterion(criterion, aux_logits, targets_a, targets_b, lam) 117 | loss = (loss1 + 0.4*loss2) 118 | else: 119 | loss1 = criterion(logits, labels) 120 | loss2 = criterion(aux_logits, labels) 121 | loss = (loss1 + 0.4*loss2) 122 | else: 123 | logits = model(inputs) 124 | if phase == 'train' and mixup: 125 | loss = mixup_criterion(criterion, logits, targets_a, targets_b, lam) 126 | else: 127 | loss = criterion(logits, labels) 128 | 129 | if phase == 'train': 130 | optimizer.zero_grad() 131 | loss.backward() 132 | optimizer.step() 133 | summary_writer.add_scalar(tag='loss', scalar_value=loss.item(), 134 | global_step=step_per_epoch*epoch+i) 135 | 136 | running_loss += loss.item() 137 | if scheduler is not None and (scheduler_name == 'cycle' or 138 | scheduler_name == 'warmup' or 139 | scheduler_name == 'cos' or 140 | scheduler_name == 'cosw' or 141 | scheduler_name == 'sgdr'): 142 | scheduler.step() 143 | 144 | _, preds = logits.max(1) 145 | if phase == 'train' and mixup: 146 | correct = (lam * preds.eq(targets_a.data).sum().float() 147 | + (1 - lam) * preds.eq(targets_b.data).sum().float()) 148 | else: 149 | correct = (preds == labels).sum() 150 | p.extend(preds.detach().cpu().numpy()) 151 | l.extend(labels.cpu().numpy()) 152 | 153 | running_correct += correct 154 | if phase == 'val': 155 | # print('label:', labels) 156 | # print('preds:', preds) 157 | error_label = (preds != labels).long() * (labels + 1) 158 | error_label = error_label[error_label>0] 159 | err.append(error_label) 160 | 161 | epoch_acc = running_correct.double() / len(dataloaders[phase].dataset) 162 | if phase == 'train': 163 | epoch_loss = running_loss / len(dataloaders[phase].dataset) 164 | print('train --> loss: {:.4f}, acc: {:.4f}'.format(epoch_loss, epoch_acc)) 165 | else: 166 | # acc, pre, val_f1 = utils.f1score(l, p, 9) 167 | val_f1 = f1_score(l, p, average='macro') 168 | 169 | print('val --> acc: {:.4f}, f1: {:.4f}'.format(epoch_acc, val_f1)) 170 | if len(val_dis) >= 1: 171 | errors = torch.cat(err, 0) 172 | # total_err = errors.cpu().numpy().shape[0] 173 | val_len = len(val_dis)+1 174 | err_rate = np.bincount(errors.cpu().numpy(), minlength=val_len)[1:] / val_dis 175 | p_rate = np.bincount(p, minlength=len(val_dis))[:] / val_dis 176 | [print('{}: {:.2f}, {:.2f}'.format(utils.weather_classes[i], rate, p_rate[i])) for i, rate in enumerate(err_rate)] 177 | 178 | acc_history.append(epoch_acc.item()) 179 | summary_writer.add_scalar(tag='correct', scalar_value=epoch_acc, global_step=epoch) 180 | if epoch_acc >= best_acc and epoch > num_epochs//2: 181 | best_acc = epoch_acc.item() 182 | best_model_wgt = copy.deepcopy(model.state_dict()) 183 | torch.save(best_model_wgt, 'model/temp.ckpt') 184 | print() 185 | if scheduler is not None: 186 | if scheduler_name in ['multistep', 'step', 'exponential']: 187 | scheduler.step() 188 | elif scheduler_name == 'plateau': 189 | scheduler.step(epoch_loss) 190 | 191 | summary_writer.close() 192 | toc = time.time() 193 | time_elapsed = toc - tic 194 | print('training time -> %d:%.2f' % (time_elapsed // 60, time_elapsed % 60)) 195 | print('best_acc:', best_acc) 196 | 197 | if best_model_wgt is not None: 198 | model.load_state_dict(best_model_wgt) 199 | 200 | return model, best_acc 201 | 202 | def clean_data(loader, model, device): 203 | model.eval() 204 | labels = [] 205 | im_names = [] 206 | err = [] 207 | not_correct = 0 208 | for images, labels, names in tqdm(loader): 209 | images = images.to(device) 210 | labels = labels.to(device) 211 | logits = torch.softmax(model(images), 1) 212 | 213 | score, preds = logits.max(1) 214 | not_correct += (preds != labels).sum() 215 | # score_ = ((score>=0.9) | (0.40 and (preds[i] in [5, 7] or labels[i] in [5, 7]): 220 | if lab>0: 221 | im_names.append((names[i].split('/')[-1], preds[i].item()+1, labels[i].item()+1, score[i].item())) 222 | 223 | print("%d/%d" % (len(im_names), not_correct)) 224 | return im_names 225 | 226 | 227 | def eval_model(loader, model, device): 228 | model.eval() 229 | labels = [] 230 | im_names = [] 231 | for images, names in tqdm(loader): 232 | images = images.to(device) 233 | logits = model(images) 234 | 235 | _, preds = logits.max(1) 236 | for p, n in zip(preds.cpu().numpy().tolist(), names): 237 | im_names.append(n) 238 | labels.append(p+1) 239 | return im_names, labels 240 | 241 | 242 | def eval_logits(loader, model, device=None): 243 | model.eval() 244 | labels = [] 245 | im_names = [] 246 | for images, names in tqdm(loader): 247 | images = images.to(device) 248 | logits = model(images) 249 | 250 | for p, n in zip(logits.detach().cpu().numpy().tolist(), names): 251 | im_names.append(n) 252 | labels.append(p) 253 | return im_names, labels 254 | 255 | def eval_model_tta(loader, model, tta_augs=tta_list, device=None): 256 | model.eval() 257 | labels = [] 258 | im_names = [] 259 | for images, names in tqdm(loader): 260 | images = TensorToPILs(images) 261 | logits = [] 262 | for aug in tta_augs: 263 | aug_imgs = PILsToTensor(aug(images)).to(device) 264 | outputs = model(aug_imgs).detach().cpu().numpy().tolist() 265 | logits.append(outputs) 266 | # print(np.shape(logits)) 267 | logits = np.mean(np.array(logits), axis=0) 268 | 269 | preds = np.argmax(logits, axis=1) 270 | for p, n in zip(preds, names): 271 | im_names.append(n) 272 | labels.append(p+1) 273 | return im_names, labels 274 | 275 | def eval_logits_tta(loader, model, tta_augs=tta_list, device=None): 276 | model.eval() 277 | labels = [] 278 | im_names = [] 279 | for images, names in tqdm(loader): 280 | images = TensorToPILs(images) 281 | logits = [] 282 | for aug in tta_augs: 283 | aug_imgs = PILsToTensor(aug(images)).to(device) 284 | outputs = torch.softmax(model(aug_imgs), dim=1).detach().cpu().numpy().tolist() 285 | logits.append(outputs) 286 | # print(np.shape(logits)) 287 | logits = np.mean(np.array(logits), axis=0) 288 | 289 | for p, n in zip(logits, names): 290 | im_names.append(n) 291 | labels.append(p) 292 | return im_names, labels 293 | 294 | def extract_features(loader, model, device=None): 295 | _features = [] 296 | for images, labels, names in tqdm(loader): 297 | images = images.to(device) 298 | labels = labels.to(device) 299 | features = model.module.extract_features(images).cpu().numpy().tolist() 300 | 301 | for i, feat in enumerate(features): 302 | _features.append(features[i]) 303 | 304 | return np.array(_features) 305 | 306 | def grad_cam(loader, model, device): 307 | model.eval() 308 | labels = [] 309 | im_names = [] 310 | use_cuda = device != None 311 | model_no_fc = copy.deepcopy(model) 312 | del model_no_fc.fc 313 | cam = GradCam(model=model_no_fc, target_layer_names=['layer4'], use_cuda=use_cuda, org_model=model) 314 | utils.mkdir('cam') 315 | unorm = UnNormalize() 316 | for i, (images, labels, org_imgs) in enumerate(tqdm(loader)): 317 | images = images.to(device) 318 | labels = labels.to(device) 319 | 320 | logits = model(images) 321 | _, preds = logits.max(1) 322 | 323 | right = (preds == labels) 324 | masks = cam(images, None) 325 | # print(org_imgs.shape) 326 | if right: 327 | name = './cam/{}-{}-{}-{}.jpg'.format(i, utils.weather_classes[preds[0]], utils.weather_classes[labels[0]], right[0]) 328 | img = unorm(images[0]).cpu().numpy().transpose(1, 2, 0) 329 | show_cam_on_image(img, masks, name) 330 | # print(org_imgs[0].shape) 331 | utils.save_image(org_imgs[0].numpy(), './cam/{}-{}.jpg'.format(i, utils.weather_classes[labels[0]])) 332 | 333 | print('cam finished!') 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | -------------------------------------------------------------------------------- /tta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import torchvision.transforms.functional as F 5 | from PIL import Image, ImageOps, ImageEnhance 6 | from torchvision import transforms 7 | 8 | # mean = [0.485, 0.456, 0.406] 9 | # std = [0.229, 0.224, 0.225] 10 | 11 | # mean = [0.5, 0.5, 0.5] 12 | # std = [0.5, 0.5, 0.5] 13 | 14 | mean = [0.507, 0.522, 0.500] 15 | std = [0.213, 0.207, 0.212] 16 | 17 | def visualizationImage(imgs): 18 | for index, img in enumerate(imgs): 19 | plt.subplot(1, 4, index + 1) 20 | plt.imshow(img) 21 | plt.axis('off') 22 | 23 | class NormalizeInverse(transforms.Normalize): 24 | """ 25 | Undoes the normalization and returns the reconstructed images in the input domain. 26 | """ 27 | def __init__(self, mean, std): 28 | mean = torch.as_tensor(mean) 29 | std = torch.as_tensor(std) 30 | std_inv = 1 / (std + 1e-7) 31 | # std_inv = 1.0 / std 32 | mean_inv = -mean * std_inv 33 | super().__init__(mean=mean_inv, std=std_inv) 34 | 35 | def __call__(self, tensor): 36 | return super().__call__(tensor.clone()) 37 | 38 | def TensorToPILs(inputs): 39 | # unNorm=NormalizeInverse(mean=mean, std=std) 40 | # imgs = [F.to_pil_image(unNorm(inputs[i])) for i in range(inputs.shape[0])] 41 | imgs = [F.to_pil_image(inputs[i]) for i in range(inputs.shape[0])] 42 | return imgs 43 | 44 | def PILsToTensor(imgs): 45 | Norm = transforms.Normalize(mean=mean,std=std) 46 | tensors = [Norm(F.to_tensor(img)) for img in imgs] 47 | # tensors = [F.to_tensor(img) for img in imgs] 48 | return torch.stack(tensors) 49 | 50 | class NoneAug(): 51 | def __call__(self, imgs): 52 | return imgs 53 | 54 | class Resize(): 55 | def __init__(self, size): 56 | self.size = size 57 | 58 | def __call__(self, imgs): 59 | return [F.resize(img=img, size=self.size) for img in imgs] 60 | 61 | class Hflip(): 62 | def __call__(self,imgs): 63 | return [F.hflip(img=img) for img in imgs] 64 | 65 | class Vflip(): 66 | def __call__(self, imgs): 67 | return [F.vflip(img=img) for img in imgs] 68 | 69 | class Rotate(): 70 | def __init__(self, angle): 71 | self.angle = angle 72 | def __call__(self, imgs): 73 | return [F.rotate(img=img, angle=self.angle) for img in imgs] 74 | 75 | class Grayscale(): 76 | def __init__(self, output_channels=1): 77 | self.output_channels = output_channels 78 | def __call__(self, imgs): 79 | return [F.to_grayscale(img=img, num_output_channels=self.output_channels) for img in imgs] 80 | 81 | class Adjustbright(): 82 | def __init__(self, bright_factor): 83 | self.bright_factor = bright_factor 84 | def __call__(self, imgs): 85 | return [F.adjust_brightness(img=img, brightness_factor=self.bright_factor) for img in imgs] 86 | 87 | class Adjustcontrast(): 88 | def __init__(self, contrast_factor): 89 | self.contrast_factor = contrast_factor 90 | def __call__(self, imgs): 91 | return [F.adjust_contrast(img=img, contrast_factor=self.contrast_factor) for img in imgs] 92 | 93 | class Adjustsaturation(): 94 | def __init__(self, saturation_factor): 95 | self.saturation_factor = saturation_factor 96 | def __call__(self, imgs): 97 | return [F.adjust_saturation(img=img, saturation_factor=self.saturation_factor) for img in imgs] 98 | 99 | class Adjustgamma(): 100 | def __init__(self, gamma, gain=1): 101 | self.gamma = gamma 102 | self.gain = gain 103 | def __call__(self, imgs): 104 | return [F.adjust_gamma(img=img, gamma=self.gamma, gain=self.gain) for img in imgs] 105 | 106 | 107 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from PIL import ImageFile 6 | from tqdm import tqdm 7 | import csv 8 | import glob 9 | import os 10 | import time 11 | from multiprocessing import Pool 12 | 13 | import smote_variants as sv 14 | import warnings 15 | warnings.filterwarnings("ignore") 16 | 17 | ImageFile.LOAD_TRUNCATED_IMAGES = True 18 | 19 | weather_classes = ['雨凇', '雾凇', '雾霾', '霜', '露', '结冰', '降雨', '降雪', '冰雹'] 20 | # 1 2 3 4 5 6 7 8 9 21 | def classes_num(filename): 22 | labels = [0]*9 23 | with open(filename) as f: 24 | f_csv = csv.reader(f) 25 | f_csv.__next__() 26 | for row in f_csv: 27 | labels[int(row[1])-1] += 1 28 | return labels 29 | 30 | def image_info(dir): 31 | min_w, min_h = 1000, 1000 32 | max_w, max_h = 0, 0 33 | w, h = [], [] 34 | ex = [] 35 | means = [0, 0, 0] 36 | std = [0, 0, 0] 37 | for img_path in tqdm(glob.glob(dir+'/*')): 38 | try: 39 | img = Image.open(img_path) 40 | if img.size[0] < min_w: 41 | min_w = img.size[0] 42 | elif img.size[0] > max_w: 43 | max_w = img.size[0] 44 | if img.size[1] < min_h: 45 | min_h = img.size[1] 46 | elif img.size[1] > max_h: 47 | max_h = img.size[1] 48 | w.append(img.size[0]) 49 | h.append(img.size[1]) 50 | 51 | except(OSError, NameError): 52 | ex.append(img_path) 53 | # img = cv2.imread(img_path) 54 | # print(img.shape) 55 | 56 | img = np.array(img).astype(np.float32) 57 | img = img / 255.0 58 | # print(img_path, img.shape) 59 | if len(img.shape) == 2: continue 60 | for i in range(3): 61 | means[i] += img[:, :, i].mean() 62 | std[i] += img[:, :, i].std() 63 | 64 | means.reverse() 65 | std.reverse() 66 | 67 | means = np.asarray(means) / len(w) 68 | std = np.asarray(std) / len(w) 69 | 70 | print("max_w:{}, max_h:{}".format(max_w, max_h)) 71 | print("min_w:{}, min_h:{}".format(min_w, min_h)) 72 | print('len:{}, mean_w:{}, mean_h:{}'.format(len(w), np.mean(w), np.mean(h))) 73 | print(ex) 74 | print("normMean = {}".format(means)) 75 | print("normStd = {}".format(std)) 76 | 77 | 78 | def read_test_data(fdir): 79 | images = [] 80 | im_names =[] 81 | i=0 82 | for img_path in tqdm(glob.glob(fdir+'/*')): 83 | im = load_image(os.path.join(fdir, img_path)) 84 | im_names.append(img_path.split('/')[-1]) 85 | images.append(im) 86 | i+=1 87 | if i>10000: 88 | break 89 | return images, im_names 90 | 91 | def read_test_ice_snow_data(fdir, filename): 92 | images = [] 93 | im_names =[] 94 | 95 | with open(filename) as f: 96 | f_csv = csv.reader(f) 97 | f_csv.__next__() 98 | for row in tqdm(f_csv): 99 | if int(row[1]) == 6: 100 | img = load_image(os.path.join(fdir, row[0])) 101 | images.append(img) 102 | im_names.append(row[0]) 103 | 104 | return images, im_names 105 | 106 | def read_non_ice_snow_data(fdir, filename): 107 | images = [] 108 | labels = [] 109 | 110 | with open(filename) as f: 111 | f_csv = csv.reader(f) 112 | f_csv.__next__() 113 | print('loading image...') 114 | i=0 115 | for row in tqdm(f_csv): 116 | if row[0] == 'cad097b0899f45bcba277adf5344097e.png': 117 | continue 118 | if int(row[1]) in [6, 8]: 119 | continue 120 | elif int(row[1]) in [7]: 121 | labels.append(5) 122 | elif int(row[1]) in [9]: 123 | labels.append(6) 124 | # labels.append(int(row[1]) - 3) 125 | else: 126 | labels.append(int(row[1])-1) 127 | images.append(os.path.join(fdir, row[0])) 128 | 129 | i+=1 130 | if i>7000: 131 | break 132 | 133 | return images, labels 134 | 135 | def read_ice_snow_data(fdir, filename): 136 | images = [] 137 | labels = [] 138 | # names = [] 139 | # path = os.path.join(fdir, filename) 140 | 141 | with open(filename) as f: 142 | f_csv = csv.reader(f) 143 | f_csv.__next__() 144 | print('loading image...') 145 | 146 | for row in tqdm(f_csv): 147 | if row[0] == 'cad097b0899f45bcba277adf5344097e.png': 148 | continue 149 | if int(row[1]) not in [6, 8]: 150 | continue 151 | # img = load_image(os.path.join(fdir, row[0])) 152 | images.append(os.path.join(fdir, row[0])) 153 | 154 | # 6 -> 1, 8 -> 0 155 | labels.append(int(int(row[1])-1==5)) 156 | # names.append(row[0]) 157 | 158 | return images, labels 159 | 160 | def read_data(fdir, filename, train_less=False, clean_data=False): 161 | images = [] 162 | labels = [] 163 | # names = [] 164 | # path = os.path.join(fdir, filename) 165 | need_cleans = [] 166 | if clean_data: 167 | with open('./err.csv', 'r') as f: 168 | f_csv = csv.reader(f) 169 | for row in tqdm(f_csv): 170 | need_cleans.append(row[0]) 171 | 172 | with open(filename) as f: 173 | f_csv = csv.reader(f) 174 | f_csv.__next__() 175 | print('loading image...') 176 | i=0 177 | for row in tqdm(f_csv): 178 | if row[0] == 'cad097b0899f45bcba277adf5344097e.png': 179 | continue 180 | if clean_data: 181 | if row[0] in need_cleans: 182 | continue 183 | if train_less: 184 | images.append(os.path.join(fdir, row[0])) 185 | if int(row[1]) in [6, 8]: 186 | labels.append(5) 187 | elif int(row[1]) in [9]: 188 | labels.append(7) 189 | else: 190 | labels.append(int(row[1])-1) 191 | # labels.append(1) 192 | 193 | continue 194 | 195 | # img = load_image(os.path.join(fdir, row[0])) 196 | images.append(os.path.join(fdir, row[0])) 197 | labels.append(int(row[1])-1) 198 | # names.append(row[0]) 199 | i+=1 200 | if i>70000: 201 | break 202 | 203 | return images, labels 204 | 205 | 206 | def read_smote_data(fdir, filename, val_num=500): 207 | images = [] 208 | labels = [] 209 | # path = os.path.join(fdir, filename) 210 | with open(filename) as f: 211 | f_csv = csv.reader(f) 212 | f_csv.__next__() 213 | print('loading image...') 214 | for row in tqdm(f_csv): 215 | if row[0] == 'cad097b0899f45bcba277adf5344097e.png': 216 | continue 217 | images.append(os.path.join(fdir, row[0])) 218 | labels.append(int(row[1])-1) 219 | 220 | train_data = images[val_num:], labels[val_num:] 221 | val_data = images[:val_num], labels[:val_num] 222 | return train_data, val_data 223 | 224 | 225 | def smote_data(images, labels): 226 | # images = images[:50] 227 | # labels = labels[:50] 228 | shape = np.shape(images) 229 | nums = shape[0] // 2 230 | oversampler = sv.MulticlassOversampling(sv.Borderline_SMOTE2(proportion=0.7, n_neighbors=3, k_neighbors=3, n_jobs=12)) # MDO 231 | X, y = oversampler.sample(np.reshape(images, (len(images), -1)), labels) 232 | X = X.reshape((len(y), shape[1], shape[2], shape[3])).astype(np.uint8) 233 | mkdir('new_train') 234 | with open('new_train_label.csv', 'a', encoding='utf-8') as f: 235 | f_csv = csv.writer(f) 236 | for i, x in enumerate(X): 237 | im = Image.fromarray(x) 238 | im.save('./new_train/'+str(i)+'.jpg', 'jpeg') 239 | f_csv.writerow([str(i)+'.jpg', y[i]+1]) 240 | 241 | print('org: %d -> x: %d' % (len(labels), len(y))) 242 | ys = [0]*10 243 | for i in y: 244 | ys[i+1] += 1 245 | print(ys) 246 | 247 | def load_image(filename): 248 | try: 249 | img = Image.open(filename) 250 | except(OSError, NameError): 251 | # print('cv opened image') 252 | cv_img = cv2.imread(filename) 253 | img = Image.fromarray(cv_img) 254 | 255 | img = img.convert("RGB") 256 | # print(filename) 257 | return img 258 | 259 | def load_image_label(params, resize=600): 260 | try: 261 | img = Image.open(params[0]) 262 | except(OSError, NameError): 263 | print('cv opened image') 264 | cv_img = cv2.imread(params[0]) 265 | img = Image.fromarray(cv_img) 266 | 267 | img = img.convert("RGB") 268 | img = img.resize((resize, resize), Image.ANTIALIAS) 269 | # print(filename) 270 | return np.array(img), params[1] 271 | 272 | def to_tensor(data, dtype=torch.float16, device=None): 273 | return torch.as_tensor(data, dtype=dtype, device=device) 274 | 275 | def mkdir(path): 276 | # give a path, create the folder 277 | folder = os.path.exists(path) 278 | 279 | if not folder: 280 | os.makedirs(path) 281 | 282 | def save_image(img, name): 283 | cv2.imwrite(name, img) 284 | 285 | def add_weight_decay(net, l2_value, skip_list=()): 286 | decay, no_decay = [], [] 287 | for name, param in net.named_parameters(): 288 | if not param.requires_grad: continue # frozen weights 289 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: no_decay.append(param) 290 | else: decay.append(param) 291 | return [{'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': l2_value}] 292 | 293 | def f1score(y_true, y_pred, num_classes): 294 | # calculates accuracy, weighted precision, and weighted f1-score for n-class classification for n>=3 295 | # note that weighted recall is the same as accuracy 296 | 297 | N = len(y_true) 298 | confusion_matrix = [[0 for _ in range(num_classes)] for _ in range(num_classes)] 299 | 300 | for i in range(0, N): 301 | confusion_matrix[y_true[i]][y_pred[i]] += 1 302 | 303 | sum_diagonal = 0 304 | 305 | for i in range(0, num_classes): 306 | sum_diagonal += confusion_matrix[i][i] 307 | 308 | precision = 0.0 309 | f1score = 0.0 310 | 311 | for i in range(0, num_classes): 312 | support = 0 313 | sum_column = 0 314 | 315 | for j in range(0, num_classes): 316 | support += confusion_matrix[i][j] 317 | sum_column += confusion_matrix[j][i] 318 | 319 | if support != 0: 320 | g = confusion_matrix[i][i] * support 321 | f1score += g / (support + sum_column) 322 | 323 | if sum_column != 0: 324 | precision += g / sum_column 325 | 326 | accuracy = sum_diagonal / N 327 | precision /= N 328 | f1score = 2 * f1score / N 329 | 330 | return accuracy, precision, f1score 331 | 332 | 333 | 334 | if __name__ == "__main__": 335 | file_dir = "/home/lzw/datasets/air" 336 | filename = "Train_label.csv" 337 | label_file = os.path.join(file_dir, filename) 338 | 339 | nums = classes_num(label_file) 340 | for i, cl in enumerate(weather_classes): 341 | print('{}:{}'.format(cl, nums[i])) 342 | print('nums:', np.sum(nums)) 343 | 344 | image_info(os.path.join(file_dir,'train')) 345 | 346 | # train_data, val_data = read_smote_data( 347 | # os.path.join(file_dir,'train'), os.path.join(file_dir,filename), val_num=0) 348 | # images, labels = train_data 349 | # tic = time.time() 350 | # pool = Pool(48) 351 | # img_names = pool.map(load_image_label, list(zip(images, labels))) 352 | # pool.close() 353 | # pool.join() 354 | # toc = time.time() 355 | # imgs = [] 356 | # labs = [] 357 | # for im, name in img_names: 358 | # imgs.append(im) 359 | # labs.append(name) 360 | # print('load image: ', toc-tic) 361 | # print(imgs[0].shape) 362 | 363 | # smote_data(imgs, labs) 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | --------------------------------------------------------------------------------