├── .gitignore ├── AircraftsDataset.py ├── BCNN.py ├── CNN.py ├── CUBDataset.py ├── CarsDataset.py ├── LICENSE ├── MITIndoorDataset.py ├── README.md ├── compact_bilinear_pooling └── __init__.py ├── config.py ├── feature_extractor.py ├── find_best_model.py ├── finetune.py ├── iNatDataset.py ├── inv_images ├── 001.Black_footed_Albatross.png ├── 002.Laysan_Albatross.png ├── 005.Crested_Auklet.png ├── 006.Least_Auklet.png ├── 007.Parakeet_Auklet.png ├── 009.Brewer_Blackbird.png ├── 010.Red_winged_Blackbird.png ├── 011.Rusty_Blackbird.png ├── 012.Yellow_headed_Blackbird.png ├── 013.Bobolink.png ├── 014.Indigo_Bunting.png ├── 015.Lazuli_Bunting.png ├── 016.Painted_Bunting.png ├── 017.Cardinal.png ├── 018.Spotted_Catbird.png ├── 019.Gray_Catbird.png ├── 020.Yellow_breasted_Chat.png ├── 024.Red_faced_Cormorant.png ├── 025.Pelagic_Cormorant.png ├── 026.Bronzed_Cowbird.png ├── 027.Shiny_Cowbird.png ├── 028.Brown_Creeper.png ├── 029.American_Crow.png └── 030.Fish_Crow.png ├── inversion.py ├── matrixSquareRoot.py ├── plot_curve.py ├── requirements.txt ├── resize_inat.py ├── test.py ├── train.py └── train_cnn.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | */__pycache__ 3 | .ipynb_checkpoints/ 4 | -------------------------------------------------------------------------------- /AircraftsDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import numpy as np 4 | import os 5 | from torchvision.datasets import folder as dataset_parser 6 | import json 7 | 8 | def make_dataset(dataset_root, imageRoot, split, level='variant', 9 | subset=False): 10 | if level == 'variant': 11 | class_meta = 'variants' 12 | elif level == 'manufacturer': 13 | class_meta = 'manufacturers' 14 | elif level == 'family': 15 | class_meta = 'families' 16 | 17 | if split == 'train': 18 | split_suffix = '_train' 19 | elif split == 'val': 20 | split_suffix = '_val' 21 | elif split == 'train_val': 22 | split_suffix = '_trainval' 23 | elif split == 'test': 24 | split_suffix = '_test' 25 | else: 26 | ValueError('Unknown split: %s' % split) 27 | 28 | with open(os.path.join(dataset_root, 'data', 29 | 'images_' + level + split_suffix + '.txt')) as f: 30 | imgAnnoList = f.readlines() 31 | with open(os.path.join(dataset_root, 'data', class_meta + '.txt')) as f: 32 | classes = f.readlines() 33 | 34 | class_dict = {x.rstrip():idx for idx, x in enumerate(classes)} 35 | 36 | count = [0]*len(classes) 37 | img = [] 38 | for x in imgAnnoList: 39 | imgName = os.path.join(dataset_root, 'data', 'images', x.split()[0]+'.jpg') 40 | anno = class_dict[x.rstrip().split(' ', 1)[1]] 41 | if subset: 42 | if count[anno] >= 5: 43 | continue 44 | count[anno] += 1 45 | img.append((imgName, anno)) 46 | 47 | return img, classes 48 | 49 | class AircraftsDataset(data.Dataset): 50 | def __init__(self, dataset_root, split, subset=False, level='variant', 51 | transform=None, target_transform=None, 52 | loader=dataset_parser.default_loader): 53 | self.loader = loader 54 | self.dataset_root = dataset_root 55 | self.imageRoot = os.path.join(dataset_root, 'data', 'images') 56 | 57 | self.imgs, self.classes = make_dataset(self.dataset_root, 58 | self.imageRoot, split, level, subset) 59 | 60 | self.transform = transform 61 | self.target_transform = target_transform 62 | 63 | self.dataset_name = 'aircrafts' 64 | self.load_images = True 65 | self.feat_root = None 66 | 67 | def __getitem__(self, index): 68 | 69 | if self.load_images: 70 | path, target = self.imgs[index] 71 | img = self.loader(path) 72 | if self.transform is not None: 73 | img = [x(img) for x in self.transform] 74 | 75 | if self.target_transform is not None: 76 | target = self.target_transform(target) 77 | 78 | else: 79 | path, target = self.imgs[index] 80 | path = os.path.join(self.feat_root, path[len(self.imageRoot)+1:-3]) 81 | path = path + 'pt' 82 | img = torch.load(path) 83 | if self.target_transform is not None: 84 | target = self.target_transform(target) 85 | 86 | return (*img, target, path) 87 | 88 | def get_num_classes(self): 89 | return len(self.classes) 90 | 91 | def __len__(self): 92 | return len(self.imgs) 93 | 94 | def set_to_load_features(self, feat_root): 95 | self.load_images = False 96 | self.feat_root = feat_root 97 | 98 | def set_to_load_images(self): 99 | self.load_images = True 100 | self.feat_root = None 101 | -------------------------------------------------------------------------------- /BCNN.py: -------------------------------------------------------------------------------- 1 | import feature_extractor as fe 2 | import torchvision 3 | import torch 4 | import torch.nn as nn 5 | import functools 6 | import operator 7 | from compact_bilinear_pooling import CountSketch 8 | from torch.autograd import Function 9 | from matrixSquareRoot import MatrixSquareRoot 10 | import torch.nn.functional as F 11 | 12 | matrix_sqrt = MatrixSquareRoot.apply 13 | 14 | def create_backbone(model_name, finetune_model=True, use_pretrained=True): 15 | model_ft = None 16 | input_size = 0 17 | 18 | if model_name == 'vgg': 19 | """ VGG 20 | """ 21 | model_ft = fe.VGG() 22 | set_parameter_requires_grad(model_ft, finetune_model) 23 | 24 | output_dim = 512 25 | 26 | elif model_name == "resnet": 27 | """ Resnet101 28 | """ 29 | model_ft = fe.ResNet() 30 | set_parameter_requires_grad(model_ft, finetune_model) 31 | # num_ftrs = model_ft.fc.in_features 32 | # model_ft.fc = nn.Linear(num_ftrs, num_classes) 33 | # input_size = 224 34 | 35 | output_dim = 2048 36 | 37 | elif model_name == "densenet": 38 | """ Densenet 39 | """ 40 | model_ft = fe.DenseNet() 41 | set_parameter_requires_grad(model_ft, finetune_model) 42 | # num_ftrs = model_ft.classifier.in_features 43 | # model_ft.classifier = nn.Linear(num_ftrs, num_classes) 44 | # input_size = 224 45 | 46 | output_dim = 1920 47 | 48 | elif model_name == "inception": 49 | """ Inception v3 50 | Be careful, expects (299,299) sized images and has auxiliary output 51 | """ 52 | model_ft = fe.Inception() 53 | set_parameter_requires_grad(model_ft, finetune_model) 54 | 55 | output_dim = 2048 56 | else: 57 | # print("Invalid model name, exiting...") 58 | # logger.debug("Invalid mode name") 59 | exit() 60 | 61 | return model_ft, output_dim 62 | 63 | def set_parameter_requires_grad(model, requires_grad): 64 | if requires_grad: 65 | for param in model.parameters(): 66 | param.requires_grad = True 67 | 68 | # This implementation only works for VGG networks 69 | class MultiHeadsBCNN(nn.Module): 70 | def __init__(self, num_classes, feature_extractors=None): 71 | super(MultiHeadsBCNN, self).__init__() 72 | self.feature_extractors = feature_extractors 73 | dim_all_layers = feature_extractors.get_feature_dims() 74 | self.pooling_fn_list = nn.ModuleList( 75 | [TensorProduct([dim] * 2) for dim in dim_all_layers] 76 | ) 77 | self.feature_dim_list = [ 78 | pooling_fn.get_output_dim() 79 | for pooling_fn in self.pooling_fn_list 80 | ] 81 | self.fc_list = nn.ModuleList( 82 | [nn.Linear(feature_dim, num_classes, bias=True) 83 | for feature_dim in self.feature_dim_list] 84 | ) 85 | 86 | def forward(self, x): 87 | relu_acts = self.feature_extractors(x) 88 | 89 | bs, _, h1, w1 = x.shape 90 | 91 | bcnn_list = [ 92 | pooling_fn(z, z) 93 | for z, pooling_fn in zip(relu_acts, self.pooling_fn_list) 94 | ] 95 | bcnn_list = [ 96 | z.view(bs, feature_dim) 97 | for z, feature_dim in zip(bcnn_list, self.feature_dim_list) 98 | ] 99 | 100 | bcnn_list = [ 101 | torch.sqrt(F.relu(z) + 1e-5) - torch.sqrt(F.relu(-z) + 1e-5) 102 | for z in bcnn_list 103 | ] 104 | 105 | bcnn_list = [ 106 | torch.nn.functional.normalize(z) for 107 | z in bcnn_list 108 | ] 109 | 110 | y = [fc(z) for z, fc in zip(bcnn_list, self.fc_list)] 111 | 112 | return y 113 | 114 | 115 | class BCNNModule(nn.Module): 116 | def __init__(self, num_classes, feature_extractors=None, 117 | pooling_fn=None, order=2, m_sqrt_iter=0, demo_agg=False, 118 | fc_bottleneck=False, learn_proj=False): 119 | super(BCNNModule, self).__init__() 120 | 121 | assert feature_extractors is not None 122 | assert pooling_fn is not None 123 | 124 | self.feature_extractors = feature_extractors 125 | self.pooling_fn = pooling_fn 126 | 127 | self.feature_dim = self.pooling_fn.get_output_dim() 128 | if fc_bottleneck: 129 | self.fc = nn.Sequential(nn.Linear(self.feature_dim, 1024, bias=True), 130 | nn.Linear(1024, num_classes, bias=True)) 131 | else: 132 | self.fc = nn.Linear(self.feature_dim, num_classes, bias=True) 133 | 134 | # TODO assert m_sqrt is not used together with tensor sketch nor 135 | # the BCNN models without sharing 136 | if m_sqrt_iter > 0: 137 | self.m_sqrt = MatrixSquareRoot( 138 | m_sqrt_iter, 139 | int(self.feature_dim ** 0.5), 140 | backwardIter=5 141 | ) 142 | else: 143 | self.m_sqrt = None 144 | 145 | self.demo_agg = demo_agg 146 | self.order = order 147 | self.learn_proj = learn_proj 148 | 149 | def get_order(self): 150 | return self.order 151 | 152 | def forward(self, *args): 153 | x = self.feature_extractors(*args) 154 | 155 | bs, _, h1, w1 = x[0].shape 156 | for i in range(1, len(args)): 157 | _, _, h2, w2 = x[i].shape 158 | if h1 != h2 or w1 != w2: 159 | x[i] = torch.nn.functional.interpolate(x[i], size=(h1, w1), 160 | mode='bilinear') 161 | z = self.pooling_fn(*x) 162 | 163 | # TODO improve coding style, modulize normlaization operations 164 | # use a list of normalization operations 165 | # normalization 166 | 167 | if self.m_sqrt is not None: 168 | z = self.m_sqrt(z) 169 | z = z.view(bs, self.feature_dim) 170 | z = torch.sqrt(F.relu(z) + 1e-5) - torch.sqrt(F.relu(-z) + 1e-5) 171 | z = torch.nn.functional.normalize(z) 172 | 173 | # linear classifier 174 | y = self.fc(z) 175 | 176 | return y 177 | 178 | 179 | class MultiStreamsCNNExtractors(nn.Module): 180 | def __init__(self, backbones_list, dim_list, proj_dim=0): 181 | super(MultiStreamsCNNExtractors, self).__init__() 182 | self.feature_extractors = nn.ModuleList(backbones_list) 183 | if proj_dim > 0: 184 | temp = [nn.Sequential(x, \ 185 | nn.Conv2d(fe_dim, proj_dim, 1, 1, bias=False)) \ 186 | for x, fe_dim in zip(self.feature_extractors, dim_list)] 187 | self.feature_extractors = nn.ModuleList(temp) 188 | 189 | class BCNN_sharing(MultiStreamsCNNExtractors): 190 | def __init__(self, backbones_list, dim_list, proj_dim=0, order=2): 191 | super(BCNN_sharing, self).__init__(backbones_list, dim_list, proj_dim) 192 | 193 | # one backbone network for sharing parameters 194 | assert len(backbones_list) == 1 195 | 196 | self.order = order 197 | 198 | def get_number_output(self): 199 | return self.order 200 | 201 | def forward(self, *args): 202 | # y = self.feature_extractors[0](x) 203 | y = [self.feature_extractors[0](x) for x in args] 204 | 205 | if len(args) == 1: 206 | # out = y * self.order 207 | # y[0].register_hook(lambda grad: print(grad[0,0,:3,:3])) 208 | 209 | # return out 210 | return y * self.order 211 | # return [y for z in range(self.order)] 212 | else: 213 | return y 214 | 215 | class BCNN_no_sharing(MultiStreamsCNNExtractors): 216 | def __init__(self, backbones_list, dim_list, proj_dim=0): 217 | super(BCNN_no_sharing, self).__init__(backbones_list, dim_list, proj_dim) 218 | 219 | # two networks for the model without sharing 220 | assert len(backbones_list) >= 2 221 | self.order = len(backbones_list) 222 | 223 | def get_number_output(self): 224 | return self.order 225 | 226 | def forward(self, *args): 227 | y = [fe(x) for x, fe in zip(args, self.feature_extractors)] 228 | 229 | return y 230 | 231 | class TensorProduct(nn.Module): 232 | def __init__(self, dim_list): 233 | super(TensorProduct, self).__init__() 234 | self.output_dim = functools.reduce(operator.mul, dim_list) 235 | 236 | # Use tensor sketch for the order greater than 2 237 | assert len(dim_list) == 2 238 | 239 | def get_output_dim(self): 240 | return self.output_dim 241 | 242 | def forward(self, *args): 243 | (x1, x2) = args 244 | [bs, c1, h1, w1] = x1.size() 245 | [bs, c2, h2, w2] = x2.size() 246 | 247 | x1 = x1.view(bs, c1, h1*w1) 248 | x2 = x2.view(bs, c2, h2*w2) 249 | y = torch.bmm(x1, torch.transpose(x2, 1, 2)) 250 | 251 | # return y.view(bs, c1*c2) / (h1 * w1) 252 | return y / (h1 * w1) 253 | 254 | 255 | class TensorSketch(nn.Module): 256 | def __init__(self, dim_list, embedding_dim=4096, pooling=True, 257 | update_sketch=False): 258 | super(TensorSketch, self).__init__() 259 | 260 | 261 | self.output_dim = embedding_dim 262 | 263 | self.count_sketch = nn.ModuleList( 264 | [CountSketch(dim, embedding_dim, update_proj=update_sketch) \ 265 | for dim in dim_list]) 266 | self.pooling = pooling 267 | 268 | def get_output_dim(self): 269 | return self.output_dim 270 | 271 | def forward(self, *args): 272 | y = [sketch_fn(x.permute(0,2,3,1)) \ 273 | for x, sketch_fn in zip(args, self.count_sketch)] 274 | 275 | z = ApproxTensorProduct.apply(self.output_dim, *y) 276 | _, h, w, _ = z.shape 277 | 278 | if self.pooling: 279 | return torch.squeeze( 280 | torch.nn.functional.avg_pool2d(z.permute(0,3,1,2), (h, w))) 281 | else: 282 | return z.permute(0, 3, 1, 2) 283 | 284 | class SketchGammaDemocratic(nn.ModuleList): 285 | def __init__(self, dim_list, embedding_dim=4096, 286 | gamma=0, sinkhorn_t=0.5, sinkhorn_iter=10, update_sketch=False): 287 | super(SketchGammaDemocratic, self).__init__() 288 | self.sketch = TensorSketch(dim_list, embedding_dim, False, update_sketch) 289 | output_dim = self.sketch.get_output_dim() 290 | self.gamma_demo = GammaDemocratic(output_dim, gamma, sinkhorn_t, sinkhorn_iter) 291 | 292 | def forward(self, *args): 293 | x = self.sketch(*args) 294 | x = self.gamma_demo(x) 295 | 296 | return x 297 | 298 | def get_output_dim(self): 299 | return self.sketch.get_output_dim() 300 | 301 | class GammaDemocratic(nn.ModuleList): 302 | def __init__(self, output_dim, gamma=0, sinkhorn_t=0.5, sinkhorn_iter=10): 303 | super(GammaDemocratic, self).__init__() 304 | self.sinkhorn_t = sinkhorn_t # dampening parameter 305 | self.gamma = gamma 306 | self.sinkhorn_iter = sinkhorn_iter 307 | self.output_dim = output_dim 308 | 309 | def forward(self, x): 310 | [bs, ch, h, w] = x.shape 311 | x = x.view(bs, ch, -1).transpose(2, 1) 312 | # x.register_hook(self.save_grad('x')) 313 | 314 | K = x.bmm(x.transpose(2, 1)) 315 | K = (K + torch.abs(K)) / 2 316 | 317 | # alpha = torch.autograd.Variable(torch.ones(bs, h*w, 1)).cuda() 318 | alpha = torch.ones_like(x[:,:,[0]]) 319 | Ci = torch.sum(K, 2, keepdim=True) 320 | Ci = torch.pow(Ci, self.gamma).detach() 321 | 322 | for _ in range(self.sinkhorn_iter): 323 | # alpha = torch.pow(alpha + 1e-10, 1-self.sinkhorn_t) * \ 324 | # torch.pow(Ci + 1e-10, self.sinkhorn_t) / \ 325 | # (torch.pow(K.bmm(alpha) + 1e-10, self.sinkhorn_t) + 1e-10) 326 | alpha = torch.pow(Ci + 1e-10, self.sinkhorn_t) * \ 327 | torch.pow(alpha + 1e-10, 1-self.sinkhorn_t) / \ 328 | (torch.pow(K.bmm(alpha) + 1e-10, self.sinkhorn_t) + 1e-10) 329 | 330 | x = torch.sum(x * alpha, dim=1, keepdim=False) 331 | 332 | return x 333 | 334 | def get_output_dim(self): 335 | return self.output_dim 336 | 337 | class SecondOrderGammaDemocratic(nn.Module): 338 | def __init__(self, output_dim, gamma=0, sinkhorn_t=0.5, sinkhorn_iter=10): 339 | super(SecondOrderGammaDemocratic, self).__init__() 340 | self.sinkhorn_t = sinkhorn_t # dampening parameter 341 | self.sinkhorn_iter = sinkhorn_iter 342 | self.gamma = gamma 343 | self.iter = sinkhorn_iter 344 | # self.grad = {} 345 | self.output_dim = output_dim 346 | 347 | def forward(self, *args): 348 | # The forward assume args[0] == args[1]. This should be asserted during 349 | # model creation 350 | 351 | x = args[0] 352 | [bs, ch, h, w] = x.shape 353 | x = x.view(bs, ch, -1).transpose(2, 1) 354 | 355 | K = x.bmm(x.transpose(2, 1)) 356 | K = K * K; 357 | 358 | alpha = torch.ones_like(x[:,:,[0]]) 359 | Ci = torch.sum(K, 2, keepdim=True) 360 | Ci = torch.pow(Ci, self.gamma).detach() 361 | 362 | for _ in range(self.sinkhorn_iter): 363 | alpha = torch.pow(Ci + 1e-10, self.sinkhorn_t) * \ 364 | torch.pow(alpha + 1e-10, 1-self.sinkhorn_t) / \ 365 | (torch.pow(K.bmm(alpha) + 1e-10, self.sinkhorn_t) + 1e-10) 366 | 367 | x = x * torch.pow(alpha + 1e-8, 0.5) 368 | x = x.transpose(1, 2).bmm(x).view(bs, -1) 369 | 370 | return x 371 | 372 | def get_output_dim(self): 373 | return self.output_dim 374 | 375 | 376 | class ApproxTensorProduct(Function): 377 | 378 | @staticmethod 379 | def forward(ctx, embedding_dim, *args): 380 | fx = [torch.rfft(x, 1) for x in args] 381 | 382 | re_fx1 = fx[0].select(-1, 0) 383 | im_fx1 = fx[0].select(-1, 1) 384 | for i in range(1, len(fx)): 385 | re_fx2 = fx[i].select(-1, 0) 386 | im_fx2 = fx[i].select(-1, 1) 387 | 388 | # complex number multiplication 389 | Z_re = torch.addcmul(re_fx1*re_fx2, -1, im_fx1, im_fx2) 390 | Z_im = torch.addcmul(re_fx1*im_fx2, 1, im_fx1, re_fx2) 391 | re_fx1 = Z_re 392 | im_fx1 = Z_im 393 | 394 | ctx.save_for_backward(re_fx1, im_fx1, *fx) 395 | # ctx.save_for_backward(*fx) 396 | re = torch.irfft(torch.stack((re_fx1, im_fx1), re_fx1.dim()), 1, 397 | signal_sizes=(embedding_dim,)) 398 | 399 | ctx.embedding_dim = embedding_dim 400 | return re 401 | 402 | @staticmethod 403 | def backward(ctx, grad_output): 404 | 405 | grad_output = grad_output.contiguous() 406 | grad_prod = torch.rfft(grad_output, 1) 407 | grad_re_prod = grad_prod.select(-1, 0) 408 | grad_im_prod = grad_prod.select(-1, 1) 409 | 410 | re_fout = ctx.saved_tensors[0] 411 | im_fout = ctx.saved_tensors[1] 412 | 413 | fx = ctx.saved_tensors[2:] 414 | grad = [] 415 | for fi in fx: 416 | re_fi = fi.select(-1, 0) 417 | im_fi = fi.select(-1, 1) 418 | 419 | temp_norm = (re_fi**2 + im_fi**2 + 1e-8) 420 | temp_re = torch.addcmul(re_fout * re_fi, 1, im_fout, im_fi) \ 421 | / temp_norm 422 | temp_im = torch.addcmul(im_fout * re_fi, -1, re_fout, im_fi) \ 423 | /temp_norm 424 | grad_re = torch.addcmul(grad_re_prod * temp_re, 1, 425 | temp_im, grad_im_prod) 426 | grad_im = torch.addcmul(grad_im_prod * temp_re, -1, 427 | grad_re_prod, temp_im) 428 | grad_fi = torch.irfft( 429 | torch.stack((grad_re, grad_im), grad_re.dim()), 1, 430 | signal_sizes=(ctx.embedding_dim,)) 431 | grad.append(grad_fi) 432 | 433 | return (None, *grad) 434 | 435 | 436 | def create_bcnn_model(model_names_list, num_classes, 437 | pooling_method='outer_product', fine_tune=True, pre_train=True, 438 | embedding_dim=8192, order=2, m_sqrt_iter=0, 439 | fc_bottleneck=False, proj_dim=0, update_sketch=False, 440 | gamma=0.5): 441 | 442 | temp_list = [create_backbone(model_name, finetune_model=fine_tune, \ 443 | use_pretrained=pre_train) for model_name in model_names_list] 444 | 445 | temp_list = list(map(list, zip(*temp_list))) 446 | backbones_list = temp_list[0] 447 | 448 | # list of feature dimensions of the backbone networks 449 | dim_list = temp_list[1] 450 | 451 | # BCNN mdoels with sharing parameters. The computation of the two backbone 452 | # networks are shared resulting in a symmetric BCNN 453 | if len(backbones_list) == 1: 454 | assert order >= 2 455 | dim_list = dim_list * order 456 | feature_extractors = BCNN_sharing( 457 | backbones_list, 458 | dim_list, 459 | proj_dim, order 460 | ) 461 | else: 462 | feature_extractors = BCNN_no_sharing(backbones_list, dim_list, proj_dim) 463 | 464 | # update the reduced feature dimension in dim_list if there is 465 | # dimensionality reduction 466 | if proj_dim > 0: 467 | dim_list = [proj_dim for x in dim_list] 468 | 469 | if pooling_method == 'outer_product': 470 | pooling_fn = TensorProduct(dim_list) 471 | elif pooling_method == 'sketch': 472 | pooling_fn = TensorSketch(dim_list, embedding_dim, True, update_sketch) 473 | elif pooling_method == 'gamma_demo': 474 | assert isinstance(feature_extractors, BCNN_sharing) 475 | pooling_fn = SecondOrderGammaDemocratic(dim_list[0] ** 2, gamma=gamma, sinkhorn_t=0.5, 476 | sinkhorn_iter=10) 477 | elif pooling_method == 'sketch_gamma_demo': 478 | pooling_fn = SketchGammaDemocratic( 479 | dim_list, 480 | embedding_dim, 481 | gamma=gamma, 482 | sinkhorn_t=0.5, 483 | sinkhorn_iter=10, 484 | update_sketch=update_sketch 485 | ) 486 | else: 487 | raise ValueError('Unknown pooling method: %s' % pooling_method) 488 | 489 | learn_proj = True if proj_dim > 0 else False 490 | return BCNNModule( 491 | num_classes, 492 | feature_extractors, 493 | pooling_fn, 494 | order, 495 | m_sqrt_iter=m_sqrt_iter, 496 | fc_bottleneck=fc_bottleneck, 497 | learn_proj=learn_proj 498 | ) 499 | 500 | def create_multi_heads_bcnn(num_classes): 501 | 502 | backbone = fe.VGG_all_conv_features() 503 | return MultiHeadsBCNN(num_classes, backbone) 504 | 505 | -------------------------------------------------------------------------------- /CNN.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class DenseNet(nn.Module): 7 | def __init__(self, input_size): 8 | super(DenseNet, self).__init__() 9 | self.model = models.densenet201(pretrained=True) 10 | self.input = input_size 11 | self.output_dim = 1920 12 | 13 | def forward(self, x): 14 | x = self.model.features(x) 15 | x = F.relu(x, inplace=True) 16 | x = F.adaptive_avg_pool2d(x, (1, 1)) 17 | 18 | return x 19 | 20 | def get_output_dim(self): 21 | return self.output_dim 22 | 23 | class ResNet(nn.Module): 24 | def __init__(self, input_size): 25 | super(ResNet, self).__init__() 26 | 27 | self.model = models.resnet101(pretrained=True) 28 | self.input_size = input_size 29 | ''' 30 | if input_size == 448: 31 | kernel_size = 2 * self.model.avgpool.kernel_size 32 | self.model.avgpool = nn.AvgPool2d(kernel_size) 33 | ''' 34 | self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 35 | self.output_dim = 2048 36 | delattr(self.model, 'fc') 37 | 38 | def forward(self, x): 39 | x = self.model.conv1(x) 40 | x = self.model.bn1(x) 41 | x = self.model.relu(x) 42 | x = self.model.maxpool(x) 43 | 44 | x = self.model.layer1(x) 45 | x = self.model.layer2(x) 46 | x = self.model.layer3(x) 47 | x = self.model.layer4(x) 48 | x = self.model.avgpool(x) 49 | 50 | return x 51 | 52 | def get_output_dim(self): 53 | return self.output_dim 54 | 55 | class AlexNet(nn.Module): 56 | def __init__(self): 57 | super(AlexNet, self).__init__() 58 | self.model = models.alexnet(pretrained=True) 59 | temp = list(self.model.classifier.children()) 60 | self.model.classifier = nn.Sequential(*temp[:-1]) 61 | self.input_size = 224 62 | self.output_dim = 4096 63 | 64 | def forward(self, x): 65 | x = self.model.features(x) 66 | x = x.view(x.size(0), -1) 67 | x = self.model.classifier(x) 68 | return x 69 | 70 | def get_output_dim(self): 71 | return self.output_dim 72 | 73 | class VGG(nn.Module): 74 | def __init__(self): 75 | super(VGG, self).__init__() 76 | self.model = models.vgg16(pretrained=True) 77 | temp = list(self.model.classifier.children()) 78 | self.model.classifier = nn.Sequential(*temp[:-1]) 79 | # self.model = models.vgg16(pretrained=True).features 80 | # self.model = torch.nn.Sequential(*list(self.model.children())[:-1]) 81 | self.input_size = 224 82 | self.output_dim = 4096 83 | 84 | def forward(self, x): 85 | x = self.model.features(x) 86 | x = x.view(x.size(0), -1) 87 | x = self.model.classifier(x) 88 | # x = self.model.features(x) 89 | return x 90 | 91 | def get_output_dim(self): 92 | return self.output_dim 93 | 94 | 95 | class Inception(nn.Module): 96 | def __init__(self): 97 | super(Inception, self).__init__() 98 | self.model = models.inception_v3(pretrained=True) 99 | self.input_size = 299 100 | 101 | def forward(self, x): 102 | if self.model.transform_input: 103 | ''' 104 | x = x.clone() 105 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 106 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 107 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 108 | ''' 109 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 110 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 111 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 112 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 113 | 114 | x = self.model.Conv2d_1a_3x3(x) 115 | x = self.model.Conv2d_2a_3x3(x) 116 | x = self.model.Conv2d_2b_3x3(x) 117 | x = F.max_pool2d(x, kernel_size=3, stride=2) 118 | x = self.model.Conv2d_3b_1x1(x) 119 | x = self.model.Conv2d_4a_3x3(x) 120 | x = F.max_pool2d(x, kernel_size=3, stride=2) 121 | x = self.model.Mixed_5b(x) 122 | x = self.model.Mixed_5c(x) 123 | x = self.model.Mixed_5d(x) 124 | x = self.model.Mixed_6a(x) 125 | x = self.model.Mixed_6b(x) 126 | x = self.model.Mixed_6c(x) 127 | x = self.model.Mixed_6d(x) 128 | x = self.model.Mixed_6e(x) 129 | x = self.model.Mixed_7a(x) 130 | x = self.model.Mixed_7b(x) 131 | x = self.model.Mixed_7c(x) 132 | 133 | return x 134 | 135 | class CNN_Model(nn.Module): 136 | def __init__(self, num_classes, feature_dim, feature_extractors=None): 137 | super(CNN_Model, self).__init__() 138 | self.feature_extractors = feature_extractors 139 | self.feature_dim = feature_dim 140 | self.fc = nn.Linear(self.feature_dim, num_classes, bias=True) 141 | 142 | def forward(self, x): 143 | x = self.feature_extractors(x) 144 | x = x.view(x.size(0), -1) 145 | y = self.fc(x) 146 | 147 | return y 148 | 149 | def create_cnn_model(model_name, num_classes, input_size=224, 150 | fine_tune=True, pre_train=True): 151 | ''' 152 | if input_size != 224: 153 | assert model_name == 'resnet' 154 | ''' 155 | if model_name == 'vgg': 156 | feature_extractors = VGG() 157 | elif model_name == 'resnet': 158 | feature_extractors = ResNet(input_size) 159 | elif model_name == 'densenet': 160 | feature_extractors = DenseNet(input_size) 161 | else: 162 | exit() 163 | 164 | output_dim = feature_extractors.get_output_dim() 165 | return CNN_Model(num_classes, output_dim, feature_extractors) 166 | 167 | -------------------------------------------------------------------------------- /CUBDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import numpy 4 | import os 5 | from torchvision.datasets import folder as dataset_parser 6 | import json 7 | 8 | def make_dataset(dataset_root, imageRoot, split, classes, subset=False, 9 | create_val=True): 10 | with open(os.path.join(dataset_root, 'train_test_split.txt'), 'r') as f: 11 | setList = f.readlines() 12 | with open(os.path.join(dataset_root, 'images.txt'), 'r') as f: 13 | imgList = f.readlines() 14 | with open(os.path.join(dataset_root, 'image_class_labels.txt'), 'r') as f: 15 | annoList = f.readlines() 16 | 17 | if split == 'train': 18 | setIdx = [1] 19 | elif split == 'val': 20 | setIdx = [0] 21 | elif split == 'test': 22 | setIdx = [-1] 23 | elif split == 'train_val': 24 | setIdx = [1, 0] 25 | else: 26 | ValueError('Unknown split: %s' % split) 27 | 28 | setDict = [x.split() for x in setList] 29 | setDict = {x[0]:int(x[1]) for x in setDict} 30 | 31 | if create_val: 32 | numpy.random.seed(0) 33 | trainList = [k for k, v in setDict.items() if v == 1] 34 | trainList = numpy.random.permutation(trainList) 35 | valNum = numpy.ceil(0.333*len(trainList)).astype(int) 36 | valList = trainList[:valNum] 37 | trainList = trainList[valNum:] 38 | 39 | for k, v in setDict.items(): 40 | if setDict[k] == 0: 41 | setDict[k] = -1 42 | for k in valList: 43 | setDict[k] = 0 44 | 45 | numpy.random.seed() 46 | 47 | imgDict = [x.split() for x in imgList] 48 | imgDict = {x[0]:x[1] for x in imgDict} 49 | 50 | img = [] 51 | count = [0]*len(classes) 52 | for anno in annoList: 53 | temp = anno.split() 54 | label = int(temp[1]) - 1 55 | imgKey = temp[0] 56 | if setDict[imgKey] not in setIdx: 57 | continue 58 | if subset: 59 | if count[label] >= 5: 60 | continue 61 | count[label] += 1 62 | imageName = os.path.join(imageRoot, imgDict[imgKey]) 63 | img.append((imageName, label)) 64 | 65 | return img 66 | 67 | class CUBDataset(data.Dataset): 68 | def __init__(self, dataset_root, split, subset=False, transform=None, 69 | create_val=True, target_transform=None, 70 | loader=dataset_parser.default_loader): 71 | self.loader = loader 72 | self.dataset_root = dataset_root 73 | self.imageRoot = os.path.join(dataset_root, 'images') 74 | self.split = split 75 | 76 | with open(os.path.join(dataset_root, 'classes.txt'), 'r') as f: 77 | clsList = f.readlines() 78 | 79 | self.classes = [x.split()[1] for x in clsList] 80 | 81 | self.imgs = make_dataset(self.dataset_root, self.imageRoot, split, 82 | self.classes, subset, create_val=create_val) 83 | self.transform = transform 84 | self.target_transform = target_transform 85 | 86 | self.dataset_name = 'cub' 87 | self.load_images = True 88 | self.feat_root = None 89 | 90 | def __getitem__(self, index): 91 | 92 | if self.load_images: 93 | path, target = self.imgs[index] 94 | img = self.loader(path) 95 | if self.transform is not None: 96 | img = [x(img) for x in self.transform] 97 | 98 | if self.target_transform is not None: 99 | target = self.target_transform(target) 100 | 101 | else: 102 | path, target = self.imgs[index] 103 | path = os.path.join(self.feat_root, path[len(self.imageRoot)+1:-3]) 104 | path = path + 'pt' 105 | img = torch.load(path) 106 | if self.target_transform is not None: 107 | target = self.target_transform(target) 108 | 109 | return (*img, target, path) 110 | 111 | def get_num_classes(self): 112 | return len(self.classes) 113 | 114 | def __len__(self): 115 | return len(self.imgs) 116 | 117 | def set_to_load_features(self, feat_root): 118 | self.load_images = False 119 | self.feat_root = feat_root 120 | 121 | def set_to_load_images(self): 122 | self.load_images = True 123 | self.feat_root = None 124 | -------------------------------------------------------------------------------- /CarsDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import numpy as np 4 | from scipy.io import loadmat as loadmat 5 | 6 | import os 7 | from torchvision.datasets import folder as dataset_parser 8 | import json 9 | 10 | def make_dataset(meta, split, dataset_root, classes, subset=False, 11 | create_val=True): 12 | 13 | imgList = [str(x[0][0]) for x in meta['annotations'][0]] 14 | setList = [int(np.squeeze(x[6])) for x in meta['annotations'][0]] 15 | annoList = [int(np.squeeze(x[5])) for x in meta['annotations'][0]] 16 | 17 | if split == 'train': 18 | setIdx = [0] 19 | elif split == 'val': 20 | setIdx = [1] 21 | elif split == 'test': 22 | setIdx = [-1] 23 | elif split == 'train_val': 24 | setIdx = [1, 0] 25 | else: 26 | ValueError('Unknown split: %s' % split) 27 | 28 | if create_val: 29 | np.random.seed(0) 30 | trainList = [idx for idx, v in enumerate(setList) if v==0] 31 | trainList = np.random.permutation(trainList) 32 | valNum = np.ceil(0.333*len(trainList)).astype(int) 33 | valList = trainList[:valNum] 34 | trainList = trainList[valNum:] 35 | 36 | for idx, v in enumerate(setList): 37 | if v == 1: 38 | setList[idx] = -1 39 | for k in valList: 40 | setList[k] = 1 41 | 42 | np.random.seed() 43 | 44 | img = [] 45 | count = [0]*len(classes) 46 | for idx, anno in enumerate(annoList): 47 | label = anno - 1 48 | if setList[idx] not in setIdx: 49 | continue 50 | if subset: 51 | if count[label] >= 5: 52 | continue 53 | count[label] += 1 54 | imageName = os.path.join(dataset_root, imgList[idx]) 55 | img.append((imageName, label)) 56 | 57 | return img 58 | 59 | 60 | class CarsDataset(data.Dataset): 61 | def __init__(self, dataset_root, split, subset=False, transform=None, 62 | create_val=True, target_transform=None, 63 | loader=dataset_parser.default_loader): 64 | self.loader = loader 65 | self.dataset_root = dataset_root 66 | self.imageRoot = os.path.join(dataset_root, 'car_ims') 67 | 68 | meta = loadmat(os.path.join(dataset_root, 'cars_annos.mat')) 69 | class_meta = meta['class_names'][0] 70 | self.classes = [np.array_str(x) for x in class_meta] 71 | 72 | self.imgs = make_dataset(meta, split, self.dataset_root, 73 | self.classes, subset, create_val=create_val) 74 | self.transform = transform 75 | self.target_transform = target_transform 76 | 77 | self.dataset_name = 'cars' 78 | self.load_images = True 79 | self.feat_root = None 80 | 81 | def __getitem__(self, index): 82 | 83 | if self.load_images: 84 | path, target = self.imgs[index] 85 | img = self.loader(path) 86 | if self.transform is not None: 87 | img = [x(img) for x in self.transform] 88 | 89 | if self.target_transform is not None: 90 | target = self.target_transform(target) 91 | 92 | else: 93 | path, target = self.imgs[index] 94 | path = os.path.join(self.feat_root, path[len(self.imageRoot)+1:-3]) 95 | path = path + 'pt' 96 | img = torch.load(path) 97 | if self.target_transform is not None: 98 | target = self.target_transform(target) 99 | 100 | return (*img, target, path) 101 | 102 | def get_num_classes(self): 103 | return len(self.classes) 104 | 105 | def __len__(self): 106 | return len(self.imgs) 107 | 108 | def set_to_load_features(self, feat_root): 109 | self.load_images = False 110 | self.feat_root = feat_root 111 | 112 | def set_to_load_images(self): 113 | self.load_images = True 114 | self.feat_root = None 115 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Tsung-Yu Lin and Subhransu Maji 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | 21 | ************************************************************************* 22 | 23 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION 24 | This package incorporates material from the pytorch_compact_bilineaar_pooling 25 | projects (https://github.com/gdlg/pytorch_compact_bilinear_pooling) 26 | listed below. The original copyright notice and license are included below. 27 | 28 | Copyright (c) 2017, Grégoire Payen de La Garanderie, Durham University 29 | All rights reserved. 30 | 31 | Redistribution and use in source and binary forms, with or without 32 | modification, are permitted provided that the following conditions are met: 33 | 34 | * Redistributions of source code must retain the above copyright notice, this 35 | list of conditions and the following disclaimer. 36 | 37 | * Redistributions in binary form must reproduce the above copyright notice, 38 | this list of conditions and the following disclaimer in the documentation 39 | and/or other materials provided with the distribution. 40 | 41 | * Neither the name of the copyright holder nor the names of its 42 | contributors may be used to endorse or promote products derived from 43 | this software without specific prior written permission. 44 | 45 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 46 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 47 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 48 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 49 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 50 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 51 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 52 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 53 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 54 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /MITIndoorDataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy 3 | import os 4 | from torchvision.datasets import folder as dataset_parser 5 | 6 | 7 | def make_dataset( 8 | dataset_root, 9 | imageRoot, 10 | split, 11 | classes, 12 | class_to_anno, 13 | subset=False, 14 | create_val=True, 15 | ): 16 | 17 | if split == "test": 18 | with open(os.path.join(dataset_root, "TestImages.txt"), "r") as f: 19 | imgList = f.readlines() 20 | else: 21 | if split == "val": 22 | assert create_val 23 | with open(os.path.join(dataset_root, "TrainImages.txt"), "r") as f: 24 | imgList = f.readlines() 25 | 26 | imgList = [x.rstrip("\n") for x in imgList] 27 | 28 | if split in ["train", "val"] and create_val: 29 | valNum = numpy.ceil(0.333 * len(imgList)).astype(int) 30 | numpy.random.seed(0) 31 | numpy.random.shuffle(imgList) 32 | if split == "train": 33 | imgList = imgList[valNum:] 34 | else: 35 | imgList = imgList[:valNum] 36 | 37 | numpy.random.seed() 38 | 39 | annoList = [class_to_anno[x.split("/")[0]] for x in imgList] 40 | img = [] 41 | for img_name, anno in zip(imgList, annoList): 42 | img.append((os.path.join(imageRoot, img_name), anno)) 43 | 44 | return img 45 | 46 | 47 | class MITIndoorDataset(data.Dataset): 48 | def __init__( 49 | self, 50 | dataset_root, 51 | split, 52 | subset=False, 53 | transform=None, 54 | create_val=True, 55 | target_transform=None, 56 | loader=dataset_parser.default_loader, 57 | ): 58 | self.loader = loader 59 | self.dataset_root = dataset_root 60 | self.imageRoot = os.path.join(dataset_root, "Images") 61 | self.split = split 62 | 63 | self.classes = list(os.listdir(self.imageRoot)) 64 | self.classes.sort() 65 | self.class_to_anno = {x: i for i, x in enumerate(self.classes)} 66 | 67 | self.imgs = make_dataset( 68 | self.dataset_root, 69 | self.imageRoot, 70 | split, 71 | self.classes, 72 | self.class_to_anno, 73 | subset, 74 | create_val=create_val, 75 | ) 76 | self.transform = transform 77 | self.target_transform = target_transform 78 | 79 | self.dataset_name = "mit_indoor" 80 | 81 | def __getitem__(self, index): 82 | 83 | path, target = self.imgs[index] 84 | img = self.loader(path) 85 | if self.transform is not None: 86 | img = [x(img) for x in self.transform] 87 | 88 | if self.target_transform is not None: 89 | target = self.target_transform(target) 90 | 91 | return (*img, target, path) 92 | 93 | def get_num_classes(self): 94 | return len(self.classes) 95 | 96 | def __len__(self): 97 | return len(self.imgs) 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | This repository contains PyTorch implementations for training bilinear 3 | (second-order) CNNs. 4 | The series of works listed below investigates bilinear pooling of 5 | convolutional features for fine-grained recognition. 6 | This repository constructs symmetric BCNNs, which represent images as 7 | covariance matrices of CNN activations. More details can be found in 8 | the PAMI 2017 paper and [Tsung-Yu Lin's PhD 9 | thesis](http://vis-www.cs.umass.edu/papers/tsungyu_thesis.pdf). 10 | 11 | 1. [Bilinear CNN Models for Fine-grained Visual 12 | Recognition](http://vis-www.cs.umass.edu/bcnn/), Tsung-Yu Lin, 13 | Aruni RoyChowdhury and Subhransu Maji, ICCV 2015 14 | 1. [Visualizing and Understanding Deep Texture 15 | Representations](http://vis-www.cs.umass.edu/bcnn/), Tsung-Yu Lin, 16 | and Subhransu Maji, CVPR 2016 17 | 1. [Improved Bilinear Pooling with 18 | CNNs](http://vis-www.cs.umass.edu/bcnn/), Tsung-Yu Lin, and 19 | Subhransu Maji, BMVC 2017 20 | 1. [Bilinear Convolutional Neural Networks for Fine-grained Visual 21 | Recognition](http://vis-www.cs.umass.edu/bcnn/), Tsung-Yu Lin, 22 | Aruni RoyChowdhury and Subhransu Maji, PAMI 2017 23 | 1. [Second-order Democratic 24 | Aggregation](http://vis-www.cs.umass.edu/o2dp/), Tsung-Yu Lin, 25 | Subhransu Maji and Piotr Koniusz, ECCV 2018 26 | 27 | 28 | 29 | In particular, we provide the code for: 30 | 1. training BCNN models 31 | 2. training improved BCNN models (via matrix square-root normalization) 32 | 3. training the CNN models with second-order democratic aggregation 33 | 4. inverting fine-grained categories with BCNN representations 34 | 35 | The prerequisite can be installed by `pip install -r requirements.txt`. 36 | 37 | Links to the original implementations in Matlab and MatConvNet can 38 | be found in the project webpages. 39 | Please cite the appropriate papers if you find this code 40 | useful. 41 | 42 | 43 | ## Datasets 44 | To get started download the following datasets and point the 45 | corresponding entries in `config.py` to the location where you 46 | download the data (you can start with the CUB dataset first). 47 | * Caltech-UCSD Birds: [CUB-200-2011 dataset](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html). 48 | * FGVC Aircrafts: [FGVC aircraft dataset](http://www.robots.ox.ac.uk/~vgg/data/oid/) 49 | * Stanford Cars: [Stanford cars dataset](http://ai.stanford.edu/~jkrause/cars/car_dataset.html) 50 | * MIT Indoor: [MIT indoor scenes dataset](http://web.mit.edu/torralba/www/indoor.html) 51 | 52 | The results obtained from using the code in this repository are 53 | summarized in the following table. Note that the accuracy reported here are obtained by 54 | the softmax classifier instead of SVM (unlike the ICCV15 paper). 55 | 56 | | Datasets | BCNN [VGG-D] | Improved BCNN [VGG-D] | Improved BCNN [DenseNet201] 57 | | :--- | :----: | :---: | :--: 58 | | Birds | 84.1% | 85.5% | 87.5% 59 | | Cars | 90.5% | 92.5% | 92.9% 60 | | Aircrafts | 87.5% | 90.7% | 90.6% 61 | 62 | 63 | ## Pre-trained models 64 | We provide fine-tuned [BCNN](http://maxwell.cs.umass.edu/bcnn/pytorch_models/bcnn_vgg/) and [Improved BCNN](http://maxwell.cs.umass.edu/bcnn/pytorch_models/impbcnn_vgg/) models on VGG. In addition, we also provide the fine-tuned [Improved BCNN](http://maxwell.cs.umass.edu/bcnn/pytorch_models/impbcnn_densenet/) models on DenseNet. All the models can be download together in the [tar.gz file](http://maxwell.cs.umass.edu/bcnn/pytorch_models/pytorch_bcnn_pretrained_models.tar.gz). We provide the code to evaluate the pre-trained models. The models are assumed in the folder `pretrained_models`. You can run the following commands to evaluate the models: 65 | 66 | python test.py --pretrained_filename bcnn_vgg/cub_bcnn_vgg_models.pth.tar --dataset cub 67 | python test.py --pretrained_filename impbcnn_vgg/cub_impbcnn_vgg_models.pth.tar --dataset cub --matrix_sqrt_iter 5 68 | python test.py --model_names_list densenet --proj_dim 128 --pretrained_filename impbcnn_densenet/cub_impbcnn_densenet_models.pth.tar --dataset cub --matrix_sqrt_iter 5 69 | 70 | ## Training bilinear CNN 71 | The following command is used to train the BCNN model with VGG-D as 72 | backbone on the CUB dataset: 73 | 74 | python train.py --lr 1e-4 --optimizer adam --exp bcnn_vgg --dataset cub --batch_size 16 --model_names_list vgg 75 | 76 | This will construct a bilinear model with ImageNet pretrained VGG-D 77 | as the backbone network and start the training a linear + softmax 78 | layer to predict the 200 categories. 79 | You will see the output as follows: 80 | 81 | 82 | Iteration 624/9365 83 | ---------- 84 | Train Loss: 4.3640 Acc: 0.1890 85 | Validation Loss: 3.2716 Acc: 0.3643 86 | Iteration 1249/9365 87 | ---------- 88 | Train Loss: 2.3581 Acc: 0.5791 89 | Validation Loss: 2.0965 Acc: 0.5865 90 | Iteration 1874/9365 91 | ---------- 92 | Train Loss: 1.4570 Acc: 0.7669 93 | Validation Loss: 1.6335 Acc: 0.6717 94 | 95 | After the linear classifier is trained, end-to-end fine-tuning will 96 | start automatically (as a comparision point the validation accuracy 97 | reaches 76.2% on CUB after 40mins on a NVIDIA TitanX GPU). 98 | The intermediate checkpoints, models, and the results can be found in the folder 99 | `../exp/cub/bcnn_vgg`. 100 | We used the standard training split from Birds and 101 | Cars, and the training + val split from Aircrafts for training. 102 | The validation set is set as `test`. See `config.py` file for details 103 | and the corresponding dataset loaders. 104 | The test accuracy can be read off 105 | directly from the log file `train_history.txt` (see the table above 106 | for the final accuracy). 107 | 108 | 109 | ## Training improved bilinear CNN 110 | These incorporate matrix normalization layers. In particular the 111 | covariance representations with matrix square-root function, 112 | implemented efficiently using iterative methods presented in the 113 | BMVC17 paper. 114 | The following command is used to train the model: 115 | 116 | python train.py --lr 1e-4 --optimizer adam --matrix_sqrt_iter 5 --exp impbcnn_vgg --batch_size 16 --dataset cub --model_names_list vgg 117 | 118 | The intermediate checkpoints, models, and the results can be found in 119 | the folder `../exp/cub/impbcnn_vgg`. Adding the matrix normalization 120 | layer adds little overhead to the feed-forward computation but 121 | provides consistent improments in accuracy as seen in the table above. 122 | 123 | 124 | Replacing the VGG-D networks with a DenseNet or ResNet provides 125 | further gains. 126 | However extracting second-order representations requires significantly 127 | more memory since the emebedding dimension of the activations is 1920 (thus the 128 | covariance matrix is 1920x1920). 129 | The code performs a low-rank projection to the feature before 130 | computing the outer product (the dimension can be controlled using the 131 | argument `proj_dim`.) 132 | The following command is used to train a DenseNet based model: 133 | 134 | 135 | python train.py --lr 1e-4 --optimizer adam --matrix_sqrt_iter 5 --exp impbcnn_desnsenet --batch_size 16 --dataset cub --model_names_list densenet --proj_dim 128 136 | 137 | | Datasets | Birds | Cars | Aircrafts | 138 | | :--- | :----: | :---: | :--: | 139 | | Accuracy | 87.5% | 92.9% | 90.6% | 140 | 141 | ## Training second-order democratic aggregation 142 | This provides an alternative to reweighting feature importance by democratic aggregation. The approach can be combined with Tensor Sketch to reduce feature dimension. The following commands reproducing the result (84.3% accuracy) on MIT Indoor dataset without end-to-end fine-tuning using ResNet-101 reported in ECCV'18 paper: 143 | 144 | python train.py --init_lr 1 --init_wd 1e-5 --optimizer sgd --exp democratic_resnet_sketch --dataset mit_indoor --pooling_method sketch_gamma_demo --model_names_list resnet --no_finetune --init_epoch 70 145 | 146 | The accuracy can be read off from the log file `train_init_history.txt` located in `../exp/cub/sketch_gamma_demo`. You can also train the model end-to-end. The following command is used to train the model with VGG-D as backbone. 147 | 148 | python train.py --lr 1e-4 --optimizer adam --exp democratic_vgg --dataset cub --batch_size 16 --pooling_method gamma_demo --model_names_list vgg 149 | 150 | The intermediate checkpoints, models, and the results can be found in the folder `../exp/cub/democratic_vgg`. 151 | 152 | ## Visualizing fine-grained categories as textures 153 | Second-order representations are known to capture texture properties 154 | (also see the work in Style transfer [Gatys et al.] and Covariance 155 | representations [Portilla and Simoncelli]). 156 | Hence by visualizing maximal images for each category according to the 157 | bilinear model we can gain some insights into what texture are different 158 | categories. 159 | The approach is based on gradient-based optimization of the inputs to 160 | model to maximize the score of target categories. 161 | The following command computes inverses for all categories in the 162 | birds dataset: 163 | 164 | 165 | python inversion.py --exp_dir invert_categories 166 | 167 | The code starts by training softmax classifers on top of BCNN 168 | representations extracted from the layers `{relu2_2, relu3_3, relu4_3, 169 | relu5_3}` and then find images that maximize the prediction scores 170 | for each categories plus a regularization term via gradient ascent. 171 | Using intermediate layers results in better multi-scale texture 172 | representations (color and small-scale details are better preserved). 173 | You can find the output images as shown in the following in the folder: 174 | `../exp_inversion/cub/invert_categories/inv_output`. Some examples below: 175 | 176 | ![example-1](inv_images/002.Laysan_Albatross.png) 177 | ![example-2](inv_images/005.Crested_Auklet.png) 178 | ![example-3](inv_images/018.Spotted_Catbird.png) 179 | ![example-4](inv_images/010.Red_winged_Blackbird.png) 180 | ![example-5](inv_images/012.Yellow_headed_Blackbird.png) 181 | ![example-6](inv_images/014.Indigo_Bunting.png) 182 | ![example-7](inv_images/017.Cardinal.png) 183 | ![example-8](inv_images/019.Gray_Catbird.png) 184 | ![example-9](inv_images/024.Red_faced_Cormorant.png) 185 | -------------------------------------------------------------------------------- /compact_bilinear_pooling/__init__.py: -------------------------------------------------------------------------------- 1 | import types 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Function 5 | 6 | 7 | def CountSketchFn_forward(h, s, output_size, x, force_cpu_scatter_add=False): 8 | x_size = tuple(x.size()) 9 | 10 | s_view = (1,) * (len(x_size)-1) + (x_size[-1],) 11 | 12 | out_size = x_size[:-1] + (output_size,) 13 | 14 | # Broadcast s and compute x * s 15 | s = s.view(s_view) 16 | xs = x * s 17 | 18 | # Broadcast h then compute h: 19 | # out[h_i] += x_i * s_i 20 | h = h.view(s_view).expand(x_size) 21 | 22 | if force_cpu_scatter_add: 23 | out = x.new(*out_size).zero_().cpu() 24 | return out.scatter_add_(-1, h.cpu(), xs.cpu()).cuda() 25 | else: 26 | out = x.new(*out_size).zero_() 27 | return out.scatter_add_(-1, h, xs) 28 | 29 | 30 | def CountSketchFn_backward(h, s, x, x_size, grad_output): 31 | s_view = (1,) * (len(x_size)-1) + (x_size[-1],) 32 | 33 | s = s.view(s_view) 34 | h = h.view(s_view).expand(x_size) 35 | 36 | grad_x = grad_output.gather(-1, h) 37 | grad_x = grad_x * s 38 | 39 | grad_s = torch.sum(grad_output.gather(-1, h) * x, dim=(0,1,2)) 40 | return grad_x, grad_s 41 | 42 | class CountSketchFn(Function): 43 | 44 | @staticmethod 45 | def forward(ctx, h, s, output_size, x, force_cpu_scatter_add=False): 46 | x_size = tuple(x.size()) 47 | 48 | ctx.save_for_backward(h, s, x) 49 | ctx.x_size = tuple(x.size()) 50 | 51 | return CountSketchFn_forward(h, s, output_size, x, force_cpu_scatter_add) 52 | 53 | 54 | @staticmethod 55 | def backward(ctx, grad_output): 56 | h, s, x = ctx.saved_variables 57 | 58 | grad_x, grad_s = CountSketchFn_backward(h, s, x, ctx.x_size,grad_output) 59 | return None, grad_s, None, grad_x 60 | 61 | class CountSketch(nn.Module): 62 | r"""Compute the count sketch over an input signal. 63 | 64 | .. math:: 65 | 66 | out_j = \sum_{i : j = h_i} s_i x_i 67 | 68 | Args: 69 | input_size (int): Number of channels in the input array 70 | output_size (int): Number of channels in the output sketch 71 | h (array, optional): Optional array of size input_size of indices in the range [0,output_size] 72 | s (array, optional): Optional array of size input_size of -1 and 1. 73 | 74 | .. note:: 75 | 76 | If h and s are None, they will be automatically be generated using LongTensor.random_. 77 | 78 | Shape: 79 | - Input: (...,input_size) 80 | - Output: (...,output_size) 81 | 82 | References: 83 | Yang Gao et al. "Compact Bilinear Pooling" in Proceedings of IEEE Conference on Computer Vision and Pattern Recognition (2016). 84 | Akira Fukui et al. "Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding", arXiv:1606.01847 (2016). 85 | """ 86 | 87 | def __init__(self, input_size, output_size, h=None, s=None, update_proj=False): 88 | super(CountSketch, self).__init__() 89 | 90 | self.input_size = input_size 91 | self.output_size = output_size 92 | 93 | if h is None: 94 | h = torch.LongTensor(input_size).random_(0, output_size) 95 | if s is None: 96 | s = 2 * torch.Tensor(input_size).random_(0,2) - 1 97 | 98 | # The Variable h being a list of indices, 99 | # If the type of this module is changed (e.g. float to double), 100 | # the variable h should remain a LongTensor 101 | # therefore we force float() and double() to be no-ops on the variable h. 102 | def identity(self): 103 | return self 104 | 105 | h.float = types.MethodType(identity,h) 106 | h.double = types.MethodType(identity,h) 107 | 108 | self.register_buffer('h', h) 109 | if not update_proj: 110 | self.register_buffer('s', s) 111 | else: 112 | # self.register_parameter('s', s) 113 | self.s = nn.Parameter(s) 114 | 115 | def forward(self, x): 116 | x_size = list(x.size()) 117 | 118 | assert(x_size[-1] == self.input_size) 119 | 120 | return CountSketchFn.apply(self.h, self.s, self.output_size, x) 121 | 122 | def ComplexMultiply_forward(X_re, X_im, Y_re, Y_im): 123 | Z_re = torch.addcmul(X_re*Y_re, -1, X_im, Y_im) 124 | Z_im = torch.addcmul(X_re*Y_im, 1, X_im, Y_re) 125 | return Z_re,Z_im 126 | 127 | def ComplexMultiply_backward(X_re, X_im, Y_re, Y_im, grad_Z_re, grad_Z_im): 128 | grad_X_re = torch.addcmul(grad_Z_re * Y_re, 1, grad_Z_im, Y_im) 129 | grad_X_im = torch.addcmul(grad_Z_im * Y_re, -1, grad_Z_re, Y_im) 130 | grad_Y_re = torch.addcmul(grad_Z_re * X_re, 1, grad_Z_im, X_im) 131 | grad_Y_im = torch.addcmul(grad_Z_im * X_re, -1, grad_Z_re, X_im) 132 | return grad_X_re,grad_X_im,grad_Y_re,grad_Y_im 133 | 134 | class ComplexMultiply(torch.autograd.Function): 135 | 136 | @staticmethod 137 | def forward(ctx, X_re, X_im, Y_re, Y_im): 138 | ctx.save_for_backward(X_re,X_im,Y_re,Y_im) 139 | return ComplexMultiply_forward(X_re, X_im, Y_re, Y_im) 140 | 141 | @staticmethod 142 | def backward(ctx,grad_Z_re, grad_Z_im): 143 | X_re,X_im,Y_re,Y_im = ctx.saved_tensors 144 | return ComplexMultiply_backward(X_re,X_im,Y_re,Y_im, grad_Z_re, grad_Z_im) 145 | 146 | class CompactBilinearPoolingFn(Function): 147 | 148 | @staticmethod 149 | def forward(ctx, h1, s1, h2, s2, output_size, x, y, force_cpu_scatter_add=False): 150 | ctx.save_for_backward(h1,s1,h2,s2,x,y) 151 | ctx.x_size = tuple(x.size()) 152 | ctx.y_size = tuple(y.size()) 153 | ctx.force_cpu_scatter_add = force_cpu_scatter_add 154 | ctx.output_size = output_size 155 | 156 | # Compute the count sketch of each input 157 | px = CountSketchFn_forward(h1, s1, output_size, x, force_cpu_scatter_add) 158 | fx = torch.rfft(px,1) 159 | re_fx = fx.select(-1, 0) 160 | im_fx = fx.select(-1, 1) 161 | del px 162 | py = CountSketchFn_forward(h2, s2, output_size, y, force_cpu_scatter_add) 163 | fy = torch.rfft(py,1) 164 | re_fy = fy.select(-1,0) 165 | im_fy = fy.select(-1,1) 166 | del py 167 | 168 | # Convolution of the two sketch using an FFT. 169 | # Compute the FFT of each sketch 170 | 171 | 172 | # Complex multiplication 173 | re_prod, im_prod = ComplexMultiply_forward(re_fx,im_fx,re_fy,im_fy) 174 | 175 | # Back to real domain 176 | # The imaginary part should be zero's 177 | re = torch.irfft(torch.stack((re_prod, im_prod), re_prod.dim()), 1, signal_sizes=(output_size,)) 178 | 179 | return re 180 | 181 | @staticmethod 182 | def backward(ctx,grad_output): 183 | h1,s1,h2,s2,x,y = ctx.saved_tensors 184 | 185 | # Recompute part of the forward pass to get the input to the complex product 186 | # Compute the count sketch of each input 187 | px = CountSketchFn_forward(h1, s1, ctx.output_size, x, ctx.force_cpu_scatter_add) 188 | py = CountSketchFn_forward(h2, s2, ctx.output_size, y, ctx.force_cpu_scatter_add) 189 | 190 | # Then convert the output to Fourier domain 191 | grad_output = grad_output.contiguous() 192 | grad_prod = torch.rfft(grad_output, 1) 193 | grad_re_prod = grad_prod.select(-1, 0) 194 | grad_im_prod = grad_prod.select(-1, 1) 195 | 196 | # Compute the gradient of x first then y 197 | 198 | # Gradient of x 199 | # Recompute fy 200 | fy = torch.rfft(py,1) 201 | re_fy = fy.select(-1,0) 202 | im_fy = fy.select(-1,1) 203 | del py 204 | # Compute the gradient of fx, then back to temporal space 205 | grad_re_fx = torch.addcmul(grad_re_prod * re_fy, 1, grad_im_prod, im_fy) 206 | grad_im_fx = torch.addcmul(grad_im_prod * re_fy, -1, grad_re_prod, im_fy) 207 | grad_fx = torch.irfft(torch.stack((grad_re_fx,grad_im_fx), grad_re_fx.dim()), 1, signal_sizes=(ctx.output_size,)) 208 | # Finally compute the gradient of x 209 | grad_x = CountSketchFn_backward(h1, s1, ctx.x_size, grad_fx) 210 | del re_fy,im_fy,grad_re_fx,grad_im_fx,grad_fx 211 | 212 | # Gradient of y 213 | # Recompute fx 214 | fx = torch.rfft(px,1) 215 | re_fx = fx.select(-1, 0) 216 | im_fx = fx.select(-1, 1) 217 | del px 218 | # Compute the gradient of fy, then back to temporal space 219 | grad_re_fy = torch.addcmul(grad_re_prod * re_fx, 1, grad_im_prod, im_fx) 220 | grad_im_fy = torch.addcmul(grad_im_prod * re_fx, -1, grad_re_prod, im_fx) 221 | grad_fy = torch.irfft(torch.stack((grad_re_fy,grad_im_fy), grad_re_fy.dim()), 1, signal_sizes=(ctx.output_size,)) 222 | # Finally compute the gradient of y 223 | grad_y = CountSketchFn_backward(h2, s2, ctx.y_size, grad_fy) 224 | del re_fx,im_fx,grad_re_fy,grad_im_fy,grad_fy 225 | 226 | return None, None, None, None, None, grad_x, grad_y, None 227 | 228 | class CompactBilinearPooling(nn.Module): 229 | r"""Compute the compact bilinear pooling between two input array x and y 230 | 231 | .. math:: 232 | 233 | out = \Psi (x,h_1,s_1) \ast \Psi (y,h_2,s_2) 234 | 235 | Args: 236 | input_size1 (int): Number of channels in the first input array 237 | input_size2 (int): Number of channels in the second input array 238 | output_size (int): Number of channels in the output array 239 | h1 (array, optional): Optional array of size input_size of indices in the range [0,output_size] 240 | s1 (array, optional): Optional array of size input_size of -1 and 1. 241 | h2 (array, optional): Optional array of size input_size of indices in the range [0,output_size] 242 | s2 (array, optional): Optional array of size input_size of -1 and 1. 243 | force_cpu_scatter_add (boolean, optional): Force the scatter_add operation to run on CPU for testing purposes 244 | 245 | .. note:: 246 | 247 | If h1, s1, s2, h2 are None, they will be automatically be generated using LongTensor.random_. 248 | 249 | Shape: 250 | - Input 1: (...,input_size1) 251 | - Input 2: (...,input_size2) 252 | - Output: (...,output_size) 253 | 254 | References: 255 | Yang Gao et al. "Compact Bilinear Pooling" in Proceedings of IEEE Conference on Computer Vision and Pattern Recognition (2016). 256 | Akira Fukui et al. "Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding", arXiv:1606.01847 (2016). 257 | """ 258 | def __init__(self, input1_size, input2_size, output_size, h1 = None, s1 = None, h2 = None, s2 = None, force_cpu_scatter_add=False): 259 | super(CompactBilinearPooling, self).__init__() 260 | self.add_module('sketch1', CountSketch(input1_size, output_size, h1, s1)) 261 | self.add_module('sketch2', CountSketch(input2_size, output_size, h2, s2)) 262 | self.output_size = output_size 263 | self.force_cpu_scatter_add = force_cpu_scatter_add 264 | 265 | def forward(self, x, y = None): 266 | if y is None: 267 | y = x 268 | 269 | return CompactBilinearPoolingFn.apply(self.sketch1.h, self.sketch1.s, self.sketch2.h, self.sketch2.s, self.output_size, x, y, self.force_cpu_scatter_add) 270 | 271 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import copy 3 | from shutil import copytree 4 | import os 5 | import time 6 | 7 | 8 | dset_root = {} 9 | dset_root['cub'] = '/scratch1/tsungyulin/dataset/cub' 10 | dset_root['cars'] = '/scratch1/tsungyulin/dataset/cars' 11 | dset_root['aircrafts'] = '/scratch1/tsungyulin/dataset/fgvc-aircraft-2013b' 12 | dset_root['inat'] = '/scratch1/tsungyulin/dataset/inat_2018_448' 13 | dset_root['mit_indoor'] = '/scratch1/tsungyulin/dataset/mit_indoor' 14 | 15 | test_code = False 16 | 17 | 18 | if 'node' in socket.gethostname() or test_code: 19 | nfs_dset = copy.deepcopy(dset_root) 20 | if test_code: 21 | local_path = os.path.join(os.getenv("HOME"), 'my_local_test') 22 | else: 23 | local_path = '/local/image_datasets' 24 | if not os.path.isdir(local_path): 25 | os.makedirs(local_path) 26 | for x in dset_root.items(): 27 | folder_name = os.path.basename(x[1]) 28 | dset_root[x[0]] = os.path.join(local_path, folder_name) 29 | 30 | def wait_dataset_copy_finish(dataset): 31 | flag_file = os.path.join(dset_root[dataset] + '_flag', 32 | 'flag_ready.txt') 33 | while True: 34 | with open(flag_file, 'r') as f: 35 | status = f.readline() 36 | if status == 'True': 37 | break 38 | time.sleep(600) 39 | 40 | 41 | def setup_dataset(dataset): 42 | my_tmp = os.path.join(os.getenv("HOME"), 'tmp') 43 | if not os.path.isdir(my_tmp): 44 | os.makedirs(my_tmp) 45 | os.environ["TMPDIR"] = my_tmp 46 | if 'node' in socket.gethostname(): 47 | if not os.path.isdir(dset_root[dataset]): 48 | if os.path.isdir(os.path.join(dset_root[dataset] + '_flag')): 49 | wait_dataset_copy_finish(dataset) 50 | else: 51 | gypsum_copy_data_to_local(dataset) 52 | else: 53 | wait_dataset_copy_finish(dataset) 54 | 55 | 56 | def gypsum_copy_data_to_local(dataset): 57 | flag_file = os.path.join(dset_root[dataset] + '_flag', 'flag_ready.txt') 58 | 59 | os.makedirs(dset_root[dataset] + '_flag') 60 | with open(flag_file, 'w') as f: 61 | f.write('False') 62 | if test_code: 63 | import pdb 64 | pdb.set_trace() 65 | pass 66 | copytree(nfs_dset[dataset], dset_root[dataset]) 67 | 68 | if test_code: 69 | pdb.set_trace() 70 | pass 71 | with open(flag_file, 'w') as f: 72 | f.write('True') 73 | 74 | -------------------------------------------------------------------------------- /feature_extractor.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | class ResNet(nn.Module): 7 | def __init__(self): 8 | super(ResNet, self).__init__() 9 | 10 | self.model = models.resnet101(pretrained=True) 11 | self.input_size = 224 12 | delattr(self.model, 'fc') 13 | delattr(self.model, 'avgpool') 14 | 15 | def forward(self, x): 16 | x = self.model.conv1(x) 17 | x = self.model.bn1(x) 18 | x = self.model.relu(x) 19 | x = self.model.maxpool(x) 20 | 21 | x = self.model.layer1(x) 22 | x = self.model.layer2(x) 23 | x = self.model.layer3(x) 24 | x = self.model.layer4(x) 25 | 26 | return x 27 | 28 | class AlexNet(nn.Module): 29 | def __init__(self): 30 | super(AlexNet, self).__init__() 31 | self.model = models.alexnet(pretrained=True) 32 | self.input_size = 224 33 | 34 | def forward(self, x): 35 | x = self.model.features(x) 36 | return x 37 | 38 | class VGG(nn.Module): 39 | def __init__(self): 40 | super(VGG, self).__init__() 41 | # self.model = models.vgg16(pretrained=True) 42 | self.model = models.vgg16(pretrained=True).features 43 | self.model = torch.nn.Sequential(*list(self.model.children())[:-1]) 44 | self.input_size = 224 45 | 46 | def forward(self, x): 47 | x = self.model(x) 48 | # x = self.model.features(x) 49 | return x 50 | 51 | class DenseNet(nn.Module): 52 | def __init__(self): 53 | super(DenseNet, self).__init__() 54 | self.model = models.densenet201(pretrained=True) 55 | self.input = 224 56 | 57 | def forward(self, x): 58 | x = self.model.features(x) 59 | x = F.relu(x, inplace=True) 60 | return x 61 | 62 | class Inception(nn.Module): 63 | def __init__(self): 64 | super(Inception, self).__init__() 65 | self.model = models.inception_v3(pretrained=True) 66 | self.input_size = 299 67 | 68 | def forward(self, x): 69 | if self.model.transform_input: 70 | ''' 71 | x = x.clone() 72 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 73 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 74 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 75 | ''' 76 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 77 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 78 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 79 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 80 | 81 | x = self.model.Conv2d_1a_3x3(x) 82 | x = self.model.Conv2d_2a_3x3(x) 83 | x = self.model.Conv2d_2b_3x3(x) 84 | x = F.max_pool2d(x, kernel_size=3, stride=2) 85 | x = self.model.Conv2d_3b_1x1(x) 86 | x = self.model.Conv2d_4a_3x3(x) 87 | x = F.max_pool2d(x, kernel_size=3, stride=2) 88 | x = self.model.Mixed_5b(x) 89 | x = self.model.Mixed_5c(x) 90 | x = self.model.Mixed_5d(x) 91 | x = self.model.Mixed_6a(x) 92 | x = self.model.Mixed_6b(x) 93 | x = self.model.Mixed_6c(x) 94 | x = self.model.Mixed_6d(x) 95 | x = self.model.Mixed_6e(x) 96 | x = self.model.Mixed_7a(x) 97 | x = self.model.Mixed_7b(x) 98 | x = self.model.Mixed_7c(x) 99 | 100 | return x 101 | 102 | # Conv1 features are not returned 103 | class VGG_all_conv_features(nn.Module): 104 | def __init__(self): 105 | super(VGG_all_conv_features, self).__init__() 106 | # default ceil_mode for MaxPool2d is False, not sure if I shoulde chage it 107 | # to True 108 | vgg_pretrained = models.vgg16(pretrained=True) 109 | # add -1 to the index to remove the pooling layer 110 | self.block1 = nn.Sequential(*list(vgg_pretrained.features.children())[:5-1]) 111 | self.block2 = nn.Sequential(*list(vgg_pretrained.features.children())[5:10-1]) 112 | self.block3 = nn.Sequential(*list(vgg_pretrained.features.children())[10:17-1]) 113 | self.block4 = nn.Sequential(*list(vgg_pretrained.features.children())[17:24-1]) 114 | self.block5 = nn.Sequential(*list(vgg_pretrained.features.children())[24:-1]) 115 | 116 | self.pooling = nn.MaxPool2d(kernel_size=2, stride=2) 117 | 118 | def get_feature_dims(self): 119 | return (128, 256, 512, 512) 120 | 121 | def forward(self, x): 122 | x1 = self.block1(x) 123 | x1_ = self.pooling(x1) 124 | 125 | x2 = self.block2(x1_) 126 | x2_ = self.pooling(x2) 127 | 128 | x3 = self.block3(x2_) 129 | x3_ = self.pooling(x3) 130 | 131 | x4 = self.block4(x3_) 132 | x4_ = self.pooling(x4) 133 | 134 | x5 = self.block5(x4_) 135 | 136 | return x2, x3, x4, x5 137 | -------------------------------------------------------------------------------- /find_best_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | 5 | def main(args): 6 | folder_path = os.path.join('..', 'exp', args.dataset) 7 | res = re.compile('Best val accuracy: ([.\d]+)') 8 | best_acc = 0 9 | best_model = None 10 | 11 | for folder in os.listdir(folder_path): 12 | train_hist_file = os.path.join(folder_path, folder, 'train_history.txt') 13 | if not os.path.isfile(train_hist_file): 14 | break 15 | with open(train_hist_file, 'r') as f : 16 | lines = f.readlines() 17 | for l in lines: 18 | if 'Best val accuracy:' in l: 19 | m = res.match(l) 20 | acc = float(m.groups()[0]) 21 | if best_acc < acc: 22 | best_acc = acc 23 | best_model = folder 24 | 25 | print('Best model folder: %s \n Best accuracy: %f\n'%(best_model, best_acc)) 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--dataset', default='cub', type=str) 30 | args = parser.parse_args() 31 | 32 | main(args) 33 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | import torchvision 6 | from torchvision import datasets, models, transforms 7 | import os 8 | from config import dset_root 9 | import random 10 | import argparse 11 | import copy 12 | import logging 13 | import sys 14 | import time 15 | import shutil 16 | from tensorboardX import SummaryWriter 17 | from stn import STNet 18 | 19 | def initializeLogging(log_filename, logger_name): 20 | log = logging.getLogger(logger_name) 21 | log.setLevel(logging.DEBUG) 22 | log.addHandler(logging.StreamHandler(sys.stdout)) 23 | log.addHandler(logging.FileHandler(log_filename, mode='a')) 24 | 25 | return log 26 | 27 | def save_checkpoint(state, is_best, checkpoint_folder='exp', 28 | filename='checkpoint.pth.tar'): 29 | filename = os.path.join(checkpoint_folder, filename) 30 | best_model_filename = os.path.join(checkpoint_folder, 'model_best.pth.tar') 31 | torch.save(state, filename) 32 | if is_best: 33 | shutil.copyfile(filename, best_model_filename) 34 | 35 | def set_parameter_requires_grad(model, feature_extract): 36 | if feature_extract: 37 | for param in model.parameters(): 38 | param.requires_grad = False 39 | 40 | def initialize_optimizer(model_ft, feature_extract=False, stn=False): 41 | params_to_update = model_ft.parameters() 42 | if feature_extract: 43 | params_to_update = [] 44 | for name,param in model_ft.named_parameters(): 45 | if param.requires_grad == True: 46 | params_to_update.append(param) 47 | 48 | # Observe that all parameters are being optimized 49 | # optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9) 50 | # optimizer_ft = optim.Adam(params_to_update, lr=1e-4, weight_decay=0, betas=(0.9, 0.999)) 51 | if stn is False: 52 | optimizer_ft = optim.Adam(params_to_update, lr=1e-4, weight_decay=0, betas=(0.9, 0.999)) 53 | else: 54 | params_to_update = [] 55 | # params_to_update_name = [] 56 | for name,param in model_ft.model_ft.named_parameters(): 57 | if param.requires_grad == True: 58 | params_to_update.append(param) 59 | # params_to_update_name.append(name) 60 | params_to_update_stn = [] 61 | # params_to_update_stn_name = [] 62 | for name,param in model_ft.fc_loc.named_parameters(): 63 | if param.requires_grad == True: 64 | params_to_update_stn.append(param) 65 | # params_to_update_stn_name.append(name) 66 | for name,param in model_ft.localization.named_parameters(): 67 | if param.requires_grad == True: 68 | params_to_update_stn.append(param) 69 | # params_to_update_stn_name.append(name) 70 | 71 | optimizer_ft = optim.Adam([ {'params':params_to_update}, 72 | {'params':params_to_update_stn, 'lr':1e-8, 'weight_decay':1e-5}], 73 | lr=1e-4, weight_decay=0, betas=(0.9, 0.999)) 74 | return optimizer_ft 75 | 76 | def initialize_model(model_name, num_classes, feature_extract=False, 77 | use_pretrained=True): 78 | model_ft = None 79 | input_size = 0 80 | 81 | if model_name == "resnet": 82 | """ Resnet101 83 | """ 84 | model_ft = models.resnet101(pretrained=use_pretrained) 85 | set_parameter_requires_grad(model_ft, feature_extract) 86 | num_ftrs = model_ft.fc.in_features 87 | model_ft.fc = nn.Linear(num_ftrs, num_classes) 88 | input_size = 224 89 | 90 | elif model_name == "resnet50": 91 | """ Resnet50 92 | """ 93 | model_ft = models.resnet50(pretrained=use_pretrained) 94 | set_parameter_requires_grad(model_ft, feature_extract) 95 | num_ftrs = model_ft.fc.in_features 96 | model_ft.fc = nn.Linear(num_ftrs, num_classes) 97 | input_size = 224 98 | 99 | elif model_name == "alexnet": 100 | """ Alexnet 101 | """ 102 | model_ft = models.alexnet(pretrained=use_pretrained) 103 | set_parameter_requires_grad(model_ft, feature_extract) 104 | num_ftrs = model_ft.classifier[6].in_features 105 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) 106 | input_size = 224 107 | 108 | elif model_name == "vgg": 109 | """ VGG11_bn 110 | """ 111 | model_ft = models.vgg11_bn(pretrained=use_pretrained) 112 | set_parameter_requires_grad(model_ft, feature_extract) 113 | num_ftrs = model_ft.classifier[6].in_features 114 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) 115 | input_size = 224 116 | 117 | elif model_name == "squeezenet": 118 | """ Squeezenet 119 | """ 120 | model_ft = models.squeezenet1_0(pretrained=use_pretrained) 121 | set_parameter_requires_grad(model_ft, feature_extract) 122 | model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1)) 123 | model_ft.num_classes = num_classes 124 | input_size = 224 125 | 126 | elif model_name == "densenet": 127 | """ Densenet 128 | """ 129 | model_ft = models.densenet201(pretrained=use_pretrained) 130 | set_parameter_requires_grad(model_ft, feature_extract) 131 | num_ftrs = model_ft.classifier.in_features 132 | model_ft.classifier = nn.Linear(num_ftrs, num_classes) 133 | input_size = 224 134 | 135 | elif model_name == "inception": 136 | """ Inception v3 137 | Be careful, expects (299,299) sized images and has auxiliary output 138 | """ 139 | model_ft = models.inception_v3(pretrained=use_pretrained) 140 | set_parameter_requires_grad(model_ft, feature_extract) 141 | # Handle the auxilary net 142 | num_ftrs = model_ft.AuxLogits.fc.in_features 143 | model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes) 144 | # Handle the primary net 145 | num_ftrs = model_ft.fc.in_features 146 | model_ft.fc = nn.Linear(num_ftrs,num_classes) 147 | input_size = 299 148 | 149 | else: 150 | # print("Invalid model name, exiting...") 151 | logger.debug("Invalid mode name") 152 | exit() 153 | 154 | return model_ft, input_size 155 | 156 | def train_model(model, dataloaders, criterion, optimizer, num_epochs=35, 157 | is_inception=False, logger_name='train_logger', checkpoint_folder='exp', 158 | start_epoch=0, writer=None): 159 | 160 | logger = logging.getLogger(logger_name) 161 | 162 | device = next(model.parameters()).device 163 | since = time.time() 164 | 165 | val_acc_history = [] 166 | 167 | best_model_wts = copy.deepcopy(model.state_dict()) 168 | best_acc = 0.0 169 | 170 | for epoch in range(start_epoch, num_epochs): 171 | logger.info('Epoch {}/{}'.format(epoch + 1, num_epochs)) 172 | logger.info('-' * 10) 173 | # print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 174 | # print('-' * 10) 175 | 176 | # Each epoch has a training and validation phase 177 | for phase in ['train', 'val']: 178 | if phase == 'train': 179 | model.train() # Set model to training mode 180 | else: 181 | model.eval() # Set model to evaluate mode 182 | 183 | running_loss = 0.0 184 | running_corrects = 0 185 | 186 | # Iterate over data. 187 | for inputs, labels, _ in dataloaders[phase]: 188 | inputs = inputs.to(device) 189 | labels = labels.to(device) 190 | 191 | # zero the parameter gradients 192 | optimizer.zero_grad() 193 | 194 | # forward 195 | # track history if only in train 196 | with torch.set_grad_enabled(phase == 'train'): 197 | # Get model outputs and calculate loss 198 | # Special case for inception because in training it has an auxiliary output. In train 199 | # mode we calculate the loss by summing the final output and the auxiliary output 200 | # but in testing we only consider the final output. 201 | if is_inception and phase == 'train': 202 | outputs, aux_outputs = model(inputs) 203 | loss1 = criterion(outputs, labels) 204 | loss2 = criterion(aux_outputs, labels) 205 | loss = loss1 + 0.4*loss2 206 | else: 207 | outputs = model(inputs) 208 | loss = criterion(outputs, labels) 209 | 210 | _, preds = torch.max(outputs, 1) 211 | 212 | # backward + optimize only if in training phase 213 | if phase == 'train': 214 | loss.backward() 215 | optimizer.step() 216 | 217 | # statistics 218 | running_loss += loss.item() * inputs.size(0) 219 | running_corrects += torch.sum(preds == labels.data) 220 | 221 | epoch_loss = running_loss / len(dataloaders[phase].dataset) 222 | epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) 223 | 224 | # print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) 225 | logger.info('{} Loss: {:.4f} Acc: {:.2f}'.format(phase, epoch_loss, epoch_acc*100)) 226 | 227 | writer.add_scalar(phase+'/loss', epoch_loss, epoch+1) 228 | writer.add_scalar(phase+'/acc', epoch_acc*100, epoch+1) 229 | 230 | # deep copy the model 231 | is_best = epoch_acc > best_acc 232 | if phase == 'val' and epoch_acc > best_acc: 233 | best_acc = epoch_acc 234 | best_model_wts = copy.deepcopy(model.state_dict()) 235 | if phase == 'val': 236 | val_acc_history.append(epoch_acc) 237 | 238 | save_checkpoint({ 239 | 'epoch': epoch + 1, 240 | 'model': args.model, 241 | 'state_dict': model.state_dict(), 242 | 'best_acc': best_acc, 243 | 'optimizer' : optimizer.state_dict(), 244 | }, is_best, checkpoint_folder=checkpoint_folder) 245 | 246 | if epoch > 0 and (epoch+1) % 15 == 0: 247 | for param_group in optimizer.param_groups: 248 | param_group['lr'] = param_group['lr']*0.5 249 | 250 | 251 | time_elapsed = time.time() - since 252 | logger.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 253 | logger.info('Best val Acc: {:4f}'.format(best_acc)) 254 | # print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 255 | # print('Best val Acc: {:4f}'.format(best_acc)) 256 | 257 | # load best model weights 258 | model.load_state_dict(best_model_wts) 259 | 260 | writer.close() 261 | return model, val_acc_history 262 | 263 | def main(args): 264 | 265 | log_dir = args.exp_dir+'/log' 266 | if os.path.exists(log_dir): 267 | shutil.rmtree(log_dir) 268 | writer = SummaryWriter(log_dir) 269 | 270 | batch_size = 32 271 | maxIter = 10000 272 | split = 'val' 273 | input_size = 224 274 | 275 | if not os.path.isdir(args.exp_dir): 276 | os.makedirs(args.exp_dir) 277 | if not os.path.isdir(os.path.join(args.exp_dir, args.task)): 278 | os.makedirs(os.path.join(args.exp_dir, args.task)) 279 | checkpoint_folder = os.path.join(args.exp_dir, args.task, 'checkpoints') 280 | if not os.path.isdir(checkpoint_folder): 281 | os.makedirs(checkpoint_folder) 282 | 283 | logger_name = 'train_logger' 284 | logger = initializeLogging(os.path.join(args.exp_dir, args.task, 285 | 'train_history.txt'), logger_name) 286 | 287 | # ================== Craete data loader ================================== 288 | data_transforms = { 289 | 'train': transforms.Compose([ 290 | transforms.RandomResizedCrop(input_size), 291 | transforms.RandomHorizontalFlip(), 292 | transforms.ToTensor(), 293 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 294 | ]), 295 | 'val': transforms.Compose([ 296 | transforms.Resize(input_size), 297 | transforms.CenterCrop(input_size), 298 | transforms.ToTensor(), 299 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 300 | ]) 301 | } 302 | 303 | if args.task == 'cub': 304 | from CUBDataset import CUBDataset 305 | image_datasets = {split: CUBDataset(dset_root['cub'], split, 306 | create_val=True, transform=data_transforms[split]) \ 307 | for split in ['train', 'val']} 308 | elif args.task == 'cars': 309 | from CarsDataset import CarsDataset 310 | image_datasets = {split: CarsDataset(dset_root['cars'], split, 311 | create_val=True, transform=data_transforms[split]) \ 312 | for split in ['train', 'val']} 313 | elif args.task == 'aircrafts': 314 | from AircraftsDataset import AircraftsDataset 315 | image_datasets = {split: AircraftsDataset(dset_root['aircrafts'], split, 316 | transform=data_transforms[split]) \ 317 | for split in ['train', 'val']} 318 | elif args.task[:len('inat_')] == 'inat_': 319 | from iNatDataset import iNatDataset 320 | task = args.task 321 | subtask = task[len('inat_'):] 322 | subtask = subtask[0].upper() + subtask[1:] 323 | image_datasets = {split: iNatDataset(dset_root['inat'], split, subtask, 324 | transform=data_transforms[split]) \ 325 | for split in ['train', 'val']} 326 | else: 327 | raise ValueError('Unknown dataset: %s' % task) 328 | 329 | 330 | num_classes = image_datasets['train'].get_num_classes() 331 | 332 | dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], 333 | batch_size=args.batch_size, shuffle=True, num_workers=4) \ 334 | for x in ['train', 'val']} 335 | 336 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 337 | 338 | #======================= Initialize the model============================== 339 | model_ft, input_size = initialize_model(args.model, num_classes, 340 | feature_extract=False, use_pretrained=True) 341 | if args.stn: 342 | model_ft = STNet(model_ft) 343 | model_ft = model_ft.to(device) 344 | 345 | #====================== Initialize optimizer ============================== 346 | optim = initialize_optimizer(model_ft, feature_extract=False, stn=args.stn) 347 | 348 | # Setup the loss fxn 349 | criterion = nn.CrossEntropyLoss() 350 | 351 | start_epoch = 0 352 | # load from checkpoint if exist 353 | if not args.train_from_beginning: 354 | checkpoint_filename = os.path.join(checkpoint_folder, 355 | 'checkpoint.pth.tar') 356 | if os.path.isfile(checkpoint_filename): 357 | print("=> loading checkpoint '{}'".format(checkpoint_filename)) 358 | checkpoint = torch.load(checkpoint_filename) 359 | start_epoch = checkpoint['epoch'] 360 | best_acc= checkpoint['best_acc'] 361 | model_ft.load_state_dict(checkpoint['state_dict']) 362 | optim.load_state_dict(checkpoint['optimizer']) 363 | print("=> loaded checkpoint '{}' (epoch {})" 364 | .format(checkpoint_filename, checkpoint['epoch'])) 365 | 366 | # parallelize the model if using multiple gpus 367 | if torch.cuda.device_count() > 1: 368 | model_ft = torch.nn.DataParallel(model_ft) 369 | 370 | # Train the miodel 371 | model_ft = train_model(model_ft, dataloaders_dict, criterion, optim, 372 | num_epochs=args.num_epochs, is_inception=(args.model=="inception"), 373 | logger_name=logger_name, checkpoint_folder=checkpoint_folder, 374 | start_epoch=start_epoch, writer=writer) 375 | 376 | if __name__ == '__main__': 377 | parser = argparse.ArgumentParser() 378 | parser.add_argument('--task', default='cub', type=str, 379 | help='the name of the task|dataset') 380 | parser.add_argument('--model', default='resnet50', type=str, 381 | help='resnet|densenet') 382 | parser.add_argument('--batch_size', default=32, type=int, 383 | help='size of mini-batch') 384 | parser.add_argument('--num_epochs', default=35, type=int, 385 | help='number of epochs') 386 | parser.add_argument('--exp_dir', default='exp', type=str, 387 | help='path to the chekcpoint folder for the experiment') 388 | parser.add_argument('--train_from_beginning', action='store_true', 389 | help='train the model from first epoch, i.e. ignore the checkpoint') 390 | parser.add_argument('--stn', dest='stn', action='store_true', 391 | help='use STN') 392 | args = parser.parse_args() 393 | main(args) 394 | 395 | 396 | 397 | -------------------------------------------------------------------------------- /iNatDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import numpy as np 4 | import os 5 | from torchvision.datasets import folder as dataset_parser 6 | import json 7 | 8 | def make_dataset(dataset_root, split, subset=None): 9 | 10 | with open(os.path.join(dataset_root, '%s2018.json'%split)) as f: 11 | data = json.load(f) 12 | if split != 'test': 13 | if subset is not None: 14 | # select the images with the annotations as one of the class in subset 15 | data['categories'] = [x for x in data['categories'] \ 16 | if x['supercategory'] == subset] 17 | subset_cid = [x['id'] for x in data['categories']] 18 | select_images = [(data['images'][idx], x) \ 19 | for idx, x in enumerate(data['annotations']) \ 20 | if x['category_id'] in subset_cid] 21 | data['images'], data['annotations'] = zip(*select_images) 22 | 23 | # re-index the categories 24 | cls_mapping = {x['id']: idx \ 25 | for idx, x in enumerate(data['categories'])} 26 | for idx, x in enumerate(data['categories']): 27 | data['categories'][idx]['id'] = cls_mapping[x['id']] 28 | for idx, x in enumerate(data['annotations']): 29 | data['annotations'][idx]['category_id'] = \ 30 | cls_mapping[x['category_id']] 31 | 32 | num_classes = len(data['categories']) 33 | img = [(im['file_name'], annot['category_id']) \ 34 | for im, annot in zip(data['images'], data['annotations'])] 35 | classes = [x['name'] for x in data['categories']] 36 | else: 37 | num_classes = -1 38 | img = [(im['file_name'], -1) for im in data['images']] 39 | classes = [] 40 | 41 | return img, num_classes, classes 42 | 43 | 44 | class iNatDataset(data.Dataset): 45 | def __init__(self, dataset_root, split, subset=None, transform=None, 46 | target_transform=None, loader=dataset_parser.default_loader): 47 | 48 | assert subset in ['Plantae', 'Insecta', 'Aves', \ 49 | 'Actinopterygii', 'Fungi', 'Reptilia', 'Mollusca', 'Mammalia', \ 50 | 'Animalia', 'Amphibia', 'Arachnida', None] 51 | self.subset = subset 52 | 53 | self.loader = loader 54 | self.dataset_root = dataset_root 55 | 56 | if split == 'train_val': 57 | self.imgs, self.num_classes, self.classes = make_dataset( 58 | self.dataset_root, 'train', subset) 59 | self.imgs2, _, _ = make_dataset(self.dataset_root, 'val', subset) 60 | self.imgs = self.imgs + self.imgs2 61 | else: 62 | self.imgs, self.num_classes, self.classes = make_dataset( 63 | self.dataset_root, split, subset) 64 | self.transform = transform 65 | self.target_transform = target_transform 66 | self.dataset_root = dataset_root 67 | 68 | def __getitem__(self, index): 69 | path, target = self.imgs[index] 70 | path = os.path.join(self.dataset_root, path) 71 | img = self.loader(path) 72 | if self.transform is not None: 73 | img = [x(img) for x in self.transform] 74 | 75 | if self.target_transform is not None: 76 | target = self.target_transform(target) 77 | 78 | return (*img, target, path) 79 | 80 | def __len__(self): 81 | return len(self.imgs) 82 | 83 | def get_num_classes(self): 84 | return self.num_classes 85 | -------------------------------------------------------------------------------- /inv_images/001.Black_footed_Albatross.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/001.Black_footed_Albatross.png -------------------------------------------------------------------------------- /inv_images/002.Laysan_Albatross.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/002.Laysan_Albatross.png -------------------------------------------------------------------------------- /inv_images/005.Crested_Auklet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/005.Crested_Auklet.png -------------------------------------------------------------------------------- /inv_images/006.Least_Auklet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/006.Least_Auklet.png -------------------------------------------------------------------------------- /inv_images/007.Parakeet_Auklet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/007.Parakeet_Auklet.png -------------------------------------------------------------------------------- /inv_images/009.Brewer_Blackbird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/009.Brewer_Blackbird.png -------------------------------------------------------------------------------- /inv_images/010.Red_winged_Blackbird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/010.Red_winged_Blackbird.png -------------------------------------------------------------------------------- /inv_images/011.Rusty_Blackbird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/011.Rusty_Blackbird.png -------------------------------------------------------------------------------- /inv_images/012.Yellow_headed_Blackbird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/012.Yellow_headed_Blackbird.png -------------------------------------------------------------------------------- /inv_images/013.Bobolink.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/013.Bobolink.png -------------------------------------------------------------------------------- /inv_images/014.Indigo_Bunting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/014.Indigo_Bunting.png -------------------------------------------------------------------------------- /inv_images/015.Lazuli_Bunting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/015.Lazuli_Bunting.png -------------------------------------------------------------------------------- /inv_images/016.Painted_Bunting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/016.Painted_Bunting.png -------------------------------------------------------------------------------- /inv_images/017.Cardinal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/017.Cardinal.png -------------------------------------------------------------------------------- /inv_images/018.Spotted_Catbird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/018.Spotted_Catbird.png -------------------------------------------------------------------------------- /inv_images/019.Gray_Catbird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/019.Gray_Catbird.png -------------------------------------------------------------------------------- /inv_images/020.Yellow_breasted_Chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/020.Yellow_breasted_Chat.png -------------------------------------------------------------------------------- /inv_images/024.Red_faced_Cormorant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/024.Red_faced_Cormorant.png -------------------------------------------------------------------------------- /inv_images/025.Pelagic_Cormorant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/025.Pelagic_Cormorant.png -------------------------------------------------------------------------------- /inv_images/026.Bronzed_Cowbird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/026.Bronzed_Cowbird.png -------------------------------------------------------------------------------- /inv_images/027.Shiny_Cowbird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/027.Shiny_Cowbird.png -------------------------------------------------------------------------------- /inv_images/028.Brown_Creeper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/028.Brown_Creeper.png -------------------------------------------------------------------------------- /inv_images/029.American_Crow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/029.American_Crow.png -------------------------------------------------------------------------------- /inv_images/030.Fish_Crow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvl-umass/bilinear-cnn/d6a05aa200774c97b38f753feed9032aa3764cce/inv_images/030.Fish_Crow.png -------------------------------------------------------------------------------- /inversion.py: -------------------------------------------------------------------------------- 1 | from config import dset_root, setup_dataset 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import numpy as np 6 | import torchvision 7 | from torchvision import datasets, models, transforms 8 | import os 9 | import argparse 10 | import sys 11 | from BCNN import create_multi_heads_bcnn 12 | import json 13 | import logging 14 | import copy 15 | import shutil 16 | import scipy.misc 17 | 18 | def initializeLogging(log_filename, logger_name): 19 | log = logging.getLogger(logger_name) 20 | log.setLevel(logging.DEBUG) 21 | log.addHandler(logging.StreamHandler(sys.stdout)) 22 | log.addHandler(logging.FileHandler(log_filename, mode='a')) 23 | 24 | return log 25 | 26 | def train_model(model, dset_loader, criterion, 27 | optimizer, batch_size_update=256, 28 | # maxItr=50000, logger_name='train_logger', checkpoint_folder='exp', 29 | epoch=45, logger_name='train_logger', checkpoint_folder='exp', 30 | start_itr=0, clip_grad=-1, scheduler=None, fine_tune=True): 31 | 32 | maxItr = epoch * len(dset_loader['train'].dataset) // \ 33 | dset_loader['train'].batch_size + 1 34 | 35 | val_every_number_examples = max(10000, 36 | len(dset_loader['train'].dataset) // 5) 37 | val_frequency = val_every_number_examples // dset_loader['train'].batch_size 38 | checkpoint_frequency = 5 * len(dset_loader['train'].dataset) / \ 39 | dset_loader['train'].batch_size 40 | last_checkpoint = start_itr - 1 41 | logger = logging.getLogger(logger_name) 42 | 43 | device = next(model.parameters()).device 44 | 45 | running_num_data = 0 46 | # Train the fc classifier for the features from 4 layers 47 | # {relu2_2, relu3_3, relur4_3, relu5_3} 48 | running_loss = [0.0] * 4 49 | running_corrects = [0] * 4 50 | 51 | best_acc = [0.0] * 4 52 | 53 | dset_iter = {x:iter(dset_loader[x]) for x in ['train', 'val']} 54 | bs = dset_loader['train'].batch_size 55 | update_frequency = batch_size_update // bs 56 | 57 | model.module.fc_list.train() 58 | 59 | last_epoch = 0 60 | for itr in range(start_itr, maxItr): 61 | # at the end of validation set model.train() 62 | if (itr + 1) % val_frequency == 0 or itr == maxItr - 1: 63 | logger.info('Iteration {}/{}'.format(itr, maxItr - 1)) 64 | logger.info('-' * 10) 65 | 66 | try: 67 | all_fields = next(dset_iter['train']) 68 | labels = all_fields[-2] 69 | inputs = all_fields[:-2] 70 | # inputs, labels, _ = next(dset_iter['train']) 71 | except StopIteration: 72 | dset_iter['train'] = iter(dset_loader['train']) 73 | all_fields = next(dset_iter['train']) 74 | labels = all_fields[-2] 75 | inputs = all_fields[:-2] 76 | 77 | inputs = [x.to(device) for x in inputs] 78 | labels = labels.to(device) 79 | 80 | with torch.set_grad_enabled(True): 81 | outputs = model(*inputs) 82 | loss_list = [criterion(output, labels) for output in outputs] 83 | loss = torch.sum(torch.stack(loss_list)) 84 | 85 | preds = [] 86 | for output in outputs: 87 | _, pred = torch.max(output, 1) 88 | preds.append(pred) 89 | 90 | loss.backward() 91 | 92 | if (itr + 1) % update_frequency == 0: 93 | optimizer.step() 94 | optimizer.zero_grad() 95 | 96 | epoch = ((itr + 1) * bs) // len(dset_loader['train'].dataset) 97 | 98 | running_num_data += inputs[0].size(0) 99 | for idx, loss_ in enumerate(loss_list): 100 | running_loss[idx] += loss_.item() * inputs[0].size(0) 101 | running_corrects[idx] += torch.sum(preds[idx] == labels.data) 102 | 103 | if (itr + 1) % val_frequency == 0 or itr == maxItr - 1: 104 | running_loss = [ 105 | r_loss / running_num_data for r_loss in running_loss 106 | ] 107 | running_acc = [ 108 | r_corrects.double() / running_num_data 109 | for r_corrects in running_corrects 110 | ] 111 | logger.info( 112 | '{} Loss: {:.4f} {:.4f} {:.4f} {:.4f} Acc: {:.4f} {:.4f} {:.4f} {:.4f}'.format( \ 113 | 'Train - relu2_2, relu3_3, relu4_3, relu5_3', 114 | *running_loss, *running_acc) 115 | ) 116 | running_num_data = 0 117 | running_loss = [0.0] * 4 118 | running_corrects = [0] * 4 119 | 120 | model.eval() 121 | 122 | val_running_loss = [0.0] * 4 123 | val_running_corrects = [0] * 4 124 | 125 | for all_fields in dset_loader['val']: 126 | labels = all_fields[-2] 127 | inputs = all_fields[:-2] 128 | inputs = [x.to(device) for x in inputs] 129 | labels = labels.to(device) 130 | 131 | with torch.set_grad_enabled(False): 132 | outputs = model(*inputs) 133 | loss_list = [criterion(output, labels) for output in outputs] 134 | loss = torch.sum(torch.stack(loss_list)) 135 | 136 | preds = [] 137 | for output in outputs: 138 | _, pred = torch.max(output, 1) 139 | preds.append(pred) 140 | 141 | for idx, loss_ in enumerate(loss_list): 142 | val_running_loss[idx] += loss_.item() * inputs[0].size(0) 143 | val_running_corrects[idx] += torch.sum(preds[idx] == labels.data) 144 | 145 | val_loss = [ 146 | r_loss / len(dset_loader['val'].dataset) 147 | for r_loss in val_running_loss 148 | ] 149 | val_acc = [ 150 | r_corrects.double() / len(dset_loader['val'].dataset) 151 | for r_corrects in val_running_corrects 152 | ] 153 | logger.info( 154 | '{} Loss: {:.4f} {:.4f} {:.4f} {:.4f} Acc: {:.4f} {:.4f} {:.4f} {:.4f}'.format( \ 155 | 'Validation - relu2_2, relu3_3, relu4_3, relu5_3', 156 | *val_loss, *val_acc) 157 | ) 158 | 159 | model.module.fc_list.train() 160 | 161 | # checkpoint 162 | if (itr + 1) % val_frequency == 0 or itr == maxItr - 1: 163 | do_checkpoint = (itr - last_checkpoint) >= checkpoint_frequency 164 | if do_checkpoint or itr == maxItr - 1: 165 | last_checkpoint = itr 166 | checkpoint_dict = { 167 | 'itr': itr + 1, 168 | 'state_dict': model.module.fc_list.state_dict(), 169 | 'optimizer' : optimizer.state_dict(), 170 | 'best_acc': best_acc 171 | } 172 | save_checkpoint( 173 | checkpoint_dict, 174 | checkpoint_folder=checkpoint_folder 175 | ) 176 | 177 | best_model_path = os.path.join(checkpoint_folder, 'model_best.pth.tar') 178 | 179 | 180 | update_best_model = False 181 | for v_idx, val_acc_ in enumerate(val_acc): 182 | is_best = val_acc_ > best_acc[v_idx] 183 | if is_best: 184 | if os.path.isfile(best_model_path): 185 | best_fc = torch.load(best_model_path) 186 | best_fc = best_fc['state_dict'] 187 | else: 188 | best_fc = copy.deepcopy( 189 | model.module.fc_list.state_dict() 190 | ) 191 | update_best_model = True 192 | break 193 | 194 | for v_idx, val_acc_ in enumerate(val_acc): 195 | is_best = val_acc_ > best_acc[v_idx] 196 | param_names = ['%d.weight'%v_idx, '%d.bias'%v_idx] 197 | if is_best: 198 | best_acc[v_idx] = val_acc_ 199 | for name in param_names: 200 | best_fc[name] = model.module.fc_list.state_dict()[name] 201 | 202 | if update_best_model: 203 | torch.save({'state_dict': best_fc}, best_model_path) 204 | 205 | logger.info('Best val accuracy: {:4f} {:4f} {:4f} {:4f}'.format(*best_acc)) 206 | 207 | # load best model weights 208 | best_model_wts = torch.load(os.path.join(checkpoint_folder, 'model_best.pth.tar')) 209 | model.module.fc_list.load_state_dict(best_model_wts['state_dict']) 210 | 211 | return model 212 | 213 | def save_checkpoint( 214 | state, 215 | checkpoint_folder='exp', 216 | filename='checkpoint.pth.tar' 217 | ): 218 | filename = os.path.join(checkpoint_folder, filename) 219 | torch.save(state, filename) 220 | 221 | 222 | def initialize_optimizer(model_ft, lr, wd=0): 223 | 224 | fc_params_to_update = [] 225 | fc_params_group_2 = [] 226 | fc_params_group_3 = [] 227 | for name, param in model_ft.named_parameters(): 228 | # if name == 'module.fc.bias' or name == 'module.fc.weight': 229 | if 'module.fc_list' in name: 230 | param.requires_grad = True 231 | if '0' in name: 232 | fc_params_group_3.append(param) 233 | elif '1' in name: 234 | fc_params_group_2.append(param) 235 | else: 236 | fc_params_to_update.append(param) 237 | else: 238 | param.requires_grad = False 239 | 240 | ''' 241 | optimizer_ft = optim.SGD(fc_params_to_update, lr=lr, momentum=0.9, 242 | weight_decay=wd) 243 | ''' 244 | optimizer_ft = optim.SGD([ 245 | {'params': fc_params_to_update}, 246 | {'params': fc_params_group_2, 'lr': lr * 1}, 247 | {'params': fc_params_group_3, 'lr': lr * 1}], 248 | lr=lr, momentum=0.9, weight_decay=wd) 249 | 250 | return optimizer_ft 251 | 252 | 253 | def inverting_categories( 254 | classes, 255 | model, 256 | criterion, 257 | input_size, 258 | tv_beta = 2, 259 | num_steps=200, 260 | logger_name='inv_logger', 261 | ): 262 | logger = logging.getLogger(logger_name) 263 | device = next(model.parameters()).device 264 | output_imgs = [] 265 | for i in range(len(classes)): 266 | target_label = torch.tensor([i], dtype=torch.int64, device=device) 267 | logger.info('=' * 80 + '\nClass {}:'.format(classes[i])) 268 | img = torch.randn( 269 | [1, 3, *input_size], 270 | dtype=torch.float32, 271 | device=device, 272 | requires_grad=True 273 | ) 274 | optimizer = optim.LBFGS([img]) 275 | 276 | itr = 0 277 | while itr < num_steps: 278 | 279 | cache_loss = [0.0] 280 | def closure(): 281 | optimizer.zero_grad() 282 | preds_softmax = model(img) 283 | 284 | inv_loss = [] 285 | for output in preds_softmax: 286 | loss_ = criterion(output, target_label) 287 | inv_loss.append(loss_) 288 | loss = torch.sum(torch.stack(inv_loss)) 289 | 290 | d1 = img[:,:,1:,:] - img[:,:,:-1,:] 291 | d2 = img[:,:,:,1:] - img[:,:,:,:-1] 292 | tv = torch.sum( 293 | ( 294 | torch.sqrt( 295 | d1.view(-1) ** 2 + 296 | d2.view(-1) ** 2 297 | ) ** 298 | tv_beta 299 | ) 300 | ) 301 | 302 | loss += 1e-9 * tv 303 | 304 | loss.backward() 305 | 306 | # logger.info('Loss: {:.4f}'.format(loss.item())) 307 | # current_loss[0] = loss.item() 308 | cache_loss[0] = loss.item() 309 | return loss 310 | 311 | # optimizer.step(lambda : closure(cache_loss)) 312 | optimizer.step(closure) 313 | logger.info('Step {} Loss: {}'.format(itr, cache_loss[0])) 314 | itr += 1 315 | 316 | output_imgs.append(torch.squeeze(img)) 317 | 318 | return output_imgs 319 | 320 | def save_outputs(output_imgs, classes, output_folder): 321 | device = output_imgs[0].device 322 | img_mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1).to(device) 323 | img_var = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1).to(device) 324 | 325 | for img, c_name in zip(output_imgs, classes): 326 | img = img * img_var + img_mean 327 | img.data.clamp_(0, 1) 328 | img = img.permute(2, 1, 0).cpu().detach().numpy() 329 | x_range = np.percentile(img, [1, 99]) 330 | img = np.clip(img, x_range[0], x_range[1]) 331 | 332 | img = (img - x_range[0]) / (x_range[1] - x_range[0]) 333 | 334 | output_file_name = os.path.join(output_folder, c_name + '.png') 335 | scipy.misc.imsave(output_file_name, img) 336 | 337 | def main(args): 338 | lr = args.lr 339 | input_size = args.input_size 340 | 341 | args.exp_dir = os.path.join(args.dataset, args.exp_dir) 342 | 343 | if args.dataset in ['cars', 'aircrafts']: 344 | keep_aspect = False 345 | else: 346 | keep_aspect = True 347 | 348 | if args.dataset in ['aircrafts']: 349 | crop_from_size = [(x * 256) // 224 for x in input_size] 350 | else: 351 | crop_from_size = input_size 352 | 353 | if 'inat' in args.dataset: 354 | split = {'train': 'train', 'val': 'val'} 355 | else: 356 | split = {'train': 'train_val', 'val': 'test'} 357 | 358 | if len(input_size) > 1: 359 | assert order == len(input_size) 360 | 361 | if not keep_aspect: 362 | input_size = [(x, x) for x in input_size] 363 | crop_from_size = [(x, x) for x in crop_from_size] 364 | 365 | exp_root = '../exp_inversion' 366 | checkpoint_folder = os.path.join(exp_root, args.exp_dir, 'checkpoints') 367 | 368 | if not os.path.isdir(checkpoint_folder): 369 | os.makedirs(checkpoint_folder) 370 | 371 | # log the setup for the experiments 372 | args_dict = vars(args) 373 | with open(os.path.join(exp_root, args.exp_dir, 'args.txt'), 'a') as f: 374 | f.write(json.dumps(args_dict, sort_keys=True, indent=4)) 375 | 376 | # make sure the dataset is ready 377 | if 'inat' in args.dataset: 378 | setup_dataset('inat') 379 | else: 380 | setup_dataset(args.dataset) 381 | 382 | # ================== Craete data loader ================================== 383 | data_transforms = { 384 | 'train': [transforms.Compose([ 385 | transforms.Resize(x[0]), 386 | # transforms.CenterCrop(x[1]), 387 | transforms.RandomCrop(x[1]), 388 | transforms.RandomHorizontalFlip(), 389 | transforms.ToTensor(), 390 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) \ 391 | for x in zip(crop_from_size, input_size)], 392 | 'val': [transforms.Compose([ 393 | transforms.Resize(x[0]), 394 | transforms.CenterCrop(x[1]), 395 | transforms.ToTensor(), 396 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) \ 397 | for x in zip(crop_from_size, input_size)], 398 | } 399 | 400 | 401 | if args.dataset == 'cub': 402 | from CUBDataset import CUBDataset as dataset 403 | elif args.dataset == 'cars': 404 | from CarsDataset import CarsDataset as dataset 405 | elif args.dataset == 'aircrafts': 406 | from AircraftsDataset import AircraftsDataset as dataset 407 | elif 'inat' in args.dataset: 408 | from iNatDataset import iNatDataset as dataset 409 | if args.dataset == 'inat': 410 | subset = None 411 | else: 412 | subset = args.dataset[len('inat_'):] 413 | subset = subset[0].upper() + subset[1:] 414 | else: 415 | raise ValueError('Unknown dataset: %s' % task) 416 | 417 | if 'inat' in args.dataset: 418 | dset = {x: dataset(dset_root['inat'], split[x], subset, \ 419 | transform=data_transforms[x]) for x in ['train', 'val']} 420 | else: 421 | dset = {x: dataset(dset_root[args.dataset], split[x], \ 422 | transform=data_transforms[x]) for x in ['train', 'val']} 423 | 424 | dset_loader = {x: torch.utils.data.DataLoader(dset[x], 425 | batch_size=32, shuffle=True, num_workers=8, 426 | drop_last=drop_last) \ 427 | for x, drop_last in zip(['train', 'val'], [True, False])} 428 | 429 | 430 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 431 | 432 | #======================= Initialize the model ========================= 433 | model = create_multi_heads_bcnn(len(dset['train'].classes)) 434 | model = model.to(device) 435 | model = torch.nn.DataParallel(model) 436 | 437 | # Setup the loss fxn 438 | criterion = nn.CrossEntropyLoss() 439 | 440 | #====================== Initialize optimizer ============================== 441 | model_checkpoint = os.path.join(checkpoint_folder, 'checkpoint.pth.tar') 442 | start_itr = 0 443 | optim_fc = initialize_optimizer( 444 | model, 445 | args.lr, 446 | wd=args.wd, 447 | ) 448 | 449 | logger_name = 'train_logger' 450 | logger = initializeLogging( 451 | os.path.join(exp_root, args.exp_dir, 'train_fc_history.txt'), 452 | logger_name 453 | ) 454 | 455 | model_train_fc = False 456 | fc_model_path = os.path.join(exp_root, args.exp_dir, 'fc_params.pth.tar') 457 | if not args.train_from_beginning: 458 | if os.path.isfile(fc_model_path): 459 | # load the fc parameters if they are already trained 460 | print("=> loading fc parameters'{}'".format(fc_model_path)) 461 | checkpoint = torch.load(fc_model_path) 462 | model.module.fc_list.load_state_dict(checkpoint['state_dict']) 463 | print("=> loaded fc initialization parameters") 464 | else: 465 | if os.path.isfile(model_checkpoint): 466 | # load the checkpoint if it exists 467 | print("=> loading checkpoint '{}'".format(model_checkpoint)) 468 | checkpoint = torch.load(model_checkpoint) 469 | start_itr = checkpoint['itr'] 470 | model.module.fc_list.load_state_dict(checkpoint['state_dict']) 471 | optim_fc.load_state_dict(checkpoint['optimizer']) 472 | print("=> loaded checkpoint for the fc initialization") 473 | 474 | # resume training 475 | model_train_fc = True 476 | else: 477 | # Training everything from the beginning 478 | model_train_fc = True 479 | start_itr = 0 480 | 481 | if model_train_fc: 482 | # do the training 483 | model.eval() 484 | 485 | model = train_model(model, dset_loader, criterion, optim_fc, 486 | batch_size_update=256, 487 | epoch=args.epoch, logger_name=logger_name, start_itr=start_itr, 488 | checkpoint_folder=checkpoint_folder, fine_tune=False) 489 | shutil.copyfile( 490 | os.path.join(checkpoint_folder, 'model_best.pth.tar'), 491 | fc_model_path) 492 | 493 | logger_inv = initializeLogging( 494 | os.path.join(exp_root, args.exp_dir, 'inv_history.txt'), 495 | 'inv_logger' 496 | ) 497 | 498 | output_images = inverting_categories( 499 | dset['train'].classes, 500 | model, 501 | criterion, 502 | [224, 224], 503 | logger_name='inv_logger', 504 | ) 505 | 506 | inv_folder = os.path.join(exp_root, args.exp_dir, 'inv_outputs') 507 | if not os.path.isdir(inv_folder): 508 | os.makedirs(inv_folder) 509 | save_outputs(output_images, dset['train'].classes, inv_folder) 510 | 511 | if __name__ == '__main__': 512 | parser = argparse.ArgumentParser() 513 | parser.add_argument('--epoch', default=45, type=int, 514 | help='number of epochs') 515 | parser.add_argument('--lr', default=1, type=float, 516 | help='learning rate') 517 | parser.add_argument('--wd', default=1e-8, type=float, 518 | help='weight decay') 519 | parser.add_argument('--exp_dir', default='inv', type=str, 520 | help='foldername where to save the results for the experiment') 521 | parser.add_argument('--train_from_beginning', action='store_true', 522 | help='train the model from first epoch, i.e. ignore the checkpoint') 523 | parser.add_argument('--dataset', default='cub', type=str, 524 | help='cub | cars | aircrafts') 525 | parser.add_argument('--input_size', nargs='+', default=[448], type=int, 526 | help='input size as a list of sizes') 527 | args = parser.parse_args() 528 | 529 | main(args) 530 | 531 | -------------------------------------------------------------------------------- /matrixSquareRoot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | 5 | class MatrixSquareRootFun(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, A, numIters, I, backwardIter): 9 | bs, dim, _ = A.shape 10 | normA = A.norm('fro', dim=[1, 2], keepdim=True) 11 | Y = A.div(normA) 12 | 13 | Z = I.clone() 14 | Z = Z.unsqueeze(0).repeat(bs, 1, 1) 15 | I = I.unsqueeze(0).expand(bs, dim, dim) 16 | 17 | for i in range(numIters): 18 | T = 0.5 * (3.0 * I - Z.bmm(Y)) 19 | Y = Y.bmm(T) 20 | Z = T.bmm(Z) 21 | 22 | sA = Y.mul(torch.sqrt(normA)) 23 | ctx.save_for_backward(sA, I) 24 | ctx.backwardIter = backwardIter 25 | 26 | return sA 27 | 28 | @staticmethod 29 | def backward(ctx, grad_output): 30 | z = ctx.saved_tensors[0] 31 | I = ctx.saved_tensors[1] 32 | bs, dim, _ = z.shape 33 | normz = z.norm('fro', dim=[1, 2], keepdim=True) 34 | a = z.div(normz) 35 | q = grad_output.div(normz) 36 | 37 | for i in range(ctx.backwardIter): 38 | q = 0.5 * (q.bmm(3.0 * I - a.bmm(a)) - \ 39 | a.transpose(1, 2).bmm(a.transpose(1,2).bmm(q) - q.bmm(a))) 40 | a = 0.5 * a.bmm(3.0 * I - a.bmm(a)) 41 | 42 | dlda = 0.5 * q 43 | return (dlda, None, None, None) 44 | 45 | 46 | class MatrixSquareRoot(nn.Module): 47 | def __init__(self, numIter, dim, backwardIter=0): 48 | super(MatrixSquareRoot, self).__init__() 49 | self.numIter = numIter 50 | self.dim = dim 51 | self.register_buffer('I', torch.eye(dim, dim)) 52 | 53 | if backwardIter < 1: 54 | self.backwardIter = numIter 55 | else: 56 | self.backwardIter = backwardIter 57 | 58 | def forward(self, x): 59 | return MatrixSquareRootFun.apply(x, self.numIter, self.I, self.backwardIter) 60 | -------------------------------------------------------------------------------- /plot_curve.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import re 7 | import argparse 8 | 9 | 10 | res = [re.compile('Iteration (\d+)/(\d+)'), 11 | re.compile('Train Loss: ([.\d]+) Acc: ([.\d]+)'), 12 | re.compile('Validation Loss: ([.\d]+) Acc: ([.\d]+)')] 13 | 14 | def plot_acc(log_name): 15 | 16 | data = {} 17 | with open(log_name) as f: 18 | lines = f.readlines() 19 | for l in lines: 20 | i = 0 21 | for r in res: 22 | m = r.match(l) 23 | if m is not None: 24 | break 25 | i += 1 26 | if m is None: 27 | continue 28 | if i == 0: 29 | iteration = int(m.groups()[0]) 30 | total_iter = int(m.groups()[1]) 31 | if iteration not in data: 32 | data[iteration] = [0] * 4 33 | else: 34 | loss = float(m.groups()[0]) 35 | acc = float(m.groups()[1]) 36 | 37 | data[iteration][(i-1)*2] = loss 38 | data[iteration][(i-1)*2 +1] = acc 39 | 40 | 41 | train_acc = [] 42 | train_loss = [] 43 | val_acc = [] 44 | val_loss = [] 45 | 46 | for k, v in data.items(): 47 | train_loss.append(v[0]) 48 | train_acc.append(v[1]) 49 | val_loss.append(v[2]) 50 | val_acc.append(v[3]) 51 | 52 | 53 | iter_list = [int(x) + 1 for x in data.keys()] 54 | x_train = iter_list 55 | x_val = iter_list 56 | # x_train = np.arange(len(train_acc)) 57 | # x_val = np.arange(len(val_acc)) 58 | plt.subplot(1, 2, 1) 59 | plt.plot(x_train, train_acc, '-', linestyle='-', color='r', linewidth=2, 60 | label='train_top1') 61 | plt.plot(x_val, val_acc, '-', linestyle='-', color='b', linewidth=2, 62 | label='val_top1') 63 | plt.legend(loc="best") 64 | plt.xticks(np.arange(0, iter_list[-1], iter_list[-1]//10)) 65 | plt.yticks(np.arange(0.1, 1, 0.05)) 66 | plt.xlim([0, iter_list[-1]]) 67 | min_y = min([min(train_acc), min(val_acc)]) - 0.05 68 | max_y = max([max(train_acc), max(val_acc)]) + 0.05 69 | if max_y - min_y < 0.1: 70 | min_y = max(0, min_y - 0.05) 71 | max_y = min(1, max_y + 0.05) 72 | plt.ylim(min_y, max_y) 73 | 74 | 75 | # plt.ylim([min([min(train_acc), min(val_acc)]), 76 | # max([max(train_acc), max(val_acc)])]) 77 | plt.grid(True) 78 | 79 | plt.subplot(1, 2, 2) 80 | plt.semilogy(x_train, train_loss, '-', linestyle='-', color='r', linewidth=2, 81 | label='train_loss') 82 | plt.semilogy(x_val, val_loss, '-', linestyle='-', color='b', linewidth=2, 83 | label='val_loss') 84 | plt.legend(loc="best") 85 | plt.xticks(np.arange(0, iter_list[-1], iter_list[-1]//10)) 86 | plt.yticks(np.arange(0.1, 1, 0.05)) 87 | plt.xlim([0, iter_list[-1]]) 88 | # plt.ylim([min([min(train_loss), min(val_loss)]), 89 | # max([max(train_loss), max(val_loss)])]) 90 | plt.yscale('log') 91 | plt.grid(True) 92 | 93 | return max(val_acc) 94 | 95 | 96 | def plot_log(log_path, save_path, close_fig=True): 97 | plt.figure(figsize=(14, 8)) 98 | plt.xlabel("Iterations") 99 | plt.ylabel("Accuracy") 100 | 101 | max_acc = plot_acc(log_path) 102 | plt.grid(True) 103 | plt.savefig(save_path) 104 | if close_fig: 105 | plt.close() 106 | 107 | return max_acc 108 | 109 | def main(args): 110 | _ = plot_log(os.path.join('../exp', args.exp_dir, args.logs), 111 | os.path.join('../exp', args.exp_dir, args.output_filename)) 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('--logs', type=str, default='train_history.txt') 117 | parser.add_argument('--exp_dir', type=str, default='exp') 118 | parser.add_argument('--output_filename', type=str, 119 | default='train_curve.png') 120 | 121 | args = parser.parse_args() 122 | main(args) 123 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | future==0.18.2 3 | kiwisolver==1.2.0 4 | matplotlib==3.2.1 5 | numpy==1.18.4 6 | Pillow==7.1.2 7 | pkg-resources==0.0.0 8 | pyparsing==2.4.7 9 | python-dateutil==2.8.1 10 | scipy==1.4.1 11 | six==1.15.0 12 | torch==1.5.0+cu101 13 | torchvision==0.6.0+cu101 14 | -------------------------------------------------------------------------------- /resize_inat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # import cv2 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | # import lmdb 6 | import os 7 | import json 8 | import math 9 | import PIL 10 | from config import dset_root 11 | from shutil import copyfile 12 | from random import shuffle 13 | 14 | ''' 15 | base_file_list = ['./filelists/metaiNat/base.json', 16 | './filelists/metaiNat/val.json', 17 | './filelists/metaiNat/novel.json', 18 | './filelists/metaiNat/novel_train.json', 19 | './filelists/metaiNat/novel_test.json'] 20 | ''' 21 | 22 | all_json = ['train2018.json', 'val2018.json', 'test2018.json'] 23 | 24 | to_size = 448 25 | output_root = dset_root['inat'].replace('inat_2018', 'inat_2018_%d'%to_size) 26 | 27 | for json_file in all_json: 28 | # for base_file in base_file_list: 29 | 30 | base_file = os.path.join(dset_root['inat'], json_file) 31 | with open(base_file, 'r') as f: 32 | meta = json.load(f) 33 | 34 | folder_name = set([os.path.dirname(x['file_name']) for x in meta['images']]) 35 | 36 | for x in folder_name: 37 | fpath = os.path.join(output_root, x) 38 | if not os.path.isdir(fpath): 39 | os.makedirs(fpath) 40 | 41 | ''' 42 | lmdb_file, _ = os.path.splitext(base_file) 43 | base_name = os.path.basename(lmdb_file) 44 | dataset = SimpleDataset(base_file, None) 45 | ''' 46 | 47 | num_img = len(meta['images']) 48 | shuffle(meta['images']) 49 | 50 | for idx, x in enumerate(meta['images']): 51 | 52 | if (idx + 1) % 1000 == 0: 53 | print('%s: %d / %d'%(json_file.split('.')[0], idx+1, num_img)) 54 | 55 | out_name = os.path.join(output_root, x['file_name']) 56 | if os.path.isfile(out_name): 57 | continue 58 | im = PIL.Image.open(os.path.join(dset_root['inat'], x['file_name'])) 59 | ratio = to_size / min(im.size) 60 | resize_to = tuple([math.ceil(y*ratio) for y in im.size]) 61 | resizeImg = im.resize(resize_to, resample=PIL.Image.BILINEAR) 62 | 63 | resizeImg.save(out_name) 64 | 65 | 66 | 67 | # copyfile(base_file, os.path.join(output_root, json_file)) 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import transforms 4 | import os 5 | from config import dset_root 6 | import argparse 7 | import logging 8 | from BCNN import create_bcnn_model 9 | import sys 10 | 11 | pretrain_folder = "pretrained_models" 12 | 13 | 14 | def initializeLogging(logger_name): 15 | log = logging.getLogger(logger_name) 16 | log.setLevel(logging.DEBUG) 17 | log.addHandler(logging.StreamHandler(sys.stdout)) 18 | 19 | return log 20 | 21 | 22 | def test_model(model, criterion, dset_loader, logger_name=None): 23 | 24 | if logger_name is not None: 25 | logger = logging.getLogger(logger_name) 26 | device = next(model.parameters()).device 27 | model.eval() 28 | 29 | running_corrects = 0 30 | for idx, all_fields in enumerate(dset_loader): 31 | if logger_name is not None and (idx + 1) % 10 == 0: 32 | logger.info("%d / %d" % (idx + 1, len(dset_loader))) 33 | labels = all_fields[-2] 34 | inputs = all_fields[:-2] 35 | inputs = [x.to(device) for x in inputs] 36 | labels = labels.to(device) 37 | 38 | with torch.set_grad_enabled(False): 39 | outputs = model(*inputs) 40 | _, preds = torch.max(outputs, 1) 41 | 42 | running_corrects += torch.sum(preds == labels.data) 43 | 44 | test_acc = running_corrects.double() / len(dset_loader.dataset) 45 | 46 | if logger_name is not None: 47 | logger.info("Test accuracy: {:.3f}".format(test_acc)) 48 | 49 | 50 | def main(args): 51 | model_path = os.path.join(pretrain_folder, args.pretrained_filename) 52 | input_size = args.input_size 53 | 54 | _ = initializeLogging("mylogger") 55 | 56 | if args.dataset in ["cars", "aircrafts"]: 57 | keep_aspect = False 58 | else: 59 | keep_aspect = True 60 | 61 | if args.dataset in ["aircrafts"]: 62 | crop_from_size = [(x * 256) // 224 for x in input_size] 63 | else: 64 | crop_from_size = input_size 65 | 66 | if not keep_aspect: 67 | input_size = [(x, x) for x in input_size] 68 | crop_from_size = [(x, x) for x in crop_from_size] 69 | 70 | data_transforms = [ 71 | transforms.Compose( 72 | [ 73 | transforms.Resize(x[0]), 74 | transforms.CenterCrop(x[1]), 75 | transforms.ToTensor(), 76 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 77 | ] 78 | ) 79 | for x in zip(crop_from_size, input_size) 80 | ] 81 | 82 | if args.dataset == "cub": 83 | from CUBDataset import CUBDataset as Dataset 84 | elif args.dataset == "cars": 85 | from CarsDataset import CarsDataset as Dataset 86 | elif args.dataset == "aircrafts": 87 | from AircraftsDataset import AircraftsDataset as Dataset 88 | else: 89 | raise ValueError("Unknown dataset: %s" % args.dataset) 90 | 91 | # TODO: check the split name 92 | dset_test = Dataset(dset_root[args.dataset], "test", transform=data_transforms) 93 | test_loader = torch.utils.data.DataLoader( 94 | dset_test, 95 | batch_size=args.batch_size, 96 | shuffle=False, 97 | num_workers=8, 98 | drop_last=False, 99 | ) 100 | 101 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 102 | model = create_bcnn_model( 103 | args.model_names_list, 104 | len(dset_test.classes), 105 | args.pooling_method, 106 | False, 107 | True, 108 | args.embedding_dim, 109 | 2, 110 | m_sqrt_iter=args.matrix_sqrt_iter, 111 | proj_dim=args.proj_dim, 112 | ) 113 | model = model.to(device) 114 | model = torch.nn.DataParallel(model) 115 | criterion = nn.CrossEntropyLoss() 116 | if os.path.isfile(model_path): 117 | print("=> loading checkpoint '{}'".format(model_path)) 118 | checkpoint = torch.load(model_path) 119 | model.load_state_dict(checkpoint["state_dict"]) 120 | print("=> loaded checkpoint '{}')".format(model_path)) 121 | else: 122 | raise ValueError("pretrained model %s does not exist" % (model_path)) 123 | 124 | test_model(model, criterion, test_loader, "mylogger") 125 | 126 | 127 | if __name__ == "__main__": 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument( 130 | "--batch_size", 131 | default=32, 132 | type=int, 133 | help="size of mini-batch that can fit into gpus", 134 | ) 135 | parser.add_argument( 136 | "--pretrained_filename", type=str, help="file name of pretrained model", 137 | ) 138 | parser.add_argument( 139 | "--dataset", default="cub", type=str, help="cub | cars | aircrafts" 140 | ) 141 | parser.add_argument( 142 | "--input_size", 143 | nargs="+", 144 | default=[448], 145 | type=int, 146 | help="input size as a list of sizes", 147 | ) 148 | parser.add_argument( 149 | "--model_names_list", 150 | nargs="+", 151 | default=["vgg"], 152 | type=str, 153 | help="input size as a list of sizes", 154 | ) 155 | parser.add_argument( 156 | "--pooling_method", 157 | default="outer_product", 158 | type=str, 159 | help="outer_product | sketch | gamma_demo | sketch_gamma_demo", 160 | ) 161 | parser.add_argument( 162 | "--embedding_dim", 163 | type=int, 164 | default=8192, 165 | help="the dimension for the tnesor sketch approximation", 166 | ) 167 | parser.add_argument( 168 | "--matrix_sqrt_iter", 169 | type=int, 170 | default=0, 171 | help="number of iteration for the Newtons Method approximating" 172 | + "matirx square rooti. Default=0 [no matrix square root]", 173 | ) 174 | parser.add_argument( 175 | "--proj_dim", 176 | type=int, 177 | default=0, 178 | help="project the dimension of cnn features to lower " 179 | + "dimensionality before computing tensor product", 180 | ) 181 | parser.add_argument( 182 | "--gamma", 183 | default=0.5, 184 | type=float, 185 | help="the value of gamma for gamma democratic aggregation", 186 | ) 187 | args = parser.parse_args() 188 | 189 | main(args) 190 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from config import dset_root, setup_dataset 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import numpy as np 6 | import torchvision 7 | from torchvision import datasets, models, transforms 8 | import os 9 | import random 10 | import argparse 11 | import copy 12 | import logging 13 | import sys 14 | import time 15 | import shutil 16 | from BCNN import create_bcnn_model 17 | from test import test_model 18 | from plot_curve import plot_log 19 | import json 20 | 21 | def initializeLogging(log_filename, logger_name): 22 | log = logging.getLogger(logger_name) 23 | log.setLevel(logging.DEBUG) 24 | log.addHandler(logging.StreamHandler(sys.stdout)) 25 | log.addHandler(logging.FileHandler(log_filename, mode='a')) 26 | 27 | return log 28 | 29 | def save_checkpoint(state, is_best, checkpoint_folder='exp', 30 | filename='checkpoint.pth.tar'): 31 | filename = os.path.join(checkpoint_folder, filename) 32 | best_model_filename = os.path.join(checkpoint_folder, 'model_best.pth.tar') 33 | torch.save(state, filename) 34 | if is_best: 35 | shutil.copyfile(filename, best_model_filename) 36 | 37 | def initialize_optimizer(model_ft, lr, optimizer='sgd', wd=0, finetune_model=True, 38 | proj_lr=1e-3, proj_wd=1e-5, beta1=0.9, beta2=0.999): 39 | fc_params_to_update = [] 40 | params_to_update = [] 41 | proj_params_to_update = [] 42 | if finetune_model: 43 | for name, param in model_ft.named_parameters(): 44 | # if name == 'module.fc.bias' or name == 'module.fc.weight': 45 | if 'module.fc' in name: 46 | fc_params_to_update.append(param) 47 | else: 48 | if model_ft.module.learn_proj and \ 49 | 'feature_extractors.0.1.weight' in name: 50 | proj_params_to_update.append(param) 51 | else: 52 | params_to_update.append(param) 53 | param.requires_grad = True 54 | 55 | # Observe that all parameters are being optimized 56 | if optimizer == 'sgd': 57 | optimizer_ft = optim.SGD([ 58 | {'params': params_to_update}, 59 | {'params': proj_params_to_update, 60 | 'weight_decay': proj_wd, 'lr': proj_lr}, 61 | {'params': fc_params_to_update, 'weight_decay': 1e-5, 'lr': 1e-2}], 62 | lr=lr, momentum=0.9, weight_decay=wd) 63 | elif optimizer == 'adam': 64 | optimizer_ft = optim.Adam([ 65 | {'params': params_to_update}, 66 | {'params': proj_params_to_update, 67 | 'weight_decay': proj_wd, 'lr': proj_lr}, 68 | {'params': fc_params_to_update, 'weight_decay': 1e-5, 'lr': 1e-2}], 69 | lr=lr, weight_decay=wd, 70 | betas=(beta1, beta2)) 71 | else: 72 | raise ValueError('Unknown optimizer: %s' % optimizer) 73 | else: 74 | for name, param in model_ft.named_parameters(): 75 | # if name == 'module.fc.bias' or name == 'module.fc.weight': 76 | if 'module.fc' in name: 77 | param.requires_grad = True 78 | fc_params_to_update.append(param) 79 | else: 80 | if model_ft.module.learn_proj and \ 81 | 'feature_extractors.0.1.weight' in name: 82 | param.requires_grad = True 83 | proj_params_to_update.append(param) 84 | else: 85 | param.requires_grad = False 86 | 87 | # Observe that all parameters are being optimized 88 | if optimizer == 'sgd': 89 | if len(proj_params_to_update) == 0: 90 | optimizer_ft = optim.SGD(fc_params_to_update, lr=lr, momentum=0.9, 91 | weight_decay=wd) 92 | else: 93 | optimizer_ft = optim.SGD( 94 | [{'params': fc_params_to_update}, 95 | {'params': proj_params_to_update, 96 | 'weight_decay': proj_wd, 'lr': proj_lr}], 97 | lr=lr, momentum=0.9, weight_decay=wd) 98 | elif optimizer == 'adam': 99 | optimizer_ft = optim.Adam(fc_params_to_update, lr=lr, weight_decay=wd, 100 | betas=(beta1, beta2)) 101 | else: 102 | raise ValueError('Unknown optimizer: %s' % optimizer) 103 | 104 | return optimizer_ft 105 | 106 | def train_model(model, dset_loader, criterion, 107 | optimizer, batch_size_update=256, 108 | # maxItr=50000, logger_name='train_logger', checkpoint_folder='exp', 109 | epoch=45, logger_name='train_logger', checkpoint_folder='exp', 110 | start_itr=0, clip_grad=-1, scheduler=None, fine_tune=True): 111 | 112 | maxItr = epoch * len(dset_loader['train'].dataset) // \ 113 | dset_loader['train'].batch_size + 1 114 | 115 | val_every_number_examples = max(10000, 116 | len(dset_loader['train'].dataset) // 5) 117 | val_frequency = val_every_number_examples // dset_loader['train'].batch_size 118 | checkpoint_frequency = 5 * len(dset_loader['train'].dataset) / \ 119 | dset_loader['train'].batch_size 120 | last_checkpoint = start_itr - 1 121 | logger = logging.getLogger(logger_name) 122 | logger_filename = logger.handlers[1].stream.name 123 | 124 | device = next(model.parameters()).device 125 | since = time.time() 126 | 127 | running_loss = 0.0; running_num_data = 0 128 | running_corrects = 0 129 | val_loss_history = []; best_acc = 0.0 130 | val_acc = 0.0 131 | # best_model_wts = copy.deepcopy(model.state_dict()) 132 | 133 | dset_iter = {x:iter(dset_loader[x]) for x in ['train', 'val']} 134 | bs = dset_loader['train'].batch_size 135 | update_frequency = batch_size_update // bs 136 | 137 | if fine_tune: 138 | model.train() 139 | else: 140 | model.module.fc.train() 141 | 142 | last_epoch = 0 143 | for itr in range(start_itr, maxItr): 144 | # at the end of validation set model.train() 145 | if (itr + 1) % val_frequency == 0 or itr == maxItr - 1: 146 | logger.info('Iteration {}/{}'.format(itr, maxItr - 1)) 147 | logger.info('-' * 10) 148 | 149 | try: 150 | all_fields = next(dset_iter['train']) 151 | labels = all_fields[-2] 152 | inputs = all_fields[:-2] 153 | # inputs, labels, _ = next(dset_iter['train']) 154 | except StopIteration: 155 | dset_iter['train'] = iter(dset_loader['train']) 156 | all_fields = next(dset_iter['train']) 157 | labels = all_fields[-2] 158 | inputs = all_fields[:-2] 159 | # inputs, labels, _ = next(dset_iter['train']) 160 | 161 | inputs = [x.to(device) for x in inputs] 162 | labels = labels.to(device) 163 | 164 | with torch.set_grad_enabled(True): 165 | outputs = model(*inputs) 166 | loss = criterion(outputs, labels) 167 | 168 | _, preds = torch.max(outputs, 1) 169 | 170 | loss.backward() 171 | 172 | if (itr + 1) % update_frequency == 0: 173 | if clip_grad > 0: 174 | torch.nn.utils.clip_grad_norm_(model.parameters(), 175 | clip_grad) 176 | optimizer.step() 177 | optimizer.zero_grad() 178 | 179 | epoch = ((itr + 1) * bs) // len(dset_loader['train'].dataset) 180 | 181 | running_num_data += inputs[0].size(0) 182 | running_loss += loss.item() * inputs[0].size(0) 183 | running_corrects += torch.sum(preds == labels.data) 184 | 185 | if (itr + 1) % val_frequency == 0 or itr == maxItr - 1: 186 | running_loss = running_loss / running_num_data 187 | running_acc = running_corrects.double() / running_num_data 188 | # print('{} Loss: {:.4f} Acc: {:.4f}'.format('Train', 189 | # running_loss, running_acc)) 190 | logger.info('{} Loss: {:.4f} Acc: {:.4f}'.format( \ 191 | 'Train', running_loss, running_acc)) 192 | running_loss = 0.0; running_num_data = 0; running_corrects = 0 193 | 194 | model.eval() 195 | val_running_loss = 0.0; val_running_corrects = 0 196 | 197 | # for inputs, labels, _ in dset_loader['val']: 198 | for all_fields in dset_loader['val']: 199 | labels = all_fields[-2] 200 | inputs = all_fields[:-2] 201 | inputs = [x.to(device) for x in inputs] 202 | labels = labels.to(device) 203 | 204 | with torch.set_grad_enabled(False): 205 | outputs = model(*inputs) 206 | loss = criterion(outputs, labels) 207 | 208 | _, preds = torch.max(outputs, 1) 209 | 210 | val_running_loss += loss.item() * inputs[0].size(0) 211 | val_running_corrects += torch.sum(preds == labels.data) 212 | val_loss = val_running_loss / len(dset_loader['val'].dataset) 213 | val_acc = val_running_corrects.double() / len(dset_loader['val'].dataset) 214 | # print('{} Loss: {:.4f} Acc: {:.4f}'.format('Validation', 215 | # val_loss, val_acc)) 216 | logger.info('{} Loss: {:.4f} Acc: {:.4f}'.format( \ 217 | 'Validation', val_loss, val_acc)) 218 | 219 | plot_log(logger_filename, 220 | logger_filename.replace('history.txt', 'curve.png'), True) 221 | 222 | if fine_tune: 223 | model.train() 224 | else: 225 | model.module.fc.train() 226 | 227 | 228 | # update scheduler 229 | if scheduler is not None: 230 | if isinstance(scheduler, \ 231 | torch.optim.lr_scheduler.ReduceLROnPlateau): 232 | if (itr + 1) % val_frequency == 0: 233 | scheduler.step(val_acc) 234 | else: 235 | if epoch > last_epoch and scheduler is not None: 236 | last_epoch = epoch 237 | scheduler.step() 238 | # checkpoint 239 | if (itr + 1) % val_frequency == 0 or itr == maxItr - 1: 240 | is_best = val_acc > best_acc 241 | if is_best: 242 | best_acc = val_acc 243 | # best_model_wts = copy.deepcopy(model.state_dict()) 244 | 245 | do_checkpoint = (itr - last_checkpoint) >= checkpoint_frequency 246 | if is_best or itr == maxItr - 1 or do_checkpoint: 247 | last_checkpoint = itr 248 | checkpoint_dict = { 249 | 'itr': itr + 1, 250 | 'state_dict': model.state_dict(), 251 | 'optimizer' : optimizer.state_dict(), 252 | 'best_acc': best_acc 253 | } 254 | if scheduler is not None: 255 | checkpoint_dict['scheduler'] = scheduler.state_dict() 256 | save_checkpoint(checkpoint_dict, 257 | is_best, checkpoint_folder=checkpoint_folder) 258 | 259 | time_elapsed = time.time() - since 260 | logger.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 261 | logger.info('Best val accuracy: {:4f}'.format(best_acc)) 262 | 263 | # load best model weights 264 | best_model_wts = torch.load(os.path.join(checkpoint_folder, 265 | 'model_best.pth.tar')) 266 | model.load_state_dict(best_model_wts['state_dict']) 267 | 268 | return model 269 | 270 | def main(args): 271 | fine_tune = not args.no_finetune 272 | pre_train = True 273 | 274 | lr = args.lr 275 | input_size = args.input_size 276 | 277 | order = 2 278 | embedding = args.embedding_dim 279 | model_names_list = args.model_names_list 280 | 281 | args.exp_dir = os.path.join(args.dataset, args.exp_dir) 282 | 283 | if args.dataset in ['cars', 'aircrafts', 'mit_indoor']: 284 | keep_aspect = False 285 | else: 286 | keep_aspect = True 287 | 288 | if args.dataset in ['aircrafts']: 289 | crop_from_size = [(x * 256) // 224 for x in input_size] 290 | else: 291 | crop_from_size = input_size 292 | 293 | if 'inat' in args.dataset: 294 | split = {'train': 'train', 'val': 'val'} 295 | else: 296 | split = {'train': 'train_val', 'val': 'test'} 297 | 298 | if len(input_size) > 1: 299 | assert order == len(input_size) 300 | 301 | if not keep_aspect: 302 | input_size = [(x, x) for x in input_size] 303 | crop_from_size = [(x, x) for x in crop_from_size] 304 | 305 | exp_root = '../exp' 306 | checkpoint_folder = os.path.join(exp_root, args.exp_dir, 'checkpoints') 307 | 308 | if not os.path.isdir(checkpoint_folder): 309 | os.makedirs(checkpoint_folder) 310 | 311 | init_checkpoint_folder = os.path.join( 312 | exp_root, args.exp_dir, 'init_checkpoints' 313 | ) 314 | 315 | if not os.path.isdir(init_checkpoint_folder): 316 | os.makedirs(init_checkpoint_folder) 317 | 318 | # log the setup for the experiments 319 | args_dict = vars(args) 320 | with open(os.path.join(exp_root, args.exp_dir, 'args.txt'), 'a') as f: 321 | f.write(json.dumps(args_dict, sort_keys=True, indent=4)) 322 | 323 | # make sure the dataset is ready 324 | if 'inat' in args.dataset: 325 | setup_dataset('inat') 326 | else: 327 | setup_dataset(args.dataset) 328 | 329 | # ================== Craete data loader ================================== 330 | data_transforms = { 331 | 'train': [transforms.Compose([ 332 | transforms.Resize(x[0]), 333 | # transforms.CenterCrop(x[1]), 334 | transforms.RandomCrop(x[1]), 335 | transforms.RandomHorizontalFlip(), 336 | transforms.ToTensor(), 337 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) \ 338 | for x in zip(crop_from_size, input_size)], 339 | 'val': [transforms.Compose([ 340 | transforms.Resize(x[0]), 341 | transforms.CenterCrop(x[1]), 342 | transforms.ToTensor(), 343 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) \ 344 | for x in zip(crop_from_size, input_size)], 345 | } 346 | 347 | 348 | if args.dataset == 'cub': 349 | from CUBDataset import CUBDataset as dataset 350 | elif args.dataset == 'cars': 351 | from CarsDataset import CarsDataset as dataset 352 | elif args.dataset == 'aircrafts': 353 | from AircraftsDataset import AircraftsDataset as dataset 354 | elif args.dataset == 'mit_indoor': 355 | from MITIndoorDataset import MITIndoorDataset as dataset 356 | elif 'inat' in args.dataset: 357 | from iNatDataset import iNatDataset as dataset 358 | if args.dataset == 'inat': 359 | subset = None 360 | else: 361 | subset = args.dataset[len('inat_'):] 362 | subset = subset[0].upper() + subset[1:] 363 | else: 364 | raise ValueError('Unknown dataset: %s' % task) 365 | 366 | if 'inat' in args.dataset: 367 | dset = {x: dataset(dset_root['inat'], split[x], subset, \ 368 | transform=data_transforms[x]) for x in ['train', 'val']} 369 | dset_test = dataset(dset_root['inat'], 'test', subset, \ 370 | transform=data_transforms['val']) 371 | else: 372 | dset = {x: dataset(dset_root[args.dataset], split[x], \ 373 | transform=data_transforms[x]) for x in ['train', 'val']} 374 | dset_test = dataset(dset_root[args.dataset], 'test', \ 375 | transform=data_transforms['val']) 376 | 377 | 378 | dset_loader = {x: torch.utils.data.DataLoader(dset[x], 379 | batch_size=args.batch_size, shuffle=True, num_workers=8, 380 | drop_last=drop_last) \ 381 | for x, drop_last in zip(['train', 'val'], [True, False])} 382 | 383 | 384 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 385 | 386 | #======================= Initialize the model ========================= 387 | 388 | # The argument embedding is used only when tensor_sketch is True 389 | # The argument order is used only when the model parameters are shared 390 | # between feature extractors 391 | model = create_bcnn_model(model_names_list, len(dset['train'].classes), 392 | args.pooling_method, fine_tune, pre_train, embedding, order, 393 | m_sqrt_iter=args.matrix_sqrt_iter, 394 | fc_bottleneck=args.fc_bottleneck, proj_dim=args.proj_dim, 395 | update_sketch=args.update_sketch, gamma=args.gamma) 396 | model = model.to(device) 397 | model = torch.nn.DataParallel(model) 398 | 399 | # Setup the loss fxn 400 | criterion = nn.CrossEntropyLoss() 401 | 402 | #====================== Initialize optimizer ============================== 403 | init_model_checkpoint = os.path.join(init_checkpoint_folder, 404 | 'checkpoint.pth.tar') 405 | start_itr = 0 406 | optim_fc = initialize_optimizer( 407 | model, 408 | args.init_lr, 409 | optimizer='sgd', 410 | wd=args.init_wd, 411 | finetune_model=False, 412 | proj_lr=args.proj_lr, 413 | proj_wd=args.proj_wd, 414 | ) 415 | 416 | logger_name = 'train_init_logger' 417 | logger = initializeLogging(os.path.join(exp_root, args.exp_dir, 418 | 'train_init_history.txt'), logger_name) 419 | 420 | model_train_fc = False 421 | fc_model_path = os.path.join(exp_root, args.exp_dir, 'fc_params.pth.tar') 422 | if not args.train_from_beginning: 423 | if os.path.isfile(fc_model_path): 424 | # load the fc parameters if they are already trained 425 | print("=> loading fc parameters'{}'".format(fc_model_path)) 426 | checkpoint = torch.load(fc_model_path) 427 | model.load_state_dict(checkpoint['state_dict']) 428 | print("=> loaded fc initialization parameters") 429 | else: 430 | if os.path.isfile(init_model_checkpoint): 431 | # load the checkpoint if it exists 432 | print("=> loading checkpoint '{}'".format(init_model_checkpoint)) 433 | checkpoint = torch.load(init_model_checkpoint) 434 | start_itr = checkpoint['itr'] 435 | model.load_state_dict(checkpoint['state_dict']) 436 | optim_fc.load_state_dict(checkpoint['optimizer']) 437 | print("=> loaded checkpoint for the fc initialization") 438 | 439 | # resume training 440 | model_train_fc = True 441 | else: 442 | # Training everything from the beginning 443 | model_train_fc = True 444 | start_itr = 0 445 | 446 | if model_train_fc: 447 | # do the training 448 | if not fine_tune: 449 | model.eval() 450 | 451 | model = train_model(model, dset_loader, criterion, optim_fc, 452 | batch_size_update=256, 453 | epoch=args.init_epoch, logger_name=logger_name, start_itr=start_itr, 454 | checkpoint_folder=init_checkpoint_folder, fine_tune=fine_tune) 455 | shutil.copyfile( 456 | os.path.join(init_checkpoint_folder, 'model_best.pth.tar'), 457 | fc_model_path) 458 | 459 | if fine_tune: 460 | optim = initialize_optimizer(model, args.lr, optimizer=args.optimizer, 461 | wd=args.wd, finetune_model=fine_tune, 462 | beta1=args.beta1, beta2=args.beta2) 463 | 464 | scheduler = torch.optim.lr_scheduler.LambdaLR(optim, 465 | lr_lambda=lambda epoch: 0.1 ** (epoch // 25)) 466 | 467 | logger_name = 'train_logger' 468 | logger = initializeLogging(os.path.join(exp_root, args.exp_dir, 469 | 'train_history.txt'), logger_name) 470 | 471 | start_itr = 0 472 | # load from checkpoint if exist 473 | if not args.train_from_beginning: 474 | checkpoint_filename = os.path.join(checkpoint_folder, 475 | 'checkpoint.pth.tar') 476 | if os.path.isfile(checkpoint_filename): 477 | print("=> loading checkpoint '{}'".format(checkpoint_filename)) 478 | checkpoint = torch.load(checkpoint_filename) 479 | start_itr = checkpoint['itr'] 480 | model.load_state_dict(checkpoint['state_dict']) 481 | optim.load_state_dict(checkpoint['optimizer']) 482 | scheduler.load_state_dict(checkpoint['scheduler']) 483 | print("=> loaded checkpoint '{}' (iteration{})" 484 | .format(checkpoint_filename, checkpoint['itr'])) 485 | 486 | # parallelize the model if using multiple gpus 487 | # if torch.cuda.device_count() > 1: 488 | 489 | # Train the miodel 490 | model = train_model(model, dset_loader, criterion, optim, 491 | batch_size_update=args.batch_size_update_model, 492 | # maxItr=args.iteration, logger_name=logger_name, 493 | epoch=args.epoch, logger_name=logger_name, 494 | checkpoint_folder=checkpoint_folder, 495 | start_itr=start_itr, scheduler=scheduler) 496 | 497 | 498 | if __name__ == '__main__': 499 | parser = argparse.ArgumentParser() 500 | parser.add_argument('--batch_size_update_model', default=128, type=int, 501 | help='optimizer update the model after seeing batch_size number \ 502 | of inputs') 503 | parser.add_argument('--batch_size', default=32, type=int, 504 | help='size of mini-batch that can fit into gpus (sub bacth size') 505 | parser.add_argument('--epoch', default=45, type=int, 506 | help='number of epochs') 507 | parser.add_argument('--init_epoch', default=25, type=int, 508 | help='number of epochs for initializing fc layer') 509 | parser.add_argument('--init_lr', default=1.0, type=float, 510 | help='learning rate') 511 | parser.add_argument('--lr', default=1e-4, type=float, 512 | help='learning rate') 513 | parser.add_argument('--wd', default=1e-5, type=float, 514 | help='weight decay') 515 | parser.add_argument('--init_wd', default=1e-8, type=float, 516 | help='weight decay for initializing fc layer') 517 | parser.add_argument('--optimizer', default='adam', type=str, 518 | help='optimizer sgd|adam') 519 | parser.add_argument('--exp_dir', default='exp', type=str, 520 | help='foldername where to save the results for the experiment') 521 | parser.add_argument('--train_from_beginning', action='store_true', 522 | help='train the model from first epoch, i.e. ignore the checkpoint') 523 | parser.add_argument('--dataset', default='cub', type=str, 524 | help='cub | cars | aircrafts') 525 | parser.add_argument('--input_size', nargs='+', default=[448], type=int, 526 | help='input size as a list of sizes') 527 | parser.add_argument('--model_names_list', nargs='+', default=['vgg'], 528 | type=str, help='input size as a list of sizes') 529 | parser.add_argument('--pooling_method', default='outer_product', type=str, 530 | help='outer_product | sketch | gamma_demo | sketch_gamma_demo') 531 | parser.add_argument('--embedding_dim', type=int, default=8192, 532 | help='the dimension for the tnesor sketch approximation') 533 | parser.add_argument('--matrix_sqrt_iter', type=int, default=0, 534 | help='number of iteration for the Newtons Method approximating' + \ 535 | 'matirx square rooti. Default=0 [no matrix square root]') 536 | parser.add_argument('--fc_bottleneck', action='store_true', 537 | help='add bottelneck to the fc layers') 538 | parser.add_argument('--proj_dim', type=int, default=0, 539 | help='project the dimension of cnn features to lower ' + \ 540 | 'dimensionality before computing tensor product') 541 | parser.add_argument('--proj_lr', default=1e-3, type=float, 542 | help='learning rate') 543 | parser.add_argument('--proj_wd', default=1e-5, type=float, 544 | help='weight decay') 545 | parser.add_argument('--update_sketch', action='store_true', 546 | help='add bottelneck to the fc layers') 547 | parser.add_argument('--gamma', default=0.5, type=float, 548 | help='the value of gamma for gamma democratic aggregation') 549 | parser.add_argument('--beta1', default=0.99, type=float, 550 | help='the value of beta1 for adam') 551 | parser.add_argument('--beta2', default=0.999, type=float, 552 | help='the value of beta2 for adam') 553 | parser.add_argument('--no_finetune', action='store_true', 554 | help='not do fine tuning') 555 | 556 | args = parser.parse_args() 557 | 558 | main(args) 559 | 560 | -------------------------------------------------------------------------------- /train_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | import torchvision 6 | from torchvision import datasets, models, transforms 7 | import os 8 | from config import dset_root, setup_dataset 9 | import random 10 | import argparse 11 | import copy 12 | import logging 13 | import sys 14 | import time 15 | import shutil 16 | from CNN import create_cnn_model 17 | from test import test_model 18 | from plot_curve import plot_log 19 | import json 20 | 21 | def initializeLogging(log_filename, logger_name): 22 | log = logging.getLogger(logger_name) 23 | log.setLevel(logging.DEBUG) 24 | log.addHandler(logging.StreamHandler(sys.stdout)) 25 | log.addHandler(logging.FileHandler(log_filename, mode='a')) 26 | 27 | return log 28 | 29 | def save_checkpoint(state, is_best, checkpoint_folder='exp', 30 | filename='checkpoint.pth.tar'): 31 | filename = os.path.join(checkpoint_folder, filename) 32 | best_model_filename = os.path.join(checkpoint_folder, 'model_best.pth.tar') 33 | torch.save(state, filename) 34 | if is_best: 35 | shutil.copyfile(filename, best_model_filename) 36 | 37 | # def initialize_optimizer(model_ft, lr, optimizer='sgd', finetune_model=True): 38 | def initialize_optimizer(model_ft, lr, optimizer='sgd', wd=0, finetune_model=True, 39 | proj_lr=1e-3, proj_wd=1e-5, beta1=0.9, beta2=0.999): 40 | fc_params_to_update = [] 41 | params_to_update = [] 42 | if finetune_model: 43 | for name,param in model_ft.named_parameters(): 44 | # if name == 'module.fc.bias' or name == 'module.fc.weight': 45 | if 'module.fc' in name: 46 | fc_params_to_update.append(param) 47 | else: 48 | params_to_update.append(param) 49 | param.requires_grad = True 50 | 51 | # Observe that all parameters are being optimized 52 | if optimizer == 'sgd': 53 | ''' 54 | optimizer_ft = optim.SGD([ 55 | {'params': params_to_update}, 56 | {'params': fc_params_to_update, 'weight_decay': 1e-5, 'lr': 1e-2}], 57 | lr=lr, momentum=0.9, weight_decay=wd) 58 | ''' 59 | optimizer_ft = optim.SGD([ 60 | {'params': params_to_update}, 61 | {'params': fc_params_to_update}], 62 | lr=lr, momentum=0.9, weight_decay=wd) 63 | elif optimizer == 'adam': 64 | optimizer_ft = optim.Adam([ 65 | {'params': params_to_update}, 66 | {'params': fc_params_to_update}], 67 | lr=lr, weight_decay=wd, 68 | betas=(beta1, beta2)) 69 | else: 70 | raise ValueError('Unknown optimizer: %s' % optimizer) 71 | else: 72 | for name,param in model_ft.named_parameters(): 73 | # if name == 'module.fc.bias' or name == 'module.fc.weight': 74 | if 'module.fc' in name: 75 | param.requires_grad = True 76 | fc_params_to_update.append(param) 77 | else: 78 | param.requires_grad = False 79 | 80 | # Observe that all parameters are being optimized 81 | if optimizer == 'sgd': 82 | optimizer_ft = optim.SGD(fc_params_to_update, lr=lr, momentum=0.9, 83 | weight_decay=wd) 84 | elif optimizer == 'adam': 85 | optimizer_ft = optim.Adam(fc_params_to_update, lr=lr, weight_decay=wd, 86 | betas=(beta1, beta2)) 87 | else: 88 | raise ValueError('Unknown optimizer: %s' % optimizer) 89 | 90 | return optimizer_ft 91 | 92 | def train_model(model, dset_loader, criterion, 93 | optimizer, batch_size_update=256, 94 | # maxItr=50000, logger_name='train_logger', checkpoint_folder='exp', 95 | epoch=45, logger_name='train_logger', checkpoint_folder='exp', 96 | start_itr=0, clip_grad=-1, scheduler=None, fine_tune=True): 97 | 98 | maxItr = epoch * len(dset_loader['train'].dataset) // \ 99 | dset_loader['train'].batch_size + 1 100 | 101 | val_every_number_examples = max(10000, 102 | len(dset_loader['train'].dataset) // 5) 103 | val_frequency = val_every_number_examples // dset_loader['train'].batch_size 104 | checkpoint_frequency = 5 * len(dset_loader['train'].dataset) // \ 105 | dset_loader['train'].batch_size 106 | last_checkpoint = start_itr - 1 107 | # val_frequency = 10000 // dset_loader['train'].batch_size 108 | logger = logging.getLogger(logger_name) 109 | logger_filename = logger.handlers[1].stream.name 110 | 111 | device = next(model.parameters()).device 112 | since = time.time() 113 | 114 | running_loss = 0.0; running_num_data = 0 115 | running_corrects = 0 116 | val_loss_history = []; best_acc = 0.0 117 | val_acc = 0.0 118 | # best_model_wts = copy.deepcopy(model.state_dict()) 119 | 120 | dset_iter = {x:iter(dset_loader[x]) for x in ['train', 'val']} 121 | bs = dset_loader['train'].batch_size 122 | update_frequency = batch_size_update // bs 123 | 124 | if fine_tune: 125 | model.train() 126 | else: 127 | model.module.fc.train() 128 | 129 | last_epoch = 0 130 | for itr in range(start_itr, maxItr): 131 | # at the end of validation set model.train() 132 | if (itr + 1) % val_frequency == 0 or itr == maxItr - 1: 133 | logger.info('Iteration {}/{}'.format(itr, maxItr - 1)) 134 | logger.info('-' * 10) 135 | 136 | try: 137 | all_fields = next(dset_iter['train']) 138 | labels = all_fields[-2] 139 | inputs = all_fields[:-2] 140 | # inputs, labels, _ = next(dset_iter['train']) 141 | except StopIteration: 142 | dset_iter['train'] = iter(dset_loader['train']) 143 | all_fields = next(dset_iter['train']) 144 | labels = all_fields[-2] 145 | inputs = all_fields[:-2] 146 | # inputs, labels, _ = next(dset_iter['train']) 147 | 148 | inputs = inputs[0].to(device) 149 | # inputs = [x.to(device) for x in inputs] 150 | labels = labels.to(device) 151 | 152 | with torch.set_grad_enabled(True): 153 | ''' 154 | torch.cuda.synchronize() 155 | torch.cuda.synchronize() 156 | ta = time.perf_counter() 157 | ''' 158 | outputs = model(inputs) 159 | # outputs = model(*inputs) 160 | loss = criterion(outputs, labels) 161 | 162 | _, preds = torch.max(outputs, 1) 163 | 164 | loss.backward() 165 | ''' 166 | torch.cuda.synchronize() 167 | tb = time.perf_counter() 168 | print('time: {:.02e}s'.format((tb - ta)/outputs.shape[0])) 169 | ''' 170 | 171 | if (itr + 1) % update_frequency == 0: 172 | if clip_grad > 0: 173 | torch.nn.utils.clip_grad_norm_(model.parameters(), 174 | clip_grad) 175 | optimizer.step() 176 | optimizer.zero_grad() 177 | 178 | epoch = ((itr + 1) * bs) // len(dset_loader['train'].dataset) 179 | 180 | running_num_data += inputs.size(0) 181 | running_loss += loss.item() * inputs.size(0) 182 | running_corrects += torch.sum(preds == labels.data) 183 | 184 | # evaluate the current model on val 185 | if (itr + 1) % val_frequency == 0 or itr == maxItr - 1: 186 | running_loss = running_loss / running_num_data 187 | running_acc = running_corrects.double() / running_num_data 188 | # print('{} Loss: {:.4f} Acc: {:.4f}'.format('Train', 189 | # running_loss, running_acc)) 190 | logger.info('{} Loss: {:.4f} Acc: {:.4f}'.format( \ 191 | 'Train', running_loss, running_acc)) 192 | running_loss = 0.0; running_num_data = 0; running_corrects = 0 193 | 194 | model.eval() 195 | val_running_loss = 0.0; val_running_corrects = 0 196 | 197 | # for inputs, labels, _ in dset_loader['val']: 198 | for all_fields in dset_loader['val']: 199 | labels = all_fields[-2] 200 | inputs = all_fields[:-2] 201 | 202 | inputs = inputs[0].to(device) 203 | # inputs = [x.to(device) for x in inputs] 204 | labels = labels.to(device) 205 | 206 | with torch.set_grad_enabled(False): 207 | outputs = model(inputs) 208 | # outputs = model(*inputs) 209 | loss = criterion(outputs, labels) 210 | 211 | _, preds = torch.max(outputs, 1) 212 | 213 | # val_running_loss += loss.item() * inputs[0].size(0) 214 | val_running_loss += loss.item() * inputs.size(0) 215 | val_running_corrects += torch.sum(preds == labels.data) 216 | val_loss = val_running_loss / len(dset_loader['val'].dataset) 217 | val_acc = val_running_corrects.double() / len(dset_loader['val'].dataset) 218 | # print('{} Loss: {:.4f} Acc: {:.4f}'.format('Validation', 219 | # val_loss, val_acc)) 220 | logger.info('{} Loss: {:.4f} Acc: {:.4f}'.format( \ 221 | 'Validation', val_loss, val_acc)) 222 | 223 | plot_log(logger_filename, 224 | logger_filename.replace('history.txt', 'curve.png'), True) 225 | 226 | if fine_tune: 227 | model.train() 228 | else: 229 | model.module.fc.train() 230 | 231 | # update scheduler 232 | if scheduler is not None: 233 | if isinstance(scheduler, \ 234 | torch.optim.lr_scheduler.ReduceLROnPlateau): 235 | if (itr + 1) % val_frequency == 0: 236 | scheduler.step(val_acc) 237 | else: 238 | if epoch > last_epoch and scheduler is not None: 239 | last_epoch = epoch 240 | scheduler.step() 241 | # checkpoint 242 | if (itr + 1) % val_frequency == 0 or itr == maxItr - 1: 243 | is_best = val_acc > best_acc 244 | if is_best: 245 | best_acc = val_acc 246 | # best_model_wts = copy.deepcopy(model.state_dict()) 247 | 248 | 249 | do_checkpoint = (itr - last_checkpoint) >= checkpoint_frequency 250 | if is_best or itr == maxItr - 1 or do_checkpoint: 251 | last_checkpoint = itr 252 | checkpoint_dict = { 253 | 'itr': itr + 1, 254 | 'state_dict': model.state_dict(), 255 | 'optimizer' : optimizer.state_dict(), 256 | 'best_acc': best_acc 257 | } 258 | if scheduler is not None: 259 | checkpoint_dict['scheduler'] = scheduler.state_dict() 260 | save_checkpoint(checkpoint_dict, 261 | is_best, checkpoint_folder=checkpoint_folder) 262 | 263 | 264 | time_elapsed = time.time() - since 265 | logger.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 266 | logger.info('Best val accuracy: {:4f}'.format(best_acc)) 267 | 268 | # load best model weights 269 | best_model_wts = torch.load(os.path.join(checkpoint_folder, 270 | 'model_best.pth.tar')) 271 | model.load_state_dict(best_model_wts['state_dict']) 272 | # model.load_state_dict(best_model_wts) 273 | 274 | # return model, val_acc_history 275 | return model 276 | 277 | def main(args): 278 | fine_tune = not args.no_finetune 279 | pre_train = True 280 | 281 | lr = args.lr 282 | input_size = args.input_size 283 | # input_size = [448] 284 | # keep_aspect = True 285 | # model_names_list = ['vgg'] 286 | # tensor_sketch = False 287 | # embedding = 8192 288 | 289 | model_names= args.model_names 290 | 291 | args.exp_dir = os.path.join(args.dataset, args.exp_dir) 292 | 293 | if args.dataset in ['cars', 'aircrafts']: 294 | keep_aspect = False 295 | else: 296 | keep_aspect = True 297 | 298 | if args.dataset in ['aircrafts']: 299 | crop_from_size = [(x * 256) // 224 for x in input_size] 300 | else: 301 | crop_from_size = input_size 302 | 303 | if 'inat' in args.dataset: 304 | split = {'train': 'train', 'val': 'val'} 305 | else: 306 | split = {'train': 'train_val', 'val': 'test'} 307 | 308 | if not keep_aspect: 309 | input_size = [(x, x) for x in input_size] 310 | crop_from_size = [(x, x) for x in crop_from_size] 311 | 312 | exp_root = '../exp' 313 | checkpoint_folder = os.path.join(exp_root, args.exp_dir, 'checkpoints') 314 | if not os.path.isdir(checkpoint_folder): 315 | os.makedirs(checkpoint_folder) 316 | 317 | args_dict = vars(args) 318 | with open(os.path.join(exp_root, args.exp_dir, 'args.txt'), 'a') as f: 319 | f.write(json.dumps(args_dict, sort_keys=True, indent=4)) 320 | 321 | # make sure the dataset is ready 322 | if 'inat' in args.dataset: 323 | setup_dataset('inat') 324 | else: 325 | setup_dataset(args.dataset) 326 | 327 | # ================== Craete data loader ================================== 328 | data_transforms = { 329 | 'train': [transforms.Compose([ 330 | transforms.Resize(x[0]), 331 | # transforms.CenterCrop(x[1]), 332 | transforms.RandomCrop(x[1]), 333 | transforms.RandomHorizontalFlip(), 334 | transforms.ToTensor(), 335 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) \ 336 | for x in zip(crop_from_size, input_size)], 337 | 'val': [transforms.Compose([ 338 | transforms.Resize(x[0]), 339 | transforms.CenterCrop(x[1]), 340 | transforms.ToTensor(), 341 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) \ 342 | for x in zip(crop_from_size, input_size)], 343 | } 344 | 345 | if args.dataset == 'cub': 346 | from CUBDataset import CUBDataset as dataset 347 | elif args.dataset == 'cars': 348 | from CarsDataset import CarsDataset as dataset 349 | elif args.dataset == 'aircrafts': 350 | from AircraftsDataset import AircraftsDataset as dataset 351 | elif 'inat' in args.dataset: 352 | from iNatDataset import iNatDataset as dataset 353 | if args.dataset == 'inat': 354 | subset = None 355 | else: 356 | subset = args.dataset[len('inat_'):] 357 | subset = subset[0].upper() + subset[1:] 358 | else: 359 | raise ValueError('Unknown dataset: %s' % task) 360 | 361 | if 'inat' in args.dataset: 362 | dset = {x: dataset(dset_root['inat'], split[x], subset, \ 363 | transform=data_transforms[x]) for x in ['train', 'val']} 364 | dset_test = dataset(dset_root['inat'], 'test', subset, \ 365 | transform=data_transforms['val']) 366 | else: 367 | dset = {x: dataset(dset_root[args.dataset], split[x], 368 | transform=data_transforms[x]) for x in ['train', 'val']} 369 | dset_test = dataset(dset_root[args.dataset], 'test', 370 | transform=data_transforms['val']) 371 | 372 | dset_loader = {x: torch.utils.data.DataLoader(dset[x], 373 | batch_size=args.batch_size, shuffle=True, 374 | num_workers=4, drop_last=drop_last) \ 375 | for x, drop_last in zip(['train', 'val'], [True, False])} 376 | 377 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 378 | 379 | #======================= Initialize the model ========================= 380 | 381 | # The argument embedding is used only when tensor_sketch is True 382 | # The argument order is used only when the model parameters are shared 383 | # between feature extractors 384 | model = create_cnn_model(model_names, len(dset['train'].classes), 385 | input_size[0], fine_tune, pre_train) 386 | model = model.to(device) 387 | model = torch.nn.DataParallel(model) 388 | 389 | # Setup the loss fxn 390 | criterion = nn.CrossEntropyLoss() 391 | 392 | #====================== Initialize optimizer ============================== 393 | start_itr = 0 394 | 395 | optim = initialize_optimizer(model, args.lr, optimizer=args.optimizer, 396 | wd=args.wd, finetune_model=fine_tune) 397 | 398 | if 'inat' not in args.dataset: 399 | scheduler = torch.optim.lr_scheduler.LambdaLR(optim, 400 | lr_lambda=lambda epoch: 0.1 ** (epoch // 25)) 401 | else: 402 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, 'max') 403 | logger_name = 'train_logger' 404 | logger = initializeLogging(os.path.join(exp_root, args.exp_dir, 405 | 'train_history.txt'), logger_name) 406 | 407 | start_itr = 0 408 | # load from checkpoint if exist 409 | if not args.train_from_beginning: 410 | checkpoint_filename = os.path.join(checkpoint_folder, 411 | 'checkpoint.pth.tar') 412 | if os.path.isfile(checkpoint_filename): 413 | print("=> loading checkpoint '{}'".format(checkpoint_filename)) 414 | checkpoint = torch.load(checkpoint_filename) 415 | start_itr = checkpoint['itr'] 416 | model.load_state_dict(checkpoint['state_dict']) 417 | optim.load_state_dict(checkpoint['optimizer']) 418 | scheduler.load_state_dict(checkpoint['scheduler']) 419 | print("=> loaded checkpoint '{}' (iteration{})" 420 | .format(checkpoint_filename, checkpoint['itr'])) 421 | 422 | 423 | # parallelize the model if using multiple gpus 424 | # if torch.cuda.device_count() > 1: 425 | 426 | if fine_tune: 427 | model.train() 428 | else: 429 | # fix the batchnorm by model.eval() 430 | model.eval() 431 | model.module.fc.train() 432 | 433 | # Train the miodel 434 | model = train_model(model, dset_loader, criterion, optim, 435 | batch_size_update=args.batch_size_update_model, 436 | # maxItr=args.iteration, logger_name=logger_name, 437 | epoch=args.epoch, logger_name=logger_name, 438 | checkpoint_folder=checkpoint_folder, 439 | start_itr=start_itr, scheduler=scheduler, 440 | fine_tune=fine_tune) 441 | 442 | if 'inat' not in args.dataset: 443 | # do test 444 | test_loader = torch.utils.data.DataLoader(dset_test, 445 | batch_size=args.batch_size, shuffle=False, 446 | num_workers=8, drop_last=False) 447 | print('evaluating test data') 448 | # test_model(model, criterion, test_loader, logger_name) 449 | 450 | 451 | 452 | if __name__ == '__main__': 453 | parser = argparse.ArgumentParser() 454 | parser.add_argument('--batch_size_update_model', default=128, type=int, 455 | help='optimizer update the model after seeing batch_size number \ 456 | of inputs') 457 | parser.add_argument('--batch_size', default=32, type=int, 458 | help='size of mini-batch that can fit into gpus (sub bacth size') 459 | parser.add_argument('--epoch', default=45, type=int, 460 | help='number of epochs') 461 | parser.add_argument('--init_epoch', default=55, type=int, 462 | help='number of epochs for initializing fc layer') 463 | # parser.add_argument('--iteration', default=20000, type=int, 464 | # help='number of iterations') 465 | parser.add_argument('--init_lr', default=1.0, type=float, 466 | help='learning rate') 467 | parser.add_argument('--lr', default=1e-4, type=float, 468 | help='learning rate') 469 | parser.add_argument('--wd', default=1e-5, type=float, 470 | help='weight decay') 471 | parser.add_argument('--init_wd', default=1e-8, type=float, 472 | help='weight decay for initializing fc layer') 473 | parser.add_argument('--optimizer', default='adam', type=str, 474 | help='optimizer sgd|adam') 475 | parser.add_argument('--exp_dir', default='exp', type=str, 476 | help='foldername where to save the results for the experiment') 477 | parser.add_argument('--train_from_beginning', action='store_true', 478 | help='train the model from first epoch, i.e. ignore the checkpoint') 479 | # parser.add_argument('--train_split', default='train_val', type=str, 480 | # help='split used to train augmentor') 481 | parser.add_argument('--dataset', default='cub', type=str, 482 | help='cub | cars | aircrafts') 483 | parser.add_argument('--input_size', nargs='+', default=[448], type=int, 484 | help='input size as a list of sizes') 485 | parser.add_argument('--model_names', default='vgg', 486 | type=str, help='input size as a list of sizes') 487 | parser.add_argument('--fc_bottleneck', action='store_true', 488 | help='add bottelneck to the fc layers') 489 | parser.add_argument('--beta1', default=0.99, type=float, 490 | help='the value of beta1 for adam') 491 | parser.add_argument('--beta2', default=0.999, type=float, 492 | help='the value of beta2 for adam') 493 | parser.add_argument('--no_finetune', action='store_true', 494 | help='not do fine tuning') 495 | args = parser.parse_args() 496 | 497 | main(args) 498 | 499 | --------------------------------------------------------------------------------