├── README.md └── models ├── HPM.py ├── ResNet.py └── __init__.py /README.md: -------------------------------------------------------------------------------- 1 | # Horizontal Pyramid Matching for Person Re-identification(HPM) 2 | 3 | ### Citing HPM 4 | This repository contains the the core source codes of proposed HPM, which may help you to reproduce the performance reported in the paper. If you find this repository or the HPM approach useful in your research, please consider citing: 5 | 6 | @article{fu2018horizontal, 7 | title={Horizontal Pyramid Matching for Person Re-identification}, 8 | author={Fu, Yang and Wei, Yunchao and Zhou, Yuqian and Shi, Honghui and Huang, Gao and Wang, Xinchao and Yao, Zhiqiang and Huang, Thomas}, 9 | journal={AAAI}, 10 | year={2019} 11 | } 12 | -------------------------------------------------------------------------------- /models/HPM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torchvision import models 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | from .ResNet import resnet50 8 | # from random_erasing import RandomErasing_vertical, RandomErasing_2x2 9 | import math 10 | 11 | __all__ = ['HPM'] 12 | ###################################################################### 13 | def weights_init_kaiming(m): 14 | classname = m.__class__.__name__ 15 | # print(classname) 16 | if classname.find('Conv') != -1: 17 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 18 | elif classname.find('Linear') != -1: 19 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 20 | init.constant(m.bias.data, 0.0) 21 | elif classname.find('BatchNorm1d') != -1: 22 | init.normal_(m.weight.data, 1.0, 0.02) 23 | init.constant_(m.bias.data, 0.0) 24 | 25 | def weights_init_classifier(m): 26 | classname = m.__class__.__name__ 27 | if classname.find('Linear') != -1: 28 | init.normal_(m.weight.data, std=0.001) 29 | # init.constant(m.bias.data, 0.0) 30 | 31 | def weight_init(m): 32 | if isinstance(m, nn.Conv2d): 33 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 34 | m.weight.data.normal_(0, math.sqrt(2. / n)) 35 | elif isinstance(m, nn.BatchNorm2d): 36 | m.weight.data.fill_(1) 37 | m.bias.data.zero_() 38 | elif isinstance(m, nn.Linear): 39 | m.weight.data.normal_(0, 0.001) 40 | 41 | 42 | 43 | def pcb_block(num_ftrs, num_stripes, local_conv_out_channels, num_classes, avg=False): 44 | if avg: 45 | pooling_list = nn.ModuleList([nn.AdaptiveAvgPool2d(1) for _ in range(num_stripes)]) 46 | else: 47 | pooling_list = nn.ModuleList([nn.AdaptiveMaxPool2d(1) for _ in range(num_stripes)]) 48 | conv_list = nn.ModuleList([nn.Conv2d(num_ftrs, local_conv_out_channels, 1, bias=False) for _ in range(num_stripes)]) 49 | batchnorm_list = nn.ModuleList([nn.BatchNorm2d(local_conv_out_channels) for _ in range(num_stripes)]) 50 | relu_list = nn.ModuleList([nn.ReLU(inplace=True) for _ in range(num_stripes)]) 51 | fc_list = nn.ModuleList([nn.Linear(local_conv_out_channels, num_classes, bias=False) for _ in range(num_stripes)]) 52 | for m in conv_list: 53 | weight_init(m) 54 | for m in batchnorm_list: 55 | weight_init(m) 56 | for m in fc_list: 57 | weight_init(m) 58 | return pooling_list, conv_list, batchnorm_list, relu_list, fc_list 59 | 60 | 61 | def spp_vertical(feats, pool_list, conv_list, bn_list, relu_list, fc_list, num_strides, feat_list=[], logits_list=[]): 62 | for i in range(num_strides): 63 | pcb_feat = pool_list[i](feats[:, :, i * int(feats.size(2) / num_strides): (i+1) * int(feats.size(2) / num_strides), :]) 64 | pcb_feat = conv_list[i](pcb_feat) 65 | pcb_feat = bn_list[i](pcb_feat) 66 | pcb_feat = relu_list[i](pcb_feat) 67 | pcb_feat = pcb_feat.view(pcb_feat.size(0), -1) 68 | feat_list.append(pcb_feat) 69 | logits_list.append(fc_list[i](pcb_feat)) 70 | return feat_list, logits_list 71 | 72 | def global_pcb(feats, pool, conv, bn, relu, fc, feat_list=[], logits_list=[]): 73 | global_feat = pool(feats) 74 | global_feat = conv(global_feat) 75 | global_feat = bn(global_feat) 76 | global_feat = relu(global_feat) 77 | global_feat = global_feat.view(feats.size(0), -1) 78 | feat_list.append(global_feat) 79 | logits_list.append(fc(global_feat)) 80 | return feat_list, logits_list 81 | 82 | 83 | 84 | 85 | class HPM(nn.Module): 86 | def __init__(self, num_classes, num_stripes=6, local_conv_out_channels=256, erase=0, loss={'xent'}, avg=False, **kwargs): 87 | super(HPM, self).__init__() 88 | self.erase = erase 89 | self.num_stripes = num_stripes 90 | self.loss = loss 91 | 92 | model_ft = resnet50(pretrained=True, remove_last=True, last_conv_stride=1) 93 | self.num_ftrs = list(model_ft.layer4)[-1].conv1.in_channels 94 | self.features = model_ft 95 | # PSP 96 | # self.psp_pool, self.psp_conv, self.psp_bn, self.psp_relu, self.psp_upsample, self.conv = psp_block(self.num_ftrs) 97 | 98 | # global 99 | self.global_pooling = nn.AdaptiveMaxPool2d(1) 100 | self.global_conv = nn.Conv2d(self.num_ftrs, local_conv_out_channels, 1, bias=False) 101 | self.global_bn = nn.BatchNorm2d(local_conv_out_channels) 102 | self.global_relu = nn.ReLU(inplace=True) 103 | self.global_fc = nn.Linear(local_conv_out_channels, num_classes, bias=False) 104 | 105 | weight_init(self.global_conv) 106 | weight_init(self.global_bn) 107 | weight_init(self.global_fc) 108 | 109 | 110 | # 2x 111 | self.pcb2_pool_list, self.pcb2_conv_list, self.pcb2_batchnorm_list, self.pcb2_relu_list, self.pcb2_fc_list = pcb_block(self.num_ftrs, 2, local_conv_out_channels, num_classes, avg) 112 | # 4x 113 | self.pcb4_pool_list, self.pcb4_conv_list, self.pcb4_batchnorm_list, self.pcb4_relu_list, self.pcb4_fc_list = pcb_block(self.num_ftrs, 4, local_conv_out_channels, num_classes, avg) 114 | # 8x 115 | self.pcb8_pool_list, self.pcb8_conv_list, self.pcb8_batchnorm_list, self.pcb8_relu_list, self.pcb8_fc_list = pcb_block(self.num_ftrs, 8, local_conv_out_channels, num_classes, avg) 116 | 117 | 118 | 119 | def forward(self, x): 120 | feat_list = [] 121 | logits_list = [] 122 | feats = self.features(x) # N, C, H, W 123 | assert feats.size(2) == 24 124 | assert feats.size(-1) == 8 125 | assert feats.size(2) % self.num_stripes == 0 126 | 127 | if self.erase>0: 128 | # print('Random Erasing') 129 | erasing = RandomErasing_vertical(probability=self.erase) 130 | feats = erasing(feats) 131 | 132 | feat_list, logits_list = global_pcb(feats, self.global_pooling, self.global_conv, self.global_bn, 133 | self.global_relu, self.global_fc, [], []) 134 | feat_list, logits_list = spp_vertical(feats, self.pcb2_pool_list, self.pcb2_conv_list, 135 | self.pcb2_batchnorm_list, self.pcb2_relu_list, self.pcb2_fc_list, 2, feat_list, logits_list) 136 | feat_list, logits_list = spp_vertical(feats, self.pcb4_pool_list, self.pcb4_conv_list, 137 | self.pcb4_batchnorm_list, self.pcb4_relu_list, self.pcb4_fc_list, 4, feat_list, logits_list) 138 | 139 | feat_list, logits_list = spp_vertical(feats, self.pcb8_pool_list, self.pcb8_conv_list, 140 | self.pcb8_batchnorm_list, self.pcb8_relu_list, self.pcb8_fc_list, 8, feat_list, logits_list) 141 | 142 | if not self.training: 143 | return torch.cat(feat_list, dim=1) 144 | 145 | if self.loss == {'xent'}: 146 | return logits_list 147 | elif self.loss == {'xent', 'htri'}: 148 | return logits_list, feat_list 149 | elif self.loss == {'cent'}: 150 | return logits_list, feat_list 151 | elif self.loss == {'ring'}: 152 | return logits_list, feat_list 153 | else: 154 | raise KeyError("Unsupported loss: {}".format(self.loss)) -------------------------------------------------------------------------------- /models/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | 98 | def __init__(self, block, layers, last_conv_stride=2, num_classes=1000): 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride) 110 | self.avgpool = nn.AvgPool2d(7, stride=1) 111 | self.fc = nn.Linear(512 * block.expansion, num_classes) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 116 | elif isinstance(m, nn.BatchNorm2d): 117 | nn.init.constant_(m.weight, 1) 118 | nn.init.constant_(m.bias, 0) 119 | 120 | def _make_layer(self, block, planes, blocks, stride=1): 121 | downsample = None 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | nn.Conv2d(self.inplanes, planes * block.expansion, 125 | kernel_size=1, stride=stride, bias=False), 126 | nn.BatchNorm2d(planes * block.expansion), 127 | ) 128 | 129 | layers = [] 130 | layers.append(block(self.inplanes, planes, stride, downsample)) 131 | self.inplanes = planes * block.expansion 132 | for i in range(1, blocks): 133 | layers.append(block(self.inplanes, planes)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | x = self.maxpool(x) 142 | 143 | x = self.layer1(x) 144 | x = self.layer2(x) 145 | x = self.layer3(x) 146 | x = self.layer4(x) 147 | 148 | x = self.avgpool(x) 149 | x = x.view(x.size(0), -1) 150 | x = self.fc(x) 151 | 152 | return x 153 | 154 | 155 | def resnet18(pretrained=False, **kwargs): 156 | """Constructs a ResNet-18 model. 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | """ 161 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 162 | if pretrained: 163 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 164 | return model 165 | 166 | 167 | def resnet34(pretrained=False, **kwargs): 168 | """Constructs a ResNet-34 model. 169 | 170 | Args: 171 | pretrained (bool): If True, returns a model pre-trained on ImageNet 172 | """ 173 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 174 | if pretrained: 175 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 176 | return model 177 | 178 | 179 | def resnet50(pretrained=False, **kwargs): 180 | """Constructs a ResNet-50 model. 181 | 182 | Args: 183 | pretrained (bool): If True, returns a model pre-trained on ImageNet 184 | """ 185 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 186 | if pretrained: 187 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 188 | return model 189 | 190 | 191 | def resnet101(pretrained=False, **kwargs): 192 | """Constructs a ResNet-101 model. 193 | 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 198 | if pretrained: 199 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 200 | return model 201 | 202 | 203 | def resnet152(pretrained=False, **kwargs): 204 | """Constructs a ResNet-152 model. 205 | 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 210 | if pretrained: 211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 212 | return model -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .ResNet import * 4 | from .ResNeXt import * 5 | from .SEResNet import * 6 | from .DenseNet import * 7 | from .MuDeep import * 8 | from .HACNN import * 9 | from .SqueezeNet import * 10 | from .MobileNet import * 11 | from .ShuffleNet import * 12 | from .Xception import * 13 | from .InceptionV4 import * 14 | from .NASNet import * 15 | from .DPN import * 16 | from .InceptionResNetV2 import * 17 | from .HPM import * 18 | from .ResNetM import * 19 | from .IDE import * 20 | from .ResNetVideo import * 21 | 22 | __factory = { 23 | # 'resnet50': resnet50, 24 | # 'resnet101': resnet101, 25 | 'resnet50': ResNet50, 26 | 'resnet101': ResNet101, 27 | 'seresnet50': SEResNet50, 28 | 'seresnet101': SEResNet101, 29 | 'seresnext50': SEResNeXt50, 30 | 'seresnext101': SEResNeXt101, 31 | 'resnext101': ResNeXt101_32x4d, 32 | 'resnet50m': ResNet50M, 33 | 'densenet121': DenseNet121, 34 | 'squeezenet': SqueezeNet, 35 | 'mobilenet': MobileNetV2, 36 | 'shufflenet': ShuffleNet, 37 | 'xception': Xception, 38 | 'inceptionv4': InceptionV4ReID, 39 | 'nasnet': NASNetAMobile, 40 | 'dpn92': DPN, 41 | 'inceptionresnetv2': InceptionResNetV2, 42 | 'mudeep': MuDeep, 43 | 'hacnn': HACNN, 44 | 'hpm': HPM, 45 | 'resnet50tp': ResNet50TP, 46 | 'resnet50ta': ResNet50TA, 47 | 'resnet50rnn': ResNet50RNN, 48 | 'resnet50ide': ResNet50IDE, 49 | 'resnet50mide': ResNet50MIDE, 50 | 'resnet50st':ResNet50ST, 51 | 'resnet50sta':ResNet50STA, 52 | } 53 | 54 | def get_names(): 55 | return __factory.keys() 56 | 57 | def init_model(name, *args, **kwargs): 58 | if name not in __factory.keys(): 59 | raise KeyError("Unknown model: {}".format(name)) 60 | return __factory[name](*args, **kwargs) --------------------------------------------------------------------------------