├── .gitattributes ├── CVPR2020.pdf ├── images ├── results.PNG └── framework.PNG ├── util ├── optimizers.py ├── FeatureExtractor.py ├── gumbel.py ├── spectral.py ├── dataset_loader.py ├── transforms.py ├── samplers.py ├── re_ranking.py ├── ms_ssim.py ├── local_dist.py ├── distance.py ├── utils.py └── eval_metrics.py ├── LICENSE ├── models ├── DenseNet.py ├── __init__.py ├── LSRO.py ├── AlignedReID.py ├── IDE.py ├── PCB.py ├── MuDeep.py └── HACNN.py ├── ReID_attr.py ├── advloss.py ├── README.md ├── GD.py ├── train.py └── opts.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /CVPR2020.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whj363636/Adversarial-attack-on-Person-ReID-With-Deep-Mis-Ranking/HEAD/CVPR2020.pdf -------------------------------------------------------------------------------- /images/results.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whj363636/Adversarial-attack-on-Person-ReID-With-Deep-Mis-Ranking/HEAD/images/results.PNG -------------------------------------------------------------------------------- /images/framework.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whj363636/Adversarial-attack-on-Person-ReID-With-Deep-Mis-Ranking/HEAD/images/framework.PNG -------------------------------------------------------------------------------- /util/optimizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | __all__ = ['init_optim'] 4 | 5 | def init_optim(optim, params, lr, weight_decay): 6 | if optim == 'adam': 7 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay) 8 | elif optim == 'sgd': 9 | return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay) 10 | elif optim == 'rmsprop': 11 | return torch.optim.RMSprop(params, lr=lr, momentum=0.9, weight_decay=weight_decay) 12 | else: 13 | raise KeyError("Unsupported optim: {}".format(optim)) -------------------------------------------------------------------------------- /util/FeatureExtractor.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from IPython import embed 3 | 4 | class FeatureExtractor(nn.Module): 5 | def __init__(self,submodule,extracted_layers): 6 | super(FeatureExtractor,self).__init__() 7 | self.submodule = submodule 8 | self.extracted_layers = extracted_layers 9 | 10 | def forward(self, x): 11 | outputs = [] 12 | for name, module in self.submodule._modules.items(): 13 | if name is "classfier": 14 | x = x.view(x.size(0),-1) 15 | if name is "base": 16 | for block_name, cnn_block in module._modules.items(): 17 | x = cnn_block(x) 18 | if block_name in self.extracted_layers: 19 | outputs.append(x) 20 | return outputs -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 whj 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/DenseNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | import torchvision 7 | 8 | __all__ = ['DenseNet121'] 9 | 10 | 11 | class DenseNet121(nn.Module): 12 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 13 | super(DenseNet121, self).__init__() 14 | self.loss = loss 15 | densenet121 = torchvision.models.densenet121(pretrained=True) 16 | self.base = densenet121.features 17 | self.classifier = nn.Linear(1024, num_classes) 18 | self.feat_dim = 1024 # feature dimension 19 | 20 | def forward(self, x, is_training): 21 | x = self.base(x) 22 | x = F.avg_pool2d(x, x.size()[2:]) 23 | f = x.view(x.size(0), -1) 24 | if not is_training: 25 | return f 26 | y = self.classifier(f) 27 | 28 | if self.loss == {'xent'}: 29 | return [y] 30 | elif self.loss == {'xent', 'htri'}: 31 | return [y, f] 32 | elif self.loss == {'cent'}: 33 | return [y, f] 34 | elif self.loss == {'ring'}: 35 | return [y, f] 36 | else: 37 | raise KeyError("Unsupported loss: {}".format(self.loss)) -------------------------------------------------------------------------------- /util/gumbel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch import Tensor 8 | from torch.nn import Parameter 9 | 10 | 11 | def _sample_gumbel(shape, eps=1e-10, out=None): 12 | """ 13 | Based on 14 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb , 15 | (MIT license) 16 | """ 17 | U = out.resize_(shape).uniform_() if out is not None else torch.rand(shape) 18 | return - torch.log(eps - torch.log(U + eps)) 19 | 20 | 21 | def _gumbel_softmax_sample(logits, T=1, eps=1e-10): 22 | """ 23 | Based on 24 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb 25 | (MIT license) 26 | """ 27 | dims = logits.dim() 28 | gumbel_noise = _sample_gumbel(logits.size(), eps=eps, out=logits.data.new()) 29 | y = logits + gumbel_noise 30 | return F.softmax(y / T, dims - 1) 31 | 32 | 33 | def gumbel_softmax(logits, k, T=1, hard=True, eps=1e-10): 34 | shape = logits.size() 35 | assert len(shape) == 2 36 | y_soft = _gumbel_softmax_sample(logits, T=T, eps=eps) 37 | if hard: 38 | _, ind = torch.topk(y_soft, k=k, dim=-1, largest=True) 39 | y_hard = logits.new_zeros(*shape).scatter_(-1, ind.view(-1, k), 1.0) 40 | y = y_hard - y_soft.detach() + y_soft 41 | else: 42 | y = y_soft 43 | return y -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .DenseNet import * 6 | from .MuDeep import * 7 | from .AlignedReID import * 8 | from .PCB import * 9 | from .HACNN import * 10 | from .IDE import * 11 | from .LSRO import * 12 | 13 | __factory = { 14 | # 1. 15 | 'hacnn': HACNN, 16 | 'densenet121': DenseNet121, 17 | 'ide': IDE, 18 | # 2. 19 | 'aligned': ResNet50, 20 | 'pcb': PCB, 21 | 'mudeep': MuDeep, 22 | # 3. 23 | 'cam': IDE, 24 | 'hhl': IDE, 25 | 'lsro': DenseNet121, 26 | 'spgan': IDE, 27 | } 28 | 29 | def get_names(): 30 | return __factory.keys() 31 | 32 | def init_model(name, pre_dir, *args, **kwargs): 33 | if name not in __factory.keys(): 34 | raise KeyError("Unknown model: {}".format(name)) 35 | 36 | print("Initializing model: {}".format(name)) 37 | net = __factory[name](*args, **kwargs) 38 | # load pretrained model 39 | checkpoint = torch.load(pre_dir) # for Python 2 40 | # checkpoint = torch.load(pre_dir, encoding="latin1") # for Python 3 41 | state_dict = checkpoint['state_dict'] if isinstance(checkpoint, dict) and 'state_dict' in checkpoint else checkpoint 42 | change = False 43 | for k, v in state_dict.items(): 44 | if k[:6] == 'module': 45 | change = True 46 | break 47 | if not change: 48 | new_state_dict = state_dict 49 | else: 50 | from collections import OrderedDict 51 | new_state_dict = OrderedDict() 52 | for k, v in state_dict.items(): 53 | name = k[7:] # remove 'module.' of dataparallel 54 | new_state_dict[name]=v 55 | net.load_state_dict(new_state_dict) 56 | # freeze 57 | net.eval() 58 | net.volatile = True 59 | return net -------------------------------------------------------------------------------- /models/LSRO.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import torchvision 6 | from torchvision import models 7 | from torch.autograd import Variable 8 | from torch.nn import functional as F 9 | 10 | __all__ = ['DenseNet121'] 11 | 12 | 13 | def weights_init_kaiming(m): 14 | classname = m.__class__.__name__ 15 | # print(classname) 16 | if classname.find('Conv') != -1: 17 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 18 | elif classname.find('Linear') != -1: 19 | init.kaiming_normal(m.weight.data, a=0, mode='fan_out') 20 | init.constant(m.bias.data, 0.0) 21 | elif classname.find('BatchNorm1d') != -1: 22 | init.normal(m.weight.data, 1.0, 0.02) 23 | init.constant(m.bias.data, 0.0) 24 | 25 | def weights_init_classifier(m): 26 | classname = m.__class__.__name__ 27 | if classname.find('Linear') != -1: 28 | init.normal(m.weight.data, std=0.001) 29 | init.constant(m.bias.data, 0.0) 30 | 31 | class DenseNet121(nn.Module): 32 | def __init__(self, num_classes): 33 | super(DenseNet121,self).__init__() 34 | model_ft = models.densenet121(pretrained=True) 35 | # add pooling to the model 36 | # in the originial version, pooling is written in the forward function 37 | model_ft.features.avgpool = nn.AdaptiveAvgPool2d((1,1)) 38 | 39 | add_block = [] 40 | num_bottleneck = 512 41 | add_block += [nn.Linear(1024, num_bottleneck)] #For ResNet, it is 2048 42 | add_block += [nn.BatchNorm1d(num_bottleneck)] 43 | add_block += [nn.LeakyReLU(0.1)] 44 | add_block += [nn.Dropout(p=0.5)] 45 | add_block = nn.Sequential(*add_block) 46 | add_block.apply(weights_init_kaiming) 47 | model_ft.fc = add_block 48 | self.model = model_ft 49 | 50 | classifier = [] 51 | classifier += [nn.Linear(num_bottleneck, num_classes)] 52 | classifier = nn.Sequential(*classifier) 53 | classifier.apply(weights_init_classifier) 54 | self.classifier = classifier 55 | 56 | def forward(self, x, is_training): 57 | x = self.model.features(x) 58 | x = x.view(x.size(0),-1) 59 | x = self.model.fc(x) 60 | logits = self.classifier(x) 61 | return [logits, x] 62 | 63 | -------------------------------------------------------------------------------- /models/AlignedReID.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | import torchvision 7 | 8 | __all__ = ['ResNet50'] 9 | 10 | class ResNet50(nn.Module): 11 | """ 12 | Alignedreid: Surpassing human-level performance in person re-identification 13 | 14 | Reference: 15 | Zhang, Xuan, et al. "Alignedreid: Surpassing human-level performance in person re-identification." arXiv preprint arXiv:1711.08184 (2017) 16 | """ 17 | def __init__(self, num_classes, **kwargs): 18 | super(ResNet50, self).__init__() 19 | self.loss = {'softmax', 'metric'} 20 | resnet50 = torchvision.models.resnet50(pretrained=True) 21 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 22 | self.classifier = nn.Linear(2048, num_classes) 23 | self.feat_dim = 2048 # feature dimension 24 | self.aligned = True 25 | self.horizon_pool = HorizontalMaxPool2d() 26 | if self.aligned: 27 | self.bn = nn.BatchNorm2d(2048) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv1 = nn.Conv2d(2048, 128, kernel_size=1, stride=1, padding=0, bias=True) 30 | 31 | def forward(self, x, is_training): 32 | x = self.base(x) 33 | if not is_training: 34 | lf = self.horizon_pool(x) 35 | if self.aligned and is_training: 36 | lf = self.bn(x) 37 | lf = self.relu(lf) 38 | lf = self.horizon_pool(lf) 39 | lf = self.conv1(lf) 40 | if self.aligned or not is_training: 41 | lf = lf.view(lf.size()[0:3]) 42 | lf = lf / torch.pow(lf,2).sum(dim=1, keepdim=True).clamp(min=1e-12).sqrt() 43 | x = F.avg_pool2d(x, x.size()[2:]) 44 | f = x.view(x.size(0), -1) 45 | #f = 1. * f / (torch.norm(f, 2, dim=-1, keepdim=True).expand_as(f) + 1e-12) 46 | if not is_training: 47 | return [f,lf] 48 | y = self.classifier(f) 49 | if self.loss == {'softmax'}: 50 | return [y] 51 | elif self.loss == {'metric'}: 52 | if self.aligned: 53 | return [f, lf] 54 | return [f] 55 | elif self.loss == {'softmax', 'metric'}: 56 | if self.aligned: 57 | return [y, f, lf] 58 | return [y, f] 59 | else: 60 | raise KeyError("Unsupported loss: {}".format(self.loss)) 61 | 62 | class HorizontalMaxPool2d(nn.Module): 63 | def __init__(self): 64 | super(HorizontalMaxPool2d, self).__init__() 65 | 66 | 67 | def forward(self, x): 68 | inp_size = x.size() 69 | return nn.functional.max_pool2d(input=x,kernel_size= (1, inp_size[3])) -------------------------------------------------------------------------------- /util/spectral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch import Tensor 8 | from torch.nn import Parameter 9 | 10 | def l2normalize(v, eps=1e-12): 11 | return v / (v.norm() + eps) 12 | 13 | 14 | class SpectralNorm(nn.Module): 15 | def __init__(self, module, name='weight', power_iterations=1): 16 | super(SpectralNorm, self).__init__() 17 | self.module = module 18 | self.name = name 19 | self.power_iterations = power_iterations 20 | if not self._made_params(): 21 | self._make_params() 22 | 23 | def _update_u_v(self): 24 | u = getattr(self.module, self.name + "_u") 25 | v = getattr(self.module, self.name + "_v") 26 | w = getattr(self.module, self.name + "_bar") 27 | 28 | height = w.data.shape[0] 29 | for _ in range(self.power_iterations): 30 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 31 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 32 | 33 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 34 | sigma = u.dot(w.view(height, -1).mv(v)) 35 | setattr(self.module, self.name, w / sigma.expand_as(w)) 36 | 37 | def _made_params(self): 38 | try: 39 | u = getattr(self.module, self.name + "_u") 40 | v = getattr(self.module, self.name + "_v") 41 | w = getattr(self.module, self.name + "_bar") 42 | return True 43 | except AttributeError: 44 | return False 45 | 46 | 47 | def _make_params(self): 48 | w = getattr(self.module, self.name) 49 | 50 | height = w.data.shape[0] 51 | width = w.view(height, -1).data.shape[1] 52 | 53 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 54 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 55 | u.data = l2normalize(u.data) 56 | v.data = l2normalize(v.data) 57 | w_bar = Parameter(w.data) 58 | 59 | del self.module._parameters[self.name] 60 | 61 | self.module.register_parameter(self.name + "_u", u) 62 | self.module.register_parameter(self.name + "_v", v) 63 | self.module.register_parameter(self.name + "_bar", w_bar) 64 | 65 | 66 | def forward(self, *args): 67 | self._update_u_v() 68 | return self.module.forward(*args) -------------------------------------------------------------------------------- /models/IDE.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import pdb 8 | 9 | __all__ = ['IDE'] 10 | 11 | 12 | class IDE(nn.Module): 13 | def __init__(self, pretrained=True, cut_at_pooling=False, 14 | num_features=1024, norm=False, dropout=0, num_classes=0): 15 | super(IDE, self).__init__() 16 | 17 | self.pretrained = pretrained 18 | self.cut_at_pooling = cut_at_pooling 19 | 20 | # Construct base (pretrained) resnet 21 | self.base = torchvision.models.resnet50(pretrained=True) 22 | 23 | if not self.cut_at_pooling: 24 | self.num_features = num_features 25 | self.norm = norm 26 | self.dropout = dropout 27 | self.has_embedding = num_features > 0 28 | self.num_classes = num_classes 29 | 30 | out_planes = self.base.fc.in_features 31 | 32 | # Append new layers 33 | if self.has_embedding: 34 | self.feat = nn.Linear(out_planes, self.num_features) 35 | self.feat_bn = nn.BatchNorm1d(self.num_features) 36 | init.kaiming_normal(self.feat.weight, mode='fan_out') 37 | init.constant(self.feat.bias, 0) 38 | init.constant(self.feat_bn.weight, 1) 39 | init.constant(self.feat_bn.bias, 0) 40 | else: 41 | # Change the num_features to CNN output channels 42 | self.num_features = out_planes 43 | if self.dropout > 0: 44 | self.drop = nn.Dropout(self.dropout) 45 | if self.num_classes > 0: 46 | self.classifier = nn.Linear(self.num_features, self.num_classes) 47 | init.normal(self.classifier.weight, std=0.001) 48 | init.constant(self.classifier.bias, 0) 49 | 50 | if not self.pretrained: 51 | self.reset_params() 52 | 53 | def forward(self, x, is_training, output_feature=None): 54 | for name, module in self.base._modules.items(): 55 | if name == 'avgpool': 56 | break 57 | x = module(x) 58 | 59 | if self.cut_at_pooling: 60 | return x 61 | 62 | x = F.avg_pool2d(x, x.size()[2:]) 63 | x = x.view(x.size(0), -1) 64 | 65 | if output_feature == 'pool5': 66 | x = F.normalize(x) 67 | return x 68 | if self.has_embedding: 69 | x = self.feat(x) 70 | x = self.feat_bn(x) 71 | if self.norm: 72 | x = F.normalize(x) 73 | elif self.has_embedding: 74 | x = F.relu(x) 75 | if self.dropout > 0: 76 | x = self.drop(x) 77 | if self.num_classes > 0: 78 | logits = self.classifier(x) 79 | return [logits, x] 80 | 81 | def reset_params(self): 82 | for m in self.modules(): 83 | if isinstance(m, nn.Conv2d): 84 | init.kaiming_normal(m.weight, mode='fan_out') 85 | if m.bias is not None: 86 | init.constant(m.bias, 0) 87 | elif isinstance(m, nn.BatchNorm2d): 88 | init.constant(m.weight, 1) 89 | init.constant(m.bias, 0) 90 | elif isinstance(m, nn.Linear): 91 | init.normal(m.weight, std=0.001) 92 | if m.bias is not None: 93 | init.constant(m.bias, 0) -------------------------------------------------------------------------------- /ReID_attr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import math 5 | import random 6 | import glob 7 | import cv2 8 | import torch 9 | from scipy import io 10 | from opts import market1501_train_map, duke_train_map, get_opts 11 | 12 | market_dict = {'age':[1,2,3,4], # young(1), teenager(2), adult(3), old(4) 13 | 'backpack':[1,2], # no(1), yes(2) 14 | 'bag':[1,2], # no(1), yes(2) 15 | 'handbag':[1,2], # no(1), yes(2) 16 | 'downblack':[1,2], # no(1), yes(2) 17 | 'downblue':[1,2], # no(1), yes(2) 18 | 'downbrown':[1,2], # no(1), yes(2) 19 | 'downgray':[1,2], # no(1), yes(2) 20 | 'downgreen':[1,2], # no(1), yes(2) 21 | 'downpink':[1,2], # no(1), yes(2) 22 | 'downpurple':[1,2], # no(1), yes(2) 23 | 'downwhite':[1,2], # no(1), yes(2) 24 | 'downyellow':[1,2], # no(1), yes(2) 25 | 'upblack':[1,2], # no(1), yes(2) 26 | 'upblue':[1,2], # no(1), yes(2) 27 | 'upgreen':[1,2], # no(1), yes(2) 28 | 'upgray':[1,2], # no(1), yes(2) 29 | 'uppurple':[1,2], # no(1), yes(2) 30 | 'upred':[1,2], # no(1), yes(2) 31 | 'upwhite':[1,2], # no(1), yes(2) 32 | 'upyellow':[1,2], # no(1), yes(2) 33 | 'clothes':[1,2], # dress(1), pants(2) 34 | 'down':[1,2], # long lower body clothing(1), short(2) 35 | 'up':[1,2], # long sleeve(1), short sleeve(2) 36 | 'hair':[1,2], # short hair(1), long hair(2) 37 | 'hat':[1,2], # no(1), yes(2) 38 | 'gender':[1,2]}# male(1), female(2) 39 | 40 | duke_dict = {'gender':[1,2], # male(1), female(2) 41 | 'top':[1,2], # short upper body clothing(1), long(2) 42 | 'boots':[1,2], # no(1), yes(2) 43 | 'hat':[1,2], # no(1), yes(2) 44 | 'backpack':[1,2], # no(1), yes(2) 45 | 'bag':[1,2], # no(1), yes(2) 46 | 'handbag':[1,2], # no(1), yes(2) 47 | 'shoes':[1,2], # dark(1), light(2) 48 | 'downblack':[1,2], # no(1), yes(2) 49 | 'downwhite':[1,2], # no(1), yes(2) 50 | 'downred':[1,2], # no(1), yes(2) 51 | 'downgray':[1,2], # no(1), yes(2) 52 | 'downblue':[1,2], # no(1), yes(2) 53 | 'downgreen':[1,2], # no(1), yes(2) 54 | 'downbrown':[1,2], # no(1), yes(2) 55 | 'upblack':[1,2], # no(1), yes(2) 56 | 'upwhite':[1,2], # no(1), yes(2) 57 | 'upred':[1,2], # no(1), yes(2) 58 | 'uppurple':[1,2], # no(1), yes(2) 59 | 'upgray':[1,2], # no(1), yes(2) 60 | 'upblue':[1,2], # no(1), yes(2) 61 | 'upgreen':[1,2], # no(1), yes(2) 62 | 'upbrown':[1,2]} # no(1), yes(2) 63 | 64 | __dict_factory={ 65 | 'market_attribute': market_dict, 66 | 'dukemtmcreid_attribute': duke_dict 67 | } 68 | 69 | def get_keys(dict_name): 70 | for key, value in __dict_factory.items(): 71 | if key == dict_name: 72 | return value.keys() 73 | 74 | def get_target_withattr(attr_matrix, dataset_name, attr_list, pids, pids_raw): 75 | attr_key, attr_value = attr_list 76 | attr_name = 'duke_attribute' if dataset_name == 'dukemtmcreid' else 'market_attribute' 77 | mapping = duke_train_map if dataset_name == 'dukemtmcreid' else market1501_train_map 78 | column = attr_matrix[attr_name][0]['train'][0][0][attr_key][0][0] 79 | 80 | n = pids_raw.size(0) 81 | targets = np.zeros_like(column) 82 | for i in range(n): 83 | if column[mapping[pids_raw[i].item()]] == attr_value: 84 | targets[pids[i].item()] = 1 85 | return torch.from_numpy(targets).view(1,-1).repeat(n, 1) -------------------------------------------------------------------------------- /util/dataset_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | import random 6 | import os.path as osp 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | 11 | def read_image(img_path): 12 | """Keep reading image until succeed. 13 | This can avoid IOError incurred by heavy IO process.""" 14 | got_img = False 15 | if not osp.exists(img_path): 16 | raise IOError("{} does not exist".format(img_path)) 17 | while not got_img: 18 | try: 19 | img = Image.open(img_path).convert('RGB') 20 | got_img = True 21 | except IOError: 22 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 23 | pass 24 | return img 25 | 26 | class ImageDataset(Dataset): 27 | """Image Person ReID Dataset""" 28 | def __init__(self, dataset, transform=None): 29 | self.dataset = dataset 30 | self.transform = transform 31 | 32 | def __len__(self): 33 | return len(self.dataset) 34 | 35 | def __getitem__(self, index): 36 | tp = self.dataset[index] 37 | if len(tp) == 3: 38 | img_path, pid, camid = tp 39 | pid_raw = pid 40 | elif len(tp) == 4: 41 | img_path, pid, camid, pid_raw = tp 42 | img = read_image(img_path) 43 | if self.transform is not None: 44 | img = self.transform(img) 45 | return img, pid, camid, pid_raw 46 | 47 | class VideoDataset(Dataset): 48 | """Video Person ReID Dataset. 49 | Note batch data has shape (batch, seq_len, channel, height, width). 50 | """ 51 | sample_methods = ['evenly', 'random', 'all'] 52 | 53 | def __init__(self, dataset, seq_len=15, sample='evenly', transform=None): 54 | self.dataset = dataset 55 | self.seq_len = seq_len 56 | self.sample = sample 57 | self.transform = transform 58 | 59 | def __len__(self): 60 | return len(self.dataset) 61 | 62 | def __getitem__(self, index): 63 | img_paths, pid, camid = self.dataset[index] 64 | num = len(img_paths) 65 | 66 | if self.sample == 'random': 67 | """ 68 | Randomly sample seq_len items from num items, 69 | if num is smaller than seq_len, then replicate items 70 | """ 71 | indices = np.arange(num) 72 | replace = False if num >= self.seq_len else True 73 | indices = np.random.choice(indices, size=self.seq_len, replace=replace) 74 | # sort indices to keep temporal order 75 | # comment it to be order-agnostic 76 | indices = np.sort(indices) 77 | elif self.sample == 'evenly': 78 | """Evenly sample seq_len items from num items.""" 79 | if num >= self.seq_len: 80 | num -= num % self.seq_len 81 | indices = np.arange(0, num, num/self.seq_len) 82 | else: 83 | # if num is smaller than seq_len, simply replicate the last image 84 | # until the seq_len requirement is satisfied 85 | indices = np.arange(0, num) 86 | num_pads = self.seq_len - num 87 | indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num-1)]) 88 | assert len(indices) == self.seq_len 89 | elif self.sample == 'all': 90 | """ 91 | Sample all items, seq_len is useless now and batch_size needs 92 | to be set to 1. 93 | """ 94 | indices = np.arange(num) 95 | else: 96 | raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods)) 97 | 98 | imgs = [] 99 | for index in indices: 100 | img_path = img_paths[index] 101 | img = read_image(img_path) 102 | if self.transform is not None: 103 | img = self.transform(img) 104 | img = img.unsqueeze(0) 105 | imgs.append(img) 106 | imgs = torch.cat(imgs, dim=0) 107 | 108 | return imgs, pid, camid -------------------------------------------------------------------------------- /models/PCB.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | from torchvision import models 7 | 8 | __all__ = ['PCB', 'PCB_test'] 9 | 10 | def weights_init_kaiming(m): 11 | classname = m.__class__.__name__ 12 | # print(classname) 13 | if classname.find('Conv') != -1: 14 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 15 | elif classname.find('Linear') != -1: 16 | init.kaiming_normal(m.weight.data, a=0, mode='fan_out') 17 | init.constant(m.bias.data, 0.0) 18 | elif classname.find('BatchNorm1d') != -1: 19 | init.normal(m.weight.data, 1.0, 0.02) 20 | init.constant(m.bias.data, 0.0) 21 | 22 | def weights_init_classifier(m): 23 | classname = m.__class__.__name__ 24 | if classname.find('Linear') != -1: 25 | init.normal(m.weight.data, std=0.001) 26 | init.constant(m.bias.data, 0.0) 27 | 28 | class ClassBlock(nn.Module): 29 | def __init__(self, input_dim, class_num, dropout=True, relu=True, num_bottleneck=512): 30 | super(ClassBlock, self).__init__() 31 | add_block = [] 32 | add_block += [nn.Linear(input_dim, num_bottleneck)] 33 | add_block += [nn.BatchNorm1d(num_bottleneck)] 34 | if relu: 35 | add_block += [nn.LeakyReLU(0.1)] 36 | if dropout: 37 | add_block += [nn.Dropout(p=0.5)] 38 | add_block = nn.Sequential(*add_block) 39 | add_block.apply(weights_init_kaiming) 40 | 41 | classifier = [] 42 | classifier += [nn.Linear(num_bottleneck, class_num)] 43 | classifier = nn.Sequential(*classifier) 44 | classifier.apply(weights_init_classifier) 45 | 46 | self.add_block = add_block 47 | self.classifier = classifier 48 | def forward(self, x): 49 | x = self.add_block(x) 50 | x = self.classifier(x) 51 | return x 52 | 53 | class PCB(nn.Module): 54 | """ 55 | Based on 56 | https://github.com/layumi/Person_reID_baseline_pytorch 57 | """ 58 | def __init__(self, num_classes): 59 | super(PCB, self).__init__() 60 | 61 | self.part = 6 # We cut the pool5 to 6 parts 62 | model_ft = models.resnet50(pretrained=True) 63 | self.model = model_ft 64 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1)) 65 | self.dropout = nn.Dropout(p=0.5) 66 | # remove the final downsample 67 | self.model.layer4[0].downsample[0].stride = (1,1) 68 | self.model.layer4[0].conv2.stride = (1,1) 69 | # define 6 classifiers 70 | for i in range(self.part): 71 | name = 'classifier'+str(i) 72 | setattr(self, name, ClassBlock(2048, num_classes, True, False, 256)) 73 | 74 | def forward(self, x, is_training): 75 | x = self.model.conv1(x) 76 | x = self.model.bn1(x) 77 | x = self.model.relu(x) 78 | x = self.model.maxpool(x) 79 | 80 | x = self.model.layer1(x) 81 | x = self.model.layer2(x) 82 | x = self.model.layer3(x) 83 | x = self.model.layer4(x) 84 | x = self.avgpool(x) 85 | x = self.dropout(x) 86 | part = {} 87 | feature = [] 88 | predict = [] 89 | # get six part feature batchsize*2048*6 90 | for i in range(self.part): 91 | part[i] = torch.squeeze(x[:,:,i]) 92 | name = 'classifier'+str(i) 93 | c = getattr(self,name) 94 | feature.append(part[i]) 95 | predict.append(c(part[i])) 96 | return [predict, feature] 97 | 98 | class PCB_test(nn.Module): 99 | def __init__(self, model): 100 | super(PCB_test, self).__init__() 101 | self.part = 6 102 | self.model = model.model 103 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1)) 104 | # remove the final downsample 105 | self.model.layer4[0].downsample[0].stride = (1,1) 106 | self.model.layer4[0].conv2.stride = (1,1) 107 | 108 | def forward(self, x, is_training): 109 | x = self.model.conv1(x) 110 | x = self.model.bn1(x) 111 | x = self.model.relu(x) 112 | x = self.model.maxpool(x) 113 | 114 | x = self.model.layer1(x) 115 | x = self.model.layer2(x) 116 | x = self.model.layer3(x) 117 | x = self.model.layer4(x) 118 | x = self.avgpool(x) 119 | y = x.view(x.size(0),x.size(1),x.size(2)) 120 | return [y] -------------------------------------------------------------------------------- /util/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import numpy as np 7 | import math 8 | 9 | 10 | class RectScale(object): 11 | def __init__(self, height, width, interpolation=Image.BILINEAR): 12 | self.height = height 13 | self.width = width 14 | self.interpolation = interpolation 15 | 16 | def __call__(self, img): 17 | w, h = img.size 18 | if h == self.height and w == self.width: 19 | return img 20 | return img.resize((self.width, self.height), self.interpolation) 21 | 22 | 23 | class RandomSizedRectCrop(object): 24 | def __init__(self, height, width, interpolation=Image.BILINEAR): 25 | self.height = height 26 | self.width = width 27 | self.interpolation = interpolation 28 | 29 | def __call__(self, img): 30 | for attempt in range(10): 31 | area = img.size[0] * img.size[1] 32 | target_area = random.uniform(0.64, 1.0) * area 33 | aspect_ratio = random.uniform(2, 3) 34 | 35 | h = int(round(math.sqrt(target_area * aspect_ratio))) 36 | w = int(round(math.sqrt(target_area / aspect_ratio))) 37 | 38 | if w <= img.size[0] and h <= img.size[1]: 39 | x1 = random.randint(0, img.size[0] - w) 40 | y1 = random.randint(0, img.size[1] - h) 41 | 42 | img = img.crop((x1, y1, x1 + w, y1 + h)) 43 | assert(img.size == (w, h)) 44 | 45 | return img.resize((self.width, self.height), self.interpolation) 46 | 47 | # Fallback 48 | scale = RectScale(self.height, self.width, 49 | interpolation=self.interpolation) 50 | return scale(img) 51 | 52 | 53 | class RandomErasing(object): 54 | def __init__(self, EPSILON=0.5, mean=[0.485, 0.456, 0.406]): 55 | self.EPSILON = EPSILON 56 | self.mean = mean 57 | 58 | def __call__(self, img): 59 | 60 | if random.uniform(0, 1) > self.EPSILON: 61 | return img 62 | 63 | for attempt in range(100): 64 | area = img.size()[1] * img.size()[2] 65 | 66 | target_area = random.uniform(0.02, 0.2) * area 67 | aspect_ratio = random.uniform(0.3, 3) 68 | 69 | h = int(round(math.sqrt(target_area * aspect_ratio))) 70 | w = int(round(math.sqrt(target_area / aspect_ratio))) 71 | 72 | if w <= img.size()[2] and h <= img.size()[1]: 73 | x1 = random.randint(0, img.size()[1] - h) 74 | y1 = random.randint(0, img.size()[2] - w) 75 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 76 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 77 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 78 | 79 | return img 80 | 81 | return img 82 | 83 | class Random2DTranslation(object): 84 | """ 85 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 86 | 87 | Args: 88 | height (int): target height. 89 | width (int): target width. 90 | p (float): probability of performing this transformation. Default: 0.5. 91 | """ 92 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 93 | self.height = height 94 | self.width = width 95 | self.p = p 96 | self.interpolation = interpolation 97 | 98 | def __call__(self, img): 99 | """ 100 | Args: 101 | img (PIL Image): Image to be cropped. 102 | 103 | Returns: 104 | PIL Image: Cropped image. 105 | """ 106 | if random.random() < self.p: 107 | return img.resize((self.width, self.height), self.interpolation) 108 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125)) 109 | resized_img = img.resize((new_width, new_height), self.interpolation) 110 | x_maxrange = new_width - self.width 111 | y_maxrange = new_height - self.height 112 | x1 = int(round(random.uniform(0, x_maxrange))) 113 | y1 = int(round(random.uniform(0, y_maxrange))) 114 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height)) 115 | return croped_img 116 | 117 | if __name__ == '__main__': 118 | pass -------------------------------------------------------------------------------- /util/samplers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import numpy as np 4 | import os.path as osp 5 | import torch 6 | from torch.utils.data.sampler import Sampler 7 | 8 | class RandomIdentitySampler(Sampler): 9 | """ 10 | Randomly sample N identities, then for each identity, 11 | randomly sample K instances, therefore batch size is N*K. 12 | 13 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 14 | 15 | Args: 16 | data_source (Dataset): dataset to sample from. 17 | num_instances (int): number of instances per identity. 18 | """ 19 | def __init__(self, data_source, num_instances=4): 20 | self.data_source = data_source 21 | self.num_instances = num_instances 22 | self.index_dic = defaultdict(list) 23 | for index, tp in enumerate(data_source): 24 | if len(tp) == 3: 25 | _, pid, _ = tp 26 | elif len(tp) == 4: 27 | _, pid, _, _ = tp 28 | 29 | self.index_dic[pid].append(index) 30 | self.pids = list(self.index_dic.keys()) 31 | self.num_identities = len(self.pids) 32 | 33 | def __iter__(self): 34 | indices = torch.randperm(self.num_identities) 35 | ret = [] 36 | for i in indices: 37 | pid = self.pids[i] 38 | t = self.index_dic[pid] 39 | replace = False if len(t) >= self.num_instances else True 40 | t = np.random.choice(t, size=self.num_instances, replace=replace) 41 | ret.extend(t) 42 | return iter(ret) 43 | 44 | def __len__(self): 45 | return self.num_identities * self.num_instances 46 | 47 | class RandomIdentitySamplerCls(Sampler): 48 | """ 49 | Randomly sample N identities, then for each identity, 50 | randomly sample K instances, therefore batch size is N*K. 51 | 52 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 53 | 54 | Args: 55 | data_source (Dataset): dataset to sample from. 56 | num_instances (int): number of instances per identity. 57 | """ 58 | def __init__(self, data_source, num_instances=4): 59 | self.data_source = data_source 60 | self.num_instances = num_instances 61 | self.index_dic = defaultdict(list) 62 | for index, (_, target) in enumerate(data_source): 63 | self.index_dic[target].append(index) 64 | self.targets = list(self.index_dic.keys()) 65 | self.num_identities = len(self.targets) 66 | 67 | def __iter__(self): 68 | indices = torch.randperm(self.num_identities) 69 | ret = [] 70 | for i in indices: 71 | target = self.targets[i] 72 | t = self.index_dic[target] 73 | replace = False if len(t) >= self.num_instances else True 74 | t = np.random.choice(t, size=self.num_instances, replace=replace) 75 | ret.extend(t) 76 | return iter(ret) 77 | 78 | def __len__(self): 79 | return self.num_identities * self.num_instances 80 | 81 | class AttrPool(Sampler): 82 | def __init__(self, data_source, dataset_name, attr_matrix, attr_list, sample_num): 83 | from opts import market1501_train_map, duke_train_map 84 | attr_key, attr_value = attr_list 85 | attr_name = 'duke_attribute' if dataset_name == 'dukemtmcreid' else 'market_attribute' 86 | mapping = duke_train_map if dataset_name == 'dukemtmcreid' else market1501_train_map 87 | column = attr_matrix[attr_name][0]['train'][0][0][attr_key][0][0] 88 | 89 | self.data_source = data_source 90 | self.sample_num = sample_num 91 | self.attr_pool = defaultdict(list) 92 | 93 | for index, (_, pid, _, pid_raw) in enumerate(data_source): 94 | if column[mapping[pid_raw]] == attr_value: 95 | self.attr_pool[0].append(index) 96 | else: 97 | self.attr_pool[1].append(index) 98 | self.attrs = list(self.attr_pool.keys()) 99 | self.num_attrs = len(self.attrs) 100 | 101 | def __iter__(self): 102 | ret = [] 103 | for i in range(700): 104 | t = self.attr_pool[self.attrs[i%2]] 105 | replace = False if len(t) >= self.sample_num else True 106 | t = np.random.choice(t, size=self.sample_num, replace=replace) 107 | ret.extend(t) 108 | return iter(ret) 109 | 110 | def __len__(self): 111 | return self.sample_num*700 -------------------------------------------------------------------------------- /advloss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | import random 4 | import numpy as np 5 | from torch import nn 6 | 7 | __all__ = ['DeepSupervision', 'adv_CrossEntropyLoss','adv_CrossEntropyLabelSmooth', 'adv_TripletLoss'] 8 | 9 | def DeepSupervision(criterion, xs, *args, **kwargs): 10 | loss = 0. 11 | for x in xs: loss += criterion(x, *args, **kwargs) 12 | return loss 13 | 14 | class adv_CrossEntropyLoss(nn.Module): 15 | def __init__(self, use_gpu=True): 16 | super(adv_CrossEntropyLoss, self).__init__() 17 | self.use_gpu = use_gpu 18 | self.crossentropy_loss = nn.CrossEntropyLoss() 19 | 20 | def forward(self, logits, pids): 21 | """ 22 | Args: 23 | logits: prediction matrix (before softmax) with shape (batch_size, num_classes) 24 | """ 25 | _, adv_target = torch.min(logits, 1) 26 | 27 | if self.use_gpu: adv_target = adv_target.cuda() 28 | loss = self.crossentropy_loss(logits, adv_target) 29 | return torch.log(loss) 30 | 31 | class adv_CrossEntropyLabelSmooth(nn.Module): 32 | """ 33 | Args: 34 | num_classes (int): number of classes. 35 | epsilon (float): weight. 36 | """ 37 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 38 | super(adv_CrossEntropyLabelSmooth, self).__init__() 39 | self.num_classes = num_classes 40 | self.epsilon = epsilon 41 | self.use_gpu = use_gpu 42 | self.logsoftmax = nn.LogSoftmax(dim=1) 43 | 44 | def forward(self, logits, pids): 45 | """ 46 | Args: 47 | logits: prediction matrix (before softmax) with shape (batch_size, num_classes) 48 | pids: ground truth labels with shape (num_classes) 49 | """ 50 | # n = pids.size(0) 51 | # _, top2 = torch.topk(logits, k=2, dim=1, largest=False) 52 | # adv_target = top2[:,0] 53 | # for i in range(n): 54 | # if adv_target[i] == pids[i]: adv_target[i] = top2[i,1] 55 | # else: continue 56 | _, adv_target = torch.min(logits, 1) 57 | # for i in range(n): 58 | # while adv_target[i] == pids[i]: 59 | # adv_target[i] = random.randint(0, self.num_classes) 60 | 61 | log_probs = self.logsoftmax(logits) 62 | adv_target = torch.zeros(log_probs.size()).scatter_(1, adv_target.unsqueeze(1).data.cpu(), 1) 63 | smooth = torch.ones(log_probs.size()) / (self.num_classes-1) 64 | smooth[:, pids.data.cpu()] = 0 # Pytorch1.0 65 | smooth = smooth.cuda() 66 | if self.use_gpu: adv_target = adv_target.cuda() 67 | adv_target = (1 - self.epsilon) * adv_target + self.epsilon * smooth 68 | loss = (- adv_target * log_probs).mean(0).sum() 69 | return torch.log(loss) 70 | 71 | class adv_TripletLoss(nn.Module): 72 | def __init__(self, ak_type, margin=0.3): 73 | super(adv_TripletLoss, self).__init__() 74 | self.margin = margin 75 | self.ak_type = ak_type 76 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 77 | 78 | def forward(self, features, pids, targets=None): 79 | """ 80 | Args: 81 | features: feature matrix with shape (batch_size, feat_dim) 82 | pids: ground truth labels with shape (num_classes) 83 | targets: pids with certain attribute (batch_size, pids) 84 | """ 85 | n = features.size(0) 86 | 87 | dist = torch.pow(features, 2).sum(dim=1, keepdim=True).expand(n, n) 88 | dist = dist + dist.t() 89 | dist.addmm_(1, -2, features, features.t()) 90 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 91 | 92 | if self.ak_type < 0: 93 | mask = pids.expand(n, n).eq(pids.expand(n, n).t()) 94 | dist_ap, dist_an = [], [] 95 | for i in range(n): 96 | dist_an.append(dist[i][mask[i]].min().unsqueeze(0)) # make nearest pos-pos far away 97 | dist_ap.append(dist[i][mask[i] == 0].max().unsqueeze(0)) # make hardest pos-neg closer 98 | 99 | elif self.ak_type > 0: 100 | p = [] 101 | for i in range(n): 102 | p.append(pids[i].item()) 103 | mask = targets[0][p].expand(n, n).eq(targets[0][p].expand(n, n).t()) 104 | dist_ap, dist_an = [], [] 105 | for i in range(n): 106 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 107 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 108 | 109 | dist_ap = torch.cat(dist_ap) 110 | dist_an = torch.cat(dist_an) 111 | 112 | y = torch.ones_like(dist_an) 113 | 114 | loss = self.ranking_loss(dist_an, dist_ap, y) 115 | return torch.log(loss) -------------------------------------------------------------------------------- /util/re_ranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | @author: luohao 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat = None, only_local = False): 29 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 30 | query_num = probFea.size(0) 31 | all_num = query_num + galFea.size(0) 32 | if only_local: 33 | original_dist = local_distmat 34 | else: 35 | feat = torch.cat([probFea,galFea]) 36 | print('using GPU to compute original distance') 37 | distmat = torch.pow(feat,2).sum(dim=1, keepdim=True).expand(all_num,all_num) + \ 38 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 39 | distmat.addmm_(1,-2,feat,feat.t()) 40 | original_dist = distmat.numpy() 41 | del feat 42 | if not local_distmat is None: 43 | original_dist = original_dist + local_distmat 44 | gallery_num = original_dist.shape[0] 45 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 46 | V = np.zeros_like(original_dist).astype(np.float16) 47 | initial_rank = np.argsort(original_dist).astype(np.int32) 48 | 49 | print('starting re_ranking') 50 | for i in range(all_num): 51 | # k-reciprocal neighbors 52 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 53 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 54 | fi = np.where(backward_k_neigh_index == i)[0] 55 | k_reciprocal_index = forward_k_neigh_index[fi] 56 | k_reciprocal_expansion_index = k_reciprocal_index 57 | for j in range(len(k_reciprocal_index)): 58 | candidate = k_reciprocal_index[j] 59 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 60 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 61 | :int(np.around(k1 / 2)) + 1] 62 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 63 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 64 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 65 | candidate_k_reciprocal_index): 66 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 67 | 68 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 69 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 70 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 71 | original_dist = original_dist[:query_num, ] 72 | if k2 != 1: 73 | V_qe = np.zeros_like(V, dtype=np.float16) 74 | for i in range(all_num): 75 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 76 | V = V_qe 77 | del V_qe 78 | del initial_rank 79 | invIndex = [] 80 | for i in range(gallery_num): 81 | invIndex.append(np.where(V[:, i] != 0)[0]) 82 | 83 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 84 | 85 | for i in range(query_num): 86 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 87 | indNonZero = np.where(V[i, :] != 0)[0] 88 | indImages = [invIndex[ind] for ind in indNonZero] 89 | for j in range(len(indNonZero)): 90 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 91 | V[indImages[j], indNonZero[j]]) 92 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 93 | 94 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 95 | del original_dist 96 | del V 97 | del jaccard_dist 98 | final_dist = final_dist[:query_num, query_num:] 99 | return final_dist 100 | 101 | -------------------------------------------------------------------------------- /util/ms_ssim.py: -------------------------------------------------------------------------------- 1 | """Code imported from https://github.com/jorge-pessoa/pytorch-msssim/blob/master/pytorch_msssim/__init__.py""" 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from math import exp 6 | import numpy as np 7 | 8 | 9 | def gaussian(window_size, sigma): 10 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 11 | return gauss/gauss.sum() 12 | 13 | 14 | def create_window(window_size, channel=1): 15 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 16 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 17 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 18 | return window 19 | 20 | 21 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 22 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 23 | if val_range is None: 24 | if torch.max(img1) > 128: 25 | max_val = 255 26 | else: 27 | max_val = 1 28 | 29 | if torch.min(img1) < -0.5: 30 | min_val = -1 31 | else: 32 | min_val = 0 33 | L = max_val - min_val 34 | else: 35 | L = val_range 36 | 37 | padd = 0 38 | (_, channel, height, width) = img1.size() 39 | if window is None: 40 | real_size = min(window_size, height, width) 41 | window = create_window(real_size, channel=channel).to(img1.device) 42 | 43 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 44 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 45 | 46 | mu1_sq = mu1.pow(2) 47 | mu2_sq = mu2.pow(2) 48 | mu1_mu2 = mu1 * mu2 49 | 50 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 51 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 52 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 53 | 54 | C1 = (0.01 * L) ** 2 55 | C2 = (0.03 * L) ** 2 56 | 57 | v1 = 2.0 * sigma12 + C2 58 | v2 = sigma1_sq + sigma2_sq + C2 59 | cs = torch.mean(v1 / v2) # contrast sensitivity 60 | 61 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 62 | 63 | if size_average: 64 | ret = ssim_map.mean() 65 | else: 66 | ret = ssim_map.mean(1).mean(1).mean(1) 67 | 68 | if full: 69 | return ret, cs 70 | return ret 71 | 72 | 73 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): 74 | device = img1.device 75 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 76 | levels = weights.size()[0] 77 | mssim = [] 78 | mcs = [] 79 | for _ in range(levels): 80 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 81 | mssim.append(sim) 82 | mcs.append(cs) 83 | 84 | img1 = F.avg_pool2d(img1, (2, 2)) 85 | img2 = F.avg_pool2d(img2, (2, 2)) 86 | 87 | mssim = torch.stack(mssim) 88 | mcs = torch.stack(mcs) 89 | 90 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) 91 | if normalize: 92 | mssim = (mssim + 1) / 2 93 | mcs = (mcs + 1) / 2 94 | 95 | pow1 = mcs ** weights 96 | pow2 = mssim ** weights 97 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 98 | output = torch.prod(pow1[:-1] * pow2[-1]) 99 | return output 100 | 101 | 102 | # Classes to re-use window 103 | class SSIM(torch.nn.Module): 104 | def __init__(self, window_size=11, size_average=True, val_range=None): 105 | super(SSIM, self).__init__() 106 | self.window_size = window_size 107 | self.size_average = size_average 108 | self.val_range = val_range 109 | 110 | # Assume 1 channel for SSIM 111 | self.channel = 1 112 | self.window = create_window(window_size) 113 | 114 | def forward(self, img1, img2): 115 | (_, channel, _, _) = img1.size() 116 | 117 | if channel == self.channel and self.window.dtype == img1.dtype: 118 | window = self.window 119 | else: 120 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 121 | self.window = window 122 | self.channel = channel 123 | 124 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 125 | 126 | class MSSSIM(torch.nn.Module): 127 | def __init__(self, window_size=11, size_average=True, channel=3): 128 | super(MSSSIM, self).__init__() 129 | self.window_size = window_size 130 | self.size_average = size_average 131 | self.channel = channel 132 | 133 | def forward(self, img1, img2): 134 | # TODO: store window between calls if possible 135 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) -------------------------------------------------------------------------------- /util/local_dist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def batch_euclidean_dist(x, y): 4 | """ 5 | Args: 6 | x: pytorch Variable, with shape [Batch size, Local part, Feature channel] 7 | y: pytorch Variable, with shape [Batch size, Local part, Feature channel] 8 | Returns: 9 | dist: pytorch Variable, with shape [Batch size, Local part, Local part] 10 | """ 11 | assert len(x.size()) == 3 12 | assert len(y.size()) == 3 13 | assert x.size(0) == y.size(0) 14 | assert x.size(-1) == y.size(-1) 15 | 16 | N, m, d = x.size() 17 | N, n, d = y.size() 18 | 19 | # shape [N, m, n] 20 | xx = torch.pow(x, 2).sum(-1, keepdim=True).expand(N, m, n) 21 | yy = torch.pow(y, 2).sum(-1, keepdim=True).expand(N, n, m).permute(0, 2, 1) 22 | dist = xx + yy 23 | dist.baddbmm_(1, -2, x, y.permute(0, 2, 1)) 24 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 25 | return dist 26 | 27 | def shortest_dist(dist_mat): 28 | """Parallel version. 29 | Args: 30 | dist_mat: pytorch Variable, available shape: 31 | 1) [m, n] 32 | 2) [m, n, N], N is batch size 33 | 3) [m, n, *], * can be arbitrary additional dimensions 34 | Returns: 35 | dist: three cases corresponding to `dist_mat`: 36 | 1) scalar 37 | 2) pytorch Variable, with shape [N] 38 | 3) pytorch Variable, with shape [*] 39 | """ 40 | m, n = dist_mat.size()[:2] 41 | # Just offering some reference for accessing intermediate distance. 42 | dist = [[0 for _ in range(n)] for _ in range(m)] 43 | for i in range(m): 44 | for j in range(n): 45 | if (i == 0) and (j == 0): 46 | dist[i][j] = dist_mat[i, j] 47 | elif (i == 0) and (j > 0): 48 | dist[i][j] = dist[i][j - 1] + dist_mat[i, j] 49 | elif (i > 0) and (j == 0): 50 | dist[i][j] = dist[i - 1][j] + dist_mat[i, j] 51 | else: 52 | dist[i][j] = torch.min(dist[i - 1][j], dist[i][j - 1]) + dist_mat[i, j] 53 | dist = dist[-1][-1] 54 | return dist 55 | 56 | def hard_example_mining(dist_mat, labels, return_inds=False): 57 | """For each anchor, find the hardest positive and negative sample. 58 | Args: 59 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 60 | labels: pytorch LongTensor, with shape [N] 61 | return_inds: whether to return the indices. Save time if `False`(?) 62 | Returns: 63 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 64 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 65 | p_inds: pytorch LongTensor, with shape [N]; 66 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 67 | n_inds: pytorch LongTensor, with shape [N]; 68 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 69 | NOTE: Only consider the case in which all labels have same num of samples, 70 | thus we can cope with all anchors in parallel. 71 | """ 72 | 73 | assert len(dist_mat.size()) == 2 74 | assert dist_mat.size(0) == dist_mat.size(1) 75 | N = dist_mat.size(0) 76 | 77 | # shape [N, N] 78 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 79 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 80 | 81 | # `dist_ap` means distance(anchor, positive) 82 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 83 | dist_ap, relative_p_inds = torch.max(dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 84 | # `dist_an` means distance(anchor, negative) 85 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 86 | dist_an, relative_n_inds = torch.min(dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 87 | # shape [N] 88 | dist_ap = dist_ap.squeeze(1) 89 | dist_an = dist_an.squeeze(1) 90 | 91 | if return_inds: 92 | # shape [N, N] 93 | ind = (labels.new().resize_as_(labels).copy_(torch.arange(0, N).long()).unsqueeze( 0).expand(N, N)) 94 | # shape [N, 1] 95 | p_inds = torch.gather(ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 96 | n_inds = torch.gather(ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 97 | # shape [N] 98 | p_inds = p_inds.squeeze(1) 99 | n_inds = n_inds.squeeze(1) 100 | return dist_ap, dist_an, p_inds, n_inds 101 | 102 | return dist_ap, dist_an 103 | 104 | def euclidean_dist(x, y): 105 | """ 106 | Args: 107 | x: pytorch Variable, with shape [m, d] 108 | y: pytorch Variable, with shape [n, d] 109 | Returns: 110 | dist: pytorch Variable, with shape [m, n] 111 | """ 112 | m, n = x.size(0), y.size(0) 113 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 114 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 115 | dist = xx + yy 116 | dist.addmm_(1, -2, x, y.t()) 117 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 118 | return dist 119 | 120 | def batch_local_dist(x, y): 121 | """ 122 | Args: 123 | x: pytorch Variable, with shape [N, m, d] 124 | y: pytorch Variable, with shape [N, n, d] 125 | Returns: 126 | dist: pytorch Variable, with shape [N] 127 | """ 128 | assert len(x.size()) == 3 129 | assert len(y.size()) == 3 130 | assert x.size(0) == y.size(0) 131 | assert x.size(-1) == y.size(-1) 132 | 133 | # shape [N, m, n] 134 | dist_mat = batch_euclidean_dist(x, y) 135 | dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.) 136 | # shape [N] 137 | dist = shortest_dist(dist_mat.permute(1, 2, 0)) 138 | return dist 139 | 140 | if __name__ == '__main__': 141 | x = torch.randn(32,2048) 142 | y = torch.randn(32,2048) 143 | dist_mat = euclidean_dist(x,y) 144 | dist_ap, dist_an, p_inds, n_inds = hard_example_mining(dist_mat,return_inds=True) 145 | from IPython import embed 146 | embed() -------------------------------------------------------------------------------- /models/MuDeep.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | import torchvision 7 | 8 | __all__ = ['MuDeep'] 9 | 10 | class ConvBlock(nn.Module): 11 | """Basic convolutional block: 12 | convolution + batch normalization + relu. 13 | Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d): 14 | in_c (int): number of input channels. 15 | out_c (int): number of output channels. 16 | k (int or tuple): kernel size. 17 | s (int or tuple): stride. 18 | p (int or tuple): padding. 19 | """ 20 | def __init__(self, in_c, out_c, k, s, p): 21 | super(ConvBlock, self).__init__() 22 | self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) 23 | self.bn = nn.BatchNorm2d(out_c) 24 | 25 | def forward(self, x): 26 | return F.relu(self.bn(self.conv(x))) 27 | 28 | class ConvLayers(nn.Module): 29 | """Preprocessing layers.""" 30 | def __init__(self): 31 | super(ConvLayers, self).__init__() 32 | self.conv1 = ConvBlock(3, 48, k=3, s=1, p=1) 33 | self.conv2 = ConvBlock(48, 96, k=3, s=1, p=1) 34 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 35 | 36 | def forward(self, x): 37 | x = self.conv1(x) 38 | x = self.conv2(x) 39 | x = self.maxpool(x) 40 | return x 41 | 42 | class MultiScaleA(nn.Module): 43 | """Multi-scale stream layer A (Sec.3.1)""" 44 | def __init__(self): 45 | super(MultiScaleA, self).__init__() 46 | self.stream1 = nn.Sequential( 47 | ConvBlock(96, 96, k=1, s=1, p=0), 48 | ConvBlock(96, 24, k=3, s=1, p=1), 49 | ) 50 | self.stream2 = nn.Sequential( 51 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 52 | ConvBlock(96, 24, k=1, s=1, p=0), 53 | ) 54 | self.stream3 = ConvBlock(96, 24, k=1, s=1, p=0) 55 | self.stream4 = nn.Sequential( 56 | ConvBlock(96, 16, k=1, s=1, p=0), 57 | ConvBlock(16, 24, k=3, s=1, p=1), 58 | ConvBlock(24, 24, k=3, s=1, p=1), 59 | ) 60 | 61 | def forward(self, x): 62 | s1 = self.stream1(x) 63 | s2 = self.stream2(x) 64 | s3 = self.stream3(x) 65 | s4 = self.stream4(x) 66 | y = torch.cat([s1, s2, s3, s4], dim=1) 67 | return y 68 | 69 | class Reduction(nn.Module): 70 | """Reduction layer (Sec.3.1)""" 71 | def __init__(self): 72 | super(Reduction, self).__init__() 73 | self.stream1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 74 | self.stream2 = ConvBlock(96, 96, k=3, s=2, p=1) 75 | self.stream3 = nn.Sequential( 76 | ConvBlock(96, 48, k=1, s=1, p=0), 77 | ConvBlock(48, 56, k=3, s=1, p=1), 78 | ConvBlock(56, 64, k=3, s=2, p=1), 79 | ) 80 | 81 | def forward(self, x): 82 | s1 = self.stream1(x) 83 | s2 = self.stream2(x) 84 | s3 = self.stream3(x) 85 | y = torch.cat([s1, s2, s3], dim=1) 86 | return y 87 | 88 | class MultiScaleB(nn.Module): 89 | """Multi-scale stream layer B (Sec.3.1)""" 90 | def __init__(self): 91 | super(MultiScaleB, self).__init__() 92 | self.stream1 = nn.Sequential( 93 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 94 | ConvBlock(256, 256, k=1, s=1, p=0), 95 | ) 96 | self.stream2 = nn.Sequential( 97 | ConvBlock(256, 64, k=1, s=1, p=0), 98 | ConvBlock(64, 128, k=(1, 3), s=1, p=(0, 1)), 99 | ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)), 100 | ) 101 | self.stream3 = ConvBlock(256, 256, k=1, s=1, p=0) 102 | self.stream4 = nn.Sequential( 103 | ConvBlock(256, 64, k=1, s=1, p=0), 104 | ConvBlock(64, 64, k=(1, 3), s=1, p=(0, 1)), 105 | ConvBlock(64, 128, k=(3, 1), s=1, p=(1, 0)), 106 | ConvBlock(128, 128, k=(1, 3), s=1, p=(0, 1)), 107 | ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)), 108 | ) 109 | 110 | def forward(self, x): 111 | s1 = self.stream1(x) 112 | s2 = self.stream2(x) 113 | s3 = self.stream3(x) 114 | s4 = self.stream4(x) 115 | return s1, s2, s3, s4 116 | 117 | class Fusion(nn.Module): 118 | """Saliency-based learning fusion layer (Sec.3.2)""" 119 | def __init__(self): 120 | super(Fusion, self).__init__() 121 | self.a1 = nn.Parameter(torch.rand(1, 256, 1, 1)) 122 | self.a2 = nn.Parameter(torch.rand(1, 256, 1, 1)) 123 | self.a3 = nn.Parameter(torch.rand(1, 256, 1, 1)) 124 | self.a4 = nn.Parameter(torch.rand(1, 256, 1, 1)) 125 | 126 | # We add an average pooling layer to reduce the spatial dimension 127 | # of feature maps, which differs from the original paper. 128 | self.avgpool = nn.AvgPool2d(kernel_size=4, stride=4, padding=0) 129 | 130 | def forward(self, x1, x2, x3, x4): 131 | s1 = self.a1.expand_as(x1) * x1 132 | s2 = self.a2.expand_as(x2) * x2 133 | s3 = self.a3.expand_as(x3) * x3 134 | s4 = self.a4.expand_as(x4) * x4 135 | y = self.avgpool(s1 + s2 + s3 + s4) 136 | return y 137 | 138 | class MuDeep(nn.Module): 139 | """Multiscale deep neural network. 140 | Reference: 141 | Qian et al. Multi-scale Deep Learning Architectures for Person Re-identification. ICCV 2017. 142 | """ 143 | def __init__(self, num_classes, loss={'xent', 'htri'}, **kwargs): 144 | super(MuDeep, self).__init__() 145 | self.loss = loss 146 | 147 | self.block1 = ConvLayers() 148 | self.block2 = MultiScaleA() 149 | self.block3 = Reduction() 150 | self.block4 = MultiScaleB() 151 | self.block5 = Fusion() 152 | 153 | # Due to this fully connected layer, input image has to be fixed 154 | # in shape, i.e. (3, 256, 128), such that the last convolutional feature 155 | # maps are of shape (256, 16, 8). If input shape is changed, 156 | # the input dimension of this layer has to be changed accordingly. 157 | self.fc = nn.Sequential( 158 | nn.Linear(256*16*8, 4096), 159 | nn.BatchNorm1d(4096), 160 | nn.ReLU(), 161 | ) 162 | self.classifier = nn.Linear(4096, num_classes) 163 | self.feat_dim = 4096 # feature dimension 164 | 165 | def forward(self, x, is_training): 166 | x = self.block1(x) 167 | x = self.block2(x) 168 | x = self.block3(x) 169 | x = self.block4(x) 170 | x = self.block5(*x) 171 | x = x.view(x.size(0), -1) 172 | x = self.fc(x) 173 | y = self.classifier(x) 174 | 175 | if self.loss == {'xent'}: 176 | return [y] 177 | elif self.loss == {'xent', 'htri'}: 178 | return [y, x] 179 | elif self.loss == {'cent'}: 180 | return [y, x] 181 | else: 182 | raise KeyError("Unsupported loss: {}".format(self.loss)) -------------------------------------------------------------------------------- /util/distance.py: -------------------------------------------------------------------------------- 1 | """Numpy version of euclidean distance, shortest distance, etc. 2 | Notice the input/output shape of methods, so that you can better understand 3 | the meaning of these methods.""" 4 | import numpy as np 5 | 6 | 7 | def normalize(nparray, order=2, axis=0): 8 | """Normalize a N-D numpy array along the specified axis.""" 9 | norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True) 10 | return nparray / (norm + np.finfo(np.float32).eps) 11 | 12 | 13 | def compute_dist(array1, array2, type='euclidean'): 14 | """Compute the euclidean or cosine distance of all pairs. 15 | Args: 16 | array1: numpy array with shape [m1, n] 17 | array2: numpy array with shape [m2, n] 18 | type: one of ['cosine', 'euclidean'] 19 | Returns: 20 | numpy array with shape [m1, m2] 21 | """ 22 | assert type in ['cosine', 'euclidean'] 23 | if type == 'cosine': 24 | array1 = normalize(array1, axis=1) 25 | array2 = normalize(array2, axis=1) 26 | dist = np.matmul(array1, array2.T) 27 | return dist 28 | else: 29 | # shape [m1, 1] 30 | square1 = np.sum(np.square(array1), axis=1)[..., np.newaxis] 31 | # shape [1, m2] 32 | square2 = np.sum(np.square(array2), axis=1)[np.newaxis, ...] 33 | squared_dist = - 2 * np.matmul(array1, array2.T) + square1 + square2 34 | squared_dist[squared_dist < 0] = 0 35 | dist = np.sqrt(squared_dist) 36 | return dist 37 | 38 | 39 | def shortest_dist(dist_mat): 40 | """Parallel version. 41 | Args: 42 | dist_mat: numpy array, available shape 43 | 1) [m, n] 44 | 2) [m, n, N], N is batch size 45 | 3) [m, n, *], * can be arbitrary additional dimensions 46 | Returns: 47 | dist: three cases corresponding to `dist_mat` 48 | 1) scalar 49 | 2) numpy array, with shape [N] 50 | 3) numpy array with shape [*] 51 | """ 52 | m, n = dist_mat.shape[:2] 53 | dist = np.zeros_like(dist_mat) 54 | for i in range(m): 55 | for j in range(n): 56 | if (i == 0) and (j == 0): 57 | dist[i, j] = dist_mat[i, j] 58 | elif (i == 0) and (j > 0): 59 | dist[i, j] = dist[i, j - 1] + dist_mat[i, j] 60 | elif (i > 0) and (j == 0): 61 | dist[i, j] = dist[i - 1, j] + dist_mat[i, j] 62 | else: 63 | dist[i, j] = \ 64 | np.min(np.stack([dist[i - 1, j], dist[i, j - 1]], axis=0), axis=0) \ 65 | + dist_mat[i, j] 66 | # I ran into memory disaster when returning this reference! I still don't 67 | # know why. 68 | # dist = dist[-1, -1] 69 | dist = dist[-1, -1].copy() 70 | return dist 71 | 72 | def unaligned_dist(dist_mat): 73 | """Parallel version. 74 | Args: 75 | dist_mat: numpy array, available shape 76 | 1) [m, n] 77 | 2) [m, n, N], N is batch size 78 | 3) [m, n, *], * can be arbitrary additional dimensions 79 | Returns: 80 | dist: three cases corresponding to `dist_mat` 81 | 1) scalar 82 | 2) numpy array, with shape [N] 83 | 3) numpy array with shape [*] 84 | """ 85 | 86 | m = dist_mat.shape[0] 87 | dist = np.zeros_like(dist_mat[0]) 88 | for i in range(m): 89 | dist[i] = dist_mat[i][i] 90 | dist = np.sum(dist, axis=0).copy() 91 | return dist 92 | 93 | 94 | def meta_local_dist(x, y, aligned): 95 | """ 96 | Args: 97 | x: numpy array, with shape [m, d] 98 | y: numpy array, with shape [n, d] 99 | Returns: 100 | dist: scalar 101 | """ 102 | eu_dist = compute_dist(x, y, 'euclidean') 103 | dist_mat = (np.exp(eu_dist) - 1.) / (np.exp(eu_dist) + 1.) 104 | if aligned: 105 | dist = shortest_dist(dist_mat[np.newaxis])[0] 106 | else: 107 | dist = unaligned_dist(dist_mat[np.newaxis])[0] 108 | return dist 109 | 110 | 111 | # Tooooooo slow! 112 | def serial_local_dist(x, y): 113 | """ 114 | Args: 115 | x: numpy array, with shape [M, m, d] 116 | y: numpy array, with shape [N, n, d] 117 | Returns: 118 | dist: numpy array, with shape [M, N] 119 | """ 120 | M, N = x.shape[0], y.shape[0] 121 | dist_mat = np.zeros([M, N]) 122 | for i in range(M): 123 | for j in range(N): 124 | dist_mat[i, j] = meta_local_dist(x[i], y[j]) 125 | return dist_mat 126 | 127 | 128 | def parallel_local_dist(x, y, aligned): 129 | """Parallel version. 130 | Args: 131 | x: numpy array, with shape [M, m, d] 132 | y: numpy array, with shape [N, n, d] 133 | Returns: 134 | dist: numpy array, with shape [M, N] 135 | """ 136 | M, m, d = x.shape 137 | N, n, d = y.shape 138 | x = x.reshape([M * m, d]) 139 | y = y.reshape([N * n, d]) 140 | # shape [M * m, N * n] 141 | dist_mat = compute_dist(x, y, type='euclidean') 142 | dist_mat = (np.exp(dist_mat) - 1.) / (np.exp(dist_mat) + 1.) 143 | # shape [M * m, N * n] -> [M, m, N, n] -> [m, n, M, N] 144 | dist_mat = dist_mat.reshape([M, m, N, n]).transpose([1, 3, 0, 2]) 145 | # shape [M, N] 146 | if aligned: 147 | dist_mat = shortest_dist(dist_mat) 148 | else: 149 | dist_mat = unaligned_dist(dist_mat) 150 | return dist_mat 151 | 152 | 153 | def local_dist(x, y, aligned): 154 | if (x.ndim == 2) and (y.ndim == 2): 155 | return meta_local_dist(x, y, aligned) 156 | elif (x.ndim == 3) and (y.ndim == 3): 157 | return parallel_local_dist(x, y, aligned) 158 | else: 159 | raise NotImplementedError('Input shape not supported.') 160 | 161 | 162 | def low_memory_matrix_op( 163 | func, 164 | x, y, 165 | x_split_axis, y_split_axis, 166 | x_num_splits, y_num_splits, 167 | verbose=False, aligned=True): 168 | """ 169 | For matrix operation like multiplication, in order not to flood the memory 170 | with huge data, split matrices into smaller parts (Divide and Conquer). 171 | 172 | Note: 173 | If still out of memory, increase `*_num_splits`. 174 | 175 | Args: 176 | func: a matrix function func(x, y) -> z with shape [M, N] 177 | x: numpy array, the dimension to split has length M 178 | y: numpy array, the dimension to split has length N 179 | x_split_axis: The axis to split x into parts 180 | y_split_axis: The axis to split y into parts 181 | x_num_splits: number of splits. 1 <= x_num_splits <= M 182 | y_num_splits: number of splits. 1 <= y_num_splits <= N 183 | verbose: whether to print the progress 184 | 185 | Returns: 186 | mat: numpy array, shape [M, N] 187 | """ 188 | 189 | if verbose: 190 | import sys 191 | import time 192 | printed = True 193 | st = time.time() 194 | last_time = time.time() 195 | 196 | mat = [[] for _ in range(x_num_splits)] 197 | for i, part_x in enumerate( 198 | np.array_split(x, x_num_splits, axis=x_split_axis)): 199 | for j, part_y in enumerate( 200 | np.array_split(y, y_num_splits, axis=y_split_axis)): 201 | part_mat = func(part_x, part_y, aligned) 202 | mat[i].append(part_mat) 203 | 204 | if verbose: 205 | if not printed: 206 | printed = True 207 | else: 208 | # Clean the current line 209 | sys.stdout.write("\033[F\033[K") 210 | print('Matrix part ({}, {}) / ({}, {}), +{:.2f}s, total {:.2f}s' 211 | .format(i + 1, j + 1, x_num_splits, y_num_splits, 212 | time.time() - last_time, time.time() - st)) 213 | last_time = time.time() 214 | mat[i] = np.concatenate(mat[i], axis=1) 215 | mat = np.concatenate(mat, axis=0) 216 | return mat 217 | 218 | 219 | def low_memory_local_dist(x, y, aligned=True): 220 | print('Computing local distance...') 221 | x_num_splits = int(len(x) / 200) + 1 222 | y_num_splits = int(len(y) / 200) + 1 223 | z = low_memory_matrix_op(local_dist, x, y, 0, 0, x_num_splits, y_num_splits, verbose=True, aligned=aligned) 224 | return z -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial-attack-on-Person-ReID-With-Deep-Mis-Ranking 2 | This is the code for the [CVPR'20 paper](https://arxiv.org/abs/2004.04199) "Transferable, Controllable, and Inconspicuous Adversarial Attacks on Person Re-identification With Deep Mis-Ranking." by Hongjun Wang, Guangrun Wang, Ya Li, Dongyu Zhang, Liang Lin. 3 | 4 |

5 | 6 |

7 | 8 | # Prerequisites 9 | * Python2 / Python3 10 | * Pytorch0.4.1 (do not test for >=Pytorch1.0) 11 | * CUDA 12 | * Numpy 13 | * Matplotlib 14 | * Scipy 15 | 16 | # Prepare data 17 | Create a directory to store reid datasets under this repo 18 | ```bash 19 | mkdir data/ 20 | ``` 21 | 22 | If you wanna store datasets in another directory, you need to specify `--root path_to_your/data` when running the training code. Please follow the instructions below to prepare each dataset. After that, you can simply do `-d the_dataset` when running the training code. 23 | 24 | **Market1501** : 25 | 26 | 1. Download dataset to `data/` from http://www.liangzheng.org/Project/project_reid.html. 27 | 2. Extract dataset and rename to `market1501`. The data structure would look like: 28 | ``` 29 | market1501/ 30 | bounding_box_test/ 31 | bounding_box_train/ 32 | ... 33 | ``` 34 | 3. Use `-d market1501` when running the training code. 35 | 36 | **CUHK03** [13]: 37 | 1. Create a folder named `cuhk03/` under `data/`. 38 | 2. Download dataset to `data/cuhk03/` from http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html and extract `cuhk03_release.zip`, so you will have `data/cuhk03/cuhk03_release`. 39 | 3. Download new split [14] from [person-re-ranking](https://github.com/zhunzhong07/person-re-ranking/tree/master/evaluation/data/CUHK03). What you need are `cuhk03_new_protocol_config_detected.mat` and `cuhk03_new_protocol_config_labeled.mat`. Put these two mat files under `data/cuhk03`. Finally, the data structure would look like 40 | ``` 41 | cuhk03/ 42 | cuhk03_release/ 43 | cuhk03_new_protocol_config_detected.mat 44 | cuhk03_new_protocol_config_labeled.mat 45 | ... 46 | ``` 47 | 4. Use `-d cuhk03` when running the training code. In default mode, we use new split (767/700). If you wanna use the original splits (1367/100) created by [13], specify `--cuhk03-classic-split`. As [13] computes CMC differently from Market1501, you might need to specify `--use-metric-cuhk03` for fair comparison with their method. In addition, we support both `labeled` and `detected` modes. The default mode loads `detected` images. Specify `--cuhk03-labeled` if you wanna train and test on `labeled` images. 48 | 49 | **DukeMTMC-reID** [16, 17]: 50 | 51 | 1. Create a directory under `data/` called `dukemtmc-reid`. 52 | 2. Download dataset `DukeMTMC-reID.zip` from https://github.com/layumi/DukeMTMC-reID_evaluation#download-dataset and put it to `data/dukemtmc-reid`. Extract the zip file, which leads to 53 | ``` 54 | dukemtmc-reid/ 55 | DukeMTMC-reid.zip # (you can delete this zip file, it is ok) 56 | DukeMTMC-reid/ # this folder contains 8 files. 57 | ``` 58 | 3. Use `-d dukemtmcreid` when running the training code. 59 | 60 | 61 | **MSMT17** [22]: 62 | 1. Create a directory named `msmt17/` under `data/`. 63 | 2. Download dataset `MSMT17_V1.tar.gz` to `data/msmt17/` from http://www.pkuvmc.com/publications/msmt17.html. Extract the file under the same folder, so you will have 64 | ``` 65 | msmt17/ 66 | MSMT17_V1.tar.gz # (do whatever you want with this .tar file) 67 | MSMT17_V1/ 68 | train/ 69 | test/ 70 | list_train.txt 71 | ... (totally six .txt files) 72 | ``` 73 | 3. Use `-d msmt17` when running the training code. 74 | 75 | # Prepare pretrained ReID models 76 | 1. Create a directory to store reid pretrained models under this repo 77 | 78 | ```bash 79 | mkdir models/ 80 | ``` 81 | 2. Download the pretrained models or train the models from scratch by yourself offline 82 | 83 | 2.1 Download Links 84 | 85 | [IDE](https://drive.google.com/open?id=1hVYGcuhfwMs25QVdo2R-ugXW4WyAzuHF) 86 | 87 | [DenseNet121](https://drive.google.com/drive/folders/1XSiVo0lqULQJyYv4T2pt6uA4qtxKSb6X?usp=sharing) 88 | 89 | [AlignedReID](https://drive.google.com/open?id=1YZ7J85f1Fcjft7sh2rlPs1s0dlcaFpf-) 90 | 91 | [PCB](https://drive.google.com/open?id=1xkA981JDESHxhGM_2N-ZdvboVXXzi3yd) 92 | 93 | [Mudeep](https://drive.google.com/open?id=1g6HBt5uCVSbLQL1JUOY_jZZqYKtRmVsX) 94 | 95 | [HACNN](https://drive.google.com/open?id=1ZxzY149vgagHzDUQLMuJqCpCSG3mtH3M) 96 | 97 | [CamStyle](https://drive.google.com/open?id=11WsAyhme4p8i3lNehYpfdB0jZtSSOTzx) 98 | 99 | [LSRO](https://drive.google.com/drive/folders/1cxeOJ3FU6qraHWU927HWC24E_MpXghP5?usp=sharing) 100 | 101 | [HHL](https://drive.google.com/open?id=1ZStrZ6qrB_kgcoB9BLXre81RiXtybBXD) 102 | 103 | [SPGAN](https://drive.google.com/open?id=1YwnmBjfhBHlVQmTRn1ehaHRe5cXVGg5Z) 104 | 105 | 2.2 Training models from scratch (optional) 106 | 107 | Create a directory named by the targeted model (like `aligned/` or `hacnn/`) following `__init__.py`under `models/` and move the checkpoint of pretrained models to this directory. Details of naming rules can refer to the download link. 108 | 109 | 3. Customized ReID models (optional) 110 | 111 | It is easy to test the robustness of any customized ReID models following the above steps (1→2.2→3). The extra thing you need to do is to add the structure of your own models to `models/` and register it in`__init__.py` . 112 | 113 | # Train 114 | Take attacking AlignedReID trained on Market1501 as an example: 115 | 116 | ```bash 117 | python train.py \ 118 | --targetmodel='aligned' \ 119 | --dataset='market1501'\ 120 | --mode='train' \ 121 | --loss='xent_htri' \ 122 | --ak_type=-1 \ 123 | --temperature=-1 \ 124 | --use_SSIM=2 \ 125 | --epoch=40 126 | ``` 127 | 128 | # Test 129 | Take attacking AlignedReID trained on Market1501 as an example: 130 | 131 | ```bash 132 | python train.py \ 133 | --targetmodel='aligned' \ 134 | --dataset='market1501'\ 135 | --G_resume_dir='./logs/aligned/market1501/best_G.pth.tar' \ 136 | --mode='test' \ 137 | --loss='xent_htri' \ 138 | --ak_type=-1 \ 139 | --temperature=-1 \ 140 | --use_SSIM=2 \ 141 | --epoch=40 142 | ``` 143 | 144 | # Results 145 | 146 |

147 | 148 |

149 | 150 | 151 | # Reminders 152 | 153 | 1. If you are using your *own* trained ReID models (no matter whether they are customized), be careful about the name of variables and properly change or hold Line 38–53 in `__init__.py` (adaptation to early Pytorch0.3 trained models). 154 | 2. You may notice some arguments and codes involve the attribute information, if you are interested in that you can easily find and download the extra attribute files about Market1501 or DukeMTMC. We have conducted some related experiments about attribute attack but it is *not* the main content of this paper so I delete that part of code. 155 | 156 | # Reference 157 | 158 | If you are interested in our work, please consider citing our paper. 159 | ``` 160 | @InProceedings{Wang_2020_CVPR, 161 | author = {Wang, Hongjun and Wang, Guangrun and Li, Ya and Zhang, Dongyu and Lin, Liang}, 162 | title = {Transferable, Controllable, and Inconspicuous Adversarial Attacks on Person Re-identification With Deep Mis-Ranking}, 163 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 164 | month = {June}, 165 | year = {2020} 166 | } 167 | ``` 168 | 169 | # Acknowledgements 170 | Thanks for the following excellent works: 171 | 172 | - Open-reid [code](https://github.com/Cysu/open-reid) 173 | - AlignedReID [paper](https://www.sciencedirect.com/science/article/abs/pii/S0031320319302031?via%3Dihub#!) and [code](https://github.com/michuanhaohao/AlignedReID) by michuanhaohao 174 | - Person ReID baseline [code](https://github.com/layumi/Person_reID_baseline_pytorch) by layumi 175 | - LSRO [paper](https://arxiv.org/abs/1701.07717) and [code](https://github.com/layumi/Person-reID_GAN) by layumi 176 | - HHL [paper](http://openaccess.thecvf.com/content_ECCV_2018/html/Zhun_Zhong_Generalizing_A_Person_ECCV_2018_paper.html) and [code](https://github.com/zhunzhong07/HHL) by zhunzhong07 177 | -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | import os 4 | import sys 5 | import errno 6 | import shutil 7 | import json 8 | import time 9 | import os.path as osp 10 | from PIL import Image 11 | import matplotlib 12 | import numpy as np 13 | from numpy import array, argmin 14 | 15 | import torch 16 | 17 | def mkdir_if_missing(directory): 18 | if not osp.exists(directory): 19 | try: 20 | os.makedirs(directory) 21 | except OSError as e: 22 | if e.errno != errno.EEXIST: 23 | raise 24 | 25 | def fliplr(img): 26 | '''flip horizontal''' 27 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long().cuda() # N x C x H x W 28 | img_flip = img.index_select(3,inv_idx) 29 | return img_flip 30 | 31 | def save_heatmap(path, den): 32 | matplotlib.use('Agg') 33 | import matplotlib.pyplot as plt 34 | from matplotlib.colors import PowerNorm, LogNorm 35 | import matplotlib.cm as cm 36 | plt.axis('off') 37 | plt.imshow(den, 38 | cmap=cm.jet, 39 | Norm=LogNorm(), 40 | interpolation="bicubic") 41 | # save fig 42 | fig = plt.gcf() 43 | fig.savefig(path, format='png', bbox_inches='tight', transparent=True, dpi=600) 44 | plt.close('all') 45 | 46 | class AverageMeter(object): 47 | """Computes and stores the average and current value. 48 | 49 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 50 | """ 51 | def __init__(self): 52 | self.reset() 53 | 54 | def reset(self): 55 | self.val = 0 56 | self.avg = 0 57 | self.sum = 0 58 | self.count = 0 59 | 60 | def update(self, val, n=1): 61 | self.val = val 62 | self.sum += val * n 63 | self.count += n 64 | self.avg = self.sum / self.count 65 | 66 | def save_checkpoint(state, is_best, G_or_D, fpath='checkpoint.pth.tar'): 67 | mkdir_if_missing(osp.dirname(fpath)) 68 | torch.save(state, fpath) 69 | if is_best: 70 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_'+ G_or_D +'.pth.tar')) 71 | 72 | class Logger(object): 73 | """ 74 | Write console output to external text file. 75 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 76 | """ 77 | def __init__(self, fpath=None): 78 | self.console = sys.stdout 79 | self.file = None 80 | if fpath is not None: 81 | mkdir_if_missing(os.path.dirname(fpath)) 82 | self.file = open(fpath, 'w') 83 | 84 | def __del__(self): 85 | self.close() 86 | 87 | def __enter__(self): 88 | pass 89 | 90 | def __exit__(self, *args): 91 | self.close() 92 | 93 | def write(self, msg): 94 | self.console.write(msg) 95 | if self.file is not None: 96 | self.file.write(msg) 97 | 98 | def flush(self): 99 | self.console.flush() 100 | if self.file is not None: 101 | self.file.flush() 102 | os.fsync(self.file.fileno()) 103 | 104 | def close(self): 105 | self.console.close() 106 | if self.file is not None: 107 | self.file.close() 108 | 109 | def read_json(fpath): 110 | with open(fpath, 'r') as f: 111 | obj = json.load(f) 112 | return obj 113 | 114 | def write_json(obj, fpath): 115 | mkdir_if_missing(osp.dirname(fpath)) 116 | with open(fpath, 'w') as f: 117 | json.dump(obj, f, indent=4, separators=(',', ': ')) 118 | 119 | def _traceback(D): 120 | i,j = array(D.shape)-1 121 | p,q = [i],[j] 122 | while (i>0) or (j>0): 123 | tb = argmin((D[i,j-1], D[i-1,j])) 124 | if tb == 0: 125 | j -= 1 126 | else: #(tb==1) 127 | i -= 1 128 | p.insert(0,i) 129 | q.insert(0,j) 130 | return array(p), array(q) 131 | 132 | def dtw(dist_mat): 133 | m, n = dist_mat.shape[:2] 134 | dist = np.zeros_like(dist_mat) 135 | for i in range(m): 136 | for j in range(n): 137 | if (i == 0) and (j == 0): 138 | dist[i, j] = dist_mat[i, j] 139 | elif (i == 0) and (j > 0): 140 | dist[i, j] = dist[i, j - 1] + dist_mat[i, j] 141 | elif (i > 0) and (j == 0): 142 | dist[i, j] = dist[i - 1, j] + dist_mat[i, j] 143 | else: 144 | dist[i, j] = \ 145 | np.min(np.stack([dist[i - 1, j], dist[i, j - 1]], axis=0), axis=0) \ 146 | + dist_mat[i, j] 147 | path = _traceback(dist) 148 | return dist[-1,-1]/sum(dist.shape), dist, path 149 | 150 | def read_image(img_path): 151 | got_img = False 152 | if not osp.exists(img_path): 153 | raise IOError("{} does not exist".format(img_path)) 154 | while not got_img: 155 | try: 156 | img = Image.open(img_path).convert('RGB') 157 | got_img = True 158 | except IOError: 159 | print("IOError incurred when reading '{}'. Will Redo. Don't worry. Just chill".format(img_path)) 160 | pass 161 | return img 162 | 163 | def img_to_tensor(img,transform): 164 | img = transform(img) 165 | img = img.unsqueeze(0) 166 | return img 167 | 168 | def feat_flatten(feat): 169 | shp = feat.shape 170 | feat = feat.reshape(shp[0] * shp[1], shp[2]) 171 | return feat 172 | 173 | def merge_feature(feature_list, shp, sample_rate = None): 174 | def pre_process(torch_feature_map): 175 | numpy_feature_map = torch_feature_map.cpu().data.numpy()[0] 176 | numpy_feature_map = numpy_feature_map.transpose(1,2,0) 177 | shp = numpy_feature_map.shape[:2] 178 | return numpy_feature_map, shp 179 | def resize_as(tfm, shp): 180 | nfm, shp2 = pre_process(tfm) 181 | scale = shp[0]/shp2[0] 182 | nfm1 = nfm.repeat(scale, axis = 0).repeat(scale, axis=1) 183 | return nfm1 184 | final_nfm = resize_as(feature_list[0], shp) 185 | for i in range(1, len(feature_list)): 186 | temp_nfm = resize_as(feature_list[i],shp) 187 | final_nfm = np.concatenate((final_nfm, temp_nfm),axis =-1) 188 | if sample_rate > 0: 189 | final_nfm = final_nfm[0:-1:sample_rate, 0:-1,sample_rate, :] 190 | return final_nfm 191 | 192 | def visualize_ranked_results(distmat, dataset, save_dir, topk=20): 193 | """ 194 | Visualize ranked results 195 | Support both imgreid and vidreid 196 | Args: 197 | - distmat: distance matrix of shape (num_query, num_gallery). 198 | - dataset: has dataset.query and dataset.gallery, both are lists of (img_path, pid, camid); 199 | for imgreid, img_path is a string, while for vidreid, img_path is a tuple containing 200 | a sequence of strings. 201 | - save_dir: directory to save output images. 202 | - topk: int, denoting top-k images in the rank list to be visualized. 203 | """ 204 | num_q, num_g = distmat.shape 205 | 206 | print("Visualizing top-{} ranks in '{}' ...".format(topk, save_dir)) 207 | print("# query: {}. # gallery {}".format(num_q, num_g)) 208 | 209 | assert num_q == len(dataset.query) 210 | assert num_g == len(dataset.gallery) 211 | 212 | indices = np.argsort(distmat, axis=1) 213 | mkdir_if_missing(save_dir) 214 | 215 | for q_idx in range(num_q): 216 | qimg_path, qpid, qcamid = dataset.query[q_idx] 217 | qdir = osp.join(save_dir, 'query' + str(q_idx + 1).zfill(5)) 218 | mkdir_if_missing(qdir) 219 | cp_img_to(qimg_path, qdir, rank=0, prefix='query') 220 | 221 | rank_idx = 1 222 | for g_idx in indices[q_idx,:]: 223 | gimg_path, gpid, gcamid = dataset.gallery[g_idx] 224 | invalid = (qpid == gpid) & (qcamid == gcamid) 225 | if not invalid: 226 | cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery') 227 | rank_idx += 1 228 | if rank_idx > topk: 229 | break 230 | 231 | def cp_img_to(src, dst, rank, prefix): 232 | """ 233 | - src: image path or tuple (for vidreid) 234 | - dst: target directory 235 | - rank: int, denoting ranked position, starting from 1 236 | - prefix: string 237 | """ 238 | if isinstance(src, tuple) or isinstance(src, list): 239 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) 240 | mkdir_if_missing(dst) 241 | for img_path in src: 242 | shutil.copy(img_path, dst) 243 | else: 244 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) 245 | shutil.copy(src, dst) -------------------------------------------------------------------------------- /util/eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | import torch 4 | import copy 5 | import os.path as osp 6 | from collections import defaultdict 7 | from opts import market1501_test_map, duke_test_map 8 | import sys 9 | 10 | def make_results(qf, gf, lqf, lgf, q_pids, g_pids, q_camids, g_camids, targetmodel, ak_typ, attr_matrix=None, dataset_name=None, attr=None): 11 | qf, gf = featureNormalization(qf, gf, targetmodel) 12 | m, n = qf.size(0), gf.size(0) 13 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 14 | distmat.addmm_(1, -2, qf, gf.t()) 15 | distmat = distmat.numpy() 16 | 17 | if targetmodel == 'aligned': 18 | from .distance import low_memory_local_dist 19 | lqf, lgf = lqf.permute(0,2,1), lgf.permute(0,2,1) 20 | local_distmat = low_memory_local_dist(lqf.numpy(),lgf.numpy(), aligned=True) 21 | distmat = local_distmat+distmat 22 | 23 | if ak_typ > 0: 24 | distmat, all_hit, ignore_list = evaluate_attr(distmat, q_pids, g_pids, attr_matrix, dataset_name, attr) 25 | return distmat, all_hit, ignore_list 26 | else: 27 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, use_metric_cuhk03=False) 28 | return distmat, cmc, mAP 29 | 30 | def featureNormalization(qf, gf, targetmodel): 31 | if targetmodel in ['aligned', 'densenet121', 'hacnn', 'mudeep', 'ide', 'cam', 'lsro', 'hhl', 'spgan']: 32 | qf = 1. * qf / (torch.norm(qf, p=2, dim=-1, keepdim=True).expand_as(qf) + 1e-12) 33 | gf = 1. * gf / (torch.norm(gf, p=2, dim=-1, keepdim=True).expand_as(gf) + 1e-12) 34 | 35 | elif targetmodel in ['pcb']: 36 | qf = (qf / (np.sqrt(6) * torch.norm(qf, p=2, dim=1, keepdim=True).expand_as(qf))).view(qf.size(0), -1) 37 | gf = (gf / (np.sqrt(6) * torch.norm(gf, p=2, dim=1, keepdim=True).expand_as(gf))).view(gf.size(0), -1) 38 | 39 | return qf, gf 40 | 41 | def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, N=100): 42 | """Evaluation with cuhk03 metric 43 | Key: one image for each gallery identity is randomly sampled for each query identity. 44 | Random sampling is performed N times (default: N=100). 45 | """ 46 | num_q, num_g = distmat.shape 47 | if num_g < max_rank: 48 | max_rank = num_g 49 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 50 | indices = np.argsort(distmat, axis=1) 51 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 52 | 53 | # compute cmc curve for each query 54 | all_cmc = [] 55 | all_AP = [] 56 | num_valid_q = 0. # number of valid query 57 | for q_idx in range(num_q): 58 | # get query pid and camid 59 | q_pid = q_pids[q_idx] 60 | q_camid = q_camids[q_idx] 61 | 62 | # remove gallery samples that have the same pid and camid with query 63 | order = indices[q_idx] 64 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 65 | keep = np.invert(remove) 66 | 67 | # compute cmc curve 68 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 69 | if not np.any(orig_cmc): 70 | # this condition is true when query identity does not appear in gallery 71 | continue 72 | 73 | kept_g_pids = g_pids[order][keep] 74 | g_pids_dict = defaultdict(list) 75 | for idx, pid in enumerate(kept_g_pids): 76 | g_pids_dict[pid].append(idx) 77 | 78 | cmc, AP = 0., 0. 79 | for repeat_idx in range(N): 80 | mask = np.zeros(len(orig_cmc), dtype=np.bool) 81 | for _, idxs in g_pids_dict.items(): 82 | # randomly sample one image for each gallery person 83 | rnd_idx = np.random.choice(idxs) 84 | mask[rnd_idx] = True 85 | masked_orig_cmc = orig_cmc[mask] 86 | _cmc = masked_orig_cmc.cumsum() 87 | _cmc[_cmc > 1] = 1 88 | cmc += _cmc[:max_rank].astype(np.float32) 89 | # compute AP 90 | num_rel = masked_orig_cmc.sum() 91 | tmp_cmc = masked_orig_cmc.cumsum() 92 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 93 | tmp_cmc = np.asarray(tmp_cmc) * masked_orig_cmc 94 | AP += tmp_cmc.sum() / num_rel 95 | cmc /= N 96 | AP /= N 97 | all_cmc.append(cmc) 98 | all_AP.append(AP) 99 | num_valid_q += 1. 100 | 101 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 102 | 103 | all_cmc = np.asarray(all_cmc).astype(np.float32) 104 | all_cmc = all_cmc.sum(0) / num_valid_q 105 | mAP = np.mean(all_AP) 106 | 107 | return all_cmc, mAP 108 | 109 | def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 110 | """Evaluation with market1501 metric 111 | Key: for each query identity, its gallery images from the same camera view are discarded. 112 | """ 113 | num_q, num_g = distmat.shape 114 | if num_g < max_rank: 115 | max_rank = num_g 116 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 117 | indices = np.argsort(distmat, axis=1) 118 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 119 | 120 | # compute cmc curve for each query 121 | all_cmc = [] 122 | all_AP = [] 123 | num_valid_q = 0. # number of valid query 124 | for q_idx in range(num_q): 125 | # get query pid and camid 126 | q_pid = q_pids[q_idx] 127 | q_camid = q_camids[q_idx] 128 | 129 | # remove gallery samples that have the same pid and camid with query 130 | order = indices[q_idx] 131 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 132 | keep = np.invert(remove) 133 | 134 | # compute cmc curve 135 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 136 | if not np.any(orig_cmc): 137 | # this condition is true when query identity does not appear in gallery 138 | continue 139 | 140 | cmc = orig_cmc.cumsum() 141 | cmc[cmc > 1] = 1 142 | 143 | all_cmc.append(cmc[:max_rank]) 144 | num_valid_q += 1. 145 | 146 | # compute average precision 147 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 148 | num_rel = orig_cmc.sum() 149 | tmp_cmc = orig_cmc.cumsum() 150 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 151 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 152 | AP = tmp_cmc.sum() / num_rel 153 | all_AP.append(AP) 154 | 155 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 156 | 157 | all_cmc = np.asarray(all_cmc).astype(np.float32) 158 | all_cmc = all_cmc.sum(0) / num_valid_q 159 | mAP = np.mean(all_AP) 160 | 161 | return all_cmc, mAP 162 | 163 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=20, use_metric_cuhk03=False): 164 | if use_metric_cuhk03: return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 165 | else: return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 166 | 167 | def evaluate_attr(distmat, q_pids, g_pids, attr_matrix, dataset_name, attr_list, max_rank=20): 168 | attr_key, attr_value = attr_list 169 | attr_name = 'duke_attribute' if dataset_name == 'dukemtmcreid' else 'market_attribute' 170 | offset = 0 if dataset_name == 'dukemtmcreid' else 1 171 | mapping = duke_test_map if dataset_name == 'dukemtmcreid' else market1501_test_map 172 | column = attr_matrix[attr_name][0]['test'][0][0][attr_key][0][0] 173 | 174 | num_q, num_g = distmat.shape 175 | indices = np.argsort(distmat, axis=1) 176 | 177 | all_hit = [] 178 | ignore_list = [] 179 | num_valid_q = 0. # number of valid query 180 | for q_idx in range(num_q): 181 | q_pid = q_pids[q_idx] 182 | if column[mapping[q_pid]-offset] == attr_value: 183 | ignore_list.append(q_idx) 184 | continue 185 | 186 | order = indices[q_idx] 187 | matches = np.zeros_like(order) 188 | 189 | for i in range(len(order)): 190 | if column[mapping[g_pids[order[i]]]-offset] == attr_value: 191 | matches[i] = 1 192 | 193 | hit = matches.cumsum() 194 | hit[hit > 1] = 1 195 | all_hit.append(hit[:max_rank]) 196 | num_valid_q += 1. # number of valid query 197 | 198 | assert num_valid_q > 0 199 | all_hit = np.asarray(all_hit).astype(np.float32) 200 | all_hit = all_hit.sum(0) / num_valid_q 201 | 202 | # distmat = np.delete(distmat, ignore_list, axis=0) 203 | 204 | return distmat, all_hit, ignore_list -------------------------------------------------------------------------------- /models/HACNN.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | import torchvision 7 | 8 | __all__ = ['HACNN'] 9 | 10 | class ConvBlock(nn.Module): 11 | """Basic convolutional block: 12 | convolution + batch normalization + relu. 13 | 14 | Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d): 15 | in_c (int): number of input channels. 16 | out_c (int): number of output channels. 17 | k (int or tuple): kernel size. 18 | s (int or tuple): stride. 19 | p (int or tuple): padding. 20 | """ 21 | def __init__(self, in_c, out_c, k, s=1, p=0): 22 | super(ConvBlock, self).__init__() 23 | self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) 24 | self.bn = nn.BatchNorm2d(out_c) 25 | 26 | def forward(self, x): 27 | return F.relu(self.bn(self.conv(x))) 28 | 29 | class InceptionA(nn.Module): 30 | """ 31 | Args: 32 | in_channels (int): number of input channels 33 | out_channels (int): number of output channels AFTER concatenation 34 | """ 35 | def __init__(self, in_channels, out_channels): 36 | super(InceptionA, self).__init__() 37 | single_out_channels = out_channels // 4 38 | 39 | self.stream1 = nn.Sequential( 40 | ConvBlock(in_channels, single_out_channels, 1), 41 | ConvBlock(single_out_channels, single_out_channels, 3, p=1), 42 | ) 43 | self.stream2 = nn.Sequential( 44 | ConvBlock(in_channels, single_out_channels, 1), 45 | ConvBlock(single_out_channels, single_out_channels, 3, p=1), 46 | ) 47 | self.stream3 = nn.Sequential( 48 | ConvBlock(in_channels, single_out_channels, 1), 49 | ConvBlock(single_out_channels, single_out_channels, 3, p=1), 50 | ) 51 | self.stream4 = nn.Sequential( 52 | nn.AvgPool2d(3, stride=1, padding=1), 53 | ConvBlock(in_channels, single_out_channels, 1), 54 | ) 55 | 56 | def forward(self, x): 57 | s1 = self.stream1(x) 58 | s2 = self.stream2(x) 59 | s3 = self.stream3(x) 60 | s4 = self.stream4(x) 61 | y = torch.cat([s1, s2, s3, s4], dim=1) 62 | return y 63 | 64 | class InceptionB(nn.Module): 65 | """ 66 | Args: 67 | in_channels (int): number of input channels 68 | out_channels (int): number of output channels AFTER concatenation 69 | """ 70 | def __init__(self, in_channels, out_channels): 71 | super(InceptionB, self).__init__() 72 | single_out_channels = out_channels // 4 73 | 74 | self.stream1 = nn.Sequential( 75 | ConvBlock(in_channels, single_out_channels, 1), 76 | ConvBlock(single_out_channels, single_out_channels, 3, s=2, p=1), 77 | ) 78 | self.stream2 = nn.Sequential( 79 | ConvBlock(in_channels, single_out_channels, 1), 80 | ConvBlock(single_out_channels, single_out_channels, 3, p=1), 81 | ConvBlock(single_out_channels, single_out_channels, 3, s=2, p=1), 82 | ) 83 | self.stream3 = nn.Sequential( 84 | nn.MaxPool2d(3, stride=2, padding=1), 85 | ConvBlock(in_channels, single_out_channels*2, 1), 86 | ) 87 | 88 | def forward(self, x): 89 | s1 = self.stream1(x) 90 | s2 = self.stream2(x) 91 | s3 = self.stream3(x) 92 | y = torch.cat([s1, s2, s3], dim=1) 93 | return y 94 | 95 | class SpatialAttn(nn.Module): 96 | """Spatial Attention (Sec. 3.1.I.1)""" 97 | def __init__(self): 98 | super(SpatialAttn, self).__init__() 99 | self.conv1 = ConvBlock(1, 1, 3, s=2, p=1) 100 | self.conv2 = ConvBlock(1, 1, 1) 101 | 102 | def forward(self, x): 103 | # global cross-channel averaging 104 | x = x.mean(1, keepdim=True) 105 | # 3-by-3 conv 106 | x = self.conv1(x) 107 | # bilinear resizing 108 | x = F.upsample(x, (x.size(2)*2, x.size(3)*2), mode='bilinear', align_corners=True) 109 | # scaling conv 110 | x = self.conv2(x) 111 | return x 112 | 113 | class ChannelAttn(nn.Module): 114 | """Channel Attention (Sec. 3.1.I.2)""" 115 | def __init__(self, in_channels, reduction_rate=16): 116 | super(ChannelAttn, self).__init__() 117 | assert in_channels%reduction_rate == 0 118 | self.conv1 = ConvBlock(in_channels, in_channels//reduction_rate, 1) 119 | self.conv2 = ConvBlock(in_channels//reduction_rate, in_channels, 1) 120 | 121 | def forward(self, x): 122 | # squeeze operation (global average pooling) 123 | x = F.avg_pool2d(x, x.size()[2:]) 124 | # excitation operation (2 conv layers) 125 | x = self.conv1(x) 126 | x = self.conv2(x) 127 | return x 128 | 129 | class SoftAttn(nn.Module): 130 | """Soft Attention (Sec. 3.1.I) 131 | Aim: Spatial Attention + Channel Attention 132 | Output: attention maps with shape identical to input. 133 | """ 134 | def __init__(self, in_channels): 135 | super(SoftAttn, self).__init__() 136 | self.spatial_attn = SpatialAttn() 137 | self.channel_attn = ChannelAttn(in_channels) 138 | self.conv = ConvBlock(in_channels, in_channels, 1) 139 | 140 | def forward(self, x): 141 | y_spatial = self.spatial_attn(x) 142 | y_channel = self.channel_attn(x) 143 | y = y_spatial * y_channel 144 | y = F.sigmoid(self.conv(y)) 145 | return y 146 | 147 | class HardAttn(nn.Module): 148 | """Hard Attention (Sec. 3.1.II)""" 149 | def __init__(self, in_channels): 150 | super(HardAttn, self).__init__() 151 | self.fc = nn.Linear(in_channels, 4*2) 152 | self.init_params() 153 | 154 | def init_params(self): 155 | self.fc.weight.data.zero_() 156 | self.fc.bias.data.copy_(torch.tensor([0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float)) 157 | 158 | def forward(self, x): 159 | # squeeze operation (global average pooling) 160 | x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1)) 161 | # predict transformation parameters 162 | theta = F.tanh(self.fc(x)) 163 | theta = theta.view(-1, 4, 2) 164 | return theta 165 | 166 | class HarmAttn(nn.Module): 167 | """Harmonious Attention (Sec. 3.1)""" 168 | def __init__(self, in_channels): 169 | super(HarmAttn, self).__init__() 170 | self.soft_attn = SoftAttn(in_channels) 171 | self.hard_attn = HardAttn(in_channels) 172 | 173 | def forward(self, x): 174 | y_soft_attn = self.soft_attn(x) 175 | theta = self.hard_attn(x) 176 | return y_soft_attn, theta 177 | 178 | class HACNN(nn.Module): 179 | """ 180 | Harmonious Attention Convolutional Neural Network 181 | 182 | Reference: 183 | Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018. 184 | 185 | Args: 186 | num_classes (int): number of classes to predict 187 | nchannels (list): number of channels AFTER concatenation 188 | feat_dim (int): feature dimension for a single stream 189 | learn_region (bool): whether to learn region features (i.e. local branch) 190 | """ 191 | def __init__(self, num_classes, loss={'xent', 'htri'}, nchannels=[128, 256, 384], feat_dim=512, learn_region=True, use_gpu=True, **kwargs): 192 | super(HACNN, self).__init__() 193 | self.loss = loss 194 | self.learn_region = learn_region 195 | self.use_gpu = use_gpu 196 | 197 | self.conv = ConvBlock(3, 32, 3, s=2, p=1) 198 | 199 | # Construct Inception + HarmAttn blocks 200 | # ============== Block 1 ============== 201 | self.inception1 = nn.Sequential( 202 | InceptionA(32, nchannels[0]), 203 | InceptionB(nchannels[0], nchannels[0]), 204 | ) 205 | self.ha1 = HarmAttn(nchannels[0]) 206 | 207 | # ============== Block 2 ============== 208 | self.inception2 = nn.Sequential( 209 | InceptionA(nchannels[0], nchannels[1]), 210 | InceptionB(nchannels[1], nchannels[1]), 211 | ) 212 | self.ha2 = HarmAttn(nchannels[1]) 213 | 214 | # ============== Block 3 ============== 215 | self.inception3 = nn.Sequential( 216 | InceptionA(nchannels[1], nchannels[2]), 217 | InceptionB(nchannels[2], nchannels[2]), 218 | ) 219 | self.ha3 = HarmAttn(nchannels[2]) 220 | 221 | self.fc_global = nn.Sequential( 222 | nn.Linear(nchannels[2], feat_dim), 223 | nn.BatchNorm1d(feat_dim), 224 | nn.ReLU(), 225 | ) 226 | self.classifier_global = nn.Linear(feat_dim, num_classes) 227 | 228 | if self.learn_region: 229 | self.init_scale_factors() 230 | self.local_conv1 = InceptionB(32, nchannels[0]) 231 | self.local_conv2 = InceptionB(nchannels[0], nchannels[1]) 232 | self.local_conv3 = InceptionB(nchannels[1], nchannels[2]) 233 | self.fc_local = nn.Sequential( 234 | nn.Linear(nchannels[2]*4, feat_dim), 235 | nn.BatchNorm1d(feat_dim), 236 | nn.ReLU(), 237 | ) 238 | self.classifier_local = nn.Linear(feat_dim, num_classes) 239 | self.feat_dim = feat_dim * 2 240 | else: 241 | self.feat_dim = feat_dim 242 | 243 | def init_scale_factors(self): 244 | # initialize scale factors (s_w, s_h) for four regions 245 | self.scale_factors = [] 246 | self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)) 247 | self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)) 248 | self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)) 249 | self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float)) 250 | 251 | def stn(self, x, theta): 252 | """Perform spatial transform 253 | x: (batch, channel, height, width) 254 | theta: (batch, 2, 3) 255 | """ 256 | grid = F.affine_grid(theta, x.size()) 257 | x = F.grid_sample(x, grid) 258 | return x 259 | 260 | def transform_theta(self, theta_i, region_idx): 261 | """Transform theta to include (s_w, s_h), 262 | resulting in (batch, 2, 3)""" 263 | scale_factors = self.scale_factors[region_idx] 264 | theta = torch.zeros(theta_i.size(0), 2, 3) 265 | theta[:,:,:2] = scale_factors 266 | theta[:,:,-1] = theta_i 267 | if self.use_gpu: theta = theta.cuda() 268 | return theta 269 | 270 | def forward(self, x, is_training): 271 | assert x.size(2) == 160 and x.size(3) == 64, \ 272 | "Input size does not match, expected (160, 64) but got ({}, {})".format(x.size(2), x.size(3)) 273 | x = self.conv(x) 274 | 275 | # ============== Block 1 ============== 276 | # global branch 277 | x1 = self.inception1(x) 278 | x1_attn, x1_theta = self.ha1(x1) 279 | x1_out = x1 * x1_attn 280 | # local branch 281 | if self.learn_region: 282 | x1_local_list = [] 283 | for region_idx in range(4): 284 | x1_theta_i = x1_theta[:,region_idx,:] 285 | x1_theta_i = self.transform_theta(x1_theta_i, region_idx) 286 | x1_trans_i = self.stn(x, x1_theta_i) 287 | x1_trans_i = F.upsample(x1_trans_i, (24, 28), mode='bilinear', align_corners=True) 288 | x1_local_i = self.local_conv1(x1_trans_i) 289 | x1_local_list.append(x1_local_i) 290 | 291 | # ============== Block 2 ============== 292 | # Block 2 293 | # global branch 294 | x2 = self.inception2(x1_out) 295 | x2_attn, x2_theta = self.ha2(x2) 296 | x2_out = x2 * x2_attn 297 | # local branch 298 | if self.learn_region: 299 | x2_local_list = [] 300 | for region_idx in range(4): 301 | x2_theta_i = x2_theta[:,region_idx,:] 302 | x2_theta_i = self.transform_theta(x2_theta_i, region_idx) 303 | x2_trans_i = self.stn(x1_out, x2_theta_i) 304 | x2_trans_i = F.upsample(x2_trans_i, (12, 14), mode='bilinear', align_corners=True) 305 | x2_local_i = x2_trans_i + x1_local_list[region_idx] 306 | x2_local_i = self.local_conv2(x2_local_i) 307 | x2_local_list.append(x2_local_i) 308 | 309 | # ============== Block 3 ============== 310 | # Block 3 311 | # global branch 312 | x3 = self.inception3(x2_out) 313 | x3_attn, x3_theta = self.ha3(x3) 314 | x3_out = x3 * x3_attn 315 | # local branch 316 | if self.learn_region: 317 | x3_local_list = [] 318 | for region_idx in range(4): 319 | x3_theta_i = x3_theta[:,region_idx,:] 320 | x3_theta_i = self.transform_theta(x3_theta_i, region_idx) 321 | x3_trans_i = self.stn(x2_out, x3_theta_i) 322 | x3_trans_i = F.upsample(x3_trans_i, (6, 7), mode='bilinear', align_corners=True) 323 | x3_local_i = x3_trans_i + x2_local_list[region_idx] 324 | x3_local_i = self.local_conv3(x3_local_i) 325 | x3_local_list.append(x3_local_i) 326 | 327 | # ============== Feature generation ============== 328 | # global branch 329 | x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(x3_out.size(0), x3_out.size(1)) 330 | x_global = self.fc_global(x_global) 331 | # local branch 332 | if self.learn_region: 333 | x_local_list = [] 334 | for region_idx in range(4): 335 | x_local_i = x3_local_list[region_idx] 336 | x_local_i = F.avg_pool2d(x_local_i, x_local_i.size()[2:]).view(x_local_i.size(0), -1) 337 | x_local_list.append(x_local_i) 338 | x_local = torch.cat(x_local_list, 1) 339 | x_local = self.fc_local(x_local) 340 | 341 | if not is_training: 342 | # l2 normalization before concatenation 343 | if self.learn_region: 344 | x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True) 345 | x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True) 346 | return [torch.cat([x_global, x_local], 1)] 347 | else: 348 | return [x_global] 349 | 350 | prelogits_global = self.classifier_global(x_global) 351 | if self.learn_region: 352 | prelogits_local = self.classifier_local(x_local) 353 | 354 | if self.loss == {'xent'}: 355 | if self.learn_region: 356 | return [prelogits_global, prelogits_local] 357 | else: 358 | return [prelogits_global] 359 | elif self.loss == {'xent', 'htri'}: 360 | if self.learn_region: 361 | return [(prelogits_global, prelogits_local), (x_global, x_local)] 362 | else: 363 | return [prelogits_global, x_global] 364 | else: 365 | raise KeyError("Unsupported loss: {}".format(self.loss)) -------------------------------------------------------------------------------- /GD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | import functools 6 | from torch.autograd import Variable 7 | from torch.optim import lr_scheduler 8 | from util.spectral import SpectralNorm 9 | from util.gumbel import gumbel_softmax 10 | import numpy as np 11 | import math 12 | 13 | class Pat_Discriminator(nn.Module): 14 | """ 15 | Defines a PatchGAN discriminator 16 | Code based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 17 | """ 18 | def __init__(self, input_nc, ndf=64, n_layers=3, norm='bn'): 19 | """Construct a PatchGAN discriminator 20 | Parameters: 21 | input_nc (int) -- the number of channels in input images 22 | ndf (int) -- the number of filters in the last conv layer 23 | n_layers (int) -- the number of conv layers in the discriminator 24 | norm_layer -- normalization layer 25 | """ 26 | super(Pat_Discriminator, self).__init__() 27 | 28 | norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d 29 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 30 | use_bias = norm_layer.func != nn.BatchNorm2d 31 | else: 32 | use_bias = norm_layer != nn.BatchNorm2d 33 | 34 | kw = 4 35 | padw = 1 36 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 37 | nf_mult = 1 38 | nf_mult_prev = 1 39 | for n in range(1, n_layers): # gradually increase the number of filters 40 | nf_mult_prev = nf_mult 41 | nf_mult = min(2 ** n, 8) 42 | sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True)] 43 | 44 | nf_mult_prev = nf_mult 45 | nf_mult = min(2 ** n_layers, 8) 46 | sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True)] 47 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 48 | self.model = nn.Sequential(*sequence) 49 | 50 | def forward(self, x): 51 | return self.model(x), torch.ones_like(x) 52 | 53 | 54 | class MS_Discriminator(nn.Module): 55 | def __init__(self, input_nc, ndf=64, n_layers=3, norm='bn', num_D=3, temperature=-1, use_gumbel=False): 56 | super(MS_Discriminator, self).__init__() 57 | self.num_D = num_D 58 | self.n_layers = n_layers 59 | self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, count_include_pad=False) 60 | self.same0 = SamePadding(kernel_size=3, stride=2) 61 | self.same1 = SamePadding(kernel_size=4, stride=2) 62 | self.same2 = SamePadding(kernel_size=4, stride=1) 63 | self.Mask = Mask(norm, temperature, use_gumbel) 64 | 65 | for i in range(num_D): 66 | netD = sub_Discriminator(input_nc, ndf, n_layers, norm) 67 | for j in range(n_layers+2): setattr(self, 'D'+str(i)+'_layer'+str(j), getattr(netD, 'layer'+str(j))) 68 | 69 | def single_forward(self, model, x): 70 | result = [x] 71 | for i in range(len(model)): 72 | samepadding = self.same1 if i < len(model)-2 else self.same2 73 | result.append(model[i](samepadding(result[-1]))) 74 | return result[1:] 75 | 76 | def forward(self, x): 77 | num_D = self.num_D 78 | proposal = [] 79 | result = [] 80 | mask = None 81 | input_downsampled = x 82 | for i in range(num_D): 83 | model = [getattr(self, 'D'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] 84 | proposal.append(self.single_forward(model, input_downsampled)) #[[D2L0, D2L1,..., D2L4],[D1L0,...,D1L4],[D0L0,...,D0L4]] 85 | if i != (num_D-1): input_downsampled = self.downsample(self.same0(input_downsampled)) 86 | for i in proposal: result.append(i[-1]) 87 | mask = self.Mask(x, proposal) 88 | return result, mask 89 | 90 | # (64,128,256,512,1) 91 | class sub_Discriminator(nn.Module): 92 | def __init__(self, input_nc, ndf=64, n_layers=3, norm='in'): 93 | super(sub_Discriminator, self).__init__() 94 | self.n_layers = n_layers 95 | 96 | use_bias = norm == 'in' 97 | norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d 98 | sequence = [[SpectralNorm(nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, bias=use_bias)), nn.LeakyReLU(0.2, True)]] 99 | nf = ndf 100 | for n in range(1, n_layers): 101 | nf_prev = nf 102 | nf = min(nf*2, 512) 103 | sequence += [[SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=4, stride=2, bias=use_bias)), norm_layer(nf), nn.LeakyReLU(0.2, True)]] 104 | 105 | nf_prev = nf 106 | nf = min(nf*2, 512) 107 | sequence += [[SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=4, stride=1, bias=use_bias)), norm_layer(nf), nn.LeakyReLU(0.2, True)]] 108 | sequence += [[nn.Conv2d(nf, 1, kernel_size=4, stride=1)]] 109 | 110 | for n in range(len(sequence)): 111 | setattr(self, 'layer'+str(n), nn.Sequential(*sequence[n])) 112 | 113 | def forward(self, input): 114 | res = [input] 115 | for n in range(self.n_layers+2): 116 | model = getattr(self, 'layer'+str(n)) 117 | res.append(model(res[-1])) 118 | return res[1:] 119 | 120 | class Mask(nn.Module): 121 | def __init__(self, norm, temperature, use_gumbel, fused=1): 122 | super(Mask, self).__init__() 123 | self.temperature = temperature 124 | self.use_gumbel = use_gumbel 125 | self.fused = fused 126 | self.T = nn.Parameter(torch.Tensor([1])) 127 | norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d 128 | small_channels = [512, 512, 256, 128] 129 | big_channels = [512+128, 512+128+64, 128+64, 64] if self.fused == 2 else [512, 512, 128, 64] 130 | 131 | self.up32_16 = UpLayer(big_channels=big_channels[0], out_channels=512, small_channels=small_channels[0], norm_layer=norm_layer) 132 | self.up16_8 = UpLayer(big_channels=big_channels[1], out_channels=256, small_channels=small_channels[1], norm_layer=norm_layer) 133 | self.up8_4 = UpLayer(big_channels=big_channels[2], out_channels=128, small_channels=small_channels[2], norm_layer=norm_layer) 134 | # self.up4_2 = UpLayer(big_channels=big_channels[3], out_channels=64, small_channels=small_channels[3], norm_layer=norm_layer) 135 | self.deconv1 = nn.Sequential(*[SpectralNorm(nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)), nn.LeakyReLU(0.2, True)]) 136 | self.deconv2 = nn.Sequential(*[SpectralNorm(nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)), nn.LeakyReLU(0.2, True)]) 137 | self.conv2 = nn.Sequential(*[nn.Conv2d(128, 1, kernel_size=1, stride=1)]) 138 | self.logsoftmax = nn.LogSoftmax(dim=1) 139 | 140 | def forward(self, x, proposal): 141 | n,c,h,w = x.size() 142 | if self.temperature == -1: return torch.ones((n,1,h,w)).cuda() 143 | scale32 = proposal[2][3] 144 | scale16 = torch.cat((proposal[2][1], proposal[1][3]),1) if self.fused == 2 else proposal[1][3] 145 | scale8 = torch.cat((proposal[0][3], proposal[1][1], proposal[2][0]),1) if self.fused == 2 else proposal[0][3] 146 | scale4 = torch.cat((proposal[0][1], proposal[1][0]),1) if self.fused == 2 else proposal[0][1] 147 | scale2 = proposal[0][0] 148 | out = self.up32_16(scale32, scale16) 149 | out = self.up16_8(out, scale8) 150 | out = self.up8_4(out, scale4) 151 | # out = self.up4_2(out, scale2) 152 | out = self.deconv1(out) 153 | out = self.deconv2(out) 154 | out = self.conv2(out) 155 | 156 | if not self.use_gumbel: 157 | logits = self.logsoftmax(out.view(n, -1)) 158 | th, _ = torch.topk(logits, k=int(self.temperature), dim=1, largest=True) 159 | mask, zeros, ones = torch.zeros_like(logits).cuda(), torch.zeros(h*w).cuda(), torch.ones(h*w).cuda() 160 | for i in range(n): 161 | mask[i,:] = torch.where(logits[i,:]>=th[i, int(self.temperature)-1], ones, zeros) 162 | mask = mask.view(n, 1, h, w) 163 | elif self.use_gumbel: 164 | logits = gumbel_softmax(out.view(n, -1), k=int(self.temperature), T=self.T, hard=True, eps=1e-10).view(n, 1, h, w) 165 | mask = logits.cuda() 166 | # logits = F.gumbel_softmax(out.view(n, -1), tau=self.temperature).view(n, 1, h, w) 167 | # # logits_normed = torch.clamp((logits_normed+1e-4), min=0, max=1) 168 | # logits = np.minimum(1.0, logits.data.cpu().numpy()*(h*w)+1e-4) 169 | # mask = torch.bernoulli(torch.from_numpy(logits)).cuda() 170 | return mask 171 | 172 | class UpLayer(nn.Module): 173 | def __init__(self, big_channels, out_channels, small_channels, norm_layer): 174 | super(UpLayer, self).__init__() 175 | self.big_channels = big_channels 176 | self.out_channels = out_channels 177 | self.small_channels = small_channels 178 | self.conv1 = nn.Sequential(*[SpectralNorm(nn.Conv2d(self.big_channels, self.small_channels, kernel_size=1, stride=1)), norm_layer(self.small_channels), nn.LeakyReLU(0.2, True)]) 179 | self.conv2 = nn.Sequential(*[SpectralNorm(nn.Conv2d(self.small_channels, self.out_channels, kernel_size=3, stride=1, padding=1)), norm_layer(self.out_channels), nn.LeakyReLU(0.2, True)]) 180 | 181 | def forward(self, small, big): 182 | small = F.upsample(small, size=(big.size()[2], big.size()[3]), mode='bilinear') 183 | big = self.conv1(big) 184 | out = self.conv2(big+small) 185 | return out 186 | 187 | class Generator(nn.Module): 188 | def __init__(self, input_nc, output_nc, ngf, norm='bn', n_blocks=6): 189 | super(Generator, self).__init__() 190 | 191 | n_downsampling = n_upsampling = 2 192 | use_bias = norm == 'in' 193 | norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d 194 | begin_layers, down_layers, res_layers, up_layers, end_layers = [], [], [], [], [] 195 | for i in range(n_upsampling): 196 | up_layers.append([]) 197 | # ngf 198 | begin_layers = [nn.ReflectionPad2d(3), SpectralNorm(nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias)), norm_layer(ngf), nn.ReLU(True)] 199 | # 2ngf, 4ngf 200 | for i in range(n_downsampling): 201 | mult = 2**i 202 | down_layers += [SpectralNorm(nn.Conv2d(ngf*mult, ngf*mult*2, kernel_size=3, stride=2, padding=1, bias=use_bias)), norm_layer(ngf*mult*2), nn.ReLU(True)] 203 | # 4ngf 204 | mult = 2**n_downsampling 205 | for i in range(n_blocks): 206 | res_layers += [ResnetBlock(ngf*mult, norm_layer, use_bias)] 207 | # 2ngf, ngf 208 | for i in range(n_upsampling): 209 | mult = 2**(n_upsampling - i) 210 | up_layers[i] += [SpectralNorm(nn.ConvTranspose2d(ngf*mult, int(ngf*mult/2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias)), norm_layer(int(ngf*mult/2)), nn.ReLU(True)] 211 | 212 | end_layers += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] 213 | 214 | self.l1 = nn.Sequential(*begin_layers) 215 | self.l2 = nn.Sequential(*down_layers) 216 | self.l3 = nn.Sequential(*res_layers) 217 | self.l4_1 = nn.Sequential(*up_layers[0]) 218 | self.l4_2 = nn.Sequential(*up_layers[1]) 219 | self.l5 = nn.Sequential(*end_layers) 220 | 221 | def forward(self, inputs): 222 | out = self.l1(inputs) 223 | out = self.l2(out) 224 | out = self.l3(out) 225 | out = self.l4_1(out) 226 | out = self.l4_2(out) 227 | out = self.l5(out) 228 | return out 229 | 230 | class ResnetG(nn.Module): 231 | def __init__(self, input_nc, output_nc, ngf, norm='bn', n_blocks=6): 232 | super(ResnetG, self).__init__() 233 | 234 | n_downsampling = n_upsampling = 2 235 | use_bias = norm == 'in' 236 | norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d 237 | begin_layers, down_layers, res_layers, up_layers, end_layers = [], [], [], [], [] 238 | for i in range(n_upsampling): 239 | up_layers.append([]) 240 | # ngf 241 | begin_layers = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)] 242 | # 2ngf, 4ngf 243 | for i in range(n_downsampling): 244 | mult = 2**i 245 | down_layers += [nn.Conv2d(ngf*mult, ngf*mult*2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf*mult*2), nn.ReLU(True)] 246 | # 4ngf 247 | mult = 2**n_downsampling 248 | for i in range(n_blocks): 249 | res_layers += [ResnetBlock(ngf*mult, norm_layer, use_bias)] 250 | # 2ngf, ngf 251 | for i in range(n_upsampling): 252 | mult = 2**(n_upsampling - i) 253 | up_layers[i] += [nn.ConvTranspose2d(ngf*mult, int(ngf*mult/2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf*mult/2)), nn.ReLU(True)] 254 | 255 | end_layers += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] 256 | 257 | self.l1 = nn.Sequential(*begin_layers) 258 | self.l2 = nn.Sequential(*down_layers) 259 | self.l3 = nn.Sequential(*res_layers) 260 | self.l4_1 = nn.Sequential(*up_layers[0]) 261 | self.l4_2 = nn.Sequential(*up_layers[1]) 262 | self.l5 = nn.Sequential(*end_layers) 263 | 264 | def forward(self, inputs): 265 | out = self.l1(inputs) 266 | out = self.l2(out) 267 | out = self.l3(out) 268 | out = self.l4_1(out) 269 | out = self.l4_2(out) 270 | out = self.l5(out) 271 | return out 272 | 273 | # Define a resnet block 274 | class ResnetBlock(nn.Module): 275 | def __init__(self, dim, norm_layer, use_bias): 276 | super(ResnetBlock, self).__init__() 277 | self.conv_block = self.build_conv_block(dim, norm_layer, use_bias) 278 | 279 | def build_conv_block(self, dim, norm_layer, use_bias): 280 | conv_block = [] 281 | for i in range(2): 282 | conv_block += [nn.ReflectionPad2d(1)] 283 | conv_block += [SpectralNorm(nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=use_bias)), norm_layer(dim)] 284 | if i < 1: 285 | conv_block += [nn.ReLU(True)] 286 | return nn.Sequential(*conv_block) 287 | 288 | def forward(self, x): 289 | out = x + self.conv_block(x) 290 | return out 291 | 292 | class SamePadding(nn.Module): 293 | def __init__(self, kernel_size, stride): 294 | super(SamePadding, self).__init__() 295 | self.kernel_size = torch.nn.modules.utils._pair(kernel_size) 296 | self.stride = torch.nn.modules.utils._pair(stride) 297 | 298 | def forward(self, input): 299 | in_width = input.size()[2] 300 | in_height = input.size()[3] 301 | out_width = math.ceil(float(in_width) / float(self.stride[0])) 302 | out_height = math.ceil(float(in_height) / float(self.stride[1])) 303 | pad_along_width = ((out_width - 1) * self.stride[0] + 304 | self.kernel_size[0] - in_width) 305 | pad_along_height = ((out_height - 1) * self.stride[1] + 306 | self.kernel_size[1] - in_height) 307 | pad_left = int(pad_along_width / 2) 308 | pad_top = int(pad_along_height / 2) 309 | pad_right = pad_along_width - pad_left 310 | pad_bottom = pad_along_height - pad_top 311 | return F.pad(input, (int(pad_left), int(pad_right), int(pad_top), int(pad_bottom)), 'constant', 0) 312 | 313 | def __repr__(self): 314 | return self.__class__.__name__ 315 | 316 | def weights_init(m): 317 | classname = m.__class__.__name__ 318 | # print(dir(m)) 319 | if classname.find('Conv') != -1: 320 | if 'weight' in dir(m): 321 | m.weight.data.normal_(0.0, 1) 322 | elif classname.find('BatchNorm2d') != -1: 323 | m.weight.data.normal_(1.0, 0.02) 324 | m.bias.data.fill_(0) 325 | 326 | class GANLoss(nn.Module): 327 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.cuda.FloatTensor): 328 | super(GANLoss, self).__init__() 329 | self.real_label = target_real_label 330 | self.fake_label = target_fake_label 331 | self.real_label_var = None 332 | self.fake_label_var = None 333 | self.Tensor = tensor 334 | if use_lsgan: self.loss = nn.MSELoss() 335 | else: self.loss = nn.BCELoss() 336 | 337 | def get_target_tensor(self, input, target_is_real): 338 | target_tensor = None 339 | if target_is_real: 340 | create_label = ((self.real_label_var is None) or 341 | (self.real_label_var.numel() != input.numel())) 342 | if create_label: 343 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 344 | self.real_label_var = Variable(real_tensor, requires_grad=False) 345 | target_tensor = self.real_label_var 346 | else: 347 | create_label = ((self.fake_label_var is None) or 348 | (self.fake_label_var.numel() != input.numel())) 349 | if create_label: 350 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 351 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 352 | target_tensor = self.fake_label_var 353 | return target_tensor 354 | 355 | def __call__(self, input, target_is_real): 356 | if isinstance(input[0], list): 357 | loss = 0 358 | for input_i in input: 359 | pred = input_i[-1] 360 | target_tensor = self.get_target_tensor(pred, target_is_real) 361 | loss += self.loss(pred, target_tensor) 362 | return loss 363 | else: 364 | target_tensor = self.get_target_tensor(input[-1], target_is_real) 365 | return self.loss(input[-1], target_tensor) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function, division 3 | import sys 4 | import time 5 | import datetime 6 | import argparse 7 | import os 8 | import numpy as np 9 | import os.path as osp 10 | import math 11 | from random import sample 12 | from scipy import io 13 | 14 | import torchvision 15 | import torch 16 | import torch.nn as nn 17 | import torch.optim as optim 18 | from torch.utils.data import DataLoader 19 | import torch.backends.cudnn as cudnn 20 | 21 | import models 22 | from models.PCB import PCB_test 23 | # from ReID_attr import get_target_withattr # Need Attribute file 24 | from opts import get_opts, Imagenet_mean, Imagenet_stddev 25 | from GD import Generator, MS_Discriminator, Pat_Discriminator, GANLoss, weights_init 26 | from advloss import DeepSupervision, adv_CrossEntropyLoss, adv_CrossEntropyLabelSmooth, adv_TripletLoss 27 | from util import data_manager 28 | from util.dataset_loader import ImageDataset 29 | from util.utils import fliplr, Logger, save_checkpoint, visualize_ranked_results 30 | from util.eval_metrics import make_results 31 | from util.samplers import RandomIdentitySampler, AttrPool 32 | 33 | # Training settings 34 | parser = argparse.ArgumentParser(description='adversarial attack') 35 | parser.add_argument('--root', type=str, default='data', help="root path to data directory") 36 | parser.add_argument('--targetmodel', type=str, default='aligned', choices=models.get_names()) 37 | parser.add_argument('--dataset', type=str, default='market1501', choices=data_manager.get_names()) 38 | # PATH 39 | parser.add_argument('--G_resume_dir', type=str, default='', metavar='path to resume G') 40 | parser.add_argument('--pre_dir', type=str, default='models', help='path to be attacked model') 41 | parser.add_argument('--attr_dir', type=str, default='', help='path to attribute file') 42 | parser.add_argument('--save_dir', type=str, default='logs', help='path to save model') 43 | parser.add_argument('--vis_dir', type=str, default='vis', help='path to save visualization result') 44 | parser.add_argument('--ablation', type=str, default='', help='for ablation study') 45 | # var 46 | parser.add_argument('--mode', type=str, default='train', help='train/test') 47 | parser.add_argument('--D', type=str, default='MSGAN', help='Type of discriminator: PatchGAN or Multi-stage GAN') 48 | parser.add_argument('--normalization', type=str, default='bn', help='bn or in') 49 | parser.add_argument('--loss', type=str, default='xent_htri', choices=['cent', 'xent', 'htri', 'xent_htri']) 50 | parser.add_argument('--ak_type', type=int, default=-1, help='-1 if non-targeted, 1 if attribute attack') 51 | parser.add_argument('--attr_key', type=str, default='upwhite', help='[attribute, value]') 52 | parser.add_argument('--attr_value', type=int, default=2, help='[attribute, value]') 53 | parser.add_argument('--mag_in', type=float, default=16.0, help='l_inf magnitude of perturbation') 54 | parser.add_argument('--temperature', type=float, default=-1, help="tau in paper") 55 | parser.add_argument('--usegumbel', action='store_true', default=False, help='whether to use gumbel softmax') 56 | parser.add_argument('--use_SSIM', type=int, default=2, help="0: None, 1: SSIM, 2: MS-SSIM ") 57 | # Base 58 | parser.add_argument('--train_batch', default=32, type=int,help="train batch size") 59 | parser.add_argument('--test_batch', default=32, type=int, help="test batch size") 60 | parser.add_argument('--epoch', type=int, default=50, help='number of epochs to train for') 61 | 62 | parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss") 63 | parser.add_argument('--num_ker', type=int, default=32, help='generator filters in first conv layer') 64 | parser.add_argument('--lr', type=float, default=0.0002, help='Learning Rate. Default=0.002') 65 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 66 | parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123') 67 | parser.add_argument('--print_freq', type=int, default=20, help="print frequency") 68 | parser.add_argument('--eval_freq', type=int, default=1, help="eval frequency") 69 | parser.add_argument('--usevis', action='store_true', default=False, help='whether to save vis') 70 | 71 | args = parser.parse_args() 72 | is_training = args.mode == 'train' 73 | attr_list = [args.attr_key, args.attr_value] 74 | attr_matrix = None 75 | if args.attr_dir: 76 | assert args.dataset in ['dukemtmcreid', 'market1501'] 77 | attr_matrix = io.loadmat(args.attr_dir) 78 | args.ablation = osp.join('attr', args.attr_key + '=' + str(args.attr_value)) 79 | 80 | pre_dir = osp.join(args.pre_dir, args.targetmodel, args.dataset+'.pth.tar') 81 | save_dir = osp.join(args.save_dir, args.targetmodel, args.dataset, args.ablation) 82 | vis_dir = osp.join(args.vis_dir, args.targetmodel, args.dataset, args.ablation) 83 | 84 | 85 | def main(opt): 86 | if not osp.exists(save_dir): os.makedirs(save_dir) 87 | if not osp.exists(vis_dir): os.makedirs(vis_dir) 88 | 89 | use_gpu = torch.cuda.is_available() 90 | pin_memory = True if use_gpu else False 91 | 92 | if args.mode == 'train': 93 | sys.stdout = Logger(osp.join(save_dir, 'log_train.txt')) 94 | else: 95 | sys.stdout = Logger(osp.join(save_dir, 'log_test.txt')) 96 | print("==========\nArgs:{}\n==========".format(args)) 97 | 98 | if use_gpu: 99 | print("GPU mode") 100 | cudnn.benchmark = True 101 | torch.cuda.manual_seed(args.seed) 102 | else: 103 | print("CPU mode") 104 | 105 | ### Setup dataset loader ### 106 | print("Initializing dataset {}".format(args.dataset)) 107 | dataset = data_manager.init_img_dataset(root=args.root, name=args.dataset, split_id=opt['split_id'], cuhk03_labeled=opt['cuhk03_labeled'], cuhk03_classic_split=opt['cuhk03_classic_split']) 108 | if args.ak_type < 0: 109 | trainloader = DataLoader(ImageDataset(dataset.train, transform=opt['transform_train']), sampler=RandomIdentitySampler(dataset.train, num_instances=opt['num_instances']), batch_size=args.train_batch, num_workers=opt['workers'], pin_memory=pin_memory, drop_last=True) 110 | elif args.ak_type > 0: 111 | trainloader = DataLoader(ImageDataset(dataset.train, transform=opt['transform_train']), sampler=AttrPool(dataset.train, args.dataset, attr_matrix, attr_list, sample_num=16), batch_size=args.train_batch, num_workers=opt['workers'], pin_memory=pin_memory, drop_last=True) 112 | queryloader = DataLoader(ImageDataset(dataset.query, transform=opt['transform_test']), batch_size=args.test_batch, shuffle=False, num_workers=opt['workers'], pin_memory=pin_memory, drop_last=False) 113 | galleryloader = DataLoader(ImageDataset(dataset.gallery, transform=opt['transform_test']), batch_size=args.test_batch, shuffle=False, num_workers=opt['workers'], pin_memory=pin_memory, drop_last=False) 114 | 115 | ### Prepare criterion ### 116 | if args.ak_type<0: 117 | clf_criterion = adv_CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu) if args.loss in ['xent', 'xent_htri'] else adv_CrossEntropyLoss(use_gpu=use_gpu) 118 | else: 119 | clf_criterion = nn.MultiLabelSoftMarginLoss() 120 | metric_criterion = adv_TripletLoss(margin=args.margin, ak_type=args.ak_type) 121 | criterionGAN = GANLoss() 122 | 123 | ### Prepare pretrained model ### 124 | target_net = models.init_model(name=args.targetmodel, pre_dir=pre_dir, num_classes=dataset.num_train_pids) 125 | check_freezen(target_net, need_modified=True, after_modified=False) 126 | 127 | ### Prepare main net ### 128 | G = Generator(3, 3, args.num_ker, norm=args.normalization).apply(weights_init) 129 | if args.D == 'PatchGAN': 130 | D = Pat_Discriminator(input_nc=6, norm=args.normalization).apply(weights_init) 131 | elif args.D == 'MSGAN': 132 | D = MS_Discriminator(input_nc=6, norm=args.normalization, temperature=args.temperature, use_gumbel=args.usegumbel).apply(weights_init) 133 | check_freezen(G, need_modified=True, after_modified=True) 134 | check_freezen(D, need_modified=True, after_modified=True) 135 | print("Model size: {:.5f}M".format((sum(g.numel() for g in G.parameters())+sum(d.numel() for d in D.parameters()))/1000000.0)) 136 | # setup optimizer 137 | optimizer_G = optim.Adam(G.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) 138 | optimizer_D = optim.Adam(D.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) 139 | 140 | if use_gpu: 141 | test_target_net = nn.DataParallel(target_net).cuda() if not args.targetmodel == 'pcb' else nn.DataParallel(PCB_test(target_net)).cuda() 142 | target_net = nn.DataParallel(target_net).cuda() 143 | G = nn.DataParallel(G).cuda() 144 | D = nn.DataParallel(D).cuda() 145 | 146 | if args.mode == 'test': 147 | epoch = 'test' 148 | test(G, D, test_target_net, dataset, queryloader, galleryloader, epoch, use_gpu, is_test=True) 149 | return 0 150 | 151 | # Ready 152 | start_time = time.time() 153 | train_time = 0 154 | worst_mAP, worst_rank1, worst_rank5, worst_rank10, worst_epoch = np.inf, np.inf, np.inf, np.inf, 0 155 | best_hit, best_epoch = -np.inf, 0 156 | print("==> Start training") 157 | 158 | for epoch in range(1,args.epoch+1): 159 | start_train_time = time.time() 160 | train(epoch, G, D, target_net, criterionGAN, clf_criterion, metric_criterion, optimizer_G, optimizer_D, trainloader, use_gpu) 161 | train_time += round(time.time() - start_train_time) 162 | 163 | if epoch % args.eval_freq == 0: 164 | print("==> Eval at epoch {}".format(epoch)) 165 | if args.ak_type < 0: 166 | cmc, mAP = test(G, D, test_target_net, dataset, queryloader, galleryloader, epoch, use_gpu, is_test=False) 167 | is_worst = cmc[0]<=worst_rank1 and cmc[1]<=worst_rank5 and cmc[2]<=worst_rank10 and mAP<=worst_mAP 168 | if is_worst: 169 | worst_mAP, worst_rank1, worst_epoch = mAP, cmc[0], epoch 170 | print("==> Worst_epoch is {}, Worst mAP {:.1%}, Worst rank-1 {:.1%}".format(worst_epoch, worst_mAP, worst_rank1)) 171 | save_checkpoint(G.state_dict(), is_worst, 'G', osp.join(save_dir, 'G_ep' + str(epoch) + '.pth.tar')) 172 | save_checkpoint(D.state_dict(), is_worst, 'D', osp.join(save_dir, 'D_ep' + str(epoch) + '.pth.tar')) 173 | 174 | else: 175 | all_hits = test(G, D, target_net, dataset, queryloader, galleryloader, epoch, use_gpu, is_test=False) 176 | is_best = all_hits[0]>=best_hit 177 | if is_best: 178 | best_hit, best_epoch = all_hits[0], epoch 179 | print("==> Best_epoch is {}, Best rank-1 {:.1%}".format(best_epoch, best_hit)) 180 | save_checkpoint(G.state_dict(), is_best, 'G', osp.join(save_dir, 'G_ep' + str(epoch) + '.pth.tar')) 181 | save_checkpoint(D.state_dict(), is_best, 'D', osp.join(save_dir, 'D_ep' + str(epoch) + '.pth.tar')) 182 | 183 | elapsed = round(time.time() - start_time) 184 | elapsed = str(datetime.timedelta(seconds=elapsed)) 185 | train_time = str(datetime.timedelta(seconds=train_time)) 186 | print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time)) 187 | 188 | def train(epoch, G, D, target_net, criterionGAN, clf_criterion, metric_criterion, optimizer_G, optimizer_D, trainloader, use_gpu): 189 | G.train() 190 | D.train() 191 | global is_training 192 | is_training = True 193 | 194 | for batch_idx, (imgs, pids, _, pids_raw) in enumerate(trainloader): 195 | if use_gpu: 196 | imgs, pids, pids_raw = imgs.cuda(), pids.cuda(), pids_raw.cuda() 197 | 198 | new_imgs, mask = perturb(imgs, G, D, train_or_test='train') 199 | new_imgs = new_imgs.cuda() 200 | mask = mask.cuda() 201 | # Fake Detection and Loss 202 | pred_fake_pool, _ = D(torch.cat((imgs, new_imgs.detach()), 1)) 203 | loss_D_fake = criterionGAN(pred_fake_pool, False) 204 | 205 | # Real Detection and Loss 206 | num = args.train_batch//2 207 | pred_real, _ = D(torch.cat((imgs[0:num,:,:,:], imgs[num:,:,:,:].detach()), 1)) 208 | loss_D_real = criterionGAN(pred_real, True) 209 | 210 | # GAN loss (Fake Passability Loss) 211 | pred_fake, _ = D(torch.cat((imgs, new_imgs), 1)) 212 | loss_G_GAN = criterionGAN(pred_fake, True) 213 | 214 | # Re-ID advloss 215 | ls = target_net(new_imgs, is_training) 216 | if len(ls) == 1: new_outputs = ls[0] 217 | if len(ls) == 2: new_outputs, new_features = ls 218 | if len(ls) == 3: new_outputs, new_features, new_local_features = ls 219 | xent_loss, global_loss, loss_G_ssim = 0, 0, 0 220 | targets = None 221 | 222 | if args.loss in ['cent', 'xent', 'xent_htri']: 223 | if args.ak_type < 0: 224 | xent_loss = DeepSupervision(clf_criterion, new_outputs, pids) if isinstance(new_features, (tuple, list)) else clf_criterion(new_outputs, pids) 225 | 226 | elif args.ak_type > 0: 227 | targets = get_target_withattr(attr_matrix, args.dataset, attr_list, pids, pids_raw).float().cuda() 228 | xent_loss = 0#DeepSupervision(clf_criterion, new_outputs, targets) if isinstance(new_features, (tuple, list)) else clf_criterion(new_outputs, targets) 229 | 230 | if args.loss in ['htri', 'xent_htri']: 231 | assert len(ls) >= 2 232 | global_loss = DeepSupervision(metric_criterion, new_features, pids, targets) if isinstance(new_features, (tuple, list)) else metric_criterion(new_features, pids, targets) 233 | 234 | loss_G_ReID = (xent_loss+ global_loss)*opt['ReID_factor'] 235 | 236 | # # SSIM loss 237 | if not args.use_SSIM == 0: 238 | from util.ms_ssim import msssim, ssim 239 | loss_func = msssim if args.use_SSIM == 2 else ssim 240 | loss_G_ssim = (1-loss_func(imgs, new_imgs))*0.1 241 | 242 | ############## Forward ############### 243 | loss_D = (loss_D_fake + loss_D_real)/2 244 | loss_G = loss_G_GAN + loss_G_ReID + loss_G_ssim 245 | ############## Backward ############# 246 | # update generator weights 247 | optimizer_G.zero_grad() 248 | # loss_G.backward(retain_graph=True) 249 | loss_G.backward() 250 | optimizer_G.step() 251 | # update discriminator weights 252 | optimizer_D.zero_grad() 253 | loss_D.backward() 254 | optimizer_D.step() 255 | if (batch_idx+1) % args.print_freq == 0: 256 | print("===> Epoch[{}]({}/{}) loss_D: {:.4f} loss_G_GAN: {:.4f} loss_G_ReID: {:.4f} loss_G_SSIM: {:.4f}".format(epoch, batch_idx, len(trainloader), loss_D.item(), loss_G_GAN.item(), loss_G_ReID.item(), loss_G_ssim)) 257 | 258 | def test(G, D, target_net, dataset, queryloader, galleryloader, epoch, use_gpu, is_test=False, ranks=[1, 5, 10, 20]): 259 | global is_training 260 | is_training = False 261 | if args.mode == 'test' and args.G_resume_dir: 262 | G_resume_dir, D_resume_dir = args.G_resume_dir, args.G_resume_dir.replace('G', 'D') 263 | G_checkpoint, D_checkpoint = torch.load(G_resume_dir), torch.load(D_resume_dir) 264 | G_state_dict = G_checkpoint['state_dict'] if isinstance(G_checkpoint, dict) and 'state_dict' in G_checkpoint else G_checkpoint 265 | D_state_dict = D_checkpoint['state_dict'] if isinstance(D_checkpoint, dict) and 'state_dict' in D_checkpoint else D_checkpoint 266 | 267 | G.load_state_dict(G_state_dict) 268 | D.load_state_dict(D_state_dict) 269 | print("Sucessfully, loading {} and {}".format(G_resume_dir, D_resume_dir)) 270 | 271 | with torch.no_grad(): 272 | qf, lqf, new_qf, new_lqf, q_pids, q_camids = extract_and_perturb(queryloader, G, D, target_net, use_gpu, query_or_gallery='query', is_test=is_test, epoch=epoch) 273 | gf, lgf, g_pids, g_camids = extract_and_perturb(galleryloader, G, D, target_net, use_gpu, query_or_gallery='gallery', is_test=is_test, epoch=epoch) 274 | 275 | if args.ak_type > 0: 276 | distmat, hits, ignore_list = make_results(new_qf, gf, new_lqf, lgf, q_pids, g_pids, q_camids, g_camids, args.targetmodel, args.ak_type, attr_matrix, args.dataset, attr_list) 277 | print("Hits rate, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}".format(ranks[0], hits[ranks[0]-1], ranks[1], hits[ranks[1]-1], ranks[2], hits[ranks[2]-1], ranks[3], hits[ranks[3]-1])) 278 | if not is_test: 279 | return hits 280 | 281 | else: 282 | if is_test: 283 | distmat, cmc, mAP = make_results(qf, gf, lqf, lgf, q_pids, g_pids, q_camids, g_camids, args.targetmodel, args.ak_type) 284 | new_distmat, new_cmc, new_mAP = make_results(new_qf, gf, new_lqf, lgf, q_pids, g_pids, q_camids, g_camids, args.targetmodel, args.ak_type) 285 | print("Results ----------") 286 | print("Before, mAP: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}".format(mAP, ranks[0], cmc[ranks[0]-1], ranks[1], cmc[ranks[1]-1], ranks[2], cmc[ranks[2]-1], ranks[3], cmc[ranks[3]-1])) 287 | print("After , mAP: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}".format(new_mAP, ranks[0], new_cmc[ranks[0]-1], ranks[1], new_cmc[ranks[1]-1], ranks[2], new_cmc[ranks[2]-1], ranks[3], new_cmc[ranks[3]-1])) 288 | if args.usevis: 289 | visualize_ranked_results(distmat, dataset, save_dir=osp.join(vis_dir, 'origin_results'), topk=20) 290 | if args.usevis: 291 | visualize_ranked_results(new_distmat, dataset, save_dir=osp.join(vis_dir, 'polluted_results'), topk=20) 292 | else: 293 | _, new_cmc, new_mAP = make_results(new_qf, gf, new_lqf, lgf, q_pids, g_pids, q_camids, g_camids, args.targetmodel, args.ak_type) 294 | print("mAP: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}".format(new_mAP, ranks[0], new_cmc[ranks[0]-1], ranks[1], new_cmc[ranks[1]-1], ranks[2], new_cmc[ranks[2]-1], ranks[3], new_cmc[ranks[3]-1])) 295 | return new_cmc, new_mAP 296 | 297 | def extract_and_perturb(loader, G, D, target_net, use_gpu, query_or_gallery, is_test, epoch): 298 | f, lf, new_f, new_lf, l_pids, l_camids = [], [], [], [], [], [] 299 | ave_mask, num = 0, 0 300 | for batch_idx, (imgs, pids, camids, pids_raw) in enumerate(loader): 301 | if use_gpu: 302 | imgs = imgs.cuda() 303 | ls = extract(imgs, target_net) 304 | if len(ls) == 1: features = ls[0] 305 | if len(ls) == 2: 306 | features, local_features = ls 307 | lf.append(local_features.detach().data.cpu()) 308 | 309 | f.append(features.detach().data.cpu()) 310 | l_pids.extend(pids) 311 | l_camids.extend(camids) 312 | 313 | if query_or_gallery == 'query': 314 | G.eval() 315 | D.eval() 316 | new_imgs, delta, mask = perturb(imgs, G, D, train_or_test='test') 317 | ave_mask += torch.sum(mask.detach()).cpu().numpy() 318 | num += imgs.size(0) 319 | 320 | ls = extract(new_imgs, target_net) 321 | if len(ls) == 1: new_features = ls[0] 322 | if len(ls) == 2: 323 | new_features, new_local_features = ls 324 | new_lf.append(new_local_features.detach().data.cpu()) 325 | new_f.append(new_features.detach().data.cpu()) 326 | 327 | ls = [imgs, new_imgs, delta, mask] 328 | if is_test: 329 | save_img(ls, pids, camids, epoch, batch_idx) 330 | 331 | f = torch.cat(f, 0) 332 | if not lf == []: lf = torch.cat(lf, 0) 333 | l_pids, l_camids = np.asarray(l_pids), np.asarray(l_camids) 334 | 335 | print("Extracted features for {} set, obtained {}-by-{} matrix".format(query_or_gallery, f.size(0), f.size(1))) 336 | if query_or_gallery == 'gallery': 337 | return [f, lf, l_pids, l_camids] 338 | elif query_or_gallery == 'query': 339 | new_f = torch.cat(new_f, 0) 340 | if not new_lf == []: 341 | new_lf = torch.cat(new_lf, 0) 342 | return [f, lf, new_f, new_lf, l_pids, l_camids] 343 | 344 | def extract(imgs, target_net): 345 | if args.targetmodel in ['pcb', 'lsro']: 346 | ls = [target_net(imgs, is_training)[0] + target_net(fliplr(imgs), is_training)[0]] 347 | else: 348 | ls = target_net(imgs, is_training) 349 | for i in range(len(ls)): ls[i] = ls[i].data.cpu() 350 | return ls 351 | 352 | def perturb(imgs, G, D, train_or_test='test'): 353 | n,c,h,w = imgs.size() 354 | delta = G(imgs) 355 | delta = L_norm(delta, train_or_test) 356 | new_imgs = torch.add(imgs.cuda(), delta[0:imgs.size(0)].cuda()) 357 | 358 | _, mask = D(torch.cat((imgs, new_imgs.detach()), 1)) 359 | delta = delta * mask 360 | new_imgs = torch.add(imgs.cuda(), delta[0:imgs.size(0)].cuda()) 361 | 362 | for c in range(3): 363 | new_imgs.data[:,c,:,:] = new_imgs.data[:,c,:,:].clamp(new_imgs.data[:,c,:,:].min(), new_imgs.data[:,c,:,:].max()) # do clamping per channel 364 | if train_or_test == 'train': 365 | return new_imgs, mask 366 | elif train_or_test == 'test': 367 | return new_imgs, delta, mask 368 | 369 | def L_norm(delta, mode='train'): 370 | delta.data += 1 371 | delta.data *= 0.5 372 | 373 | for c in range(3): 374 | delta.data[:,c,:,:] = (delta.data[:,c,:,:] - Imagenet_mean[c]) / Imagenet_stddev[c] 375 | 376 | bs = args.train_batch if (mode == 'train') else args.test_batch 377 | for i in range(bs): 378 | # do per channel l_inf normalization 379 | for ci in range(3): 380 | try: 381 | l_inf_channel = delta[i,ci,:,:].data.abs().max() 382 | # l_inf_channel = torch.norm(delta[i,ci,:,:]).data 383 | mag_in_scaled_c = args.mag_in/(255.0*Imagenet_stddev[ci]) 384 | delta[i,ci,:,:].data *= np.minimum(1.0, mag_in_scaled_c / l_inf_channel.cpu()).float().cuda() 385 | except IndexError: 386 | break 387 | return delta 388 | 389 | def save_img(ls, pids, camids, epoch, batch_idx): 390 | image, new_image, delta, mask = ls 391 | # undo normalize image color channels 392 | delta_tmp = torch.zeros(delta.size()) 393 | for c in range(3): 394 | image.data[:,c,:,:] = (image.data[:,c,:,:] * Imagenet_stddev[c]) + Imagenet_mean[c] 395 | new_image.data[:,c,:,:] = (new_image.data[:,c,:,:] * Imagenet_stddev[c]) + Imagenet_mean[c] 396 | delta_tmp.data[:,c,:,:] = (delta.data[:,c,:,:] * Imagenet_stddev[c]) + Imagenet_mean[c] 397 | 398 | if args.usevis: 399 | torchvision.utils.save_image(image.data, osp.join(vis_dir, 'original_epoch{}_batch{}.png'.format(epoch, batch_idx))) 400 | torchvision.utils.save_image(new_image.data, osp.join(vis_dir, 'polluted_epoch{}_batch{}.png'.format(epoch, batch_idx))) 401 | torchvision.utils.save_image(delta_tmp.data, osp.join(vis_dir, 'delta_epoch{}_batch{}.png'.format(epoch, batch_idx))) 402 | torchvision.utils.save_image(mask.data*255, osp.join(vis_dir, 'mask_epoch{}_batch{}.png'.format(epoch, batch_idx))) 403 | 404 | def check_freezen(net, need_modified=False, after_modified=None): 405 | # print(net) 406 | cc = 0 407 | for child in net.children(): 408 | for param in child.parameters(): 409 | if need_modified: param.requires_grad = after_modified 410 | # if param.requires_grad: print('child', cc , 'was active') 411 | # else: print('child', cc , 'was forzen') 412 | cc += 1 413 | 414 | if __name__ == '__main__': 415 | opt = get_opts(args.targetmodel) 416 | main(opt) 417 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torchvision 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms as transforms 6 | from util import transforms as T 7 | 8 | Imagenet_mean = [0.485, 0.456, 0.406] 9 | Imagenet_stddev = [0.229, 0.224, 0.225] 10 | 11 | market1501_train_map = {2: 0, 7: 1, 10: 2, 11: 3, 12: 4, 20: 5, 22: 6, 23: 7, 27: 8, 28: 9, 30: 10, 32: 11, 35: 12, 37: 13, 42: 14, 43: 15, 46: 16, 47: 17, 48: 18, 52: 19, 53: 20, 56: 21, 57: 22, 59: 23, 64: 24, 65: 25, 67: 26, 68: 27, 69: 28, 70: 29, 76: 30, 77: 31, 79: 32, 81: 33, 82: 34, 84: 35, 86: 36, 88: 37, 90: 38, 93: 39, 95: 40, 97: 41, 98: 42, 99: 43, 100: 44, 104: 45, 105: 46, 106: 47, 107: 48, 108: 49, 110: 50, 111: 51, 114: 52, 115: 53, 116: 54, 117: 55, 118: 56, 121: 57, 122: 58, 123: 59, 125: 60, 127: 61, 129: 62, 132: 63, 134: 64, 135: 65, 136: 66, 139: 67, 140: 68, 141: 69, 142: 70, 143: 71, 148: 72, 149: 73, 150: 74, 151: 75, 158: 76, 159: 77, 160: 78, 162: 79, 164: 80, 166: 81, 167: 82, 169: 83, 172: 84, 173: 85, 175: 86, 176: 87, 177: 88, 178: 89, 179: 90, 180: 91, 181: 92, 184: 93, 185: 94, 190: 95, 193: 96, 195: 97, 197: 98, 199: 99, 201: 100, 202: 101, 204: 102, 206: 103, 208: 104, 209: 105, 211: 106, 212: 107, 214: 108, 216: 109, 221: 110, 222: 111, 223: 112, 224: 113, 225: 114, 232: 115, 234: 116, 236: 117, 237: 118, 239: 119, 241: 120, 242: 121, 243: 122, 245: 123, 248: 124, 249: 125, 250: 126, 251: 127, 254: 128, 255: 129, 259: 130, 261: 131, 264: 132, 266: 133, 268: 134, 269: 135, 272: 136, 273: 137, 276: 138, 277: 139, 279: 140, 281: 141, 282: 142, 287: 143, 296: 144, 297: 145, 298: 146, 299: 147, 301: 148, 303: 149, 306: 150, 307: 151, 308: 152, 309: 153, 313: 154, 314: 155, 317: 156, 318: 157, 321: 158, 323: 159, 324: 160, 325: 161, 326: 162, 327: 163, 328: 164, 331: 165, 332: 166, 333: 167, 335: 168, 338: 169, 339: 170, 340: 171, 341: 172, 344: 173, 347: 174, 348: 175, 349: 176, 350: 177, 352: 178, 354: 179, 357: 180, 358: 181, 359: 182, 361: 183, 367: 184, 368: 185, 369: 186, 370: 187, 371: 188, 374: 189, 375: 190, 376: 191, 377: 192, 379: 193, 380: 194, 382: 195, 383: 196, 384: 197, 385: 198, 386: 199, 389: 200, 390: 201, 392: 202, 393: 203, 394: 204, 397: 205, 398: 206, 399: 207, 402: 208, 403: 209, 404: 210, 407: 211, 408: 212, 409: 213, 410: 214, 411: 215, 413: 216, 414: 217, 415: 218, 419: 219, 420: 220, 421: 221, 423: 222, 424: 223, 427: 224, 429: 225, 430: 226, 432: 227, 433: 228, 434: 229, 435: 230, 437: 231, 441: 232, 442: 233, 444: 234, 445: 235, 446: 236, 449: 237, 450: 238, 451: 239, 456: 240, 457: 241, 459: 242, 464: 243, 466: 244, 468: 245, 470: 246, 472: 247, 475: 248, 477: 249, 480: 250, 481: 251, 482: 252, 484: 253, 485: 254, 486: 255, 491: 256, 494: 257, 496: 258, 499: 259, 500: 260, 503: 261, 508: 262, 509: 263, 513: 264, 515: 265, 516: 266, 517: 267, 518: 268, 519: 269, 522: 270, 524: 271, 525: 272, 528: 273, 529: 274, 534: 275, 536: 276, 537: 277, 539: 278, 540: 279, 545: 280, 546: 281, 547: 282, 549: 283, 551: 284, 552: 285, 554: 286, 555: 287, 556: 288, 557: 289, 558: 290, 563: 291, 564: 292, 565: 293, 566: 294, 570: 295, 571: 296, 572: 297, 573: 298, 575: 299, 579: 300, 581: 301, 584: 302, 586: 303, 588: 304, 589: 305, 592: 306, 593: 307, 594: 308, 596: 309, 597: 310, 599: 311, 603: 312, 604: 313, 605: 314, 606: 315, 611: 316, 612: 317, 613: 318, 614: 319, 615: 320, 616: 321, 619: 322, 620: 323, 622: 324, 623: 325, 628: 326, 629: 327, 630: 328, 633: 329, 635: 330, 636: 331, 637: 332, 639: 333, 640: 334, 641: 335, 642: 336, 645: 337, 647: 338, 648: 339, 649: 340, 652: 341, 653: 342, 655: 343, 656: 344, 657: 345, 658: 346, 659: 347, 660: 348, 661: 349, 662: 350, 663: 351, 665: 352, 666: 353, 667: 354, 669: 355, 670: 356, 673: 357, 674: 358, 676: 359, 677: 360, 681: 361, 682: 362, 683: 363, 685: 364, 688: 365, 689: 366, 696: 367, 697: 368, 700: 369, 701: 370, 702: 371, 703: 372, 704: 373, 705: 374, 706: 375, 707: 376, 708: 377, 709: 378, 711: 379, 712: 380, 714: 381, 718: 382, 724: 383, 726: 384, 729: 385, 730: 386, 733: 387, 734: 388, 738: 389, 739: 390, 741: 391, 742: 392, 744: 393, 748: 394, 749: 395, 752: 396, 754: 397, 755: 398, 757: 399, 759: 400, 760: 401, 761: 402, 762: 403, 765: 404, 766: 405, 767: 406, 772: 407, 773: 408, 774: 409, 779: 410, 780: 411, 781: 412, 782: 413, 785: 414, 787: 415, 788: 416, 792: 417, 793: 418, 795: 419, 796: 420, 802: 421, 803: 422, 806: 423, 809: 424, 810: 425, 814: 426, 816: 427, 818: 428, 820: 429, 821: 430, 823: 431, 826: 432, 828: 433, 830: 434, 832: 435, 833: 436, 837: 437, 839: 438, 840: 439, 842: 440, 843: 441, 844: 442, 848: 443, 849: 444, 850: 445, 851: 446, 854: 447, 855: 448, 857: 449, 859: 450, 862: 451, 863: 452, 864: 453, 868: 454, 871: 455, 872: 456, 875: 457, 876: 458, 879: 459, 882: 460, 883: 461, 885: 462, 886: 463, 887: 464, 890: 465, 891: 466, 892: 467, 893: 468, 894: 469, 895: 470, 896: 471, 898: 472, 900: 473, 901: 474, 902: 475, 903: 476, 904: 477, 905: 478, 907: 479, 914: 480, 915: 481, 917: 482, 919: 483, 926: 484, 930: 485, 933: 486, 936: 487, 939: 488, 940: 489, 941: 490, 942: 491, 943: 492, 945: 493, 946: 494, 947: 495, 948: 496, 952: 497, 953: 498, 954: 499, 955: 500, 957: 501, 958: 502, 961: 503, 962: 504, 963: 505, 967: 506, 969: 507, 970: 508, 971: 509, 972: 510, 973: 511, 975: 512, 976: 513, 979: 514, 982: 515, 984: 516, 986: 517, 987: 518, 988: 519, 990: 520, 991: 521, 992: 522, 994: 523, 995: 524, 997: 525, 998: 526, 999: 527, 1000: 528, 1001: 529, 1002: 530, 1003: 531, 1004: 532, 1007: 533, 1010: 534, 1011: 535, 1012: 536, 1017: 537, 1018: 538, 1019: 539, 1023: 540, 1025: 541, 1027: 542, 1030: 543, 1031: 544, 1032: 545, 1033: 546, 1038: 547, 1039: 548, 1041: 549, 1045: 550, 1048: 551, 1049: 552, 1051: 553, 1052: 554, 1055: 555, 1056: 556, 1066: 557, 1071: 558, 1072: 559, 1075: 560, 1076: 561, 1078: 562, 1079: 563, 1080: 564, 1081: 565, 1086: 566, 1088: 567, 1091: 568, 1093: 569, 1094: 570, 1096: 571, 1097: 572, 1098: 573, 1099: 574, 1100: 575, 1101: 576, 1106: 577, 1107: 578, 1110: 579, 1111: 580, 1112: 581, 1113: 582, 1114: 583, 1115: 584, 1116: 585, 1117: 586, 1123: 587, 1124: 588, 1126: 589, 1127: 590, 1129: 591, 1132: 592, 1134: 593, 1135: 594, 1138: 595, 1140: 596, 1142: 597, 1152: 598, 1157: 599, 1158: 600, 1159: 601, 1162: 602, 1165: 603, 1167: 604, 1168: 605, 1169: 606, 1173: 607, 1176: 608, 1177: 609, 1178: 610, 1179: 611, 1189: 612, 1193: 613, 1197: 614, 1198: 615, 1200: 616, 1201: 617, 1204: 618, 1206: 619, 1213: 620, 1217: 621, 1218: 622, 1219: 623, 1220: 624, 1227: 625, 1230: 626, 1231: 627, 1232: 628, 1234: 629, 1235: 630, 1237: 631, 1238: 632, 1240: 633, 1242: 634, 1243: 635, 1244: 636, 1250: 637, 1252: 638, 1253: 639, 1254: 640, 1257: 641, 1258: 642, 1260: 643, 1261: 644, 1263: 645, 1266: 646, 1269: 647, 1275: 648, 1278: 649, 1281: 650, 1286: 651, 1289: 652, 1291: 653, 1292: 654, 1294: 655, 1295: 656, 1296: 657, 1297: 658, 1300: 659, 1303: 660, 1304: 661, 1309: 662, 1313: 663, 1315: 664, 1316: 665, 1318: 666, 1320: 667, 1321: 668, 1325: 669, 1326: 670, 1327: 671, 1330: 672, 1331: 673, 1332: 674, 1334: 675, 1335: 676, 1336: 677, 1338: 678, 1339: 679, 1341: 680, 1343: 681, 1344: 682, 1346: 683, 1350: 684, 1353: 685, 1358: 686, 1363: 687, 1364: 688, 1365: 689, 1368: 690, 1372: 691, 1373: 692, 1379: 693, 1380: 694, 1381: 695, 1385: 696, 1386: 697, 1389: 698, 1391: 699, 1392: 700, 1393: 701, 1400: 702, 1402: 703, 1404: 704, 1405: 705, 1406: 706, 1407: 707, 1408: 708, 1409: 709, 1411: 710, 1415: 711, 1420: 712, 1421: 713, 1422: 714, 1426: 715, 1427: 716, 1428: 717, 1430: 718, 1432: 719, 1433: 720, 1434: 721, 1437: 722, 1442: 723, 1443: 724, 1445: 725, 1447: 726, 1449: 727, 1451: 728, 1453: 729, 1454: 730, 1455: 731, 1458: 732, 1463: 733, 1464: 734, 1466: 735, 1467: 736, 1469: 737, 1470: 738, 1471: 739, 1473: 740, 1474: 741, 1475: 742, 1479: 743, 1480: 744, 1487: 745, 1489: 746, 1492: 747, 1495: 748, 1496: 749, 1500: 750} 12 | market1501_test_map = {0: 0, 1: 1, 3: 2, 4: 3, 5: 4, 6: 5, 8: 6, 9: 7, 13: 8, 14: 9, 15: 10, 16: 11, 17: 12, 18: 13, 19: 14, 21: 15, 24: 16, 25: 17, 26: 18, 29: 19, 31: 20, 33: 21, 34: 22, 36: 23, 38: 24, 39: 25, 40: 26, 41: 27, 44: 28, 45: 29, 49: 30, 50: 31, 51: 32, 54: 33, 55: 34, 58: 35, 60: 36, 61: 37, 62: 38, 63: 39, 66: 40, 71: 41, 72: 42, 73: 43, 74: 44, 75: 45, 78: 46, 80: 47, 83: 48, 85: 49, 87: 50, 89: 51, 91: 52, 92: 53, 94: 54, 96: 55, 101: 56, 102: 57, 103: 58, 109: 59, 112: 60, 113: 61, 119: 62, 120: 63, 124: 64, 126: 65, 128: 66, 130: 67, 131: 68, 133: 69, 137: 70, 138: 71, 144: 72, 145: 73, 146: 74, 147: 75, 152: 76, 153: 77, 154: 78, 155: 79, 156: 80, 157: 81, 161: 82, 163: 83, 165: 84, 168: 85, 170: 86, 171: 87, 174: 88, 182: 89, 183: 90, 186: 91, 187: 92, 188: 93, 189: 94, 191: 95, 192: 96, 194: 97, 196: 98, 198: 99, 200: 100, 203: 101, 205: 102, 207: 103, 210: 104, 213: 105, 215: 106, 217: 107, 218: 108, 219: 109, 220: 110, 226: 111, 227: 112, 228: 113, 229: 114, 230: 115, 231: 116, 233: 117, 235: 118, 238: 119, 240: 120, 244: 121, 246: 122, 247: 123, 252: 124, 253: 125, 256: 126, 257: 127, 258: 128, 260: 129, 262: 130, 263: 131, 265: 132, 267: 133, 270: 134, 271: 135, 274: 136, 275: 137, 278: 138, 280: 139, 283: 140, 284: 141, 285: 142, 286: 143, 288: 144, 289: 145, 290: 146, 291: 147, 292: 148, 293: 149, 294: 150, 295: 151, 300: 152, 302: 153, 304: 154, 305: 155, 310: 156, 311: 157, 312: 158, 315: 159, 316: 160, 319: 161, 320: 162, 322: 163, 329: 164, 330: 165, 334: 166, 336: 167, 337: 168, 342: 169, 343: 170, 345: 171, 346: 172, 351: 173, 353: 174, 355: 175, 356: 176, 360: 177, 362: 178, 363: 179, 364: 180, 365: 181, 366: 182, 372: 183, 373: 184, 378: 185, 381: 186, 387: 187, 388: 188, 391: 189, 395: 190, 396: 191, 400: 192, 401: 193, 405: 194, 406: 195, 412: 196, 416: 197, 417: 198, 418: 199, 422: 200, 425: 201, 426: 202, 428: 203, 431: 204, 436: 205, 438: 206, 439: 207, 440: 208, 443: 209, 447: 210, 448: 211, 452: 212, 453: 213, 454: 214, 455: 215, 458: 216, 460: 217, 461: 218, 462: 219, 463: 220, 465: 221, 467: 222, 469: 223, 471: 224, 473: 225, 474: 226, 476: 227, 478: 228, 479: 229, 483: 230, 487: 231, 488: 232, 489: 233, 490: 234, 492: 235, 493: 236, 495: 237, 497: 238, 498: 239, 501: 240, 502: 241, 504: 242, 505: 243, 506: 244, 507: 245, 510: 246, 511: 247, 512: 248, 514: 249, 520: 250, 521: 251, 523: 252, 526: 253, 527: 254, 530: 255, 531: 256, 532: 257, 533: 258, 535: 259, 538: 260, 541: 261, 542: 262, 543: 263, 544: 264, 548: 265, 550: 266, 553: 267, 559: 268, 560: 269, 561: 270, 562: 271, 567: 272, 568: 273, 569: 274, 574: 275, 576: 276, 577: 277, 578: 278, 580: 279, 582: 280, 583: 281, 585: 282, 587: 283, 590: 284, 591: 285, 595: 286, 598: 287, 600: 288, 601: 289, 602: 290, 607: 291, 608: 292, 609: 293, 610: 294, 617: 295, 618: 296, 621: 297, 624: 298, 625: 299, 626: 300, 627: 301, 631: 302, 632: 303, 634: 304, 638: 305, 643: 306, 644: 307, 646: 308, 650: 309, 651: 310, 654: 311, 664: 312, 668: 313, 671: 314, 672: 315, 675: 316, 678: 317, 679: 318, 680: 319, 684: 320, 686: 321, 687: 322, 690: 323, 691: 324, 692: 325, 693: 326, 694: 327, 695: 328, 698: 329, 699: 330, 710: 331, 713: 332, 715: 333, 716: 334, 717: 335, 719: 336, 720: 337, 721: 338, 722: 339, 723: 340, 725: 341, 727: 342, 728: 343, 731: 344, 732: 345, 735: 346, 736: 347, 737: 348, 740: 349, 743: 350, 745: 351, 746: 352, 747: 353, 750: 354, 751: 355, 753: 356, 756: 357, 758: 358, 763: 359, 764: 360, 768: 361, 769: 362, 770: 363, 771: 364, 775: 365, 776: 366, 777: 367, 778: 368, 783: 369, 784: 370, 786: 371, 789: 372, 790: 373, 791: 374, 794: 375, 797: 376, 798: 377, 799: 378, 800: 379, 801: 380, 804: 381, 805: 382, 807: 383, 808: 384, 811: 385, 812: 386, 813: 387, 815: 388, 817: 389, 819: 390, 822: 391, 824: 392, 825: 393, 827: 394, 829: 395, 831: 396, 834: 397, 835: 398, 836: 399, 838: 400, 841: 401, 845: 402, 846: 403, 847: 404, 852: 405, 853: 406, 856: 407, 858: 408, 860: 409, 861: 410, 865: 411, 866: 412, 867: 413, 869: 414, 870: 415, 873: 416, 874: 417, 877: 418, 878: 419, 880: 420, 881: 421, 884: 422, 888: 423, 889: 424, 897: 425, 899: 426, 906: 427, 908: 428, 909: 429, 910: 430, 911: 431, 912: 432, 913: 433, 916: 434, 918: 435, 920: 436, 921: 437, 922: 438, 923: 439, 924: 440, 925: 441, 927: 442, 928: 443, 929: 444, 931: 445, 932: 446, 934: 447, 935: 448, 937: 449, 938: 450, 944: 451, 949: 452, 950: 453, 951: 454, 956: 455, 959: 456, 960: 457, 964: 458, 965: 459, 966: 460, 968: 461, 974: 462, 977: 463, 978: 464, 980: 465, 981: 466, 983: 467, 985: 468, 989: 469, 993: 470, 996: 471, 1005: 472, 1006: 473, 1008: 474, 1009: 475, 1013: 476, 1014: 477, 1015: 478, 1016: 479, 1020: 480, 1021: 481, 1022: 482, 1024: 483, 1026: 484, 1028: 485, 1029: 486, 1034: 487, 1035: 488, 1036: 489, 1037: 490, 1040: 491, 1042: 492, 1043: 493, 1044: 494, 1046: 495, 1047: 496, 1050: 497, 1053: 498, 1054: 499, 1057: 500, 1058: 501, 1059: 502, 1060: 503, 1061: 504, 1062: 505, 1063: 506, 1064: 507, 1065: 508, 1067: 509, 1068: 510, 1069: 511, 1070: 512, 1073: 513, 1074: 514, 1077: 515, 1082: 516, 1083: 517, 1084: 518, 1085: 519, 1087: 520, 1089: 521, 1090: 522, 1092: 523, 1095: 524, 1102: 525, 1103: 526, 1104: 527, 1105: 528, 1108: 529, 1109: 530, 1118: 531, 1119: 532, 1120: 533, 1121: 534, 1122: 535, 1125: 536, 1128: 537, 1130: 538, 1131: 539, 1133: 540, 1136: 541, 1137: 542, 1139: 543, 1141: 544, 1143: 545, 1144: 546, 1145: 547, 1146: 548, 1147: 549, 1148: 550, 1149: 551, 1150: 552, 1151: 553, 1153: 554, 1154: 555, 1155: 556, 1156: 557, 1160: 558, 1161: 559, 1163: 560, 1164: 561, 1166: 562, 1170: 563, 1171: 564, 1172: 565, 1174: 566, 1175: 567, 1180: 568, 1181: 569, 1182: 570, 1183: 571, 1184: 572, 1185: 573, 1186: 574, 1187: 575, 1188: 576, 1190: 577, 1191: 578, 1192: 579, 1194: 580, 1195: 581, 1196: 582, 1199: 583, 1202: 584, 1203: 585, 1205: 586, 1207: 587, 1208: 588, 1209: 589, 1210: 590, 1211: 591, 1212: 592, 1214: 593, 1215: 594, 1216: 595, 1221: 596, 1222: 597, 1223: 598, 1224: 599, 1225: 600, 1226: 601, 1228: 602, 1229: 603, 1233: 604, 1236: 605, 1239: 606, 1241: 607, 1245: 608, 1246: 609, 1247: 610, 1248: 611, 1249: 612, 1251: 613, 1255: 614, 1256: 615, 1259: 616, 1262: 617, 1264: 618, 1265: 619, 1267: 620, 1268: 621, 1270: 622, 1271: 623, 1272: 624, 1273: 625, 1274: 626, 1276: 627, 1277: 628, 1279: 629, 1280: 630, 1282: 631, 1283: 632, 1284: 633, 1285: 634, 1287: 635, 1288: 636, 1290: 637, 1293: 638, 1298: 639, 1299: 640, 1301: 641, 1302: 642, 1305: 643, 1306: 644, 1307: 645, 1308: 646, 1310: 647, 1311: 648, 1312: 649, 1314: 650, 1317: 651, 1319: 652, 1322: 653, 1323: 654, 1324: 655, 1328: 656, 1329: 657, 1333: 658, 1337: 659, 1340: 660, 1342: 661, 1345: 662, 1347: 663, 1348: 664, 1349: 665, 1351: 666, 1352: 667, 1354: 668, 1355: 669, 1356: 670, 1357: 671, 1359: 672, 1360: 673, 1361: 674, 1362: 675, 1366: 676, 1367: 677, 1369: 678, 1370: 679, 1371: 680, 1374: 681, 1375: 682, 1376: 683, 1377: 684, 1378: 685, 1382: 686, 1383: 687, 1384: 688, 1387: 689, 1388: 690, 1390: 691, 1394: 692, 1395: 693, 1396: 694, 1397: 695, 1398: 696, 1399: 697, 1401: 698, 1403: 699, 1410: 700, 1412: 701, 1413: 702, 1414: 703, 1416: 704, 1417: 705, 1418: 706, 1419: 707, 1423: 708, 1424: 709, 1425: 710, 1429: 711, 1431: 712, 1435: 713, 1436: 714, 1438: 715, 1439: 716, 1440: 717, 1441: 718, 1444: 719, 1446: 720, 1448: 721, 1450: 722, 1452: 723, 1456: 724, 1457: 725, 1459: 726, 1460: 727, 1461: 728, 1462: 729, 1465: 730, 1468: 731, 1472: 732, 1476: 733, 1477: 734, 1478: 735, 1481: 736, 1482: 737, 1483: 738, 1484: 739, 1485: 740, 1486: 741, 1488: 742, 1490: 743, 1491: 744, 1493: 745, 1494: 746, 1497: 747, 1498: 748, 1499: 749, 1501: 750} 13 | duke_train_map = {1: 0, 8: 1, 13: 2, 14: 3, 15: 4, 16: 5, 17: 6, 18: 7, 20: 8, 22: 9, 24: 10, 26: 11, 28: 12, 29: 13, 32: 14, 36: 15, 37: 16, 38: 17, 40: 18, 41: 19, 45: 20, 48: 21, 52: 22, 54: 23, 55: 24, 57: 25, 58: 26, 59: 27, 60: 28, 62: 29, 63: 30, 64: 31, 65: 32, 67: 33, 70: 34, 71: 35, 73: 36, 74: 37, 81: 38, 82: 39, 84: 40, 85: 41, 87: 42, 93: 43, 94: 44, 96: 45, 100: 46, 102: 47, 104: 48, 105: 49, 108: 50, 110: 51, 113: 52, 116: 53, 120: 54, 121: 55, 124: 56, 129: 57, 130: 58, 131: 59, 132: 60, 133: 61, 138: 62, 139: 63, 144: 64, 146: 65, 148: 66, 152: 67, 153: 68, 154: 69, 155: 70, 156: 71, 157: 72, 160: 73, 161: 74, 165: 75, 166: 76, 168: 77, 172: 78, 173: 79, 176: 80, 177: 81, 178: 82, 179: 83, 182: 84, 185: 85, 189: 86, 190: 87, 191: 88, 193: 89, 195: 90, 196: 91, 198: 92, 202: 93, 203: 94, 208: 95, 209: 96, 216: 97, 217: 98, 222: 99, 224: 100, 225: 101, 226: 102, 227: 103, 228: 104, 231: 105, 232: 106, 233: 107, 234: 108, 236: 109, 242: 110, 245: 111, 246: 112, 248: 113, 250: 114, 252: 115, 255: 116, 258: 117, 259: 118, 263: 119, 265: 120, 271: 121, 278: 122, 280: 123, 281: 124, 282: 125, 283: 126, 284: 127, 286: 128, 289: 129, 290: 130, 291: 131, 296: 132, 297: 133, 306: 134, 307: 135, 308: 136, 309: 137, 310: 138, 312: 139, 317: 140, 318: 141, 319: 142, 320: 143, 322: 144, 325: 145, 326: 146, 327: 147, 328: 148, 330: 149, 331: 150, 333: 151, 335: 152, 336: 153, 338: 154, 339: 155, 343: 156, 345: 157, 348: 158, 349: 159, 357: 160, 362: 161, 365: 162, 366: 163, 368: 164, 370: 165, 373: 166, 374: 167, 382: 168, 383: 169, 384: 170, 385: 171, 387: 172, 388: 173, 392: 174, 393: 175, 396: 176, 397: 177, 398: 178, 401: 179, 402: 180, 403: 181, 404: 182, 406: 183, 407: 184, 411: 185, 413: 186, 417: 187, 419: 188, 421: 189, 422: 190, 423: 191, 424: 192, 425: 193, 430: 194, 432: 195, 435: 196, 436: 197, 437: 198, 438: 199, 439: 200, 440: 201, 441: 202, 443: 203, 445: 204, 446: 205, 447: 206, 448: 207, 450: 208, 452: 209, 454: 210, 456: 211, 458: 212, 463: 213, 464: 214, 465: 215, 472: 216, 473: 217, 474: 218, 478: 219, 480: 220, 481: 221, 483: 222, 485: 223, 487: 224, 489: 225, 490: 226, 491: 227, 493: 228, 496: 229, 498: 230, 502: 231, 504: 232, 505: 233, 507: 234, 510: 235, 511: 236, 512: 237, 518: 238, 519: 239, 520: 240, 521: 241, 522: 242, 524: 243, 526: 244, 528: 245, 530: 246, 531: 247, 532: 248, 534: 249, 536: 250, 544: 251, 545: 252, 546: 253, 547: 254, 548: 255, 550: 256, 556: 257, 557: 258, 558: 259, 559: 260, 561: 261, 562: 262, 563: 263, 564: 264, 566: 265, 568: 266, 569: 267, 572: 268, 573: 269, 574: 270, 575: 271, 578: 272, 579: 273, 582: 274, 585: 275, 588: 276, 589: 277, 595: 278, 598: 279, 600: 280, 602: 281, 604: 282, 606: 283, 607: 284, 610: 285, 613: 286, 614: 287, 615: 288, 616: 289, 617: 290, 618: 291, 619: 292, 622: 293, 623: 294, 624: 295, 628: 296, 630: 297, 633: 298, 634: 299, 636: 300, 637: 301, 638: 302, 639: 303, 640: 304, 642: 305, 645: 306, 650: 307, 653: 308, 655: 309, 657: 310, 658: 311, 659: 312, 660: 313, 662: 314, 664: 315, 665: 316, 666: 317, 667: 318, 668: 319, 669: 320, 670: 321, 671: 322, 673: 323, 675: 324, 677: 325, 679: 326, 682: 327, 684: 328, 687: 329, 689: 330, 692: 331, 696: 332, 697: 333, 704: 334, 708: 335, 710: 336, 713: 337, 714: 338, 715: 339, 716: 340, 719: 341, 720: 342, 721: 343, 723: 344, 724: 345, 725: 346, 727: 347, 728: 348, 730: 349, 731: 350, 732: 351, 735: 352, 737: 353, 739: 354, 740: 355, 744: 356, 745: 357, 747: 358, 751: 359, 753: 360, 759: 361, 761: 362, 762: 363, 764: 364, 767: 365, 768: 366, 770: 367, 771: 368, 774: 369, 776: 370, 778: 371, 779: 372, 780: 373, 782: 374, 783: 375, 784: 376, 785: 377, 789: 378, 793: 379, 795: 380, 796: 381, 797: 382, 798: 383, 799: 384, 802: 385, 805: 386, 808: 387, 811: 388, 813: 389, 814: 390, 815: 391, 817: 392, 819: 393, 821: 394, 825: 395, 829: 396, 831: 397, 835: 398, 836: 399, 837: 400, 839: 401, 842: 402, 843: 403, 844: 404, 848: 405, 855: 406, 859: 407, 860: 408, 883: 409, 1034: 410, 1120: 411, 1174: 412, 1239: 413, 1240: 414, 1242: 415, 1246: 416, 1248: 417, 1252: 418, 1259: 419, 1312: 420, 1333: 421, 1358: 422, 1363: 423, 1396: 424, 1397: 425, 1438: 426, 1471: 427, 1472: 428, 1501: 429, 1524: 430, 1526: 431, 1532: 432, 1542: 433, 1559: 434, 1562: 435, 1565: 436, 1587: 437, 1589: 438, 1614: 439, 1631: 440, 1636: 441, 1665: 442, 1671: 443, 1672: 444, 1693: 445, 1696: 446, 1716: 447, 1729: 448, 1732: 449, 1746: 450, 1756: 451, 1760: 452, 1767: 453, 1776: 454, 1786: 455, 1794: 456, 1812: 457, 1827: 458, 1830: 459, 1874: 460, 1879: 461, 1911: 462, 1953: 463, 1954: 464, 1973: 465, 1988: 466, 1989: 467, 1996: 468, 2004: 469, 2016: 470, 2032: 471, 2036: 472, 2044: 473, 2058: 474, 2408: 475, 2410: 476, 2420: 477, 2421: 478, 2422: 479, 2432: 480, 2435: 481, 2436: 482, 2446: 483, 2464: 484, 2469: 485, 2496: 486, 2515: 487, 2520: 488, 2529: 489, 2542: 490, 2558: 491, 2581: 492, 2597: 493, 2598: 494, 2642: 495, 2726: 496, 2735: 497, 2742: 498, 2748: 499, 2770: 500, 2953: 501, 3058: 502, 3253: 503, 3261: 504, 3344: 505, 3362: 506, 3363: 507, 3368: 508, 3370: 509, 3371: 510, 3451: 511, 3516: 512, 3520: 513, 3545: 514, 3546: 515, 3555: 516, 3582: 517, 3614: 518, 3619: 519, 3621: 520, 3680: 521, 3688: 522, 3715: 523, 3716: 524, 3732: 525, 3753: 526, 3758: 527, 3765: 528, 3776: 529, 3782: 530, 4061: 531, 4063: 532, 4064: 533, 4068: 534, 4076: 535, 4084: 536, 4096: 537, 4104: 538, 4105: 539, 4107: 540, 4108: 541, 4111: 542, 4115: 543, 4120: 544, 4132: 545, 4133: 546, 4135: 547, 4136: 548, 4140: 549, 4145: 550, 4151: 551, 4160: 552, 4164: 553, 4167: 554, 4180: 555, 4184: 556, 4186: 557, 4187: 558, 4192: 559, 4195: 560, 4198: 561, 4199: 562, 4201: 563, 4206: 564, 4208: 565, 4209: 566, 4211: 567, 4212: 568, 4215: 569, 4216: 570, 4225: 571, 4235: 572, 4237: 573, 4238: 574, 4243: 575, 4250: 576, 4258: 577, 4260: 578, 4261: 579, 4263: 580, 4275: 581, 4276: 582, 4277: 583, 4278: 584, 4279: 585, 4286: 586, 4288: 587, 4292: 588, 4301: 589, 4306: 590, 4307: 591, 4317: 592, 4323: 593, 4330: 594, 4333: 595, 4336: 596, 4344: 597, 4355: 598, 4362: 599, 4365: 600, 4387: 601, 4389: 602, 4391: 603, 4393: 604, 4406: 605, 4410: 606, 4412: 607, 4415: 608, 4417: 609, 4423: 610, 4425: 611, 4426: 612, 4430: 613, 4431: 614, 4432: 615, 4438: 616, 4445: 617, 4448: 618, 4451: 619, 4453: 620, 4461: 621, 4462: 622, 4463: 623, 4464: 624, 4472: 625, 4481: 626, 4484: 627, 4487: 628, 4488: 629, 4490: 630, 4492: 631, 4493: 632, 4495: 633, 4499: 634, 4501: 635, 4502: 636, 4509: 637, 4512: 638, 4513: 639, 4515: 640, 4520: 641, 4526: 642, 4527: 643, 4528: 644, 4532: 645, 4537: 646, 4538: 647, 4548: 648, 4551: 649, 4553: 650, 4555: 651, 4556: 652, 4567: 653, 4577: 654, 4583: 655, 4590: 656, 4597: 657, 4602: 658, 4618: 659, 4624: 660, 4625: 661, 4627: 662, 4629: 663, 4631: 664, 4656: 665, 4664: 666, 4667: 667, 4679: 668, 4683: 669, 4684: 670, 4685: 671, 4689: 672, 4690: 673, 4694: 674, 4707: 675, 4721: 676, 4728: 677, 4733: 678, 4740: 679, 4741: 680, 4751: 681, 4767: 682, 4768: 683, 4791: 684, 4796: 685, 4800: 686, 4802: 687, 4805: 688, 4810: 689, 4811: 690, 4812: 691, 4815: 692, 5251: 693, 5254: 694, 5258: 695, 5259: 696, 5339: 697, 5388: 698, 5398: 699, 7136: 700, 7140: 701} 14 | duke_test_map = {2: 0, 3: 1, 4: 2, 5: 3, 7: 4, 9: 5, 10: 6, 11: 7, 12: 8, 19: 9, 21: 10, 23: 11, 25: 12, 27: 13, 30: 14, 31: 15, 33: 16, 34: 17, 35: 18, 39: 19, 42: 20, 43: 21, 44: 22, 46: 23, 47: 24, 49: 25, 50: 26, 51: 27, 53: 28, 56: 29, 61: 30, 66: 31, 68: 32, 69: 33, 72: 34, 75: 35, 76: 36, 77: 37, 78: 38, 79: 39, 80: 40, 83: 41, 86: 42, 88: 43, 89: 44, 90: 45, 91: 46, 92: 47, 95: 48, 97: 49, 98: 50, 99: 51, 101: 52, 103: 53, 106: 54, 107: 55, 109: 56, 111: 57, 112: 58, 114: 59, 115: 60, 117: 61, 118: 62, 119: 63, 122: 64, 123: 65, 125: 66, 126: 67, 127: 68, 128: 69, 134: 70, 135: 71, 136: 72, 137: 73, 140: 74, 141: 75, 142: 76, 143: 77, 145: 78, 147: 79, 149: 80, 150: 81, 151: 82, 158: 83, 159: 84, 162: 85, 163: 86, 164: 87, 167: 88, 169: 89, 170: 90, 171: 91, 174: 92, 175: 93, 180: 94, 181: 95, 183: 96, 184: 97, 186: 98, 187: 99, 188: 100, 192: 101, 194: 102, 197: 103, 199: 104, 200: 105, 201: 106, 204: 107, 205: 108, 206: 109, 207: 110, 210: 111, 211: 112, 212: 113, 213: 114, 214: 115, 215: 116, 218: 117, 219: 118, 220: 119, 221: 120, 223: 121, 229: 122, 230: 123, 235: 124, 237: 125, 238: 126, 239: 127, 240: 128, 241: 129, 243: 130, 244: 131, 247: 132, 249: 133, 251: 134, 253: 135, 254: 136, 256: 137, 257: 138, 261: 139, 262: 140, 264: 141, 266: 142, 267: 143, 268: 144, 269: 145, 270: 146, 272: 147, 273: 148, 274: 149, 275: 150, 276: 151, 277: 152, 279: 153, 285: 154, 287: 155, 288: 156, 292: 157, 293: 158, 294: 159, 295: 160, 298: 161, 299: 162, 300: 163, 301: 164, 302: 165, 303: 166, 304: 167, 305: 168, 311: 169, 313: 170, 314: 171, 315: 172, 316: 173, 321: 174, 323: 175, 324: 176, 329: 177, 332: 178, 334: 179, 337: 180, 340: 181, 341: 182, 342: 183, 344: 184, 346: 185, 347: 186, 350: 187, 351: 188, 352: 189, 353: 190, 354: 191, 355: 192, 356: 193, 358: 194, 359: 195, 360: 196, 361: 197, 363: 198, 364: 199, 367: 200, 369: 201, 371: 202, 372: 203, 375: 204, 376: 205, 377: 206, 378: 207, 379: 208, 380: 209, 381: 210, 386: 211, 389: 212, 390: 213, 391: 214, 394: 215, 395: 216, 400: 217, 405: 218, 408: 219, 409: 220, 410: 221, 412: 222, 414: 223, 415: 224, 416: 225, 418: 226, 420: 227, 426: 228, 427: 229, 428: 230, 429: 231, 431: 232, 433: 233, 434: 234, 442: 235, 444: 236, 449: 237, 451: 238, 453: 239, 455: 240, 457: 241, 459: 242, 460: 243, 461: 244, 462: 245, 466: 246, 467: 247, 468: 248, 469: 249, 470: 250, 471: 251, 479: 252, 482: 253, 484: 254, 486: 255, 488: 256, 492: 257, 494: 258, 495: 259, 497: 260, 499: 261, 500: 262, 501: 263, 503: 264, 506: 265, 508: 266, 509: 267, 513: 268, 514: 269, 515: 270, 516: 271, 517: 272, 523: 273, 525: 274, 527: 275, 529: 276, 533: 277, 535: 278, 537: 279, 538: 280, 539: 281, 540: 282, 541: 283, 542: 284, 543: 285, 549: 286, 551: 287, 552: 288, 553: 289, 554: 290, 555: 291, 560: 292, 565: 293, 567: 294, 570: 295, 571: 296, 576: 297, 577: 298, 580: 299, 581: 300, 583: 301, 584: 302, 586: 303, 587: 304, 590: 305, 591: 306, 592: 307, 593: 308, 594: 309, 596: 310, 597: 311, 599: 312, 601: 313, 603: 314, 605: 315, 608: 316, 609: 317, 611: 318, 612: 319, 620: 320, 621: 321, 625: 322, 626: 323, 627: 324, 629: 325, 631: 326, 632: 327, 635: 328, 641: 329, 643: 330, 644: 331, 646: 332, 647: 333, 648: 334, 649: 335, 651: 336, 652: 337, 654: 338, 656: 339, 661: 340, 663: 341, 672: 342, 674: 343, 676: 344, 678: 345, 680: 346, 681: 347, 683: 348, 685: 349, 686: 350, 688: 351, 690: 352, 691: 353, 693: 354, 694: 355, 695: 356, 698: 357, 699: 358, 700: 359, 701: 360, 702: 361, 703: 362, 705: 363, 706: 364, 707: 365, 709: 366, 711: 367, 712: 368, 717: 369, 718: 370, 722: 371, 726: 372, 729: 373, 733: 374, 734: 375, 736: 376, 738: 377, 741: 378, 742: 379, 743: 380, 746: 381, 748: 382, 749: 383, 750: 384, 752: 385, 754: 386, 755: 387, 756: 388, 757: 389, 758: 390, 760: 391, 763: 392, 765: 393, 766: 394, 769: 395, 772: 396, 773: 397, 775: 398, 777: 399, 781: 400, 786: 401, 787: 402, 788: 403, 790: 404, 791: 405, 792: 406, 794: 407, 800: 408, 803: 409, 804: 410, 806: 411, 807: 412, 809: 413, 810: 414, 812: 415, 816: 416, 818: 417, 820: 418, 823: 419, 824: 420, 826: 421, 828: 422, 830: 423, 832: 424, 834: 425, 838: 426, 840: 427, 845: 428, 846: 429, 847: 430, 849: 431, 850: 432, 851: 433, 852: 434, 853: 435, 854: 436, 856: 437, 857: 438, 858: 439, 863: 440, 864: 441, 884: 442, 1104: 443, 1108: 444, 1109: 445, 1110: 446, 1226: 447, 1228: 448, 1229: 449, 1233: 450, 1243: 451, 1244: 452, 1290: 453, 1297: 454, 1300: 455, 1307: 456, 1314: 457, 1328: 458, 1343: 459, 1346: 460, 1366: 461, 1382: 462, 1386: 463, 1391: 464, 1398: 465, 1403: 466, 1408: 467, 1421: 468, 1426: 469, 1440: 470, 1463: 471, 1467: 472, 1480: 473, 1486: 474, 1487: 475, 1489: 476, 1490: 477, 1518: 478, 1555: 479, 1584: 480, 1585: 481, 1586: 482, 1598: 483, 1601: 484, 1626: 485, 1635: 486, 1637: 487, 1642: 488, 1673: 489, 1682: 490, 1698: 491, 1699: 492, 1723: 493, 1724: 494, 1725: 495, 1730: 496, 1737: 497, 1741: 498, 1745: 499, 1749: 500, 1750: 501, 1758: 502, 1759: 503, 1762: 504, 1766: 505, 1775: 506, 1782: 507, 1784: 508, 1785: 509, 1788: 510, 1790: 511, 1811: 512, 1834: 513, 1849: 514, 1893: 515, 1901: 516, 1922: 517, 1946: 518, 1949: 519, 2001: 520, 2012: 521, 2023: 522, 2053: 523, 2407: 524, 2429: 525, 2454: 526, 2470: 527, 2471: 528, 2479: 529, 2488: 530, 2495: 531, 2532: 532, 2556: 533, 2557: 534, 2573: 535, 2599: 536, 2724: 537, 2736: 538, 2754: 539, 2768: 540, 2772: 541, 2777: 542, 2942: 543, 2988: 544, 3201: 545, 3202: 546, 3259: 547, 3335: 548, 3353: 549, 3354: 550, 3358: 551, 3410: 552, 3446: 553, 3495: 554, 3515: 555, 3561: 556, 3609: 557, 3618: 558, 3638: 559, 3649: 560, 3664: 561, 3674: 562, 3731: 563, 3761: 564, 3763: 565, 4055: 566, 4057: 567, 4059: 568, 4060: 569, 4062: 570, 4065: 571, 4066: 572, 4070: 573, 4071: 574, 4072: 575, 4075: 576, 4079: 577, 4082: 578, 4099: 579, 4100: 580, 4102: 581, 4106: 582, 4110: 583, 4113: 584, 4114: 585, 4116: 586, 4117: 587, 4118: 588, 4119: 589, 4121: 590, 4128: 591, 4134: 592, 4141: 593, 4143: 594, 4144: 595, 4146: 596, 4147: 597, 4150: 598, 4152: 599, 4158: 600, 4159: 601, 4163: 602, 4169: 603, 4170: 604, 4174: 605, 4176: 606, 4177: 607, 4178: 608, 4185: 609, 4190: 610, 4197: 611, 4204: 612, 4205: 613, 4207: 614, 4210: 615, 4219: 616, 4221: 617, 4226: 618, 4227: 619, 4228: 620, 4230: 621, 4239: 622, 4245: 623, 4246: 624, 4247: 625, 4249: 626, 4254: 627, 4255: 628, 4256: 629, 4257: 630, 4271: 631, 4272: 632, 4274: 633, 4280: 634, 4284: 635, 4285: 636, 4309: 637, 4310: 638, 4315: 639, 4319: 640, 4321: 641, 4324: 642, 4326: 643, 4329: 644, 4331: 645, 4332: 646, 4334: 647, 4335: 648, 4337: 649, 4341: 650, 4349: 651, 4356: 652, 4361: 653, 4366: 654, 4372: 655, 4373: 656, 4374: 657, 4380: 658, 4386: 659, 4392: 660, 4398: 661, 4405: 662, 4411: 663, 4416: 664, 4419: 665, 4422: 666, 4427: 667, 4428: 668, 4433: 669, 4443: 670, 4447: 671, 4449: 672, 4452: 673, 4459: 674, 4460: 675, 4473: 676, 4477: 677, 4480: 678, 4483: 679, 4489: 680, 4494: 681, 4500: 682, 4503: 683, 4504: 684, 4508: 685, 4510: 686, 4511: 687, 4514: 688, 4519: 689, 4521: 690, 4540: 691, 4541: 692, 4547: 693, 4550: 694, 4558: 695, 4560: 696, 4563: 697, 4568: 698, 4572: 699, 4573: 700, 4580: 701, 4582: 702, 4587: 703, 4594: 704, 4596: 705, 4605: 706, 4606: 707, 4607: 708, 4609: 709, 4613: 710, 4622: 711, 4632: 712, 4633: 713, 4634: 714, 4639: 715, 4640: 716, 4646: 717, 4647: 718, 4654: 719, 4672: 720, 4681: 721, 4693: 722, 4695: 723, 4699: 724, 4708: 725, 4713: 726, 4717: 727, 4719: 728, 4723: 729, 4725: 730, 4726: 731, 4727: 732, 4729: 733, 4736: 734, 4739: 735, 4743: 736, 4750: 737, 4757: 738, 4758: 739, 4759: 740, 4760: 741, 4769: 742, 4772: 743, 4774: 744, 4779: 745, 4782: 746, 4789: 747, 4790: 748, 4804: 749, 4807: 750, 4808: 751, 4809: 752, 4817: 753, 4823: 754, 5249: 755, 5272: 756, 5333: 757, 5358: 758, 5474: 759, 5587: 760, 5599: 761, 5842: 762, 5849: 763, 5855: 764, 5856: 765, 5860: 766, 5867: 767, 5876: 768, 5877: 769, 5887: 770, 5889: 771, 5899: 772, 5904: 773, 5905: 774, 5906: 775, 5907: 776, 5910: 777, 5911: 778, 5920: 779, 5921: 780, 5922: 781, 5924: 782, 5927: 783, 5937: 784, 5939: 785, 5940: 786, 5941: 787, 5943: 788, 5947: 789, 5948: 790, 5949: 791, 5951: 792, 5952: 793, 5966: 794, 5970: 795, 5971: 796, 5972: 797, 5973: 798, 5974: 799, 5975: 800, 5977: 801, 5982: 802, 5985: 803, 5994: 804, 6008: 805, 6019: 806, 6031: 807, 6040: 808, 6046: 809, 6048: 810, 6049: 811, 6050: 812, 6051: 813, 6054: 814, 6056: 815, 6058: 816, 6059: 817, 6063: 818, 6068: 819, 6070: 820, 6071: 821, 6072: 822, 6073: 823, 6074: 824, 6076: 825, 6077: 826, 6084: 827, 6087: 828, 6088: 829, 6091: 830, 6093: 831, 6094: 832, 6097: 833, 6100: 834, 6101: 835, 6102: 836, 6103: 837, 6105: 838, 6109: 839, 6110: 840, 6111: 841, 6112: 842, 6115: 843, 6117: 844, 6119: 845, 6122: 846, 6123: 847, 6134: 848, 6136: 849, 6137: 850, 6139: 851, 6140: 852, 6143: 853, 6146: 854, 6147: 855, 6148: 856, 6151: 857, 6155: 858, 6156: 859, 6158: 860, 6161: 861, 6164: 862, 6166: 863, 6172: 864, 6176: 865, 6178: 866, 6179: 867, 6180: 868, 6185: 869, 6188: 870, 6189: 871, 6191: 872, 6195: 873, 6196: 874, 6198: 875, 6199: 876, 6202: 877, 6204: 878, 6205: 879, 6208: 880, 6210: 881, 6212: 882, 6213: 883, 6214: 884, 6215: 885, 6216: 886, 6219: 887, 6220: 888, 6223: 889, 6224: 890, 6225: 891, 6227: 892, 6230: 893, 6235: 894, 6236: 895, 6244: 896, 6246: 897, 6247: 898, 6252: 899, 6253: 900, 6255: 901, 6257: 902, 6258: 903, 6259: 904, 6262: 905, 6263: 906, 6264: 907, 6269: 908, 6271: 909, 6277: 910, 6279: 911, 6281: 912, 6285: 913, 6287: 914, 6290: 915, 6291: 916, 6296: 917, 6297: 918, 6299: 919, 6301: 920, 6319: 921, 6320: 922, 6328: 923, 6331: 924, 6337: 925, 6338: 926, 6339: 927, 6340: 928, 6342: 929, 6344: 930, 6345: 931, 6347: 932, 6348: 933, 6351: 934, 6352: 935, 6353: 936, 6355: 937, 6356: 938, 6357: 939, 6359: 940, 6362: 941, 6365: 942, 6366: 943, 6367: 944, 6368: 945, 6369: 946, 6370: 947, 6371: 948, 6376: 949, 6377: 950, 6389: 951, 6391: 952, 6393: 953, 6396: 954, 6397: 955, 6398: 956, 6399: 957, 6400: 958, 6402: 959, 6403: 960, 6406: 961, 6407: 962, 6408: 963, 6410: 964, 6412: 965, 6414: 966, 6415: 967, 6416: 968, 6422: 969, 6423: 970, 6429: 971, 6433: 972, 6439: 973, 6440: 974, 6441: 975, 6446: 976, 6447: 977, 6448: 978, 6449: 979, 6452: 980, 6459: 981, 6464: 982, 6465: 983, 6474: 984, 6476: 985, 6479: 986, 6481: 987, 6482: 988, 6483: 989, 6486: 990, 6489: 991, 6494: 992, 6499: 993, 6500: 994, 6502: 995, 6503: 996, 6504: 997, 6505: 998, 6506: 999, 6507: 1000, 6509: 1001, 6517: 1002, 6522: 1003, 6524: 1004, 6528: 1005, 6530: 1006, 6531: 1007, 6533: 1008, 6535: 1009, 6539: 1010, 6540: 1011, 6543: 1012, 6545: 1013, 6546: 1014, 6547: 1015, 6548: 1016, 6549: 1017, 6550: 1018, 6552: 1019, 6558: 1020, 6559: 1021, 6566: 1022, 6569: 1023, 6571: 1024, 6577: 1025, 6578: 1026, 6585: 1027, 6586: 1028, 6592: 1029, 6595: 1030, 6596: 1031, 6602: 1032, 6603: 1033, 6605: 1034, 6606: 1035, 6607: 1036, 6609: 1037, 6610: 1038, 6611: 1039, 6614: 1040, 6615: 1041, 6616: 1042, 6617: 1043, 6621: 1044, 6636: 1045, 6637: 1046, 6639: 1047, 6641: 1048, 6648: 1049, 6649: 1050, 6651: 1051, 6660: 1052, 6661: 1053, 6662: 1054, 6665: 1055, 6668: 1056, 6669: 1057, 6670: 1058, 6671: 1059, 6672: 1060, 6673: 1061, 6674: 1062, 6676: 1063, 6679: 1064, 6680: 1065, 6685: 1066, 6686: 1067, 6688: 1068, 6689: 1069, 6690: 1070, 6694: 1071, 6695: 1072, 6697: 1073, 6698: 1074, 6699: 1075, 6700: 1076, 6704: 1077, 6708: 1078, 6709: 1079, 6710: 1080, 6717: 1081, 6722: 1082, 6725: 1083, 6726: 1084, 6732: 1085, 6741: 1086, 6744: 1087, 6745: 1088, 6755: 1089, 6758: 1090, 6759: 1091, 6763: 1092, 6764: 1093, 6767: 1094, 6770: 1095, 6776: 1096, 6777: 1097, 6778: 1098, 6779: 1099, 6785: 1100, 6788: 1101, 6789: 1102, 6794: 1103, 6799: 1104, 6804: 1105, 6805: 1106, 6813: 1107, 7138: 1108, 7139: 1109} 15 | 16 | base_opt = {'workers': 4, 17 | 'split_id': 0, 18 | 'cuhk03_labeled': False, 19 | 'cuhk03_classic_split': False, 20 | 'use_metric_cuhk03': False, 21 | 'num_instances': 4, 22 | 'ReID_factor': 10, } 23 | 24 | def get_opts(name): 25 | # 1. 26 | if name == 'ide': 27 | base_opt['transform_train'] = T.Compose([T.RandomSizedRectCrop(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev), T.RandomErasing(EPSILON=0)]) 28 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128), interpolation=3), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 29 | 30 | elif name == 'densenet121': 31 | base_opt['transform_train'] = T.Compose([T.Random2DTranslation(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 32 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128)), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 33 | 34 | elif name == 'mudeep': 35 | base_opt['transform_train'] = T.Compose([T.Random2DTranslation(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 36 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128)), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 37 | 38 | # 2. 39 | elif name == 'aligned': 40 | base_opt['transform_train'] = T.Compose([T.Random2DTranslation(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 41 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128)), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 42 | 43 | elif name == 'pcb': 44 | base_opt['transform_train'] = T.Compose([T.Resize((384,192), interpolation=3), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(Imagenet_mean, Imagenet_stddev)]) 45 | base_opt['transform_test'] = T.Compose([T.Resize((384,192), interpolation=3), T.ToTensor(), T.Normalize(Imagenet_mean, Imagenet_stddev)]) 46 | base_opt['ReID_factor'] = 2 47 | base_opt['workers'] = 16 48 | 49 | elif name == 'hacnn': 50 | base_opt['transform_train'] = T.Compose([T.Random2DTranslation(160, 64), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 51 | base_opt['transform_test'] = T.Compose([T.Resize((160, 64)), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 52 | 53 | # 3. 54 | elif name == 'cam': 55 | base_opt['transform_train'] = T.Compose([T.RandomSizedRectCrop(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev), T.RandomErasing(EPSILON=0.5)]) 56 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128), interpolation=3), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 57 | 58 | elif name == 'lsro': 59 | base_opt['transform_train'] = T.Compose([T.Resize(144, interpolation=3), T.RandomCrop((256,128)), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 60 | base_opt['transform_test'] = T.Compose([T.Resize((288,144), interpolation=3), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 61 | 62 | elif name == 'hhl': 63 | base_opt['transform_train'] = T.Compose([T.RandomSizedRectCrop(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev), T.RandomErasing(EPSILON=0)]) 64 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128), interpolation=3), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 65 | 66 | elif name == 'spgan': 67 | base_opt['transform_train'] = T.Compose([T.RandomSizedRectCrop(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev), T.RandomErasing(EPSILON=0)]) 68 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128), interpolation=3), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)]) 69 | 70 | return base_opt --------------------------------------------------------------------------------