├── README.md ├── dataset.py ├── increase new 1 class each time.png ├── main.py ├── new class test accuracy.jpg ├── old class test accuracy.jpg ├── transforms.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Learning-without-Forgetting-using-Pytorch 2 | This is the Pytorch implementation of LwF 3 | 4 | In my experiment, the baseline is Alexnet from Pytorch whose top1 accuracy is 56.518% and top5 accuracy is 79.070% (I only use top1 accuracy in my experiment following). I use CUB data set. The dataset.py refered to: https://github.com/weiaicunzai/Bag_of_Tricks_for_Image_Classification_with_Convolutional_Neural_Networks 5 | 6 | If you want to find the paper, you can click :https://ieeexplore.ieee.org/document/8107520 7 | 8 | I suggest that `lr` should be set as 0.001 and `alpha` should be set less than 0.3. 9 | If `alpha` is too large, the accuracy of old classes will decrease very quickly. But if it is too small, the accuracy of new classes will be lower. 10 | I set `T` as 2, I didn't try other values so I don't know whether there will be other values that lead to higher performance. If you want to know detail of this super-parameter, you can read this paper: https://arxiv.org/pdf/1503.02531.pdf。 This paper is the source of 'Knowledge Distillation'. 11 | There are some result: 12 | 13 | number of new classes | alpha | accuracy of old 1000 classes | accuracy of new classes 14 | ---- | ----- | ------ | ------- 15 | 50 | 0.1 | 55.120% | 47.948% 16 | 50 | 0.3 | 51.512% | 54.664% 17 | 100 | 0.1 | 54.208% | 39.525% 18 | 100 | 0.3 | 50.002% | 48.092% 19 | 150 | 0.1 | 53.284% | 36.573% 20 | 150 | 0.3 | 49.116% | 45.284% 21 | 200 | 0.1 | 53.370% | 34.413% 22 | 200 | 0.3 | 49.342% | 44.492% 23 | 24 | 25 | If I only add one class, the accuracy of the new class will be very high with a litle decrease of accuracy of old 1000 classes. 26 | I tried this in CUB data set and the ave-acc of new 200 classes is 80.98% and the accuracy of old 1000 classes is 55.035% (-1.465%). 27 | 28 | I holp you can check your curves after implementing this code. You will find that, at the begining, the accuracy of new classes will be always 0 for some epoches but the accuracy of old classes will decrease quickly. I don't know why. If you have any interesting finding, please contact me. 29 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | 8 | class CUB_200_2011_Train(Dataset): 9 | 10 | def __init__(self, path, transform=None, target_transform=None): 11 | 12 | self.root = path 13 | self.transform = transform 14 | self.target_transform = target_transform 15 | self.images_path = {} 16 | with open(os.path.join(self.root, 'images.txt')) as f: 17 | for line in f: 18 | image_id, path = line.split() 19 | self.images_path[image_id] = path 20 | 21 | self.class_ids = {} 22 | with open(os.path.join(self.root, 'image_class_labels.txt')) as f: 23 | for line in f: 24 | image_id, class_id = line.split() 25 | self.class_ids[image_id] = class_id 26 | 27 | self.train_id = [] 28 | with open(os.path.join(self.root, 'train_test_split.txt')) as f: 29 | for line in f: 30 | image_id, is_train = line.split() 31 | if int(is_train) and self.class_ids[image_id]=='1':# only use the first class 32 | self.train_id.append(image_id) 33 | 34 | def __len__(self): 35 | return len(self.train_id) 36 | 37 | def __getitem__(self, index): 38 | """ 39 | Args: 40 | index: index of training dataset 41 | Returns: 42 | image and its corresponding label 43 | """ 44 | image_id = self.train_id[index] 45 | class_id = int(self._get_class_by_id(image_id)) - 1 46 | path = self._get_path_by_id(image_id) 47 | cv2.setNumThreads(0) 48 | cv2.ocl.setUseOpenCL(False) 49 | image = cv2.imread(os.path.join(self.root, 'images', path)) 50 | 51 | if self.transform: 52 | image = self.transform(image) 53 | 54 | if self.target_transform: 55 | class_id = self.target_transform(class_id) 56 | return image, class_id 57 | 58 | def _get_path_by_id(self, image_id): 59 | 60 | return self.images_path[image_id] 61 | 62 | def _get_class_by_id(self, image_id): 63 | 64 | return self.class_ids[image_id] 65 | 66 | 67 | class CUB_200_2011_Test(Dataset): 68 | 69 | def __init__(self, path, transform=None, target_transform=None): 70 | 71 | self.root = path 72 | self.transform = transform 73 | self.target_transform = target_transform 74 | self.images_path = {} 75 | with open(os.path.join(self.root, 'images.txt')) as f: 76 | for line in f: 77 | image_id, path = line.split() 78 | self.images_path[image_id] = path 79 | 80 | self.class_ids = {} 81 | with open(os.path.join(self.root, 'image_class_labels.txt')) as f: 82 | for line in f: 83 | image_id, class_id = line.split() 84 | self.class_ids[image_id] = class_id 85 | 86 | self.train_id = [] 87 | with open(os.path.join(self.root, 'train_test_split.txt')) as f: 88 | for line in f: 89 | image_id, is_train = line.split() 90 | if not int(is_train) and self.class_ids[image_id]=='1':# only use the first class 91 | self.train_id.append(image_id) 92 | 93 | def __len__(self): 94 | return len(self.train_id) 95 | 96 | def __getitem__(self, index): 97 | """ 98 | Args: 99 | index: index of training dataset 100 | Returns: 101 | image and its corresponding label 102 | """ 103 | image_id = self.train_id[index] 104 | class_id = int(self._get_class_by_id(image_id)) - 1 105 | cv2.setNumThreads(0) 106 | cv2.ocl.setUseOpenCL(False) 107 | path = self._get_path_by_id(image_id) 108 | image = cv2.imread(os.path.join(self.root, 'images', path)) 109 | 110 | if self.transform: 111 | image = self.transform(image) 112 | 113 | if self.target_transform: 114 | class_id = self.target_transform(class_id) 115 | 116 | return image, class_id 117 | 118 | def _get_path_by_id(self, image_id): 119 | 120 | return self.images_path[image_id] 121 | 122 | def _get_class_by_id(self, image_id): 123 | 124 | return self.class_ids[image_id] 125 | 126 | 127 | def compute_mean_and_std(dataset): 128 | """Compute dataset mean and std, and normalize it 129 | 130 | Args: 131 | dataset: instance of CUB_200_2011_Train, CUB_200_2011_Test 132 | 133 | Returns: 134 | return: mean and std of this dataset 135 | """ 136 | 137 | mean_r = 0 138 | mean_g = 0 139 | mean_b = 0 140 | 141 | for img, _ in dataset: 142 | mean_b += np.mean(img[:, :, 0]) 143 | mean_g += np.mean(img[:, :, 1]) 144 | mean_r += np.mean(img[:, :, 2]) 145 | 146 | mean_b /= len(dataset) 147 | mean_g /= len(dataset) 148 | mean_r /= len(dataset) 149 | 150 | diff_r = 0 151 | diff_g = 0 152 | diff_b = 0 153 | 154 | N = 0 155 | 156 | for img, _ in dataset: 157 | 158 | diff_b += np.sum(np.power(img[:, :, 0] - mean_b, 2)) 159 | diff_g += np.sum(np.power(img[:, :, 1] - mean_g, 2)) 160 | diff_r += np.sum(np.power(img[:, :, 2] - mean_r, 2)) 161 | 162 | N += np.prod(img[:, :, 0].shape) 163 | 164 | std_b = np.sqrt(diff_b / N) 165 | std_g = np.sqrt(diff_g / N) 166 | std_r = np.sqrt(diff_r / N) 167 | 168 | mean = (mean_b.item() / 255.0, mean_g.item() / 255.0, mean_r.item() / 255.0) 169 | std = (std_b.item() / 255.0, std_g.item() / 255.0, std_r.item() / 255.0) 170 | return mean, std 171 | -------------------------------------------------------------------------------- /increase new 1 class each time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasLiang/Learning-without-Forgetting-using-Pytorch/fc9bfdb35075da94aec060cb89fd7d480cdbd978/increase new 1 class each time.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torch.backends.cudnn as cudnn 7 | import scipy.io as scio 8 | from scipy.io import loadmat 9 | import torchvision 10 | import os 11 | import argparse 12 | from utils import progress_bar 13 | import numpy as np 14 | import random 15 | 16 | import matplotlib.pyplot as plt 17 | import pdb 18 | import transforms 19 | from dataset import CUB_200_2011_Train, CUB_200_2011_Test 20 | import torchvision.transforms as tfs 21 | import torchvision.datasets as datasets 22 | 23 | def kaiming_normal_init(m): 24 | if isinstance(m, nn.Conv2d): 25 | nn.init.kaiming_normal_(m.weight, nonlinearity='relu') 26 | elif isinstance(m, nn.Linear): 27 | nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid') 28 | 29 | np.random.seed(1) 30 | torch.manual_seed(1) 31 | torch.cuda.manual_seed_all(1) 32 | random.seed(1) 33 | torch.backends.cudnn.benchmark = False 34 | torch.backends.cudnn.deterministic = True 35 | torch.backends.cudnn.enabled = True 36 | parser = argparse.ArgumentParser(description='Learning Without Forgetting') 37 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 38 | parser.add_argument('--resume', '-r', action='store_true', 39 | help='resume from checkpoint') 40 | args = parser.parse_args() 41 | 42 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 43 | best_acc = 0 # best test accuracy 44 | start_epoch = 1 # start from epoch 0 or last checkpoint epoch 45 | 46 | # New Data 47 | print('==> Preparing data..') 48 | train_transforms = transforms.Compose([ 49 | transforms.ToCVImage(), 50 | transforms.RandomResizedCrop(224), 51 | transforms.RandomHorizontalFlip(), 52 | transforms.ColorJitter(brightness=0.4, saturation=0.4, hue=0.4), 53 | transforms.ToTensor(), 54 | transforms.Normalize( 55 | [0.48560741861744905, 0.49941626449353244, 0.43237713785804116], 56 | [0.2321024260764962, 0.22770540015765814, 0.2665100547329813]) 57 | ]) 58 | 59 | test_transforms = transforms.Compose([ 60 | transforms.ToCVImage(), 61 | transforms.CenterCrop(224), 62 | transforms.ToTensor(), 63 | transforms.Normalize( 64 | [0.4862169586881995, 0.4998156522834164, 0.4311430419332438], 65 | [0.23264268069040475, 0.22781080253662814, 0.26667253517177186]) 66 | ]) 67 | 68 | trainset = CUB_200_2011_Train( 69 | '~/CUB_200_2011', 70 | transform=train_transforms, 71 | ) 72 | trainloader = torch.utils.data.DataLoader( 73 | trainset, batch_size=384, shuffle=True, num_workers=2) 74 | 75 | testset = CUB_200_2011_Test( 76 | '~/CUB_200_2011', 77 | transform=test_transforms, 78 | ) 79 | testloader = torch.utils.data.DataLoader( 80 | testset, batch_size=100, shuffle=False, num_workers=2) 81 | 82 | # Used to test the performance on old dataset 83 | normalize = tfs.Normalize(mean=[0.485, 0.456, 0.406], 84 | std=[0.229, 0.224, 0.225]) 85 | valloader = torch.utils.data.DataLoader( 86 | datasets.ImageFolder(root='~/Imagenet2012/ILSVRC2012_img_val', transform=tfs.Compose([ 87 | tfs.Resize(256), 88 | tfs.CenterCrop(224), 89 | tfs.ToTensor(), 90 | normalize, 91 | ])), 92 | batch_size=256, shuffle=False, 93 | num_workers=4, pin_memory=True) 94 | # Model 95 | print('==> Building model..') 96 | net = torchvision.models.alexnet(pretrained=False) 97 | net11 = torchvision.models.alexnet(pretrained=False) 98 | 99 | # Use the pretrained model from Pytorch 100 | oor = torch.load('~/pytorch_pretrained_modle/alexnet_pretrained.pth') 101 | # 'Net' is the new model to learn new classes 102 | net.load_state_dict(oor) 103 | # 'Net11' is the old model to learn old classes 104 | net11.load_state_dict(oor) 105 | # Number of new class, if this number is changed, you should change 'dataset.py' 106 | # You can chech the 31th line and the 90th line of dataset.py 107 | # In my experiment, I test each 200 classes on by on and retain this value as '1' 108 | # If you want to add more classes, you should change the code I said before 109 | num_new_class = 1 110 | # Old number of input/output channel of the last FC layer in old model 111 | in_features = net.classifier[6].in_features 112 | out_features = net.classifier[6].out_features 113 | # Old weight/bias of the last FC layer 114 | weight = net.classifier[6].weight.data 115 | bias = net.classifier[6].bias.data 116 | # New number of output channel of the last FC layer in new model 117 | new_out_features = num_new_class+out_features 118 | # Creat a new FC layer and initial it's weight/bias 119 | new_fc = nn.Linear(in_features, new_out_features) 120 | kaiming_normal_init(new_fc.weight) 121 | new_fc.weight.data[:out_features] = weight 122 | new_fc.bias.data[:out_features] = bias 123 | # Replace the old FC layer 124 | net.classifier[6] = new_fc 125 | # CUDA 126 | net = net.to(device) 127 | net11 = net11.to(device) 128 | if device == 'cuda': 129 | net = torch.nn.DataParallel(net) 130 | net11 = torch.nn.DataParallel(net11) 131 | cudnn.benchmark = True 132 | 133 | # Loss function 134 | criterion = nn.CrossEntropyLoss() 135 | # Temperature of the new softmax proposed in 'Distillation of Knowledge' 136 | T=2 137 | # Used to balance the new class loss1 and the old class loss2 138 | # Loss1 is the cross entropy between output of the new task and label 139 | # Loss2 is the cross entropy between output of the old task and output of the old model 140 | # It should be noticed that before calculating loss2, the output of each model should- 141 | # -be handled by the new softmax 142 | alpha = 0.01 143 | def train(epoch): 144 | print('\nEpoch: %d' % epoch) 145 | net.eval() 146 | train_loss = 0 147 | correct = 0 148 | total = 0 149 | for batch_idx, (inputs, targets) in enumerate(trainloader): 150 | inputs, targets = inputs.to(device), targets.to(device) 151 | targets += out_features 152 | optimizer.zero_grad() 153 | outputs = net(inputs) 154 | soft_target = net11(inputs) 155 | # Cross entropy between output of the new task and label 156 | loss1 = criterion(outputs,targets) 157 | # Using the new softmax to handle outputs 158 | outputs_S = F.softmax(outputs[:,:out_features]/T,dim=1) 159 | outputs_T = F.softmax(soft_target[:,:out_features]/T,dim=1) 160 | # Cross entropy between output of the old task and output of the old model 161 | loss2 = outputs_T.mul(-1*torch.log(outputs_S)) 162 | loss2 = loss2.sum(1) 163 | loss2 = loss2.mean()*T*T 164 | loss = loss1*alpha+loss2*(1-alpha) 165 | loss.backward(retain_graph=True) 166 | optimizer.step() 167 | 168 | train_loss += loss.item() 169 | _, predicted = outputs.max(1) 170 | total += targets.size(0) 171 | correct += predicted.eq(targets).sum().item() 172 | 173 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 174 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 175 | 176 | return train_loss/(batch_idx+1) 177 | def test(epoch): 178 | global best_acc 179 | net.eval() 180 | test_loss = 0 181 | correct = 0 182 | total = 0 183 | with torch.no_grad(): 184 | for batch_idx, (inputs, targets) in enumerate(testloader): 185 | inputs, targets = inputs.to(device), targets.to(device) 186 | targets += out_features 187 | outputs = net(inputs) 188 | soft_target = net11(inputs) 189 | loss1 = criterion(outputs,targets) 190 | loss = loss1 191 | outputs_S = F.softmax(outputs[:,:out_features]/T,dim=1) 192 | outputs_T = F.softmax(soft_target[:,:out_features]/T,dim=1) 193 | loss2 = outputs_T.mul(-1*torch.log(outputs_S)) 194 | loss2 = loss2.sum(1) 195 | loss2 = loss2.mean()*T*T 196 | loss = loss1*alpha+loss2*(1-alpha) 197 | 198 | test_loss += loss.item() 199 | _, predicted = outputs.max(1) 200 | total += targets.size(0) 201 | correct += predicted.eq(targets).sum().item() 202 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 203 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 204 | # Save checkpoint. 205 | acc = 100.*correct/total 206 | if acc > best_acc: 207 | print('Saving..') 208 | state = { 209 | 'net': net.state_dict(), 210 | 'acc': acc, 211 | 'epoch': epoch, 212 | } 213 | if not os.path.isdir('checkpoint'): 214 | os.mkdir('checkpoint') 215 | torch.save(state, './checkpoint/ckpt.pth') 216 | best_acc = acc 217 | return acc 218 | def val(epoch): 219 | net.eval() 220 | correct = 0 221 | total = 0 222 | with torch.no_grad(): 223 | for batch_idx, (inputs, targets) in enumerate(valloader): 224 | inputs, targets = inputs.to(device), targets.to(device) 225 | outputs = net(inputs) 226 | _, predicted_old = outputs.max(1) 227 | total += targets.size(0) 228 | correct += predicted_old.eq(targets).sum().item() 229 | progress_bar(batch_idx, len(valloader), 'Acc: %.3f%% (%d/%d)' 230 | % (100.*correct/total, correct, total)) 231 | return 100.*correct/total 232 | 233 | epochs = [] 234 | test_new_accs = [] 235 | test_old_accs = [] 236 | train_losses = [] 237 | layer_num = [0,3,7,10,14,17,20,24,27,30,34,37,40] 238 | ct = 0 239 | 240 | 241 | # Ensure that the old model don't train 242 | for param in net11.module.parameters(): 243 | param.requires_grad = False 244 | 245 | ## Warm up step 246 | ## In my experiment, I didn't use this step, and the paper said that this step is not necessary 247 | # 248 | # optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, 249 | # momentum=0.9, weight_decay=5e-4) 250 | # 251 | ## Frozen the model 252 | # for param in net11.module.parameters(): 253 | # param.requires_grad = False 254 | ## Thaw the last FC layer 255 | # for param in net.module.classifier[6].parameters(): 256 | # param.requires_grad = True 257 | # 258 | # for epoch in range(start_epoch, start_epoch+20): 259 | # train_loss = train(epoch) 260 | # # Make sure there are only weights/bias corresponding to the new task being trained 261 | # net.module.classifier[6].weight.data[:out_features] = net11.module.classifier[6].weight.data 262 | # net.module.classifier[6].bias.data[:out_features] = net11.module.classifier[6].bias.data 263 | # acc,test_loss = test(epoch) 264 | ## Thaw the model 265 | # for param in net.module.parameters(): 266 | # param.requires_grad = True 267 | 268 | ## train step 269 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr, 270 | momentum=0.9, weight_decay=5e-4) 271 | 272 | for epoch in range(start_epoch, start_epoch+200): 273 | train_loss = train(epoch) 274 | acc_new = test(epoch) 275 | acc_old = val(epoch) 276 | test_new_accs.append(acc_new) 277 | test_old_accs.append(acc_old) 278 | train_losses.append(train_loss) 279 | epochs.append(epoch) 280 | ## Save the final model 281 | # torch.save(net.state_dict(), 'temp.pkl') 282 | 283 | # Picture of the new class test accuracy changing with training 284 | plt.figure() 285 | plt.plot(epochs,test_new_accs) 286 | plt.xlabel("epochs") 287 | plt.ylabel("accuracy") 288 | plt.title('new class test accuracy') 289 | plt.savefig('./acc_new_class.jpg') 290 | 291 | # Picture of the old class test accuracy changing with training 292 | plt.figure() 293 | plt.plot(epochs,test_old_accs) 294 | plt.xlabel("epochs") 295 | plt.ylabel("accuracy") 296 | plt.title('old class test accuracy') 297 | plt.savefig('./acc_old_class.jpg') 298 | 299 | # Picture of the training loss changing with training 300 | plt.figure() 301 | plt.plot(epochs,train_losses) 302 | plt.xlabel('epochs') 303 | plt.ylabel('loss') 304 | plt.title('train loss') 305 | plt.savefig('./train_loss.jpg') 306 | 307 | -------------------------------------------------------------------------------- /new class test accuracy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasLiang/Learning-without-Forgetting-using-Pytorch/fc9bfdb35075da94aec060cb89fd7d480cdbd978/new class test accuracy.jpg -------------------------------------------------------------------------------- /old class test accuracy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasLiang/Learning-without-Forgetting-using-Pytorch/fc9bfdb35075da94aec060cb89fd7d480cdbd978/old class test accuracy.jpg -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import math 4 | import numbers 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | import torch 10 | 11 | class Compose: 12 | """Composes several transforms together. 13 | 14 | Args: 15 | transforms(list of 'Transform' object): list of transforms to compose 16 | 17 | """ 18 | 19 | def __init__(self, transforms): 20 | self.transforms = transforms 21 | 22 | def __call__(self, img): 23 | 24 | for trans in self.transforms: 25 | img = trans(img) 26 | 27 | return img 28 | 29 | def __repr__(self): 30 | format_string = self.__class__.__name__ + '(' 31 | for t in self.transforms: 32 | format_string += '\n' 33 | format_string += ' {0}'.format(t) 34 | format_string += '\n)' 35 | return format_string 36 | 37 | 38 | class ToCVImage: 39 | """Convert an Opencv image to a 3 channel uint8 image 40 | """ 41 | 42 | def __call__(self, image): 43 | """ 44 | Args: 45 | image (numpy array): Image to be converted to 32-bit floating point 46 | 47 | Returns: 48 | image (numpy array): Converted Image 49 | """ 50 | if len(image.shape) == 2: 51 | image = cv2.cvtColor(iamge, cv2.COLOR_GRAY2BGR) 52 | 53 | image = image.astype('uint8') 54 | 55 | return image 56 | 57 | 58 | class RandomResizedCrop: 59 | """Randomly crop a rectangle region whose aspect ratio is randomly sampled 60 | in [3/4, 4/3] and area randomly sampled in [8%, 100%], then resize the cropped 61 | region into a 224-by-224 square image. 62 | 63 | Args: 64 | size: expected output size of each edge 65 | scale: range of size of the origin size cropped 66 | ratio: range of aspect ratio of the origin aspect ratio cropped (w / h) 67 | interpolation: Default: cv2.INTER_LINEAR: 68 | """ 69 | 70 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation='linear'): 71 | 72 | self.methods={ 73 | "area":cv2.INTER_AREA, 74 | "nearest":cv2.INTER_NEAREST, 75 | "linear" : cv2.INTER_LINEAR, 76 | "cubic" : cv2.INTER_CUBIC, 77 | "lanczos4" : cv2.INTER_LANCZOS4 78 | } 79 | 80 | self.size = (size, size) 81 | self.interpolation = self.methods[interpolation] 82 | self.scale = scale 83 | self.ratio = ratio 84 | 85 | def __call__(self, img): 86 | h, w, _ = img.shape 87 | 88 | area = w * h 89 | 90 | for attempt in range(10): 91 | target_area = random.uniform(*self.scale) * area 92 | target_ratio = random.uniform(*self.ratio) 93 | 94 | output_h = int(round(math.sqrt(target_area * target_ratio))) 95 | output_w = int(round(math.sqrt(target_area / target_ratio))) 96 | 97 | if random.random() < 0.5: 98 | output_w, output_h = output_h, output_w 99 | 100 | if output_w <= w and output_h <= h: 101 | topleft_x = random.randint(0, w - output_w) 102 | topleft_y = random.randint(0, h - output_h) 103 | break 104 | 105 | if output_w > w or output_h > h: 106 | output_w = min(w, h) 107 | output_h = output_w 108 | topleft_x = random.randint(0, w - output_w) 109 | topleft_y = random.randint(0, h - output_w) 110 | 111 | cropped = img[topleft_y : topleft_y + output_h, topleft_x : topleft_x + output_w] 112 | 113 | resized = cv2.resize(cropped, self.size, interpolation=self.interpolation) 114 | 115 | return resized 116 | 117 | def __repr__(self): 118 | for name, inter in self.methods.items(): 119 | if inter == self.interpolation: 120 | inter_name = name 121 | 122 | interpolate_str = inter_name 123 | format_str = self.__class__.__name__ + '(size={0}'.format(self.size) 124 | format_str += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 125 | format_str += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) 126 | format_str += ', interpolation={0})'.format(interpolate_str) 127 | 128 | return format_str 129 | 130 | 131 | class RandomHorizontalFlip: 132 | """Horizontally flip the given opencv image with given probability p. 133 | 134 | Args: 135 | p: probability of the image being flipped 136 | """ 137 | def __init__(self, p=0.5): 138 | self.p = p 139 | 140 | def __call__(self, img): 141 | """ 142 | Args: 143 | the image to be flipped 144 | Returns: 145 | flipped image 146 | """ 147 | if random.random() < self.p: 148 | img = cv2.flip(img, 1) 149 | 150 | return img 151 | 152 | class ColorJitter: 153 | 154 | """Randomly change the brightness, contrast and saturation of an image 155 | 156 | Args: 157 | brightness: (float or tuple of float(min, max)): how much to jitter 158 | brightness, brightness_factor is choosen uniformly from[max(0, 1-brightness), 159 | 1 + brightness] or the given [min, max], Should be non negative numbe 160 | contrast: same as brightness 161 | saturation: same as birghtness 162 | hue: same as brightness 163 | """ 164 | 165 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 166 | self.brightness = self._check_input(brightness) 167 | self.contrast = self._check_input(contrast) 168 | self.saturation = self._check_input(saturation) 169 | self.hue = self._check_input(hue) 170 | 171 | def _check_input(self, value): 172 | 173 | if isinstance(value, numbers.Number): 174 | assert value >= 0, 'value should be non negative' 175 | value = [max(0, 1 - value), 1 + value] 176 | 177 | elif isinstance(value, (list, tuple)): 178 | assert len(value) == 2, 'brightness should be a tuple/list with 2 elements' 179 | assert 0 <= value[0] <= value[1], 'max should be larger than or equal to min,\ 180 | and both larger than 0' 181 | 182 | else: 183 | raise TypeError('need to pass int, float, list or tuple, instead got{}'.format(type(value).__name__)) 184 | 185 | return value 186 | 187 | def __call__(self, img): 188 | """ 189 | Args: 190 | img to be jittered 191 | Returns: 192 | jittered img 193 | """ 194 | 195 | img_dtype = img.dtype 196 | h_factor = random.uniform(*self.hue) 197 | b_factor = random.uniform(*self.brightness) 198 | s_factor = random.uniform(*self.saturation) 199 | c_factor = random.uniform(*self.contrast) 200 | 201 | img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 202 | img = img.astype('float32') 203 | 204 | #h 205 | img[:, :, 0] *= h_factor 206 | img[:, :, 0] = np.clip(img[:, :, 0], 0, 179) 207 | 208 | #s 209 | img[:, :, 1] *= s_factor 210 | img[:, :, 1] = np.clip(img[:, :, 1], 0, 255) 211 | 212 | #v 213 | img[:, :, 2] *= b_factor 214 | img[:, :, 2] = np.clip(img[:, :, 2], 0, 255) 215 | 216 | img = img.astype(img_dtype) 217 | img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR) 218 | 219 | #c 220 | img = img * c_factor 221 | img = img.astype(img_dtype) 222 | img = np.clip(img, 0, 255) 223 | 224 | return img 225 | 226 | class ToTensor: 227 | """convert an opencv image (h, w, c) ndarray range from 0 to 255 to a pytorch 228 | float tensor (c, h, w) ranged from 0 to 1 229 | """ 230 | 231 | def __call__(self, img): 232 | """ 233 | Args: 234 | a numpy array (h, w, c) range from [0, 255] 235 | 236 | Returns: 237 | a pytorch tensor 238 | """ 239 | #convert format H W C to C H W 240 | img = img.transpose(2, 0, 1) 241 | img = torch.from_numpy(img) 242 | img = img.float() / 255.0 243 | 244 | return img 245 | 246 | class Normalize: 247 | """Normalize a torch tensor (H, W, BGR order) with mean and standard deviation 248 | 249 | for each channel in torch tensor: 250 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 251 | 252 | Args: 253 | mean: sequence of means for each channel 254 | std: sequence of stds for each channel 255 | """ 256 | 257 | def __init__(self, mean, std, inplace=False): 258 | self.mean = mean 259 | self.std = std 260 | self.inplace = inplace 261 | 262 | def __call__(self, img): 263 | """ 264 | Args: 265 | (H W C) format numpy array range from [0, 255] 266 | Returns: 267 | (H W C) format numpy array in float32 range from [0, 1] 268 | """ 269 | assert torch.is_tensor(img) and img.ndimension() == 3, 'not an image tensor' 270 | 271 | if not self.inplace: 272 | img = img.clone() 273 | 274 | mean = torch.tensor(self.mean, dtype=torch.float32) 275 | std = torch.tensor(self.std, dtype=torch.float32) 276 | img.sub_(mean[:, None, None]).div_(std[:, None, None]) 277 | 278 | return img 279 | 280 | class CenterCrop: 281 | """resize each image’s shorter edge to r pixels while keeping its aspect ratio. 282 | Next, we crop out the cropped region in the center 283 | Args: 284 | resized: resize image' shorter edge to resized pixels while keeping the aspect ratio 285 | cropped: output image size(h, w), if cropped is an int, then output cropped * cropped size 286 | image 287 | """ 288 | 289 | def __init__(self, cropped, resized=256, interpolation='linear'): 290 | 291 | methods = { 292 | "area":cv2.INTER_AREA, 293 | "nearest":cv2.INTER_NEAREST, 294 | "linear" : cv2.INTER_LINEAR, 295 | "cubic" : cv2.INTER_CUBIC, 296 | "lanczos4" : cv2.INTER_LANCZOS4 297 | } 298 | self.interpolation = methods[interpolation] 299 | 300 | self.resized = resized 301 | 302 | if isinstance(cropped, numbers.Number): 303 | cropped = (cropped, cropped) 304 | 305 | self.cropped = cropped 306 | 307 | def __call__(self, img): 308 | 309 | shorter = min(*img.shape[:2]) 310 | 311 | scaler = float(self.resized) / shorter 312 | 313 | img = cv2.resize(img, (0, 0), fx=scaler, fy=scaler, interpolation=self.interpolation) 314 | 315 | h, w, _ = img.shape 316 | 317 | topleft_x = int((w - self.cropped[1]) / 2) 318 | topleft_y = int((h - self.cropped[0]) / 2) 319 | 320 | center_cropped = img[topleft_y : topleft_y + self.cropped[0], 321 | topleft_x : topleft_x + self.cropped[1]] 322 | 323 | return center_cropped 324 | 325 | class RandomErasing: 326 | """Random erasing the an rectangle region in Image. 327 | Class that performs Random Erasing in Random Erasing Data Augmentation by Zhong et al. 328 | 329 | Args: 330 | sl: min erasing area region 331 | sh: max erasing area region 332 | r1: min aspect ratio range of earsing region 333 | p: probability of performing random erasing 334 | """ 335 | 336 | def __init__(self, p=0.5, sl=0.02, sh=0.4, r1=0.3): 337 | 338 | self.p = p 339 | self.s = (sl, sh) 340 | self.r = (r1, 1/r1) 341 | 342 | 343 | def __call__(self, img): 344 | """ 345 | perform random erasing 346 | Args: 347 | img: opencv numpy array in form of [w, h, c] range 348 | from [0, 255] 349 | 350 | Returns: 351 | erased img 352 | """ 353 | 354 | assert len(img.shape) == 3, 'image should be a 3 dimension numpy array' 355 | 356 | if random.random() > self.p: 357 | return img 358 | 359 | else: 360 | while True: 361 | Se = random.uniform(*self.s) * img.shape[0] * img.shape[1] 362 | re = random.uniform(*self.r) 363 | 364 | He = int(round(math.sqrt(Se * re))) 365 | We = int(round(math.sqrt(Se / re))) 366 | 367 | xe = random.randint(0, img.shape[1]) 368 | ye = random.randint(0, img.shape[0]) 369 | 370 | if xe + We <= img.shape[1] and ye + He <= img.shape[0]: 371 | img[ye : ye + He, xe : xe + We, :] = np.random.randint(low=0, high=255, size=(He, We, img.shape[2])) 372 | 373 | return img 374 | 375 | class CutOut: 376 | """Randomly mask out one or more patches from an image. An image 377 | is a opencv format image (h,w,c numpy array) 378 | 379 | Args: 380 | n_holes (int): Number of patches to cut out of each image. 381 | length (int): The length (in pixels) of each square patch. 382 | """ 383 | 384 | def __init__(self, length, n_holes=1): 385 | self.n_holes = n_holes 386 | self.length = length 387 | 388 | def __call__(self, img): 389 | 390 | while self.n_holes: 391 | 392 | y = random.randint(0, img.shape[0] - 1) 393 | x = random.randint(0, img.shape[1] - 1) 394 | 395 | tl_x = int(max(0, x - self.length / 2)) 396 | tl_y = int(max(0, y - self.length / 2)) 397 | 398 | img[tl_y : tl_y + self.length, tl_x : tl_x + self.length, :] = 0 399 | 400 | self.n_holes -= 1 401 | 402 | return img 403 | 404 | 405 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | 14 | 15 | def get_mean_and_std(dataset): 16 | '''Compute the mean and std value of dataset.''' 17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:,i,:,:].mean() 24 | std[i] += inputs[:,i,:,:].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | def init_params(net): 30 | '''Init layer parameters.''' 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal(m.weight, mode='fan_out') 34 | if m.bias: 35 | init.constant(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant(m.weight, 1) 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal(m.weight, std=1e-3) 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | 44 | 45 | #_, term_width = os.popen('stty size', 'r').read().split() 46 | #term_width = int(term_width) 47 | 48 | TOTAL_BAR_LENGTH = 65. 49 | last_time = time.time() 50 | begin_time = last_time 51 | def progress_bar(current, total, msg=None): 52 | global last_time, begin_time 53 | if current == 0: 54 | begin_time = time.time() # Reset for new bar. 55 | 56 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 57 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 58 | 59 | sys.stdout.write(' [') 60 | for i in range(cur_len): 61 | sys.stdout.write('=') 62 | sys.stdout.write('>') 63 | for i in range(rest_len): 64 | sys.stdout.write('.') 65 | sys.stdout.write(']') 66 | 67 | cur_time = time.time() 68 | step_time = cur_time - last_time 69 | last_time = cur_time 70 | tot_time = cur_time - begin_time 71 | 72 | L = [] 73 | L.append(' Step: %s' % format_time(step_time)) 74 | L.append(' | Tot: %s' % format_time(tot_time)) 75 | if msg: 76 | L.append(' | ' + msg) 77 | 78 | msg = ''.join(L) 79 | sys.stdout.write(msg) 80 | #for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 81 | # sys.stdout.write(' ') 82 | 83 | # Go back to the center of the bar. 84 | #for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 85 | # sys.stdout.write('\b') 86 | sys.stdout.write(' %d/%d ' % (current+1, total)) 87 | 88 | if current < total-1: 89 | sys.stdout.write('\r') 90 | else: 91 | sys.stdout.write('\n') 92 | sys.stdout.flush() 93 | 94 | def format_time(seconds): 95 | days = int(seconds / 3600/24) 96 | seconds = seconds - days*3600*24 97 | hours = int(seconds / 3600) 98 | seconds = seconds - hours*3600 99 | minutes = int(seconds / 60) 100 | seconds = seconds - minutes*60 101 | secondsf = int(seconds) 102 | seconds = seconds - secondsf 103 | millis = int(seconds*1000) 104 | 105 | f = '' 106 | i = 1 107 | if days > 0: 108 | f += str(days) + 'D' 109 | i += 1 110 | if hours > 0 and i <= 2: 111 | f += str(hours) + 'h' 112 | i += 1 113 | if minutes > 0 and i <= 2: 114 | f += str(minutes) + 'm' 115 | i += 1 116 | if secondsf > 0 and i <= 2: 117 | f += str(secondsf) + 's' 118 | i += 1 119 | if millis > 0 and i <= 2: 120 | f += str(millis) + 'ms' 121 | i += 1 122 | if f == '': 123 | f = '0ms' 124 | return f 125 | --------------------------------------------------------------------------------