├── .gitignore ├── .vscode └── settings.json ├── README.md ├── evaluation.py ├── extract_features.py ├── main.py ├── misc ├── custom_loss.py ├── transforms.py └── utils.py ├── models ├── alexnet.py ├── resnet101.py ├── resnet50.py └── vgg11_bn.py ├── options.py ├── run_image_based.sh └── run_sketch_based.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/usr/bin/python" 3 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Triplet Center Loss (TCL) Solution for SHREC2018 IBR & SBR 2 | This repo holds the code for our method _Triple Center Loss(TCL)_ at the [SHREC 2018](http://www2.projects.science.uu.nl/shrec/index2018-cfparticipation.html) challenges: 3 | - [Sketch-Based 3D Scene Retrieval(SBR)](http://orca.st.usm.edu/~bli/SceneSBR2018/) 4 | - [Image-Based 3D Scene Retrieval(IBR)](http://orca.st.usm.edu/~bli/SceneIBR2018/). 5 | 6 | The TCL method is based on the CVPR 2018 work 7 | > Xinwei He, Yang Zhou, Zhichao Zhou, Song Bai, Xiang Bai. **Triplet-Center Loss for Multi-View 3D Object Retrieval**. _CVPR 2018_. [[pdf]](http://openaccess.thecvf.com/content_cvpr_2018/papers/He_Triplet-Center_Loss_for_CVPR_2018_paper.pdf) 8 | 9 | For a detailed description of our submitted method, you can also refer to the technical report and find our method scription in Sec 4.2. 10 | > SHREC’18 Track: 2D Image-Based 3D Scene Retrieval. [[pdf]](http://orca.st.usm.edu/~bli/SceneIBR2018/SHREC18_Track_2D_Scene_Image-Based_3D_Scene_Retrieval.pdf) 11 | 12 | > SHREC’18 Track: 2D Scene Sketch-Based 3D Scene Retrieval. [[pdf]](http://orca.st.usm.edu/~bli/SceneSBR2018/SHREC18_Track_2D_Scene_Sketch-Based_3D_Scene_Retrieval.pdf) 13 | 14 | ## Prerequisites 15 | Our code has been tested with Python2 + PyTorch 0.3. It should work with higher versions after minor modifications. 16 | 17 | ## Usage 18 | The code works for both the IBR and the SBR tasks. 19 | 20 | For the IBR task, run the following command 21 | ``` 22 | bash run_image_based.sh $gpu $backbone $action $suffix $output_dir 23 | ``` 24 | Arguments: 25 | - gpu: the GPU ID 26 | - backbone: the backbone architecture to use. e.g. vgg11_bn, resnet50. 27 | - action: 1 or 2. choose **1 for training** and **2 for inference**. 28 | - output_dir: where to save the outputs 29 | 30 | For the SBR task, use the script `run_sketch_based.sh` instead. The arguments are the same. 31 | 32 | ## Related Projects 33 | - [cvpr_2018_TCL.pytorch](https://github.com/eriche2016/cvpr_2018_TCL.pytorch) 34 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics.pairwise import euclidean_distances 3 | import torch 4 | from rank_metrics import mean_average_precision 5 | # import ipdb 6 | def map_and_auc(label_q, label_d, d): 7 | rs = convert_rank_gt(label_q, label_d, d) 8 | trec_precisions = [] 9 | mrecs = [] 10 | mpres = [] 11 | aps = [] 12 | for i, r in enumerate(rs): 13 | #ipdb.set_trace() 14 | res = precision_and_recall(rs[i]) 15 | trec_precisions.append(res[0]) 16 | mrecs.append(res[1]) 17 | mpres.append(res[2]) 18 | aps.append(res[3]) 19 | 20 | trec_precisions = np.stack(trec_precisions) 21 | mrecs = np.stack(mrecs) 22 | mpres = np.stack(mpres) 23 | aps = np.stack(aps) 24 | AUC = np.mean(aps) 25 | mAP = np.mean(trec_precisions) 26 | return AUC, mAP 27 | 28 | def compute_map(label_q, label_d, d): 29 | rs = convert_rank_gt(label_q, label_d, d) 30 | return mean_average_precision(rs) 31 | 32 | def convert_rank_gt(label_q, label_d, d): 33 | idx = d.argsort(axis=1) 34 | label_q.resize(label_q.size, 1) 35 | label_d.resize(1, label_d.size) 36 | gt = (label_q == label_d) 37 | rs = [gt[i][idx[i]] for i in range(gt.shape[0])] # rank ground truth 38 | return rs 39 | 40 | 41 | def precision_and_recall(r): 42 | num_gt = np.sum(r) 43 | trec_precision = np.array([np.mean(r[:i+1]) for i in range(r.size) if r[i]]) 44 | recall = [np.sum(r[:i+1])/num_gt for i in range(r.size)] 45 | precision = [np.mean(r[:i+1]) for i in range(r.size)] 46 | 47 | # interpolate it 48 | mrec = np.array([0.] + recall + [1.]) 49 | mpre = np.array([0.] + precision + [0.]) 50 | 51 | for i in range(len(mpre)-2, -1, -1): 52 | mpre[i] = max(mpre[i], mpre[i+1]) 53 | 54 | i = np.where(mrec[1:] != mrec[:-1])[0]+1 55 | ap = np.sum((mrec[i]-mrec[i-1]) * mpre[i]) 56 | return trec_precision, mrec, mpre, ap 57 | 58 | 59 | def plot_pr_cure(mpres, mrecs): 60 | pr_curve = np.zeros(mpres.shape[0], 10) 61 | for r in range(mpres.shape[0]): 62 | this_mprec = mpres[r] 63 | for c in range(10): 64 | pr_curve[r, c] = np.max(this_mprec[mrecs[r]>(c-1)*0.1]) 65 | return pr_curve 66 | 67 | def l2_normalize(features): 68 | # features: num * ndim 69 | features_c = features.copy() 70 | features_c /= np.sqrt((features_c * features_c).sum(axis=1))[:, None] 71 | return features_c 72 | 73 | def compute_distance(x, y, l2=True): 74 | if l2: 75 | x = l2_normalize(x) 76 | y = l2_normalize(y) 77 | return euclidean_distances(x, y) 78 | 79 | 80 | -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | ## code on 1030 3 | from __future__ import print_function, absolute_import 4 | import torch 5 | # import torchvision.models as models 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | from torch.autograd import Variable 9 | from torch.backends import cudnn 10 | import torch.optim as optim 11 | 12 | import os 13 | import argparse 14 | import sys 15 | import time 16 | import numpy as np 17 | import scipy.io as scio 18 | 19 | # defined by zczhou 20 | # from models.sketch2shape import sketch_net, shape_net 21 | import misc.custom_loss as custom_loss 22 | 23 | import dataset.sk_dataset as sk_dataset 24 | import dataset.sh_views_dataset as sh_views_dataset 25 | 26 | import misc.transforms as T 27 | # from misc.utils import Logger 28 | import models 29 | 30 | from evaluation import map_and_auc, compute_distance, compute_map 31 | 32 | import misc.utils as utils 33 | # from sampler import RandomIdentitySampler 34 | 35 | from IPython.core.debugger import Tracer 36 | debug_here = Tracer() 37 | 38 | 39 | def get_test_data(train_shape_views_folder, test_shape_views_folder, train_shape_flist, test_shape_flist, 40 | train_sketch_folder, test_sketch_folder, train_sketch_flist, test_sketch_flist, 41 | height, width, batch_size, workers, pk_flag=False): 42 | 43 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 44 | std=[0.229, 0.224, 0.225]) 45 | 46 | train_transformer = T.Compose([ 47 | # T.RandomSizedRectCrop(height, width), 48 | T.RectScale(height, width), 49 | T.ToTensor(), 50 | normalizer, 51 | ]) 52 | 53 | test_transformer = T.Compose([ 54 | T.RectScale(height, width), 55 | T.ToTensor(), 56 | normalizer, 57 | ]) 58 | 59 | # define sketch dataset 60 | extname = '.png' if 'image' not in train_sketch_flist else '.JPEG' 61 | sketch_train_data = sk_dataset.Sk_Dataset(train_sketch_folder, train_sketch_flist, transform=train_transformer, ext=extname) 62 | sketch_test_data = sk_dataset.Sk_Dataset(test_sketch_folder, test_sketch_flist, transform=test_transformer, ext=extname) 63 | 64 | # define shape views dataset 65 | shape_train_data = sh_views_dataset.Sh_Views_Dataset(train_shape_views_folder, train_shape_flist, transform=train_transformer) 66 | shape_test_data = sh_views_dataset.Sh_Views_Dataset(test_shape_views_folder, test_shape_flist, transform=test_transformer) 67 | 68 | # num_classes = sketch_train_data.num_classes 69 | 70 | 71 | 72 | train_sketch_loader = DataLoader( 73 | sketch_train_data, 74 | batch_size=batch_size*2, num_workers=workers, 75 | shuffle=False, pin_memory=True) 76 | 77 | train_shape_loader = DataLoader( 78 | shape_train_data, 79 | batch_size=batch_size, num_workers=workers, 80 | shuffle=False, pin_memory=True) 81 | 82 | 83 | test_sketch_loader = DataLoader( 84 | sketch_test_data, 85 | batch_size=batch_size, num_workers=workers, 86 | shuffle=False, pin_memory=True) 87 | 88 | test_shape_loader = DataLoader( 89 | shape_test_data, 90 | batch_size=batch_size, num_workers=workers, 91 | shuffle=False, pin_memory=True) 92 | 93 | 94 | # sketch_weight = utils.get_weight(sketch_train_data.imgs) 95 | #��shape_weight = utils.get_weight(shape_train_data.imgs) 96 | # cls_weight = sketch_weight / (train_sketch_loader.batch_size*1.0 / train_shape_loader.batch_size) + shape_weight 97 | # cls_weight = cls_weight / cls_weight.sum() * cls_weight.size 98 | # cls_weight = torch.Tensor(cls_weight) 99 | 100 | return train_sketch_loader, train_shape_loader, test_sketch_loader, test_shape_loader # , cls_weight 101 | 102 | 103 | def main(opt): 104 | np.random.seed(opt.seed) 105 | torch.manual_seed(opt.seed) 106 | cudnn.benchmark = True 107 | 108 | # make output directory 109 | opt.checkpoint_folder += '_'+opt.backbone 110 | if opt.sketch_finetune: 111 | opt.checkpoint_folder += '_finetune' 112 | if not os.path.exists(opt.checkpoint_folder): 113 | os.makedirs(opt.checkpoint_folder) 114 | 115 | print(opt) 116 | # Redirect print to both console and log file 117 | # if not opt.evaluate: 118 | # sys.stdout = Logger(os.path.join(opt.logs_dir, opt.log_name)) 119 | 120 | # Create data loaders 121 | if opt.height is None or opt.width is None: 122 | opt.height, opt.width = (224, 224) 123 | 124 | train_sketch_loader, train_shape_loader, test_sketch_loader, test_shape_loader = get_test_data(opt.train_shape_views_folder, 125 | opt.test_shape_views_folder, opt.train_shape_flist, opt.test_shape_flist, 126 | opt.train_sketch_folder, opt.test_sketch_folder, opt.train_sketch_flist, opt.test_sketch_flist, 127 | opt.height, opt.width, opt.batch_size, opt.workers, pk_flag=False) 128 | 129 | # Create model 130 | 131 | kwargs = {'pool_idx': opt.pool_idx} if opt.pool_idx is not None else {} 132 | backbone = eval('models.'+opt.backbone) 133 | net_bp = backbone.Net_Prev_Pool(**kwargs) 134 | net_vp = backbone.View_And_Pool() 135 | net_ap = backbone.Net_After_Pool(**kwargs) 136 | if opt.sketch_finetune: 137 | net_whole = backbone.Net_Whole(nclasses = 10, use_finetuned=True) 138 | else: 139 | net_whole = backbone.Net_Whole(nclasses = 10) 140 | # for alexnet or vgg, feat_dim = 4096 141 | # for resnet, feat_dim = 2048 142 | net_cls = backbone.Net_Classifier(nclasses = 10) 143 | # Criterion 144 | # criterion = nn.CrossEntropyLoss().cuda() 145 | # if opt.balance: # current no balancing 146 | # crt_cls = nn.CrossEntropyLoss().cuda() 147 | # else: 148 | # classification loss 149 | crt_cls = nn.CrossEntropyLoss().cuda() 150 | # triplet center loss 151 | crt_tlc = custom_loss.TripletCenter10Loss(margin=opt.margin).cuda() 152 | if opt.wn: 153 | crt_tlc = torch.nn.utils.weight_norm(crt_tlc, name='centers') 154 | criterion = [crt_cls, crt_tlc, opt.w1, opt.w2] 155 | 156 | # Load from checkpoint 157 | start_epoch = best_top1 = 0 158 | opt.resume = opt.checkpoint_folder + '/model_best.pth' 159 | if True: 160 | checkpoint = torch.load(opt.resume) 161 | net_bp.load_state_dict(checkpoint['net_bp']) 162 | net_ap.load_state_dict(checkpoint['net_ap']) 163 | net_whole.load_state_dict(checkpoint['net_whole']) 164 | net_cls.load_state_dict(checkpoint['net_cls']) 165 | crt_tlc.load_state_dict(checkpoint['centers']) 166 | epoch = checkpoint['epoch'] if checkpoint.has_key('epoch') else 0 167 | # best_top1 = checkpoint['best_top1'] 168 | # print("=> Start epoch {} best top1 {:.1%}" 169 | # .format(start_epoch, best_top1)) 170 | 171 | # model = nn.DataParallel(model).cuda() 172 | net_bp = nn.DataParallel(net_bp).cuda() 173 | net_vp = net_vp.cuda() 174 | net_ap = nn.DataParallel(net_ap).cuda() 175 | net_whole = nn.DataParallel(net_whole).cuda() 176 | net_cls = nn.DataParallel(net_cls).cuda() 177 | # wrap multiple models in optimizer 178 | 179 | model = (net_whole, net_bp, net_vp, net_ap, net_cls) 180 | 181 | all_metric = [] 182 | # total_epochs = opt.max_epochs*10 if opt.pk_flag else opt.max_epochs 183 | for phase in ['test']: 184 | 185 | print("\nTest on {} *:".format(phase)) 186 | sketch_loader = eval(phase+'_sketch_loader') 187 | shape_loader = eval(phase+'_shape_loader') 188 | savename = phase + 'full_feat_final.mat' 189 | cur_metric = validate(sketch_loader, shape_loader, model, criterion, epoch, opt, savename) 190 | top1 = cur_metric[-1] 191 | 192 | print('\n * Finished epoch {:3d} top1: {:5.3%}'.format(epoch, top1)) 193 | all_metric.append(cur_metric) 194 | #print('Train Metric ', all_metric[0]) 195 | #print('Test Metric ', all_metric[1]) 196 | print('Test Metric ', all_metric[0]) 197 | 198 | 199 | def train(sketch_dataloader, shape_dataloader, model, criterion, optimizer, epoch, opt): 200 | """ 201 | train for one epoch on the training set 202 | """ 203 | batch_time = utils.AverageMeter() 204 | losses = utils.AverageMeter() 205 | top1 = utils.AverageMeter() 206 | tpl_losses = utils.AverageMeter() 207 | 208 | # training mode 209 | net_whole, net_bp, net_vp, net_ap, net_cls = model 210 | optim_sketch, optim_shape, optim_centers = optimizer 211 | crt_cls, crt_tlc, w1, w2 = criterion 212 | 213 | net_whole.train() 214 | net_bp.train() 215 | net_vp.train() 216 | net_ap.train() 217 | net_cls.train() 218 | 219 | end = time.time() 220 | # debug_here() 221 | for i, ((sketches, k_labels), (shapes, p_labels)) in enumerate(zip(sketch_dataloader, shape_dataloader)): 222 | 223 | shapes = shapes.view(shapes.size(0)*shapes.size(1), shapes.size(2), shapes.size(3), shapes.size(4)) 224 | 225 | # expanding: (bz * 12) x 3 x 224 x 224 226 | shapes = shapes.expand(shapes.size(0), 3, shapes.size(2), shapes.size(3)) 227 | 228 | shapes_v = Variable(shapes.cuda()) 229 | p_labels_v = Variable(p_labels.long().cuda()) 230 | 231 | sketches_v = Variable(sketches.cuda()) 232 | k_labels_v = Variable(k_labels.long().cuda()) 233 | 234 | 235 | o_bp = net_bp(shapes_v) 236 | o_vp = net_vp(o_bp) 237 | shape_feat = net_ap(o_vp) 238 | sketch_feat = net_whole(sketches_v) 239 | feat = torch.cat([shape_feat, sketch_feat]) 240 | target = torch.cat([p_labels_v, k_labels_v]) 241 | score = net_cls(feat) 242 | 243 | cls_loss = crt_cls(score, target) 244 | tpl_loss, _ = crt_tlc(score, target) 245 | # tpl_loss, _ = crt_tlc(feat, target) 246 | 247 | loss = w1 * cls_loss + w2 * tpl_loss 248 | 249 | ## measure accuracy 250 | prec1 = utils.accuracy(score.data, target.data, topk=(1,))[0] 251 | losses.update(cls_loss.data[0], score.size(0)) # batchsize 252 | tpl_losses.update(tpl_loss.data[0], score.size(0)) 253 | top1.update(prec1[0], score.size(0)) 254 | 255 | ## backward 256 | optim_sketch.zero_grad() 257 | optim_shape.zero_grad() 258 | optim_centers.zero_grad() 259 | 260 | loss.backward() 261 | utils.clip_gradient(optim_sketch, opt.gradient_clip) 262 | utils.clip_gradient(optim_shape, opt.gradient_clip) 263 | utils.clip_gradient(optim_centers, opt.gradient_clip) 264 | 265 | optim_sketch.step() 266 | optim_shape.step() 267 | optim_centers.step() 268 | 269 | # measure elapsed time 270 | batch_time.update(time.time() - end) 271 | end = time.time() 272 | if i % opt.print_freq == 0: 273 | print('Epoch: [{0}][{1}/{2}]\t' 274 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 275 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 276 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 277 | 'Trploss {triplet.val:.4f}({triplet.avg:.3f})'.format( 278 | epoch, i, len(sketch_dataloader), batch_time=batch_time, 279 | loss=losses, top1=top1, triplet=tpl_losses)) 280 | # print('triplet loss: ', tpl_center_loss.data[0]) 281 | print(' * Train Prec@1 {top1.avg:.3f}'.format(top1=top1)) 282 | return top1.avg 283 | 284 | def validate(sketch_dataloader, shape_dataloader, model, criterion, epoch, opt, savename): 285 | 286 | """ 287 | test for one epoch on the testing set 288 | """ 289 | sketch_losses = utils.AverageMeter() 290 | sketch_top1 = utils.AverageMeter() 291 | 292 | shape_losses = utils.AverageMeter() 293 | shape_top1 = utils.AverageMeter() 294 | 295 | net_whole, net_bp, net_vp, net_ap, net_cls = model 296 | # optim_sketch, optim_shape, optim_centers = optimizer 297 | crt_cls, crt_tlc, w1, w2 = criterion 298 | 299 | net_whole.eval() 300 | net_bp.eval() 301 | net_vp.eval() 302 | net_ap.eval() 303 | net_cls.eval() 304 | 305 | sketch_features = [] 306 | sketch_scores = [] 307 | sketch_labels = [] 308 | 309 | shape_features = [] 310 | shape_scores = [] 311 | shape_labels = [] 312 | 313 | batch_time = utils.AverageMeter() 314 | end = time.time() 315 | 316 | for i, (sketches, k_labels) in enumerate(sketch_dataloader): 317 | sketches_v = Variable(sketches.cuda()) 318 | k_labels_v = Variable(k_labels.long().cuda()) 319 | sketch_feat = net_whole(sketches_v) 320 | sketch_score = net_cls(sketch_feat) 321 | 322 | loss = crt_cls(sketch_score, k_labels_v) 323 | 324 | prec1 = utils.accuracy(sketch_score.data, k_labels_v.data, topk=(1,))[0] 325 | sketch_losses.update(loss.data[0], sketch_score.size(0)) # batchsize 326 | sketch_top1.update(prec1[0], sketch_score.size(0)) 327 | sketch_features.append(sketch_feat.data.cpu()) 328 | sketch_labels.append(k_labels) 329 | sketch_scores.append(sketch_score.data.cpu()) 330 | 331 | batch_time.update(time.time() - end) 332 | end = time.time() 333 | 334 | if i % opt.print_freq == 0: 335 | print('Test: [{0}/{1}]\t' 336 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 337 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 338 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 339 | i, len(sketch_dataloader), batch_time=batch_time, loss=sketch_losses, 340 | top1=sketch_top1)) 341 | print(' *Sketch Prec@1 {top1.avg:.3f}'.format(top1=sketch_top1)) 342 | 343 | batch_time = utils.AverageMeter() 344 | end = time.time() 345 | for i, (shapes, p_labels) in enumerate(shape_dataloader): 346 | shapes = shapes.view(shapes.size(0)*shapes.size(1), shapes.size(2), shapes.size(3), shapes.size(4)) 347 | # expanding: (bz * 12) x 3 x 224 x 224 348 | shapes = shapes.expand(shapes.size(0), 3, shapes.size(2), shapes.size(3)) 349 | 350 | shapes_v = Variable(shapes.cuda()) 351 | p_labels_v = Variable(p_labels.long().cuda()) 352 | 353 | o_bp = net_bp(shapes_v) 354 | o_vp = net_vp(o_bp) 355 | shape_feat = net_ap(o_vp) 356 | shape_score = net_cls(shape_feat) 357 | 358 | loss = crt_cls(shape_score, p_labels_v) 359 | 360 | prec1 = utils.accuracy(shape_score.data, p_labels_v.data, topk=(1,))[0] 361 | shape_losses.update(loss.data[0], shape_score.size(0)) # batchsize 362 | shape_top1.update(prec1[0], shape_score.size(0)) 363 | shape_features.append(shape_feat.data.cpu()) 364 | shape_labels.append(p_labels) 365 | shape_scores.append(shape_score.data.cpu()) 366 | 367 | batch_time.update(time.time() - end) 368 | end = time.time() 369 | 370 | if i % opt.print_freq == 0: 371 | print('Test: [{0}/{1}]\t' 372 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 373 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 374 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 375 | i, len(shape_dataloader), batch_time=batch_time, loss=shape_losses, 376 | top1=shape_top1)) 377 | print(' *Shape Prec@1 {top1.avg:.3f}'.format(top1=shape_top1)) 378 | 379 | shape_features = torch.cat(shape_features, 0).numpy() 380 | sketch_features = torch.cat(sketch_features, 0).numpy() 381 | 382 | shape_scores = torch.cat(shape_scores, 0).numpy() 383 | sketch_scores = torch.cat(sketch_scores, 0).numpy() 384 | 385 | shape_labels = torch.cat(shape_labels, 0).numpy() 386 | sketch_labels = torch.cat(sketch_labels, 0).numpy() 387 | 388 | # d = compute_distance(sketch_features.copy(), shape_features.copy(), l2=False) 389 | # scio.savemat('test/example.mat',{'d':d, 'feat':dataset_features, 'labels':dataset_labels}) 390 | # AUC, mAP = map_and_auc(sketch_labels.copy(), shape_labels.copy(), d) 391 | # print(' * Feature AUC {0:.5} mAP {0:.5}'.format(AUC, mAP)) 392 | 393 | d_feat = compute_distance(sketch_features.copy(), shape_features.copy(), l2=False) 394 | d_feat_norm = compute_distance(sketch_features.copy(), shape_features.copy(), l2=True) 395 | mAP_feat = compute_map(sketch_labels.copy(), shape_labels.copy(), d_feat) 396 | mAP_feat_norm = compute_map(sketch_labels.copy(), shape_labels.copy(), d_feat_norm) 397 | print(' * Feature mAP {0:.5%}\tNorm Feature mAP {1:.5%}'.format(mAP_feat, mAP_feat_norm)) 398 | 399 | 400 | d_score = compute_distance(sketch_scores.copy(), shape_scores.copy(), l2=False) 401 | mAP_score = compute_map(sketch_labels.copy(), shape_labels.copy(), d_score) 402 | d_score_norm = compute_distance(sketch_scores.copy(), shape_scores.copy(), l2=True) 403 | mAP_score_norm = compute_map(sketch_labels.copy(), shape_labels.copy(), d_score_norm) 404 | 405 | shape_paths = [img[0] for img in shape_dataloader.dataset.shape_target_path_list] 406 | sketch_paths = [img[0] for img in sketch_dataloader.dataset.sketch_target_path_list] 407 | scio.savemat('{}/{}'.format(opt.checkpoint_folder, savename), {'score_dist':d_score, 'score_dist_norm': d_score_norm, 'feat_dist': d_feat, 'feat_dist_norm': d_feat_norm,'sketch_features':sketch_features, 'sketch_labels':sketch_labels, 'sketch_scores': sketch_scores, 408 | 'shape_features':shape_features, 'shape_labels':shape_labels, 'shape_scores': shape_scores, 'sketch_pathcs':sketch_paths, 'shape_paths':shape_paths}) 409 | print(' * Score mAP {0:.5%}\tNorm Score mAP {1:.5%}'.format(mAP_score, mAP_score_norm)) 410 | return [sketch_top1.avg, shape_top1.avg, mAP_feat, mAP_feat_norm, mAP_score, mAP_score_norm] 411 | 412 | 413 | if __name__ == '__main__': 414 | from options import get_arguments 415 | 416 | opt = get_arguments() 417 | main(opt) 418 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | ## code on 1032 3 | from __future__ import print_function, absolute_import 4 | import torch 5 | # import torchvision.models as models 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | from torch.autograd import Variable 9 | from torch.backends import cudnn 10 | import torch.optim as optim 11 | 12 | import os, shutil 13 | import argparse 14 | import sys 15 | import time 16 | import numpy as np 17 | import scipy.io as scio 18 | 19 | # defined by zczhou 20 | # from models.sketch2shape import sketch_net, shape_net 21 | import misc.custom_loss as custom_loss 22 | 23 | import dataset.sk_dataset as sk_dataset 24 | import dataset.sh_views_dataset as sh_views_dataset 25 | 26 | import misc.transforms as T 27 | # from misc.utils import Logger 28 | import models 29 | 30 | from evaluation import map_and_auc, compute_distance, compute_map 31 | 32 | import misc.utils as utils 33 | # from sampler import RandomIdentitySampler 34 | 35 | from IPython.core.debugger import Tracer 36 | debug_here = Tracer() 37 | 38 | 39 | def get_data(train_shape_views_folder, test_shape_views_folder, train_shape_flist, test_shape_flist, 40 | train_sketch_folder, test_sketch_folder, train_sketch_flist, test_sketch_flist, 41 | height, width, batch_size, workers, pk_flag=False): 42 | 43 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 44 | std=[0.229, 0.224, 0.225]) 45 | 46 | train_transformer = T.Compose([ 47 | # T.RandomSizedRectCrop(height, width), 48 | T.RectScale(height, width), 49 | T.RandomHorizontalFlip(), 50 | T.ToTensor(), 51 | normalizer, 52 | ]) 53 | 54 | test_transformer = T.Compose([ 55 | T.RectScale(height, width), 56 | T.ToTensor(), 57 | normalizer, 58 | ]) 59 | 60 | # define sketch dataset 61 | extname = '.png' if 'image' not in train_sketch_flist else '.JPEG' 62 | sketch_train_data = sk_dataset.Sk_Dataset(train_sketch_folder, train_sketch_flist, transform=train_transformer, ext=extname) 63 | sketch_test_data = sk_dataset.Sk_Dataset(test_sketch_folder, test_sketch_flist, transform=test_transformer, ext=extname) 64 | 65 | # define shape views dataset 66 | shape_train_data = sh_views_dataset.Sh_Views_Dataset(train_shape_views_folder, train_shape_flist, transform=train_transformer) 67 | shape_test_data = sh_views_dataset.Sh_Views_Dataset(test_shape_views_folder, test_shape_flist, transform=test_transformer) 68 | 69 | # num_classes = sketch_train_data.num_classes 70 | 71 | if pk_flag: 72 | train_sketch_loader = DataLoader( 73 | sketch_train_data, 74 | batch_size=batch_size, num_workers=workers, 75 | # sampler=RandomIdentitySampler(sketch_train_data.imgs, num_instances), 76 | pin_memory=True, drop_last=True) 77 | 78 | train_shape_loader = DataLoader( 79 | shape_train_data, 80 | batch_size=batch_size, num_workers=workers, 81 | # sampler=RandomIdentitySampler(shape_train_data.imgs, num_instances), 82 | pin_memory=True, drop_last=True) 83 | else: 84 | train_sketch_loader = DataLoader( 85 | sketch_train_data, 86 | batch_size=batch_size*2, num_workers=workers, 87 | shuffle=True, pin_memory=True, drop_last=True) 88 | 89 | train_shape_loader = DataLoader( 90 | shape_train_data, 91 | batch_size=batch_size, num_workers=workers, 92 | shuffle=True, pin_memory=True, drop_last=True) 93 | 94 | 95 | test_sketch_loader = DataLoader( 96 | sketch_test_data, 97 | batch_size=batch_size, num_workers=workers, 98 | shuffle=False, pin_memory=True) 99 | 100 | test_shape_loader = DataLoader( 101 | shape_test_data, 102 | batch_size=batch_size, num_workers=workers, 103 | shuffle=False, pin_memory=True) 104 | 105 | 106 | # sketch_weight = utils.get_weight(sketch_train_data.imgs) 107 | # shape_weight = utils.get_weight(shape_train_data.imgs) 108 | # cls_weight = sketch_weight / (train_sketch_loader.batch_size*1.0 / train_shape_loader.batch_size) + shape_weight 109 | # cls_weight = cls_weight / cls_weight.sum() * cls_weight.size 110 | # cls_weight = torch.Tensor(cls_weight) 111 | 112 | return train_sketch_loader, train_shape_loader, test_sketch_loader, test_shape_loader # , cls_weight 113 | 114 | 115 | def main(opt): 116 | np.random.seed(opt.seed) 117 | torch.manual_seed(opt.seed) 118 | cudnn.benchmark = True 119 | opt.checkpoint_folder += '_'+opt.backbone 120 | if opt.sketch_finetune: 121 | opt.checkpoint_folder += '_finetune' 122 | if not os.path.exists(opt.checkpoint_folder): 123 | os.makedirs(opt.checkpoint_folder) 124 | 125 | print(opt) 126 | # Redirect print to both console and log file 127 | # if not opt.evaluate: 128 | # sys.stdout = Logger(os.path.join(opt.logs_dir, opt.log_name)) 129 | 130 | # Create data loaders 131 | if opt.height is None or opt.width is None: 132 | opt.height, opt.width = (224, 224) 133 | 134 | train_sketch_loader, train_shape_loader, test_sketch_loader, test_shape_loader = get_data(opt.train_shape_views_folder, 135 | opt.test_shape_views_folder, opt.train_shape_flist, opt.test_shape_flist, 136 | opt.train_sketch_folder, opt.test_sketch_folder, opt.train_sketch_flist, opt.test_sketch_flist, 137 | opt.height, opt.width, opt.batch_size, opt.workers, pk_flag=False) 138 | 139 | # Create model 140 | #if opt.pool_idx is None: 141 | # opt.pool_idx = set_default_pool 142 | kwargs = {'pool_idx': opt.pool_idx} if opt.pool_idx is not None else {} 143 | backbone = eval('models.'+opt.backbone) 144 | net_bp = backbone.Net_Prev_Pool(**kwargs) 145 | net_vp = backbone.View_And_Pool() 146 | net_ap = backbone.Net_After_Pool(**kwargs) 147 | if opt.sketch_finetune: 148 | net_whole = backbone.Net_Whole(nclasses = 10, use_finetuned=True) 149 | else: 150 | net_whole = backbone.Net_Whole(nclasses = 10) 151 | # for alexnet or vgg, feat_dim = 4096 152 | # for resnet, feat_dim = 2048 153 | net_cls = backbone.Net_Classifier(nclasses = 10) 154 | # Criterion 155 | # criterion = nn.CrossEntropyLoss().cuda() 156 | # if opt.balance: # current no balancing 157 | # crt_cls = nn.CrossEntropyLoss().cuda() 158 | # else: 159 | # classification loss 160 | crt_cls = nn.CrossEntropyLoss().cuda() 161 | # triplet center loss 162 | crt_tlc = custom_loss.TripletCenterLoss(margin=opt.margin).cuda() 163 | if opt.wn: 164 | crt_tlc = torch.nn.utils.weight_norm(crt_tlc, name='centers') 165 | criterion = [crt_cls, crt_tlc, opt.w1, opt.w2] 166 | 167 | # Load from checkpoint 168 | start_epoch = best_top1 = 0 169 | if opt.resume: 170 | checkpoint = torch.load(opt.resume) 171 | net_bp.load_state_dict(checkpoint['net_bp']) 172 | net_ap.load_state_dict(checkpoint['net_ap']) 173 | net_whole.load_state_dict(checkpoint['net_whole']) 174 | net_cls.load_state_dict(checkpoint['net_cls']) 175 | crt_tlc.load_state_dict(checkpoint['centers']) 176 | start_epoch = checkpoint['epoch'] 177 | best_top1 = checkpoint['best_prec'] 178 | # start_epoch = checkpoint['epoch'] 179 | # best_top1 = checkpoint['best_top1'] 180 | # print("=> Start epoch {} best top1 {:.1%}" 181 | # .format(start_epoch, best_top1)) 182 | 183 | # model = nn.DataParallel(model).cuda() 184 | net_bp = nn.DataParallel(net_bp).cuda() 185 | net_vp = net_vp.cuda() 186 | net_ap = nn.DataParallel(net_ap).cuda() 187 | net_whole = nn.DataParallel(net_whole).cuda() 188 | net_cls = nn.DataParallel(net_cls).cuda() 189 | # wrap multiple models in optimizer 190 | optim_shape = optim.SGD([{'params': net_ap.parameters()}, 191 | {'params': net_bp.parameters(), 'lr':1e-3}, 192 | {'params': net_cls.parameters()}], 193 | lr=0.001, momentum=0.9, weight_decay=opt.weight_decay) 194 | 195 | base_param_ids = set(map(id, net_whole.module.features.parameters())) 196 | new_params = [p for p in net_whole.parameters() if id(p) not in base_param_ids] 197 | param_groups = [ 198 | {'params': net_whole.module.features.parameters(), 'lr_mult':0.1}, 199 | {'params':new_params, 'lr_mult':1.0}] 200 | 201 | # optim_sketch = optim.SGD(net_whole.module.parameters(), lr=0.01) 202 | optim_sketch = optim.SGD(param_groups, lr=0.001, momentum=0.9, weight_decay=opt.weight_decay) 203 | optim_centers = optim.SGD(crt_tlc.parameters(), lr=0.1) 204 | 205 | optimizer = (optim_sketch, optim_shape, optim_centers) 206 | model = (net_whole, net_bp, net_vp, net_ap, net_cls) 207 | 208 | # Schedule learning rate 209 | def adjust_lr(epoch, optimizer): 210 | step_size = 800 if opt.pk_flag else 80 # 40 211 | lr = opt.lr * (0.1 ** (epoch // step_size)) 212 | for g in optimizer.param_groups: 213 | g['lr'] = lr * g.get('lr_mult', 1) 214 | 215 | # Start training 216 | top1 = 0.0 217 | if opt.evaluate: 218 | # validate and compute mAP 219 | _, top1 = validate(test_sketch_loader, test_shape_loader, model, criterion, 0, opt) 220 | exit() 221 | best_epoch = -1 222 | best_metric = None 223 | # total_epochs = opt.max_epochs*10 if opt.pk_flag else opt.max_epochs 224 | for epoch in range(start_epoch, opt.max_epochs): 225 | # adjust_lr(epoch, optim_sketch) 226 | # adjust_lr(epoch, optim_shape) 227 | # adjust_lr(epoch, optim_centers) 228 | # cls acc top1 229 | train_top1 = train(train_sketch_loader, train_shape_loader, model, criterion, optimizer, epoch, opt) 230 | if epoch < opt.start_save and (epoch % opt.interval == 0): 231 | continue 232 | 233 | if train_top1 > 0.1: 234 | print("Test:") 235 | cur_metric = validate(test_sketch_loader, test_shape_loader, model, criterion, epoch, opt) 236 | top1 = cur_metric[-1] 237 | 238 | is_best = top1 > best_top1 239 | if is_best: 240 | best_epoch = epoch + 1 241 | best_metric = cur_metric 242 | best_top1 = max(top1, best_top1) 243 | 244 | 245 | 246 | checkpoint = {} 247 | checkpoint['epoch'] = epoch + 1 248 | checkpoint['current_prec'] = top1 249 | checkpoint['best_prec'] = best_top1 250 | checkpoint['net_bp'] = net_bp.module.state_dict() 251 | checkpoint['net_ap'] = net_ap.module.state_dict() 252 | checkpoint['net_whole'] = net_whole.module.state_dict() 253 | checkpoint['net_cls'] = net_cls.module.state_dict() 254 | checkpoint['centers'] = crt_tlc.state_dict() 255 | 256 | path_checkpoint = '{0}/model_latest.pth'.format(opt.checkpoint_folder) 257 | utils.save_checkpoint(checkpoint, path_checkpoint) 258 | 259 | if is_best: # save checkpoint 260 | path_checkpoint = '{0}/model_best.pth'.format(opt.checkpoint_folder) 261 | utils.save_checkpoint(checkpoint, path_checkpoint) 262 | if opt.sf: 263 | shutil.copyfile(opt.checkpoint_folder+'/test_feat_temp.mat', opt.checkpoint_folder+'/test_feat_best.mat') 264 | 265 | print('\n * Finished epoch {:3d} top1: {:5.3%} best: {:5.3%}{} @epoch {}\n'. 266 | format(epoch, top1, best_top1, ' *' if is_best else '', best_epoch)) 267 | 268 | print('Best metric', best_metric) 269 | 270 | def train(sketch_dataloader, shape_dataloader, model, criterion, optimizer, epoch, opt): 271 | """ 272 | train for one epoch on the training set 273 | """ 274 | batch_time = utils.AverageMeter() 275 | losses = utils.AverageMeter() 276 | top1 = utils.AverageMeter() 277 | tpl_losses = utils.AverageMeter() 278 | 279 | # training mode 280 | net_whole, net_bp, net_vp, net_ap, net_cls = model 281 | optim_sketch, optim_shape, optim_centers = optimizer 282 | crt_cls, crt_tlc, w1, w2 = criterion 283 | 284 | net_whole.train() 285 | net_bp.train() 286 | net_vp.train() 287 | net_ap.train() 288 | net_cls.train() 289 | 290 | end = time.time() 291 | # debug_here() 292 | for i, ((sketches, k_labels), (shapes, p_labels)) in enumerate(zip(sketch_dataloader, shape_dataloader)): 293 | 294 | shapes = shapes.view(shapes.size(0)*shapes.size(1), shapes.size(2), shapes.size(3), shapes.size(4)) 295 | 296 | # expanding: (bz * 12) x 3 x 224 x 224 297 | shapes = shapes.expand(shapes.size(0), 3, shapes.size(2), shapes.size(3)) 298 | 299 | shapes_v = Variable(shapes.cuda()) 300 | p_labels_v = Variable(p_labels.long().cuda()) 301 | 302 | sketches_v = Variable(sketches.cuda()) 303 | k_labels_v = Variable(k_labels.long().cuda()) 304 | 305 | 306 | o_bp = net_bp(shapes_v) 307 | o_vp = net_vp(o_bp) 308 | shape_feat = net_ap(o_vp) 309 | sketch_feat = net_whole(sketches_v) 310 | feat = torch.cat([shape_feat, sketch_feat]) 311 | target = torch.cat([p_labels_v, k_labels_v]) 312 | score = net_cls(feat) 313 | 314 | cls_loss = crt_cls(score, target) 315 | tpl_loss, _ = crt_tlc(score, target) 316 | # tpl_loss, _ = crt_tlc(feat, target) 317 | 318 | loss = w1 * cls_loss + w2 * tpl_loss 319 | 320 | ## measure accuracy 321 | prec1 = utils.accuracy(score.data, target.data, topk=(1,))[0] 322 | losses.update(cls_loss.data[0], score.size(0)) # batchsize 323 | tpl_losses.update(tpl_loss.data[0], score.size(0)) 324 | top1.update(prec1[0], score.size(0)) 325 | 326 | ## backward 327 | optim_sketch.zero_grad() 328 | optim_shape.zero_grad() 329 | optim_centers.zero_grad() 330 | 331 | loss.backward() 332 | utils.clip_gradient(optim_sketch, opt.gradient_clip) 333 | utils.clip_gradient(optim_shape, opt.gradient_clip) 334 | utils.clip_gradient(optim_centers, opt.gradient_clip) 335 | 336 | optim_sketch.step() 337 | optim_shape.step() 338 | optim_centers.step() 339 | 340 | # measure elapsed time 341 | batch_time.update(time.time() - end) 342 | end = time.time() 343 | if i % opt.print_freq == 0: 344 | print('Epoch: [{0}][{1}/{2}]\t' 345 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 346 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 347 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 348 | 'Trploss {triplet.val:.4f}({triplet.avg:.3f})'.format( 349 | epoch, i, len(sketch_dataloader), batch_time=batch_time, 350 | loss=losses, top1=top1, triplet=tpl_losses)) 351 | # print('triplet loss: ', tpl_center_loss.data[0]) 352 | print(' * Train Prec@1 {top1.avg:.3f}'.format(top1=top1)) 353 | return top1.avg 354 | 355 | def validate(sketch_dataloader, shape_dataloader, model, criterion, epoch, opt): 356 | 357 | """ 358 | test for one epoch on the testing set 359 | """ 360 | sketch_losses = utils.AverageMeter() 361 | sketch_top1 = utils.AverageMeter() 362 | 363 | shape_losses = utils.AverageMeter() 364 | shape_top1 = utils.AverageMeter() 365 | 366 | net_whole, net_bp, net_vp, net_ap, net_cls = model 367 | # optim_sketch, optim_shape, optim_centers = optimizer 368 | crt_cls, crt_tlc, w1, w2 = criterion 369 | 370 | net_whole.eval() 371 | net_bp.eval() 372 | net_vp.eval() 373 | net_ap.eval() 374 | net_cls.eval() 375 | 376 | sketch_features = [] 377 | sketch_scores = [] 378 | sketch_labels = [] 379 | 380 | shape_features = [] 381 | shape_scores = [] 382 | shape_labels = [] 383 | 384 | batch_time = utils.AverageMeter() 385 | end = time.time() 386 | 387 | for i, (sketches, k_labels) in enumerate(sketch_dataloader): 388 | sketches_v = Variable(sketches.cuda()) 389 | k_labels_v = Variable(k_labels.long().cuda()) 390 | sketch_feat = net_whole(sketches_v) 391 | sketch_score = net_cls(sketch_feat) 392 | 393 | loss = crt_cls(sketch_score, k_labels_v) 394 | 395 | prec1 = utils.accuracy(sketch_score.data, k_labels_v.data, topk=(1,))[0] 396 | sketch_losses.update(loss.data[0], sketch_score.size(0)) # batchsize 397 | sketch_top1.update(prec1[0], sketch_score.size(0)) 398 | sketch_features.append(sketch_feat.data.cpu()) 399 | sketch_labels.append(k_labels) 400 | sketch_scores.append(sketch_score.data.cpu()) 401 | 402 | batch_time.update(time.time() - end) 403 | end = time.time() 404 | 405 | if i % opt.print_freq == 0: 406 | print('Test: [{0}/{1}]\t' 407 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 408 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 409 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 410 | i, len(sketch_dataloader), batch_time=batch_time, loss=sketch_losses, 411 | top1=sketch_top1)) 412 | print(' *Sketch Prec@1 {top1.avg:.3f}'.format(top1=sketch_top1)) 413 | 414 | batch_time = utils.AverageMeter() 415 | end = time.time() 416 | for i, (shapes, p_labels) in enumerate(shape_dataloader): 417 | shapes = shapes.view(shapes.size(0)*shapes.size(1), shapes.size(2), shapes.size(3), shapes.size(4)) 418 | # expanding: (bz * 12) x 3 x 224 x 224 419 | shapes = shapes.expand(shapes.size(0), 3, shapes.size(2), shapes.size(3)) 420 | 421 | shapes_v = Variable(shapes.cuda()) 422 | p_labels_v = Variable(p_labels.long().cuda()) 423 | 424 | o_bp = net_bp(shapes_v) 425 | o_vp = net_vp(o_bp) 426 | shape_feat = net_ap(o_vp) 427 | shape_score = net_cls(shape_feat) 428 | 429 | loss = crt_cls(shape_score, p_labels_v) 430 | 431 | prec1 = utils.accuracy(shape_score.data, p_labels_v.data, topk=(1,))[0] 432 | shape_losses.update(loss.data[0], shape_score.size(0)) # batchsize 433 | shape_top1.update(prec1[0], shape_score.size(0)) 434 | shape_features.append(shape_feat.data.cpu()) 435 | shape_labels.append(p_labels) 436 | shape_scores.append(shape_score.data.cpu()) 437 | 438 | batch_time.update(time.time() - end) 439 | end = time.time() 440 | 441 | if i % opt.print_freq == 0: 442 | print('Test: [{0}/{1}]\t' 443 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 444 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 445 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 446 | i, len(shape_dataloader), batch_time=batch_time, loss=shape_losses, 447 | top1=shape_top1)) 448 | print(' *Shape Prec@1 {top1.avg:.3f}'.format(top1=shape_top1)) 449 | 450 | shape_features = torch.cat(shape_features, 0).numpy() 451 | sketch_features = torch.cat(sketch_features, 0).numpy() 452 | 453 | shape_scores = torch.cat(shape_scores, 0).numpy() 454 | sketch_scores = torch.cat(sketch_scores, 0).numpy() 455 | 456 | shape_labels = torch.cat(shape_labels, 0).numpy() 457 | sketch_labels = torch.cat(sketch_labels, 0).numpy() 458 | 459 | # d = compute_distance(sketch_features.copy(), shape_features.copy(), l2=False) 460 | # scio.savemat('test/example.mat',{'d':d, 'feat':dataset_features, 'labels':dataset_labels}) 461 | # AUC, mAP = map_and_auc(sketch_labels.copy(), shape_labels.copy(), d) 462 | # print(' * Feature AUC {0:.5} mAP {0:.5}'.format(AUC, mAP)) 463 | 464 | d_feat = compute_distance(sketch_features.copy(), shape_features.copy(), l2=False) 465 | d_feat_norm = compute_distance(sketch_features.copy(), shape_features.copy(), l2=True) 466 | mAP_feat = compute_map(sketch_labels.copy(), shape_labels.copy(), d_feat) 467 | mAP_feat_norm = compute_map(sketch_labels.copy(), shape_labels.copy(), d_feat_norm) 468 | print(' * Feature mAP {0:.5%}\tNorm Feature mAP {1:.5%}'.format(mAP_feat, mAP_feat_norm)) 469 | 470 | 471 | d_score = compute_distance(sketch_scores.copy(), shape_scores.copy(), l2=False) 472 | mAP_score = compute_map(sketch_labels.copy(), shape_labels.copy(), d_score) 473 | d_score_norm = compute_distance(sketch_scores.copy(), shape_scores.copy(), l2=True) 474 | mAP_score_norm = compute_map(sketch_labels.copy(), shape_labels.copy(), d_score_norm) 475 | if opt.sf: 476 | shape_paths = [img[0] for img in shape_dataloader.dataset.shape_target_path_list] 477 | sketch_paths = [img[0] for img in sketch_dataloader.dataset.sketch_target_path_list] 478 | scio.savemat('{}/test_feat_temp.mat'.format(opt.checkpoint_folder), {'score_dist':d_score, 'score_dist_norm': d_score_norm, 'feat_dist': d_feat, 'feat_dist_norm': d_feat_norm,'sketch_features':sketch_features, 'sketch_labels':sketch_labels, 'sketch_scores': sketch_scores, 479 | 'shape_features':shape_features, 'shape_labels':shape_labels, 'sketch_paths':sketch_paths, 'shape_paths':shape_paths}) 480 | print(' * Score mAP {0:.5%}\tNorm Score mAP {1:.5%}'.format(mAP_score, mAP_score_norm)) 481 | return [sketch_top1.avg, shape_top1.avg, mAP_feat, mAP_feat_norm, mAP_score, mAP_score_norm] 482 | 483 | 484 | if __name__ == '__main__': 485 | from options import get_arguments 486 | 487 | opt = get_arguments() 488 | main(opt) 489 | -------------------------------------------------------------------------------- /misc/custom_loss.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from torch.nn import Parameter 9 | import numpy as np 10 | 11 | from IPython.core.debugger import Tracer 12 | debug_here = Tracer() 13 | 14 | ################################################################ 15 | ## Triplet related loss 16 | ################################################################ 17 | def pdist(A, squared=False, eps=1e-4): 18 | prod = torch.mm(A, A.t()) 19 | norm = prod.diag().unsqueeze(1).expand_as(prod) 20 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 21 | return res if squared else (res + eps).sqrt() + eps 22 | 23 | 24 | class TripletCenterLoss(nn.Module): 25 | def __init__(self, margin=0, num_classes=10): 26 | super(TripletCenterLoss, self).__init__() 27 | self.margin = margin 28 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 29 | self.centers = nn.Parameter(torch.randn(num_classes, num_classes)) 30 | 31 | def forward(self, inputs, targets): 32 | batch_size = inputs.size(0) 33 | targets_expand = targets.view(batch_size, 1).expand(batch_size, inputs.size(1)) 34 | centers_batch = self.centers.gather(0, targets_expand) # centers batch 35 | 36 | # compute pairwise distances between input features and corresponding centers 37 | centers_batch_bz = torch.stack([centers_batch]*batch_size) 38 | inputs_bz = torch.stack([inputs]*batch_size).transpose(0, 1) 39 | dist = torch.sum((centers_batch_bz -inputs_bz)**2, 2).squeeze() 40 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 41 | 42 | # for each anchor, find the hardest positive and negative 43 | mask = targets.expand(batch_size, batch_size).eq(targets.expand(batch_size, batch_size).t()) 44 | dist_ap, dist_an = [], [] 45 | for i in range(batch_size): # for each sample, we compute distance 46 | dist_ap.append(dist[i][mask[i]].max()) # mask[i]: positive samples of sample i 47 | dist_an.append(dist[i][mask[i]==0].min()) # mask[i]==0: negative samples of sample i 48 | 49 | dist_ap = torch.cat(dist_ap) 50 | dist_an = torch.cat(dist_an) 51 | 52 | # generate a new label y 53 | # compute ranking hinge loss 54 | y = dist_an.data.new() 55 | y.resize_as_(dist_an.data) 56 | y.fill_(1) 57 | y = Variable(y) 58 | # y_i = 1, means dist_an > dist_ap + margin will casuse loss be zero 59 | loss = self.ranking_loss(dist_an, dist_ap, y) 60 | 61 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) # normalize data by batch size 62 | return loss, prec 63 | 64 | 65 | -------------------------------------------------------------------------------- /misc/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from PIL import Image 3 | 4 | from torchvision.transforms import * 5 | 6 | 7 | class RectScale(object): 8 | def __init__(self, height, width, interpolation=Image.BILINEAR): 9 | self.height = height 10 | self.width = width 11 | self.interpolation = interpolation 12 | 13 | def __call__(self, img): 14 | w, h = img.size 15 | if h == self.height and w == self.width: 16 | return img 17 | return img.resize((self.width, self.height), self.interpolation) 18 | 19 | 20 | class RandomSizedRectCrop(object): 21 | def __init__(self, height, width, interpolation=Image.BILINEAR): 22 | self.height = height 23 | self.width = width 24 | self.interpolation = interpolation 25 | 26 | def __call__(self, img): 27 | for attempt in range(10): 28 | area = img.size[0] * img.size[1] 29 | target_area = random.uniform(0.64, 1.0) * area 30 | aspect_ratio = random.uniform(2, 3) 31 | 32 | h = int(round(math.sqrt(target_area * aspect_ratio))) 33 | w = int(round(math.sqrt(target_area / aspect_ratio))) 34 | 35 | if w <= img.size[0] and h <= img.size[1]: 36 | x1 = random.randint(0, img.size[0] - w) 37 | y1 = random.randint(0, img.size[1] - h) 38 | 39 | img = img.crop((x1, y1, x1 + w, y1 + h)) 40 | assert(img.size == (w, h)) 41 | 42 | return img.resize((self.width, self.height), self.interpolation) 43 | 44 | # Fallback 45 | scale = RectScale(self.height, self.width, 46 | interpolation=self.interpolation) 47 | return scale(img) -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | import time 6 | from torch.autograd import Variable 7 | from IPython.core.debugger import Tracer 8 | debug_here = Tracer() 9 | 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | def __init__(self): 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n # val*n: how many samples predicted correctly among the n samples 24 | self.count += n # totoal samples has been through 25 | self.avg = self.sum / self.count 26 | 27 | ################################################# 28 | ## confusion matrix 29 | ################################################# 30 | class ConfusionMatrix(object): 31 | def __init__(self, K): # K is number of classes 32 | self.reset(K) 33 | def reset(self, K): 34 | self.num_classes = K 35 | # declare a table matrix and zero it 36 | self.cm = torch.zeros(K, K) # one row for each class, column is predicted class 37 | # self.valids 38 | self.valids = torch.zeros(K) 39 | # mean average precision, i.e., mean class accuracy 40 | self.mean_class_acc = 0 41 | 42 | def batchAdd(self, outputs, targets): 43 | """ 44 | output is predicetd probability 45 | """ 46 | _, preds = outputs.topk(1, 1, True, True) 47 | # convert cudalong tensor to long tensor 48 | # preds: bz x 1 49 | for m in range(preds.size(0)): 50 | self.cm[targets[m]][preds[m][0]] = self.cm[targets[m]][preds[m][0]] + 1 51 | 52 | 53 | def updateValids(self): 54 | # total = 0 55 | for t in range(self.num_classes): 56 | if self.cm.select(0, t).sum() != 0: # column 57 | # sum of t-th row is the number of samples coresponding to this class (groundtruth) 58 | self.valids[t] = self.cm[t][t] / self.cm.select(0, t).sum() 59 | else: 60 | self.valids[t] = 0 61 | 62 | self.mean_class_acc = self.valids.mean() 63 | 64 | 65 | ################################################# 66 | ## compute accuracy 67 | ################################################# 68 | def accuracy(output, target, topk=(1,)): 69 | """Computes the precision@k for the specified values of k""" 70 | maxk = max(topk) 71 | batch_size = target.size(0) 72 | 73 | _, pred = output.topk(maxk, 1, True, True) 74 | pred = pred.t() 75 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 76 | 77 | res = [] 78 | for k in topk: 79 | # top k 80 | correct_k = correct[:k].view(-1).float().sum(0) 81 | res.append(correct_k.mul_(100.0 / batch_size)) 82 | return res 83 | 84 | def save_checkpoint(model, output_path): 85 | 86 | ## if not os.path.exists(output_dir): 87 | ## os.makedirs("model/") 88 | torch.save(model, output_path) 89 | 90 | print("Checkpoint saved to {}".format(output_path)) 91 | 92 | 93 | # do gradient clip 94 | def clip_gradient(optimizer, grad_clip): 95 | assert grad_clip>0, 'gradient clip value must be greater than 1' 96 | for group in optimizer.param_groups: 97 | for param in group['params']: 98 | # gradient 99 | param.grad.data.clamp_(-grad_clip, grad_clip) 100 | 101 | def preprocess(inputs_12v, mean, std, data_augment): 102 | """ 103 | inputs_12v: (bz * 12) x 3 x 224 x 224 104 | """ 105 | # to tensor 106 | if isinstance(inputs_12v, torch.ByteTensor): 107 | inputs_12v = inputs_12v.float() 108 | 109 | inputs_12v.sub_(mean).div_(std) 110 | 111 | if data_augment: 112 | print('currently not support data augmentation') 113 | 114 | return inputs_12v 115 | 116 | 117 | # centers: 40(or 55) x 3 x 4096 118 | # features: bz * 4096 -> bz * 3 * 4096 119 | # compute distance features between each features and centers 120 | def get_center_loss(centers, features, target, alpha, num_classes): 121 | batch_size = target.size(0) 122 | features_dim = features.size(1) 123 | num_centers = centers.size(1) 124 | # bz x 3 x 4096 125 | features_view = features.unsqueeze(1).expand(batch_size, num_centers, features_dim) 126 | 127 | target_expand = target.view(batch_size,1, 1).expand(batch_size,num_centers, features_dim) 128 | centers_var = Variable(centers) 129 | centers_batch = centers_var.gather(0,target_expand) 130 | criterion = nn.MSELoss() 131 | center_loss = criterion(features_view, centers_batch) 132 | 133 | # compute gradient w.r.t. center 134 | diff = centers_batch - features_view # bz x 3 x 4096 135 | 136 | unique_label, unique_reverse, unique_count = np.unique(target.cpu().data.numpy(), return_inverse=True, return_counts=True) 137 | appear_times = torch.from_numpy(unique_count).gather(0,torch.from_numpy(unique_reverse)) 138 | appear_times_expand = appear_times.view(-1,1,1).expand(batch_size,num_centers, features_dim).type(torch.FloatTensor) 139 | diff_cpu = diff.cpu().data / appear_times_expand.add(1e-6) 140 | diff_cpu = alpha * diff_cpu 141 | 142 | # update related centers 143 | for i in range(batch_size): 144 | centers[target.data[i]] -= diff_cpu[i].type(centers.type()) 145 | 146 | return center_loss, centers 147 | 148 | def get_centers_loss_margin(centers, features, target, alpha, num_classes, margin=1.0): 149 | batch_size = target.size(0) 150 | features_dim = features.size(1) 151 | # bz x 3 x 4096 152 | features_view = features.unsqueeze(1).expand(batch_size, 3, features_dim) 153 | 154 | target_expand = target.view(batch_size,1, 1).expand(batch_size,3, features_dim) 155 | centers_var = Variable(centers, requires_grad=False) 156 | centers_batch = centers_var.gather(0,target_expand) 157 | criterion = nn.MSELoss() 158 | 159 | center_loss = criterion(features_view, centers_batch) 160 | centers_var.requires_grad=True 161 | 162 | # compute gradient w.r.t. center 163 | diff = centers_batch - features_view # bz x 3 x 4096 164 | 165 | unique_label, unique_reverse, unique_count = np.unique(target.cpu().data.numpy(), return_inverse=True, return_counts=True) 166 | appear_times = torch.from_numpy(unique_count).gather(0,torch.from_numpy(unique_reverse)) 167 | appear_times_expand = appear_times.view(-1,1,1).expand(batch_size,3, features_dim).type(torch.FloatTensor) 168 | diff_cpu = diff.cpu().data / appear_times_expand.add(1e-6) 169 | diff_cpu = alpha * diff_cpu 170 | 171 | # update related centers 172 | 173 | for i in range(batch_size): 174 | centers[target.data[i]] -= diff_cpu[i].type(centers.type()) 175 | 176 | dist_c1c2 = torch.norm(centers_var[target.data[i]][0] - centers_var[target.data[i]][1], 2) 177 | dist_hinge1 = torch.clamp(-dist_c1c2 + margin, min=0.0) 178 | dist_c2c3 = torch.norm(centers_var[target.data[i]][1] - centers_var[target.data[i]][2], 2) 179 | dist_hinge2 = torch.clamp(-dist_c2c3 + margin, min=0.0) 180 | dist_c1c3 = torch.norm(centers_var[target.data[i]][0] - centers_var[target.data[i]][2], 2) 181 | dist_hinge3 = torch.clamp(-dist_c1c3 + margin, min=0.0) 182 | loss = dist_hinge1 + dist_hinge2 + dist_hinge3 183 | loss.backward() 184 | centers[target.data[i]] -= centers_var.grad.data[target.data[i]] 185 | 186 | centers_var.requires_grad=False 187 | 188 | return center_loss, centers 189 | 190 | ########################################################### 191 | ## start of centers_loss margin 192 | ########################################################### 193 | def similarity_matrix(feats_bzxD): 194 | # get the product x * y 195 | # here, y = x.t() 196 | r = torch.mm(feats_bzxD, feats_bzxD.t()) 197 | # get the diagonal elements 198 | diag = r.diag().unsqueeze(0) 199 | diag = diag.expand_as(r) 200 | # compute the distance matrix 201 | # D[i, j]: similarity of sample i-th feature feats_bzxD[i] 202 | # in the batch and feats_bzxD[j] 203 | D = diag + diag.t() - 2*r + 1e-6 204 | return D.sqrt() # no square in the orignal paper 205 | def convert_y2(y): 206 | bz = y.size(0) # batch size (number of samples in the batch) 207 | y_expand = y.unsqueeze(0).expand(bz, bz) 208 | Y = y_expand.eq(y_expand.t()) 209 | return Y 210 | 211 | ########################################### 212 | ## 213 | ########################################### 214 | def pdist(A, squared=False, eps=1e-4): 215 | prod = torch.mm(A, A.t()) 216 | norm = prod.diag().unsqueeze(1).expand_as(prod) 217 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 218 | return res if squared else (res + eps).sqrt() + eps 219 | def pdist2(A, B, squared=False, eps=1e-4): 220 | """ 221 | input: 222 | A: bz x D 223 | B: bz x D 224 | output: 225 | C: bz x bz 226 | """ 227 | m = A.size(0) 228 | mmp1 = torch.stack([A]*m) 229 | mmp2 = torch.stack([B]*m).transpose(0,1) 230 | C = torch.sum((mmp1-mmp2)**2,2).squeeze() 231 | 232 | return C if squared else (C + eps).sqrt() + eps 233 | 234 | def get_center_loss_single_center_each_class_margin_hard(centers, features, target, alpha, num_classes, margin=1.0): 235 | batch_size = target.size(0) 236 | features_dim = features.size(1) 237 | 238 | target_expand = target.view(batch_size,1).expand(batch_size, features_dim) 239 | centers_var = Variable(centers, requires_grad=False) 240 | 241 | centers_batch = centers_var.gather(0,target_expand) 242 | criterion = nn.MSELoss() 243 | center_loss = criterion(features, centers_batch) 244 | 245 | # compute gradient w.r.t. center 246 | diff = centers_batch - features 247 | 248 | unique_label, unique_reverse, unique_count = np.unique(target.cpu().data.numpy(), return_inverse=True, return_counts=True) 249 | appear_times = torch.from_numpy(unique_count).gather(0,torch.from_numpy(unique_reverse)) 250 | appear_times_expand = appear_times.view(-1,1).expand(batch_size, features_dim).type(torch.FloatTensor) 251 | diff_cpu = diff.cpu().data / appear_times_expand.add(1e-6) 252 | diff_cpu = alpha * diff_cpu 253 | 254 | # update related centers 255 | for i in range(batch_size): 256 | centers[target.data[i]] -= diff_cpu[i].type(centers.type()) 257 | 258 | ############################################# 259 | # additional 260 | ############################################# 261 | centers_var.requires_grad = True 262 | # nothing 263 | # else 264 | # centers_batch_b1 = centers_var.gather(0,target_expand) 265 | feats_centers_dist = pdist2(features, centers_batch) 266 | # else 267 | # feats_centers_dist = pdist2(features, centers_batch_b1) 268 | 269 | # normalize data 270 | norms = feats_centers_dist.norm(2, 1) 271 | feats_centers_dist = feats_centers_dist / norms.unsqueeze(1).repeat(1, feats_centers_dist.size(1)) 272 | 273 | pos = torch.eq(*[target.unsqueeze(dim).expand_as(feats_centers_dist) for dim in [0, 1]]).type_as(features) 274 | pd, _ = (pos * feats_centers_dist).max(1) 275 | n_pos = pos.eq(0).float() 276 | nd, _= (feats_centers_dist * n_pos).masked_fill_(pos.byte(), float('inf')).min(1) 277 | dist_m = pd + margin - nd 278 | 279 | loss_mh = torch.clamp(dist_m, min=0.0).mean(0).squeeze() 280 | 281 | center_loss_mh = loss_mh + center_loss 282 | 283 | return center_loss_mh, centers # , centers_var 284 | 285 | def get_centers_loss_margin_hard(centers, features, target, alpha, num_classes, margin=1.0): 286 | batch_size = target.size(0) 287 | features_dim = features.size(1) 288 | num_centers = centers.size(1) 289 | 290 | # bz x 3 x 4096 291 | features_view = features.unsqueeze(1).expand(batch_size, 3, features_dim) # bz x 3 x feat_Dim 292 | target_expand = target.view(batch_size,1, 1).expand(batch_size,3, features_dim) 293 | centers_var = Variable(centers, requires_grad=False) 294 | centers_batch = centers_var.gather(0,target_expand) # bz x 3 x feat_Dim 295 | 296 | criterion = nn.MSELoss() 297 | center_loss = criterion(features_view, centers_batch) 298 | 299 | # compute gradient w.r.t. center 300 | diff = centers_batch - features_view # bz x 3 x 4096 301 | 302 | unique_label, unique_reverse, unique_count = np.unique(target.cpu().data.numpy(), return_inverse=True, return_counts=True) 303 | appear_times = torch.from_numpy(unique_count).gather(0,torch.from_numpy(unique_reverse)) 304 | appear_times_expand = appear_times.view(-1,1,1).expand(batch_size,3, features_dim).type(torch.FloatTensor) 305 | diff_cpu = diff.cpu().data / appear_times_expand.add(1e-6) 306 | diff_cpu = alpha * diff_cpu 307 | 308 | centers_var.requires_grad=True 309 | New_Feats = features.view(batch_size, 1, 1, features_dim).expand(batch_size, num_classes, num_centers, features_dim) 310 | centers_var_view = centers_var.view(1, num_classes, num_centers, features_dim).expand(batch_size, num_classes, num_centers, features_dim) 311 | sim_D = torch.pow(New_Feats - centers_var, 2).sum(3).sqrt() # bz x num_class x K 312 | norms = sim_D.norm(2, 2) 313 | sim_D = sim_D / norms.unsqueeze(2).repeat(1,1,sim_D.size(2)) 314 | # make index 315 | pos_mask = Variable(torch.ByteTensor().resize_(sim_D.size()).zero_()) 316 | for i in range(batch_size): 317 | pos_mask.data[i][target.data[i]][:].fill_(1) 318 | 319 | pd, _= sim_D.masked_select(pos_mask.cuda()).view(pos_mask.size(0), -1).max(1) 320 | neg_mask = pos_mask.eq(0) 321 | nd, _= sim_D.masked_select(neg_mask.cuda()).view(neg_mask.size(0), -1).min(1) 322 | margin_tensor = Variable(torch.Tensor([margin]).expand(nd.size(0)).cuda()) 323 | diff_margin = pd + margin_tensor - nd 324 | loss_mh = torch.clamp(diff_margin, min=0.0).mean(0).squeeze() 325 | 326 | # debug_here() 327 | for i in range(batch_size): 328 | centers[target.data[i]] -= diff_cpu[i].type(centers.type()) 329 | 330 | # debug_here() 331 | center_loss_mh = loss_mh + center_loss 332 | 333 | return center_loss_mh, centers 334 | 335 | def get_centers_loss_margin_hard_v2(centers, features, target, alpha, num_classes, margin=1.0): 336 | batch_size = target.size(0) 337 | features_dim = features.size(1) 338 | num_centers = centers.size(1) 339 | 340 | # bz x 3 x 4096 341 | features_view = features.unsqueeze(1).expand(batch_size, 3, features_dim) # bz x 3 x feat_Dim 342 | target_expand = target.view(batch_size,1, 1).expand(batch_size,3, features_dim) 343 | centers_var = Variable(centers, requires_grad=False) 344 | centers_batch = centers_var.gather(0,target_expand) # bz x 3 x feat_Dim 345 | 346 | criterion = nn.MSELoss() 347 | center_loss = criterion(features_view, centers_batch) 348 | 349 | # compute gradient w.r.t. center 350 | diff = centers_batch - features_view # bz x 3 x 4096 351 | 352 | unique_label, unique_reverse, unique_count = np.unique(target.cpu().data.numpy(), return_inverse=True, return_counts=True) 353 | appear_times = torch.from_numpy(unique_count).gather(0,torch.from_numpy(unique_reverse)) 354 | appear_times_expand = appear_times.view(-1,1,1).expand(batch_size,3, features_dim).type(torch.FloatTensor) 355 | diff_cpu = diff.cpu().data / appear_times_expand.add(1e-6) 356 | diff_cpu = alpha * diff_cpu 357 | 358 | New_Feats = features.unsqueeze(1).expand(batch_size, num_centers, features_dim).contiguous().view(batch_size*num_centers, features_dim) 359 | centers_batch = centers_batch.view(batch_size*num_centers, features_dim) 360 | feats_centers_dist_raw = pdist2(New_Feats, centers_batch) 361 | # declare a mask: 24 x 24: every three element contains a one 362 | L = batch_size * num_centers # 8(batch_size) * 3(number_centers) 363 | Mask_Three = torch.arange(0, L*L).view(L, L) 364 | Mask_Three = Mask_Three.apply_(lambda x: 1 if x%3 == 0 else 0).byte() 365 | Mask_Three = Variable(Mask_Three).cuda() 366 | 367 | feats_centers_dist = feats_centers_dist_raw.masked_select(Mask_Three).view(batch_size, batch_size,num_centers) 368 | norms = feats_centers_dist.norm(2, 2) 369 | feats_centers_dist = feats_centers_dist / norms.unsqueeze(2).repeat(1, 1, feats_centers_dist.size(2)) 370 | 371 | # make index 372 | pos_mask = Variable(torch.ByteTensor().resize_(feats_centers_dist.size()).zero_()).float().cuda() 373 | for i in range(batch_size): 374 | pos_mask.data[i][i][:].fill_(1) 375 | 376 | # debug_here() 377 | neg_mask = pos_mask.eq(0).float() 378 | pd_1, _ = (feats_centers_dist * pos_mask).max(1) 379 | pd, _ = pd_1.max(1) 380 | 381 | nd_1, _ = (feats_centers_dist * neg_mask).masked_fill_(pos_mask.byte(), float('inf')).min(1) 382 | nd, _ = nd_1.min(1) 383 | 384 | margin_tensor = Variable(torch.Tensor([margin]).expand(nd.size(0)).cuda()) 385 | diff_margin = pd + margin_tensor - nd 386 | loss_mh = torch.clamp(diff_margin, min=0.0).mean(0).squeeze() 387 | 388 | debug_here() 389 | for i in range(batch_size): 390 | centers[target.data[i]] -= diff_cpu[i].type(centers.type()) 391 | 392 | center_loss_mh = loss_mh + center_loss 393 | 394 | return center_loss_mh, centers 395 | 396 | ######################################################################################### 397 | ## end of centers_loss_margin_hard 398 | ######################################################################################### 399 | 400 | # update nearest center 401 | def get_center_loss_nn(centers, features, target, alpha, num_classes): 402 | batch_size = target.size(0) 403 | features_dim = features.size(1) 404 | # bz x 3 x 4096 405 | features_view = features.unsqueeze(1).expand(batch_size, 3, features_dim) 406 | 407 | target_expand = target.view(batch_size,1, 1).expand(batch_size,3, features_dim) 408 | centers_var = Variable(centers) 409 | centers_batch = centers_var.gather(0,target_expand) 410 | criterion = nn.MSELoss() 411 | center_loss = criterion(features_view, centers_batch) 412 | 413 | # compute gradient w.r.t. center 414 | diff = centers_batch - features_view # bz x 3 x 4096 415 | 416 | # debug_here() 417 | norm_diff_3dim = torch.norm(diff.data, 2, 2) 418 | _, min_idx = torch.min(norm_diff_3dim, 1) 419 | 420 | unique_label, unique_reverse, unique_count = np.unique(target.cpu().data.numpy(), return_inverse=True, return_counts=True) 421 | appear_times = torch.from_numpy(unique_count).gather(0,torch.from_numpy(unique_reverse)) 422 | appear_times_expand = appear_times.view(-1,1,1).expand(batch_size,3, features_dim).type(torch.FloatTensor) 423 | diff_cpu = diff.cpu().data / appear_times_expand.add(1e-6) 424 | diff_cpu = alpha * diff_cpu 425 | 426 | # update related centers 427 | 428 | for i in range(batch_size): 429 | centers[target.data[i]][min_idx[i]] -= diff_cpu[i][min_idx[i]].type(centers.type()) 430 | """ 431 | for i in range(batch_size): 432 | centers[target.data[i]] -= diff_cpu[i].type(centers.type()) 433 | """ 434 | return center_loss, centers 435 | 436 | 437 | def get_center_loss_single_center_each_class(centers, features, target, alpha, num_classes): 438 | batch_size = target.size(0) 439 | features_dim = features.size(1) 440 | 441 | target_expand = target.view(batch_size,1).expand(batch_size, features_dim) 442 | centers_var = Variable(centers) 443 | centers_batch = centers_var.gather(0,target_expand) 444 | criterion = nn.MSELoss() 445 | center_loss = criterion(features, centers_batch) 446 | 447 | # compute gradient w.r.t. center 448 | diff = centers_batch - features 449 | 450 | unique_label, unique_reverse, unique_count = np.unique(target.cpu().data.numpy(), return_inverse=True, return_counts=True) 451 | appear_times = torch.from_numpy(unique_count).gather(0,torch.from_numpy(unique_reverse)) 452 | appear_times_expand = appear_times.view(-1,1).expand(batch_size, features_dim).type(torch.FloatTensor) 453 | diff_cpu = diff.cpu().data / appear_times_expand.add(1e-6) 454 | diff_cpu = alpha * diff_cpu 455 | 456 | # update related centers 457 | for i in range(batch_size): 458 | centers[target.data[i]] -= diff_cpu[i].type(centers.type()) 459 | 460 | return center_loss, centers 461 | 462 | 463 | 464 | # this loss will try to drag the center away from each other 465 | def get_contrastive_center_loss(centers, targets): 466 | 467 | num_classes = centers.size(0) # for shapenet55, it is 55 468 | l2_norm = centers.norm(2) # normalize the input 469 | centers = centers.div_(l2_norm) 470 | 471 | centers_var = Variable(centers, requires_grad = True) 472 | centers_var_stack = torch.stack([centers_var]*num_classes) 473 | centers_var_stack_t = torch.stack([centers_var]*num_classes).transpose(0, 1) 474 | 475 | # zero out coresponding centers which are not updated during this iterations 476 | distance_map = torch.sum((centers_var_stack - centers_var_stack_t)**2, 2).squeeze() 477 | 478 | mask = torch.zeros(num_classes, num_classes).long() 479 | mask[targets.data.cpu(), :] = 1 480 | mask = mask.type_as(centers) 481 | distance_map.data = distance_map.data * mask 482 | 483 | # different classes are different, so enforce their distance to 1 484 | # we should normalize centers 485 | cross_target = 1 - np.identity(num_classes) 486 | cross_target = torch.from_numpy(cross_target).type_as(centers) 487 | cross_target = cross_target * mask 488 | 489 | # target = np.ones((num_classes, num_classes)) 490 | cross_target = Variable(cross_target) 491 | 492 | criterion = nn.MSELoss() 493 | 494 | # if we want distance d12 to be 1, then we need to |x1 - x2| to be reach 1 495 | contrastive_center_loss = criterion(distance_map, cross_target) 496 | # print(contrastive_center_loss) 497 | 498 | # based on the input centers, we update its centers 499 | contrastive_center_loss.backward() 500 | 501 | centers_var.grad.data = centers_var.grad.data * mask 502 | # up : 0.01 503 | centers_var.data -= 0.01 * centers_var.grad.data 504 | 505 | # resume 506 | centers = centers.mul_(l2_norm) 507 | 508 | return centers 509 | 510 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | # import ipdb 5 | from torch.autograd import Variable 6 | 7 | 8 | ''' 9 | AlexNet ( 10 | (features): Sequential ( 11 | (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)) 12 | (1): ReLU (inplace) 13 | (2): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)) 14 | (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 15 | (4): ReLU (inplace) 16 | (5): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)) 17 | (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 18 | (7): ReLU (inplace) 19 | (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 20 | (9): ReLU (inplace) 21 | (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 22 | (11): ReLU (inplace) 23 | (12): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)) 24 | ) 25 | (classifier): Sequential ( 26 | (0): Dropout (p = 0.5) 27 | (1): Linear (9216 -> 4096) 28 | (2): ReLU (inplace) 29 | (3): Dropout (p = 0.5) 30 | (4): Linear (4096 -> 4096) 31 | (5): ReLU (inplace) 32 | (6): Linear (4096 -> 1000) 33 | ) 34 | ) 35 | ''' 36 | nclasses = 90 37 | original_model = models.alexnet(pretrained=True) 38 | # original_model.classifier._modules['6'] = nn.Linear(4096, nclasses) 39 | def Net_Classifier(nfea=4096, nclasses=90): 40 | return nn.Linear(nfea, nclasses) 41 | 42 | class Net_Prev_Pool(nn.Module): 43 | def __init__(self, pool_idx=13): 44 | super(Net_Prev_Pool, self).__init__() 45 | self.Prev_Pool_Net = nn.Sequential( 46 | # use bottom layers, suppose pool_idx = 1, 47 | # then we use the bottomest layer(i.e, first layer) 48 | *list(original_model.features.children())[:pool_idx] 49 | ) 50 | def forward(self, x): 51 | x = self.Prev_Pool_Net(x) 52 | return x 53 | 54 | 55 | # this layer has no parameters 56 | class View_And_Pool(nn.Module): 57 | def __init__(self): 58 | super(View_And_Pool, self).__init__() 59 | # note that in python, dimension idx starts from 1 60 | # self.Pool_Net = legacy_nn.Max(1) 61 | # only max pool layer, we will use view in forward function 62 | # self.w = nn.Parameter(torch.ones(1, 12, 1, 1, 1), requires_grad=True) 63 | # self. = nn.Parameter(torch.zeros(12, 4096), requires_grad=True) 64 | 65 | def forward(self, x): 66 | # view x ( (bz*12) x C x H x W) ) as 67 | # bz x 12 x C x H x W 68 | # transform each view: 12 x C x H x W -> 12 X C x H x W 69 | x = x.view(-1, 12, x.size()[1], x.size()[2], x.size()[3]) 70 | # using average pool instead of max pool 71 | x, _= torch.max(x, 1) 72 | 73 | return x 74 | 75 | class Net_After_Pool(nn.Module): 76 | def __init__(self, pool_idx=13): 77 | super(Net_After_Pool, self).__init__() 78 | self.After_Pool_Net = nn.Sequential( 79 | # use top layers, suppose pool_idx = 1, 80 | # then we use from 2 layer up to the topest layer 81 | *list(original_model.features.children())[pool_idx:] 82 | ) 83 | self.modules_list = nn.ModuleList([module for module in original_model.classifier.children()]) 84 | 85 | 86 | def forward(self, x): 87 | x = self.After_Pool_Net(x) 88 | # need to insert a view layer so that we can feed it to classification layers 89 | x = x.view(x.size()[0], -1) 90 | 91 | x = self.modules_list[0](x) 92 | x = self.modules_list[1](x) 93 | x = self.modules_list[2](x) 94 | x = self.modules_list[3](x) 95 | x = self.modules_list[4](x) 96 | out1 = self.modules_list[5](x) 97 | # out2 = self.modules_list[6](out1) 98 | return out1 #[out1, out2] 99 | 100 | class Net_Whole(nn.Module): 101 | def __init__(self, nclasses=90): 102 | super(Net_Whole, self).__init__() 103 | net = models.alexnet(pretrained=True) 104 | self.features = net.features 105 | classifier = net.classifier 106 | # classifier._modules['6'] = nn.Linear(4096, nclasses) 107 | self.modules_list = nn.ModuleList([module for module in classifier.children()]) 108 | 109 | def forward(self, x): 110 | x = self.features(x) 111 | x = x.view(x.size()[0], -1) 112 | x = self.modules_list[0](x) 113 | x = self.modules_list[1](x) 114 | x = self.modules_list[2](x) 115 | x = self.modules_list[3](x) 116 | x = self.modules_list[4](x) 117 | out1 = self.modules_list[5](x) 118 | # out2 = self.modules_list[6](out1) 119 | return out1 #[out1, out2] 120 | 121 | 122 | # no use 123 | class zzc_maxpooling(nn.Module): 124 | def __init__(self): 125 | super(zzc_maxpooling, self).__init__() 126 | net = models.alexnet(pretrained=False) 127 | self.features = net.features 128 | self.classifier = net.classifier 129 | 130 | def forward(self, x): 131 | x = x.view(-1, x.size(2), x.size(3), x.size(4)) 132 | x = self.features(x) 133 | x = x.view(-1, 12, x.size()[1], x.size()[2], x.size()[3]) 134 | x, _= torch.max(x, 1) 135 | x = self.classifier(x) 136 | return x 137 | 138 | 139 | if __name__ == '__main__': 140 | ''' 141 | pool_idx = 13 142 | # avoid pool at relu layer, because if relu is inplace, then 143 | # may cause misleading 144 | model_prev_pool = Net_Prev_Pool(pool_idx).cuda() 145 | view_and_pool = View_And_Pool().cuda() 146 | # ipdb.set_trace() 147 | x = Variable(torch.rand(12*2, 3, 224, 224).cuda()) 148 | model_after_pool = Net_After_Pool(pool_idx).cuda() 149 | bp = model_prev_pool(x) 150 | ap = view_and_pool(bp) 151 | o1 = model_after_pool(ap) 152 | 153 | whole = Net_Whole().cuda() 154 | ipdb.set_trace() 155 | 156 | x = Variable(torch.rand(12*2, 3, 224, 224).cuda()) 157 | o2 = whole(x) 158 | ''' 159 | m = zzc_maxpooling() 160 | x = Variable(torch.rand(2, 12, 3, 224, 224).cuda()) 161 | o = m(x) 162 | ipdb.set_trace() 163 | 164 | 165 | 166 | -------------------------------------------------------------------------------- /models/resnet101.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | # import ipdb 5 | from torch.autograd import Variable 6 | 7 | 8 | ''' 9 | ResNet( 10 | 0: conv1 11 | 1: bn1 12 | 2: relu 13 | 3: maxpool 14 | 4: layer1 15 | 5: layer2 16 | 6: layer3 17 | 7: layer4 18 | 8: avgpool 19 | 9: fc 20 | ) 21 | ''' 22 | nclasses = 90 23 | original_model = models.resnet101(pretrained=True) 24 | # original_model.classifier._modules['6'] = nn.Linear(4096, nclasses) 25 | def Net_Classifier(nfea=2048, nclasses=90): 26 | return nn.Linear(nfea, nclasses) 27 | 28 | class Net_Prev_Pool(nn.Module): 29 | def __init__(self, pool_idx=7): 30 | super(Net_Prev_Pool, self).__init__() 31 | self.Prev_Pool_Net = nn.Sequential( 32 | # use bottom layers, suppose pool_idx = 1, 33 | # then we use the bottomest layer(i.e, first layer) 34 | *list(original_model.children())[:pool_idx] 35 | ) 36 | def forward(self, x): 37 | x = self.Prev_Pool_Net(x) 38 | return x 39 | 40 | 41 | # this layer has no parameters 42 | class View_And_Pool(nn.Module): 43 | def __init__(self): 44 | super(View_And_Pool, self).__init__() 45 | # note that in python, dimension idx starts from 1 46 | # self.Pool_Net = legacy_nn.Max(1) 47 | # only max pool layer, we will use view in forward function 48 | # self.w = nn.Parameter(torch.ones(1, 12, 1, 1, 1), requires_grad=True) 49 | # self. = nn.Parameter(torch.zeros(12, 4096), requires_grad=True) 50 | 51 | def forward(self, x): 52 | # view x ( (bz*12) x C x H x W) ) as 53 | # bz x 12 x C x H x W 54 | # transform each view: 12 x C x H x W -> 12 X C x H x W 55 | x = x.view(-1, 12, x.size()[1], x.size()[2], x.size()[3]) 56 | # using average pool instead of max pool 57 | x, _= torch.max(x, 1) 58 | 59 | return x 60 | 61 | class Net_After_Pool(nn.Module): 62 | def __init__(self, pool_idx=7): 63 | super(Net_After_Pool, self).__init__() 64 | self.After_Pool_Net = nn.Sequential( 65 | # use top layers, suppose pool_idx = 1, 66 | # then we use from 2 layer up to the topest layer 67 | *list(original_model.children())[pool_idx:-1] 68 | ) 69 | #self.modules_list = nn.ModuleList([module for module in original_model.fc.children()]) 70 | 71 | 72 | def forward(self, x): 73 | # forward through all left layers except the classifier 74 | # however, an additional reshape operation is needed 75 | x = self.After_Pool_Net(x) 76 | # need to insert a view layer so that we can feed it to classification layers 77 | x = x.view(x.size()[0], -1) 78 | 79 | #x = self.modules_list[0](x) 80 | #x = self.modules_list[1](x) 81 | #x = self.modules_list[2](x) 82 | #x = self.modules_list[3](x) 83 | #x = self.modules_list[4](x) 84 | #out1 = self.modules_list[5](x) 85 | # out2 = self.modules_list[6](out1) 86 | return x #[out1, out2] 87 | 88 | class Net_Whole(nn.Module): 89 | def __init__(self, nclasses=90, use_finetuned=False): 90 | super(Net_Whole, self).__init__() 91 | if use_finetuned: 92 | net = models.resnet101(num_classes=250) 93 | d = torch.load('../sketch_finetune/resnet101.pth.tar') 94 | sd = d['state_dict'] 95 | od = net.state_dict() 96 | for sk, ok in zip(sd.keys(), od.keys()): 97 | od[ok] = sd[sk] 98 | else: 99 | net = models.resnet101(pretrained=True) 100 | self.features = nn.Sequential( 101 | *list(original_model.children())[:-1]) 102 | #classifier = net.classifier 103 | # classifier._modules['6'] = nn.Linear(4096, nclasses) 104 | #self.modules_list = nn.ModuleList([module for module in classifier.children()]) 105 | 106 | def forward(self, x): 107 | x = self.features(x) 108 | x = x.view(x.size()[0], -1) 109 | 110 | # out2 = self.modules_list[6](out1) 111 | return x #[out1, out2] 112 | 113 | 114 | # no use 115 | class zzc_maxpooling(nn.Module): 116 | def __init__(self): 117 | super(zzc_maxpooling, self).__init__() 118 | net = models.alexnet(pretrained=False) 119 | self.features = net.features 120 | self.classifier = net.classifier 121 | 122 | def forward(self, x): 123 | x = x.view(-1, x.size(2), x.size(3), x.size(4)) 124 | x = self.features(x) 125 | x = x.view(-1, 12, x.size()[1], x.size()[2], x.size()[3]) 126 | x, _= torch.max(x, 1) 127 | x = self.classifier(x) 128 | return x 129 | 130 | 131 | if __name__ == '__main__': 132 | ''' 133 | pool_idx = 13 134 | # avoid pool at relu layer, because if relu is inplace, then 135 | # may cause misleading 136 | model_prev_pool = Net_Prev_Pool(pool_idx).cuda() 137 | view_and_pool = View_And_Pool().cuda() 138 | # ipdb.set_trace() 139 | x = Variable(torch.rand(12*2, 3, 224, 224).cuda()) 140 | model_after_pool = Net_After_Pool(pool_idx).cuda() 141 | bp = model_prev_pool(x) 142 | ap = view_and_pool(bp) 143 | o1 = model_after_pool(ap) 144 | 145 | whole = Net_Whole().cuda() 146 | ipdb.set_trace() 147 | 148 | x = Variable(torch.rand(12*2, 3, 224, 224).cuda()) 149 | o2 = whole(x) 150 | ''' 151 | m = zzc_maxpooling() 152 | x = Variable(torch.rand(2, 12, 3, 224, 224).cuda()) 153 | o = m(x) 154 | ipdb.set_trace() 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /models/resnet50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | # import ipdb 5 | from torch.autograd import Variable 6 | 7 | 8 | ''' 9 | ResNet( 10 | 0: conv1 11 | 1: bn1 12 | 2: relu 13 | 3: maxpool 14 | 4: layer1 15 | 5: layer2 16 | 6: layer3 17 | 7: layer4 18 | 8: avgpool 19 | 9: fc 20 | ) 21 | ''' 22 | nclasses = 90 23 | original_model = models.resnet50(pretrained=True) 24 | # original_model.classifier._modules['6'] = nn.Linear(4096, nclasses) 25 | def Net_Classifier(nfea=2048, nclasses=90): 26 | return nn.Linear(nfea, nclasses) 27 | 28 | class Net_Prev_Pool(nn.Module): 29 | def __init__(self, pool_idx=7): 30 | super(Net_Prev_Pool, self).__init__() 31 | self.Prev_Pool_Net = nn.Sequential( 32 | # use bottom layers, suppose pool_idx = 1, 33 | # then we use the bottomest layer(i.e, first layer) 34 | *list(original_model.children())[:pool_idx] 35 | ) 36 | def forward(self, x): 37 | x = self.Prev_Pool_Net(x) 38 | return x 39 | 40 | 41 | # this layer has no parameters 42 | class View_And_Pool(nn.Module): 43 | def __init__(self): 44 | super(View_And_Pool, self).__init__() 45 | # note that in python, dimension idx starts from 1 46 | # self.Pool_Net = legacy_nn.Max(1) 47 | # only max pool layer, we will use view in forward function 48 | # self.w = nn.Parameter(torch.ones(1, 12, 1, 1, 1), requires_grad=True) 49 | # self. = nn.Parameter(torch.zeros(12, 4096), requires_grad=True) 50 | 51 | def forward(self, x): 52 | # view x ( (bz*12) x C x H x W) ) as 53 | # bz x 12 x C x H x W 54 | # transform each view: 12 x C x H x W -> 12 X C x H x W 55 | x = x.view(-1, 12, x.size()[1], x.size()[2], x.size()[3]) 56 | # using average pool instead of max pool 57 | x, _= torch.max(x, 1) 58 | 59 | return x 60 | 61 | class Net_After_Pool(nn.Module): 62 | def __init__(self, pool_idx=7): 63 | super(Net_After_Pool, self).__init__() 64 | self.After_Pool_Net = nn.Sequential( 65 | # use top layers, suppose pool_idx = 1, 66 | # then we use from 2 layer up to the topest layer 67 | *list(original_model.children())[pool_idx:-1] 68 | ) 69 | #self.modules_list = nn.ModuleList([module for module in original_model.fc.children()]) 70 | 71 | 72 | def forward(self, x): 73 | # forward through all left layers except the classifier 74 | # however, an additional reshape operation is needed 75 | x = self.After_Pool_Net(x) 76 | # need to insert a view layer so that we can feed it to classification layers 77 | x = x.view(x.size()[0], -1) 78 | 79 | #x = self.modules_list[0](x) 80 | #x = self.modules_list[1](x) 81 | #x = self.modules_list[2](x) 82 | #x = self.modules_list[3](x) 83 | #x = self.modules_list[4](x) 84 | #out1 = self.modules_list[5](x) 85 | # out2 = self.modules_list[6](out1) 86 | return x #[out1, out2] 87 | 88 | class Net_Whole(nn.Module): 89 | def __init__(self, nclasses=90, use_finetuned=False): 90 | super(Net_Whole, self).__init__() 91 | if use_finetuned: 92 | net = models.resnet50(num_classes=250) 93 | d = torch.load('../sketch_finetune/resnet50_sketch.pth.tar') 94 | sd = d['state_dict'] 95 | od = net.state_dict() 96 | for sk, ok in zip(sd.keys(), od.keys()): 97 | od[ok] = sd[sk] 98 | #net.load_state_dict(d['state_dict']) 99 | else: 100 | net = models.resnet50(pretrained=True) 101 | self.features = nn.Sequential( 102 | *list(original_model.children())[:-1]) 103 | #classifier = net.classifier 104 | # classifier._modules['6'] = nn.Linear(4096, nclasses) 105 | #self.modules_list = nn.ModuleList([module for module in classifier.children()]) 106 | 107 | def forward(self, x): 108 | x = self.features(x) 109 | x = x.view(x.size()[0], -1) 110 | 111 | # out2 = self.modules_list[6](out1) 112 | return x #[out1, out2] 113 | 114 | 115 | # no use 116 | class zzc_maxpooling(nn.Module): 117 | def __init__(self): 118 | super(zzc_maxpooling, self).__init__() 119 | net = models.alexnet(pretrained=False) 120 | self.features = net.features 121 | self.classifier = net.classifier 122 | 123 | def forward(self, x): 124 | x = x.view(-1, x.size(2), x.size(3), x.size(4)) 125 | x = self.features(x) 126 | x = x.view(-1, 12, x.size()[1], x.size()[2], x.size()[3]) 127 | x, _= torch.max(x, 1) 128 | x = self.classifier(x) 129 | return x 130 | 131 | 132 | if __name__ == '__main__': 133 | ''' 134 | pool_idx = 13 135 | # avoid pool at relu layer, because if relu is inplace, then 136 | # may cause misleading 137 | model_prev_pool = Net_Prev_Pool(pool_idx).cuda() 138 | view_and_pool = View_And_Pool().cuda() 139 | # ipdb.set_trace() 140 | x = Variable(torch.rand(12*2, 3, 224, 224).cuda()) 141 | model_after_pool = Net_After_Pool(pool_idx).cuda() 142 | bp = model_prev_pool(x) 143 | ap = view_and_pool(bp) 144 | o1 = model_after_pool(ap) 145 | 146 | whole = Net_Whole().cuda() 147 | ipdb.set_trace() 148 | 149 | x = Variable(torch.rand(12*2, 3, 224, 224).cuda()) 150 | o2 = whole(x) 151 | ''' 152 | m = zzc_maxpooling() 153 | x = Variable(torch.rand(2, 12, 3, 224, 224).cuda()) 154 | o = m(x) 155 | ipdb.set_trace() 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /models/vgg11_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | # import ipdb 5 | from torch.autograd import Variable 6 | 7 | 8 | ''' 9 | VGG ( 10 | (features): Sequential ( 11 | (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 12 | (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 13 | (2): ReLU (inplace) 14 | (3): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 15 | (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 16 | (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 17 | (6): ReLU (inplace) 18 | (7): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 19 | (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 20 | (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 21 | (10): ReLU (inplace) 22 | (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 23 | (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 24 | (13): ReLU (inplace) 25 | (14): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 26 | (15): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 27 | (16): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 28 | (17): ReLU (inplace) 29 | (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 30 | (19): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 31 | (20): ReLU (inplace) 32 | (21): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 33 | (22): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 34 | (23): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 35 | (24): ReLU (inplace) 36 | (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 37 | (26): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 38 | (27): ReLU (inplace) 39 | (28): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 40 | ) 41 | (classifier): Sequential ( 42 | (0): Linear (25088 -> 4096) 43 | (1): ReLU (inplace) 44 | (2): Dropout (p = 0.5) 45 | (3): Linear (4096 -> 4096) 46 | (4): ReLU (inplace) 47 | (5): Dropout (p = 0.5) 48 | (6): Linear (4096 -> 1000) 49 | ) 50 | ) 51 | 52 | ''' 53 | nclasses = 90 54 | original_model = models.vgg11_bn(pretrained=True) 55 | # original_model.classifier._modules['6'] = nn.Linear(4096, nclasses) 56 | def Net_Classifier(nfea=4096, nclasses=90): 57 | return nn.Linear(nfea, nclasses) 58 | 59 | class Net_Prev_Pool(nn.Module): 60 | def __init__(self, pool_idx=21): 61 | super(Net_Prev_Pool, self).__init__() 62 | self.Prev_Pool_Net = nn.Sequential( 63 | # use bottom layers, suppose pool_idx = 1, 64 | # then we use the bottomest layer(i.e, first layer) 65 | *list(original_model.features.children())[:pool_idx] 66 | ) 67 | def forward(self, x): 68 | x = self.Prev_Pool_Net(x) 69 | return x 70 | 71 | 72 | # this layer has no parameters 73 | class View_And_Pool(nn.Module): 74 | def __init__(self): 75 | super(View_And_Pool, self).__init__() 76 | # note that in python, dimension idx starts from 1 77 | # self.Pool_Net = legacy_nn.Max(1) 78 | # only max pool layer, we will use view in forward function 79 | # self.w = nn.Parameter(torch.ones(1, 12, 1, 1, 1), requires_grad=True) 80 | # self. = nn.Parameter(torch.zeros(12, 4096), requires_grad=True) 81 | 82 | def forward(self, x): 83 | # view x ( (bz*12) x C x H x W) ) as 84 | # bz x 12 x C x H x W 85 | # transform each view: 12 x C x H x W -> 12 X C x H x W 86 | x = x.view(-1, 12, x.size()[1], x.size()[2], x.size()[3]) 87 | # using average pool instead of max pool 88 | x, _= torch.max(x, 1) 89 | 90 | return x 91 | 92 | class Net_After_Pool(nn.Module): 93 | def __init__(self, pool_idx=21): 94 | super(Net_After_Pool, self).__init__() 95 | self.After_Pool_Net = nn.Sequential( 96 | # use top layers, suppose pool_idx = 1, 97 | # then we use from 2 layer up to the topest layer 98 | *list(original_model.features.children())[pool_idx:] 99 | ) 100 | self.modules_list = nn.ModuleList([module for module in original_model.classifier.children()]) 101 | 102 | 103 | def forward(self, x): 104 | x = self.After_Pool_Net(x) 105 | # need to insert a view layer so that we can feed it to classification layers 106 | x = x.view(x.size()[0], -1) 107 | 108 | x = self.modules_list[0](x) 109 | x = self.modules_list[1](x) 110 | x = self.modules_list[2](x) 111 | x = self.modules_list[3](x) 112 | x = self.modules_list[4](x) 113 | out1 = self.modules_list[5](x) 114 | # out2 = self.modules_list[6](out1) 115 | return out1 #[out1, out2] 116 | 117 | class Net_Whole(nn.Module): 118 | def __init__(self, nclasses=90, use_finetuned=False): 119 | super(Net_Whole, self).__init__() 120 | 121 | if use_finetuned: 122 | net = models.vgg11_bn(num_classes=250) 123 | d = torch.load('/home/zczhou/Research/triplet_center_loss/finetune-sketch/vgg11_bn.pth.tar') 124 | sd = d['state_dict'] 125 | od = net.state_dict() 126 | for sk, ok in zip(sd.keys(), od.keys()): 127 | od[ok] = sd[sk] 128 | else: 129 | net = models.vgg11_bn(pretrained=True) 130 | self.features = net.features 131 | classifier = net.classifier 132 | # classifier._modules['6'] = nn.Linear(4096, nclasses) 133 | self.modules_list = nn.ModuleList([module for module in classifier.children()]) 134 | 135 | def forward(self, x): 136 | x = self.features(x) 137 | x = x.view(x.size()[0], -1) 138 | x = self.modules_list[0](x) 139 | x = self.modules_list[1](x) 140 | x = self.modules_list[2](x) 141 | x = self.modules_list[3](x) 142 | x = self.modules_list[4](x) 143 | out1 = self.modules_list[5](x) 144 | # out2 = self.modules_list[6](out1) 145 | return out1 #[out1, out2] 146 | 147 | 148 | # no use 149 | class zzc_maxpooling(nn.Module): 150 | def __init__(self): 151 | super(zzc_maxpooling, self).__init__() 152 | net = models.alexnet(pretrained=False) 153 | self.features = net.features 154 | self.classifier = net.classifier 155 | 156 | def forward(self, x): 157 | x = x.view(-1, x.size(2), x.size(3), x.size(4)) 158 | x = self.features(x) 159 | x = x.view(-1, 12, x.size()[1], x.size()[2], x.size()[3]) 160 | x, _= torch.max(x, 1) 161 | x = self.classifier(x) 162 | return x 163 | 164 | 165 | if __name__ == '__main__': 166 | ''' 167 | pool_idx = 13 168 | # avoid pool at relu layer, because if relu is inplace, then 169 | # may cause misleading 170 | model_prev_pool = Net_Prev_Pool(pool_idx).cuda() 171 | view_and_pool = View_And_Pool().cuda() 172 | # ipdb.set_trace() 173 | x = Variable(torch.rand(12*2, 3, 224, 224).cuda()) 174 | model_after_pool = Net_After_Pool(pool_idx).cuda() 175 | bp = model_prev_pool(x) 176 | ap = view_and_pool(bp) 177 | o1 = model_after_pool(ap) 178 | 179 | whole = Net_Whole().cuda() 180 | ipdb.set_trace() 181 | 182 | x = Variable(torch.rand(12*2, 3, 224, 224).cuda()) 183 | o2 = whole(x) 184 | ''' 185 | m = zzc_maxpooling() 186 | x = Variable(torch.rand(2, 12, 3, 224, 224).cuda()) 187 | o = m(x) 188 | ipdb.set_trace() 189 | 190 | 191 | 192 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_arguments(): 5 | parser = argparse.ArgumentParser(description="Triplet center loss for sketch-based-shape retrieval") 6 | parser.add_argument('-b', '--batch-size', type=int, default=8) 7 | parser.add_argument('-j', '--workers', type=int, default=0) 8 | # parser.add_argument('--split', type=int, default=0) 9 | parser.add_argument('--height', type=int, 10 | help="input height, default: 256 for resnet*, " 11 | "144 for inception") 12 | parser.add_argument('--width', type=int, 13 | help="input width, default: 128 for resnet*, " 14 | "56 for inception") 15 | parser.add_argument('--features', type=int, default=4096) 16 | # parser.add_argument('--dropout', type=float, default=0.5) 17 | # optimizer 18 | parser.add_argument('--lr', type=float, default=0.01, 19 | help="learning rate of new parameters, for pretrained " 20 | "parameters it is 10 times smaller than this") 21 | # parser.add_argument('--momentum', type=float, default=0.9) 22 | # parser.add_argument('--weight-decay', type=float, default=5e-4) 23 | # training configs 24 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 25 | parser.add_argument('--evaluate', action='store_true', 26 | help="evaluation only") 27 | parser.add_argument('--max_epochs', type=int, default=100) 28 | parser.add_argument('--start_save', type=int, default=15, 29 | help="start saving checkpoints after specific epoch") 30 | parser.add_argument('--seed', type=int, default=1) 31 | parser.add_argument('--print-freq', type=int, default=1) 32 | # metric learning 33 | parser.add_argument('--wn', action='store_true', help='weight normalization for centers') 34 | parser.add_argument('--w1', type=float, default=1, help='weight for classification loss') 35 | parser.add_argument('--margin', type=float, default=5, help='margin for triplet center loss') 36 | parser.add_argument('--init', action='store_true', help='initial the norm of centers') 37 | parser.add_argument('--norm', action='store_true', help='feature normalizations') 38 | # clamp parameters into a cube 39 | parser.add_argument('--gradient_clip', type=float, default=0.05) # previous i set it to be 0.01 40 | #parser.add_argument('--pool-idx', type=int, default=13) # 13 is for alexnet 41 | parser.add_argument('--pool-idx', type=int) 42 | # parser.add_argument('--arch', type=str, default='alexnet') 43 | parser.add_argument('--w2', type=float, default=0.1) 44 | parser.add_argument('--sf', action='store_true') 45 | parser.add_argument('--pk-flag', action='store_true') 46 | parser.add_argument('--num-instances',type=int, default=5) 47 | parser.add_argument('--weight-decay', type=float, default=1e-4) 48 | parser.add_argument('--balance', action='store_true') 49 | parser.add_argument('--interval', type=int, default=5) 50 | parser.add_argument('--backbone', choices=['alexnet','vgg11_bn','resnet50','resnet101','vgg13_bn', 'vgg16_bn'],default='alexnet') 51 | parser.add_argument('--sketch_finetune', action='store_true') 52 | 53 | # specify data folders 54 | # test and train shapes/sketches reside in the same folder 55 | parser.add_argument('--train_shape_views_folder', type=str, default='/home/lxl/dataset/shape/SceneSBR2018_Dataset/Views') 56 | parser.add_argument('--test_shape_views_folder', type=str, default='/home/lxl/dataset/shape/SceneSBR2018_Dataset/Views') 57 | parser.add_argument('--train_shape_flist', type=str, default='/home/lxl/dataset/shape/SceneSBR2018_Dataset/lists/model_train_pair.txt') 58 | parser.add_argument('--test_shape_flist', type=str, default='/home/lxl/dataset/shape/SceneSBR2018_Dataset/lists/test_pair.txt') 59 | 60 | parser.add_argument('--train_sketch_folder', type=str, default='/home/lxl/dataset/shape/SceneSBR2018_Dataset/Sketches') 61 | parser.add_argument('--test_sketch_folder', type=str, default='/home/lxl/dataset/shape/SceneSBR2018_Dataset/Sketches') 62 | parser.add_argument('--train_sketch_flist', type=str, default='/home/lxl/dataset/shape/SceneSBR2018_Dataset/lists/sketch_train_pair.txt') 63 | parser.add_argument('--test_sketch_flist', type=str, default='/home/lxl/dataset/shape/SceneSBR2018_Dataset/lists/sketch_test_pair.txt') 64 | 65 | # save folder 66 | parser.add_argument('--checkpoint_folder', type=str, default='./models_checkpoint/Sketch') 67 | return parser.parse_args() 68 | -------------------------------------------------------------------------------- /run_image_based.sh: -------------------------------------------------------------------------------- 1 | GPU=$1 2 | backbone=$2 3 | DATE=`date +%Y-%m-%d_%H-%M-%S` 4 | action=$3 5 | 6 | ckptfolder=$4 7 | data3dtest=/home/lxl/dataset/shape/SceneSBR2018_Dataset/lists/test_pair.txt 8 | 9 | if [ $action -eq 1 ]; then 10 | Data2dRoot=/home/lxl/dataset/shape/SceneIBR2018_Dataset 11 | Data2dFolder=${Data2dRoot}/Images/Scene 12 | Data3dTrain=/home/lxl/dataset/shape/SceneSBR2018_Dataset/lists/model_all.txt 13 | Data2dTrain=${Data2dRoot}/lists/image_train_pair.txt 14 | Data2dTest=${Data2dRoot}/lists/image_test_pair.txt 15 | CUDA_VISIBLE_DEVICES=$GPU python sketch2shape_main.py --backbone $backbone --max_epoch 150 --print-freq 10 --margin 1 --w1 1 --w2 0.1 --train_sketch_folder $Data2dFolder --test_sketch_folder $Data2dFolder --train_sketch_flist $Data2dTrain --test_sketch_flist $Data2dTest --train_shape_flist $Data3dTrain --checkpoint_folder $ckptfolder | tee logs/image_${backbone}_${DATE}.txt 16 | fi 17 | 18 | if [ $action -eq 2 ]; then 19 | Data2dRoot=/home/lxl/dataset/shape/SceneIBR2018_Dataset 20 | Data2dFolder=${Data2dRoot}/Images/Scene 21 | Data2dTrain=${Data2dRoot}/lists/image_train_pair.txt 22 | Data2dTest=${Data2dRoot}/lists/image_test_pair.txt 23 | CUDA_VISIBLE_DEVICES=$GPU python extract_features.py --backbone $backbone -b 4 --max_epoch 150 --print-freq 10 --margin 1 --w1 1 --w2 0.1 --train_sketch_folder $Data2dFolder --test_sketch_folder $Data2dFolder --test_shape_flist $data3dtest --train_sketch_flist $Data2dTrain --test_sketch_flist $Data2dTest --checkpoint_folder $ckptfolder | tee logs/image_${backbone}_eval_${DATE}.txt 24 | fi 25 | -------------------------------------------------------------------------------- /run_sketch_based.sh: -------------------------------------------------------------------------------- 1 | GPU=$1 2 | DATE=`date +%Y-%m-%d_%H-%M-%S` 3 | backbone=$2 4 | 5 | action=$3 6 | 7 | ckptfolder=$4 8 | dataroot=/home/lxl/dataset/shape/SceneSBR2018_Dataset 9 | data2dtrain=$dataroot/lists/sketch_all.txt 10 | data3dtrain=$dataroot/lists/model_all.txt 11 | data3dtest=/home/lxl/dataset/shape/SceneSBR2018_Dataset/lists/test_pair.txt 12 | 13 | # standard train 14 | if [ $action -eq 1 ]; then 15 | CUDA_VISIBLE_DEVICES=${GPU} python sketch2shape_main.py --checkpoint_folder $ckptfolder --train_shape_flist $data3dtrain --train_sketch_flist $data2dtrain --backbone $backbone --margin 1 --w1 1 --w2 0.1 --print-freq 10 --max_epochs 150 | tee logs/sketch_${backbone}_${DATE}.txt 16 | fi 17 | 18 | # standard test 19 | if [ $action -eq 2 ]; then 20 | CUDA_VISIBLE_DEVICES=${GPU} python extract_features.py --checkpoint_folder $ckptfolder --backbone $backbone --margin 1 --w1 1 --w2 0.1 --test_shape_flist $data3dtest --print-freq 10 --max_epochs 150 | tee logs/sketch_${backbone}_eval_${DATE}.txt 21 | fi 22 | 23 | --------------------------------------------------------------------------------