├── README.md ├── evaluate_gpu.py ├── evaluate_rerank.py ├── model.py ├── model └── .gitkeep ├── prepare.py ├── random_erasing.py ├── re_ranking.py ├── test.py ├── train_new.py └── tripletfolder.py /README.md: -------------------------------------------------------------------------------- 1 | ## Person_reID_triplet-loss-baseline 2 | 3 | Baseline Code (with bottleneck) for Person-reID (pytorch). 4 | 5 | We arrived **Rank@1=86.45%, mAP=70.66%** with ResNet stride=2. 6 | SGD optimizer is used. 7 | 8 | Any suggestion is welcomed. 9 | 10 | ## Model Structure 11 | You may learn more from `model.py`. We use the L2-norm 2048-dim feature as the input. 12 | 13 | ## Tips 14 | - News: I added the fp16 support. 15 | - I did not optimize the code. I strongly suggest use fp16 and use `with torch.no_grad()`. I will update the code later. 16 | - Larger margin may lead to a worse local minimum. (margin = 0.1-0.3 may provide a better result.) 17 | - Per-class sampler (Satisfied sampler)is not neccessary. 18 | - Adam optimizer is not neccessary. 19 | 20 | ## Prerequisites 21 | - Python 3.6 22 | - GPU Memory >= 6G 23 | - Numpy 24 | - Pytorch 0.3+ 25 | 26 | **(Some reports found that updating numpy can arrive the right accuracy. If you only get 50~80 Top1 Accuracy, just try it.)** 27 | We have successfully run the code based on numpy 1.12.1 and 1.13.1 . 28 | 29 | ## Getting started 30 | ### Installation 31 | - Install Pytorch from http://pytorch.org/ 32 | - Install Torchvision from the source 33 | ``` 34 | git clone https://github.com/pytorch/vision 35 | cd vision 36 | python setup.py install 37 | ``` 38 | Because pytorch and torchvision are ongoing projects. 39 | 40 | Here we noted that our code is tested based on Pytorch 0.3.0/0.4.0 and Torchvision 0.2.0. 41 | 42 | ## Dataset & Preparation 43 | Download [Market1501 Dataset](http://www.liangzheng.org/Project/project_reid.html) 44 | 45 | Preparation: Put the images with the same id in one folder. You may use 46 | ```bash 47 | python prepare.py 48 | ``` 49 | Remember to change the dataset path to your own path. 50 | 51 | Futhermore, you also can test our code on [DukeMTMC-reID Dataset](https://github.com/layumi/DukeMTMC-reID_evaluation). 52 | Our baseline code is not such high on DukeMTMC-reID **Rank@1=64.23%, mAP=43.92%**. Hyperparameters are need to be tuned. 53 | 54 | To save trained model, we make a dir. 55 | ```bash 56 | mkdir model 57 | ``` 58 | 59 | ## Train 60 | Train a model by 61 | ```bash 62 | python train_new.py --gpu_ids 0 --name ft_ResNet50 --train_all --batchsize 32 --data_dir your_data_path 63 | ``` 64 | `--gpu_ids` which gpu to run. 65 | 66 | `--name` the name of model. 67 | 68 | `--data_dir` the path of the training data. 69 | 70 | `--train_all` using all images to train. 71 | 72 | `--batchsize` batch size. 73 | 74 | `--erasing_p` random erasing probability. 75 | 76 | Train a model with random erasing by 77 | ```bash 78 | python train_new.py --gpu_ids 0 --name ft_ResNet50 --train_all --batchsize 32 --data_dir your_data_path --erasing_p 0.5 79 | ``` 80 | 81 | ## Test 82 | Use trained model to extract feature by 83 | ```bash 84 | python test.py --gpu_ids 0 --name ft_ResNet50 --test_dir your_data_path --which_epoch 59 85 | ``` 86 | `--gpu_ids` which gpu to run. 87 | 88 | `--name` the dir name of trained model. 89 | 90 | `--which_epoch` select the i-th model. 91 | 92 | `--data_dir` the path of the testing data. 93 | 94 | 95 | ## Evaluation 96 | ```bash 97 | python evaluate.py 98 | ``` 99 | It will output Rank@1, Rank@5, Rank@10 and mAP results. 100 | You may also try `evaluate_gpu.py` to conduct a faster evaluation with GPU. 101 | 102 | For mAP calculation, you also can refer to the [C++ code for Oxford Building](http://www.robots.ox.ac.uk/~vgg/data/oxbuildings/compute_ap.cpp). We use the triangle mAP calculation (consistent with the Market1501 original code). 103 | 104 | 105 | ## Related Repos 106 | 1. [Pedestrian Alignment Network](https://github.com/layumi/Pedestrian_Alignment) 107 | 2. [2stream Person re-ID](https://github.com/layumi/2016_person_re-ID) 108 | 3. [Pedestrian GAN](https://github.com/layumi/Person-reID_GAN) 109 | 4. [Language Person Search](https://github.com/layumi/Image-Text-Embedding) 110 | -------------------------------------------------------------------------------- /evaluate_gpu.py: -------------------------------------------------------------------------------- 1 | import scipy.io 2 | import torch 3 | import numpy as np 4 | import time 5 | import os 6 | import matplotlib 7 | matplotlib.use('agg') 8 | import matplotlib.pyplot as plt 9 | ####################################################################### 10 | # Evaluate 11 | 12 | cam_metric = torch.zeros(6,6) 13 | 14 | def evaluate(qf,ql,qc,gf,gl,gc): 15 | query = qf.view(-1,1) 16 | # print(query.shape) 17 | score = torch.mm(gf,query) 18 | score = score.squeeze(1).cpu() 19 | score = score.numpy() 20 | # predict index 21 | index = np.argsort(score) #from small to large 22 | index = index[::-1] 23 | # index = index[0:2000] 24 | # good index 25 | query_index = np.argwhere(gl==ql) 26 | #same camera 27 | camera_index = np.argwhere(gc==qc) 28 | 29 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 30 | junk_index1 = np.argwhere(gl==-1) 31 | junk_index2 = np.intersect1d(query_index, camera_index) 32 | junk_index = np.append(junk_index2, junk_index1) #.flatten()) 33 | 34 | CMC_tmp = compute_mAP(index, qc, good_index, junk_index) 35 | return CMC_tmp 36 | 37 | 38 | def compute_mAP(index, qc, good_index, junk_index): 39 | ap = 0 40 | cmc = torch.IntTensor(len(index)).zero_() 41 | if good_index.size==0: # if empty 42 | cmc[0] = -1 43 | return ap,cmc 44 | 45 | # remove junk_index 46 | ranked_camera = gallery_cam[index] 47 | mask = np.in1d(index, junk_index, invert=True) 48 | #mask2 = np.in1d(index, np.append(good_index,junk_index), invert=True) 49 | index = index[mask] 50 | ranked_camera = ranked_camera[mask] 51 | for i in range(10): 52 | cam_metric[ qc-1, ranked_camera[i]-1 ] +=1 53 | 54 | # find good_index index 55 | ngood = len(good_index) 56 | mask = np.in1d(index, good_index) 57 | rows_good = np.argwhere(mask==True) 58 | rows_good = rows_good.flatten() 59 | 60 | cmc[rows_good[0]:] = 1 61 | for i in range(ngood): 62 | d_recall = 1.0/ngood 63 | precision = (i+1)*1.0/(rows_good[i]+1) 64 | if rows_good[i]!=0: 65 | old_precision = i*1.0/rows_good[i] 66 | else: 67 | old_precision=1.0 68 | ap = ap + d_recall*(old_precision + precision)/2 69 | 70 | return ap, cmc 71 | 72 | ###################################################################### 73 | result = scipy.io.loadmat('pytorch_result.mat') 74 | query_feature = torch.FloatTensor(result['query_f']) 75 | query_cam = result['query_cam'][0] 76 | query_label = result['query_label'][0] 77 | gallery_feature = torch.FloatTensor(result['gallery_f']) 78 | gallery_cam = result['gallery_cam'][0] 79 | gallery_label = result['gallery_label'][0] 80 | 81 | multi = os.path.isfile('multi_query.mat') 82 | 83 | if multi: 84 | m_result = scipy.io.loadmat('multi_query.mat') 85 | mquery_feature = torch.FloatTensor(m_result['mquery_f']) 86 | mquery_cam = m_result['mquery_cam'][0] 87 | mquery_label = m_result['mquery_label'][0] 88 | mquery_feature = mquery_feature.cuda() 89 | 90 | query_feature = query_feature.cuda() 91 | gallery_feature = gallery_feature.cuda() 92 | 93 | print(query_feature.shape) 94 | CMC = torch.IntTensor(len(gallery_label)).zero_() 95 | ap = 0.0 96 | #print(query_label) 97 | for i in range(len(query_label)): 98 | ap_tmp, CMC_tmp = evaluate(query_feature[i],query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam) 99 | if CMC_tmp[0]==-1: 100 | continue 101 | CMC = CMC + CMC_tmp 102 | ap += ap_tmp 103 | #print(i, CMC_tmp[0]) 104 | 105 | CMC = CMC.float() 106 | CMC = CMC/len(query_label) #average CMC 107 | print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label))) 108 | 109 | # multiple-query 110 | CMC = torch.IntTensor(len(gallery_label)).zero_() 111 | ap = 0.0 112 | if multi: 113 | for i in range(len(query_label)): 114 | mquery_index1 = np.argwhere(mquery_label==query_label[i]) 115 | mquery_index2 = np.argwhere(mquery_cam==query_cam[i]) 116 | mquery_index = np.intersect1d(mquery_index1, mquery_index2) 117 | mq = torch.mean(mquery_feature[mquery_index,:], dim=0) 118 | ap_tmp, CMC_tmp = evaluate(mq,query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam) 119 | if CMC_tmp[0]==-1: 120 | continue 121 | CMC = CMC + CMC_tmp 122 | ap += ap_tmp 123 | #print(i, CMC_tmp[0]) 124 | CMC = CMC.float() 125 | CMC = CMC/len(query_label) #average CMC 126 | print('multi Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label))) 127 | -------------------------------------------------------------------------------- /evaluate_rerank.py: -------------------------------------------------------------------------------- 1 | import scipy.io 2 | import torch 3 | import numpy as np 4 | import time 5 | from re_ranking import re_ranking 6 | ####################################################################### 7 | # Evaluate 8 | def evaluate(score,ql,qc,gl,gc): 9 | index = np.argsort(score) #from small to large 10 | #index = index[::-1] 11 | # good index 12 | query_index = np.argwhere(gl==ql) 13 | camera_index = np.argwhere(gc==qc) 14 | 15 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 16 | junk_index1 = np.argwhere(gl==-1) 17 | junk_index2 = np.intersect1d(query_index, camera_index) 18 | junk_index = np.append(junk_index2, junk_index1) #.flatten()) 19 | 20 | CMC_tmp = compute_mAP(index, good_index, junk_index) 21 | return CMC_tmp 22 | 23 | 24 | def compute_mAP(index, good_index, junk_index): 25 | ap = 0 26 | cmc = torch.IntTensor(len(index)).zero_() 27 | if good_index.size==0: # if empty 28 | cmc[0] = -1 29 | return ap,cmc 30 | 31 | # remove junk_index 32 | mask = np.in1d(index, junk_index, invert=True) 33 | index = index[mask] 34 | 35 | # find good_index index 36 | ngood = len(good_index) 37 | mask = np.in1d(index, good_index) 38 | rows_good = np.argwhere(mask==True) 39 | rows_good = rows_good.flatten() 40 | 41 | cmc[rows_good[0]:] = 1 42 | for i in range(ngood): 43 | d_recall = 1.0/ngood 44 | precision = (i+1)*1.0/(rows_good[i]+1) 45 | if rows_good[i]!=0: 46 | old_precision = i*1.0/rows_good[i] 47 | else: 48 | old_precision=1.0 49 | ap = ap + d_recall*(old_precision + precision)/2 50 | 51 | return ap, cmc 52 | 53 | ###################################################################### 54 | result = scipy.io.loadmat('pytorch_result.mat') 55 | query_feature = result['query_f'] 56 | query_cam = result['query_cam'][0] 57 | query_label = result['query_label'][0] 58 | gallery_feature = result['gallery_f'] 59 | gallery_cam = result['gallery_cam'][0] 60 | gallery_label = result['gallery_label'][0] 61 | 62 | CMC = torch.IntTensor(len(gallery_label)).zero_() 63 | ap = 0.0 64 | #re-ranking 65 | print('calculate initial distance') 66 | q_g_dist = np.dot(query_feature, np.transpose(gallery_feature)) 67 | q_q_dist = np.dot(query_feature, np.transpose(query_feature)) 68 | g_g_dist = np.dot(gallery_feature, np.transpose(gallery_feature)) 69 | since = time.time() 70 | re_rank = re_ranking(q_g_dist, q_q_dist, g_g_dist) 71 | time_elapsed = time.time() - since 72 | print('Reranking complete in {:.0f}m {:.0f}s'.format( 73 | time_elapsed // 60, time_elapsed % 60)) 74 | for i in range(len(query_label)): 75 | ap_tmp, CMC_tmp = evaluate(re_rank[i,:],query_label[i],query_cam[i],gallery_label,gallery_cam) 76 | if CMC_tmp[0]==-1: 77 | continue 78 | CMC = CMC + CMC_tmp 79 | ap += ap_tmp 80 | #print(i, CMC_tmp[0]) 81 | 82 | CMC = CMC.float() 83 | CMC = CMC/len(query_label) #average CMC 84 | print('top1:%f top5:%f top10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label))) 85 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torchvision import models 5 | from torch.autograd import Variable 6 | 7 | ###################################################################### 8 | def weights_init_kaiming(m): 9 | classname = m.__class__.__name__ 10 | # print(classname) 11 | if classname.find('Conv') != -1: 12 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 13 | elif classname.find('Linear') != -1: 14 | init.kaiming_normal(m.weight.data, a=0, mode='fan_out') 15 | init.constant(m.bias.data, 0.0) 16 | elif classname.find('BatchNorm1d') != -1: 17 | init.normal(m.weight.data, 1.0, 0.02) 18 | init.constant(m.bias.data, 0.0) 19 | 20 | def weights_init_classifier(m): 21 | classname = m.__class__.__name__ 22 | if classname.find('Linear') != -1: 23 | init.normal(m.weight.data, std=0.001) 24 | init.constant(m.bias.data, 0.0) 25 | 26 | # Defines the new fc layer and classification layer 27 | # |--Linear--|--bn--|--relu--|--Linear--| 28 | class ClassBlock(nn.Module): 29 | def __init__(self, input_dim, class_num, dropout=False, relu=False, num_bottleneck=512): 30 | super(ClassBlock, self).__init__() 31 | add_block = [] 32 | #add_block += [nn.Linear(input_dim, num_bottleneck)] 33 | num_bottleneck=input_dim 34 | add_block += [nn.BatchNorm1d(num_bottleneck)] 35 | if relu: 36 | add_block += [nn.LeakyReLU(0.1)] 37 | if dropout: 38 | add_block += [nn.Dropout(p=0.5)] 39 | add_block = nn.Sequential(*add_block) 40 | add_block.apply(weights_init_kaiming) 41 | 42 | classifier = [] 43 | classifier += [nn.Linear(num_bottleneck, class_num)] 44 | classifier = nn.Sequential(*classifier) 45 | classifier.apply(weights_init_classifier) 46 | 47 | self.add_block = add_block 48 | self.classifier = classifier 49 | def forward(self, x): 50 | f = self.add_block(x) 51 | f_norm = f.norm(p=2, dim=1, keepdim=True) + 1e-8 52 | f = f.div(f_norm) 53 | x = self.classifier(f) 54 | return x,f 55 | 56 | # Define the ResNet50-based Model 57 | class ft_net(nn.Module): 58 | 59 | def __init__(self, class_num ): 60 | super(ft_net, self).__init__() 61 | model_ft = models.resnet50(pretrained=True) 62 | # avg pooling to global pooling 63 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) 64 | self.model = model_ft 65 | self.classifier = ClassBlock(2048, class_num, dropout=False, relu=False) 66 | # remove the final downsample 67 | # self.model.layer4[0].downsample[0].stride = (1,1) 68 | # self.model.layer4[0].conv2.stride = (1,1) 69 | def forward(self, x): 70 | x = self.model.conv1(x) 71 | x = self.model.bn1(x) 72 | x = self.model.relu(x) 73 | x = self.model.maxpool(x) 74 | x = self.model.layer1(x) 75 | x = self.model.layer2(x) 76 | x = self.model.layer3(x) 77 | x = self.model.layer4(x) 78 | x = self.model.avgpool(x) 79 | x = torch.squeeze(x) 80 | x,f = self.classifier(x) 81 | return x,f 82 | 83 | # Define the DenseNet121-based Model 84 | class ft_net_dense(nn.Module): 85 | 86 | def __init__(self, class_num ): 87 | super().__init__() 88 | model_ft = models.densenet121(pretrained=True) 89 | model_ft.features.avgpool = nn.AdaptiveAvgPool2d((1,1)) 90 | model_ft.fc = nn.Sequential() 91 | self.model = model_ft 92 | # For DenseNet, the feature dim is 1024 93 | self.classifier = ClassBlock(1024, class_num) 94 | 95 | def forward(self, x): 96 | x = self.model.features(x) 97 | x = torch.squeeze(x) 98 | x = self.classifier(x) 99 | return x 100 | 101 | # Define the ResNet50-based Model (Middle-Concat) 102 | # In the spirit of "The Devil is in the Middle: Exploiting Mid-level Representations for Cross-Domain Instance Matching." Yu, Qian, et al. arXiv:1711.08106 (2017). 103 | class ft_net_middle(nn.Module): 104 | 105 | def __init__(self, class_num ): 106 | super(ft_net_middle, self).__init__() 107 | model_ft = models.resnet50(pretrained=True) 108 | # avg pooling to global pooling 109 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) 110 | self.model = model_ft 111 | self.classifier = ClassBlock(2048+1024, class_num) 112 | 113 | def forward(self, x): 114 | x = self.model.conv1(x) 115 | x = self.model.bn1(x) 116 | x = self.model.relu(x) 117 | x = self.model.maxpool(x) 118 | x = self.model.layer1(x) 119 | x = self.model.layer2(x) 120 | x = self.model.layer3(x) 121 | # x0 n*1024*1*1 122 | x0 = self.model.avgpool(x) 123 | x = self.model.layer4(x) 124 | # x1 n*2048*1*1 125 | x1 = self.model.avgpool(x) 126 | x = torch.cat((x0,x1),1) 127 | x = torch.squeeze(x) 128 | x = self.classifier(x) 129 | return x 130 | 131 | # Part Model proposed in Yifan Sun etal. (2018) 132 | class PCB(nn.Module): 133 | def __init__(self, class_num ): 134 | super(PCB, self).__init__() 135 | 136 | self.part = 6 # We cut the pool5 to 6 parts 137 | model_ft = models.resnet50(pretrained=True) 138 | self.model = model_ft 139 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1)) 140 | self.dropout = nn.Dropout(p=0.5) 141 | # remove the final downsample 142 | self.model.layer4[0].downsample[0].stride = (1,1) 143 | self.model.layer4[0].conv2.stride = (1,1) 144 | # define 6 classifiers 145 | for i in range(self.part): 146 | name = 'classifier'+str(i) 147 | setattr(self, name, ClassBlock(2048, class_num, True, False, 256)) 148 | 149 | def forward(self, x): 150 | x = self.model.conv1(x) 151 | x = self.model.bn1(x) 152 | x = self.model.relu(x) 153 | x = self.model.maxpool(x) 154 | 155 | x = self.model.layer1(x) 156 | x = self.model.layer2(x) 157 | x = self.model.layer3(x) 158 | x = self.model.layer4(x) 159 | x = self.avgpool(x) 160 | x = self.dropout(x) 161 | part = {} 162 | predict = {} 163 | # get six part feature batchsize*2048*6 164 | for i in range(self.part): 165 | part[i] = torch.squeeze(x[:,:,i]) 166 | name = 'classifier'+str(i) 167 | c = getattr(self,name) 168 | predict[i] = c(part[i]) 169 | 170 | # sum prediction 171 | #y = predict[0] 172 | #for i in range(self.part-1): 173 | # y += predict[i+1] 174 | y = [] 175 | for i in range(self.part): 176 | y.append(predict[i]) 177 | return y 178 | 179 | class PCB_test(nn.Module): 180 | def __init__(self,model): 181 | super(PCB_test,self).__init__() 182 | self.part = 6 183 | self.model = model.model 184 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1)) 185 | # remove the final downsample 186 | self.model.layer4[0].downsample[0].stride = (1,1) 187 | self.model.layer4[0].conv2.stride = (1,1) 188 | 189 | def forward(self, x): 190 | x = self.model.conv1(x) 191 | x = self.model.bn1(x) 192 | x = self.model.relu(x) 193 | x = self.model.maxpool(x) 194 | 195 | x = self.model.layer1(x) 196 | x = self.model.layer2(x) 197 | x = self.model.layer3(x) 198 | x = self.model.layer4(x) 199 | x = self.avgpool(x) 200 | y = x.view(x.size(0),x.size(1),x.size(2)) 201 | return y 202 | 203 | # debug model structure 204 | #net = ft_net(751) 205 | #net = ft_net(751) 206 | #print(net) 207 | #input = Variable(torch.FloatTensor(8, 3, 224, 224)) 208 | #output,f = net(input) 209 | #print('net output size:') 210 | #print(f.shape) 211 | -------------------------------------------------------------------------------- /model/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shutil import copyfile 3 | 4 | # You only need to change this line to your dataset download path 5 | download_path = '/home/zzheng/Downloads/Market' 6 | 7 | if not os.path.isdir(download_path): 8 | print('please change the download_path') 9 | 10 | save_path = download_path + '/pytorch' 11 | if not os.path.isdir(save_path): 12 | os.mkdir(save_path) 13 | #----------------------------------------- 14 | #query 15 | query_path = download_path + '/query' 16 | query_save_path = download_path + '/pytorch/query' 17 | if not os.path.isdir(query_save_path): 18 | os.mkdir(query_save_path) 19 | 20 | for root, dirs, files in os.walk(query_path, topdown=True): 21 | for name in files: 22 | if not name[-3:]=='jpg': 23 | continue 24 | ID = name.split('_') 25 | src_path = query_path + '/' + name 26 | dst_path = query_save_path + '/' + ID[0] 27 | if not os.path.isdir(dst_path): 28 | os.mkdir(dst_path) 29 | copyfile(src_path, dst_path + '/' + name) 30 | 31 | #----------------------------------------- 32 | #multi-query 33 | query_path = download_path + '/gt_bbox' 34 | query_save_path = download_path + '/pytorch/multi-query' 35 | if not os.path.isdir(query_save_path): 36 | os.mkdir(query_save_path) 37 | 38 | for root, dirs, files in os.walk(query_path, topdown=True): 39 | for name in files: 40 | if not name[-3:]=='jpg': 41 | continue 42 | ID = name.split('_') 43 | src_path = query_path + '/' + name 44 | dst_path = query_save_path + '/' + ID[0] 45 | if not os.path.isdir(dst_path): 46 | os.mkdir(dst_path) 47 | copyfile(src_path, dst_path + '/' + name) 48 | 49 | #----------------------------------------- 50 | #gallery 51 | gallery_path = download_path + '/bounding_box_test' 52 | gallery_save_path = download_path + '/pytorch/gallery' 53 | if not os.path.isdir(gallery_save_path): 54 | os.mkdir(gallery_save_path) 55 | 56 | for root, dirs, files in os.walk(gallery_path, topdown=True): 57 | for name in files: 58 | if not name[-3:]=='jpg': 59 | continue 60 | ID = name.split('_') 61 | src_path = gallery_path + '/' + name 62 | dst_path = gallery_save_path + '/' + ID[0] 63 | if not os.path.isdir(dst_path): 64 | os.mkdir(dst_path) 65 | copyfile(src_path, dst_path + '/' + name) 66 | 67 | #--------------------------------------- 68 | #train_all 69 | train_path = download_path + '/bounding_box_train' 70 | train_save_path = download_path + '/pytorch/train_all' 71 | if not os.path.isdir(train_save_path): 72 | os.mkdir(train_save_path) 73 | 74 | for root, dirs, files in os.walk(train_path, topdown=True): 75 | for name in files: 76 | if not name[-3:]=='jpg': 77 | continue 78 | ID = name.split('_') 79 | src_path = train_path + '/' + name 80 | dst_path = train_save_path + '/' + ID[0] 81 | if not os.path.isdir(dst_path): 82 | os.mkdir(dst_path) 83 | copyfile(src_path, dst_path + '/' + name) 84 | 85 | 86 | #--------------------------------------- 87 | #train_val 88 | train_path = download_path + '/bounding_box_train' 89 | train_save_path = download_path + '/pytorch/train' 90 | val_save_path = download_path + '/pytorch/val' 91 | if not os.path.isdir(train_save_path): 92 | os.mkdir(train_save_path) 93 | os.mkdir(val_save_path) 94 | 95 | for root, dirs, files in os.walk(train_path, topdown=True): 96 | for name in files: 97 | if not name[-3:]=='jpg': 98 | continue 99 | ID = name.split('_') 100 | src_path = train_path + '/' + name 101 | dst_path = train_save_path + '/' + ID[0] 102 | if not os.path.isdir(dst_path): 103 | os.mkdir(dst_path) 104 | dst_path = val_save_path + '/' + ID[0] #first image is used as val image 105 | os.mkdir(dst_path) 106 | copyfile(src_path, dst_path + '/' + name) 107 | -------------------------------------------------------------------------------- /random_erasing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | 5 | from PIL import Image 6 | import random 7 | import math 8 | import numpy as np 9 | import torch 10 | 11 | class RandomErasing(object): 12 | """ Randomly selects a rectangle region in an image and erases its pixels. 13 | 'Random Erasing Data Augmentation' by Zhong et al. 14 | See https://arxiv.org/pdf/1708.04896.pdf 15 | Args: 16 | probability: The probability that the Random Erasing operation will be performed. 17 | sl: Minimum proportion of erased area against input image. 18 | sh: Maximum proportion of erased area against input image. 19 | r1: Minimum aspect ratio of erased area. 20 | mean: Erasing value. 21 | """ 22 | 23 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 24 | self.probability = probability 25 | self.mean = mean 26 | self.sl = sl 27 | self.sh = sh 28 | self.r1 = r1 29 | 30 | def __call__(self, img): 31 | 32 | if random.uniform(0, 1) > self.probability: 33 | return img 34 | 35 | for attempt in range(100): 36 | area = img.size()[1] * img.size()[2] 37 | 38 | target_area = random.uniform(self.sl, self.sh) * area 39 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 40 | 41 | h = int(round(math.sqrt(target_area * aspect_ratio))) 42 | w = int(round(math.sqrt(target_area / aspect_ratio))) 43 | 44 | if w < img.size()[2] and h < img.size()[1]: 45 | x1 = random.randint(0, img.size()[1] - h) 46 | y1 = random.randint(0, img.size()[2] - w) 47 | if img.size()[0] == 3: 48 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 49 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 50 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 51 | else: 52 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 53 | return img 54 | 55 | return img 56 | -------------------------------------------------------------------------------- /re_ranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Jun 26 14:46:56 2017 5 | @author: luohao 6 | Modified by Houjing Huang, 2017-12-22. 7 | - This version accepts distance matrix instead of raw features. 8 | - The difference of `/` division between python 2 and 3 is handled. 9 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 10 | 11 | Modified by Zhedong Zheng, 2018-1-12. 12 | - replace sort with topK, which save about 30s. 13 | """ 14 | 15 | """ 16 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 17 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 18 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 19 | """ 20 | 21 | """ 22 | API 23 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 24 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 25 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 26 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 27 | Returns: 28 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 29 | """ 30 | 31 | 32 | import numpy as np 33 | 34 | def k_reciprocal_neigh( initial_rank, i, k1): 35 | forward_k_neigh_index = initial_rank[i,:k1+1] 36 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 37 | fi = np.where(backward_k_neigh_index==i)[0] 38 | return forward_k_neigh_index[fi] 39 | 40 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 41 | # The following naming, e.g. gallery_num, is different from outer scope. 42 | # Don't care about it. 43 | original_dist = np.concatenate( 44 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 45 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 46 | axis=0) 47 | original_dist = 2. - 2 * original_dist #np.power(original_dist, 2).astype(np.float32) 48 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 49 | V = np.zeros_like(original_dist).astype(np.float32) 50 | #initial_rank = np.argsort(original_dist).astype(np.int32) 51 | # top K1+1 52 | initial_rank = np.argpartition( original_dist, range(1,k1+1) ) 53 | 54 | query_num = q_g_dist.shape[0] 55 | all_num = original_dist.shape[0] 56 | 57 | for i in range(all_num): 58 | # k-reciprocal neighbors 59 | k_reciprocal_index = k_reciprocal_neigh( initial_rank, i, k1) 60 | k_reciprocal_expansion_index = k_reciprocal_index 61 | for j in range(len(k_reciprocal_index)): 62 | candidate = k_reciprocal_index[j] 63 | candidate_k_reciprocal_index = k_reciprocal_neigh( initial_rank, candidate, int(np.around(k1/2))) 64 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 65 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 66 | 67 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 68 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 69 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 70 | 71 | original_dist = original_dist[:query_num,] 72 | if k2 != 1: 73 | V_qe = np.zeros_like(V,dtype=np.float32) 74 | for i in range(all_num): 75 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 76 | V = V_qe 77 | del V_qe 78 | del initial_rank 79 | invIndex = [] 80 | for i in range(all_num): 81 | invIndex.append(np.where(V[:,i] != 0)[0]) 82 | 83 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 84 | 85 | for i in range(query_num): 86 | temp_min = np.zeros(shape=[1,all_num],dtype=np.float32) 87 | indNonZero = np.where(V[i,:] != 0)[0] 88 | indImages = [] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 92 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 93 | 94 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 95 | del original_dist 96 | del V 97 | del jaccard_dist 98 | final_dist = final_dist[:query_num,query_num:] 99 | return final_dist 100 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function, division 4 | 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.optim import lr_scheduler 10 | from torch.autograd import Variable 11 | import numpy as np 12 | import torchvision 13 | from torchvision import datasets, models, transforms 14 | import time 15 | import os 16 | import scipy.io 17 | from model import ft_net, ft_net_dense, PCB, PCB_test 18 | 19 | ###################################################################### 20 | # Options 21 | # -------- 22 | parser = argparse.ArgumentParser(description='Training') 23 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') 24 | parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last') 25 | parser.add_argument('--test_dir',default='../Market/pytorch',type=str, help='./test_data') 26 | parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path') 27 | parser.add_argument('--batchsize', default=64, type=int, help='batchsize') 28 | parser.add_argument('--use_dense', action='store_true', help='use densenet121' ) 29 | parser.add_argument('--PCB', action='store_true', help='use PCB' ) 30 | parser.add_argument('--multi', action='store_true', help='use multiple query' ) 31 | 32 | opt = parser.parse_args() 33 | 34 | str_ids = opt.gpu_ids.split(',') 35 | #which_epoch = opt.which_epoch 36 | name = opt.name 37 | test_dir = opt.test_dir 38 | 39 | gpu_ids = [] 40 | for str_id in str_ids: 41 | id = int(str_id) 42 | if id >=0: 43 | gpu_ids.append(id) 44 | 45 | # set gpu ids 46 | if len(gpu_ids)>0: 47 | torch.cuda.set_device(gpu_ids[0]) 48 | 49 | ###################################################################### 50 | # Load Data 51 | # --------- 52 | # 53 | # We will use torchvision and torch.utils.data packages for loading the 54 | # data. 55 | # 56 | data_transforms = transforms.Compose([ 57 | transforms.Resize((256,128), interpolation=3), 58 | transforms.ToTensor(), 59 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 60 | ############### Ten Crop 61 | #transforms.TenCrop(224), 62 | #transforms.Lambda(lambda crops: torch.stack( 63 | # [transforms.ToTensor()(crop) 64 | # for crop in crops] 65 | # )), 66 | #transforms.Lambda(lambda crops: torch.stack( 67 | # [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop) 68 | # for crop in crops] 69 | # )) 70 | ]) 71 | 72 | if opt.PCB: 73 | data_transforms = transforms.Compose([ 74 | transforms.Resize((384,192), interpolation=3), 75 | transforms.ToTensor(), 76 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 77 | ]) 78 | 79 | 80 | data_dir = test_dir 81 | image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']} 82 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, 83 | shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']} 84 | 85 | class_names = image_datasets['query'].classes 86 | use_gpu = torch.cuda.is_available() 87 | 88 | ###################################################################### 89 | # Load model 90 | #--------------------------- 91 | def load_network(network): 92 | save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch) 93 | network.load_state_dict(torch.load(save_path)) 94 | return network 95 | 96 | 97 | ###################################################################### 98 | # Extract feature 99 | # ---------------------- 100 | # 101 | # Extract feature from a trained model. 102 | # 103 | def fliplr(img): 104 | '''flip horizontal''' 105 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W 106 | img_flip = img.index_select(3,inv_idx) 107 | return img_flip 108 | 109 | def extract_feature(model,dataloaders): 110 | features = torch.FloatTensor() 111 | count = 0 112 | for data in dataloaders: 113 | img, label = data 114 | n, c, h, w = img.size() 115 | count += n 116 | print(count) 117 | ff = torch.FloatTensor(n, 2048).zero_() 118 | for i in range(2): 119 | if(i==1): 120 | img = fliplr(img) 121 | input_img = Variable(img.cuda()) 122 | outputs,f = model(input_img) 123 | f = f.data.cpu() 124 | ff = ff+f 125 | # norm feature 126 | if opt.PCB: 127 | # feature size (n,2048,6) 128 | # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature. 129 | # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6). 130 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6) 131 | ff = ff.div(fnorm.expand_as(ff)) 132 | ff = ff.view(ff.size(0), -1) 133 | else: 134 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) 135 | ff = ff.div(fnorm.expand_as(ff)) 136 | 137 | features = torch.cat((features,ff), 0) 138 | return features 139 | 140 | def get_id(img_path): 141 | camera_id = [] 142 | labels = [] 143 | for path, v in img_path: 144 | #filename = path.split('/')[-1] 145 | filename = os.path.basename(path) 146 | label = filename[0:4] 147 | camera = filename.split('c')[1] 148 | if label[0:2]=='-1': 149 | labels.append(-1) 150 | else: 151 | labels.append(int(label)) 152 | camera_id.append(int(camera[0])) 153 | return camera_id, labels 154 | 155 | gallery_path = image_datasets['gallery'].imgs 156 | query_path = image_datasets['query'].imgs 157 | mquery_path = image_datasets['multi-query'].imgs 158 | 159 | gallery_cam,gallery_label = get_id(gallery_path) 160 | query_cam,query_label = get_id(query_path) 161 | mquery_cam,mquery_label = get_id(mquery_path) 162 | 163 | ###################################################################### 164 | # Load Collected data Trained model 165 | print('-------test-----------') 166 | if opt.use_dense: 167 | model_structure = ft_net_dense(751) 168 | else: 169 | model_structure = ft_net(751) 170 | 171 | if opt.PCB: 172 | model_structure = PCB(751) 173 | 174 | model = load_network(model_structure) 175 | #model = model_structure 176 | 177 | # Remove the final fc layer and classifier layer 178 | #if not opt.PCB: 179 | # model.model.fc = nn.Sequential() 180 | # model.classifier = nn.Sequential() 181 | #else: 182 | # model = PCB_test(model) 183 | 184 | # Change to test mode 185 | model = model.eval() 186 | if use_gpu: 187 | model = model.cuda() 188 | 189 | # Extract feature 190 | gallery_feature = extract_feature(model,dataloaders['gallery']) 191 | query_feature = extract_feature(model,dataloaders['query']) 192 | if opt.multi: 193 | mquery_feature = extract_feature(model,dataloaders['multi-query']) 194 | 195 | # Save to Matlab for check 196 | result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam} 197 | scipy.io.savemat('pytorch_result.mat',result) 198 | if opt.multi: 199 | result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam} 200 | scipy.io.savemat('multi_query.mat',result) 201 | -------------------------------------------------------------------------------- /train_new.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function, division 4 | 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.optim import lr_scheduler 10 | from torch.autograd import Variable 11 | import numpy as np 12 | import torchvision 13 | from torchvision import datasets, models, transforms 14 | import matplotlib 15 | matplotlib.use('agg') 16 | import matplotlib.pyplot as plt 17 | import copy 18 | from PIL import Image 19 | import time 20 | import os 21 | #from reid_sampler import StratifiedSampler 22 | from model import ft_net, ft_net_dense, PCB 23 | from random_erasing import RandomErasing 24 | from tripletfolder import TripletFolder 25 | import json 26 | from shutil import copyfile 27 | 28 | version = torch.__version__ 29 | 30 | #fp16 31 | try: 32 | from apex.fp16_utils import * 33 | from apex import amp, optimizers 34 | except ImportError: # will be 3.x series 35 | print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0') 36 | 37 | 38 | ###################################################################### 39 | # Options 40 | # -------- 41 | parser = argparse.ArgumentParser(description='Training') 42 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') 43 | parser.add_argument('--name',default='ft_ResNet50', type=str, help='output model name') 44 | parser.add_argument('--data_dir',default='../Market/pytorch',type=str, help='training dir path') 45 | parser.add_argument('--train_all', action='store_true', help='use all training data' ) 46 | parser.add_argument('--color_jitter', action='store_true', help='use color jitter in training' ) 47 | parser.add_argument('--batchsize', default=32, type=int, help='batchsize') 48 | parser.add_argument('--poolsize', default=128, type=int, help='poolsize') 49 | parser.add_argument('--margin', default=0.3, type=float, help='margin') 50 | parser.add_argument('--lr', default=0.01, type=float, help='margin') 51 | parser.add_argument('--alpha', default=0.0, type=float, help='regularization, push to -1') 52 | parser.add_argument('--erasing_p', default=0, type=float, help='Random Erasing probability, in [0,1]') 53 | parser.add_argument('--use_dense', action='store_true', help='use densenet121' ) 54 | parser.add_argument('--PCB', action='store_true', help='use PCB+ResNet50' ) 55 | parser.add_argument('--fp16', action='store_true', help='use float16 instead of float32, which will save about 50% memory' ) 56 | opt = parser.parse_args() 57 | 58 | data_dir = opt.data_dir 59 | name = opt.name 60 | fp16 = opt.fp16 61 | str_ids = opt.gpu_ids.split(',') 62 | gpu_ids = [] 63 | for str_id in str_ids: 64 | gid = int(str_id) 65 | if gid >=0: 66 | gpu_ids.append(gid) 67 | 68 | # set gpu ids 69 | if len(gpu_ids)>0: 70 | torch.cuda.set_device(gpu_ids[0]) 71 | #print(gpu_ids[0]) 72 | 73 | 74 | ###################################################################### 75 | # Load Data 76 | # --------- 77 | # 78 | 79 | transform_train_list = [ 80 | #transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC) 81 | transforms.Resize((256,128), interpolation=3), 82 | #transforms.RandomCrop((256,128)), 83 | transforms.RandomHorizontalFlip(), 84 | transforms.ToTensor(), 85 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 86 | ] 87 | 88 | transform_val_list = [ 89 | transforms.Resize(size=(256,128),interpolation=3), #Image.BICUBIC 90 | transforms.ToTensor(), 91 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 92 | ] 93 | 94 | if opt.PCB: 95 | transform_train_list = [ 96 | transforms.Resize((384,192), interpolation=3), 97 | transforms.RandomHorizontalFlip(), 98 | transforms.ToTensor(), 99 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 100 | ] 101 | transform_val_list = [ 102 | transforms.Resize(size=(384,192),interpolation=3), #Image.BICUBIC 103 | transforms.ToTensor(), 104 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 105 | ] 106 | 107 | if opt.erasing_p>0: 108 | transform_train_list = transform_train_list + [RandomErasing(probability = opt.erasing_p, mean=[0.0, 0.0, 0.0])] 109 | 110 | if opt.color_jitter: 111 | transform_train_list = [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)] + transform_train_list 112 | 113 | print(transform_train_list) 114 | data_transforms = { 115 | 'train': transforms.Compose( transform_train_list ), 116 | 'val': transforms.Compose(transform_val_list), 117 | } 118 | 119 | 120 | train_all = '' 121 | if opt.train_all: 122 | train_all = '_all' 123 | 124 | image_datasets = {} 125 | image_datasets['train'] = TripletFolder(os.path.join(data_dir, 'train_all'), 126 | data_transforms['train']) 127 | image_datasets['val'] = TripletFolder(os.path.join(data_dir, 'val'), 128 | data_transforms['val']) 129 | 130 | batch = {} 131 | 132 | class_names = image_datasets['train'].classes 133 | class_vector = [s[1] for s in image_datasets['train'].samples] 134 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, 135 | shuffle=True, num_workers=8) 136 | for x in ['train', 'val']} 137 | 138 | dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} 139 | 140 | use_gpu = torch.cuda.is_available() 141 | 142 | since = time.time() 143 | #inputs, classes, pos, pos_classes = next(iter(dataloaders['train'])) 144 | print(time.time()-since) 145 | 146 | 147 | ###################################################################### 148 | # Training the model 149 | # ------------------ 150 | # 151 | # Now, let's write a general function to train a model. Here, we will 152 | # illustrate: 153 | # 154 | # - Scheduling the learning rate 155 | # - Saving the best model 156 | # 157 | # In the following, parameter ``scheduler`` is an LR scheduler object from 158 | # ``torch.optim.lr_scheduler``. 159 | 160 | y_loss = {} # loss history 161 | y_loss['train'] = [] 162 | y_loss['val'] = [] 163 | y_err = {} 164 | y_err['train'] = [] 165 | y_err['val'] = [] 166 | 167 | def train_model(model, criterion, optimizer, scheduler, num_epochs=25): 168 | since = time.time() 169 | 170 | best_model_wts = model.state_dict() 171 | best_acc = 0.0 172 | last_margin = 0.0 173 | 174 | for epoch in range(num_epochs): 175 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 176 | print('-' * 10) 177 | 178 | # Each epoch has a training and validation phase 179 | for phase in ['train']: 180 | if phase == 'train': 181 | model.train(True) # Set model to training mode 182 | else: 183 | model.train(False) # Set model to evaluate mode 184 | 185 | running_loss = 0.0 186 | running_corrects = 0.0 187 | running_margin = 0.0 188 | running_reg = 0.0 189 | # Iterate over data. 190 | for data in dataloaders[phase]: 191 | # get the inputs 192 | inputs, labels, pos, pos_labels = data 193 | now_batch_size,c,h,w = inputs.shape 194 | 195 | if now_batch_size0 or int(version[2]) > 3: # for the new version like 0.4.0 and 0.5.0 292 | running_loss += loss_triplet.item() #* opt.batchsize 293 | else : # for the old version like 0.3.0 and 0.3.1 294 | running_loss += loss_triplet.data[0] #*opt.batchsize 295 | #print( loss_triplet.item()) 296 | running_corrects += float(torch.sum(pscore>nscore+opt.margin)) 297 | running_margin +=float(torch.sum(pscore-nscore)) 298 | running_reg += reg 299 | 300 | datasize = dataset_sizes['train']//opt.batchsize * opt.batchsize 301 | epoch_loss = running_loss / datasize 302 | epoch_reg = opt.alpha*running_reg/ datasize 303 | epoch_acc = running_corrects / datasize 304 | epoch_margin = running_margin / datasize 305 | 306 | #if epoch_acc>0.75: 307 | # opt.margin = min(opt.margin+0.02, 1.0) 308 | print('now_margin: %.4f'%opt.margin) 309 | print('{} Loss: {:.4f} Reg: {:.4f} Acc: {:.4f} MeanMargin: {:.4f}'.format( 310 | phase, epoch_loss, epoch_reg, epoch_acc, epoch_margin)) 311 | if phase == 'train': 312 | scheduler.step() 313 | y_loss[phase].append(epoch_loss) 314 | y_err[phase].append(1.0-epoch_acc) 315 | # deep copy the model 316 | if epoch_margin>last_margin: 317 | last_margin = epoch_margin 318 | last_model_wts = model.state_dict() 319 | 320 | if epoch%10 == 9: 321 | save_network(model, epoch) 322 | draw_curve(epoch) 323 | 324 | print() 325 | 326 | time_elapsed = time.time() - since 327 | print('Training complete in {:.0f}m {:.0f}s'.format( 328 | time_elapsed // 60, time_elapsed % 60)) 329 | #print('Best val Acc: {:4f}'.format(best_acc)) 330 | 331 | # load best model weights 332 | model.load_state_dict(last_model_wts) 333 | save_network(model, 'last') 334 | return model 335 | 336 | 337 | ###################################################################### 338 | # Draw Curve 339 | #--------------------------- 340 | x_epoch = [] 341 | fig = plt.figure() 342 | ax0 = fig.add_subplot(121, title="triplet_loss") 343 | ax1 = fig.add_subplot(122, title="top1err") 344 | def draw_curve(current_epoch): 345 | x_epoch.append(current_epoch) 346 | ax0.plot(x_epoch, y_loss['train'], 'bo-', label='train') 347 | # ax0.plot(x_epoch, y_loss['val'], 'ro-', label='val') 348 | ax1.plot(x_epoch, y_err['train'], 'bo-', label='train') 349 | # ax1.plot(x_epoch, y_err['val'], 'ro-', label='val') 350 | if current_epoch == 0: 351 | ax0.legend() 352 | ax1.legend() 353 | fig.savefig( os.path.join('./model',name,'train.jpg')) 354 | 355 | ###################################################################### 356 | # Save model 357 | #--------------------------- 358 | def save_network(network, epoch_label): 359 | save_filename = 'net_%s.pth'% epoch_label 360 | save_path = os.path.join('./model',name,save_filename) 361 | torch.save(network.cpu().state_dict(), save_path) 362 | if torch.cuda.is_available: 363 | network.cuda(gpu_ids[0]) 364 | 365 | 366 | ###################################################################### 367 | # Finetuning the convnet 368 | # ---------------------- 369 | # 370 | # Load a pretrainied model and reset final fully connected layer. 371 | # 372 | 373 | if opt.use_dense: 374 | model = ft_net_dense(len(class_names)) 375 | else: 376 | model = ft_net(len(class_names)) 377 | 378 | if opt.PCB: 379 | model = PCB(len(class_names)) 380 | 381 | print(model) 382 | 383 | if use_gpu: 384 | model = model.cuda() 385 | 386 | criterion = nn.CrossEntropyLoss() 387 | 388 | if not opt.PCB: 389 | ignored_params = list(map(id, model.model.fc.parameters() )) + list(map(id, model.classifier.parameters() )) 390 | base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) 391 | optimizer_ft = optim.SGD([ 392 | {'params': base_params, 'lr': 0.1*opt.lr}, 393 | {'params': model.model.fc.parameters(), 'lr': opt.lr}, 394 | {'params': model.classifier.parameters(), 'lr': opt.lr} 395 | ], weight_decay=5e-4, momentum=0.9, nesterov=True) 396 | else: 397 | ignored_params = list(map(id, model.model.fc.parameters() )) 398 | ignored_params += (list(map(id, model.classifier0.parameters() )) 399 | +list(map(id, model.classifier1.parameters() )) 400 | +list(map(id, model.classifier2.parameters() )) 401 | +list(map(id, model.classifier3.parameters() )) 402 | +list(map(id, model.classifier4.parameters() )) 403 | +list(map(id, model.classifier5.parameters() )) 404 | #+list(map(id, model.classifier6.parameters() )) 405 | #+list(map(id, model.classifier7.parameters() )) 406 | ) 407 | base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) 408 | optimizer_ft = optim.SGD([ 409 | {'params': base_params, 'lr': 0.001}, 410 | {'params': model.model.fc.parameters(), 'lr': 0.01}, 411 | {'params': model.classifier0.parameters(), 'lr': 0.01}, 412 | {'params': model.classifier1.parameters(), 'lr': 0.01}, 413 | {'params': model.classifier2.parameters(), 'lr': 0.01}, 414 | {'params': model.classifier3.parameters(), 'lr': 0.01}, 415 | {'params': model.classifier4.parameters(), 'lr': 0.01}, 416 | {'params': model.classifier5.parameters(), 'lr': 0.01}, 417 | #{'params': model.classifier6.parameters(), 'lr': 0.01}, 418 | #{'params': model.classifier7.parameters(), 'lr': 0.01} 419 | ], weight_decay=5e-4, momentum=0.9, nesterov=True) 420 | 421 | exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones=[40,60], gamma=0.1) 422 | 423 | ###################################################################### 424 | # Train and evaluate 425 | # ^^^^^^^^^^^^^^^^^^ 426 | # 427 | # It should take around 1-2 hours on GPU. 428 | # 429 | dir_name = os.path.join('./model',name) 430 | if not os.path.isdir(dir_name): 431 | os.mkdir(dir_name) 432 | copyfile('./train_new.py', dir_name+'/train_new.py') 433 | copyfile('./model.py', dir_name+'/model.py') 434 | copyfile('./tripletfolder.py', dir_name+'/tripletfolder.py') 435 | 436 | # save opts 437 | with open('%s/opts.json'%dir_name,'w') as fp: 438 | json.dump(vars(opt), fp, indent=1) 439 | 440 | if fp16: 441 | model, optimizer_ft = amp.initialize(model, optimizer_ft, opt_level = "O1") 442 | 443 | model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler, 444 | num_epochs=70) 445 | 446 | -------------------------------------------------------------------------------- /tripletfolder.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets 2 | import os 3 | import numpy as np 4 | import random 5 | import torch 6 | 7 | class TripletFolder(datasets.ImageFolder): 8 | 9 | def __init__(self, root, transform): 10 | super(TripletFolder, self).__init__(root, transform) 11 | targets = np.asarray([s[1] for s in self.samples]) 12 | self.targets = targets 13 | cams = [] 14 | for s in self.samples: 15 | cams.append( self._get_cam_id(s[0]) ) 16 | self.cams = np.asarray(cams) 17 | 18 | def _get_cam_id(self, path): 19 | camera_id = [] 20 | filename = os.path.basename(path) 21 | camera_id = filename.split('c')[1][0] 22 | #camera_id = filename.split('_')[2][0:2] 23 | return int(camera_id)-1 24 | 25 | def _get_pos_sample(self, target, index): 26 | pos_index = np.argwhere(self.targets == target) 27 | pos_index = pos_index.flatten() 28 | pos_index = np.setdiff1d(pos_index, index) 29 | rand = np.random.permutation(len(pos_index)) 30 | result_path = [] 31 | for i in range(4): 32 | t = i%len(rand) 33 | tmp_index = pos_index[rand[t]] 34 | result_path.append(self.samples[tmp_index][0]) 35 | return result_path 36 | 37 | def _get_neg_sample(self, target): 38 | neg_index = np.argwhere(self.targets != target) 39 | neg_index = neg_index.flatten() 40 | rand = random.randint(0,len(neg_index)-1) 41 | return self.samples[neg_index[rand]] 42 | 43 | def __getitem__(self, index): 44 | path, target = self.samples[index] 45 | cam = self.cams[index] 46 | # pos_path, neg_path 47 | pos_path = self._get_pos_sample(target, index) 48 | 49 | sample = self.loader(path) 50 | pos0 = self.loader(pos_path[0]) 51 | pos1 = self.loader(pos_path[1]) 52 | pos2 = self.loader(pos_path[2]) 53 | pos3 = self.loader(pos_path[3]) 54 | 55 | if self.transform is not None: 56 | sample = self.transform(sample) 57 | pos0 = self.transform(pos0) 58 | pos1 = self.transform(pos1) 59 | pos2 = self.transform(pos2) 60 | pos3 = self.transform(pos3) 61 | 62 | if self.target_transform is not None: 63 | target = self.target_transform(target) 64 | 65 | c,h,w = pos0.shape 66 | pos = torch.cat((pos0.view(1,c,h,w), pos1.view(1,c,h,w), pos2.view(1,c,h,w), pos3.view(1,c,h,w)), 0) 67 | pos_target = target 68 | return sample, target, pos, pos_target 69 | --------------------------------------------------------------------------------