├── .gitignore ├── README.md ├── ResNet.py ├── blocks ├── DMLLoss.py ├── __init__.py ├── ms_non_local_block.py └── ms_part_guided_block.py ├── duke.py ├── model.py ├── scripts ├── res50_dpb_softmax.sh ├── res50_latent_softmax.sh └── resnet50_softmax.sh ├── train_duke.py └── utils ├── __init__.py ├── logging.py ├── random_erasing.py ├── sampler.py └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | __pycache__/ 3 | *.pyc 4 | 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # P2Net 2 | Implementation of ICCV2019 paper [Beyond Human Parts: Dual Part-Aligned Representations for Person ReID](https://arxiv.org/pdf/1910.10111.pdf) 3 | 4 | Codes from this repo can reproduce our results on DukeMTMC-reID. 5 | 6 | ## Prerequisites 7 | 8 | - Python 3.6 9 | - GPU Memory >= 6G 10 | - Numpy 11 | - Pytorch >= 0.4 12 | - Torchvision >= 0.2.0 13 | 14 | ## DukeMTMC-reID 15 | 16 | ### Dataset & Preparation 17 | Download DukeMTMC-ReID Dataset. 18 | 19 | Preparation: You may need our generated human part masks from [BaiduCloud](https://pan.baidu.com/s/18IIrRSnRN97mC8IShlmXwQ). 20 | Remember to change the dataset path to your own path in duke.py. 21 | 22 | CUHK03 human part masks from [BaiduCloud](https://pan.baidu.com/s/123ps1dHowd_17tL1dyrPjw). pwd: q39a 23 | 24 | Market-1501 human part masks from [BaiduCloud](https://pan.baidu.com/s/1ikHvcjDLEhDqyKsq0c81RA). pwd: uyus 25 | 26 | Generated human part masks from [Google Drive](https://drive.google.com/drive/folders/1iGZkYxoJA7dgmFcGTV6oPltV9vNK7ao_?usp=sharing). 27 | 28 | 29 | ### Train 30 | Train a model by 31 | ```bash 32 | cd scripts 33 | sh resnet50_softmax.sh 34 | ``` 35 | 36 | ### Results 37 | 38 | This model is based on ResNet-50. Input images are resized to 384x128. 39 | 40 | **Note that results may be better than Table 9 in the paper. (Setting here is batchsize 48 on 1 GPU)** 41 | 42 | | Method | Rank-1 | Rank-5 | Rank-10 | mAP | Model | 43 | | :----- | :-----: | :-----: | :-----: | :-----: | :-----: | 44 | | Baseline | 81.10 | 89.59 | 92.19 | 64.87 |[BaiduCloud](https://pan.baidu.com/s/1JZ_fHiqXjNDtWearwEIQ3Q) | 45 | |1 x Latent | 82.92 | 91.03 | 93.49 | 67.09 |[BaiduCloud](https://pan.baidu.com/s/1rvPB_-hOB8huqWTJuBDYSw) | 46 | |1 x DPB | 84.83 | 92.28 | 94.08 | 68.62 |[BaiduCloud](https://pan.baidu.com/s/1BSb51t8iIihyzKAyLcOgLQ) | 47 | 48 | ## Citation 49 | ``` 50 | @InProceedings{Guo_2019_ICCV, 51 | author = {Guo, Jianyuan and Yuan, Yuhui and Huang, Lang and Zhang, Chao and Yao, Jin-Ge and Han, Kai}, 52 | title = {Beyond Human Parts: Dual Part-Aligned Representations for Person Re-Identification}, 53 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 54 | month = {October}, 55 | year = {2019} 56 | } 57 | ``` 58 | 59 | ## Acknowledgement 60 | -------------------------------------------------------------------------------- /ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import pdb 6 | import functools 7 | 8 | torch_ver = torch.__version__[:3] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 66 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 67 | self.relu = nn.ReLU(inplace=False) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | out = self.conv2(out) 77 | out = self.bn2(out) 78 | out = self.relu(out) 79 | out = self.conv3(out) 80 | out = self.bn3(out) 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | out += residual 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | class ResNet(nn.Module): 90 | def __init__(self, block, layers, num_classes=1000): 91 | self.inplanes = 64 92 | super(ResNet, self).__init__() 93 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 94 | bias=False) 95 | self.bn1 = nn.BatchNorm2d(64) 96 | self.relu = nn.ReLU(inplace=False) 97 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 98 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 99 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 100 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 101 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1) 102 | self.avgpool = nn.AvgPool2d(7, stride=1) 103 | self.fc = nn.Linear(512 * block.expansion, num_classes) 104 | 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 108 | elif isinstance(m, nn.BatchNorm2d): 109 | nn.init.constant_(m.weight, 1) 110 | nn.init.constant_(m.bias, 0) 111 | 112 | def _make_layer(self, block, planes, blocks, stride=1): 113 | downsample = None 114 | if stride != 1 or self.inplanes != planes * block.expansion: 115 | downsample = nn.Sequential( 116 | nn.Conv2d(self.inplanes, planes * block.expansion, 117 | kernel_size=1, stride=stride, bias=False), 118 | nn.BatchNorm2d(planes * block.expansion), 119 | ) 120 | 121 | layers = [] 122 | layers.append(block(self.inplanes, planes, stride, downsample)) 123 | self.inplanes = planes * block.expansion 124 | for i in range(1, blocks): 125 | layers.append(block(self.inplanes, planes)) 126 | 127 | return nn.Sequential(*layers) 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | x = self.bn1(x) 132 | x = self.relu(x) 133 | x = self.maxpool(x) 134 | 135 | x = self.layer1(x) 136 | x = self.layer2(x) 137 | x = self.layer3(x) 138 | x = self.layer4(x) 139 | 140 | x = self.avgpool(x) 141 | x = x.view(x.size(0), -1) 142 | x = self.fc(x) 143 | 144 | return x 145 | 146 | 147 | def resnet50(pretrained=False, **kwargs): 148 | """Constructs a ResNet-50 model. 149 | Args: 150 | pretrained (bool): If True, returns a model pre-trained on ImageNet 151 | """ 152 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 153 | if pretrained: 154 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 155 | return model 156 | 157 | def resnet101(pretrained=False, **kwargs): 158 | """Constructs a ResNet-50 model. 159 | Args: 160 | pretrained (bool): If True, returns a model pre-trained on ImageNet 161 | """ 162 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 163 | if pretrained: 164 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 165 | return model 166 | 167 | def resnet152(pretrained=False, **kwargs): 168 | """Constructs a ResNet-50 model. 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on ImageNet 171 | """ 172 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 173 | if pretrained: 174 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 175 | return model -------------------------------------------------------------------------------- /blocks/DMLLoss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: xyliao1993@qq.com 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | from __future__ import unicode_literals 11 | 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | from torch.autograd import Variable 16 | 17 | 18 | def normalize(x, axis=-1): 19 | """Normalizing to unit length along the specified dimension. 20 | Args: 21 | x: pytorch Variable 22 | Returns: 23 | x: pytorch Variable, same shape as input 24 | """ 25 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 26 | return x 27 | 28 | 29 | def euclidean_dist(x, y): 30 | """ 31 | Args: 32 | x: pytorch Variable, with shape [m, d] 33 | y: pytorch Variable, with shape [n, d] 34 | Returns: 35 | dist: pytorch Variable, with shape [m, n] 36 | """ 37 | m, n = x.size(0), y.size(0) 38 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 39 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 40 | dist = xx + yy 41 | dist.addmm_(1, -2, x, y.t()) 42 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 43 | return dist 44 | 45 | 46 | def hard_example_mining(dist_mat, labels, return_inds=False): 47 | """For each anchor, find the hardest positive and negative sample. 48 | Args: 49 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 50 | labels: pytorch LongTensor, with shape [N] 51 | return_inds: whether to return the indices. Save time if `False`(?) 52 | Returns: 53 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 54 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 55 | p_inds: pytorch LongTensor, with shape [N]; 56 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 57 | n_inds: pytorch LongTensor, with shape [N]; 58 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 59 | NOTE: Only consider the case in which all labels have same num of samples, 60 | thus we can cope with all anchors in parallel. 61 | """ 62 | assert len(dist_mat.size()) == 2 63 | assert dist_mat.size(0) == dist_mat.size(1) 64 | N = dist_mat.size(0) 65 | 66 | # shape [N, N] 67 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 68 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 69 | 70 | # `dist_ap` means distance(anchor, positive) 71 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 72 | dist_ap, relative_p_inds = torch.max( 73 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 74 | # `dist_an` means distance(anchor, negative) 75 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 76 | dist_an, relative_n_inds = torch.min( 77 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 78 | # shape [N] 79 | dist_ap = dist_ap.squeeze(1) 80 | dist_an = dist_an.squeeze(1) 81 | 82 | if return_inds: 83 | # shape [N, N] 84 | ind = (labels.new().resize_as_(labels) 85 | .copy_(torch.arange(0, N).long()) 86 | .unsqueeze(0).expand(N, N)) 87 | # shape [N, 1] 88 | p_inds = torch.gather( 89 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 90 | n_inds = torch.gather( 91 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 92 | # shape [N] 93 | p_inds = p_inds.squeeze(1) 94 | n_inds = n_inds.squeeze(1) 95 | return dist_ap, dist_an, p_inds, n_inds 96 | 97 | return dist_ap, dist_an 98 | 99 | 100 | class TripletLoss(object): 101 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 102 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 103 | Loss for Person Re-Identification'.""" 104 | 105 | def __init__(self, margin=None): 106 | self.margin = margin 107 | if margin is not None: 108 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 109 | else: 110 | self.ranking_loss = nn.SoftMarginLoss() 111 | 112 | def __call__(self, global_feat, labels, normalize_feature=False): 113 | if normalize_feature: 114 | global_feat = normalize(global_feat, axis=-1) 115 | dist_mat = euclidean_dist(global_feat, global_feat) 116 | dist_ap, dist_an = hard_example_mining( 117 | dist_mat, labels) 118 | y = dist_an.data.new().resize_as_(dist_an.data).fill_(1) 119 | if self.margin is not None: 120 | loss = self.ranking_loss(dist_an, dist_ap, Variable(y)) 121 | else: 122 | loss = self.ranking_loss(dist_an - dist_ap, Variable(y)) 123 | return loss, dist_ap, dist_an 124 | 125 | 126 | class CrossEntropyLabelSmooth(nn.Module): 127 | """Cross entropy loss with label smoothing regularizer. 128 | Reference: 129 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 130 | Equation: y = (1 - epsilon) * y + epsilon / K. 131 | Args: 132 | num_classes (int): number of classes. 133 | epsilon (float): weight. 134 | """ 135 | 136 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 137 | super(CrossEntropyLabelSmooth, self).__init__() 138 | self.num_classes = num_classes 139 | self.epsilon = epsilon 140 | self.use_gpu = use_gpu 141 | self.logsoftmax = nn.LogSoftmax(dim=1) 142 | 143 | def forward(self, inputs, targets): 144 | """ 145 | Args: 146 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 147 | targets: ground truth labels with shape (num_classes) 148 | """ 149 | log_probs = self.logsoftmax(inputs) 150 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu().data, 1) 151 | if self.use_gpu: targets = targets.cuda() 152 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 153 | targets = Variable(targets) 154 | loss = (- targets * log_probs).mean(0).sum() 155 | return loss 156 | -------------------------------------------------------------------------------- /blocks/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | from .ms_part_guided_block import * 4 | from .ms_non_local_block import * 5 | from .DMLLoss import * 6 | -------------------------------------------------------------------------------- /blocks/ms_non_local_block.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Human part / Latent part / self-attention 4 | # Edited by Jianyuan Guo 5 | # jyguo@pku.edu.cn 6 | # 2019.10 7 | 8 | import torch 9 | import os 10 | import sys 11 | import pdb 12 | from torch import nn 13 | from torch.nn import functional as F 14 | import functools 15 | from torch.nn import init 16 | import matplotlib.pyplot as plt 17 | from PIL import Image 18 | import random 19 | import numpy as np 20 | import matplotlib.cm as cm 21 | 22 | 23 | def weights_init_kaiming(m): 24 | classname = m.__class__.__name__ 25 | if classname.find('Conv') != -1: 26 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 27 | elif classname.find('Linear') != -1: 28 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 29 | init.constant_(m.bias.data, 0.0) 30 | elif classname.find('BatchNorm1d') != -1: 31 | init.constant_(m.weight.data, 0.0) 32 | init.constant_(m.bias.data, 0.0) 33 | elif classname.find('BatchNorm2d') != -1: 34 | init.constant_(m.weight.data, 0.0) 35 | init.constant_(m.bias.data, 0.0) 36 | 37 | 38 | class _MS_NonLocalBlockND(nn.Module): 39 | ''' 40 | Modified based on the _NonLocalBlockND, compute the affinity based feature maps of two scales. 41 | thus to get the context information of the specified scale. 42 | Input: 43 | N X C X H X W 44 | Parameters: 45 | in_channels: the dimension of the input feature map 46 | c1 : the dimension of W_theta and W_phi 47 | c2 : the dimension of W_g and W_rho 48 | bn_layer : whether use BN within W_rho 49 | use_g / use_w: whether use the W_g transform and W_rho transform 50 | scale : choose the scale to downsample the input feature maps 51 | Return: 52 | N X C X H X W 53 | ''' 54 | def __init__(self, in_channels, c1, c2, out_channels=None, mode='embedded_gaussian', 55 | sub_sample=False, bn_layer=False, use_g=True, use_w=True, scale=1, vis=False): 56 | super(_MS_NonLocalBlockND, self).__init__() 57 | 58 | assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation'] 59 | self.vis = vis 60 | self.mode = mode 61 | self.use_g = use_g 62 | self.use_w = use_w 63 | self.scale = scale 64 | self.in_channels = in_channels 65 | self.out_channels = out_channels 66 | self.inter_channels = c1 67 | self.context_channels = in_channels 68 | if use_g: 69 | self.context_channels = c2 70 | if out_channels == None: 71 | self.out_channels = in_channels 72 | 73 | self.pool = nn.AvgPool2d(kernel_size=(scale, scale)) 74 | self.theta = nn.Sequential( 75 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 76 | kernel_size=1, stride=1, padding=0, bias=True) 77 | ) 78 | self.phi = nn.Sequential( 79 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 80 | kernel_size=1, stride=1, padding=0, bias=True) 81 | ) 82 | self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.context_channels, 83 | kernel_size=1, stride=1, padding=0, bias=True) 84 | self.W = nn.Sequential( 85 | nn.Conv2d(in_channels=self.context_channels, out_channels=self.out_channels, 86 | kernel_size=1, stride=1, padding=0, bias=True), 87 | nn.BatchNorm2d(self.out_channels) 88 | ) 89 | self.theta.apply(weights_init_kaiming) 90 | self.phi.apply(weights_init_kaiming) 91 | self.g.apply(weights_init_kaiming) 92 | self.W.apply(weights_init_kaiming) 93 | 94 | def forward(self, x, path=None): 95 | batch_size, h, w = x.size(0), x.size(2), x.size(3) 96 | x_scale = self.pool(x) 97 | if self.use_g: 98 | g_x = self.g(x_scale).view(batch_size, self.context_channels, -1) 99 | else: 100 | g_x = x_scale.view(batch_size, self.context_channels, -1) 101 | g_x = g_x.permute(0, 2, 1) 102 | 103 | theta_x = self.theta(x_scale).view(batch_size, self.inter_channels, -1) 104 | theta_x = theta_x.permute(0, 2, 1) 105 | phi_x = self.phi(x_scale).view(batch_size, self.inter_channels, -1) 106 | 107 | f = torch.matmul(theta_x, phi_x) 108 | f = (self.inter_channels**-.5) * f 109 | f_div_C = F.softmax(f, dim=-1) 110 | 111 | y = torch.matmul(f_div_C, g_x) 112 | y = y.permute(0, 2, 1).contiguous() 113 | y = y.view(batch_size, self.context_channels, *x.size()[2:]) 114 | 115 | if self.use_w: 116 | W_y = self.W(y) 117 | else: 118 | W_y = y 119 | return W_y 120 | 121 | 122 | class MS_NONLocalBlock2D(_MS_NonLocalBlockND): 123 | def __init__(self, in_channels, c1=None, c2=None, out_channels=None, mode='embedded_gaussian', 124 | bn_layer=False, use_g=True, use_w=True, scale=1, vis=False): 125 | super(MS_NONLocalBlock2D, self).__init__(in_channels, 126 | c1=c1, 127 | c2=c2, 128 | out_channels=out_channels, 129 | mode=mode, 130 | bn_layer=bn_layer, 131 | use_g=use_g, 132 | use_w=use_w, 133 | scale=scale, 134 | vis=vis) 135 | 136 | 137 | class MSPyramidAttentionContextModule(nn.Module): 138 | """ 139 | Parameters: 140 | in_features / out_features: the channels of the input / output feature maps. 141 | dropout: specify the dropout ratio 142 | fusion: We provide two different fusion method, "concat" or "add" 143 | sizes: compute the attention based on diverse scales based context 144 | Return: 145 | features after "concat" or "add" 146 | """ 147 | def __init__(self, in_channels, out_channels, c1, c2, dropout=0, fusion="concat", sizes=(1,4,8,16), use_head_bn=False, if_gc=0, vis=False, norm=0): 148 | super(MSPyramidAttentionContextModule, self).__init__() 149 | self.norm = norm 150 | self.if_gc = if_gc 151 | self.fusion = fusion 152 | self.stages = [] 153 | self.group = len(sizes) 154 | self.c1 = c1 155 | self.c2 = c2 156 | self.stages = nn.ModuleList([self._make_stage(in_channels, self.c1, self.c2, in_channels//self.group, size, use_head_bn=use_head_bn, vis=vis) for size in sizes]) 157 | self.bottleneck_add = nn.Sequential( 158 | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0), 159 | nn.BatchNorm2d(out_channels), 160 | nn.Dropout2d(dropout) 161 | ) 162 | if self.if_gc == 1: 163 | channel_in = 3 164 | else: 165 | channel_in = 2 166 | self.bottleneck_concat = nn.Sequential( 167 | nn.Conv2d(in_channels*channel_in, out_channels, kernel_size=1, padding=0), 168 | nn.BatchNorm2d(out_channels), 169 | nn.Dropout2d(dropout) 170 | ) 171 | self.bottleneck_add.apply(weights_init_kaiming) 172 | self.bottleneck_concat.apply(weights_init_kaiming) 173 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 174 | 175 | self.W = nn.Sequential( 176 | nn.Conv2d(in_channels=in_channels, out_channels=in_channels, 177 | kernel_size=1, stride=1, padding=0, bias=True), 178 | nn.BatchNorm2d(in_channels) 179 | ) 180 | self.W.apply(weights_init_kaiming) 181 | 182 | def _make_stage(self, in_channels, c1, c2, out_channels, size, use_head_bn, vis): 183 | return MS_NONLocalBlock2D(in_channels=in_channels, c1=c1, c2=c2, out_channels=out_channels, mode='dot_product', use_g=True, use_w=True, scale=size, bn_layer=use_head_bn, vis=vis) 184 | 185 | def forward(self, feats, parsing=None, path=None): 186 | if parsing is not None: 187 | batch_size, channel, h, w = feats.size(0), feats.size(1), feats.size(2), feats.size(3) 188 | x = F.interpolate(input=feats, size=((h,w)), mode='bilinear',align_corners=True) 189 | value = x.view(batch_size, channel, -1) 190 | value = value.permute(0, 2, 1) 191 | label = F.interpolate(input=parsing.unsqueeze(1).type(torch.cuda.FloatTensor), size=((h,w)), mode='nearest') 192 | label_row_vec = label.view(batch_size, 1, -1).expand(batch_size, h*w, h*w) 193 | label_col_vec = label_row_vec.permute(0, 2, 1) 194 | pair_label = label_col_vec.eq(label_row_vec) 195 | sim_map = F.normalize(pair_label.type(torch.cuda.FloatTensor), p=1, dim=2) 196 | context = torch.matmul(sim_map, value) 197 | context = context.permute(0, 2, 1).contiguous() 198 | context = context.view(batch_size, channel, *x.size()[2:]) 199 | context = F.interpolate(input=context, size=((h,w)), mode='bilinear',align_corners=True) 200 | parsing = self.W(context) 201 | 202 | if self.norm > 0: 203 | feats_norm = self.norm * F.normalize(feats, p=2, dim=1) 204 | priors = [stage(feats_norm) for stage in self.stages] 205 | else: 206 | priors = [stage(feats, path) for stage in self.stages] 207 | 208 | if self.if_gc >= 1: 209 | batch_size, c, h, w = feats.size(0), feats.size(1), feats.size(2), feats.size(3) 210 | gc = self.avgpool(feats) 211 | if self.norm > 0: 212 | gc = self.norm * F.normalize(gc, p=2, dim=1) 213 | gc = gc.repeat(1,1,h,w) 214 | 215 | if self.fusion == "concat": 216 | context = feats 217 | for i in range(len(priors)): 218 | context = torch.cat([context, priors[i]], 1) 219 | if self.if_gc == 1: 220 | bottle = self.bottleneck_concat(torch.cat([context, 0.5*gc], 1)) 221 | else: 222 | bottle = self.bottleneck_concat(context) # torch.cat([context, parsing], 1) 223 | return bottle 224 | elif self.fusion == 'add': 225 | context = priors[0] 226 | for i in range(1, len(priors)): 227 | context += priors[i] 228 | bottle = self.bottleneck_add(context + feats) 229 | elif self.fusion == '+': 230 | context = [priors[0]] 231 | for i in range(1, len(priors)): 232 | context += [priors[i]] 233 | if self.if_gc == 1: 234 | if parsing is not None: 235 | bottle = torch.cat(context, 1) + parsing + feats + gc 236 | else: 237 | bottle = torch.cat(context, 1) + feats + gc 238 | else: 239 | if parsing is not None: 240 | bottle = torch.cat(context, 1) + parsing + feats 241 | else: 242 | bottle = torch.cat(context, 1) + feats 243 | 244 | return bottle -------------------------------------------------------------------------------- /blocks/ms_part_guided_block.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Human part module 4 | # Edited by Jianyuan Guo 5 | # jyguo@pku.edu.cn 6 | # 2019.10 7 | 8 | import torch 9 | import os 10 | import sys 11 | import pdb 12 | from torch import nn 13 | from torch.nn import functional as F 14 | import functools 15 | from torch.nn import init 16 | import matplotlib.pyplot as plt 17 | from PIL import Image 18 | from torchvision import transforms 19 | import random 20 | import numpy as np 21 | import matplotlib.cm as cm 22 | from scipy.misc import imresize 23 | 24 | 25 | def weights_init_kaiming(m): 26 | classname = m.__class__.__name__ 27 | if classname.find('Conv') != -1: 28 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 29 | elif classname.find('Linear') != -1: 30 | init.kaiming_normal(m.weight.data, a=0, mode='fan_out') 31 | init.constant(m.bias.data, 0.0) 32 | elif classname.find('BatchNorm1d') != -1: 33 | init.constant(m.weight.data, 0.0) 34 | init.constant(m.bias.data, 0.0) 35 | elif classname.find('BatchNorm2d') != -1: 36 | init.constant(m.weight.data, 0.0) 37 | init.constant(m.bias.data, 0.0) 38 | 39 | 40 | class MS_PartGuidedBlock(nn.Module): 41 | ''' 42 | Compute the affinity based feature maps and segmentation map of two scales. 43 | Thus to get the context information of the specified scale. 44 | Input: 45 | N X C X H X W 46 | Parameters: 47 | in_channels : the dimension of the input feature map 48 | value_channels : the dimension of W_g 49 | bn_layer : whether use BN within W_rho 50 | scale : choose the scale to downsample the input feature maps 51 | Return: 52 | N X C X H X W 53 | position-aware context features.(w/o concate or add with the input) 54 | ''' 55 | def __init__(self, in_channels, value_channels, out_channels=None, mode='embedded_gaussian', 56 | bn_layer=True, scale=1, vis=False, choice=1): 57 | super(MS_PartGuidedBlock, self).__init__() 58 | 59 | assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation'] 60 | self.vis = vis 61 | self.choice = choice 62 | self.mode = mode 63 | self.scale = scale 64 | self.in_channels = in_channels 65 | self.out_channels = out_channels 66 | self.value_channels = value_channels 67 | if out_channels == None: 68 | self.out_channels = in_channels 69 | 70 | self.pool = nn.AvgPool2d(kernel_size=(scale, scale)) 71 | 72 | def forward(self, x, label, path=None): 73 | batch_size, h0, w0 = x.size(0), x.size(2), x.size(3) 74 | if self.scale > 1: 75 | x = self.pool(x) 76 | 77 | h,w = h0,w0 78 | if self.choice == 1: 79 | x = F.interpolate(input=x, size=((h,w)), mode='bilinear',align_corners=True) 80 | value = x.view(batch_size, self.in_channels, -1) 81 | elif self.choice == 2: 82 | x = self.value_2(x) 83 | x = F.interpolate(input=x, size=((h,w)), mode='bilinear',align_corners=True) 84 | value = x.view(batch_size, self.value_channels, -1) 85 | elif self.choice == 3: 86 | x = self.value_3(x) 87 | x = F.interpolate(input=x, size=((h,w)), mode='bilinear',align_corners=True) 88 | value = x.view(batch_size, self.value_channels, -1) 89 | 90 | value = value.permute(0, 2, 1) 91 | label = F.interpolate(input=label.unsqueeze(1).type(torch.cuda.FloatTensor), size=((h,w)), mode='nearest') 92 | 93 | label_row_vec = label.view(batch_size, 1, -1).expand(batch_size, h*w, h*w) 94 | label_col_vec = label_row_vec.permute(0, 2, 1) 95 | pair_label = label_col_vec.eq(label_row_vec) 96 | 97 | # background use global, commented by Huang Lang 98 | label_col_vec = 1-label_col_vec 99 | label_col_vec[label_col_vec<0]=0 100 | pair_label = pair_label.long() + label_col_vec.long() 101 | pair_label[pair_label>0]=1 102 | 103 | sim_map = F.normalize(pair_label.type(torch.cuda.FloatTensor), p=1, dim=2) 104 | 105 | context = torch.matmul(sim_map, value) 106 | context = context.permute(0, 2, 1).contiguous() 107 | if self.choice == 1: 108 | context = context.view(batch_size, self.in_channels, *x.size()[2:]) 109 | context = F.interpolate(input=context, size=((h0,w0)), mode='bilinear',align_corners=True) 110 | elif self.choice == 2: 111 | context = context.view(batch_size, self.value_channels, *x.size()[2:]) 112 | context = F.interpolate(input=context, size=((h0,w0)), mode='bilinear',align_corners=True) 113 | context = self.W(context) 114 | elif self.choice == 3: 115 | context = context.view(batch_size, self.value_channels, *x.size()[2:]) 116 | context = F.interpolate(input=context, size=((h0,w0)), mode='bilinear',align_corners=True) 117 | context = self.W(context) 118 | 119 | return context 120 | 121 | 122 | class MSPartGuidedModule(nn.Module): 123 | """ 124 | Parameters: 125 | in_features / out_features: the channels of the input / output feature maps. 126 | dropout: specify the dropout ratio 127 | fusion: We provide two different fusion method, "concat" or "add" 128 | sizes: compute the attention based on diverse scales based context 129 | Return: 130 | features after "concat" or "add" 131 | """ 132 | def __init__(self, in_channels, out_channels, value_channels, fusion="concat", sizes=(1,4,8,16), vis=False, choice=1): 133 | super(MSPartGuidedModule, self).__init__() 134 | 135 | self.fusion = fusion 136 | self.stages = [] 137 | self.group = len(sizes) 138 | self.stages = nn.ModuleList([MS_PartGuidedBlock(in_channels, value_channels, in_channels//self.group, scale=size, vis=vis, choice=choice) for size in sizes]) 139 | self.bottleneck_add = nn.Sequential( 140 | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False), 141 | nn.BatchNorm2d(out_channels), 142 | # nn.LeakyReLU(0.5) 143 | ) 144 | self.bottleneck_concat = nn.Sequential( 145 | nn.Conv2d(in_channels*2, out_channels, kernel_size=1, stride=1, padding=0, bias=True), 146 | nn.BatchNorm2d(out_channels), 147 | #nn.ReLU() 148 | ) 149 | self.bottleneck_concat.apply(weights_init_kaiming) 150 | self.bottleneck_add.apply(weights_init_kaiming) 151 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 152 | 153 | def forward(self, feats, label, path=None): 154 | priors = [stage(feats, label, path) for stage in self.stages] 155 | 156 | if self.fusion == "concat": 157 | context = feats 158 | for i in range(len(priors)): 159 | context = torch.cat([context, priors[i]], 1) 160 | bottle = self.bottleneck_concat(context) 161 | elif self.fusion == 'add': 162 | context = priors[0] 163 | for i in range(1, len(priors)): 164 | context += priors[i] 165 | bottle = self.bottleneck_add(context + feats) 166 | elif self.fusion == '+': 167 | context = [priors[0]] 168 | for i in range(1, len(priors)): 169 | context += [priors[i]] 170 | bottle = self.bottleneck_add(torch.cat(context, 1)) + feats 171 | 172 | return bottle 173 | 174 | 175 | if __name__=='__main__': 176 | pass -------------------------------------------------------------------------------- /duke.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # DukeMTMC-reID dataset loader 4 | # Edited by Jianyuan Guo 5 | # jyguo@pku.edu.cn 6 | # laynehuang@pku.edu.cn 7 | # 2019.10 8 | 9 | import sys 10 | import numpy as np 11 | import random 12 | import torch 13 | from torch.utils.data import Dataset 14 | from torchvision import transforms 15 | from PIL import Image 16 | from scipy.misc import imresize 17 | import pdb 18 | 19 | 20 | class DukeDataset(Dataset): 21 | def __init__(self, image_root='/home/guojianyuan/Desktop/ReId/data/DukeMTMC-reID/', txt_root=None, mode='train', transform=None): 22 | self.mode = mode 23 | self.transform = transform 24 | self.image_root = image_root 25 | if self.mode == 'train': 26 | self.name_list = np.genfromtxt(self.image_root + 'train_list.txt', dtype=str, delimiter=' ', usecols=[0]) 27 | self.label_list = np.genfromtxt(self.image_root + 'train_list.txt', dtype=int, delimiter=' ', usecols=[1]) 28 | 29 | def __getitem__(self, index): 30 | img = Image.open(self.image_root + self.name_list[index]) 31 | img = self.transform(img) 32 | label = self.label_list[index] 33 | return img, label 34 | 35 | def __len__(self): 36 | return len(self.name_list) 37 | 38 | 39 | class DukePartDataset(Dataset): 40 | def __init__(self, image_root='/home/guojianyuan/Desktop/ReId/data/DukeMTMC-reID/', parsing_root="/home/huanglang/research/CE2P/duke/5classes/", mode='train', transform=None): 41 | self.mode = mode 42 | self.transform = transform 43 | self.image_root = image_root 44 | self.parsing_root = parsing_root 45 | supported_modes = ('train', 'query', 'gallery') 46 | assert self.mode in supported_modes, print("Only support mode from {}".format(supported_modes)) 47 | self.name_list = np.genfromtxt(image_root + self.mode + '_list.txt', dtype=str, delimiter=' ', usecols=[0]) 48 | self.label_list = np.genfromtxt(image_root + self.mode + '_list.txt', dtype=int, delimiter=' ', usecols=[1]) 49 | 50 | def __getitem__(self, index): 51 | img = Image.open(self.image_root + self.name_list[index]) 52 | part_map = Image.open(self.parsing_root + self.name_list[index][:-3] + "png") 53 | 54 | if self.mode == 'train' and random.random() < 0.5: 55 | img = transforms.functional.hflip(img) 56 | part_map = transforms.functional.hflip(part_map) 57 | 58 | transforms_tensor = transforms.Compose([transforms.ToTensor()]) 59 | img_tensor = transforms_tensor(img) 60 | img = self.transform(img) 61 | part_map = imresize(part_map, (96, 32), interp="nearest") 62 | part_map = torch.from_numpy(np.asarray(part_map, dtype=np.float)) 63 | label = self.label_list[index] 64 | return img, label, part_map, part_map # img_tensor # For other purpose 65 | 66 | def __len__(self): 67 | return len(self.name_list) 68 | 69 | 70 | if __name__ == '__main__': 71 | pass 72 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torch.autograd import Variable 5 | import pdb 6 | import ResNet 7 | 8 | from blocks import MSPyramidAttentionContextModule 9 | 10 | 11 | def weights_init_kaiming(m): 12 | classname = m.__class__.__name__ 13 | # print(classname) 14 | if classname.find('Conv') != -1: 15 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 16 | elif classname.find('Linear') != -1: 17 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 18 | init.constant_(m.bias.data, 0.0) 19 | elif classname.find('BatchNorm1d') != -1: 20 | init.normal_(m.weight.data, 1.0, 0.02) 21 | init.constant_(m.bias.data, 0.0) 22 | elif classname.find('BatchNorm2d') != -1: 23 | init.normal_(m.weight.data, 1.0, 0.02) 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | def weights_init_classifier(m): 27 | classname = m.__class__.__name__ 28 | if classname.find('Linear') != -1: 29 | init.normal_(m.weight.data, std=0.001) 30 | init.constant_(m.bias.data, 0.0) 31 | elif classname.find('BatchNorm1d') != -1: 32 | init.normal_(m.weight.data, 1.0, 0.02) 33 | init.constant_(m.bias.data, 0.0) 34 | elif classname.find('BatchNorm2d') != -1: 35 | init.normal_(m.weight.data, 1.0, 0.02) 36 | init.constant_(m.bias.data, 0.0) 37 | 38 | 39 | class ClassBlock(nn.Module): 40 | def __init__(self, input_dim, class_num, dropout=True, relu=True, num_bottleneck=512): 41 | super(ClassBlock, self).__init__() 42 | add_block_1 = [] 43 | add_block_1 += [nn.Linear(input_dim, num_bottleneck)] 44 | add_block_1 += [nn.BatchNorm1d(num_bottleneck)] 45 | if relu: 46 | add_block_1 += [nn.LeakyReLU(0.5)] 47 | #add_block_1 += [nn.SELU()] 48 | if dropout: 49 | add_block_1 += [nn.Dropout(p=0.3)] 50 | add_block_1 = nn.Sequential(*add_block_1) 51 | add_block_1.apply(weights_init_kaiming) 52 | 53 | classifier_1 = [] 54 | classifier_1 += [nn.Linear(num_bottleneck, class_num)] 55 | classifier_1 = nn.Sequential(*classifier_1) 56 | classifier_1.apply(weights_init_classifier) 57 | 58 | self.add_block_1 = add_block_1 59 | self.classifier_1 = classifier_1 60 | 61 | def forward(self, x): 62 | x = self.add_block_1(x) 63 | x = self.classifier_1(x) 64 | 65 | return x 66 | 67 | 68 | class P2Net(nn.Module): 69 | def __init__(self, class_num=702, fusion='+', layer='50', block_num=0): 70 | super().__init__() 71 | if layer == '50': 72 | backbone = ResNet.resnet50(pretrained =True) 73 | elif layer == '152': 74 | backbone = ResNet.resnet152(pretrained=True) 75 | 76 | # avg pooling to global pooling 77 | backbone.avgpool = nn.AdaptiveAvgPool2d((1,1)) 78 | 79 | self.model = backbone 80 | self.block_num = block_num 81 | self.classifier = ClassBlock(2048, class_num, dropout=False, relu=True, num_bottleneck=256) 82 | 83 | self.context_l2_1 = MSPyramidAttentionContextModule(in_channels=1024, out_channels=1024, c1=512, c2=512, 84 | dropout=0, fusion=fusion, sizes=([1]), if_gc=0) if block_num > 0 else nn.Sequential() 85 | ''' 86 | self.context_l2_2 = MSPyramidAttentionContextModule(in_channels=2048, out_channels=2048, c1=1024, c2=1024, 87 | dropout=0, fusion='+', sizes=([1]), if_gc=0) 88 | self.context_l3_1 = MSPyramidAttentionContextModule(in_channels=512, out_channels=512, c1=256, c2=256, 89 | dropout=0, fusion='+', sizes=([1]), if_gc=0) 90 | self.context_l3_2 = MSPyramidAttentionContextModule(in_channels=512, out_channels=512, c1=256, c2=256, 91 | dropout=0, fusion='+', sizes=([1]), if_gc=0) 92 | self.context_l3_3 = MSPyramidAttentionContextModule(in_channels=1024, out_channels=1024, c1=512, c2=512, 93 | dropout=0, fusion='+', sizes=([1]), if_gc=0) 94 | ''' 95 | self.context_l2_2 = nn.Sequential() 96 | self.context_l3_1 = nn.Sequential() 97 | self.context_l3_2 = nn.Sequential() 98 | self.context_l3_3 = nn.Sequential() 99 | 100 | def forward(self, x, part_map=None, path=None): 101 | if x.dim() == 3: 102 | x = x.unsqueeze(0) 103 | x = self.model.conv1(x) 104 | x = self.model.bn1(x) 105 | x = self.model.relu(x) 106 | x = self.model.maxpool(x) 107 | x = self.model.layer1(x) 108 | 109 | x = self.model.layer2(x) 110 | 111 | x = self.model.layer3(x) 112 | if self.block_num > 0: 113 | x = self.context_l2_1(x, part_map) 114 | 115 | x = self.model.layer4(x) 116 | x = self.model.avgpool(x) 117 | x = torch.squeeze(x) 118 | if x.dim() == 1: 119 | x = x.unsqueeze(0) 120 | 121 | feature = self.classifier.add_block_1(x) 122 | category = self.classifier.classifier_1(feature) 123 | 124 | return category, feature, x, x 125 | 126 | 127 | if __name__ == '__main__': 128 | net = P2Net(751) 129 | input = Variable(torch.FloatTensor(8,3,256,128).cuda()) 130 | net=net.cuda() 131 | output = net(input) 132 | -------------------------------------------------------------------------------- /scripts/res50_dpb_softmax.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../ 4 | 5 | GPUID=6 6 | NAME="res50_dpb_softmax_test" 7 | BATCHSIZE=48 8 | FILE_NAME="results" 9 | BLOCK_NUM=1 10 | FUSION='+' 11 | PMAP=1 12 | 13 | 14 | python -u train_duke.py --gpu_ids ${GPUID} --name ${NAME} --batchsize ${BATCHSIZE} --file_name ${FILE_NAME} --block_num ${BLOCK_NUM} --fusion ${FUSION} --pmap ${PMAP} -------------------------------------------------------------------------------- /scripts/res50_latent_softmax.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../ 4 | 5 | GPUID=6 6 | NAME="res50_latent_softmax" 7 | BATCHSIZE=48 8 | FILE_NAME="results" 9 | BLOCK_NUM=1 10 | FUSION='+' 11 | 12 | 13 | python -u train_duke.py --gpu_ids ${GPUID} --name ${NAME} --batchsize ${BATCHSIZE} --file_name ${FILE_NAME} --block_num ${BLOCK_NUM} --fusion ${FUSION} -------------------------------------------------------------------------------- /scripts/resnet50_softmax.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../ 4 | 5 | GPUID=5 6 | NAME="res50_softmax" 7 | BATCHSIZE=48 8 | FILE_NAME="results" 9 | 10 | 11 | python -u train_duke.py --gpu_ids ${GPUID} --name ${NAME} --batchsize ${BATCHSIZE} --file_name ${FILE_NAME} -------------------------------------------------------------------------------- /train_duke.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import argparse 5 | import json 6 | import os 7 | import pdb 8 | import sys 9 | import scipy.io 10 | import time 11 | from PIL import Image 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.backends.cudnn as cudnn 17 | import torchvision 18 | from torch.autograd import Variable 19 | from torchvision import datasets, models, transforms 20 | 21 | from utils import RandomIdentitySampler, logging, RandomErasing 22 | from utils.test_utils import * 23 | from model import P2Net 24 | from duke import * 25 | from blocks import TripletLoss 26 | 27 | 28 | parser = argparse.ArgumentParser(description='Training') 29 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1') 30 | parser.add_argument('--name',default='ResNet50', type=str, help='output model name') 31 | parser.add_argument('--batchsize', default=48, type=int, help='batchsize') 32 | parser.add_argument('--block_num', default=0, type=int, help='dual part block number') 33 | parser.add_argument('--weight_decay', default=5e-4, type=float, help='training weight decay') 34 | parser.add_argument('--PCB', action='store_true', help='if use PCB+ResNet50' ) 35 | parser.add_argument('--era', default=0, type=float, help='Random Erasing probability, in [0,1]') 36 | parser.add_argument('--fusion', default='concat', type=str, help='self-attention-choice') 37 | parser.add_argument('--loss', default='softmax', type=str, help='choice of loss') 38 | parser.add_argument('--layer', default='50', type=str, help='Resnet-layer') 39 | parser.add_argument('--num_instances', default=8, type=int, help='for triplet loss') 40 | parser.add_argument('--epoch', default=60, type=int, help='training epoch') 41 | parser.add_argument('--margin', default=4, type=float, help='triplet loss margin') 42 | parser.add_argument('--file_name', default='result', type=str, help='file name to save') 43 | parser.add_argument('--pmap', default=False, help='use part_map') 44 | parser.add_argument('--mat', default='', type=str, help='name for saving representation' ) 45 | opt = parser.parse_args() 46 | 47 | sys.stdout = logging.Logger(os.path.join('/home/guojianyuan/ReID_Duke/'+opt.file_name+'/'+opt.name+'/', 'log.txt')) 48 | tripletloss = TripletLoss(opt.margin) 49 | 50 | gpu_ids = [] 51 | str_gpu_ids = opt.gpu_ids.split(',') 52 | for str_id in str_gpu_ids: 53 | gpu_ids.append(int(str_id)) 54 | torch.cuda.set_device(gpu_ids[0]) 55 | 56 | # Load Data 57 | if opt.pmap: 58 | transform_train_list = [ 59 | transforms.Resize((384,128), interpolation=3), 60 | transforms.ToTensor(), 61 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 62 | ] 63 | else: 64 | transform_train_list = [ 65 | transforms.Resize((384,128), interpolation=3), 66 | transforms.RandomHorizontalFlip(), 67 | transforms.ToTensor(), 68 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 69 | ] 70 | 71 | if opt.era > 0: 72 | transform_train_list = transform_train_list + [RandomErasing(probability = opt.era, mean=[0.0, 0.0, 0.0])] 73 | 74 | data_transforms = { 75 | 'train': transforms.Compose(transform_train_list) 76 | } 77 | 78 | if 'tri' not in opt.loss: 79 | if opt.pmap: 80 | cls_datasets = DukePartDataset(transform=data_transforms['train']) 81 | else: 82 | cls_datasets = DukeDataset(transform=data_transforms['train']) 83 | cls_loader = torch.utils.data.DataLoader( 84 | cls_datasets, 85 | batch_size=opt.batchsize, 86 | shuffle=True, 87 | num_workers=10, 88 | drop_last=True) 89 | dataset_sizes_allSample = len(cls_loader) 90 | else: 91 | cls_datasets = DukeDataset(transform=data_transforms['train']) 92 | triplet_loader = torch.utils.data.DataLoader( 93 | triplet_datasets, 94 | sampler=RandomIdentitySampler(triplet_datasets, opt.num_instances), 95 | batch_size=opt.batchsize, num_workers=10, 96 | drop_last=True) 97 | dataset_sizes_metricSample = len(triplet_loader) 98 | 99 | use_gpu = torch.cuda.is_available() 100 | 101 | 102 | def test(model): 103 | model = model.eval() 104 | print('-' * 10) 105 | print('test model now...') 106 | data_transforms = transforms.Compose([ 107 | transforms.Resize((384,128), interpolation=3), 108 | transforms.ToTensor(), 109 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 110 | ]) 111 | 112 | data_dir = '/home/guojianyuan/Desktop/ReId/data/Duke_pytorch_eva' 113 | if opt.pmap: 114 | image_datasets = {x: DukePartDataset(mode=x, transform=data_transforms) for x in ['gallery','query']} 115 | else: 116 | image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), data_transforms) for x in ['gallery','query']} 117 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=8, 118 | shuffle=False, num_workers=10) for x in ['gallery','query']} 119 | 120 | if opt.pmap: 121 | gallery_path = image_datasets['gallery'].name_list 122 | query_path = image_datasets['query'].name_list 123 | gallery_cam,gallery_label = get_id_with_part_map(gallery_path) 124 | query_cam,query_label = get_id_with_part_map(query_path) 125 | else: 126 | gallery_path = image_datasets['gallery'].imgs 127 | query_path = image_datasets['query'].imgs 128 | gallery_cam,gallery_label = get_id(gallery_path) 129 | query_cam,query_label = get_id(query_path) 130 | 131 | # Extract feature 132 | if opt.pmap: 133 | query_feature, query_feature_embed = extract_feature_with_part_map(model,dataloaders['query']) 134 | gallery_feature, gallery_feature_embed = extract_feature_with_part_map(model,dataloaders['gallery']) 135 | else: 136 | gallery_feature, gallery_feature_embed = extract_feature(model,dataloaders['gallery']) 137 | query_feature, query_feature_embed = extract_feature(model,dataloaders['query']) 138 | 139 | # Save to Matlab for check 140 | result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam} 141 | scipy.io.savemat('./'+opt.file_name+'/'+opt.name+'/'+opt.mat+'.mat',result) 142 | result = scipy.io.loadmat('./'+opt.file_name+'/'+opt.name+'/'+opt.mat+'.mat') 143 | 144 | query_feature = result['query_f'] 145 | query_cam = result['query_cam'][0] 146 | query_label = result['query_label'][0] 147 | gallery_feature = result['gallery_f'] 148 | gallery_cam = result['gallery_cam'][0] 149 | gallery_label = result['gallery_label'][0] 150 | 151 | CMC = torch.IntTensor(len(gallery_label)).zero_() 152 | ap = 0.0 153 | for i in range(len(query_label)): 154 | ap_tmp, CMC_tmp = evaluate(query_feature[i],query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam) 155 | if CMC_tmp[0]==-1: 156 | continue 157 | CMC = CMC + CMC_tmp 158 | ap += ap_tmp 159 | 160 | CMC = CMC.float() 161 | CMC = CMC/len(query_label) #average CMC 162 | print('Pool5 top1:%f top5:%f top10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label))) 163 | 164 | result = {'gallery_f':gallery_feature_embed.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature_embed.numpy(),'query_label':query_label,'query_cam':query_cam} 165 | scipy.io.savemat('./'+opt.file_name+'/'+opt.name+'/'+opt.mat+'.mat',result) 166 | result = scipy.io.loadmat('./'+opt.file_name+'/'+opt.name+'/'+opt.mat+'.mat') 167 | 168 | query_feature = result['query_f'] 169 | query_cam = result['query_cam'][0] 170 | query_label = result['query_label'][0] 171 | gallery_feature = result['gallery_f'] 172 | gallery_cam = result['gallery_cam'][0] 173 | gallery_label = result['gallery_label'][0] 174 | 175 | CMC = torch.IntTensor(len(gallery_label)).zero_() 176 | ap = 0.0 177 | 178 | for i in range(len(query_label)): 179 | ap_tmp, CMC_tmp = evaluate(query_feature[i],query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam) 180 | if CMC_tmp[0]==-1: 181 | continue 182 | CMC = CMC + CMC_tmp 183 | ap += ap_tmp 184 | 185 | CMC = CMC.float() 186 | CMC = CMC/len(query_label) #average CMC 187 | print('Embed top1:%f top5:%f top10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label))) 188 | 189 | 190 | def train_model(model, optimizer, scheduler, num_epochs=25): 191 | start_time = time.time() 192 | 193 | for epoch in range(num_epochs): 194 | if epoch == 0: 195 | save_network(model, epoch) 196 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 197 | print('-' * 10) 198 | 199 | # Each epoch has a training and validation phase 200 | for phase in ['train']: 201 | if phase == 'train': 202 | if 'tri' in opt.loss: 203 | adjust_lr_triplet(optimizer, epoch) 204 | else: 205 | adjust_lr_softmax(optimizer, epoch) 206 | model.train(True) # Set model to training mode 207 | else: 208 | model.train(False) # Set model to evaluate mode 209 | 210 | running_loss = 0.0 211 | running_correct_category = 0 212 | 213 | # Iterate over data. 214 | if 'tri' in opt.loss: 215 | dataloaders = triplet_loader 216 | else: 217 | dataloaders = cls_loader 218 | for data in dataloaders: 219 | if opt.pmap: 220 | inputs, labels, part_map, inputs_ori = data 221 | else: 222 | inputs, labels = data 223 | if use_gpu: 224 | inputs = Variable(inputs.cuda()) 225 | labels = Variable(labels.cuda()) 226 | if opt.pmap: 227 | part_map = Variable(part_map.cuda()) 228 | inputs_ori = Variable(inputs_ori.cuda()) 229 | else: 230 | inputs, labels = Variable(inputs), Variable(labels) 231 | 232 | # zero the parameter gradients 233 | optimizer.zero_grad() 234 | 235 | # forward 236 | if opt.pmap: 237 | [category, feature, pool5, embed] = model(inputs, part_map=part_map, path=inputs_ori) 238 | else: 239 | [category, feature, pool5,_] = model(inputs) 240 | 241 | if not opt.PCB: 242 | _,category_preds = torch.max(category.data, 1) 243 | 244 | if opt.loss == 'softmax': 245 | loss = criterion_softmax(category, labels) 246 | elif opt.loss == 'labelsmooth': 247 | loss = criterion_labelsmooth(category, labels) 248 | elif opt.loss == 'triplet': 249 | loss,_,_ = criterion_triplet(feature, labels) 250 | elif opt.loss == 'softmax+triplet': 251 | loss_softmax = criterion_softmax(category, labels) 252 | loss_triplet,_,_ = criterion_triplet(feature, labels) 253 | loss = loss_softmax + loss_triplet 254 | elif opt.loss == 'labelsmooth+triplet': 255 | loss_softmax = criterion_labelsmooth(category, labels) 256 | loss_triplet,_,_ = criterion_triplet(feature, labels) 257 | loss = loss_softmax + loss_triplet 258 | 259 | else: 260 | part = {} 261 | sm = nn.Softmax(dim=1) 262 | num_part = 6 263 | for i in range(num_part): 264 | part[i] = outputs[i] 265 | score = sm(part[0]) + sm(part[1]) +sm(part[2]) + sm(part[3]) +sm(part[4]) +sm(part[5]) 266 | _, preds = torch.max(score.data, 1) 267 | loss = criterion(part[0], labels) 268 | for i in range(num_part-1): 269 | loss += criterion(part[i+1], labels) 270 | 271 | # backward + optimize only if in training phase 272 | if phase == 'train': 273 | loss.backward() 274 | optimizer.step() 275 | 276 | # statistics 277 | running_loss += loss.item() 278 | running_correct_category += torch.sum(category_preds == labels.data) 279 | 280 | if 'tri' not in opt.loss: 281 | epoch_loss = running_loss / dataset_sizes_allSample / opt.batchsize 282 | epoch_acc = running_correct_category.cpu().numpy() / dataset_sizes_allSample / opt.batchsize 283 | else: 284 | epoch_loss = running_loss / dataset_sizes_metricSample / opt.batchsize 285 | epoch_acc = running_correct_category.cpu().numpy() / dataset_sizes_metricSample / opt.batchsize 286 | 287 | print('{} Loss: {:.4f} Acc_category: {:.4f}'.format( 288 | phase, epoch_loss, epoch_acc)) 289 | 290 | if 'tri' not in opt.loss: 291 | if epoch == 59 or epoch == 0: 292 | save_network(model, epoch) 293 | test(model) 294 | else: 295 | if epoch == 299 or epoch == 249 or epoch == 0: 296 | save_network(model, epoch) 297 | test(model) 298 | 299 | time_elapsed = time.time() - start_time 300 | print('Training complete in {:.0f}m {:.0f}s'.format( 301 | time_elapsed // 60, time_elapsed % 60)) 302 | 303 | return model 304 | 305 | 306 | # Save model 307 | def save_network(network, epoch_label): 308 | save_filename = 'net_%s.pth'% epoch_label 309 | save_path = os.path.join('./' + opt.file_name, opt.name, save_filename) 310 | torch.save(network.cpu().state_dict(), save_path) 311 | if torch.cuda.is_available: 312 | network.cuda(gpu_ids[0]) 313 | 314 | 315 | def load_network(network, path): 316 | pretrained_dict = torch.load(path) 317 | model_dict = network.state_dict() 318 | pretrained_dict = {k: v for k,v in pretrained_dict.items() if k in model_dict} 319 | model_dict.update(pretrained_dict) 320 | network.load_state_dict(model_dict) 321 | return network 322 | 323 | # Finetuning the convnet 324 | model = P2Net(class_num=702, fusion=opt.fusion, layer=opt.layer, block_num=opt.block_num) 325 | 326 | if use_gpu: 327 | cudnn.enabled = True 328 | cudnn.benchmark = True 329 | if len(gpu_ids)>1: 330 | model = torch.nn.DataParallel(model, device_ids=gpu_ids).cuda() 331 | else: 332 | model = model.cuda() 333 | 334 | criterion_softmax = nn.CrossEntropyLoss() 335 | criterion_triplet = tripletloss 336 | 337 | # Train and evaluate 338 | dir_name = os.path.join('./' + opt.file_name, opt.name) 339 | if not os.path.isdir(dir_name): 340 | os.makedirs(dir_name) 341 | 342 | # save opts 343 | with open('%s/opts.json'%dir_name,'w') as fp: 344 | json.dump(vars(opt), fp, indent=1) 345 | 346 | if len(opt.gpu_ids)>1: 347 | ignored_params = list(map(id, model.module.model.fc.parameters() )) + list(map(id, model.module.classifier.parameters() )) +\ 348 | list(map(id, model.module.context_l3_1.parameters() )) + list(map(id, model.module.context_l3_2.parameters() )) +\ 349 | list(map(id, model.module.context_l3_3.parameters() )) + list(map(id, model.module.context_l2_1.parameters() )) +\ 350 | list(map(id, model.module.context_l2_2.parameters() )) 351 | base_params = filter(lambda p: id(p) not in ignored_params, model.module.parameters() ) 352 | optimizer = torch.optim.SGD([ 353 | {'params': base_params}, 354 | {'params': model.module.model.fc.parameters()}, 355 | {'params': model.module.classifier.parameters()}, 356 | {'params': model.module.context_l3_1.parameters()}, 357 | {'params': model.module.context_l3_2.parameters()}, 358 | {'params': model.module.context_l3_3.parameters()}, 359 | {'params': model.module.context_l2_1.parameters()}, 360 | {'params': model.module.context_l2_2.parameters()} 361 | ], lr=0.1, weight_decay=opt.weight_decay, momentum=0.9, nesterov=True) 362 | else: 363 | ignored_params = list(map(id, model.model.fc.parameters() )) + list(map(id, model.classifier.parameters() )) +\ 364 | list(map(id, model.context_l3_1.parameters() )) + list(map(id, model.context_l3_2.parameters() )) +\ 365 | list(map(id, model.context_l3_3.parameters() )) + list(map(id, model.context_l2_1.parameters() )) +\ 366 | list(map(id, model.context_l2_2.parameters() )) 367 | base_params = filter(lambda p: id(p) not in ignored_params, model.parameters() ) 368 | optimizer = torch.optim.SGD([ 369 | {'params': base_params}, 370 | {'params': model.model.fc.parameters()}, 371 | {'params': model.classifier.parameters()}, 372 | {'params': model.context_l3_1.parameters()}, 373 | {'params': model.context_l3_2.parameters()}, 374 | {'params': model.context_l3_3.parameters()}, 375 | {'params': model.context_l2_1.parameters()}, 376 | {'params': model.context_l2_2.parameters()} 377 | ], lr=0.1, weight_decay=opt.weight_decay, momentum=0.9, nesterov=True) 378 | 379 | 380 | def adjust_lr_triplet(optimizer, ep): 381 | if ep < 20: 382 | lr = 1e-2 * (ep + 1) / 2 383 | elif ep < 130: 384 | lr = 1e-1 385 | elif ep < 200: 386 | lr = 1e-2 387 | elif ep < 240: 388 | lr = 1e-3 389 | elif ep < 280: 390 | lr = 1e-3 * (ep - 240 + 1) / 40 391 | elif ep < 340: 392 | lr = 1e-3 393 | for index in range(len(optimizer.param_groups)): 394 | if index == 0: 395 | optimizer.param_groups[index]['lr'] = lr * 0.1 396 | else: 397 | optimizer.param_groups[index]['lr'] = lr 398 | 399 | 400 | def adjust_lr_softmax(optimizer, ep): 401 | if ep < 40: 402 | lr = 0.1 403 | elif ep < 60: 404 | lr = 0.01 405 | else: 406 | lr = 0.001 407 | for index in range(len(optimizer.param_groups)): 408 | if index == 0: 409 | optimizer.param_groups[index]['lr'] = lr * 0.1 410 | else: 411 | optimizer.param_groups[index]['lr'] = lr 412 | 413 | 414 | model = train_model(model, optimizer, None, num_epochs=opt.epoch) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .sampler import RandomIdentitySampler, RandomIdentitySamplerWithPartMap 4 | from .random_erasing import RandomErasing 5 | from .test_utils import * 6 | import torch 7 | 8 | def to_numpy(tensor): 9 | if torch.is_tensor(tensor): 10 | return tensor.cpu().numpy() 11 | elif type(tensor).__module__ != 'numpy': 12 | raise ValueError("Cannot convert {} to numpy array" 13 | .format(type(tensor))) 14 | return tensor 15 | 16 | 17 | def to_torch(ndarray): 18 | if type(ndarray).__module__ == 'numpy': 19 | return torch.from_numpy(ndarray) 20 | elif not torch.is_tensor(ndarray): 21 | raise ValueError("Cannot convert {} to torch tensor" 22 | .format(type(ndarray))) 23 | return ndarray 24 | 25 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | import sys 5 | 6 | 7 | class Logger(object): 8 | def __init__(self, fpath=None): 9 | self.console = sys.stdout 10 | self.file = None 11 | if fpath is not None: 12 | try: 13 | os.makedirs(os.path.dirname(fpath)) 14 | except OSError as e: 15 | if e.errno != errno.EEXIST: 16 | raise 17 | self.file = open(fpath, 'w') 18 | 19 | def __del__(self): 20 | self.close() 21 | 22 | def __enter__(self): 23 | pass 24 | 25 | def __exit__(self, *args): 26 | self.close() 27 | 28 | def write(self, msg): 29 | self.console.write(msg) 30 | if self.file is not None: 31 | self.file.write(msg) 32 | 33 | def flush(self): 34 | self.console.flush() 35 | if self.file is not None: 36 | self.file.flush() 37 | os.fsync(self.file.fileno()) 38 | 39 | def close(self): 40 | self.console.close() 41 | if self.file is not None: 42 | self.file.close() 43 | -------------------------------------------------------------------------------- /utils/random_erasing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class RandomErasing(object): 12 | """ Randomly selects a rectangle region in an image and erases its pixels. 13 | 'Random Erasing Data Augmentation' by Zhong et al. 14 | See https://arxiv.org/pdf/1708.04896.pdf 15 | Args: 16 | probability: The probability that the Random Erasing operation will be performed. 17 | sl: Minimum proportion of erased area against input image. 18 | sh: Maximum proportion of erased area against input image. 19 | r1: Minimum aspect ratio of erased area. 20 | mean: Erasing value. 21 | """ 22 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 23 | self.probability = probability 24 | self.mean = mean 25 | self.sl = sl 26 | self.sh = sh 27 | self.r1 = r1 28 | 29 | def __call__(self, img): 30 | if random.uniform(0, 1) > self.probability: 31 | return img 32 | for attempt in range(100): 33 | area = img.size()[1] * img.size()[2] 34 | target_area = random.uniform(self.sl, self.sh) * area 35 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 36 | h = int(round(math.sqrt(target_area * aspect_ratio))) 37 | w = int(round(math.sqrt(target_area / aspect_ratio))) 38 | 39 | if w < img.size()[2] and h < img.size()[1]: 40 | x1 = random.randint(0, img.size()[1] - h) 41 | y1 = random.randint(0, img.size()[2] - w) 42 | if img.size()[0] == 3: 43 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 44 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 45 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 46 | else: 47 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 48 | return img 49 | return img -------------------------------------------------------------------------------- /utils/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import ( 7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 8 | WeightedRandomSampler) 9 | 10 | 11 | class RandomIdentitySampler(Sampler): 12 | def __init__(self, data_source, num_instances=1): 13 | self.data_source = data_source 14 | self.num_instances = num_instances 15 | self.index_dic = defaultdict(list) 16 | for index, (_, pid) in enumerate(data_source): 17 | self.index_dic[pid].append(index) 18 | 19 | self.pids = list(self.index_dic.keys()) 20 | self.num_samples = len(self.pids) 21 | 22 | def __len__(self): 23 | return self.num_samples * self.num_instances 24 | 25 | def __iter__(self): 26 | indices = torch.randperm(self.num_samples) 27 | ret = [] 28 | for i in indices: 29 | pid = self.pids[i] 30 | t = self.index_dic[pid] 31 | if len(t) >= self.num_instances: 32 | t = np.random.choice(t, size=self.num_instances, replace=False) 33 | else: 34 | t = np.random.choice(t, size=self.num_instances, replace=True) 35 | ret.extend(t) 36 | return iter(ret) 37 | 38 | 39 | class RandomIdentitySamplerWithPartMap(Sampler): 40 | def __init__(self, data_source, num_instances=1): 41 | self.data_source = data_source 42 | self.num_instances = num_instances 43 | self.index_dic = defaultdict(list) 44 | for index, (_, pid, _, _) in enumerate(data_source): 45 | self.index_dic[pid].append(index) 46 | 47 | self.pids = list(self.index_dic.keys()) 48 | self.num_samples = len(self.pids) 49 | 50 | def __len__(self): 51 | return self.num_samples * self.num_instances 52 | 53 | def __iter__(self): 54 | indices = torch.randperm(self.num_samples) 55 | ret = [] 56 | for i in indices: 57 | pid = self.pids[i] 58 | t = self.index_dic[pid] 59 | if len(t) >= self.num_instances: 60 | t = np.random.choice(t, size=self.num_instances, replace=False) 61 | else: 62 | t = np.random.choice(t, size=self.num_instances, replace=True) 63 | ret.extend(t) 64 | return iter(ret) 65 | -------------------------------------------------------------------------------- /utils/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | import scipy.io 5 | import pdb 6 | 7 | 8 | def fliplr(img): 9 | '''flip horizontal''' 10 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W 11 | img_flip = img.index_select(3,inv_idx) 12 | return img_flip 13 | 14 | 15 | def extract_feature(model,dataloaders): 16 | pool5_features = torch.FloatTensor() 17 | embed_features = torch.FloatTensor() 18 | count = 0 19 | for data in dataloaders: 20 | img, label = data 21 | n, c, h, w = img.size() 22 | count += n 23 | for i in range(2): 24 | if(i==1): 25 | img = fliplr(img) 26 | input_img = Variable(img.cuda()) 27 | _, embed_feature, pool5_feature,_ = model(input_img) 28 | if(i==0): 29 | ff_pool5 = torch.FloatTensor(n,pool5_feature.size(1)).zero_() 30 | ff_embed = torch.FloatTensor(n,embed_feature.size(1)).zero_() 31 | f_pool5 = pool5_feature.data.cpu() 32 | ff_pool5 = ff_pool5 + f_pool5 33 | f_embed = embed_feature.data.cpu() 34 | ff_embed = ff_embed + f_embed 35 | fnorm_pool5 = torch.norm(ff_pool5, p=2, dim=1, keepdim=True) 36 | fnorm_embed = torch.norm(ff_embed, p=2, dim=1, keepdim=True) 37 | ff_pool5 = ff_pool5.div(fnorm_pool5.expand_as(ff_pool5)) 38 | ff_embed = ff_embed.div(fnorm_embed.expand_as(ff_embed)) 39 | pool5_features = torch.cat((pool5_features,ff_pool5), 0) 40 | embed_features = torch.cat((embed_features,ff_embed), 0) 41 | return pool5_features, embed_features 42 | 43 | 44 | def extract_feature_with_part_map(model,dataloaders): 45 | pool5_features = torch.FloatTensor() 46 | embed_features = torch.FloatTensor() 47 | count = 0 48 | for data in dataloaders: 49 | img, label, part_map, _ = data 50 | n, c, h, w = img.size() 51 | count += n 52 | for i in range(2): 53 | if(i==1): 54 | img = fliplr(img) 55 | part_map = fliplr(part_map.unsqueeze(1)).squeeze() 56 | input_img = Variable(img.cuda()) 57 | input_part_map = Variable(part_map.cuda()) 58 | _, embed_feature, pool5_feature,_ = model(input_img, input_part_map) 59 | if(i==0): 60 | ff_pool5 = torch.FloatTensor(n,pool5_feature.size(1)).zero_() 61 | ff_embed = torch.FloatTensor(n,embed_feature.size(1)).zero_() 62 | f_pool5 = pool5_feature.data.cpu() 63 | ff_pool5 = ff_pool5 + f_pool5 64 | f_embed = embed_feature.data.cpu() 65 | ff_embed = ff_embed + f_embed 66 | fnorm_pool5 = torch.norm(ff_pool5, p=2, dim=1, keepdim=True) 67 | fnorm_embed = torch.norm(ff_embed, p=2, dim=1, keepdim=True) 68 | ff_pool5 = ff_pool5.div(fnorm_pool5.expand_as(ff_pool5)) 69 | ff_embed = ff_embed.div(fnorm_embed.expand_as(ff_embed)) 70 | pool5_features = torch.cat((pool5_features,ff_pool5), 0) 71 | embed_features = torch.cat((embed_features,ff_embed), 0) 72 | return pool5_features, embed_features 73 | 74 | 75 | def extract_feature_embed(model,dataloaders): 76 | features = torch.FloatTensor() 77 | count = 0 78 | for data in dataloaders: 79 | img, label = data 80 | n, c, h, w = img.size() 81 | count += n 82 | ff = torch.FloatTensor(n,256).zero_() 83 | for i in range(2): 84 | if(i==1): 85 | img = fliplr(img) 86 | input_img = Variable(img.cuda()) 87 | _,output,_,_ = model(input_img) 88 | f = output.data.cpu() 89 | ff = ff+f 90 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 91 | ff = ff.div(fnorm.expand_as(ff)) 92 | features = torch.cat((features,ff), 0) 93 | return features 94 | 95 | 96 | def extract_feature_256(model,dataloaders): 97 | features = torch.FloatTensor() 98 | count = 0 99 | for data in dataloaders: 100 | img, label = data 101 | n, c, h, w = img.size() 102 | count += n 103 | ff = torch.FloatTensor(n,256).zero_() 104 | for i in range(2): 105 | if(i==1): 106 | img = fliplr(img) 107 | input_img = Variable(img.cuda()) 108 | _,_,_,output = model(input_img) 109 | f = output.data.cpu() 110 | ff = ff+f 111 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 112 | ff = ff.div(fnorm.expand_as(ff)) 113 | features = torch.cat((features,ff), 0) 114 | return features 115 | 116 | 117 | def extract_feature_HPM(model,dataloaders): 118 | features = torch.FloatTensor() 119 | count = 0 120 | for data in dataloaders: 121 | img, label = data 122 | n, c, h, w = img.size() 123 | count += n 124 | ff = torch.FloatTensor(n,2840).zero_() 125 | for i in range(2): 126 | if(i==1): 127 | img = fliplr(img) 128 | input_img = Variable(img.cuda()) 129 | _,_,output = model(input_img) 130 | f = output.data.cpu() 131 | ff = ff+f 132 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 133 | ff = ff.div(fnorm.expand_as(ff)) 134 | features = torch.cat((features,ff), 0) 135 | return features 136 | 137 | 138 | def get_id(img_path): 139 | camera_id = [] 140 | labels = [] 141 | for path, v in img_path: 142 | filename = path.split('/')[-1] 143 | label = filename[0:4] 144 | camera = filename.split('c')[1] 145 | if label[0:2]=='-1': 146 | labels.append(-1) 147 | else: 148 | labels.append(int(label)) 149 | camera_id.append(int(camera[0])) 150 | return camera_id, labels 151 | 152 | 153 | def get_id_with_part_map(img_path): 154 | # -1_c1s1_000401_03.jpg 155 | camera_id = [] 156 | labels = [] 157 | for path in img_path: 158 | filename = path.split('/')[-1] 159 | label = filename.split('_')[0] 160 | camera = filename.split('c')[1] 161 | if label[0:2]=='-1': 162 | labels.append(-1) 163 | else: 164 | labels.append(int(label)) 165 | camera_id.append(int(camera[0])) 166 | return camera_id, labels 167 | 168 | 169 | def evaluate(qf,ql,qc,gf,gl,gc): 170 | query = qf 171 | score = np.dot(gf,query) 172 | # predict index 173 | index = np.argsort(score) #from small to large 174 | index = index[::-1] 175 | #index = index[0:2000] 176 | # good index 177 | query_index = np.argwhere(gl==ql) 178 | camera_index = np.argwhere(gc==qc) 179 | 180 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 181 | junk_index1 = np.argwhere(gl==-1) 182 | junk_index2 = np.intersect1d(query_index, camera_index) 183 | junk_index = np.append(junk_index2, junk_index1) #.flatten()) 184 | 185 | CMC_tmp = compute_mAP(index, good_index, junk_index) 186 | return CMC_tmp 187 | 188 | 189 | def compute_mAP(index, good_index, junk_index): 190 | ap = 0 191 | cmc = torch.IntTensor(len(index)).zero_() 192 | if good_index.size==0: # if empty 193 | cmc[0] = -1 194 | return ap,cmc 195 | 196 | # remove junk_index 197 | mask = np.in1d(index, junk_index, invert=True) 198 | index = index[mask] 199 | 200 | # find good_index index 201 | ngood = len(good_index) 202 | mask = np.in1d(index, good_index) 203 | rows_good = np.argwhere(mask==True) 204 | rows_good = rows_good.flatten() 205 | 206 | cmc[rows_good[0]:] = 1 207 | for i in range(ngood): 208 | d_recall = 1.0/ngood 209 | precision = (i+1)*1.0/(rows_good[i]+1) 210 | if rows_good[i]!=0: 211 | old_precision = i*1.0/rows_good[i] 212 | else: 213 | old_precision=1.0 214 | ap = ap + d_recall*(old_precision + precision)/2 215 | 216 | return ap, cmc --------------------------------------------------------------------------------