├── .gitattributes
├── CVPR2020.pdf
├── images
├── results.PNG
└── framework.PNG
├── util
├── optimizers.py
├── FeatureExtractor.py
├── gumbel.py
├── spectral.py
├── dataset_loader.py
├── transforms.py
├── samplers.py
├── re_ranking.py
├── ms_ssim.py
├── local_dist.py
├── distance.py
├── utils.py
└── eval_metrics.py
├── LICENSE
├── models
├── DenseNet.py
├── __init__.py
├── LSRO.py
├── AlignedReID.py
├── IDE.py
├── PCB.py
├── MuDeep.py
└── HACNN.py
├── ReID_attr.py
├── advloss.py
├── README.md
├── GD.py
├── train.py
└── opts.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/CVPR2020.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whj363636/Adversarial-attack-on-Person-ReID-With-Deep-Mis-Ranking/HEAD/CVPR2020.pdf
--------------------------------------------------------------------------------
/images/results.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whj363636/Adversarial-attack-on-Person-ReID-With-Deep-Mis-Ranking/HEAD/images/results.PNG
--------------------------------------------------------------------------------
/images/framework.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whj363636/Adversarial-attack-on-Person-ReID-With-Deep-Mis-Ranking/HEAD/images/framework.PNG
--------------------------------------------------------------------------------
/util/optimizers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | __all__ = ['init_optim']
4 |
5 | def init_optim(optim, params, lr, weight_decay):
6 | if optim == 'adam':
7 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
8 | elif optim == 'sgd':
9 | return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
10 | elif optim == 'rmsprop':
11 | return torch.optim.RMSprop(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
12 | else:
13 | raise KeyError("Unsupported optim: {}".format(optim))
--------------------------------------------------------------------------------
/util/FeatureExtractor.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from IPython import embed
3 |
4 | class FeatureExtractor(nn.Module):
5 | def __init__(self,submodule,extracted_layers):
6 | super(FeatureExtractor,self).__init__()
7 | self.submodule = submodule
8 | self.extracted_layers = extracted_layers
9 |
10 | def forward(self, x):
11 | outputs = []
12 | for name, module in self.submodule._modules.items():
13 | if name is "classfier":
14 | x = x.view(x.size(0),-1)
15 | if name is "base":
16 | for block_name, cnn_block in module._modules.items():
17 | x = cnn_block(x)
18 | if block_name in self.extracted_layers:
19 | outputs.append(x)
20 | return outputs
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 whj
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/models/DenseNet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | import torchvision
7 |
8 | __all__ = ['DenseNet121']
9 |
10 |
11 | class DenseNet121(nn.Module):
12 | def __init__(self, num_classes, loss={'xent'}, **kwargs):
13 | super(DenseNet121, self).__init__()
14 | self.loss = loss
15 | densenet121 = torchvision.models.densenet121(pretrained=True)
16 | self.base = densenet121.features
17 | self.classifier = nn.Linear(1024, num_classes)
18 | self.feat_dim = 1024 # feature dimension
19 |
20 | def forward(self, x, is_training):
21 | x = self.base(x)
22 | x = F.avg_pool2d(x, x.size()[2:])
23 | f = x.view(x.size(0), -1)
24 | if not is_training:
25 | return f
26 | y = self.classifier(f)
27 |
28 | if self.loss == {'xent'}:
29 | return [y]
30 | elif self.loss == {'xent', 'htri'}:
31 | return [y, f]
32 | elif self.loss == {'cent'}:
33 | return [y, f]
34 | elif self.loss == {'ring'}:
35 | return [y, f]
36 | else:
37 | raise KeyError("Unsupported loss: {}".format(self.loss))
--------------------------------------------------------------------------------
/util/gumbel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.optimizer import Optimizer, required
3 |
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 | from torch import nn
7 | from torch import Tensor
8 | from torch.nn import Parameter
9 |
10 |
11 | def _sample_gumbel(shape, eps=1e-10, out=None):
12 | """
13 | Based on
14 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
15 | (MIT license)
16 | """
17 | U = out.resize_(shape).uniform_() if out is not None else torch.rand(shape)
18 | return - torch.log(eps - torch.log(U + eps))
19 |
20 |
21 | def _gumbel_softmax_sample(logits, T=1, eps=1e-10):
22 | """
23 | Based on
24 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
25 | (MIT license)
26 | """
27 | dims = logits.dim()
28 | gumbel_noise = _sample_gumbel(logits.size(), eps=eps, out=logits.data.new())
29 | y = logits + gumbel_noise
30 | return F.softmax(y / T, dims - 1)
31 |
32 |
33 | def gumbel_softmax(logits, k, T=1, hard=True, eps=1e-10):
34 | shape = logits.size()
35 | assert len(shape) == 2
36 | y_soft = _gumbel_softmax_sample(logits, T=T, eps=eps)
37 | if hard:
38 | _, ind = torch.topk(y_soft, k=k, dim=-1, largest=True)
39 | y_hard = logits.new_zeros(*shape).scatter_(-1, ind.view(-1, k), 1.0)
40 | y = y_hard - y_soft.detach() + y_soft
41 | else:
42 | y = y_soft
43 | return y
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import torch
3 | import torch.nn as nn
4 |
5 | from .DenseNet import *
6 | from .MuDeep import *
7 | from .AlignedReID import *
8 | from .PCB import *
9 | from .HACNN import *
10 | from .IDE import *
11 | from .LSRO import *
12 |
13 | __factory = {
14 | # 1.
15 | 'hacnn': HACNN,
16 | 'densenet121': DenseNet121,
17 | 'ide': IDE,
18 | # 2.
19 | 'aligned': ResNet50,
20 | 'pcb': PCB,
21 | 'mudeep': MuDeep,
22 | # 3.
23 | 'cam': IDE,
24 | 'hhl': IDE,
25 | 'lsro': DenseNet121,
26 | 'spgan': IDE,
27 | }
28 |
29 | def get_names():
30 | return __factory.keys()
31 |
32 | def init_model(name, pre_dir, *args, **kwargs):
33 | if name not in __factory.keys():
34 | raise KeyError("Unknown model: {}".format(name))
35 |
36 | print("Initializing model: {}".format(name))
37 | net = __factory[name](*args, **kwargs)
38 | # load pretrained model
39 | checkpoint = torch.load(pre_dir) # for Python 2
40 | # checkpoint = torch.load(pre_dir, encoding="latin1") # for Python 3
41 | state_dict = checkpoint['state_dict'] if isinstance(checkpoint, dict) and 'state_dict' in checkpoint else checkpoint
42 | change = False
43 | for k, v in state_dict.items():
44 | if k[:6] == 'module':
45 | change = True
46 | break
47 | if not change:
48 | new_state_dict = state_dict
49 | else:
50 | from collections import OrderedDict
51 | new_state_dict = OrderedDict()
52 | for k, v in state_dict.items():
53 | name = k[7:] # remove 'module.' of dataparallel
54 | new_state_dict[name]=v
55 | net.load_state_dict(new_state_dict)
56 | # freeze
57 | net.eval()
58 | net.volatile = True
59 | return net
--------------------------------------------------------------------------------
/models/LSRO.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import init
5 | import torchvision
6 | from torchvision import models
7 | from torch.autograd import Variable
8 | from torch.nn import functional as F
9 |
10 | __all__ = ['DenseNet121']
11 |
12 |
13 | def weights_init_kaiming(m):
14 | classname = m.__class__.__name__
15 | # print(classname)
16 | if classname.find('Conv') != -1:
17 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
18 | elif classname.find('Linear') != -1:
19 | init.kaiming_normal(m.weight.data, a=0, mode='fan_out')
20 | init.constant(m.bias.data, 0.0)
21 | elif classname.find('BatchNorm1d') != -1:
22 | init.normal(m.weight.data, 1.0, 0.02)
23 | init.constant(m.bias.data, 0.0)
24 |
25 | def weights_init_classifier(m):
26 | classname = m.__class__.__name__
27 | if classname.find('Linear') != -1:
28 | init.normal(m.weight.data, std=0.001)
29 | init.constant(m.bias.data, 0.0)
30 |
31 | class DenseNet121(nn.Module):
32 | def __init__(self, num_classes):
33 | super(DenseNet121,self).__init__()
34 | model_ft = models.densenet121(pretrained=True)
35 | # add pooling to the model
36 | # in the originial version, pooling is written in the forward function
37 | model_ft.features.avgpool = nn.AdaptiveAvgPool2d((1,1))
38 |
39 | add_block = []
40 | num_bottleneck = 512
41 | add_block += [nn.Linear(1024, num_bottleneck)] #For ResNet, it is 2048
42 | add_block += [nn.BatchNorm1d(num_bottleneck)]
43 | add_block += [nn.LeakyReLU(0.1)]
44 | add_block += [nn.Dropout(p=0.5)]
45 | add_block = nn.Sequential(*add_block)
46 | add_block.apply(weights_init_kaiming)
47 | model_ft.fc = add_block
48 | self.model = model_ft
49 |
50 | classifier = []
51 | classifier += [nn.Linear(num_bottleneck, num_classes)]
52 | classifier = nn.Sequential(*classifier)
53 | classifier.apply(weights_init_classifier)
54 | self.classifier = classifier
55 |
56 | def forward(self, x, is_training):
57 | x = self.model.features(x)
58 | x = x.view(x.size(0),-1)
59 | x = self.model.fc(x)
60 | logits = self.classifier(x)
61 | return [logits, x]
62 |
63 |
--------------------------------------------------------------------------------
/models/AlignedReID.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | import torchvision
7 |
8 | __all__ = ['ResNet50']
9 |
10 | class ResNet50(nn.Module):
11 | """
12 | Alignedreid: Surpassing human-level performance in person re-identification
13 |
14 | Reference:
15 | Zhang, Xuan, et al. "Alignedreid: Surpassing human-level performance in person re-identification." arXiv preprint arXiv:1711.08184 (2017)
16 | """
17 | def __init__(self, num_classes, **kwargs):
18 | super(ResNet50, self).__init__()
19 | self.loss = {'softmax', 'metric'}
20 | resnet50 = torchvision.models.resnet50(pretrained=True)
21 | self.base = nn.Sequential(*list(resnet50.children())[:-2])
22 | self.classifier = nn.Linear(2048, num_classes)
23 | self.feat_dim = 2048 # feature dimension
24 | self.aligned = True
25 | self.horizon_pool = HorizontalMaxPool2d()
26 | if self.aligned:
27 | self.bn = nn.BatchNorm2d(2048)
28 | self.relu = nn.ReLU(inplace=True)
29 | self.conv1 = nn.Conv2d(2048, 128, kernel_size=1, stride=1, padding=0, bias=True)
30 |
31 | def forward(self, x, is_training):
32 | x = self.base(x)
33 | if not is_training:
34 | lf = self.horizon_pool(x)
35 | if self.aligned and is_training:
36 | lf = self.bn(x)
37 | lf = self.relu(lf)
38 | lf = self.horizon_pool(lf)
39 | lf = self.conv1(lf)
40 | if self.aligned or not is_training:
41 | lf = lf.view(lf.size()[0:3])
42 | lf = lf / torch.pow(lf,2).sum(dim=1, keepdim=True).clamp(min=1e-12).sqrt()
43 | x = F.avg_pool2d(x, x.size()[2:])
44 | f = x.view(x.size(0), -1)
45 | #f = 1. * f / (torch.norm(f, 2, dim=-1, keepdim=True).expand_as(f) + 1e-12)
46 | if not is_training:
47 | return [f,lf]
48 | y = self.classifier(f)
49 | if self.loss == {'softmax'}:
50 | return [y]
51 | elif self.loss == {'metric'}:
52 | if self.aligned:
53 | return [f, lf]
54 | return [f]
55 | elif self.loss == {'softmax', 'metric'}:
56 | if self.aligned:
57 | return [y, f, lf]
58 | return [y, f]
59 | else:
60 | raise KeyError("Unsupported loss: {}".format(self.loss))
61 |
62 | class HorizontalMaxPool2d(nn.Module):
63 | def __init__(self):
64 | super(HorizontalMaxPool2d, self).__init__()
65 |
66 |
67 | def forward(self, x):
68 | inp_size = x.size()
69 | return nn.functional.max_pool2d(input=x,kernel_size= (1, inp_size[3]))
--------------------------------------------------------------------------------
/util/spectral.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.optimizer import Optimizer, required
3 |
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 | from torch import nn
7 | from torch import Tensor
8 | from torch.nn import Parameter
9 |
10 | def l2normalize(v, eps=1e-12):
11 | return v / (v.norm() + eps)
12 |
13 |
14 | class SpectralNorm(nn.Module):
15 | def __init__(self, module, name='weight', power_iterations=1):
16 | super(SpectralNorm, self).__init__()
17 | self.module = module
18 | self.name = name
19 | self.power_iterations = power_iterations
20 | if not self._made_params():
21 | self._make_params()
22 |
23 | def _update_u_v(self):
24 | u = getattr(self.module, self.name + "_u")
25 | v = getattr(self.module, self.name + "_v")
26 | w = getattr(self.module, self.name + "_bar")
27 |
28 | height = w.data.shape[0]
29 | for _ in range(self.power_iterations):
30 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
31 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
32 |
33 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
34 | sigma = u.dot(w.view(height, -1).mv(v))
35 | setattr(self.module, self.name, w / sigma.expand_as(w))
36 |
37 | def _made_params(self):
38 | try:
39 | u = getattr(self.module, self.name + "_u")
40 | v = getattr(self.module, self.name + "_v")
41 | w = getattr(self.module, self.name + "_bar")
42 | return True
43 | except AttributeError:
44 | return False
45 |
46 |
47 | def _make_params(self):
48 | w = getattr(self.module, self.name)
49 |
50 | height = w.data.shape[0]
51 | width = w.view(height, -1).data.shape[1]
52 |
53 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
54 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
55 | u.data = l2normalize(u.data)
56 | v.data = l2normalize(v.data)
57 | w_bar = Parameter(w.data)
58 |
59 | del self.module._parameters[self.name]
60 |
61 | self.module.register_parameter(self.name + "_u", u)
62 | self.module.register_parameter(self.name + "_v", v)
63 | self.module.register_parameter(self.name + "_bar", w_bar)
64 |
65 |
66 | def forward(self, *args):
67 | self._update_u_v()
68 | return self.module.forward(*args)
--------------------------------------------------------------------------------
/models/IDE.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torch.nn import init
6 | import torchvision
7 | import pdb
8 |
9 | __all__ = ['IDE']
10 |
11 |
12 | class IDE(nn.Module):
13 | def __init__(self, pretrained=True, cut_at_pooling=False,
14 | num_features=1024, norm=False, dropout=0, num_classes=0):
15 | super(IDE, self).__init__()
16 |
17 | self.pretrained = pretrained
18 | self.cut_at_pooling = cut_at_pooling
19 |
20 | # Construct base (pretrained) resnet
21 | self.base = torchvision.models.resnet50(pretrained=True)
22 |
23 | if not self.cut_at_pooling:
24 | self.num_features = num_features
25 | self.norm = norm
26 | self.dropout = dropout
27 | self.has_embedding = num_features > 0
28 | self.num_classes = num_classes
29 |
30 | out_planes = self.base.fc.in_features
31 |
32 | # Append new layers
33 | if self.has_embedding:
34 | self.feat = nn.Linear(out_planes, self.num_features)
35 | self.feat_bn = nn.BatchNorm1d(self.num_features)
36 | init.kaiming_normal(self.feat.weight, mode='fan_out')
37 | init.constant(self.feat.bias, 0)
38 | init.constant(self.feat_bn.weight, 1)
39 | init.constant(self.feat_bn.bias, 0)
40 | else:
41 | # Change the num_features to CNN output channels
42 | self.num_features = out_planes
43 | if self.dropout > 0:
44 | self.drop = nn.Dropout(self.dropout)
45 | if self.num_classes > 0:
46 | self.classifier = nn.Linear(self.num_features, self.num_classes)
47 | init.normal(self.classifier.weight, std=0.001)
48 | init.constant(self.classifier.bias, 0)
49 |
50 | if not self.pretrained:
51 | self.reset_params()
52 |
53 | def forward(self, x, is_training, output_feature=None):
54 | for name, module in self.base._modules.items():
55 | if name == 'avgpool':
56 | break
57 | x = module(x)
58 |
59 | if self.cut_at_pooling:
60 | return x
61 |
62 | x = F.avg_pool2d(x, x.size()[2:])
63 | x = x.view(x.size(0), -1)
64 |
65 | if output_feature == 'pool5':
66 | x = F.normalize(x)
67 | return x
68 | if self.has_embedding:
69 | x = self.feat(x)
70 | x = self.feat_bn(x)
71 | if self.norm:
72 | x = F.normalize(x)
73 | elif self.has_embedding:
74 | x = F.relu(x)
75 | if self.dropout > 0:
76 | x = self.drop(x)
77 | if self.num_classes > 0:
78 | logits = self.classifier(x)
79 | return [logits, x]
80 |
81 | def reset_params(self):
82 | for m in self.modules():
83 | if isinstance(m, nn.Conv2d):
84 | init.kaiming_normal(m.weight, mode='fan_out')
85 | if m.bias is not None:
86 | init.constant(m.bias, 0)
87 | elif isinstance(m, nn.BatchNorm2d):
88 | init.constant(m.weight, 1)
89 | init.constant(m.bias, 0)
90 | elif isinstance(m, nn.Linear):
91 | init.normal(m.weight, std=0.001)
92 | if m.bias is not None:
93 | init.constant(m.bias, 0)
--------------------------------------------------------------------------------
/ReID_attr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import sys
4 | import math
5 | import random
6 | import glob
7 | import cv2
8 | import torch
9 | from scipy import io
10 | from opts import market1501_train_map, duke_train_map, get_opts
11 |
12 | market_dict = {'age':[1,2,3,4], # young(1), teenager(2), adult(3), old(4)
13 | 'backpack':[1,2], # no(1), yes(2)
14 | 'bag':[1,2], # no(1), yes(2)
15 | 'handbag':[1,2], # no(1), yes(2)
16 | 'downblack':[1,2], # no(1), yes(2)
17 | 'downblue':[1,2], # no(1), yes(2)
18 | 'downbrown':[1,2], # no(1), yes(2)
19 | 'downgray':[1,2], # no(1), yes(2)
20 | 'downgreen':[1,2], # no(1), yes(2)
21 | 'downpink':[1,2], # no(1), yes(2)
22 | 'downpurple':[1,2], # no(1), yes(2)
23 | 'downwhite':[1,2], # no(1), yes(2)
24 | 'downyellow':[1,2], # no(1), yes(2)
25 | 'upblack':[1,2], # no(1), yes(2)
26 | 'upblue':[1,2], # no(1), yes(2)
27 | 'upgreen':[1,2], # no(1), yes(2)
28 | 'upgray':[1,2], # no(1), yes(2)
29 | 'uppurple':[1,2], # no(1), yes(2)
30 | 'upred':[1,2], # no(1), yes(2)
31 | 'upwhite':[1,2], # no(1), yes(2)
32 | 'upyellow':[1,2], # no(1), yes(2)
33 | 'clothes':[1,2], # dress(1), pants(2)
34 | 'down':[1,2], # long lower body clothing(1), short(2)
35 | 'up':[1,2], # long sleeve(1), short sleeve(2)
36 | 'hair':[1,2], # short hair(1), long hair(2)
37 | 'hat':[1,2], # no(1), yes(2)
38 | 'gender':[1,2]}# male(1), female(2)
39 |
40 | duke_dict = {'gender':[1,2], # male(1), female(2)
41 | 'top':[1,2], # short upper body clothing(1), long(2)
42 | 'boots':[1,2], # no(1), yes(2)
43 | 'hat':[1,2], # no(1), yes(2)
44 | 'backpack':[1,2], # no(1), yes(2)
45 | 'bag':[1,2], # no(1), yes(2)
46 | 'handbag':[1,2], # no(1), yes(2)
47 | 'shoes':[1,2], # dark(1), light(2)
48 | 'downblack':[1,2], # no(1), yes(2)
49 | 'downwhite':[1,2], # no(1), yes(2)
50 | 'downred':[1,2], # no(1), yes(2)
51 | 'downgray':[1,2], # no(1), yes(2)
52 | 'downblue':[1,2], # no(1), yes(2)
53 | 'downgreen':[1,2], # no(1), yes(2)
54 | 'downbrown':[1,2], # no(1), yes(2)
55 | 'upblack':[1,2], # no(1), yes(2)
56 | 'upwhite':[1,2], # no(1), yes(2)
57 | 'upred':[1,2], # no(1), yes(2)
58 | 'uppurple':[1,2], # no(1), yes(2)
59 | 'upgray':[1,2], # no(1), yes(2)
60 | 'upblue':[1,2], # no(1), yes(2)
61 | 'upgreen':[1,2], # no(1), yes(2)
62 | 'upbrown':[1,2]} # no(1), yes(2)
63 |
64 | __dict_factory={
65 | 'market_attribute': market_dict,
66 | 'dukemtmcreid_attribute': duke_dict
67 | }
68 |
69 | def get_keys(dict_name):
70 | for key, value in __dict_factory.items():
71 | if key == dict_name:
72 | return value.keys()
73 |
74 | def get_target_withattr(attr_matrix, dataset_name, attr_list, pids, pids_raw):
75 | attr_key, attr_value = attr_list
76 | attr_name = 'duke_attribute' if dataset_name == 'dukemtmcreid' else 'market_attribute'
77 | mapping = duke_train_map if dataset_name == 'dukemtmcreid' else market1501_train_map
78 | column = attr_matrix[attr_name][0]['train'][0][0][attr_key][0][0]
79 |
80 | n = pids_raw.size(0)
81 | targets = np.zeros_like(column)
82 | for i in range(n):
83 | if column[mapping[pids_raw[i].item()]] == attr_value:
84 | targets[pids[i].item()] = 1
85 | return torch.from_numpy(targets).view(1,-1).repeat(n, 1)
--------------------------------------------------------------------------------
/util/dataset_loader.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os
3 | from PIL import Image
4 | import numpy as np
5 | import random
6 | import os.path as osp
7 |
8 | import torch
9 | from torch.utils.data import Dataset
10 |
11 | def read_image(img_path):
12 | """Keep reading image until succeed.
13 | This can avoid IOError incurred by heavy IO process."""
14 | got_img = False
15 | if not osp.exists(img_path):
16 | raise IOError("{} does not exist".format(img_path))
17 | while not got_img:
18 | try:
19 | img = Image.open(img_path).convert('RGB')
20 | got_img = True
21 | except IOError:
22 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
23 | pass
24 | return img
25 |
26 | class ImageDataset(Dataset):
27 | """Image Person ReID Dataset"""
28 | def __init__(self, dataset, transform=None):
29 | self.dataset = dataset
30 | self.transform = transform
31 |
32 | def __len__(self):
33 | return len(self.dataset)
34 |
35 | def __getitem__(self, index):
36 | tp = self.dataset[index]
37 | if len(tp) == 3:
38 | img_path, pid, camid = tp
39 | pid_raw = pid
40 | elif len(tp) == 4:
41 | img_path, pid, camid, pid_raw = tp
42 | img = read_image(img_path)
43 | if self.transform is not None:
44 | img = self.transform(img)
45 | return img, pid, camid, pid_raw
46 |
47 | class VideoDataset(Dataset):
48 | """Video Person ReID Dataset.
49 | Note batch data has shape (batch, seq_len, channel, height, width).
50 | """
51 | sample_methods = ['evenly', 'random', 'all']
52 |
53 | def __init__(self, dataset, seq_len=15, sample='evenly', transform=None):
54 | self.dataset = dataset
55 | self.seq_len = seq_len
56 | self.sample = sample
57 | self.transform = transform
58 |
59 | def __len__(self):
60 | return len(self.dataset)
61 |
62 | def __getitem__(self, index):
63 | img_paths, pid, camid = self.dataset[index]
64 | num = len(img_paths)
65 |
66 | if self.sample == 'random':
67 | """
68 | Randomly sample seq_len items from num items,
69 | if num is smaller than seq_len, then replicate items
70 | """
71 | indices = np.arange(num)
72 | replace = False if num >= self.seq_len else True
73 | indices = np.random.choice(indices, size=self.seq_len, replace=replace)
74 | # sort indices to keep temporal order
75 | # comment it to be order-agnostic
76 | indices = np.sort(indices)
77 | elif self.sample == 'evenly':
78 | """Evenly sample seq_len items from num items."""
79 | if num >= self.seq_len:
80 | num -= num % self.seq_len
81 | indices = np.arange(0, num, num/self.seq_len)
82 | else:
83 | # if num is smaller than seq_len, simply replicate the last image
84 | # until the seq_len requirement is satisfied
85 | indices = np.arange(0, num)
86 | num_pads = self.seq_len - num
87 | indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num-1)])
88 | assert len(indices) == self.seq_len
89 | elif self.sample == 'all':
90 | """
91 | Sample all items, seq_len is useless now and batch_size needs
92 | to be set to 1.
93 | """
94 | indices = np.arange(num)
95 | else:
96 | raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods))
97 |
98 | imgs = []
99 | for index in indices:
100 | img_path = img_paths[index]
101 | img = read_image(img_path)
102 | if self.transform is not None:
103 | img = self.transform(img)
104 | img = img.unsqueeze(0)
105 | imgs.append(img)
106 | imgs = torch.cat(imgs, dim=0)
107 |
108 | return imgs, pid, camid
--------------------------------------------------------------------------------
/models/PCB.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import init
6 | from torchvision import models
7 |
8 | __all__ = ['PCB', 'PCB_test']
9 |
10 | def weights_init_kaiming(m):
11 | classname = m.__class__.__name__
12 | # print(classname)
13 | if classname.find('Conv') != -1:
14 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
15 | elif classname.find('Linear') != -1:
16 | init.kaiming_normal(m.weight.data, a=0, mode='fan_out')
17 | init.constant(m.bias.data, 0.0)
18 | elif classname.find('BatchNorm1d') != -1:
19 | init.normal(m.weight.data, 1.0, 0.02)
20 | init.constant(m.bias.data, 0.0)
21 |
22 | def weights_init_classifier(m):
23 | classname = m.__class__.__name__
24 | if classname.find('Linear') != -1:
25 | init.normal(m.weight.data, std=0.001)
26 | init.constant(m.bias.data, 0.0)
27 |
28 | class ClassBlock(nn.Module):
29 | def __init__(self, input_dim, class_num, dropout=True, relu=True, num_bottleneck=512):
30 | super(ClassBlock, self).__init__()
31 | add_block = []
32 | add_block += [nn.Linear(input_dim, num_bottleneck)]
33 | add_block += [nn.BatchNorm1d(num_bottleneck)]
34 | if relu:
35 | add_block += [nn.LeakyReLU(0.1)]
36 | if dropout:
37 | add_block += [nn.Dropout(p=0.5)]
38 | add_block = nn.Sequential(*add_block)
39 | add_block.apply(weights_init_kaiming)
40 |
41 | classifier = []
42 | classifier += [nn.Linear(num_bottleneck, class_num)]
43 | classifier = nn.Sequential(*classifier)
44 | classifier.apply(weights_init_classifier)
45 |
46 | self.add_block = add_block
47 | self.classifier = classifier
48 | def forward(self, x):
49 | x = self.add_block(x)
50 | x = self.classifier(x)
51 | return x
52 |
53 | class PCB(nn.Module):
54 | """
55 | Based on
56 | https://github.com/layumi/Person_reID_baseline_pytorch
57 | """
58 | def __init__(self, num_classes):
59 | super(PCB, self).__init__()
60 |
61 | self.part = 6 # We cut the pool5 to 6 parts
62 | model_ft = models.resnet50(pretrained=True)
63 | self.model = model_ft
64 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1))
65 | self.dropout = nn.Dropout(p=0.5)
66 | # remove the final downsample
67 | self.model.layer4[0].downsample[0].stride = (1,1)
68 | self.model.layer4[0].conv2.stride = (1,1)
69 | # define 6 classifiers
70 | for i in range(self.part):
71 | name = 'classifier'+str(i)
72 | setattr(self, name, ClassBlock(2048, num_classes, True, False, 256))
73 |
74 | def forward(self, x, is_training):
75 | x = self.model.conv1(x)
76 | x = self.model.bn1(x)
77 | x = self.model.relu(x)
78 | x = self.model.maxpool(x)
79 |
80 | x = self.model.layer1(x)
81 | x = self.model.layer2(x)
82 | x = self.model.layer3(x)
83 | x = self.model.layer4(x)
84 | x = self.avgpool(x)
85 | x = self.dropout(x)
86 | part = {}
87 | feature = []
88 | predict = []
89 | # get six part feature batchsize*2048*6
90 | for i in range(self.part):
91 | part[i] = torch.squeeze(x[:,:,i])
92 | name = 'classifier'+str(i)
93 | c = getattr(self,name)
94 | feature.append(part[i])
95 | predict.append(c(part[i]))
96 | return [predict, feature]
97 |
98 | class PCB_test(nn.Module):
99 | def __init__(self, model):
100 | super(PCB_test, self).__init__()
101 | self.part = 6
102 | self.model = model.model
103 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1))
104 | # remove the final downsample
105 | self.model.layer4[0].downsample[0].stride = (1,1)
106 | self.model.layer4[0].conv2.stride = (1,1)
107 |
108 | def forward(self, x, is_training):
109 | x = self.model.conv1(x)
110 | x = self.model.bn1(x)
111 | x = self.model.relu(x)
112 | x = self.model.maxpool(x)
113 |
114 | x = self.model.layer1(x)
115 | x = self.model.layer2(x)
116 | x = self.model.layer3(x)
117 | x = self.model.layer4(x)
118 | x = self.avgpool(x)
119 | y = x.view(x.size(0),x.size(1),x.size(2))
120 | return [y]
--------------------------------------------------------------------------------
/util/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from torchvision.transforms import *
4 | from PIL import Image
5 | import random
6 | import numpy as np
7 | import math
8 |
9 |
10 | class RectScale(object):
11 | def __init__(self, height, width, interpolation=Image.BILINEAR):
12 | self.height = height
13 | self.width = width
14 | self.interpolation = interpolation
15 |
16 | def __call__(self, img):
17 | w, h = img.size
18 | if h == self.height and w == self.width:
19 | return img
20 | return img.resize((self.width, self.height), self.interpolation)
21 |
22 |
23 | class RandomSizedRectCrop(object):
24 | def __init__(self, height, width, interpolation=Image.BILINEAR):
25 | self.height = height
26 | self.width = width
27 | self.interpolation = interpolation
28 |
29 | def __call__(self, img):
30 | for attempt in range(10):
31 | area = img.size[0] * img.size[1]
32 | target_area = random.uniform(0.64, 1.0) * area
33 | aspect_ratio = random.uniform(2, 3)
34 |
35 | h = int(round(math.sqrt(target_area * aspect_ratio)))
36 | w = int(round(math.sqrt(target_area / aspect_ratio)))
37 |
38 | if w <= img.size[0] and h <= img.size[1]:
39 | x1 = random.randint(0, img.size[0] - w)
40 | y1 = random.randint(0, img.size[1] - h)
41 |
42 | img = img.crop((x1, y1, x1 + w, y1 + h))
43 | assert(img.size == (w, h))
44 |
45 | return img.resize((self.width, self.height), self.interpolation)
46 |
47 | # Fallback
48 | scale = RectScale(self.height, self.width,
49 | interpolation=self.interpolation)
50 | return scale(img)
51 |
52 |
53 | class RandomErasing(object):
54 | def __init__(self, EPSILON=0.5, mean=[0.485, 0.456, 0.406]):
55 | self.EPSILON = EPSILON
56 | self.mean = mean
57 |
58 | def __call__(self, img):
59 |
60 | if random.uniform(0, 1) > self.EPSILON:
61 | return img
62 |
63 | for attempt in range(100):
64 | area = img.size()[1] * img.size()[2]
65 |
66 | target_area = random.uniform(0.02, 0.2) * area
67 | aspect_ratio = random.uniform(0.3, 3)
68 |
69 | h = int(round(math.sqrt(target_area * aspect_ratio)))
70 | w = int(round(math.sqrt(target_area / aspect_ratio)))
71 |
72 | if w <= img.size()[2] and h <= img.size()[1]:
73 | x1 = random.randint(0, img.size()[1] - h)
74 | y1 = random.randint(0, img.size()[2] - w)
75 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
76 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
77 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
78 |
79 | return img
80 |
81 | return img
82 |
83 | class Random2DTranslation(object):
84 | """
85 | With a probability, first increase image size to (1 + 1/8), and then perform random crop.
86 |
87 | Args:
88 | height (int): target height.
89 | width (int): target width.
90 | p (float): probability of performing this transformation. Default: 0.5.
91 | """
92 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
93 | self.height = height
94 | self.width = width
95 | self.p = p
96 | self.interpolation = interpolation
97 |
98 | def __call__(self, img):
99 | """
100 | Args:
101 | img (PIL Image): Image to be cropped.
102 |
103 | Returns:
104 | PIL Image: Cropped image.
105 | """
106 | if random.random() < self.p:
107 | return img.resize((self.width, self.height), self.interpolation)
108 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125))
109 | resized_img = img.resize((new_width, new_height), self.interpolation)
110 | x_maxrange = new_width - self.width
111 | y_maxrange = new_height - self.height
112 | x1 = int(round(random.uniform(0, x_maxrange)))
113 | y1 = int(round(random.uniform(0, y_maxrange)))
114 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height))
115 | return croped_img
116 |
117 | if __name__ == '__main__':
118 | pass
--------------------------------------------------------------------------------
/util/samplers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import defaultdict
3 | import numpy as np
4 | import os.path as osp
5 | import torch
6 | from torch.utils.data.sampler import Sampler
7 |
8 | class RandomIdentitySampler(Sampler):
9 | """
10 | Randomly sample N identities, then for each identity,
11 | randomly sample K instances, therefore batch size is N*K.
12 |
13 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.
14 |
15 | Args:
16 | data_source (Dataset): dataset to sample from.
17 | num_instances (int): number of instances per identity.
18 | """
19 | def __init__(self, data_source, num_instances=4):
20 | self.data_source = data_source
21 | self.num_instances = num_instances
22 | self.index_dic = defaultdict(list)
23 | for index, tp in enumerate(data_source):
24 | if len(tp) == 3:
25 | _, pid, _ = tp
26 | elif len(tp) == 4:
27 | _, pid, _, _ = tp
28 |
29 | self.index_dic[pid].append(index)
30 | self.pids = list(self.index_dic.keys())
31 | self.num_identities = len(self.pids)
32 |
33 | def __iter__(self):
34 | indices = torch.randperm(self.num_identities)
35 | ret = []
36 | for i in indices:
37 | pid = self.pids[i]
38 | t = self.index_dic[pid]
39 | replace = False if len(t) >= self.num_instances else True
40 | t = np.random.choice(t, size=self.num_instances, replace=replace)
41 | ret.extend(t)
42 | return iter(ret)
43 |
44 | def __len__(self):
45 | return self.num_identities * self.num_instances
46 |
47 | class RandomIdentitySamplerCls(Sampler):
48 | """
49 | Randomly sample N identities, then for each identity,
50 | randomly sample K instances, therefore batch size is N*K.
51 |
52 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.
53 |
54 | Args:
55 | data_source (Dataset): dataset to sample from.
56 | num_instances (int): number of instances per identity.
57 | """
58 | def __init__(self, data_source, num_instances=4):
59 | self.data_source = data_source
60 | self.num_instances = num_instances
61 | self.index_dic = defaultdict(list)
62 | for index, (_, target) in enumerate(data_source):
63 | self.index_dic[target].append(index)
64 | self.targets = list(self.index_dic.keys())
65 | self.num_identities = len(self.targets)
66 |
67 | def __iter__(self):
68 | indices = torch.randperm(self.num_identities)
69 | ret = []
70 | for i in indices:
71 | target = self.targets[i]
72 | t = self.index_dic[target]
73 | replace = False if len(t) >= self.num_instances else True
74 | t = np.random.choice(t, size=self.num_instances, replace=replace)
75 | ret.extend(t)
76 | return iter(ret)
77 |
78 | def __len__(self):
79 | return self.num_identities * self.num_instances
80 |
81 | class AttrPool(Sampler):
82 | def __init__(self, data_source, dataset_name, attr_matrix, attr_list, sample_num):
83 | from opts import market1501_train_map, duke_train_map
84 | attr_key, attr_value = attr_list
85 | attr_name = 'duke_attribute' if dataset_name == 'dukemtmcreid' else 'market_attribute'
86 | mapping = duke_train_map if dataset_name == 'dukemtmcreid' else market1501_train_map
87 | column = attr_matrix[attr_name][0]['train'][0][0][attr_key][0][0]
88 |
89 | self.data_source = data_source
90 | self.sample_num = sample_num
91 | self.attr_pool = defaultdict(list)
92 |
93 | for index, (_, pid, _, pid_raw) in enumerate(data_source):
94 | if column[mapping[pid_raw]] == attr_value:
95 | self.attr_pool[0].append(index)
96 | else:
97 | self.attr_pool[1].append(index)
98 | self.attrs = list(self.attr_pool.keys())
99 | self.num_attrs = len(self.attrs)
100 |
101 | def __iter__(self):
102 | ret = []
103 | for i in range(700):
104 | t = self.attr_pool[self.attrs[i%2]]
105 | replace = False if len(t) >= self.sample_num else True
106 | t = np.random.choice(t, size=self.sample_num, replace=replace)
107 | ret.extend(t)
108 | return iter(ret)
109 |
110 | def __len__(self):
111 | return self.sample_num*700
--------------------------------------------------------------------------------
/advloss.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import torch
3 | import random
4 | import numpy as np
5 | from torch import nn
6 |
7 | __all__ = ['DeepSupervision', 'adv_CrossEntropyLoss','adv_CrossEntropyLabelSmooth', 'adv_TripletLoss']
8 |
9 | def DeepSupervision(criterion, xs, *args, **kwargs):
10 | loss = 0.
11 | for x in xs: loss += criterion(x, *args, **kwargs)
12 | return loss
13 |
14 | class adv_CrossEntropyLoss(nn.Module):
15 | def __init__(self, use_gpu=True):
16 | super(adv_CrossEntropyLoss, self).__init__()
17 | self.use_gpu = use_gpu
18 | self.crossentropy_loss = nn.CrossEntropyLoss()
19 |
20 | def forward(self, logits, pids):
21 | """
22 | Args:
23 | logits: prediction matrix (before softmax) with shape (batch_size, num_classes)
24 | """
25 | _, adv_target = torch.min(logits, 1)
26 |
27 | if self.use_gpu: adv_target = adv_target.cuda()
28 | loss = self.crossentropy_loss(logits, adv_target)
29 | return torch.log(loss)
30 |
31 | class adv_CrossEntropyLabelSmooth(nn.Module):
32 | """
33 | Args:
34 | num_classes (int): number of classes.
35 | epsilon (float): weight.
36 | """
37 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
38 | super(adv_CrossEntropyLabelSmooth, self).__init__()
39 | self.num_classes = num_classes
40 | self.epsilon = epsilon
41 | self.use_gpu = use_gpu
42 | self.logsoftmax = nn.LogSoftmax(dim=1)
43 |
44 | def forward(self, logits, pids):
45 | """
46 | Args:
47 | logits: prediction matrix (before softmax) with shape (batch_size, num_classes)
48 | pids: ground truth labels with shape (num_classes)
49 | """
50 | # n = pids.size(0)
51 | # _, top2 = torch.topk(logits, k=2, dim=1, largest=False)
52 | # adv_target = top2[:,0]
53 | # for i in range(n):
54 | # if adv_target[i] == pids[i]: adv_target[i] = top2[i,1]
55 | # else: continue
56 | _, adv_target = torch.min(logits, 1)
57 | # for i in range(n):
58 | # while adv_target[i] == pids[i]:
59 | # adv_target[i] = random.randint(0, self.num_classes)
60 |
61 | log_probs = self.logsoftmax(logits)
62 | adv_target = torch.zeros(log_probs.size()).scatter_(1, adv_target.unsqueeze(1).data.cpu(), 1)
63 | smooth = torch.ones(log_probs.size()) / (self.num_classes-1)
64 | smooth[:, pids.data.cpu()] = 0 # Pytorch1.0
65 | smooth = smooth.cuda()
66 | if self.use_gpu: adv_target = adv_target.cuda()
67 | adv_target = (1 - self.epsilon) * adv_target + self.epsilon * smooth
68 | loss = (- adv_target * log_probs).mean(0).sum()
69 | return torch.log(loss)
70 |
71 | class adv_TripletLoss(nn.Module):
72 | def __init__(self, ak_type, margin=0.3):
73 | super(adv_TripletLoss, self).__init__()
74 | self.margin = margin
75 | self.ak_type = ak_type
76 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
77 |
78 | def forward(self, features, pids, targets=None):
79 | """
80 | Args:
81 | features: feature matrix with shape (batch_size, feat_dim)
82 | pids: ground truth labels with shape (num_classes)
83 | targets: pids with certain attribute (batch_size, pids)
84 | """
85 | n = features.size(0)
86 |
87 | dist = torch.pow(features, 2).sum(dim=1, keepdim=True).expand(n, n)
88 | dist = dist + dist.t()
89 | dist.addmm_(1, -2, features, features.t())
90 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
91 |
92 | if self.ak_type < 0:
93 | mask = pids.expand(n, n).eq(pids.expand(n, n).t())
94 | dist_ap, dist_an = [], []
95 | for i in range(n):
96 | dist_an.append(dist[i][mask[i]].min().unsqueeze(0)) # make nearest pos-pos far away
97 | dist_ap.append(dist[i][mask[i] == 0].max().unsqueeze(0)) # make hardest pos-neg closer
98 |
99 | elif self.ak_type > 0:
100 | p = []
101 | for i in range(n):
102 | p.append(pids[i].item())
103 | mask = targets[0][p].expand(n, n).eq(targets[0][p].expand(n, n).t())
104 | dist_ap, dist_an = [], []
105 | for i in range(n):
106 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
107 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
108 |
109 | dist_ap = torch.cat(dist_ap)
110 | dist_an = torch.cat(dist_an)
111 |
112 | y = torch.ones_like(dist_an)
113 |
114 | loss = self.ranking_loss(dist_an, dist_ap, y)
115 | return torch.log(loss)
--------------------------------------------------------------------------------
/util/re_ranking.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Fri, 25 May 2018 20:29:09
5 |
6 | @author: luohao
7 | """
8 |
9 | """
10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking
13 | """
14 |
15 | """
16 | API
17 |
18 | probFea: all feature vectors of the query set (torch tensor)
19 | probFea: all feature vectors of the gallery set (torch tensor)
20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3)
21 | MemorySave: set to 'True' when using MemorySave mode
22 | Minibatch: avaliable when 'MemorySave' is 'True'
23 | """
24 |
25 | import numpy as np
26 | import torch
27 |
28 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat = None, only_local = False):
29 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor
30 | query_num = probFea.size(0)
31 | all_num = query_num + galFea.size(0)
32 | if only_local:
33 | original_dist = local_distmat
34 | else:
35 | feat = torch.cat([probFea,galFea])
36 | print('using GPU to compute original distance')
37 | distmat = torch.pow(feat,2).sum(dim=1, keepdim=True).expand(all_num,all_num) + \
38 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t()
39 | distmat.addmm_(1,-2,feat,feat.t())
40 | original_dist = distmat.numpy()
41 | del feat
42 | if not local_distmat is None:
43 | original_dist = original_dist + local_distmat
44 | gallery_num = original_dist.shape[0]
45 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
46 | V = np.zeros_like(original_dist).astype(np.float16)
47 | initial_rank = np.argsort(original_dist).astype(np.int32)
48 |
49 | print('starting re_ranking')
50 | for i in range(all_num):
51 | # k-reciprocal neighbors
52 | forward_k_neigh_index = initial_rank[i, :k1 + 1]
53 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
54 | fi = np.where(backward_k_neigh_index == i)[0]
55 | k_reciprocal_index = forward_k_neigh_index[fi]
56 | k_reciprocal_expansion_index = k_reciprocal_index
57 | for j in range(len(k_reciprocal_index)):
58 | candidate = k_reciprocal_index[j]
59 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1]
60 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
61 | :int(np.around(k1 / 2)) + 1]
62 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
63 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
64 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
65 | candidate_k_reciprocal_index):
66 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)
67 |
68 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
69 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
70 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
71 | original_dist = original_dist[:query_num, ]
72 | if k2 != 1:
73 | V_qe = np.zeros_like(V, dtype=np.float16)
74 | for i in range(all_num):
75 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
76 | V = V_qe
77 | del V_qe
78 | del initial_rank
79 | invIndex = []
80 | for i in range(gallery_num):
81 | invIndex.append(np.where(V[:, i] != 0)[0])
82 |
83 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
84 |
85 | for i in range(query_num):
86 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16)
87 | indNonZero = np.where(V[i, :] != 0)[0]
88 | indImages = [invIndex[ind] for ind in indNonZero]
89 | for j in range(len(indNonZero)):
90 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
91 | V[indImages[j], indNonZero[j]])
92 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
93 |
94 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
95 | del original_dist
96 | del V
97 | del jaccard_dist
98 | final_dist = final_dist[:query_num, query_num:]
99 | return final_dist
100 |
101 |
--------------------------------------------------------------------------------
/util/ms_ssim.py:
--------------------------------------------------------------------------------
1 | """Code imported from https://github.com/jorge-pessoa/pytorch-msssim/blob/master/pytorch_msssim/__init__.py"""
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from math import exp
6 | import numpy as np
7 |
8 |
9 | def gaussian(window_size, sigma):
10 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
11 | return gauss/gauss.sum()
12 |
13 |
14 | def create_window(window_size, channel=1):
15 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
16 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
17 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
18 | return window
19 |
20 |
21 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
22 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
23 | if val_range is None:
24 | if torch.max(img1) > 128:
25 | max_val = 255
26 | else:
27 | max_val = 1
28 |
29 | if torch.min(img1) < -0.5:
30 | min_val = -1
31 | else:
32 | min_val = 0
33 | L = max_val - min_val
34 | else:
35 | L = val_range
36 |
37 | padd = 0
38 | (_, channel, height, width) = img1.size()
39 | if window is None:
40 | real_size = min(window_size, height, width)
41 | window = create_window(real_size, channel=channel).to(img1.device)
42 |
43 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
44 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
45 |
46 | mu1_sq = mu1.pow(2)
47 | mu2_sq = mu2.pow(2)
48 | mu1_mu2 = mu1 * mu2
49 |
50 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
51 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
52 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
53 |
54 | C1 = (0.01 * L) ** 2
55 | C2 = (0.03 * L) ** 2
56 |
57 | v1 = 2.0 * sigma12 + C2
58 | v2 = sigma1_sq + sigma2_sq + C2
59 | cs = torch.mean(v1 / v2) # contrast sensitivity
60 |
61 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
62 |
63 | if size_average:
64 | ret = ssim_map.mean()
65 | else:
66 | ret = ssim_map.mean(1).mean(1).mean(1)
67 |
68 | if full:
69 | return ret, cs
70 | return ret
71 |
72 |
73 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
74 | device = img1.device
75 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
76 | levels = weights.size()[0]
77 | mssim = []
78 | mcs = []
79 | for _ in range(levels):
80 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
81 | mssim.append(sim)
82 | mcs.append(cs)
83 |
84 | img1 = F.avg_pool2d(img1, (2, 2))
85 | img2 = F.avg_pool2d(img2, (2, 2))
86 |
87 | mssim = torch.stack(mssim)
88 | mcs = torch.stack(mcs)
89 |
90 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
91 | if normalize:
92 | mssim = (mssim + 1) / 2
93 | mcs = (mcs + 1) / 2
94 |
95 | pow1 = mcs ** weights
96 | pow2 = mssim ** weights
97 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
98 | output = torch.prod(pow1[:-1] * pow2[-1])
99 | return output
100 |
101 |
102 | # Classes to re-use window
103 | class SSIM(torch.nn.Module):
104 | def __init__(self, window_size=11, size_average=True, val_range=None):
105 | super(SSIM, self).__init__()
106 | self.window_size = window_size
107 | self.size_average = size_average
108 | self.val_range = val_range
109 |
110 | # Assume 1 channel for SSIM
111 | self.channel = 1
112 | self.window = create_window(window_size)
113 |
114 | def forward(self, img1, img2):
115 | (_, channel, _, _) = img1.size()
116 |
117 | if channel == self.channel and self.window.dtype == img1.dtype:
118 | window = self.window
119 | else:
120 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
121 | self.window = window
122 | self.channel = channel
123 |
124 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
125 |
126 | class MSSSIM(torch.nn.Module):
127 | def __init__(self, window_size=11, size_average=True, channel=3):
128 | super(MSSSIM, self).__init__()
129 | self.window_size = window_size
130 | self.size_average = size_average
131 | self.channel = channel
132 |
133 | def forward(self, img1, img2):
134 | # TODO: store window between calls if possible
135 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
--------------------------------------------------------------------------------
/util/local_dist.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def batch_euclidean_dist(x, y):
4 | """
5 | Args:
6 | x: pytorch Variable, with shape [Batch size, Local part, Feature channel]
7 | y: pytorch Variable, with shape [Batch size, Local part, Feature channel]
8 | Returns:
9 | dist: pytorch Variable, with shape [Batch size, Local part, Local part]
10 | """
11 | assert len(x.size()) == 3
12 | assert len(y.size()) == 3
13 | assert x.size(0) == y.size(0)
14 | assert x.size(-1) == y.size(-1)
15 |
16 | N, m, d = x.size()
17 | N, n, d = y.size()
18 |
19 | # shape [N, m, n]
20 | xx = torch.pow(x, 2).sum(-1, keepdim=True).expand(N, m, n)
21 | yy = torch.pow(y, 2).sum(-1, keepdim=True).expand(N, n, m).permute(0, 2, 1)
22 | dist = xx + yy
23 | dist.baddbmm_(1, -2, x, y.permute(0, 2, 1))
24 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
25 | return dist
26 |
27 | def shortest_dist(dist_mat):
28 | """Parallel version.
29 | Args:
30 | dist_mat: pytorch Variable, available shape:
31 | 1) [m, n]
32 | 2) [m, n, N], N is batch size
33 | 3) [m, n, *], * can be arbitrary additional dimensions
34 | Returns:
35 | dist: three cases corresponding to `dist_mat`:
36 | 1) scalar
37 | 2) pytorch Variable, with shape [N]
38 | 3) pytorch Variable, with shape [*]
39 | """
40 | m, n = dist_mat.size()[:2]
41 | # Just offering some reference for accessing intermediate distance.
42 | dist = [[0 for _ in range(n)] for _ in range(m)]
43 | for i in range(m):
44 | for j in range(n):
45 | if (i == 0) and (j == 0):
46 | dist[i][j] = dist_mat[i, j]
47 | elif (i == 0) and (j > 0):
48 | dist[i][j] = dist[i][j - 1] + dist_mat[i, j]
49 | elif (i > 0) and (j == 0):
50 | dist[i][j] = dist[i - 1][j] + dist_mat[i, j]
51 | else:
52 | dist[i][j] = torch.min(dist[i - 1][j], dist[i][j - 1]) + dist_mat[i, j]
53 | dist = dist[-1][-1]
54 | return dist
55 |
56 | def hard_example_mining(dist_mat, labels, return_inds=False):
57 | """For each anchor, find the hardest positive and negative sample.
58 | Args:
59 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
60 | labels: pytorch LongTensor, with shape [N]
61 | return_inds: whether to return the indices. Save time if `False`(?)
62 | Returns:
63 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
64 | dist_an: pytorch Variable, distance(anchor, negative); shape [N]
65 | p_inds: pytorch LongTensor, with shape [N];
66 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
67 | n_inds: pytorch LongTensor, with shape [N];
68 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
69 | NOTE: Only consider the case in which all labels have same num of samples,
70 | thus we can cope with all anchors in parallel.
71 | """
72 |
73 | assert len(dist_mat.size()) == 2
74 | assert dist_mat.size(0) == dist_mat.size(1)
75 | N = dist_mat.size(0)
76 |
77 | # shape [N, N]
78 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())
79 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t())
80 |
81 | # `dist_ap` means distance(anchor, positive)
82 | # both `dist_ap` and `relative_p_inds` with shape [N, 1]
83 | dist_ap, relative_p_inds = torch.max(dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)
84 | # `dist_an` means distance(anchor, negative)
85 | # both `dist_an` and `relative_n_inds` with shape [N, 1]
86 | dist_an, relative_n_inds = torch.min(dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)
87 | # shape [N]
88 | dist_ap = dist_ap.squeeze(1)
89 | dist_an = dist_an.squeeze(1)
90 |
91 | if return_inds:
92 | # shape [N, N]
93 | ind = (labels.new().resize_as_(labels).copy_(torch.arange(0, N).long()).unsqueeze( 0).expand(N, N))
94 | # shape [N, 1]
95 | p_inds = torch.gather(ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)
96 | n_inds = torch.gather(ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)
97 | # shape [N]
98 | p_inds = p_inds.squeeze(1)
99 | n_inds = n_inds.squeeze(1)
100 | return dist_ap, dist_an, p_inds, n_inds
101 |
102 | return dist_ap, dist_an
103 |
104 | def euclidean_dist(x, y):
105 | """
106 | Args:
107 | x: pytorch Variable, with shape [m, d]
108 | y: pytorch Variable, with shape [n, d]
109 | Returns:
110 | dist: pytorch Variable, with shape [m, n]
111 | """
112 | m, n = x.size(0), y.size(0)
113 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
114 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
115 | dist = xx + yy
116 | dist.addmm_(1, -2, x, y.t())
117 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
118 | return dist
119 |
120 | def batch_local_dist(x, y):
121 | """
122 | Args:
123 | x: pytorch Variable, with shape [N, m, d]
124 | y: pytorch Variable, with shape [N, n, d]
125 | Returns:
126 | dist: pytorch Variable, with shape [N]
127 | """
128 | assert len(x.size()) == 3
129 | assert len(y.size()) == 3
130 | assert x.size(0) == y.size(0)
131 | assert x.size(-1) == y.size(-1)
132 |
133 | # shape [N, m, n]
134 | dist_mat = batch_euclidean_dist(x, y)
135 | dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.)
136 | # shape [N]
137 | dist = shortest_dist(dist_mat.permute(1, 2, 0))
138 | return dist
139 |
140 | if __name__ == '__main__':
141 | x = torch.randn(32,2048)
142 | y = torch.randn(32,2048)
143 | dist_mat = euclidean_dist(x,y)
144 | dist_ap, dist_an, p_inds, n_inds = hard_example_mining(dist_mat,return_inds=True)
145 | from IPython import embed
146 | embed()
--------------------------------------------------------------------------------
/models/MuDeep.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | import torchvision
7 |
8 | __all__ = ['MuDeep']
9 |
10 | class ConvBlock(nn.Module):
11 | """Basic convolutional block:
12 | convolution + batch normalization + relu.
13 | Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
14 | in_c (int): number of input channels.
15 | out_c (int): number of output channels.
16 | k (int or tuple): kernel size.
17 | s (int or tuple): stride.
18 | p (int or tuple): padding.
19 | """
20 | def __init__(self, in_c, out_c, k, s, p):
21 | super(ConvBlock, self).__init__()
22 | self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
23 | self.bn = nn.BatchNorm2d(out_c)
24 |
25 | def forward(self, x):
26 | return F.relu(self.bn(self.conv(x)))
27 |
28 | class ConvLayers(nn.Module):
29 | """Preprocessing layers."""
30 | def __init__(self):
31 | super(ConvLayers, self).__init__()
32 | self.conv1 = ConvBlock(3, 48, k=3, s=1, p=1)
33 | self.conv2 = ConvBlock(48, 96, k=3, s=1, p=1)
34 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
35 |
36 | def forward(self, x):
37 | x = self.conv1(x)
38 | x = self.conv2(x)
39 | x = self.maxpool(x)
40 | return x
41 |
42 | class MultiScaleA(nn.Module):
43 | """Multi-scale stream layer A (Sec.3.1)"""
44 | def __init__(self):
45 | super(MultiScaleA, self).__init__()
46 | self.stream1 = nn.Sequential(
47 | ConvBlock(96, 96, k=1, s=1, p=0),
48 | ConvBlock(96, 24, k=3, s=1, p=1),
49 | )
50 | self.stream2 = nn.Sequential(
51 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
52 | ConvBlock(96, 24, k=1, s=1, p=0),
53 | )
54 | self.stream3 = ConvBlock(96, 24, k=1, s=1, p=0)
55 | self.stream4 = nn.Sequential(
56 | ConvBlock(96, 16, k=1, s=1, p=0),
57 | ConvBlock(16, 24, k=3, s=1, p=1),
58 | ConvBlock(24, 24, k=3, s=1, p=1),
59 | )
60 |
61 | def forward(self, x):
62 | s1 = self.stream1(x)
63 | s2 = self.stream2(x)
64 | s3 = self.stream3(x)
65 | s4 = self.stream4(x)
66 | y = torch.cat([s1, s2, s3, s4], dim=1)
67 | return y
68 |
69 | class Reduction(nn.Module):
70 | """Reduction layer (Sec.3.1)"""
71 | def __init__(self):
72 | super(Reduction, self).__init__()
73 | self.stream1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
74 | self.stream2 = ConvBlock(96, 96, k=3, s=2, p=1)
75 | self.stream3 = nn.Sequential(
76 | ConvBlock(96, 48, k=1, s=1, p=0),
77 | ConvBlock(48, 56, k=3, s=1, p=1),
78 | ConvBlock(56, 64, k=3, s=2, p=1),
79 | )
80 |
81 | def forward(self, x):
82 | s1 = self.stream1(x)
83 | s2 = self.stream2(x)
84 | s3 = self.stream3(x)
85 | y = torch.cat([s1, s2, s3], dim=1)
86 | return y
87 |
88 | class MultiScaleB(nn.Module):
89 | """Multi-scale stream layer B (Sec.3.1)"""
90 | def __init__(self):
91 | super(MultiScaleB, self).__init__()
92 | self.stream1 = nn.Sequential(
93 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
94 | ConvBlock(256, 256, k=1, s=1, p=0),
95 | )
96 | self.stream2 = nn.Sequential(
97 | ConvBlock(256, 64, k=1, s=1, p=0),
98 | ConvBlock(64, 128, k=(1, 3), s=1, p=(0, 1)),
99 | ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)),
100 | )
101 | self.stream3 = ConvBlock(256, 256, k=1, s=1, p=0)
102 | self.stream4 = nn.Sequential(
103 | ConvBlock(256, 64, k=1, s=1, p=0),
104 | ConvBlock(64, 64, k=(1, 3), s=1, p=(0, 1)),
105 | ConvBlock(64, 128, k=(3, 1), s=1, p=(1, 0)),
106 | ConvBlock(128, 128, k=(1, 3), s=1, p=(0, 1)),
107 | ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)),
108 | )
109 |
110 | def forward(self, x):
111 | s1 = self.stream1(x)
112 | s2 = self.stream2(x)
113 | s3 = self.stream3(x)
114 | s4 = self.stream4(x)
115 | return s1, s2, s3, s4
116 |
117 | class Fusion(nn.Module):
118 | """Saliency-based learning fusion layer (Sec.3.2)"""
119 | def __init__(self):
120 | super(Fusion, self).__init__()
121 | self.a1 = nn.Parameter(torch.rand(1, 256, 1, 1))
122 | self.a2 = nn.Parameter(torch.rand(1, 256, 1, 1))
123 | self.a3 = nn.Parameter(torch.rand(1, 256, 1, 1))
124 | self.a4 = nn.Parameter(torch.rand(1, 256, 1, 1))
125 |
126 | # We add an average pooling layer to reduce the spatial dimension
127 | # of feature maps, which differs from the original paper.
128 | self.avgpool = nn.AvgPool2d(kernel_size=4, stride=4, padding=0)
129 |
130 | def forward(self, x1, x2, x3, x4):
131 | s1 = self.a1.expand_as(x1) * x1
132 | s2 = self.a2.expand_as(x2) * x2
133 | s3 = self.a3.expand_as(x3) * x3
134 | s4 = self.a4.expand_as(x4) * x4
135 | y = self.avgpool(s1 + s2 + s3 + s4)
136 | return y
137 |
138 | class MuDeep(nn.Module):
139 | """Multiscale deep neural network.
140 | Reference:
141 | Qian et al. Multi-scale Deep Learning Architectures for Person Re-identification. ICCV 2017.
142 | """
143 | def __init__(self, num_classes, loss={'xent', 'htri'}, **kwargs):
144 | super(MuDeep, self).__init__()
145 | self.loss = loss
146 |
147 | self.block1 = ConvLayers()
148 | self.block2 = MultiScaleA()
149 | self.block3 = Reduction()
150 | self.block4 = MultiScaleB()
151 | self.block5 = Fusion()
152 |
153 | # Due to this fully connected layer, input image has to be fixed
154 | # in shape, i.e. (3, 256, 128), such that the last convolutional feature
155 | # maps are of shape (256, 16, 8). If input shape is changed,
156 | # the input dimension of this layer has to be changed accordingly.
157 | self.fc = nn.Sequential(
158 | nn.Linear(256*16*8, 4096),
159 | nn.BatchNorm1d(4096),
160 | nn.ReLU(),
161 | )
162 | self.classifier = nn.Linear(4096, num_classes)
163 | self.feat_dim = 4096 # feature dimension
164 |
165 | def forward(self, x, is_training):
166 | x = self.block1(x)
167 | x = self.block2(x)
168 | x = self.block3(x)
169 | x = self.block4(x)
170 | x = self.block5(*x)
171 | x = x.view(x.size(0), -1)
172 | x = self.fc(x)
173 | y = self.classifier(x)
174 |
175 | if self.loss == {'xent'}:
176 | return [y]
177 | elif self.loss == {'xent', 'htri'}:
178 | return [y, x]
179 | elif self.loss == {'cent'}:
180 | return [y, x]
181 | else:
182 | raise KeyError("Unsupported loss: {}".format(self.loss))
--------------------------------------------------------------------------------
/util/distance.py:
--------------------------------------------------------------------------------
1 | """Numpy version of euclidean distance, shortest distance, etc.
2 | Notice the input/output shape of methods, so that you can better understand
3 | the meaning of these methods."""
4 | import numpy as np
5 |
6 |
7 | def normalize(nparray, order=2, axis=0):
8 | """Normalize a N-D numpy array along the specified axis."""
9 | norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
10 | return nparray / (norm + np.finfo(np.float32).eps)
11 |
12 |
13 | def compute_dist(array1, array2, type='euclidean'):
14 | """Compute the euclidean or cosine distance of all pairs.
15 | Args:
16 | array1: numpy array with shape [m1, n]
17 | array2: numpy array with shape [m2, n]
18 | type: one of ['cosine', 'euclidean']
19 | Returns:
20 | numpy array with shape [m1, m2]
21 | """
22 | assert type in ['cosine', 'euclidean']
23 | if type == 'cosine':
24 | array1 = normalize(array1, axis=1)
25 | array2 = normalize(array2, axis=1)
26 | dist = np.matmul(array1, array2.T)
27 | return dist
28 | else:
29 | # shape [m1, 1]
30 | square1 = np.sum(np.square(array1), axis=1)[..., np.newaxis]
31 | # shape [1, m2]
32 | square2 = np.sum(np.square(array2), axis=1)[np.newaxis, ...]
33 | squared_dist = - 2 * np.matmul(array1, array2.T) + square1 + square2
34 | squared_dist[squared_dist < 0] = 0
35 | dist = np.sqrt(squared_dist)
36 | return dist
37 |
38 |
39 | def shortest_dist(dist_mat):
40 | """Parallel version.
41 | Args:
42 | dist_mat: numpy array, available shape
43 | 1) [m, n]
44 | 2) [m, n, N], N is batch size
45 | 3) [m, n, *], * can be arbitrary additional dimensions
46 | Returns:
47 | dist: three cases corresponding to `dist_mat`
48 | 1) scalar
49 | 2) numpy array, with shape [N]
50 | 3) numpy array with shape [*]
51 | """
52 | m, n = dist_mat.shape[:2]
53 | dist = np.zeros_like(dist_mat)
54 | for i in range(m):
55 | for j in range(n):
56 | if (i == 0) and (j == 0):
57 | dist[i, j] = dist_mat[i, j]
58 | elif (i == 0) and (j > 0):
59 | dist[i, j] = dist[i, j - 1] + dist_mat[i, j]
60 | elif (i > 0) and (j == 0):
61 | dist[i, j] = dist[i - 1, j] + dist_mat[i, j]
62 | else:
63 | dist[i, j] = \
64 | np.min(np.stack([dist[i - 1, j], dist[i, j - 1]], axis=0), axis=0) \
65 | + dist_mat[i, j]
66 | # I ran into memory disaster when returning this reference! I still don't
67 | # know why.
68 | # dist = dist[-1, -1]
69 | dist = dist[-1, -1].copy()
70 | return dist
71 |
72 | def unaligned_dist(dist_mat):
73 | """Parallel version.
74 | Args:
75 | dist_mat: numpy array, available shape
76 | 1) [m, n]
77 | 2) [m, n, N], N is batch size
78 | 3) [m, n, *], * can be arbitrary additional dimensions
79 | Returns:
80 | dist: three cases corresponding to `dist_mat`
81 | 1) scalar
82 | 2) numpy array, with shape [N]
83 | 3) numpy array with shape [*]
84 | """
85 |
86 | m = dist_mat.shape[0]
87 | dist = np.zeros_like(dist_mat[0])
88 | for i in range(m):
89 | dist[i] = dist_mat[i][i]
90 | dist = np.sum(dist, axis=0).copy()
91 | return dist
92 |
93 |
94 | def meta_local_dist(x, y, aligned):
95 | """
96 | Args:
97 | x: numpy array, with shape [m, d]
98 | y: numpy array, with shape [n, d]
99 | Returns:
100 | dist: scalar
101 | """
102 | eu_dist = compute_dist(x, y, 'euclidean')
103 | dist_mat = (np.exp(eu_dist) - 1.) / (np.exp(eu_dist) + 1.)
104 | if aligned:
105 | dist = shortest_dist(dist_mat[np.newaxis])[0]
106 | else:
107 | dist = unaligned_dist(dist_mat[np.newaxis])[0]
108 | return dist
109 |
110 |
111 | # Tooooooo slow!
112 | def serial_local_dist(x, y):
113 | """
114 | Args:
115 | x: numpy array, with shape [M, m, d]
116 | y: numpy array, with shape [N, n, d]
117 | Returns:
118 | dist: numpy array, with shape [M, N]
119 | """
120 | M, N = x.shape[0], y.shape[0]
121 | dist_mat = np.zeros([M, N])
122 | for i in range(M):
123 | for j in range(N):
124 | dist_mat[i, j] = meta_local_dist(x[i], y[j])
125 | return dist_mat
126 |
127 |
128 | def parallel_local_dist(x, y, aligned):
129 | """Parallel version.
130 | Args:
131 | x: numpy array, with shape [M, m, d]
132 | y: numpy array, with shape [N, n, d]
133 | Returns:
134 | dist: numpy array, with shape [M, N]
135 | """
136 | M, m, d = x.shape
137 | N, n, d = y.shape
138 | x = x.reshape([M * m, d])
139 | y = y.reshape([N * n, d])
140 | # shape [M * m, N * n]
141 | dist_mat = compute_dist(x, y, type='euclidean')
142 | dist_mat = (np.exp(dist_mat) - 1.) / (np.exp(dist_mat) + 1.)
143 | # shape [M * m, N * n] -> [M, m, N, n] -> [m, n, M, N]
144 | dist_mat = dist_mat.reshape([M, m, N, n]).transpose([1, 3, 0, 2])
145 | # shape [M, N]
146 | if aligned:
147 | dist_mat = shortest_dist(dist_mat)
148 | else:
149 | dist_mat = unaligned_dist(dist_mat)
150 | return dist_mat
151 |
152 |
153 | def local_dist(x, y, aligned):
154 | if (x.ndim == 2) and (y.ndim == 2):
155 | return meta_local_dist(x, y, aligned)
156 | elif (x.ndim == 3) and (y.ndim == 3):
157 | return parallel_local_dist(x, y, aligned)
158 | else:
159 | raise NotImplementedError('Input shape not supported.')
160 |
161 |
162 | def low_memory_matrix_op(
163 | func,
164 | x, y,
165 | x_split_axis, y_split_axis,
166 | x_num_splits, y_num_splits,
167 | verbose=False, aligned=True):
168 | """
169 | For matrix operation like multiplication, in order not to flood the memory
170 | with huge data, split matrices into smaller parts (Divide and Conquer).
171 |
172 | Note:
173 | If still out of memory, increase `*_num_splits`.
174 |
175 | Args:
176 | func: a matrix function func(x, y) -> z with shape [M, N]
177 | x: numpy array, the dimension to split has length M
178 | y: numpy array, the dimension to split has length N
179 | x_split_axis: The axis to split x into parts
180 | y_split_axis: The axis to split y into parts
181 | x_num_splits: number of splits. 1 <= x_num_splits <= M
182 | y_num_splits: number of splits. 1 <= y_num_splits <= N
183 | verbose: whether to print the progress
184 |
185 | Returns:
186 | mat: numpy array, shape [M, N]
187 | """
188 |
189 | if verbose:
190 | import sys
191 | import time
192 | printed = True
193 | st = time.time()
194 | last_time = time.time()
195 |
196 | mat = [[] for _ in range(x_num_splits)]
197 | for i, part_x in enumerate(
198 | np.array_split(x, x_num_splits, axis=x_split_axis)):
199 | for j, part_y in enumerate(
200 | np.array_split(y, y_num_splits, axis=y_split_axis)):
201 | part_mat = func(part_x, part_y, aligned)
202 | mat[i].append(part_mat)
203 |
204 | if verbose:
205 | if not printed:
206 | printed = True
207 | else:
208 | # Clean the current line
209 | sys.stdout.write("\033[F\033[K")
210 | print('Matrix part ({}, {}) / ({}, {}), +{:.2f}s, total {:.2f}s'
211 | .format(i + 1, j + 1, x_num_splits, y_num_splits,
212 | time.time() - last_time, time.time() - st))
213 | last_time = time.time()
214 | mat[i] = np.concatenate(mat[i], axis=1)
215 | mat = np.concatenate(mat, axis=0)
216 | return mat
217 |
218 |
219 | def low_memory_local_dist(x, y, aligned=True):
220 | print('Computing local distance...')
221 | x_num_splits = int(len(x) / 200) + 1
222 | y_num_splits = int(len(y) / 200) + 1
223 | z = low_memory_matrix_op(local_dist, x, y, 0, 0, x_num_splits, y_num_splits, verbose=True, aligned=aligned)
224 | return z
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Adversarial-attack-on-Person-ReID-With-Deep-Mis-Ranking
2 | This is the code for the [CVPR'20 paper](https://arxiv.org/abs/2004.04199) "Transferable, Controllable, and Inconspicuous Adversarial Attacks on Person Re-identification With Deep Mis-Ranking." by Hongjun Wang, Guangrun Wang, Ya Li, Dongyu Zhang, Liang Lin.
3 |
4 |
5 |
6 |
7 |
8 | # Prerequisites
9 | * Python2 / Python3
10 | * Pytorch0.4.1 (do not test for >=Pytorch1.0)
11 | * CUDA
12 | * Numpy
13 | * Matplotlib
14 | * Scipy
15 |
16 | # Prepare data
17 | Create a directory to store reid datasets under this repo
18 | ```bash
19 | mkdir data/
20 | ```
21 |
22 | If you wanna store datasets in another directory, you need to specify `--root path_to_your/data` when running the training code. Please follow the instructions below to prepare each dataset. After that, you can simply do `-d the_dataset` when running the training code.
23 |
24 | **Market1501** :
25 |
26 | 1. Download dataset to `data/` from http://www.liangzheng.org/Project/project_reid.html.
27 | 2. Extract dataset and rename to `market1501`. The data structure would look like:
28 | ```
29 | market1501/
30 | bounding_box_test/
31 | bounding_box_train/
32 | ...
33 | ```
34 | 3. Use `-d market1501` when running the training code.
35 |
36 | **CUHK03** [13]:
37 | 1. Create a folder named `cuhk03/` under `data/`.
38 | 2. Download dataset to `data/cuhk03/` from http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html and extract `cuhk03_release.zip`, so you will have `data/cuhk03/cuhk03_release`.
39 | 3. Download new split [14] from [person-re-ranking](https://github.com/zhunzhong07/person-re-ranking/tree/master/evaluation/data/CUHK03). What you need are `cuhk03_new_protocol_config_detected.mat` and `cuhk03_new_protocol_config_labeled.mat`. Put these two mat files under `data/cuhk03`. Finally, the data structure would look like
40 | ```
41 | cuhk03/
42 | cuhk03_release/
43 | cuhk03_new_protocol_config_detected.mat
44 | cuhk03_new_protocol_config_labeled.mat
45 | ...
46 | ```
47 | 4. Use `-d cuhk03` when running the training code. In default mode, we use new split (767/700). If you wanna use the original splits (1367/100) created by [13], specify `--cuhk03-classic-split`. As [13] computes CMC differently from Market1501, you might need to specify `--use-metric-cuhk03` for fair comparison with their method. In addition, we support both `labeled` and `detected` modes. The default mode loads `detected` images. Specify `--cuhk03-labeled` if you wanna train and test on `labeled` images.
48 |
49 | **DukeMTMC-reID** [16, 17]:
50 |
51 | 1. Create a directory under `data/` called `dukemtmc-reid`.
52 | 2. Download dataset `DukeMTMC-reID.zip` from https://github.com/layumi/DukeMTMC-reID_evaluation#download-dataset and put it to `data/dukemtmc-reid`. Extract the zip file, which leads to
53 | ```
54 | dukemtmc-reid/
55 | DukeMTMC-reid.zip # (you can delete this zip file, it is ok)
56 | DukeMTMC-reid/ # this folder contains 8 files.
57 | ```
58 | 3. Use `-d dukemtmcreid` when running the training code.
59 |
60 |
61 | **MSMT17** [22]:
62 | 1. Create a directory named `msmt17/` under `data/`.
63 | 2. Download dataset `MSMT17_V1.tar.gz` to `data/msmt17/` from http://www.pkuvmc.com/publications/msmt17.html. Extract the file under the same folder, so you will have
64 | ```
65 | msmt17/
66 | MSMT17_V1.tar.gz # (do whatever you want with this .tar file)
67 | MSMT17_V1/
68 | train/
69 | test/
70 | list_train.txt
71 | ... (totally six .txt files)
72 | ```
73 | 3. Use `-d msmt17` when running the training code.
74 |
75 | # Prepare pretrained ReID models
76 | 1. Create a directory to store reid pretrained models under this repo
77 |
78 | ```bash
79 | mkdir models/
80 | ```
81 | 2. Download the pretrained models or train the models from scratch by yourself offline
82 |
83 | 2.1 Download Links
84 |
85 | [IDE](https://drive.google.com/open?id=1hVYGcuhfwMs25QVdo2R-ugXW4WyAzuHF)
86 |
87 | [DenseNet121](https://drive.google.com/drive/folders/1XSiVo0lqULQJyYv4T2pt6uA4qtxKSb6X?usp=sharing)
88 |
89 | [AlignedReID](https://drive.google.com/open?id=1YZ7J85f1Fcjft7sh2rlPs1s0dlcaFpf-)
90 |
91 | [PCB](https://drive.google.com/open?id=1xkA981JDESHxhGM_2N-ZdvboVXXzi3yd)
92 |
93 | [Mudeep](https://drive.google.com/open?id=1g6HBt5uCVSbLQL1JUOY_jZZqYKtRmVsX)
94 |
95 | [HACNN](https://drive.google.com/open?id=1ZxzY149vgagHzDUQLMuJqCpCSG3mtH3M)
96 |
97 | [CamStyle](https://drive.google.com/open?id=11WsAyhme4p8i3lNehYpfdB0jZtSSOTzx)
98 |
99 | [LSRO](https://drive.google.com/drive/folders/1cxeOJ3FU6qraHWU927HWC24E_MpXghP5?usp=sharing)
100 |
101 | [HHL](https://drive.google.com/open?id=1ZStrZ6qrB_kgcoB9BLXre81RiXtybBXD)
102 |
103 | [SPGAN](https://drive.google.com/open?id=1YwnmBjfhBHlVQmTRn1ehaHRe5cXVGg5Z)
104 |
105 | 2.2 Training models from scratch (optional)
106 |
107 | Create a directory named by the targeted model (like `aligned/` or `hacnn/`) following `__init__.py`under `models/` and move the checkpoint of pretrained models to this directory. Details of naming rules can refer to the download link.
108 |
109 | 3. Customized ReID models (optional)
110 |
111 | It is easy to test the robustness of any customized ReID models following the above steps (1→2.2→3). The extra thing you need to do is to add the structure of your own models to `models/` and register it in`__init__.py` .
112 |
113 | # Train
114 | Take attacking AlignedReID trained on Market1501 as an example:
115 |
116 | ```bash
117 | python train.py \
118 | --targetmodel='aligned' \
119 | --dataset='market1501'\
120 | --mode='train' \
121 | --loss='xent_htri' \
122 | --ak_type=-1 \
123 | --temperature=-1 \
124 | --use_SSIM=2 \
125 | --epoch=40
126 | ```
127 |
128 | # Test
129 | Take attacking AlignedReID trained on Market1501 as an example:
130 |
131 | ```bash
132 | python train.py \
133 | --targetmodel='aligned' \
134 | --dataset='market1501'\
135 | --G_resume_dir='./logs/aligned/market1501/best_G.pth.tar' \
136 | --mode='test' \
137 | --loss='xent_htri' \
138 | --ak_type=-1 \
139 | --temperature=-1 \
140 | --use_SSIM=2 \
141 | --epoch=40
142 | ```
143 |
144 | # Results
145 |
146 |
147 |
148 |
149 |
150 |
151 | # Reminders
152 |
153 | 1. If you are using your *own* trained ReID models (no matter whether they are customized), be careful about the name of variables and properly change or hold Line 38–53 in `__init__.py` (adaptation to early Pytorch0.3 trained models).
154 | 2. You may notice some arguments and codes involve the attribute information, if you are interested in that you can easily find and download the extra attribute files about Market1501 or DukeMTMC. We have conducted some related experiments about attribute attack but it is *not* the main content of this paper so I delete that part of code.
155 |
156 | # Reference
157 |
158 | If you are interested in our work, please consider citing our paper.
159 | ```
160 | @InProceedings{Wang_2020_CVPR,
161 | author = {Wang, Hongjun and Wang, Guangrun and Li, Ya and Zhang, Dongyu and Lin, Liang},
162 | title = {Transferable, Controllable, and Inconspicuous Adversarial Attacks on Person Re-identification With Deep Mis-Ranking},
163 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
164 | month = {June},
165 | year = {2020}
166 | }
167 | ```
168 |
169 | # Acknowledgements
170 | Thanks for the following excellent works:
171 |
172 | - Open-reid [code](https://github.com/Cysu/open-reid)
173 | - AlignedReID [paper](https://www.sciencedirect.com/science/article/abs/pii/S0031320319302031?via%3Dihub#!) and [code](https://github.com/michuanhaohao/AlignedReID) by michuanhaohao
174 | - Person ReID baseline [code](https://github.com/layumi/Person_reID_baseline_pytorch) by layumi
175 | - LSRO [paper](https://arxiv.org/abs/1701.07717) and [code](https://github.com/layumi/Person-reID_GAN) by layumi
176 | - HHL [paper](http://openaccess.thecvf.com/content_ECCV_2018/html/Zhun_Zhong_Generalizing_A_Person_ECCV_2018_paper.html) and [code](https://github.com/zhunzhong07/HHL) by zhunzhong07
177 |
--------------------------------------------------------------------------------
/util/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | import os
4 | import sys
5 | import errno
6 | import shutil
7 | import json
8 | import time
9 | import os.path as osp
10 | from PIL import Image
11 | import matplotlib
12 | import numpy as np
13 | from numpy import array, argmin
14 |
15 | import torch
16 |
17 | def mkdir_if_missing(directory):
18 | if not osp.exists(directory):
19 | try:
20 | os.makedirs(directory)
21 | except OSError as e:
22 | if e.errno != errno.EEXIST:
23 | raise
24 |
25 | def fliplr(img):
26 | '''flip horizontal'''
27 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long().cuda() # N x C x H x W
28 | img_flip = img.index_select(3,inv_idx)
29 | return img_flip
30 |
31 | def save_heatmap(path, den):
32 | matplotlib.use('Agg')
33 | import matplotlib.pyplot as plt
34 | from matplotlib.colors import PowerNorm, LogNorm
35 | import matplotlib.cm as cm
36 | plt.axis('off')
37 | plt.imshow(den,
38 | cmap=cm.jet,
39 | Norm=LogNorm(),
40 | interpolation="bicubic")
41 | # save fig
42 | fig = plt.gcf()
43 | fig.savefig(path, format='png', bbox_inches='tight', transparent=True, dpi=600)
44 | plt.close('all')
45 |
46 | class AverageMeter(object):
47 | """Computes and stores the average and current value.
48 |
49 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
50 | """
51 | def __init__(self):
52 | self.reset()
53 |
54 | def reset(self):
55 | self.val = 0
56 | self.avg = 0
57 | self.sum = 0
58 | self.count = 0
59 |
60 | def update(self, val, n=1):
61 | self.val = val
62 | self.sum += val * n
63 | self.count += n
64 | self.avg = self.sum / self.count
65 |
66 | def save_checkpoint(state, is_best, G_or_D, fpath='checkpoint.pth.tar'):
67 | mkdir_if_missing(osp.dirname(fpath))
68 | torch.save(state, fpath)
69 | if is_best:
70 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_'+ G_or_D +'.pth.tar'))
71 |
72 | class Logger(object):
73 | """
74 | Write console output to external text file.
75 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
76 | """
77 | def __init__(self, fpath=None):
78 | self.console = sys.stdout
79 | self.file = None
80 | if fpath is not None:
81 | mkdir_if_missing(os.path.dirname(fpath))
82 | self.file = open(fpath, 'w')
83 |
84 | def __del__(self):
85 | self.close()
86 |
87 | def __enter__(self):
88 | pass
89 |
90 | def __exit__(self, *args):
91 | self.close()
92 |
93 | def write(self, msg):
94 | self.console.write(msg)
95 | if self.file is not None:
96 | self.file.write(msg)
97 |
98 | def flush(self):
99 | self.console.flush()
100 | if self.file is not None:
101 | self.file.flush()
102 | os.fsync(self.file.fileno())
103 |
104 | def close(self):
105 | self.console.close()
106 | if self.file is not None:
107 | self.file.close()
108 |
109 | def read_json(fpath):
110 | with open(fpath, 'r') as f:
111 | obj = json.load(f)
112 | return obj
113 |
114 | def write_json(obj, fpath):
115 | mkdir_if_missing(osp.dirname(fpath))
116 | with open(fpath, 'w') as f:
117 | json.dump(obj, f, indent=4, separators=(',', ': '))
118 |
119 | def _traceback(D):
120 | i,j = array(D.shape)-1
121 | p,q = [i],[j]
122 | while (i>0) or (j>0):
123 | tb = argmin((D[i,j-1], D[i-1,j]))
124 | if tb == 0:
125 | j -= 1
126 | else: #(tb==1)
127 | i -= 1
128 | p.insert(0,i)
129 | q.insert(0,j)
130 | return array(p), array(q)
131 |
132 | def dtw(dist_mat):
133 | m, n = dist_mat.shape[:2]
134 | dist = np.zeros_like(dist_mat)
135 | for i in range(m):
136 | for j in range(n):
137 | if (i == 0) and (j == 0):
138 | dist[i, j] = dist_mat[i, j]
139 | elif (i == 0) and (j > 0):
140 | dist[i, j] = dist[i, j - 1] + dist_mat[i, j]
141 | elif (i > 0) and (j == 0):
142 | dist[i, j] = dist[i - 1, j] + dist_mat[i, j]
143 | else:
144 | dist[i, j] = \
145 | np.min(np.stack([dist[i - 1, j], dist[i, j - 1]], axis=0), axis=0) \
146 | + dist_mat[i, j]
147 | path = _traceback(dist)
148 | return dist[-1,-1]/sum(dist.shape), dist, path
149 |
150 | def read_image(img_path):
151 | got_img = False
152 | if not osp.exists(img_path):
153 | raise IOError("{} does not exist".format(img_path))
154 | while not got_img:
155 | try:
156 | img = Image.open(img_path).convert('RGB')
157 | got_img = True
158 | except IOError:
159 | print("IOError incurred when reading '{}'. Will Redo. Don't worry. Just chill".format(img_path))
160 | pass
161 | return img
162 |
163 | def img_to_tensor(img,transform):
164 | img = transform(img)
165 | img = img.unsqueeze(0)
166 | return img
167 |
168 | def feat_flatten(feat):
169 | shp = feat.shape
170 | feat = feat.reshape(shp[0] * shp[1], shp[2])
171 | return feat
172 |
173 | def merge_feature(feature_list, shp, sample_rate = None):
174 | def pre_process(torch_feature_map):
175 | numpy_feature_map = torch_feature_map.cpu().data.numpy()[0]
176 | numpy_feature_map = numpy_feature_map.transpose(1,2,0)
177 | shp = numpy_feature_map.shape[:2]
178 | return numpy_feature_map, shp
179 | def resize_as(tfm, shp):
180 | nfm, shp2 = pre_process(tfm)
181 | scale = shp[0]/shp2[0]
182 | nfm1 = nfm.repeat(scale, axis = 0).repeat(scale, axis=1)
183 | return nfm1
184 | final_nfm = resize_as(feature_list[0], shp)
185 | for i in range(1, len(feature_list)):
186 | temp_nfm = resize_as(feature_list[i],shp)
187 | final_nfm = np.concatenate((final_nfm, temp_nfm),axis =-1)
188 | if sample_rate > 0:
189 | final_nfm = final_nfm[0:-1:sample_rate, 0:-1,sample_rate, :]
190 | return final_nfm
191 |
192 | def visualize_ranked_results(distmat, dataset, save_dir, topk=20):
193 | """
194 | Visualize ranked results
195 | Support both imgreid and vidreid
196 | Args:
197 | - distmat: distance matrix of shape (num_query, num_gallery).
198 | - dataset: has dataset.query and dataset.gallery, both are lists of (img_path, pid, camid);
199 | for imgreid, img_path is a string, while for vidreid, img_path is a tuple containing
200 | a sequence of strings.
201 | - save_dir: directory to save output images.
202 | - topk: int, denoting top-k images in the rank list to be visualized.
203 | """
204 | num_q, num_g = distmat.shape
205 |
206 | print("Visualizing top-{} ranks in '{}' ...".format(topk, save_dir))
207 | print("# query: {}. # gallery {}".format(num_q, num_g))
208 |
209 | assert num_q == len(dataset.query)
210 | assert num_g == len(dataset.gallery)
211 |
212 | indices = np.argsort(distmat, axis=1)
213 | mkdir_if_missing(save_dir)
214 |
215 | for q_idx in range(num_q):
216 | qimg_path, qpid, qcamid = dataset.query[q_idx]
217 | qdir = osp.join(save_dir, 'query' + str(q_idx + 1).zfill(5))
218 | mkdir_if_missing(qdir)
219 | cp_img_to(qimg_path, qdir, rank=0, prefix='query')
220 |
221 | rank_idx = 1
222 | for g_idx in indices[q_idx,:]:
223 | gimg_path, gpid, gcamid = dataset.gallery[g_idx]
224 | invalid = (qpid == gpid) & (qcamid == gcamid)
225 | if not invalid:
226 | cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery')
227 | rank_idx += 1
228 | if rank_idx > topk:
229 | break
230 |
231 | def cp_img_to(src, dst, rank, prefix):
232 | """
233 | - src: image path or tuple (for vidreid)
234 | - dst: target directory
235 | - rank: int, denoting ranked position, starting from 1
236 | - prefix: string
237 | """
238 | if isinstance(src, tuple) or isinstance(src, list):
239 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3))
240 | mkdir_if_missing(dst)
241 | for img_path in src:
242 | shutil.copy(img_path, dst)
243 | else:
244 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src))
245 | shutil.copy(src, dst)
--------------------------------------------------------------------------------
/util/eval_metrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import numpy as np
3 | import torch
4 | import copy
5 | import os.path as osp
6 | from collections import defaultdict
7 | from opts import market1501_test_map, duke_test_map
8 | import sys
9 |
10 | def make_results(qf, gf, lqf, lgf, q_pids, g_pids, q_camids, g_camids, targetmodel, ak_typ, attr_matrix=None, dataset_name=None, attr=None):
11 | qf, gf = featureNormalization(qf, gf, targetmodel)
12 | m, n = qf.size(0), gf.size(0)
13 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
14 | distmat.addmm_(1, -2, qf, gf.t())
15 | distmat = distmat.numpy()
16 |
17 | if targetmodel == 'aligned':
18 | from .distance import low_memory_local_dist
19 | lqf, lgf = lqf.permute(0,2,1), lgf.permute(0,2,1)
20 | local_distmat = low_memory_local_dist(lqf.numpy(),lgf.numpy(), aligned=True)
21 | distmat = local_distmat+distmat
22 |
23 | if ak_typ > 0:
24 | distmat, all_hit, ignore_list = evaluate_attr(distmat, q_pids, g_pids, attr_matrix, dataset_name, attr)
25 | return distmat, all_hit, ignore_list
26 | else:
27 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, use_metric_cuhk03=False)
28 | return distmat, cmc, mAP
29 |
30 | def featureNormalization(qf, gf, targetmodel):
31 | if targetmodel in ['aligned', 'densenet121', 'hacnn', 'mudeep', 'ide', 'cam', 'lsro', 'hhl', 'spgan']:
32 | qf = 1. * qf / (torch.norm(qf, p=2, dim=-1, keepdim=True).expand_as(qf) + 1e-12)
33 | gf = 1. * gf / (torch.norm(gf, p=2, dim=-1, keepdim=True).expand_as(gf) + 1e-12)
34 |
35 | elif targetmodel in ['pcb']:
36 | qf = (qf / (np.sqrt(6) * torch.norm(qf, p=2, dim=1, keepdim=True).expand_as(qf))).view(qf.size(0), -1)
37 | gf = (gf / (np.sqrt(6) * torch.norm(gf, p=2, dim=1, keepdim=True).expand_as(gf))).view(gf.size(0), -1)
38 |
39 | return qf, gf
40 |
41 | def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, N=100):
42 | """Evaluation with cuhk03 metric
43 | Key: one image for each gallery identity is randomly sampled for each query identity.
44 | Random sampling is performed N times (default: N=100).
45 | """
46 | num_q, num_g = distmat.shape
47 | if num_g < max_rank:
48 | max_rank = num_g
49 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
50 | indices = np.argsort(distmat, axis=1)
51 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
52 |
53 | # compute cmc curve for each query
54 | all_cmc = []
55 | all_AP = []
56 | num_valid_q = 0. # number of valid query
57 | for q_idx in range(num_q):
58 | # get query pid and camid
59 | q_pid = q_pids[q_idx]
60 | q_camid = q_camids[q_idx]
61 |
62 | # remove gallery samples that have the same pid and camid with query
63 | order = indices[q_idx]
64 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
65 | keep = np.invert(remove)
66 |
67 | # compute cmc curve
68 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
69 | if not np.any(orig_cmc):
70 | # this condition is true when query identity does not appear in gallery
71 | continue
72 |
73 | kept_g_pids = g_pids[order][keep]
74 | g_pids_dict = defaultdict(list)
75 | for idx, pid in enumerate(kept_g_pids):
76 | g_pids_dict[pid].append(idx)
77 |
78 | cmc, AP = 0., 0.
79 | for repeat_idx in range(N):
80 | mask = np.zeros(len(orig_cmc), dtype=np.bool)
81 | for _, idxs in g_pids_dict.items():
82 | # randomly sample one image for each gallery person
83 | rnd_idx = np.random.choice(idxs)
84 | mask[rnd_idx] = True
85 | masked_orig_cmc = orig_cmc[mask]
86 | _cmc = masked_orig_cmc.cumsum()
87 | _cmc[_cmc > 1] = 1
88 | cmc += _cmc[:max_rank].astype(np.float32)
89 | # compute AP
90 | num_rel = masked_orig_cmc.sum()
91 | tmp_cmc = masked_orig_cmc.cumsum()
92 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
93 | tmp_cmc = np.asarray(tmp_cmc) * masked_orig_cmc
94 | AP += tmp_cmc.sum() / num_rel
95 | cmc /= N
96 | AP /= N
97 | all_cmc.append(cmc)
98 | all_AP.append(AP)
99 | num_valid_q += 1.
100 |
101 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
102 |
103 | all_cmc = np.asarray(all_cmc).astype(np.float32)
104 | all_cmc = all_cmc.sum(0) / num_valid_q
105 | mAP = np.mean(all_AP)
106 |
107 | return all_cmc, mAP
108 |
109 | def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
110 | """Evaluation with market1501 metric
111 | Key: for each query identity, its gallery images from the same camera view are discarded.
112 | """
113 | num_q, num_g = distmat.shape
114 | if num_g < max_rank:
115 | max_rank = num_g
116 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
117 | indices = np.argsort(distmat, axis=1)
118 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
119 |
120 | # compute cmc curve for each query
121 | all_cmc = []
122 | all_AP = []
123 | num_valid_q = 0. # number of valid query
124 | for q_idx in range(num_q):
125 | # get query pid and camid
126 | q_pid = q_pids[q_idx]
127 | q_camid = q_camids[q_idx]
128 |
129 | # remove gallery samples that have the same pid and camid with query
130 | order = indices[q_idx]
131 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
132 | keep = np.invert(remove)
133 |
134 | # compute cmc curve
135 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
136 | if not np.any(orig_cmc):
137 | # this condition is true when query identity does not appear in gallery
138 | continue
139 |
140 | cmc = orig_cmc.cumsum()
141 | cmc[cmc > 1] = 1
142 |
143 | all_cmc.append(cmc[:max_rank])
144 | num_valid_q += 1.
145 |
146 | # compute average precision
147 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
148 | num_rel = orig_cmc.sum()
149 | tmp_cmc = orig_cmc.cumsum()
150 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
151 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
152 | AP = tmp_cmc.sum() / num_rel
153 | all_AP.append(AP)
154 |
155 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
156 |
157 | all_cmc = np.asarray(all_cmc).astype(np.float32)
158 | all_cmc = all_cmc.sum(0) / num_valid_q
159 | mAP = np.mean(all_AP)
160 |
161 | return all_cmc, mAP
162 |
163 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=20, use_metric_cuhk03=False):
164 | if use_metric_cuhk03: return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
165 | else: return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
166 |
167 | def evaluate_attr(distmat, q_pids, g_pids, attr_matrix, dataset_name, attr_list, max_rank=20):
168 | attr_key, attr_value = attr_list
169 | attr_name = 'duke_attribute' if dataset_name == 'dukemtmcreid' else 'market_attribute'
170 | offset = 0 if dataset_name == 'dukemtmcreid' else 1
171 | mapping = duke_test_map if dataset_name == 'dukemtmcreid' else market1501_test_map
172 | column = attr_matrix[attr_name][0]['test'][0][0][attr_key][0][0]
173 |
174 | num_q, num_g = distmat.shape
175 | indices = np.argsort(distmat, axis=1)
176 |
177 | all_hit = []
178 | ignore_list = []
179 | num_valid_q = 0. # number of valid query
180 | for q_idx in range(num_q):
181 | q_pid = q_pids[q_idx]
182 | if column[mapping[q_pid]-offset] == attr_value:
183 | ignore_list.append(q_idx)
184 | continue
185 |
186 | order = indices[q_idx]
187 | matches = np.zeros_like(order)
188 |
189 | for i in range(len(order)):
190 | if column[mapping[g_pids[order[i]]]-offset] == attr_value:
191 | matches[i] = 1
192 |
193 | hit = matches.cumsum()
194 | hit[hit > 1] = 1
195 | all_hit.append(hit[:max_rank])
196 | num_valid_q += 1. # number of valid query
197 |
198 | assert num_valid_q > 0
199 | all_hit = np.asarray(all_hit).astype(np.float32)
200 | all_hit = all_hit.sum(0) / num_valid_q
201 |
202 | # distmat = np.delete(distmat, ignore_list, axis=0)
203 |
204 | return distmat, all_hit, ignore_list
--------------------------------------------------------------------------------
/models/HACNN.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | import torchvision
7 |
8 | __all__ = ['HACNN']
9 |
10 | class ConvBlock(nn.Module):
11 | """Basic convolutional block:
12 | convolution + batch normalization + relu.
13 |
14 | Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
15 | in_c (int): number of input channels.
16 | out_c (int): number of output channels.
17 | k (int or tuple): kernel size.
18 | s (int or tuple): stride.
19 | p (int or tuple): padding.
20 | """
21 | def __init__(self, in_c, out_c, k, s=1, p=0):
22 | super(ConvBlock, self).__init__()
23 | self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
24 | self.bn = nn.BatchNorm2d(out_c)
25 |
26 | def forward(self, x):
27 | return F.relu(self.bn(self.conv(x)))
28 |
29 | class InceptionA(nn.Module):
30 | """
31 | Args:
32 | in_channels (int): number of input channels
33 | out_channels (int): number of output channels AFTER concatenation
34 | """
35 | def __init__(self, in_channels, out_channels):
36 | super(InceptionA, self).__init__()
37 | single_out_channels = out_channels // 4
38 |
39 | self.stream1 = nn.Sequential(
40 | ConvBlock(in_channels, single_out_channels, 1),
41 | ConvBlock(single_out_channels, single_out_channels, 3, p=1),
42 | )
43 | self.stream2 = nn.Sequential(
44 | ConvBlock(in_channels, single_out_channels, 1),
45 | ConvBlock(single_out_channels, single_out_channels, 3, p=1),
46 | )
47 | self.stream3 = nn.Sequential(
48 | ConvBlock(in_channels, single_out_channels, 1),
49 | ConvBlock(single_out_channels, single_out_channels, 3, p=1),
50 | )
51 | self.stream4 = nn.Sequential(
52 | nn.AvgPool2d(3, stride=1, padding=1),
53 | ConvBlock(in_channels, single_out_channels, 1),
54 | )
55 |
56 | def forward(self, x):
57 | s1 = self.stream1(x)
58 | s2 = self.stream2(x)
59 | s3 = self.stream3(x)
60 | s4 = self.stream4(x)
61 | y = torch.cat([s1, s2, s3, s4], dim=1)
62 | return y
63 |
64 | class InceptionB(nn.Module):
65 | """
66 | Args:
67 | in_channels (int): number of input channels
68 | out_channels (int): number of output channels AFTER concatenation
69 | """
70 | def __init__(self, in_channels, out_channels):
71 | super(InceptionB, self).__init__()
72 | single_out_channels = out_channels // 4
73 |
74 | self.stream1 = nn.Sequential(
75 | ConvBlock(in_channels, single_out_channels, 1),
76 | ConvBlock(single_out_channels, single_out_channels, 3, s=2, p=1),
77 | )
78 | self.stream2 = nn.Sequential(
79 | ConvBlock(in_channels, single_out_channels, 1),
80 | ConvBlock(single_out_channels, single_out_channels, 3, p=1),
81 | ConvBlock(single_out_channels, single_out_channels, 3, s=2, p=1),
82 | )
83 | self.stream3 = nn.Sequential(
84 | nn.MaxPool2d(3, stride=2, padding=1),
85 | ConvBlock(in_channels, single_out_channels*2, 1),
86 | )
87 |
88 | def forward(self, x):
89 | s1 = self.stream1(x)
90 | s2 = self.stream2(x)
91 | s3 = self.stream3(x)
92 | y = torch.cat([s1, s2, s3], dim=1)
93 | return y
94 |
95 | class SpatialAttn(nn.Module):
96 | """Spatial Attention (Sec. 3.1.I.1)"""
97 | def __init__(self):
98 | super(SpatialAttn, self).__init__()
99 | self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
100 | self.conv2 = ConvBlock(1, 1, 1)
101 |
102 | def forward(self, x):
103 | # global cross-channel averaging
104 | x = x.mean(1, keepdim=True)
105 | # 3-by-3 conv
106 | x = self.conv1(x)
107 | # bilinear resizing
108 | x = F.upsample(x, (x.size(2)*2, x.size(3)*2), mode='bilinear', align_corners=True)
109 | # scaling conv
110 | x = self.conv2(x)
111 | return x
112 |
113 | class ChannelAttn(nn.Module):
114 | """Channel Attention (Sec. 3.1.I.2)"""
115 | def __init__(self, in_channels, reduction_rate=16):
116 | super(ChannelAttn, self).__init__()
117 | assert in_channels%reduction_rate == 0
118 | self.conv1 = ConvBlock(in_channels, in_channels//reduction_rate, 1)
119 | self.conv2 = ConvBlock(in_channels//reduction_rate, in_channels, 1)
120 |
121 | def forward(self, x):
122 | # squeeze operation (global average pooling)
123 | x = F.avg_pool2d(x, x.size()[2:])
124 | # excitation operation (2 conv layers)
125 | x = self.conv1(x)
126 | x = self.conv2(x)
127 | return x
128 |
129 | class SoftAttn(nn.Module):
130 | """Soft Attention (Sec. 3.1.I)
131 | Aim: Spatial Attention + Channel Attention
132 | Output: attention maps with shape identical to input.
133 | """
134 | def __init__(self, in_channels):
135 | super(SoftAttn, self).__init__()
136 | self.spatial_attn = SpatialAttn()
137 | self.channel_attn = ChannelAttn(in_channels)
138 | self.conv = ConvBlock(in_channels, in_channels, 1)
139 |
140 | def forward(self, x):
141 | y_spatial = self.spatial_attn(x)
142 | y_channel = self.channel_attn(x)
143 | y = y_spatial * y_channel
144 | y = F.sigmoid(self.conv(y))
145 | return y
146 |
147 | class HardAttn(nn.Module):
148 | """Hard Attention (Sec. 3.1.II)"""
149 | def __init__(self, in_channels):
150 | super(HardAttn, self).__init__()
151 | self.fc = nn.Linear(in_channels, 4*2)
152 | self.init_params()
153 |
154 | def init_params(self):
155 | self.fc.weight.data.zero_()
156 | self.fc.bias.data.copy_(torch.tensor([0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float))
157 |
158 | def forward(self, x):
159 | # squeeze operation (global average pooling)
160 | x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
161 | # predict transformation parameters
162 | theta = F.tanh(self.fc(x))
163 | theta = theta.view(-1, 4, 2)
164 | return theta
165 |
166 | class HarmAttn(nn.Module):
167 | """Harmonious Attention (Sec. 3.1)"""
168 | def __init__(self, in_channels):
169 | super(HarmAttn, self).__init__()
170 | self.soft_attn = SoftAttn(in_channels)
171 | self.hard_attn = HardAttn(in_channels)
172 |
173 | def forward(self, x):
174 | y_soft_attn = self.soft_attn(x)
175 | theta = self.hard_attn(x)
176 | return y_soft_attn, theta
177 |
178 | class HACNN(nn.Module):
179 | """
180 | Harmonious Attention Convolutional Neural Network
181 |
182 | Reference:
183 | Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
184 |
185 | Args:
186 | num_classes (int): number of classes to predict
187 | nchannels (list): number of channels AFTER concatenation
188 | feat_dim (int): feature dimension for a single stream
189 | learn_region (bool): whether to learn region features (i.e. local branch)
190 | """
191 | def __init__(self, num_classes, loss={'xent', 'htri'}, nchannels=[128, 256, 384], feat_dim=512, learn_region=True, use_gpu=True, **kwargs):
192 | super(HACNN, self).__init__()
193 | self.loss = loss
194 | self.learn_region = learn_region
195 | self.use_gpu = use_gpu
196 |
197 | self.conv = ConvBlock(3, 32, 3, s=2, p=1)
198 |
199 | # Construct Inception + HarmAttn blocks
200 | # ============== Block 1 ==============
201 | self.inception1 = nn.Sequential(
202 | InceptionA(32, nchannels[0]),
203 | InceptionB(nchannels[0], nchannels[0]),
204 | )
205 | self.ha1 = HarmAttn(nchannels[0])
206 |
207 | # ============== Block 2 ==============
208 | self.inception2 = nn.Sequential(
209 | InceptionA(nchannels[0], nchannels[1]),
210 | InceptionB(nchannels[1], nchannels[1]),
211 | )
212 | self.ha2 = HarmAttn(nchannels[1])
213 |
214 | # ============== Block 3 ==============
215 | self.inception3 = nn.Sequential(
216 | InceptionA(nchannels[1], nchannels[2]),
217 | InceptionB(nchannels[2], nchannels[2]),
218 | )
219 | self.ha3 = HarmAttn(nchannels[2])
220 |
221 | self.fc_global = nn.Sequential(
222 | nn.Linear(nchannels[2], feat_dim),
223 | nn.BatchNorm1d(feat_dim),
224 | nn.ReLU(),
225 | )
226 | self.classifier_global = nn.Linear(feat_dim, num_classes)
227 |
228 | if self.learn_region:
229 | self.init_scale_factors()
230 | self.local_conv1 = InceptionB(32, nchannels[0])
231 | self.local_conv2 = InceptionB(nchannels[0], nchannels[1])
232 | self.local_conv3 = InceptionB(nchannels[1], nchannels[2])
233 | self.fc_local = nn.Sequential(
234 | nn.Linear(nchannels[2]*4, feat_dim),
235 | nn.BatchNorm1d(feat_dim),
236 | nn.ReLU(),
237 | )
238 | self.classifier_local = nn.Linear(feat_dim, num_classes)
239 | self.feat_dim = feat_dim * 2
240 | else:
241 | self.feat_dim = feat_dim
242 |
243 | def init_scale_factors(self):
244 | # initialize scale factors (s_w, s_h) for four regions
245 | self.scale_factors = []
246 | self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
247 | self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
248 | self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
249 | self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
250 |
251 | def stn(self, x, theta):
252 | """Perform spatial transform
253 | x: (batch, channel, height, width)
254 | theta: (batch, 2, 3)
255 | """
256 | grid = F.affine_grid(theta, x.size())
257 | x = F.grid_sample(x, grid)
258 | return x
259 |
260 | def transform_theta(self, theta_i, region_idx):
261 | """Transform theta to include (s_w, s_h),
262 | resulting in (batch, 2, 3)"""
263 | scale_factors = self.scale_factors[region_idx]
264 | theta = torch.zeros(theta_i.size(0), 2, 3)
265 | theta[:,:,:2] = scale_factors
266 | theta[:,:,-1] = theta_i
267 | if self.use_gpu: theta = theta.cuda()
268 | return theta
269 |
270 | def forward(self, x, is_training):
271 | assert x.size(2) == 160 and x.size(3) == 64, \
272 | "Input size does not match, expected (160, 64) but got ({}, {})".format(x.size(2), x.size(3))
273 | x = self.conv(x)
274 |
275 | # ============== Block 1 ==============
276 | # global branch
277 | x1 = self.inception1(x)
278 | x1_attn, x1_theta = self.ha1(x1)
279 | x1_out = x1 * x1_attn
280 | # local branch
281 | if self.learn_region:
282 | x1_local_list = []
283 | for region_idx in range(4):
284 | x1_theta_i = x1_theta[:,region_idx,:]
285 | x1_theta_i = self.transform_theta(x1_theta_i, region_idx)
286 | x1_trans_i = self.stn(x, x1_theta_i)
287 | x1_trans_i = F.upsample(x1_trans_i, (24, 28), mode='bilinear', align_corners=True)
288 | x1_local_i = self.local_conv1(x1_trans_i)
289 | x1_local_list.append(x1_local_i)
290 |
291 | # ============== Block 2 ==============
292 | # Block 2
293 | # global branch
294 | x2 = self.inception2(x1_out)
295 | x2_attn, x2_theta = self.ha2(x2)
296 | x2_out = x2 * x2_attn
297 | # local branch
298 | if self.learn_region:
299 | x2_local_list = []
300 | for region_idx in range(4):
301 | x2_theta_i = x2_theta[:,region_idx,:]
302 | x2_theta_i = self.transform_theta(x2_theta_i, region_idx)
303 | x2_trans_i = self.stn(x1_out, x2_theta_i)
304 | x2_trans_i = F.upsample(x2_trans_i, (12, 14), mode='bilinear', align_corners=True)
305 | x2_local_i = x2_trans_i + x1_local_list[region_idx]
306 | x2_local_i = self.local_conv2(x2_local_i)
307 | x2_local_list.append(x2_local_i)
308 |
309 | # ============== Block 3 ==============
310 | # Block 3
311 | # global branch
312 | x3 = self.inception3(x2_out)
313 | x3_attn, x3_theta = self.ha3(x3)
314 | x3_out = x3 * x3_attn
315 | # local branch
316 | if self.learn_region:
317 | x3_local_list = []
318 | for region_idx in range(4):
319 | x3_theta_i = x3_theta[:,region_idx,:]
320 | x3_theta_i = self.transform_theta(x3_theta_i, region_idx)
321 | x3_trans_i = self.stn(x2_out, x3_theta_i)
322 | x3_trans_i = F.upsample(x3_trans_i, (6, 7), mode='bilinear', align_corners=True)
323 | x3_local_i = x3_trans_i + x2_local_list[region_idx]
324 | x3_local_i = self.local_conv3(x3_local_i)
325 | x3_local_list.append(x3_local_i)
326 |
327 | # ============== Feature generation ==============
328 | # global branch
329 | x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(x3_out.size(0), x3_out.size(1))
330 | x_global = self.fc_global(x_global)
331 | # local branch
332 | if self.learn_region:
333 | x_local_list = []
334 | for region_idx in range(4):
335 | x_local_i = x3_local_list[region_idx]
336 | x_local_i = F.avg_pool2d(x_local_i, x_local_i.size()[2:]).view(x_local_i.size(0), -1)
337 | x_local_list.append(x_local_i)
338 | x_local = torch.cat(x_local_list, 1)
339 | x_local = self.fc_local(x_local)
340 |
341 | if not is_training:
342 | # l2 normalization before concatenation
343 | if self.learn_region:
344 | x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True)
345 | x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True)
346 | return [torch.cat([x_global, x_local], 1)]
347 | else:
348 | return [x_global]
349 |
350 | prelogits_global = self.classifier_global(x_global)
351 | if self.learn_region:
352 | prelogits_local = self.classifier_local(x_local)
353 |
354 | if self.loss == {'xent'}:
355 | if self.learn_region:
356 | return [prelogits_global, prelogits_local]
357 | else:
358 | return [prelogits_global]
359 | elif self.loss == {'xent', 'htri'}:
360 | if self.learn_region:
361 | return [(prelogits_global, prelogits_local), (x_global, x_local)]
362 | else:
363 | return [prelogits_global, x_global]
364 | else:
365 | raise KeyError("Unsupported loss: {}".format(self.loss))
--------------------------------------------------------------------------------
/GD.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import init
5 | import functools
6 | from torch.autograd import Variable
7 | from torch.optim import lr_scheduler
8 | from util.spectral import SpectralNorm
9 | from util.gumbel import gumbel_softmax
10 | import numpy as np
11 | import math
12 |
13 | class Pat_Discriminator(nn.Module):
14 | """
15 | Defines a PatchGAN discriminator
16 | Code based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
17 | """
18 | def __init__(self, input_nc, ndf=64, n_layers=3, norm='bn'):
19 | """Construct a PatchGAN discriminator
20 | Parameters:
21 | input_nc (int) -- the number of channels in input images
22 | ndf (int) -- the number of filters in the last conv layer
23 | n_layers (int) -- the number of conv layers in the discriminator
24 | norm_layer -- normalization layer
25 | """
26 | super(Pat_Discriminator, self).__init__()
27 |
28 | norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d
29 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
30 | use_bias = norm_layer.func != nn.BatchNorm2d
31 | else:
32 | use_bias = norm_layer != nn.BatchNorm2d
33 |
34 | kw = 4
35 | padw = 1
36 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
37 | nf_mult = 1
38 | nf_mult_prev = 1
39 | for n in range(1, n_layers): # gradually increase the number of filters
40 | nf_mult_prev = nf_mult
41 | nf_mult = min(2 ** n, 8)
42 | sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True)]
43 |
44 | nf_mult_prev = nf_mult
45 | nf_mult = min(2 ** n_layers, 8)
46 | sequence += [nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True)]
47 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
48 | self.model = nn.Sequential(*sequence)
49 |
50 | def forward(self, x):
51 | return self.model(x), torch.ones_like(x)
52 |
53 |
54 | class MS_Discriminator(nn.Module):
55 | def __init__(self, input_nc, ndf=64, n_layers=3, norm='bn', num_D=3, temperature=-1, use_gumbel=False):
56 | super(MS_Discriminator, self).__init__()
57 | self.num_D = num_D
58 | self.n_layers = n_layers
59 | self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, count_include_pad=False)
60 | self.same0 = SamePadding(kernel_size=3, stride=2)
61 | self.same1 = SamePadding(kernel_size=4, stride=2)
62 | self.same2 = SamePadding(kernel_size=4, stride=1)
63 | self.Mask = Mask(norm, temperature, use_gumbel)
64 |
65 | for i in range(num_D):
66 | netD = sub_Discriminator(input_nc, ndf, n_layers, norm)
67 | for j in range(n_layers+2): setattr(self, 'D'+str(i)+'_layer'+str(j), getattr(netD, 'layer'+str(j)))
68 |
69 | def single_forward(self, model, x):
70 | result = [x]
71 | for i in range(len(model)):
72 | samepadding = self.same1 if i < len(model)-2 else self.same2
73 | result.append(model[i](samepadding(result[-1])))
74 | return result[1:]
75 |
76 | def forward(self, x):
77 | num_D = self.num_D
78 | proposal = []
79 | result = []
80 | mask = None
81 | input_downsampled = x
82 | for i in range(num_D):
83 | model = [getattr(self, 'D'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
84 | proposal.append(self.single_forward(model, input_downsampled)) #[[D2L0, D2L1,..., D2L4],[D1L0,...,D1L4],[D0L0,...,D0L4]]
85 | if i != (num_D-1): input_downsampled = self.downsample(self.same0(input_downsampled))
86 | for i in proposal: result.append(i[-1])
87 | mask = self.Mask(x, proposal)
88 | return result, mask
89 |
90 | # (64,128,256,512,1)
91 | class sub_Discriminator(nn.Module):
92 | def __init__(self, input_nc, ndf=64, n_layers=3, norm='in'):
93 | super(sub_Discriminator, self).__init__()
94 | self.n_layers = n_layers
95 |
96 | use_bias = norm == 'in'
97 | norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d
98 | sequence = [[SpectralNorm(nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, bias=use_bias)), nn.LeakyReLU(0.2, True)]]
99 | nf = ndf
100 | for n in range(1, n_layers):
101 | nf_prev = nf
102 | nf = min(nf*2, 512)
103 | sequence += [[SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=4, stride=2, bias=use_bias)), norm_layer(nf), nn.LeakyReLU(0.2, True)]]
104 |
105 | nf_prev = nf
106 | nf = min(nf*2, 512)
107 | sequence += [[SpectralNorm(nn.Conv2d(nf_prev, nf, kernel_size=4, stride=1, bias=use_bias)), norm_layer(nf), nn.LeakyReLU(0.2, True)]]
108 | sequence += [[nn.Conv2d(nf, 1, kernel_size=4, stride=1)]]
109 |
110 | for n in range(len(sequence)):
111 | setattr(self, 'layer'+str(n), nn.Sequential(*sequence[n]))
112 |
113 | def forward(self, input):
114 | res = [input]
115 | for n in range(self.n_layers+2):
116 | model = getattr(self, 'layer'+str(n))
117 | res.append(model(res[-1]))
118 | return res[1:]
119 |
120 | class Mask(nn.Module):
121 | def __init__(self, norm, temperature, use_gumbel, fused=1):
122 | super(Mask, self).__init__()
123 | self.temperature = temperature
124 | self.use_gumbel = use_gumbel
125 | self.fused = fused
126 | self.T = nn.Parameter(torch.Tensor([1]))
127 | norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d
128 | small_channels = [512, 512, 256, 128]
129 | big_channels = [512+128, 512+128+64, 128+64, 64] if self.fused == 2 else [512, 512, 128, 64]
130 |
131 | self.up32_16 = UpLayer(big_channels=big_channels[0], out_channels=512, small_channels=small_channels[0], norm_layer=norm_layer)
132 | self.up16_8 = UpLayer(big_channels=big_channels[1], out_channels=256, small_channels=small_channels[1], norm_layer=norm_layer)
133 | self.up8_4 = UpLayer(big_channels=big_channels[2], out_channels=128, small_channels=small_channels[2], norm_layer=norm_layer)
134 | # self.up4_2 = UpLayer(big_channels=big_channels[3], out_channels=64, small_channels=small_channels[3], norm_layer=norm_layer)
135 | self.deconv1 = nn.Sequential(*[SpectralNorm(nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)), nn.LeakyReLU(0.2, True)])
136 | self.deconv2 = nn.Sequential(*[SpectralNorm(nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)), nn.LeakyReLU(0.2, True)])
137 | self.conv2 = nn.Sequential(*[nn.Conv2d(128, 1, kernel_size=1, stride=1)])
138 | self.logsoftmax = nn.LogSoftmax(dim=1)
139 |
140 | def forward(self, x, proposal):
141 | n,c,h,w = x.size()
142 | if self.temperature == -1: return torch.ones((n,1,h,w)).cuda()
143 | scale32 = proposal[2][3]
144 | scale16 = torch.cat((proposal[2][1], proposal[1][3]),1) if self.fused == 2 else proposal[1][3]
145 | scale8 = torch.cat((proposal[0][3], proposal[1][1], proposal[2][0]),1) if self.fused == 2 else proposal[0][3]
146 | scale4 = torch.cat((proposal[0][1], proposal[1][0]),1) if self.fused == 2 else proposal[0][1]
147 | scale2 = proposal[0][0]
148 | out = self.up32_16(scale32, scale16)
149 | out = self.up16_8(out, scale8)
150 | out = self.up8_4(out, scale4)
151 | # out = self.up4_2(out, scale2)
152 | out = self.deconv1(out)
153 | out = self.deconv2(out)
154 | out = self.conv2(out)
155 |
156 | if not self.use_gumbel:
157 | logits = self.logsoftmax(out.view(n, -1))
158 | th, _ = torch.topk(logits, k=int(self.temperature), dim=1, largest=True)
159 | mask, zeros, ones = torch.zeros_like(logits).cuda(), torch.zeros(h*w).cuda(), torch.ones(h*w).cuda()
160 | for i in range(n):
161 | mask[i,:] = torch.where(logits[i,:]>=th[i, int(self.temperature)-1], ones, zeros)
162 | mask = mask.view(n, 1, h, w)
163 | elif self.use_gumbel:
164 | logits = gumbel_softmax(out.view(n, -1), k=int(self.temperature), T=self.T, hard=True, eps=1e-10).view(n, 1, h, w)
165 | mask = logits.cuda()
166 | # logits = F.gumbel_softmax(out.view(n, -1), tau=self.temperature).view(n, 1, h, w)
167 | # # logits_normed = torch.clamp((logits_normed+1e-4), min=0, max=1)
168 | # logits = np.minimum(1.0, logits.data.cpu().numpy()*(h*w)+1e-4)
169 | # mask = torch.bernoulli(torch.from_numpy(logits)).cuda()
170 | return mask
171 |
172 | class UpLayer(nn.Module):
173 | def __init__(self, big_channels, out_channels, small_channels, norm_layer):
174 | super(UpLayer, self).__init__()
175 | self.big_channels = big_channels
176 | self.out_channels = out_channels
177 | self.small_channels = small_channels
178 | self.conv1 = nn.Sequential(*[SpectralNorm(nn.Conv2d(self.big_channels, self.small_channels, kernel_size=1, stride=1)), norm_layer(self.small_channels), nn.LeakyReLU(0.2, True)])
179 | self.conv2 = nn.Sequential(*[SpectralNorm(nn.Conv2d(self.small_channels, self.out_channels, kernel_size=3, stride=1, padding=1)), norm_layer(self.out_channels), nn.LeakyReLU(0.2, True)])
180 |
181 | def forward(self, small, big):
182 | small = F.upsample(small, size=(big.size()[2], big.size()[3]), mode='bilinear')
183 | big = self.conv1(big)
184 | out = self.conv2(big+small)
185 | return out
186 |
187 | class Generator(nn.Module):
188 | def __init__(self, input_nc, output_nc, ngf, norm='bn', n_blocks=6):
189 | super(Generator, self).__init__()
190 |
191 | n_downsampling = n_upsampling = 2
192 | use_bias = norm == 'in'
193 | norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d
194 | begin_layers, down_layers, res_layers, up_layers, end_layers = [], [], [], [], []
195 | for i in range(n_upsampling):
196 | up_layers.append([])
197 | # ngf
198 | begin_layers = [nn.ReflectionPad2d(3), SpectralNorm(nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias)), norm_layer(ngf), nn.ReLU(True)]
199 | # 2ngf, 4ngf
200 | for i in range(n_downsampling):
201 | mult = 2**i
202 | down_layers += [SpectralNorm(nn.Conv2d(ngf*mult, ngf*mult*2, kernel_size=3, stride=2, padding=1, bias=use_bias)), norm_layer(ngf*mult*2), nn.ReLU(True)]
203 | # 4ngf
204 | mult = 2**n_downsampling
205 | for i in range(n_blocks):
206 | res_layers += [ResnetBlock(ngf*mult, norm_layer, use_bias)]
207 | # 2ngf, ngf
208 | for i in range(n_upsampling):
209 | mult = 2**(n_upsampling - i)
210 | up_layers[i] += [SpectralNorm(nn.ConvTranspose2d(ngf*mult, int(ngf*mult/2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias)), norm_layer(int(ngf*mult/2)), nn.ReLU(True)]
211 |
212 | end_layers += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
213 |
214 | self.l1 = nn.Sequential(*begin_layers)
215 | self.l2 = nn.Sequential(*down_layers)
216 | self.l3 = nn.Sequential(*res_layers)
217 | self.l4_1 = nn.Sequential(*up_layers[0])
218 | self.l4_2 = nn.Sequential(*up_layers[1])
219 | self.l5 = nn.Sequential(*end_layers)
220 |
221 | def forward(self, inputs):
222 | out = self.l1(inputs)
223 | out = self.l2(out)
224 | out = self.l3(out)
225 | out = self.l4_1(out)
226 | out = self.l4_2(out)
227 | out = self.l5(out)
228 | return out
229 |
230 | class ResnetG(nn.Module):
231 | def __init__(self, input_nc, output_nc, ngf, norm='bn', n_blocks=6):
232 | super(ResnetG, self).__init__()
233 |
234 | n_downsampling = n_upsampling = 2
235 | use_bias = norm == 'in'
236 | norm_layer = nn.BatchNorm2d if norm == 'bn' else nn.InstanceNorm2d
237 | begin_layers, down_layers, res_layers, up_layers, end_layers = [], [], [], [], []
238 | for i in range(n_upsampling):
239 | up_layers.append([])
240 | # ngf
241 | begin_layers = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)]
242 | # 2ngf, 4ngf
243 | for i in range(n_downsampling):
244 | mult = 2**i
245 | down_layers += [nn.Conv2d(ngf*mult, ngf*mult*2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf*mult*2), nn.ReLU(True)]
246 | # 4ngf
247 | mult = 2**n_downsampling
248 | for i in range(n_blocks):
249 | res_layers += [ResnetBlock(ngf*mult, norm_layer, use_bias)]
250 | # 2ngf, ngf
251 | for i in range(n_upsampling):
252 | mult = 2**(n_upsampling - i)
253 | up_layers[i] += [nn.ConvTranspose2d(ngf*mult, int(ngf*mult/2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf*mult/2)), nn.ReLU(True)]
254 |
255 | end_layers += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
256 |
257 | self.l1 = nn.Sequential(*begin_layers)
258 | self.l2 = nn.Sequential(*down_layers)
259 | self.l3 = nn.Sequential(*res_layers)
260 | self.l4_1 = nn.Sequential(*up_layers[0])
261 | self.l4_2 = nn.Sequential(*up_layers[1])
262 | self.l5 = nn.Sequential(*end_layers)
263 |
264 | def forward(self, inputs):
265 | out = self.l1(inputs)
266 | out = self.l2(out)
267 | out = self.l3(out)
268 | out = self.l4_1(out)
269 | out = self.l4_2(out)
270 | out = self.l5(out)
271 | return out
272 |
273 | # Define a resnet block
274 | class ResnetBlock(nn.Module):
275 | def __init__(self, dim, norm_layer, use_bias):
276 | super(ResnetBlock, self).__init__()
277 | self.conv_block = self.build_conv_block(dim, norm_layer, use_bias)
278 |
279 | def build_conv_block(self, dim, norm_layer, use_bias):
280 | conv_block = []
281 | for i in range(2):
282 | conv_block += [nn.ReflectionPad2d(1)]
283 | conv_block += [SpectralNorm(nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=use_bias)), norm_layer(dim)]
284 | if i < 1:
285 | conv_block += [nn.ReLU(True)]
286 | return nn.Sequential(*conv_block)
287 |
288 | def forward(self, x):
289 | out = x + self.conv_block(x)
290 | return out
291 |
292 | class SamePadding(nn.Module):
293 | def __init__(self, kernel_size, stride):
294 | super(SamePadding, self).__init__()
295 | self.kernel_size = torch.nn.modules.utils._pair(kernel_size)
296 | self.stride = torch.nn.modules.utils._pair(stride)
297 |
298 | def forward(self, input):
299 | in_width = input.size()[2]
300 | in_height = input.size()[3]
301 | out_width = math.ceil(float(in_width) / float(self.stride[0]))
302 | out_height = math.ceil(float(in_height) / float(self.stride[1]))
303 | pad_along_width = ((out_width - 1) * self.stride[0] +
304 | self.kernel_size[0] - in_width)
305 | pad_along_height = ((out_height - 1) * self.stride[1] +
306 | self.kernel_size[1] - in_height)
307 | pad_left = int(pad_along_width / 2)
308 | pad_top = int(pad_along_height / 2)
309 | pad_right = pad_along_width - pad_left
310 | pad_bottom = pad_along_height - pad_top
311 | return F.pad(input, (int(pad_left), int(pad_right), int(pad_top), int(pad_bottom)), 'constant', 0)
312 |
313 | def __repr__(self):
314 | return self.__class__.__name__
315 |
316 | def weights_init(m):
317 | classname = m.__class__.__name__
318 | # print(dir(m))
319 | if classname.find('Conv') != -1:
320 | if 'weight' in dir(m):
321 | m.weight.data.normal_(0.0, 1)
322 | elif classname.find('BatchNorm2d') != -1:
323 | m.weight.data.normal_(1.0, 0.02)
324 | m.bias.data.fill_(0)
325 |
326 | class GANLoss(nn.Module):
327 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.cuda.FloatTensor):
328 | super(GANLoss, self).__init__()
329 | self.real_label = target_real_label
330 | self.fake_label = target_fake_label
331 | self.real_label_var = None
332 | self.fake_label_var = None
333 | self.Tensor = tensor
334 | if use_lsgan: self.loss = nn.MSELoss()
335 | else: self.loss = nn.BCELoss()
336 |
337 | def get_target_tensor(self, input, target_is_real):
338 | target_tensor = None
339 | if target_is_real:
340 | create_label = ((self.real_label_var is None) or
341 | (self.real_label_var.numel() != input.numel()))
342 | if create_label:
343 | real_tensor = self.Tensor(input.size()).fill_(self.real_label)
344 | self.real_label_var = Variable(real_tensor, requires_grad=False)
345 | target_tensor = self.real_label_var
346 | else:
347 | create_label = ((self.fake_label_var is None) or
348 | (self.fake_label_var.numel() != input.numel()))
349 | if create_label:
350 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
351 | self.fake_label_var = Variable(fake_tensor, requires_grad=False)
352 | target_tensor = self.fake_label_var
353 | return target_tensor
354 |
355 | def __call__(self, input, target_is_real):
356 | if isinstance(input[0], list):
357 | loss = 0
358 | for input_i in input:
359 | pred = input_i[-1]
360 | target_tensor = self.get_target_tensor(pred, target_is_real)
361 | loss += self.loss(pred, target_tensor)
362 | return loss
363 | else:
364 | target_tensor = self.get_target_tensor(input[-1], target_is_real)
365 | return self.loss(input[-1], target_tensor)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function, division
3 | import sys
4 | import time
5 | import datetime
6 | import argparse
7 | import os
8 | import numpy as np
9 | import os.path as osp
10 | import math
11 | from random import sample
12 | from scipy import io
13 |
14 | import torchvision
15 | import torch
16 | import torch.nn as nn
17 | import torch.optim as optim
18 | from torch.utils.data import DataLoader
19 | import torch.backends.cudnn as cudnn
20 |
21 | import models
22 | from models.PCB import PCB_test
23 | # from ReID_attr import get_target_withattr # Need Attribute file
24 | from opts import get_opts, Imagenet_mean, Imagenet_stddev
25 | from GD import Generator, MS_Discriminator, Pat_Discriminator, GANLoss, weights_init
26 | from advloss import DeepSupervision, adv_CrossEntropyLoss, adv_CrossEntropyLabelSmooth, adv_TripletLoss
27 | from util import data_manager
28 | from util.dataset_loader import ImageDataset
29 | from util.utils import fliplr, Logger, save_checkpoint, visualize_ranked_results
30 | from util.eval_metrics import make_results
31 | from util.samplers import RandomIdentitySampler, AttrPool
32 |
33 | # Training settings
34 | parser = argparse.ArgumentParser(description='adversarial attack')
35 | parser.add_argument('--root', type=str, default='data', help="root path to data directory")
36 | parser.add_argument('--targetmodel', type=str, default='aligned', choices=models.get_names())
37 | parser.add_argument('--dataset', type=str, default='market1501', choices=data_manager.get_names())
38 | # PATH
39 | parser.add_argument('--G_resume_dir', type=str, default='', metavar='path to resume G')
40 | parser.add_argument('--pre_dir', type=str, default='models', help='path to be attacked model')
41 | parser.add_argument('--attr_dir', type=str, default='', help='path to attribute file')
42 | parser.add_argument('--save_dir', type=str, default='logs', help='path to save model')
43 | parser.add_argument('--vis_dir', type=str, default='vis', help='path to save visualization result')
44 | parser.add_argument('--ablation', type=str, default='', help='for ablation study')
45 | # var
46 | parser.add_argument('--mode', type=str, default='train', help='train/test')
47 | parser.add_argument('--D', type=str, default='MSGAN', help='Type of discriminator: PatchGAN or Multi-stage GAN')
48 | parser.add_argument('--normalization', type=str, default='bn', help='bn or in')
49 | parser.add_argument('--loss', type=str, default='xent_htri', choices=['cent', 'xent', 'htri', 'xent_htri'])
50 | parser.add_argument('--ak_type', type=int, default=-1, help='-1 if non-targeted, 1 if attribute attack')
51 | parser.add_argument('--attr_key', type=str, default='upwhite', help='[attribute, value]')
52 | parser.add_argument('--attr_value', type=int, default=2, help='[attribute, value]')
53 | parser.add_argument('--mag_in', type=float, default=16.0, help='l_inf magnitude of perturbation')
54 | parser.add_argument('--temperature', type=float, default=-1, help="tau in paper")
55 | parser.add_argument('--usegumbel', action='store_true', default=False, help='whether to use gumbel softmax')
56 | parser.add_argument('--use_SSIM', type=int, default=2, help="0: None, 1: SSIM, 2: MS-SSIM ")
57 | # Base
58 | parser.add_argument('--train_batch', default=32, type=int,help="train batch size")
59 | parser.add_argument('--test_batch', default=32, type=int, help="test batch size")
60 | parser.add_argument('--epoch', type=int, default=50, help='number of epochs to train for')
61 |
62 | parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss")
63 | parser.add_argument('--num_ker', type=int, default=32, help='generator filters in first conv layer')
64 | parser.add_argument('--lr', type=float, default=0.0002, help='Learning Rate. Default=0.002')
65 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
66 | parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
67 | parser.add_argument('--print_freq', type=int, default=20, help="print frequency")
68 | parser.add_argument('--eval_freq', type=int, default=1, help="eval frequency")
69 | parser.add_argument('--usevis', action='store_true', default=False, help='whether to save vis')
70 |
71 | args = parser.parse_args()
72 | is_training = args.mode == 'train'
73 | attr_list = [args.attr_key, args.attr_value]
74 | attr_matrix = None
75 | if args.attr_dir:
76 | assert args.dataset in ['dukemtmcreid', 'market1501']
77 | attr_matrix = io.loadmat(args.attr_dir)
78 | args.ablation = osp.join('attr', args.attr_key + '=' + str(args.attr_value))
79 |
80 | pre_dir = osp.join(args.pre_dir, args.targetmodel, args.dataset+'.pth.tar')
81 | save_dir = osp.join(args.save_dir, args.targetmodel, args.dataset, args.ablation)
82 | vis_dir = osp.join(args.vis_dir, args.targetmodel, args.dataset, args.ablation)
83 |
84 |
85 | def main(opt):
86 | if not osp.exists(save_dir): os.makedirs(save_dir)
87 | if not osp.exists(vis_dir): os.makedirs(vis_dir)
88 |
89 | use_gpu = torch.cuda.is_available()
90 | pin_memory = True if use_gpu else False
91 |
92 | if args.mode == 'train':
93 | sys.stdout = Logger(osp.join(save_dir, 'log_train.txt'))
94 | else:
95 | sys.stdout = Logger(osp.join(save_dir, 'log_test.txt'))
96 | print("==========\nArgs:{}\n==========".format(args))
97 |
98 | if use_gpu:
99 | print("GPU mode")
100 | cudnn.benchmark = True
101 | torch.cuda.manual_seed(args.seed)
102 | else:
103 | print("CPU mode")
104 |
105 | ### Setup dataset loader ###
106 | print("Initializing dataset {}".format(args.dataset))
107 | dataset = data_manager.init_img_dataset(root=args.root, name=args.dataset, split_id=opt['split_id'], cuhk03_labeled=opt['cuhk03_labeled'], cuhk03_classic_split=opt['cuhk03_classic_split'])
108 | if args.ak_type < 0:
109 | trainloader = DataLoader(ImageDataset(dataset.train, transform=opt['transform_train']), sampler=RandomIdentitySampler(dataset.train, num_instances=opt['num_instances']), batch_size=args.train_batch, num_workers=opt['workers'], pin_memory=pin_memory, drop_last=True)
110 | elif args.ak_type > 0:
111 | trainloader = DataLoader(ImageDataset(dataset.train, transform=opt['transform_train']), sampler=AttrPool(dataset.train, args.dataset, attr_matrix, attr_list, sample_num=16), batch_size=args.train_batch, num_workers=opt['workers'], pin_memory=pin_memory, drop_last=True)
112 | queryloader = DataLoader(ImageDataset(dataset.query, transform=opt['transform_test']), batch_size=args.test_batch, shuffle=False, num_workers=opt['workers'], pin_memory=pin_memory, drop_last=False)
113 | galleryloader = DataLoader(ImageDataset(dataset.gallery, transform=opt['transform_test']), batch_size=args.test_batch, shuffle=False, num_workers=opt['workers'], pin_memory=pin_memory, drop_last=False)
114 |
115 | ### Prepare criterion ###
116 | if args.ak_type<0:
117 | clf_criterion = adv_CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids, use_gpu=use_gpu) if args.loss in ['xent', 'xent_htri'] else adv_CrossEntropyLoss(use_gpu=use_gpu)
118 | else:
119 | clf_criterion = nn.MultiLabelSoftMarginLoss()
120 | metric_criterion = adv_TripletLoss(margin=args.margin, ak_type=args.ak_type)
121 | criterionGAN = GANLoss()
122 |
123 | ### Prepare pretrained model ###
124 | target_net = models.init_model(name=args.targetmodel, pre_dir=pre_dir, num_classes=dataset.num_train_pids)
125 | check_freezen(target_net, need_modified=True, after_modified=False)
126 |
127 | ### Prepare main net ###
128 | G = Generator(3, 3, args.num_ker, norm=args.normalization).apply(weights_init)
129 | if args.D == 'PatchGAN':
130 | D = Pat_Discriminator(input_nc=6, norm=args.normalization).apply(weights_init)
131 | elif args.D == 'MSGAN':
132 | D = MS_Discriminator(input_nc=6, norm=args.normalization, temperature=args.temperature, use_gumbel=args.usegumbel).apply(weights_init)
133 | check_freezen(G, need_modified=True, after_modified=True)
134 | check_freezen(D, need_modified=True, after_modified=True)
135 | print("Model size: {:.5f}M".format((sum(g.numel() for g in G.parameters())+sum(d.numel() for d in D.parameters()))/1000000.0))
136 | # setup optimizer
137 | optimizer_G = optim.Adam(G.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
138 | optimizer_D = optim.Adam(D.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
139 |
140 | if use_gpu:
141 | test_target_net = nn.DataParallel(target_net).cuda() if not args.targetmodel == 'pcb' else nn.DataParallel(PCB_test(target_net)).cuda()
142 | target_net = nn.DataParallel(target_net).cuda()
143 | G = nn.DataParallel(G).cuda()
144 | D = nn.DataParallel(D).cuda()
145 |
146 | if args.mode == 'test':
147 | epoch = 'test'
148 | test(G, D, test_target_net, dataset, queryloader, galleryloader, epoch, use_gpu, is_test=True)
149 | return 0
150 |
151 | # Ready
152 | start_time = time.time()
153 | train_time = 0
154 | worst_mAP, worst_rank1, worst_rank5, worst_rank10, worst_epoch = np.inf, np.inf, np.inf, np.inf, 0
155 | best_hit, best_epoch = -np.inf, 0
156 | print("==> Start training")
157 |
158 | for epoch in range(1,args.epoch+1):
159 | start_train_time = time.time()
160 | train(epoch, G, D, target_net, criterionGAN, clf_criterion, metric_criterion, optimizer_G, optimizer_D, trainloader, use_gpu)
161 | train_time += round(time.time() - start_train_time)
162 |
163 | if epoch % args.eval_freq == 0:
164 | print("==> Eval at epoch {}".format(epoch))
165 | if args.ak_type < 0:
166 | cmc, mAP = test(G, D, test_target_net, dataset, queryloader, galleryloader, epoch, use_gpu, is_test=False)
167 | is_worst = cmc[0]<=worst_rank1 and cmc[1]<=worst_rank5 and cmc[2]<=worst_rank10 and mAP<=worst_mAP
168 | if is_worst:
169 | worst_mAP, worst_rank1, worst_epoch = mAP, cmc[0], epoch
170 | print("==> Worst_epoch is {}, Worst mAP {:.1%}, Worst rank-1 {:.1%}".format(worst_epoch, worst_mAP, worst_rank1))
171 | save_checkpoint(G.state_dict(), is_worst, 'G', osp.join(save_dir, 'G_ep' + str(epoch) + '.pth.tar'))
172 | save_checkpoint(D.state_dict(), is_worst, 'D', osp.join(save_dir, 'D_ep' + str(epoch) + '.pth.tar'))
173 |
174 | else:
175 | all_hits = test(G, D, target_net, dataset, queryloader, galleryloader, epoch, use_gpu, is_test=False)
176 | is_best = all_hits[0]>=best_hit
177 | if is_best:
178 | best_hit, best_epoch = all_hits[0], epoch
179 | print("==> Best_epoch is {}, Best rank-1 {:.1%}".format(best_epoch, best_hit))
180 | save_checkpoint(G.state_dict(), is_best, 'G', osp.join(save_dir, 'G_ep' + str(epoch) + '.pth.tar'))
181 | save_checkpoint(D.state_dict(), is_best, 'D', osp.join(save_dir, 'D_ep' + str(epoch) + '.pth.tar'))
182 |
183 | elapsed = round(time.time() - start_time)
184 | elapsed = str(datetime.timedelta(seconds=elapsed))
185 | train_time = str(datetime.timedelta(seconds=train_time))
186 | print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
187 |
188 | def train(epoch, G, D, target_net, criterionGAN, clf_criterion, metric_criterion, optimizer_G, optimizer_D, trainloader, use_gpu):
189 | G.train()
190 | D.train()
191 | global is_training
192 | is_training = True
193 |
194 | for batch_idx, (imgs, pids, _, pids_raw) in enumerate(trainloader):
195 | if use_gpu:
196 | imgs, pids, pids_raw = imgs.cuda(), pids.cuda(), pids_raw.cuda()
197 |
198 | new_imgs, mask = perturb(imgs, G, D, train_or_test='train')
199 | new_imgs = new_imgs.cuda()
200 | mask = mask.cuda()
201 | # Fake Detection and Loss
202 | pred_fake_pool, _ = D(torch.cat((imgs, new_imgs.detach()), 1))
203 | loss_D_fake = criterionGAN(pred_fake_pool, False)
204 |
205 | # Real Detection and Loss
206 | num = args.train_batch//2
207 | pred_real, _ = D(torch.cat((imgs[0:num,:,:,:], imgs[num:,:,:,:].detach()), 1))
208 | loss_D_real = criterionGAN(pred_real, True)
209 |
210 | # GAN loss (Fake Passability Loss)
211 | pred_fake, _ = D(torch.cat((imgs, new_imgs), 1))
212 | loss_G_GAN = criterionGAN(pred_fake, True)
213 |
214 | # Re-ID advloss
215 | ls = target_net(new_imgs, is_training)
216 | if len(ls) == 1: new_outputs = ls[0]
217 | if len(ls) == 2: new_outputs, new_features = ls
218 | if len(ls) == 3: new_outputs, new_features, new_local_features = ls
219 | xent_loss, global_loss, loss_G_ssim = 0, 0, 0
220 | targets = None
221 |
222 | if args.loss in ['cent', 'xent', 'xent_htri']:
223 | if args.ak_type < 0:
224 | xent_loss = DeepSupervision(clf_criterion, new_outputs, pids) if isinstance(new_features, (tuple, list)) else clf_criterion(new_outputs, pids)
225 |
226 | elif args.ak_type > 0:
227 | targets = get_target_withattr(attr_matrix, args.dataset, attr_list, pids, pids_raw).float().cuda()
228 | xent_loss = 0#DeepSupervision(clf_criterion, new_outputs, targets) if isinstance(new_features, (tuple, list)) else clf_criterion(new_outputs, targets)
229 |
230 | if args.loss in ['htri', 'xent_htri']:
231 | assert len(ls) >= 2
232 | global_loss = DeepSupervision(metric_criterion, new_features, pids, targets) if isinstance(new_features, (tuple, list)) else metric_criterion(new_features, pids, targets)
233 |
234 | loss_G_ReID = (xent_loss+ global_loss)*opt['ReID_factor']
235 |
236 | # # SSIM loss
237 | if not args.use_SSIM == 0:
238 | from util.ms_ssim import msssim, ssim
239 | loss_func = msssim if args.use_SSIM == 2 else ssim
240 | loss_G_ssim = (1-loss_func(imgs, new_imgs))*0.1
241 |
242 | ############## Forward ###############
243 | loss_D = (loss_D_fake + loss_D_real)/2
244 | loss_G = loss_G_GAN + loss_G_ReID + loss_G_ssim
245 | ############## Backward #############
246 | # update generator weights
247 | optimizer_G.zero_grad()
248 | # loss_G.backward(retain_graph=True)
249 | loss_G.backward()
250 | optimizer_G.step()
251 | # update discriminator weights
252 | optimizer_D.zero_grad()
253 | loss_D.backward()
254 | optimizer_D.step()
255 | if (batch_idx+1) % args.print_freq == 0:
256 | print("===> Epoch[{}]({}/{}) loss_D: {:.4f} loss_G_GAN: {:.4f} loss_G_ReID: {:.4f} loss_G_SSIM: {:.4f}".format(epoch, batch_idx, len(trainloader), loss_D.item(), loss_G_GAN.item(), loss_G_ReID.item(), loss_G_ssim))
257 |
258 | def test(G, D, target_net, dataset, queryloader, galleryloader, epoch, use_gpu, is_test=False, ranks=[1, 5, 10, 20]):
259 | global is_training
260 | is_training = False
261 | if args.mode == 'test' and args.G_resume_dir:
262 | G_resume_dir, D_resume_dir = args.G_resume_dir, args.G_resume_dir.replace('G', 'D')
263 | G_checkpoint, D_checkpoint = torch.load(G_resume_dir), torch.load(D_resume_dir)
264 | G_state_dict = G_checkpoint['state_dict'] if isinstance(G_checkpoint, dict) and 'state_dict' in G_checkpoint else G_checkpoint
265 | D_state_dict = D_checkpoint['state_dict'] if isinstance(D_checkpoint, dict) and 'state_dict' in D_checkpoint else D_checkpoint
266 |
267 | G.load_state_dict(G_state_dict)
268 | D.load_state_dict(D_state_dict)
269 | print("Sucessfully, loading {} and {}".format(G_resume_dir, D_resume_dir))
270 |
271 | with torch.no_grad():
272 | qf, lqf, new_qf, new_lqf, q_pids, q_camids = extract_and_perturb(queryloader, G, D, target_net, use_gpu, query_or_gallery='query', is_test=is_test, epoch=epoch)
273 | gf, lgf, g_pids, g_camids = extract_and_perturb(galleryloader, G, D, target_net, use_gpu, query_or_gallery='gallery', is_test=is_test, epoch=epoch)
274 |
275 | if args.ak_type > 0:
276 | distmat, hits, ignore_list = make_results(new_qf, gf, new_lqf, lgf, q_pids, g_pids, q_camids, g_camids, args.targetmodel, args.ak_type, attr_matrix, args.dataset, attr_list)
277 | print("Hits rate, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}".format(ranks[0], hits[ranks[0]-1], ranks[1], hits[ranks[1]-1], ranks[2], hits[ranks[2]-1], ranks[3], hits[ranks[3]-1]))
278 | if not is_test:
279 | return hits
280 |
281 | else:
282 | if is_test:
283 | distmat, cmc, mAP = make_results(qf, gf, lqf, lgf, q_pids, g_pids, q_camids, g_camids, args.targetmodel, args.ak_type)
284 | new_distmat, new_cmc, new_mAP = make_results(new_qf, gf, new_lqf, lgf, q_pids, g_pids, q_camids, g_camids, args.targetmodel, args.ak_type)
285 | print("Results ----------")
286 | print("Before, mAP: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}".format(mAP, ranks[0], cmc[ranks[0]-1], ranks[1], cmc[ranks[1]-1], ranks[2], cmc[ranks[2]-1], ranks[3], cmc[ranks[3]-1]))
287 | print("After , mAP: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}".format(new_mAP, ranks[0], new_cmc[ranks[0]-1], ranks[1], new_cmc[ranks[1]-1], ranks[2], new_cmc[ranks[2]-1], ranks[3], new_cmc[ranks[3]-1]))
288 | if args.usevis:
289 | visualize_ranked_results(distmat, dataset, save_dir=osp.join(vis_dir, 'origin_results'), topk=20)
290 | if args.usevis:
291 | visualize_ranked_results(new_distmat, dataset, save_dir=osp.join(vis_dir, 'polluted_results'), topk=20)
292 | else:
293 | _, new_cmc, new_mAP = make_results(new_qf, gf, new_lqf, lgf, q_pids, g_pids, q_camids, g_camids, args.targetmodel, args.ak_type)
294 | print("mAP: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}, Rank-{}: {:.1%}".format(new_mAP, ranks[0], new_cmc[ranks[0]-1], ranks[1], new_cmc[ranks[1]-1], ranks[2], new_cmc[ranks[2]-1], ranks[3], new_cmc[ranks[3]-1]))
295 | return new_cmc, new_mAP
296 |
297 | def extract_and_perturb(loader, G, D, target_net, use_gpu, query_or_gallery, is_test, epoch):
298 | f, lf, new_f, new_lf, l_pids, l_camids = [], [], [], [], [], []
299 | ave_mask, num = 0, 0
300 | for batch_idx, (imgs, pids, camids, pids_raw) in enumerate(loader):
301 | if use_gpu:
302 | imgs = imgs.cuda()
303 | ls = extract(imgs, target_net)
304 | if len(ls) == 1: features = ls[0]
305 | if len(ls) == 2:
306 | features, local_features = ls
307 | lf.append(local_features.detach().data.cpu())
308 |
309 | f.append(features.detach().data.cpu())
310 | l_pids.extend(pids)
311 | l_camids.extend(camids)
312 |
313 | if query_or_gallery == 'query':
314 | G.eval()
315 | D.eval()
316 | new_imgs, delta, mask = perturb(imgs, G, D, train_or_test='test')
317 | ave_mask += torch.sum(mask.detach()).cpu().numpy()
318 | num += imgs.size(0)
319 |
320 | ls = extract(new_imgs, target_net)
321 | if len(ls) == 1: new_features = ls[0]
322 | if len(ls) == 2:
323 | new_features, new_local_features = ls
324 | new_lf.append(new_local_features.detach().data.cpu())
325 | new_f.append(new_features.detach().data.cpu())
326 |
327 | ls = [imgs, new_imgs, delta, mask]
328 | if is_test:
329 | save_img(ls, pids, camids, epoch, batch_idx)
330 |
331 | f = torch.cat(f, 0)
332 | if not lf == []: lf = torch.cat(lf, 0)
333 | l_pids, l_camids = np.asarray(l_pids), np.asarray(l_camids)
334 |
335 | print("Extracted features for {} set, obtained {}-by-{} matrix".format(query_or_gallery, f.size(0), f.size(1)))
336 | if query_or_gallery == 'gallery':
337 | return [f, lf, l_pids, l_camids]
338 | elif query_or_gallery == 'query':
339 | new_f = torch.cat(new_f, 0)
340 | if not new_lf == []:
341 | new_lf = torch.cat(new_lf, 0)
342 | return [f, lf, new_f, new_lf, l_pids, l_camids]
343 |
344 | def extract(imgs, target_net):
345 | if args.targetmodel in ['pcb', 'lsro']:
346 | ls = [target_net(imgs, is_training)[0] + target_net(fliplr(imgs), is_training)[0]]
347 | else:
348 | ls = target_net(imgs, is_training)
349 | for i in range(len(ls)): ls[i] = ls[i].data.cpu()
350 | return ls
351 |
352 | def perturb(imgs, G, D, train_or_test='test'):
353 | n,c,h,w = imgs.size()
354 | delta = G(imgs)
355 | delta = L_norm(delta, train_or_test)
356 | new_imgs = torch.add(imgs.cuda(), delta[0:imgs.size(0)].cuda())
357 |
358 | _, mask = D(torch.cat((imgs, new_imgs.detach()), 1))
359 | delta = delta * mask
360 | new_imgs = torch.add(imgs.cuda(), delta[0:imgs.size(0)].cuda())
361 |
362 | for c in range(3):
363 | new_imgs.data[:,c,:,:] = new_imgs.data[:,c,:,:].clamp(new_imgs.data[:,c,:,:].min(), new_imgs.data[:,c,:,:].max()) # do clamping per channel
364 | if train_or_test == 'train':
365 | return new_imgs, mask
366 | elif train_or_test == 'test':
367 | return new_imgs, delta, mask
368 |
369 | def L_norm(delta, mode='train'):
370 | delta.data += 1
371 | delta.data *= 0.5
372 |
373 | for c in range(3):
374 | delta.data[:,c,:,:] = (delta.data[:,c,:,:] - Imagenet_mean[c]) / Imagenet_stddev[c]
375 |
376 | bs = args.train_batch if (mode == 'train') else args.test_batch
377 | for i in range(bs):
378 | # do per channel l_inf normalization
379 | for ci in range(3):
380 | try:
381 | l_inf_channel = delta[i,ci,:,:].data.abs().max()
382 | # l_inf_channel = torch.norm(delta[i,ci,:,:]).data
383 | mag_in_scaled_c = args.mag_in/(255.0*Imagenet_stddev[ci])
384 | delta[i,ci,:,:].data *= np.minimum(1.0, mag_in_scaled_c / l_inf_channel.cpu()).float().cuda()
385 | except IndexError:
386 | break
387 | return delta
388 |
389 | def save_img(ls, pids, camids, epoch, batch_idx):
390 | image, new_image, delta, mask = ls
391 | # undo normalize image color channels
392 | delta_tmp = torch.zeros(delta.size())
393 | for c in range(3):
394 | image.data[:,c,:,:] = (image.data[:,c,:,:] * Imagenet_stddev[c]) + Imagenet_mean[c]
395 | new_image.data[:,c,:,:] = (new_image.data[:,c,:,:] * Imagenet_stddev[c]) + Imagenet_mean[c]
396 | delta_tmp.data[:,c,:,:] = (delta.data[:,c,:,:] * Imagenet_stddev[c]) + Imagenet_mean[c]
397 |
398 | if args.usevis:
399 | torchvision.utils.save_image(image.data, osp.join(vis_dir, 'original_epoch{}_batch{}.png'.format(epoch, batch_idx)))
400 | torchvision.utils.save_image(new_image.data, osp.join(vis_dir, 'polluted_epoch{}_batch{}.png'.format(epoch, batch_idx)))
401 | torchvision.utils.save_image(delta_tmp.data, osp.join(vis_dir, 'delta_epoch{}_batch{}.png'.format(epoch, batch_idx)))
402 | torchvision.utils.save_image(mask.data*255, osp.join(vis_dir, 'mask_epoch{}_batch{}.png'.format(epoch, batch_idx)))
403 |
404 | def check_freezen(net, need_modified=False, after_modified=None):
405 | # print(net)
406 | cc = 0
407 | for child in net.children():
408 | for param in child.parameters():
409 | if need_modified: param.requires_grad = after_modified
410 | # if param.requires_grad: print('child', cc , 'was active')
411 | # else: print('child', cc , 'was forzen')
412 | cc += 1
413 |
414 | if __name__ == '__main__':
415 | opt = get_opts(args.targetmodel)
416 | main(opt)
417 |
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import torchvision
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.transforms as transforms
6 | from util import transforms as T
7 |
8 | Imagenet_mean = [0.485, 0.456, 0.406]
9 | Imagenet_stddev = [0.229, 0.224, 0.225]
10 |
11 | market1501_train_map = {2: 0, 7: 1, 10: 2, 11: 3, 12: 4, 20: 5, 22: 6, 23: 7, 27: 8, 28: 9, 30: 10, 32: 11, 35: 12, 37: 13, 42: 14, 43: 15, 46: 16, 47: 17, 48: 18, 52: 19, 53: 20, 56: 21, 57: 22, 59: 23, 64: 24, 65: 25, 67: 26, 68: 27, 69: 28, 70: 29, 76: 30, 77: 31, 79: 32, 81: 33, 82: 34, 84: 35, 86: 36, 88: 37, 90: 38, 93: 39, 95: 40, 97: 41, 98: 42, 99: 43, 100: 44, 104: 45, 105: 46, 106: 47, 107: 48, 108: 49, 110: 50, 111: 51, 114: 52, 115: 53, 116: 54, 117: 55, 118: 56, 121: 57, 122: 58, 123: 59, 125: 60, 127: 61, 129: 62, 132: 63, 134: 64, 135: 65, 136: 66, 139: 67, 140: 68, 141: 69, 142: 70, 143: 71, 148: 72, 149: 73, 150: 74, 151: 75, 158: 76, 159: 77, 160: 78, 162: 79, 164: 80, 166: 81, 167: 82, 169: 83, 172: 84, 173: 85, 175: 86, 176: 87, 177: 88, 178: 89, 179: 90, 180: 91, 181: 92, 184: 93, 185: 94, 190: 95, 193: 96, 195: 97, 197: 98, 199: 99, 201: 100, 202: 101, 204: 102, 206: 103, 208: 104, 209: 105, 211: 106, 212: 107, 214: 108, 216: 109, 221: 110, 222: 111, 223: 112, 224: 113, 225: 114, 232: 115, 234: 116, 236: 117, 237: 118, 239: 119, 241: 120, 242: 121, 243: 122, 245: 123, 248: 124, 249: 125, 250: 126, 251: 127, 254: 128, 255: 129, 259: 130, 261: 131, 264: 132, 266: 133, 268: 134, 269: 135, 272: 136, 273: 137, 276: 138, 277: 139, 279: 140, 281: 141, 282: 142, 287: 143, 296: 144, 297: 145, 298: 146, 299: 147, 301: 148, 303: 149, 306: 150, 307: 151, 308: 152, 309: 153, 313: 154, 314: 155, 317: 156, 318: 157, 321: 158, 323: 159, 324: 160, 325: 161, 326: 162, 327: 163, 328: 164, 331: 165, 332: 166, 333: 167, 335: 168, 338: 169, 339: 170, 340: 171, 341: 172, 344: 173, 347: 174, 348: 175, 349: 176, 350: 177, 352: 178, 354: 179, 357: 180, 358: 181, 359: 182, 361: 183, 367: 184, 368: 185, 369: 186, 370: 187, 371: 188, 374: 189, 375: 190, 376: 191, 377: 192, 379: 193, 380: 194, 382: 195, 383: 196, 384: 197, 385: 198, 386: 199, 389: 200, 390: 201, 392: 202, 393: 203, 394: 204, 397: 205, 398: 206, 399: 207, 402: 208, 403: 209, 404: 210, 407: 211, 408: 212, 409: 213, 410: 214, 411: 215, 413: 216, 414: 217, 415: 218, 419: 219, 420: 220, 421: 221, 423: 222, 424: 223, 427: 224, 429: 225, 430: 226, 432: 227, 433: 228, 434: 229, 435: 230, 437: 231, 441: 232, 442: 233, 444: 234, 445: 235, 446: 236, 449: 237, 450: 238, 451: 239, 456: 240, 457: 241, 459: 242, 464: 243, 466: 244, 468: 245, 470: 246, 472: 247, 475: 248, 477: 249, 480: 250, 481: 251, 482: 252, 484: 253, 485: 254, 486: 255, 491: 256, 494: 257, 496: 258, 499: 259, 500: 260, 503: 261, 508: 262, 509: 263, 513: 264, 515: 265, 516: 266, 517: 267, 518: 268, 519: 269, 522: 270, 524: 271, 525: 272, 528: 273, 529: 274, 534: 275, 536: 276, 537: 277, 539: 278, 540: 279, 545: 280, 546: 281, 547: 282, 549: 283, 551: 284, 552: 285, 554: 286, 555: 287, 556: 288, 557: 289, 558: 290, 563: 291, 564: 292, 565: 293, 566: 294, 570: 295, 571: 296, 572: 297, 573: 298, 575: 299, 579: 300, 581: 301, 584: 302, 586: 303, 588: 304, 589: 305, 592: 306, 593: 307, 594: 308, 596: 309, 597: 310, 599: 311, 603: 312, 604: 313, 605: 314, 606: 315, 611: 316, 612: 317, 613: 318, 614: 319, 615: 320, 616: 321, 619: 322, 620: 323, 622: 324, 623: 325, 628: 326, 629: 327, 630: 328, 633: 329, 635: 330, 636: 331, 637: 332, 639: 333, 640: 334, 641: 335, 642: 336, 645: 337, 647: 338, 648: 339, 649: 340, 652: 341, 653: 342, 655: 343, 656: 344, 657: 345, 658: 346, 659: 347, 660: 348, 661: 349, 662: 350, 663: 351, 665: 352, 666: 353, 667: 354, 669: 355, 670: 356, 673: 357, 674: 358, 676: 359, 677: 360, 681: 361, 682: 362, 683: 363, 685: 364, 688: 365, 689: 366, 696: 367, 697: 368, 700: 369, 701: 370, 702: 371, 703: 372, 704: 373, 705: 374, 706: 375, 707: 376, 708: 377, 709: 378, 711: 379, 712: 380, 714: 381, 718: 382, 724: 383, 726: 384, 729: 385, 730: 386, 733: 387, 734: 388, 738: 389, 739: 390, 741: 391, 742: 392, 744: 393, 748: 394, 749: 395, 752: 396, 754: 397, 755: 398, 757: 399, 759: 400, 760: 401, 761: 402, 762: 403, 765: 404, 766: 405, 767: 406, 772: 407, 773: 408, 774: 409, 779: 410, 780: 411, 781: 412, 782: 413, 785: 414, 787: 415, 788: 416, 792: 417, 793: 418, 795: 419, 796: 420, 802: 421, 803: 422, 806: 423, 809: 424, 810: 425, 814: 426, 816: 427, 818: 428, 820: 429, 821: 430, 823: 431, 826: 432, 828: 433, 830: 434, 832: 435, 833: 436, 837: 437, 839: 438, 840: 439, 842: 440, 843: 441, 844: 442, 848: 443, 849: 444, 850: 445, 851: 446, 854: 447, 855: 448, 857: 449, 859: 450, 862: 451, 863: 452, 864: 453, 868: 454, 871: 455, 872: 456, 875: 457, 876: 458, 879: 459, 882: 460, 883: 461, 885: 462, 886: 463, 887: 464, 890: 465, 891: 466, 892: 467, 893: 468, 894: 469, 895: 470, 896: 471, 898: 472, 900: 473, 901: 474, 902: 475, 903: 476, 904: 477, 905: 478, 907: 479, 914: 480, 915: 481, 917: 482, 919: 483, 926: 484, 930: 485, 933: 486, 936: 487, 939: 488, 940: 489, 941: 490, 942: 491, 943: 492, 945: 493, 946: 494, 947: 495, 948: 496, 952: 497, 953: 498, 954: 499, 955: 500, 957: 501, 958: 502, 961: 503, 962: 504, 963: 505, 967: 506, 969: 507, 970: 508, 971: 509, 972: 510, 973: 511, 975: 512, 976: 513, 979: 514, 982: 515, 984: 516, 986: 517, 987: 518, 988: 519, 990: 520, 991: 521, 992: 522, 994: 523, 995: 524, 997: 525, 998: 526, 999: 527, 1000: 528, 1001: 529, 1002: 530, 1003: 531, 1004: 532, 1007: 533, 1010: 534, 1011: 535, 1012: 536, 1017: 537, 1018: 538, 1019: 539, 1023: 540, 1025: 541, 1027: 542, 1030: 543, 1031: 544, 1032: 545, 1033: 546, 1038: 547, 1039: 548, 1041: 549, 1045: 550, 1048: 551, 1049: 552, 1051: 553, 1052: 554, 1055: 555, 1056: 556, 1066: 557, 1071: 558, 1072: 559, 1075: 560, 1076: 561, 1078: 562, 1079: 563, 1080: 564, 1081: 565, 1086: 566, 1088: 567, 1091: 568, 1093: 569, 1094: 570, 1096: 571, 1097: 572, 1098: 573, 1099: 574, 1100: 575, 1101: 576, 1106: 577, 1107: 578, 1110: 579, 1111: 580, 1112: 581, 1113: 582, 1114: 583, 1115: 584, 1116: 585, 1117: 586, 1123: 587, 1124: 588, 1126: 589, 1127: 590, 1129: 591, 1132: 592, 1134: 593, 1135: 594, 1138: 595, 1140: 596, 1142: 597, 1152: 598, 1157: 599, 1158: 600, 1159: 601, 1162: 602, 1165: 603, 1167: 604, 1168: 605, 1169: 606, 1173: 607, 1176: 608, 1177: 609, 1178: 610, 1179: 611, 1189: 612, 1193: 613, 1197: 614, 1198: 615, 1200: 616, 1201: 617, 1204: 618, 1206: 619, 1213: 620, 1217: 621, 1218: 622, 1219: 623, 1220: 624, 1227: 625, 1230: 626, 1231: 627, 1232: 628, 1234: 629, 1235: 630, 1237: 631, 1238: 632, 1240: 633, 1242: 634, 1243: 635, 1244: 636, 1250: 637, 1252: 638, 1253: 639, 1254: 640, 1257: 641, 1258: 642, 1260: 643, 1261: 644, 1263: 645, 1266: 646, 1269: 647, 1275: 648, 1278: 649, 1281: 650, 1286: 651, 1289: 652, 1291: 653, 1292: 654, 1294: 655, 1295: 656, 1296: 657, 1297: 658, 1300: 659, 1303: 660, 1304: 661, 1309: 662, 1313: 663, 1315: 664, 1316: 665, 1318: 666, 1320: 667, 1321: 668, 1325: 669, 1326: 670, 1327: 671, 1330: 672, 1331: 673, 1332: 674, 1334: 675, 1335: 676, 1336: 677, 1338: 678, 1339: 679, 1341: 680, 1343: 681, 1344: 682, 1346: 683, 1350: 684, 1353: 685, 1358: 686, 1363: 687, 1364: 688, 1365: 689, 1368: 690, 1372: 691, 1373: 692, 1379: 693, 1380: 694, 1381: 695, 1385: 696, 1386: 697, 1389: 698, 1391: 699, 1392: 700, 1393: 701, 1400: 702, 1402: 703, 1404: 704, 1405: 705, 1406: 706, 1407: 707, 1408: 708, 1409: 709, 1411: 710, 1415: 711, 1420: 712, 1421: 713, 1422: 714, 1426: 715, 1427: 716, 1428: 717, 1430: 718, 1432: 719, 1433: 720, 1434: 721, 1437: 722, 1442: 723, 1443: 724, 1445: 725, 1447: 726, 1449: 727, 1451: 728, 1453: 729, 1454: 730, 1455: 731, 1458: 732, 1463: 733, 1464: 734, 1466: 735, 1467: 736, 1469: 737, 1470: 738, 1471: 739, 1473: 740, 1474: 741, 1475: 742, 1479: 743, 1480: 744, 1487: 745, 1489: 746, 1492: 747, 1495: 748, 1496: 749, 1500: 750}
12 | market1501_test_map = {0: 0, 1: 1, 3: 2, 4: 3, 5: 4, 6: 5, 8: 6, 9: 7, 13: 8, 14: 9, 15: 10, 16: 11, 17: 12, 18: 13, 19: 14, 21: 15, 24: 16, 25: 17, 26: 18, 29: 19, 31: 20, 33: 21, 34: 22, 36: 23, 38: 24, 39: 25, 40: 26, 41: 27, 44: 28, 45: 29, 49: 30, 50: 31, 51: 32, 54: 33, 55: 34, 58: 35, 60: 36, 61: 37, 62: 38, 63: 39, 66: 40, 71: 41, 72: 42, 73: 43, 74: 44, 75: 45, 78: 46, 80: 47, 83: 48, 85: 49, 87: 50, 89: 51, 91: 52, 92: 53, 94: 54, 96: 55, 101: 56, 102: 57, 103: 58, 109: 59, 112: 60, 113: 61, 119: 62, 120: 63, 124: 64, 126: 65, 128: 66, 130: 67, 131: 68, 133: 69, 137: 70, 138: 71, 144: 72, 145: 73, 146: 74, 147: 75, 152: 76, 153: 77, 154: 78, 155: 79, 156: 80, 157: 81, 161: 82, 163: 83, 165: 84, 168: 85, 170: 86, 171: 87, 174: 88, 182: 89, 183: 90, 186: 91, 187: 92, 188: 93, 189: 94, 191: 95, 192: 96, 194: 97, 196: 98, 198: 99, 200: 100, 203: 101, 205: 102, 207: 103, 210: 104, 213: 105, 215: 106, 217: 107, 218: 108, 219: 109, 220: 110, 226: 111, 227: 112, 228: 113, 229: 114, 230: 115, 231: 116, 233: 117, 235: 118, 238: 119, 240: 120, 244: 121, 246: 122, 247: 123, 252: 124, 253: 125, 256: 126, 257: 127, 258: 128, 260: 129, 262: 130, 263: 131, 265: 132, 267: 133, 270: 134, 271: 135, 274: 136, 275: 137, 278: 138, 280: 139, 283: 140, 284: 141, 285: 142, 286: 143, 288: 144, 289: 145, 290: 146, 291: 147, 292: 148, 293: 149, 294: 150, 295: 151, 300: 152, 302: 153, 304: 154, 305: 155, 310: 156, 311: 157, 312: 158, 315: 159, 316: 160, 319: 161, 320: 162, 322: 163, 329: 164, 330: 165, 334: 166, 336: 167, 337: 168, 342: 169, 343: 170, 345: 171, 346: 172, 351: 173, 353: 174, 355: 175, 356: 176, 360: 177, 362: 178, 363: 179, 364: 180, 365: 181, 366: 182, 372: 183, 373: 184, 378: 185, 381: 186, 387: 187, 388: 188, 391: 189, 395: 190, 396: 191, 400: 192, 401: 193, 405: 194, 406: 195, 412: 196, 416: 197, 417: 198, 418: 199, 422: 200, 425: 201, 426: 202, 428: 203, 431: 204, 436: 205, 438: 206, 439: 207, 440: 208, 443: 209, 447: 210, 448: 211, 452: 212, 453: 213, 454: 214, 455: 215, 458: 216, 460: 217, 461: 218, 462: 219, 463: 220, 465: 221, 467: 222, 469: 223, 471: 224, 473: 225, 474: 226, 476: 227, 478: 228, 479: 229, 483: 230, 487: 231, 488: 232, 489: 233, 490: 234, 492: 235, 493: 236, 495: 237, 497: 238, 498: 239, 501: 240, 502: 241, 504: 242, 505: 243, 506: 244, 507: 245, 510: 246, 511: 247, 512: 248, 514: 249, 520: 250, 521: 251, 523: 252, 526: 253, 527: 254, 530: 255, 531: 256, 532: 257, 533: 258, 535: 259, 538: 260, 541: 261, 542: 262, 543: 263, 544: 264, 548: 265, 550: 266, 553: 267, 559: 268, 560: 269, 561: 270, 562: 271, 567: 272, 568: 273, 569: 274, 574: 275, 576: 276, 577: 277, 578: 278, 580: 279, 582: 280, 583: 281, 585: 282, 587: 283, 590: 284, 591: 285, 595: 286, 598: 287, 600: 288, 601: 289, 602: 290, 607: 291, 608: 292, 609: 293, 610: 294, 617: 295, 618: 296, 621: 297, 624: 298, 625: 299, 626: 300, 627: 301, 631: 302, 632: 303, 634: 304, 638: 305, 643: 306, 644: 307, 646: 308, 650: 309, 651: 310, 654: 311, 664: 312, 668: 313, 671: 314, 672: 315, 675: 316, 678: 317, 679: 318, 680: 319, 684: 320, 686: 321, 687: 322, 690: 323, 691: 324, 692: 325, 693: 326, 694: 327, 695: 328, 698: 329, 699: 330, 710: 331, 713: 332, 715: 333, 716: 334, 717: 335, 719: 336, 720: 337, 721: 338, 722: 339, 723: 340, 725: 341, 727: 342, 728: 343, 731: 344, 732: 345, 735: 346, 736: 347, 737: 348, 740: 349, 743: 350, 745: 351, 746: 352, 747: 353, 750: 354, 751: 355, 753: 356, 756: 357, 758: 358, 763: 359, 764: 360, 768: 361, 769: 362, 770: 363, 771: 364, 775: 365, 776: 366, 777: 367, 778: 368, 783: 369, 784: 370, 786: 371, 789: 372, 790: 373, 791: 374, 794: 375, 797: 376, 798: 377, 799: 378, 800: 379, 801: 380, 804: 381, 805: 382, 807: 383, 808: 384, 811: 385, 812: 386, 813: 387, 815: 388, 817: 389, 819: 390, 822: 391, 824: 392, 825: 393, 827: 394, 829: 395, 831: 396, 834: 397, 835: 398, 836: 399, 838: 400, 841: 401, 845: 402, 846: 403, 847: 404, 852: 405, 853: 406, 856: 407, 858: 408, 860: 409, 861: 410, 865: 411, 866: 412, 867: 413, 869: 414, 870: 415, 873: 416, 874: 417, 877: 418, 878: 419, 880: 420, 881: 421, 884: 422, 888: 423, 889: 424, 897: 425, 899: 426, 906: 427, 908: 428, 909: 429, 910: 430, 911: 431, 912: 432, 913: 433, 916: 434, 918: 435, 920: 436, 921: 437, 922: 438, 923: 439, 924: 440, 925: 441, 927: 442, 928: 443, 929: 444, 931: 445, 932: 446, 934: 447, 935: 448, 937: 449, 938: 450, 944: 451, 949: 452, 950: 453, 951: 454, 956: 455, 959: 456, 960: 457, 964: 458, 965: 459, 966: 460, 968: 461, 974: 462, 977: 463, 978: 464, 980: 465, 981: 466, 983: 467, 985: 468, 989: 469, 993: 470, 996: 471, 1005: 472, 1006: 473, 1008: 474, 1009: 475, 1013: 476, 1014: 477, 1015: 478, 1016: 479, 1020: 480, 1021: 481, 1022: 482, 1024: 483, 1026: 484, 1028: 485, 1029: 486, 1034: 487, 1035: 488, 1036: 489, 1037: 490, 1040: 491, 1042: 492, 1043: 493, 1044: 494, 1046: 495, 1047: 496, 1050: 497, 1053: 498, 1054: 499, 1057: 500, 1058: 501, 1059: 502, 1060: 503, 1061: 504, 1062: 505, 1063: 506, 1064: 507, 1065: 508, 1067: 509, 1068: 510, 1069: 511, 1070: 512, 1073: 513, 1074: 514, 1077: 515, 1082: 516, 1083: 517, 1084: 518, 1085: 519, 1087: 520, 1089: 521, 1090: 522, 1092: 523, 1095: 524, 1102: 525, 1103: 526, 1104: 527, 1105: 528, 1108: 529, 1109: 530, 1118: 531, 1119: 532, 1120: 533, 1121: 534, 1122: 535, 1125: 536, 1128: 537, 1130: 538, 1131: 539, 1133: 540, 1136: 541, 1137: 542, 1139: 543, 1141: 544, 1143: 545, 1144: 546, 1145: 547, 1146: 548, 1147: 549, 1148: 550, 1149: 551, 1150: 552, 1151: 553, 1153: 554, 1154: 555, 1155: 556, 1156: 557, 1160: 558, 1161: 559, 1163: 560, 1164: 561, 1166: 562, 1170: 563, 1171: 564, 1172: 565, 1174: 566, 1175: 567, 1180: 568, 1181: 569, 1182: 570, 1183: 571, 1184: 572, 1185: 573, 1186: 574, 1187: 575, 1188: 576, 1190: 577, 1191: 578, 1192: 579, 1194: 580, 1195: 581, 1196: 582, 1199: 583, 1202: 584, 1203: 585, 1205: 586, 1207: 587, 1208: 588, 1209: 589, 1210: 590, 1211: 591, 1212: 592, 1214: 593, 1215: 594, 1216: 595, 1221: 596, 1222: 597, 1223: 598, 1224: 599, 1225: 600, 1226: 601, 1228: 602, 1229: 603, 1233: 604, 1236: 605, 1239: 606, 1241: 607, 1245: 608, 1246: 609, 1247: 610, 1248: 611, 1249: 612, 1251: 613, 1255: 614, 1256: 615, 1259: 616, 1262: 617, 1264: 618, 1265: 619, 1267: 620, 1268: 621, 1270: 622, 1271: 623, 1272: 624, 1273: 625, 1274: 626, 1276: 627, 1277: 628, 1279: 629, 1280: 630, 1282: 631, 1283: 632, 1284: 633, 1285: 634, 1287: 635, 1288: 636, 1290: 637, 1293: 638, 1298: 639, 1299: 640, 1301: 641, 1302: 642, 1305: 643, 1306: 644, 1307: 645, 1308: 646, 1310: 647, 1311: 648, 1312: 649, 1314: 650, 1317: 651, 1319: 652, 1322: 653, 1323: 654, 1324: 655, 1328: 656, 1329: 657, 1333: 658, 1337: 659, 1340: 660, 1342: 661, 1345: 662, 1347: 663, 1348: 664, 1349: 665, 1351: 666, 1352: 667, 1354: 668, 1355: 669, 1356: 670, 1357: 671, 1359: 672, 1360: 673, 1361: 674, 1362: 675, 1366: 676, 1367: 677, 1369: 678, 1370: 679, 1371: 680, 1374: 681, 1375: 682, 1376: 683, 1377: 684, 1378: 685, 1382: 686, 1383: 687, 1384: 688, 1387: 689, 1388: 690, 1390: 691, 1394: 692, 1395: 693, 1396: 694, 1397: 695, 1398: 696, 1399: 697, 1401: 698, 1403: 699, 1410: 700, 1412: 701, 1413: 702, 1414: 703, 1416: 704, 1417: 705, 1418: 706, 1419: 707, 1423: 708, 1424: 709, 1425: 710, 1429: 711, 1431: 712, 1435: 713, 1436: 714, 1438: 715, 1439: 716, 1440: 717, 1441: 718, 1444: 719, 1446: 720, 1448: 721, 1450: 722, 1452: 723, 1456: 724, 1457: 725, 1459: 726, 1460: 727, 1461: 728, 1462: 729, 1465: 730, 1468: 731, 1472: 732, 1476: 733, 1477: 734, 1478: 735, 1481: 736, 1482: 737, 1483: 738, 1484: 739, 1485: 740, 1486: 741, 1488: 742, 1490: 743, 1491: 744, 1493: 745, 1494: 746, 1497: 747, 1498: 748, 1499: 749, 1501: 750}
13 | duke_train_map = {1: 0, 8: 1, 13: 2, 14: 3, 15: 4, 16: 5, 17: 6, 18: 7, 20: 8, 22: 9, 24: 10, 26: 11, 28: 12, 29: 13, 32: 14, 36: 15, 37: 16, 38: 17, 40: 18, 41: 19, 45: 20, 48: 21, 52: 22, 54: 23, 55: 24, 57: 25, 58: 26, 59: 27, 60: 28, 62: 29, 63: 30, 64: 31, 65: 32, 67: 33, 70: 34, 71: 35, 73: 36, 74: 37, 81: 38, 82: 39, 84: 40, 85: 41, 87: 42, 93: 43, 94: 44, 96: 45, 100: 46, 102: 47, 104: 48, 105: 49, 108: 50, 110: 51, 113: 52, 116: 53, 120: 54, 121: 55, 124: 56, 129: 57, 130: 58, 131: 59, 132: 60, 133: 61, 138: 62, 139: 63, 144: 64, 146: 65, 148: 66, 152: 67, 153: 68, 154: 69, 155: 70, 156: 71, 157: 72, 160: 73, 161: 74, 165: 75, 166: 76, 168: 77, 172: 78, 173: 79, 176: 80, 177: 81, 178: 82, 179: 83, 182: 84, 185: 85, 189: 86, 190: 87, 191: 88, 193: 89, 195: 90, 196: 91, 198: 92, 202: 93, 203: 94, 208: 95, 209: 96, 216: 97, 217: 98, 222: 99, 224: 100, 225: 101, 226: 102, 227: 103, 228: 104, 231: 105, 232: 106, 233: 107, 234: 108, 236: 109, 242: 110, 245: 111, 246: 112, 248: 113, 250: 114, 252: 115, 255: 116, 258: 117, 259: 118, 263: 119, 265: 120, 271: 121, 278: 122, 280: 123, 281: 124, 282: 125, 283: 126, 284: 127, 286: 128, 289: 129, 290: 130, 291: 131, 296: 132, 297: 133, 306: 134, 307: 135, 308: 136, 309: 137, 310: 138, 312: 139, 317: 140, 318: 141, 319: 142, 320: 143, 322: 144, 325: 145, 326: 146, 327: 147, 328: 148, 330: 149, 331: 150, 333: 151, 335: 152, 336: 153, 338: 154, 339: 155, 343: 156, 345: 157, 348: 158, 349: 159, 357: 160, 362: 161, 365: 162, 366: 163, 368: 164, 370: 165, 373: 166, 374: 167, 382: 168, 383: 169, 384: 170, 385: 171, 387: 172, 388: 173, 392: 174, 393: 175, 396: 176, 397: 177, 398: 178, 401: 179, 402: 180, 403: 181, 404: 182, 406: 183, 407: 184, 411: 185, 413: 186, 417: 187, 419: 188, 421: 189, 422: 190, 423: 191, 424: 192, 425: 193, 430: 194, 432: 195, 435: 196, 436: 197, 437: 198, 438: 199, 439: 200, 440: 201, 441: 202, 443: 203, 445: 204, 446: 205, 447: 206, 448: 207, 450: 208, 452: 209, 454: 210, 456: 211, 458: 212, 463: 213, 464: 214, 465: 215, 472: 216, 473: 217, 474: 218, 478: 219, 480: 220, 481: 221, 483: 222, 485: 223, 487: 224, 489: 225, 490: 226, 491: 227, 493: 228, 496: 229, 498: 230, 502: 231, 504: 232, 505: 233, 507: 234, 510: 235, 511: 236, 512: 237, 518: 238, 519: 239, 520: 240, 521: 241, 522: 242, 524: 243, 526: 244, 528: 245, 530: 246, 531: 247, 532: 248, 534: 249, 536: 250, 544: 251, 545: 252, 546: 253, 547: 254, 548: 255, 550: 256, 556: 257, 557: 258, 558: 259, 559: 260, 561: 261, 562: 262, 563: 263, 564: 264, 566: 265, 568: 266, 569: 267, 572: 268, 573: 269, 574: 270, 575: 271, 578: 272, 579: 273, 582: 274, 585: 275, 588: 276, 589: 277, 595: 278, 598: 279, 600: 280, 602: 281, 604: 282, 606: 283, 607: 284, 610: 285, 613: 286, 614: 287, 615: 288, 616: 289, 617: 290, 618: 291, 619: 292, 622: 293, 623: 294, 624: 295, 628: 296, 630: 297, 633: 298, 634: 299, 636: 300, 637: 301, 638: 302, 639: 303, 640: 304, 642: 305, 645: 306, 650: 307, 653: 308, 655: 309, 657: 310, 658: 311, 659: 312, 660: 313, 662: 314, 664: 315, 665: 316, 666: 317, 667: 318, 668: 319, 669: 320, 670: 321, 671: 322, 673: 323, 675: 324, 677: 325, 679: 326, 682: 327, 684: 328, 687: 329, 689: 330, 692: 331, 696: 332, 697: 333, 704: 334, 708: 335, 710: 336, 713: 337, 714: 338, 715: 339, 716: 340, 719: 341, 720: 342, 721: 343, 723: 344, 724: 345, 725: 346, 727: 347, 728: 348, 730: 349, 731: 350, 732: 351, 735: 352, 737: 353, 739: 354, 740: 355, 744: 356, 745: 357, 747: 358, 751: 359, 753: 360, 759: 361, 761: 362, 762: 363, 764: 364, 767: 365, 768: 366, 770: 367, 771: 368, 774: 369, 776: 370, 778: 371, 779: 372, 780: 373, 782: 374, 783: 375, 784: 376, 785: 377, 789: 378, 793: 379, 795: 380, 796: 381, 797: 382, 798: 383, 799: 384, 802: 385, 805: 386, 808: 387, 811: 388, 813: 389, 814: 390, 815: 391, 817: 392, 819: 393, 821: 394, 825: 395, 829: 396, 831: 397, 835: 398, 836: 399, 837: 400, 839: 401, 842: 402, 843: 403, 844: 404, 848: 405, 855: 406, 859: 407, 860: 408, 883: 409, 1034: 410, 1120: 411, 1174: 412, 1239: 413, 1240: 414, 1242: 415, 1246: 416, 1248: 417, 1252: 418, 1259: 419, 1312: 420, 1333: 421, 1358: 422, 1363: 423, 1396: 424, 1397: 425, 1438: 426, 1471: 427, 1472: 428, 1501: 429, 1524: 430, 1526: 431, 1532: 432, 1542: 433, 1559: 434, 1562: 435, 1565: 436, 1587: 437, 1589: 438, 1614: 439, 1631: 440, 1636: 441, 1665: 442, 1671: 443, 1672: 444, 1693: 445, 1696: 446, 1716: 447, 1729: 448, 1732: 449, 1746: 450, 1756: 451, 1760: 452, 1767: 453, 1776: 454, 1786: 455, 1794: 456, 1812: 457, 1827: 458, 1830: 459, 1874: 460, 1879: 461, 1911: 462, 1953: 463, 1954: 464, 1973: 465, 1988: 466, 1989: 467, 1996: 468, 2004: 469, 2016: 470, 2032: 471, 2036: 472, 2044: 473, 2058: 474, 2408: 475, 2410: 476, 2420: 477, 2421: 478, 2422: 479, 2432: 480, 2435: 481, 2436: 482, 2446: 483, 2464: 484, 2469: 485, 2496: 486, 2515: 487, 2520: 488, 2529: 489, 2542: 490, 2558: 491, 2581: 492, 2597: 493, 2598: 494, 2642: 495, 2726: 496, 2735: 497, 2742: 498, 2748: 499, 2770: 500, 2953: 501, 3058: 502, 3253: 503, 3261: 504, 3344: 505, 3362: 506, 3363: 507, 3368: 508, 3370: 509, 3371: 510, 3451: 511, 3516: 512, 3520: 513, 3545: 514, 3546: 515, 3555: 516, 3582: 517, 3614: 518, 3619: 519, 3621: 520, 3680: 521, 3688: 522, 3715: 523, 3716: 524, 3732: 525, 3753: 526, 3758: 527, 3765: 528, 3776: 529, 3782: 530, 4061: 531, 4063: 532, 4064: 533, 4068: 534, 4076: 535, 4084: 536, 4096: 537, 4104: 538, 4105: 539, 4107: 540, 4108: 541, 4111: 542, 4115: 543, 4120: 544, 4132: 545, 4133: 546, 4135: 547, 4136: 548, 4140: 549, 4145: 550, 4151: 551, 4160: 552, 4164: 553, 4167: 554, 4180: 555, 4184: 556, 4186: 557, 4187: 558, 4192: 559, 4195: 560, 4198: 561, 4199: 562, 4201: 563, 4206: 564, 4208: 565, 4209: 566, 4211: 567, 4212: 568, 4215: 569, 4216: 570, 4225: 571, 4235: 572, 4237: 573, 4238: 574, 4243: 575, 4250: 576, 4258: 577, 4260: 578, 4261: 579, 4263: 580, 4275: 581, 4276: 582, 4277: 583, 4278: 584, 4279: 585, 4286: 586, 4288: 587, 4292: 588, 4301: 589, 4306: 590, 4307: 591, 4317: 592, 4323: 593, 4330: 594, 4333: 595, 4336: 596, 4344: 597, 4355: 598, 4362: 599, 4365: 600, 4387: 601, 4389: 602, 4391: 603, 4393: 604, 4406: 605, 4410: 606, 4412: 607, 4415: 608, 4417: 609, 4423: 610, 4425: 611, 4426: 612, 4430: 613, 4431: 614, 4432: 615, 4438: 616, 4445: 617, 4448: 618, 4451: 619, 4453: 620, 4461: 621, 4462: 622, 4463: 623, 4464: 624, 4472: 625, 4481: 626, 4484: 627, 4487: 628, 4488: 629, 4490: 630, 4492: 631, 4493: 632, 4495: 633, 4499: 634, 4501: 635, 4502: 636, 4509: 637, 4512: 638, 4513: 639, 4515: 640, 4520: 641, 4526: 642, 4527: 643, 4528: 644, 4532: 645, 4537: 646, 4538: 647, 4548: 648, 4551: 649, 4553: 650, 4555: 651, 4556: 652, 4567: 653, 4577: 654, 4583: 655, 4590: 656, 4597: 657, 4602: 658, 4618: 659, 4624: 660, 4625: 661, 4627: 662, 4629: 663, 4631: 664, 4656: 665, 4664: 666, 4667: 667, 4679: 668, 4683: 669, 4684: 670, 4685: 671, 4689: 672, 4690: 673, 4694: 674, 4707: 675, 4721: 676, 4728: 677, 4733: 678, 4740: 679, 4741: 680, 4751: 681, 4767: 682, 4768: 683, 4791: 684, 4796: 685, 4800: 686, 4802: 687, 4805: 688, 4810: 689, 4811: 690, 4812: 691, 4815: 692, 5251: 693, 5254: 694, 5258: 695, 5259: 696, 5339: 697, 5388: 698, 5398: 699, 7136: 700, 7140: 701}
14 | duke_test_map = {2: 0, 3: 1, 4: 2, 5: 3, 7: 4, 9: 5, 10: 6, 11: 7, 12: 8, 19: 9, 21: 10, 23: 11, 25: 12, 27: 13, 30: 14, 31: 15, 33: 16, 34: 17, 35: 18, 39: 19, 42: 20, 43: 21, 44: 22, 46: 23, 47: 24, 49: 25, 50: 26, 51: 27, 53: 28, 56: 29, 61: 30, 66: 31, 68: 32, 69: 33, 72: 34, 75: 35, 76: 36, 77: 37, 78: 38, 79: 39, 80: 40, 83: 41, 86: 42, 88: 43, 89: 44, 90: 45, 91: 46, 92: 47, 95: 48, 97: 49, 98: 50, 99: 51, 101: 52, 103: 53, 106: 54, 107: 55, 109: 56, 111: 57, 112: 58, 114: 59, 115: 60, 117: 61, 118: 62, 119: 63, 122: 64, 123: 65, 125: 66, 126: 67, 127: 68, 128: 69, 134: 70, 135: 71, 136: 72, 137: 73, 140: 74, 141: 75, 142: 76, 143: 77, 145: 78, 147: 79, 149: 80, 150: 81, 151: 82, 158: 83, 159: 84, 162: 85, 163: 86, 164: 87, 167: 88, 169: 89, 170: 90, 171: 91, 174: 92, 175: 93, 180: 94, 181: 95, 183: 96, 184: 97, 186: 98, 187: 99, 188: 100, 192: 101, 194: 102, 197: 103, 199: 104, 200: 105, 201: 106, 204: 107, 205: 108, 206: 109, 207: 110, 210: 111, 211: 112, 212: 113, 213: 114, 214: 115, 215: 116, 218: 117, 219: 118, 220: 119, 221: 120, 223: 121, 229: 122, 230: 123, 235: 124, 237: 125, 238: 126, 239: 127, 240: 128, 241: 129, 243: 130, 244: 131, 247: 132, 249: 133, 251: 134, 253: 135, 254: 136, 256: 137, 257: 138, 261: 139, 262: 140, 264: 141, 266: 142, 267: 143, 268: 144, 269: 145, 270: 146, 272: 147, 273: 148, 274: 149, 275: 150, 276: 151, 277: 152, 279: 153, 285: 154, 287: 155, 288: 156, 292: 157, 293: 158, 294: 159, 295: 160, 298: 161, 299: 162, 300: 163, 301: 164, 302: 165, 303: 166, 304: 167, 305: 168, 311: 169, 313: 170, 314: 171, 315: 172, 316: 173, 321: 174, 323: 175, 324: 176, 329: 177, 332: 178, 334: 179, 337: 180, 340: 181, 341: 182, 342: 183, 344: 184, 346: 185, 347: 186, 350: 187, 351: 188, 352: 189, 353: 190, 354: 191, 355: 192, 356: 193, 358: 194, 359: 195, 360: 196, 361: 197, 363: 198, 364: 199, 367: 200, 369: 201, 371: 202, 372: 203, 375: 204, 376: 205, 377: 206, 378: 207, 379: 208, 380: 209, 381: 210, 386: 211, 389: 212, 390: 213, 391: 214, 394: 215, 395: 216, 400: 217, 405: 218, 408: 219, 409: 220, 410: 221, 412: 222, 414: 223, 415: 224, 416: 225, 418: 226, 420: 227, 426: 228, 427: 229, 428: 230, 429: 231, 431: 232, 433: 233, 434: 234, 442: 235, 444: 236, 449: 237, 451: 238, 453: 239, 455: 240, 457: 241, 459: 242, 460: 243, 461: 244, 462: 245, 466: 246, 467: 247, 468: 248, 469: 249, 470: 250, 471: 251, 479: 252, 482: 253, 484: 254, 486: 255, 488: 256, 492: 257, 494: 258, 495: 259, 497: 260, 499: 261, 500: 262, 501: 263, 503: 264, 506: 265, 508: 266, 509: 267, 513: 268, 514: 269, 515: 270, 516: 271, 517: 272, 523: 273, 525: 274, 527: 275, 529: 276, 533: 277, 535: 278, 537: 279, 538: 280, 539: 281, 540: 282, 541: 283, 542: 284, 543: 285, 549: 286, 551: 287, 552: 288, 553: 289, 554: 290, 555: 291, 560: 292, 565: 293, 567: 294, 570: 295, 571: 296, 576: 297, 577: 298, 580: 299, 581: 300, 583: 301, 584: 302, 586: 303, 587: 304, 590: 305, 591: 306, 592: 307, 593: 308, 594: 309, 596: 310, 597: 311, 599: 312, 601: 313, 603: 314, 605: 315, 608: 316, 609: 317, 611: 318, 612: 319, 620: 320, 621: 321, 625: 322, 626: 323, 627: 324, 629: 325, 631: 326, 632: 327, 635: 328, 641: 329, 643: 330, 644: 331, 646: 332, 647: 333, 648: 334, 649: 335, 651: 336, 652: 337, 654: 338, 656: 339, 661: 340, 663: 341, 672: 342, 674: 343, 676: 344, 678: 345, 680: 346, 681: 347, 683: 348, 685: 349, 686: 350, 688: 351, 690: 352, 691: 353, 693: 354, 694: 355, 695: 356, 698: 357, 699: 358, 700: 359, 701: 360, 702: 361, 703: 362, 705: 363, 706: 364, 707: 365, 709: 366, 711: 367, 712: 368, 717: 369, 718: 370, 722: 371, 726: 372, 729: 373, 733: 374, 734: 375, 736: 376, 738: 377, 741: 378, 742: 379, 743: 380, 746: 381, 748: 382, 749: 383, 750: 384, 752: 385, 754: 386, 755: 387, 756: 388, 757: 389, 758: 390, 760: 391, 763: 392, 765: 393, 766: 394, 769: 395, 772: 396, 773: 397, 775: 398, 777: 399, 781: 400, 786: 401, 787: 402, 788: 403, 790: 404, 791: 405, 792: 406, 794: 407, 800: 408, 803: 409, 804: 410, 806: 411, 807: 412, 809: 413, 810: 414, 812: 415, 816: 416, 818: 417, 820: 418, 823: 419, 824: 420, 826: 421, 828: 422, 830: 423, 832: 424, 834: 425, 838: 426, 840: 427, 845: 428, 846: 429, 847: 430, 849: 431, 850: 432, 851: 433, 852: 434, 853: 435, 854: 436, 856: 437, 857: 438, 858: 439, 863: 440, 864: 441, 884: 442, 1104: 443, 1108: 444, 1109: 445, 1110: 446, 1226: 447, 1228: 448, 1229: 449, 1233: 450, 1243: 451, 1244: 452, 1290: 453, 1297: 454, 1300: 455, 1307: 456, 1314: 457, 1328: 458, 1343: 459, 1346: 460, 1366: 461, 1382: 462, 1386: 463, 1391: 464, 1398: 465, 1403: 466, 1408: 467, 1421: 468, 1426: 469, 1440: 470, 1463: 471, 1467: 472, 1480: 473, 1486: 474, 1487: 475, 1489: 476, 1490: 477, 1518: 478, 1555: 479, 1584: 480, 1585: 481, 1586: 482, 1598: 483, 1601: 484, 1626: 485, 1635: 486, 1637: 487, 1642: 488, 1673: 489, 1682: 490, 1698: 491, 1699: 492, 1723: 493, 1724: 494, 1725: 495, 1730: 496, 1737: 497, 1741: 498, 1745: 499, 1749: 500, 1750: 501, 1758: 502, 1759: 503, 1762: 504, 1766: 505, 1775: 506, 1782: 507, 1784: 508, 1785: 509, 1788: 510, 1790: 511, 1811: 512, 1834: 513, 1849: 514, 1893: 515, 1901: 516, 1922: 517, 1946: 518, 1949: 519, 2001: 520, 2012: 521, 2023: 522, 2053: 523, 2407: 524, 2429: 525, 2454: 526, 2470: 527, 2471: 528, 2479: 529, 2488: 530, 2495: 531, 2532: 532, 2556: 533, 2557: 534, 2573: 535, 2599: 536, 2724: 537, 2736: 538, 2754: 539, 2768: 540, 2772: 541, 2777: 542, 2942: 543, 2988: 544, 3201: 545, 3202: 546, 3259: 547, 3335: 548, 3353: 549, 3354: 550, 3358: 551, 3410: 552, 3446: 553, 3495: 554, 3515: 555, 3561: 556, 3609: 557, 3618: 558, 3638: 559, 3649: 560, 3664: 561, 3674: 562, 3731: 563, 3761: 564, 3763: 565, 4055: 566, 4057: 567, 4059: 568, 4060: 569, 4062: 570, 4065: 571, 4066: 572, 4070: 573, 4071: 574, 4072: 575, 4075: 576, 4079: 577, 4082: 578, 4099: 579, 4100: 580, 4102: 581, 4106: 582, 4110: 583, 4113: 584, 4114: 585, 4116: 586, 4117: 587, 4118: 588, 4119: 589, 4121: 590, 4128: 591, 4134: 592, 4141: 593, 4143: 594, 4144: 595, 4146: 596, 4147: 597, 4150: 598, 4152: 599, 4158: 600, 4159: 601, 4163: 602, 4169: 603, 4170: 604, 4174: 605, 4176: 606, 4177: 607, 4178: 608, 4185: 609, 4190: 610, 4197: 611, 4204: 612, 4205: 613, 4207: 614, 4210: 615, 4219: 616, 4221: 617, 4226: 618, 4227: 619, 4228: 620, 4230: 621, 4239: 622, 4245: 623, 4246: 624, 4247: 625, 4249: 626, 4254: 627, 4255: 628, 4256: 629, 4257: 630, 4271: 631, 4272: 632, 4274: 633, 4280: 634, 4284: 635, 4285: 636, 4309: 637, 4310: 638, 4315: 639, 4319: 640, 4321: 641, 4324: 642, 4326: 643, 4329: 644, 4331: 645, 4332: 646, 4334: 647, 4335: 648, 4337: 649, 4341: 650, 4349: 651, 4356: 652, 4361: 653, 4366: 654, 4372: 655, 4373: 656, 4374: 657, 4380: 658, 4386: 659, 4392: 660, 4398: 661, 4405: 662, 4411: 663, 4416: 664, 4419: 665, 4422: 666, 4427: 667, 4428: 668, 4433: 669, 4443: 670, 4447: 671, 4449: 672, 4452: 673, 4459: 674, 4460: 675, 4473: 676, 4477: 677, 4480: 678, 4483: 679, 4489: 680, 4494: 681, 4500: 682, 4503: 683, 4504: 684, 4508: 685, 4510: 686, 4511: 687, 4514: 688, 4519: 689, 4521: 690, 4540: 691, 4541: 692, 4547: 693, 4550: 694, 4558: 695, 4560: 696, 4563: 697, 4568: 698, 4572: 699, 4573: 700, 4580: 701, 4582: 702, 4587: 703, 4594: 704, 4596: 705, 4605: 706, 4606: 707, 4607: 708, 4609: 709, 4613: 710, 4622: 711, 4632: 712, 4633: 713, 4634: 714, 4639: 715, 4640: 716, 4646: 717, 4647: 718, 4654: 719, 4672: 720, 4681: 721, 4693: 722, 4695: 723, 4699: 724, 4708: 725, 4713: 726, 4717: 727, 4719: 728, 4723: 729, 4725: 730, 4726: 731, 4727: 732, 4729: 733, 4736: 734, 4739: 735, 4743: 736, 4750: 737, 4757: 738, 4758: 739, 4759: 740, 4760: 741, 4769: 742, 4772: 743, 4774: 744, 4779: 745, 4782: 746, 4789: 747, 4790: 748, 4804: 749, 4807: 750, 4808: 751, 4809: 752, 4817: 753, 4823: 754, 5249: 755, 5272: 756, 5333: 757, 5358: 758, 5474: 759, 5587: 760, 5599: 761, 5842: 762, 5849: 763, 5855: 764, 5856: 765, 5860: 766, 5867: 767, 5876: 768, 5877: 769, 5887: 770, 5889: 771, 5899: 772, 5904: 773, 5905: 774, 5906: 775, 5907: 776, 5910: 777, 5911: 778, 5920: 779, 5921: 780, 5922: 781, 5924: 782, 5927: 783, 5937: 784, 5939: 785, 5940: 786, 5941: 787, 5943: 788, 5947: 789, 5948: 790, 5949: 791, 5951: 792, 5952: 793, 5966: 794, 5970: 795, 5971: 796, 5972: 797, 5973: 798, 5974: 799, 5975: 800, 5977: 801, 5982: 802, 5985: 803, 5994: 804, 6008: 805, 6019: 806, 6031: 807, 6040: 808, 6046: 809, 6048: 810, 6049: 811, 6050: 812, 6051: 813, 6054: 814, 6056: 815, 6058: 816, 6059: 817, 6063: 818, 6068: 819, 6070: 820, 6071: 821, 6072: 822, 6073: 823, 6074: 824, 6076: 825, 6077: 826, 6084: 827, 6087: 828, 6088: 829, 6091: 830, 6093: 831, 6094: 832, 6097: 833, 6100: 834, 6101: 835, 6102: 836, 6103: 837, 6105: 838, 6109: 839, 6110: 840, 6111: 841, 6112: 842, 6115: 843, 6117: 844, 6119: 845, 6122: 846, 6123: 847, 6134: 848, 6136: 849, 6137: 850, 6139: 851, 6140: 852, 6143: 853, 6146: 854, 6147: 855, 6148: 856, 6151: 857, 6155: 858, 6156: 859, 6158: 860, 6161: 861, 6164: 862, 6166: 863, 6172: 864, 6176: 865, 6178: 866, 6179: 867, 6180: 868, 6185: 869, 6188: 870, 6189: 871, 6191: 872, 6195: 873, 6196: 874, 6198: 875, 6199: 876, 6202: 877, 6204: 878, 6205: 879, 6208: 880, 6210: 881, 6212: 882, 6213: 883, 6214: 884, 6215: 885, 6216: 886, 6219: 887, 6220: 888, 6223: 889, 6224: 890, 6225: 891, 6227: 892, 6230: 893, 6235: 894, 6236: 895, 6244: 896, 6246: 897, 6247: 898, 6252: 899, 6253: 900, 6255: 901, 6257: 902, 6258: 903, 6259: 904, 6262: 905, 6263: 906, 6264: 907, 6269: 908, 6271: 909, 6277: 910, 6279: 911, 6281: 912, 6285: 913, 6287: 914, 6290: 915, 6291: 916, 6296: 917, 6297: 918, 6299: 919, 6301: 920, 6319: 921, 6320: 922, 6328: 923, 6331: 924, 6337: 925, 6338: 926, 6339: 927, 6340: 928, 6342: 929, 6344: 930, 6345: 931, 6347: 932, 6348: 933, 6351: 934, 6352: 935, 6353: 936, 6355: 937, 6356: 938, 6357: 939, 6359: 940, 6362: 941, 6365: 942, 6366: 943, 6367: 944, 6368: 945, 6369: 946, 6370: 947, 6371: 948, 6376: 949, 6377: 950, 6389: 951, 6391: 952, 6393: 953, 6396: 954, 6397: 955, 6398: 956, 6399: 957, 6400: 958, 6402: 959, 6403: 960, 6406: 961, 6407: 962, 6408: 963, 6410: 964, 6412: 965, 6414: 966, 6415: 967, 6416: 968, 6422: 969, 6423: 970, 6429: 971, 6433: 972, 6439: 973, 6440: 974, 6441: 975, 6446: 976, 6447: 977, 6448: 978, 6449: 979, 6452: 980, 6459: 981, 6464: 982, 6465: 983, 6474: 984, 6476: 985, 6479: 986, 6481: 987, 6482: 988, 6483: 989, 6486: 990, 6489: 991, 6494: 992, 6499: 993, 6500: 994, 6502: 995, 6503: 996, 6504: 997, 6505: 998, 6506: 999, 6507: 1000, 6509: 1001, 6517: 1002, 6522: 1003, 6524: 1004, 6528: 1005, 6530: 1006, 6531: 1007, 6533: 1008, 6535: 1009, 6539: 1010, 6540: 1011, 6543: 1012, 6545: 1013, 6546: 1014, 6547: 1015, 6548: 1016, 6549: 1017, 6550: 1018, 6552: 1019, 6558: 1020, 6559: 1021, 6566: 1022, 6569: 1023, 6571: 1024, 6577: 1025, 6578: 1026, 6585: 1027, 6586: 1028, 6592: 1029, 6595: 1030, 6596: 1031, 6602: 1032, 6603: 1033, 6605: 1034, 6606: 1035, 6607: 1036, 6609: 1037, 6610: 1038, 6611: 1039, 6614: 1040, 6615: 1041, 6616: 1042, 6617: 1043, 6621: 1044, 6636: 1045, 6637: 1046, 6639: 1047, 6641: 1048, 6648: 1049, 6649: 1050, 6651: 1051, 6660: 1052, 6661: 1053, 6662: 1054, 6665: 1055, 6668: 1056, 6669: 1057, 6670: 1058, 6671: 1059, 6672: 1060, 6673: 1061, 6674: 1062, 6676: 1063, 6679: 1064, 6680: 1065, 6685: 1066, 6686: 1067, 6688: 1068, 6689: 1069, 6690: 1070, 6694: 1071, 6695: 1072, 6697: 1073, 6698: 1074, 6699: 1075, 6700: 1076, 6704: 1077, 6708: 1078, 6709: 1079, 6710: 1080, 6717: 1081, 6722: 1082, 6725: 1083, 6726: 1084, 6732: 1085, 6741: 1086, 6744: 1087, 6745: 1088, 6755: 1089, 6758: 1090, 6759: 1091, 6763: 1092, 6764: 1093, 6767: 1094, 6770: 1095, 6776: 1096, 6777: 1097, 6778: 1098, 6779: 1099, 6785: 1100, 6788: 1101, 6789: 1102, 6794: 1103, 6799: 1104, 6804: 1105, 6805: 1106, 6813: 1107, 7138: 1108, 7139: 1109}
15 |
16 | base_opt = {'workers': 4,
17 | 'split_id': 0,
18 | 'cuhk03_labeled': False,
19 | 'cuhk03_classic_split': False,
20 | 'use_metric_cuhk03': False,
21 | 'num_instances': 4,
22 | 'ReID_factor': 10, }
23 |
24 | def get_opts(name):
25 | # 1.
26 | if name == 'ide':
27 | base_opt['transform_train'] = T.Compose([T.RandomSizedRectCrop(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev), T.RandomErasing(EPSILON=0)])
28 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128), interpolation=3), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
29 |
30 | elif name == 'densenet121':
31 | base_opt['transform_train'] = T.Compose([T.Random2DTranslation(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
32 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128)), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
33 |
34 | elif name == 'mudeep':
35 | base_opt['transform_train'] = T.Compose([T.Random2DTranslation(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
36 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128)), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
37 |
38 | # 2.
39 | elif name == 'aligned':
40 | base_opt['transform_train'] = T.Compose([T.Random2DTranslation(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
41 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128)), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
42 |
43 | elif name == 'pcb':
44 | base_opt['transform_train'] = T.Compose([T.Resize((384,192), interpolation=3), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(Imagenet_mean, Imagenet_stddev)])
45 | base_opt['transform_test'] = T.Compose([T.Resize((384,192), interpolation=3), T.ToTensor(), T.Normalize(Imagenet_mean, Imagenet_stddev)])
46 | base_opt['ReID_factor'] = 2
47 | base_opt['workers'] = 16
48 |
49 | elif name == 'hacnn':
50 | base_opt['transform_train'] = T.Compose([T.Random2DTranslation(160, 64), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
51 | base_opt['transform_test'] = T.Compose([T.Resize((160, 64)), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
52 |
53 | # 3.
54 | elif name == 'cam':
55 | base_opt['transform_train'] = T.Compose([T.RandomSizedRectCrop(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev), T.RandomErasing(EPSILON=0.5)])
56 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128), interpolation=3), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
57 |
58 | elif name == 'lsro':
59 | base_opt['transform_train'] = T.Compose([T.Resize(144, interpolation=3), T.RandomCrop((256,128)), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
60 | base_opt['transform_test'] = T.Compose([T.Resize((288,144), interpolation=3), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
61 |
62 | elif name == 'hhl':
63 | base_opt['transform_train'] = T.Compose([T.RandomSizedRectCrop(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev), T.RandomErasing(EPSILON=0)])
64 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128), interpolation=3), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
65 |
66 | elif name == 'spgan':
67 | base_opt['transform_train'] = T.Compose([T.RandomSizedRectCrop(256, 128), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev), T.RandomErasing(EPSILON=0)])
68 | base_opt['transform_test'] = T.Compose([T.Resize((256, 128), interpolation=3), T.ToTensor(), T.Normalize(mean=Imagenet_mean, std=Imagenet_stddev)])
69 |
70 | return base_opt
--------------------------------------------------------------------------------