├── .gitattributes
├── LICENSE
├── README.md
├── evaluate_rerank_duke.py
├── evaluate_rerank_market.py
├── evaluate_st.py
├── gen_rerank_all_scores_mat.py
├── gen_st_model_duke.py
├── gen_st_model_market.py
├── model.py
├── plot_st_distribution.py
├── prepare.py
├── random_erasing.py
├── re_ranking.py
├── readme.txt
├── test_st_duke.py
├── test_st_market.py
├── train_duke.py
└── train_market.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2021 Guangcong Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to
8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9 | the Software, and to permit persons to whom the Software is furnished to do so,
10 | subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Spatial-Temporal Person Re-identification
2 |
3 | ----------
4 | Code for st-ReID(pytorch). We achieve **Rank@1=98.1%, mAP=87.6%** without re-ranking and **Rank@1=98.0%, mAP=95.5%** with re-ranking for market1501.For Duke-MTMC, we achieve **Rank@1=94.4%, mAP=83.9%** without re-ranking and **Rank@1=94.5%, mAP=92.7%** with re-ranking.
5 |
6 | ## Update and FQA:
7 | - 2023.12.26: I would not maintain this code base. The PyTorch version or some code packages would be obsolete, but I don't think it is a big deal. I might give you some suggestions for implementation if possible.
8 | - 2020.01.08: If you do not want to re-train a model, you can follow this link. https://github.com/Wanggcong/Spatial-Temporal-Re-identification/issues/26#issuecomment-571905649
9 | - 2019.12.26:a demo figure has been added. I am not sure if it works or not because it was written one years ago. I will update this file in the future.
10 | - 2019.07.28: Models(+RE) (google drive Link:https://drive.google.com/drive/folders/1FIreE0pUGiqLzppzz_f7gHw0kaXZb1kC)
11 | - 2019.07.11: Models (+RE) (baiduyun Link:https://pan.baidu.com/s/1QMp22dVGJvBH45e4XPdeKw password:dn7b) are released. Note that, for market, slightly different from the results in the paper because we use pytorch 0.4.1 to train these models (mAP is slightly higher than paper while rank-1 is slightly lower than paper). We may reproduce the results by Pytorch 0.3 later.
12 | - 2019.07.11: README.md, python3 prepare --Duke ---> python3 prepare.py --Duke
13 | - 2019.06.02: How to add the spatial-temporal constraint into conventional re-id models? You can replace step 2 and step 3 by your own visual feature represenation.
14 | - 2019.05.31: gen_st_model_market.py, added Line 68~69.
15 |
16 |
17 | ## 1. ST-ReID
18 | ### 1.1 model
19 | 
20 |
21 | ### 1.2 result
22 | 
23 | ----------
24 |
25 | 
26 |
27 |
28 | ## 2. rerequisites
29 | - **Pytorch 0.3**
30 | - Python 3.6
31 | - Numpy
32 |
33 |
34 | ## 3. experiment
35 | ### Market1501
36 | 1. data prepare
37 | 1) change the path of dataset
38 | 2) python3 prepare.py --Market
39 |
40 | 2. train (appearance feature learning)
41 | python3 train_market.py --PCB --gpu_ids 2 --name ft_ResNet50_pcb_market_e --erasing_p 0.5 --train_all --data_dir "/home/huangpg/st-reid/dataset/market_rename/"
42 |
43 | 3. test (appearance feature extraction)
44 | python3 test_st_market.py --PCB --gpu_ids 2 --name ft_ResNet50_pcb_market_e --test_dir "/home/huangpg/st-reid/dataset/market_rename/"
45 |
46 | 4. generate st model (spatial-temporal distribution)
47 | python3 gen_st_model_market.py --name ft_ResNet50_pcb_market_e --data_dir "/home/huangpg/st-reid/dataset/market_rename/"
48 | 5. evaluate (joint metric, you can use your own visual feature or spatial-temporal streams)
49 | python3 evaluate_st.py --name ft_ResNet50_pcb_market_e
50 |
51 | 6. re-rank
52 | 6.1) python3 gen_rerank_all_scores_mat.py --name ft_ResNet50_pcb_market_e
53 | 6.2) python3 evaluate_rerank_market.py --name ft_ResNet50_pcb_market_e
54 |
55 |
56 | ### DukeMTMC-reID
57 | 1. data prepare
58 | python3 prepare.py --Duke
59 |
60 | 2. train (appearance feature learning)
61 | python3 train_duke.py --PCB --gpu_ids 2 --name ft_ResNet50_pcb_duke_e --erasing_p 0.5 --train_all --data_dir "/home/huangpg/st-reid/dataset/DukeMTMC_prepare/"
62 |
63 | 3. test (appearance feature extraction)
64 | python3 test_st_duke.py --PCB --gpu_ids 2 --name ft_ResNet50_pcb_duke_e --test_dir "/home/huangpg/st-reid/dataset/DukeMTMC_prepare/"
65 |
66 | 4. generate st model (spatial-temporal distribution)
67 | python3 gen_st_model_duke.py --name ft_ResNet50_pcb_duke_e --data_dir "/home/huangpg/st-reid/dataset/DukeMTMC_prepare/"
68 |
69 | 5. evaluate (joint metric, you can use your own visual feature or spatial-temporal streams)
70 | python3 evaluate_st.py --name ft_ResNet50_pcb_duke_e
71 |
72 | 6. re-rank
73 | 6.1) python3 gen_rerank_all_scores_mat.py --name ft_ResNet50_pcb_duke_e
74 | 6.2) python3 evaluate_rerank_duke.py --name ft_ResNet50_pcb_duke_e
75 |
76 | ## Citation
77 |
78 | If you use this code, please kindly cite it in your paper.
79 |
80 | ```latex
81 | @article{guangcong2019aaai,
82 | title={Spatial-Temporal Person Re-identification},
83 | author={Wang, Guangcong and Lai, Jianhuang and Huang, Peigen and Xie, Xiaohua},
84 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
85 | pages={8933-8940},
86 | year={2019}
87 | }
88 | ```
89 | Paper Link:https://wvvw.aaai.org/ojs/index.php/AAAI/article/view/4921
90 | or https://arxiv.org/abs/1812.03282
91 | ## Related Repos
92 |
93 | Our codes are mainly based on this [repository](https://github.com/layumi/Person_reID_baseline_pytorch)
94 |
--------------------------------------------------------------------------------
/evaluate_rerank_duke.py:
--------------------------------------------------------------------------------
1 | import scipy.io
2 | import torch
3 | import numpy as np
4 | import time
5 | from re_ranking import re_ranking
6 | import argparse
7 | import os
8 | import math
9 |
10 | parser = argparse.ArgumentParser(description='evaluate')
11 | parser.add_argument('--name',default='ft_ResNet50_duke_pcb_r_c', type=str, help='0,1,2,3...or last')
12 | opt = parser.parse_args()
13 | name = opt.name
14 |
15 | #######################################################################
16 | # Evaluate
17 | def evaluate(score,ql,qc,gl,gc):
18 | index = np.argsort(score) #from small to large
19 | #index = index[::-1]
20 | # good index
21 | query_index = np.argwhere(gl==ql)
22 | camera_index = np.argwhere(gc==qc)
23 |
24 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
25 | junk_index1 = np.argwhere(gl==-1)
26 | junk_index2 = np.intersect1d(query_index, camera_index)
27 | junk_index = np.append(junk_index2, junk_index1) #.flatten())
28 |
29 | CMC_tmp = compute_mAP(index, good_index, junk_index)
30 | return CMC_tmp
31 |
32 |
33 | def compute_mAP(index, good_index, junk_index):
34 | ap = 0
35 | cmc = torch.IntTensor(len(index)).zero_()
36 | if good_index.size==0: # if empty
37 | cmc[0] = -1
38 | return ap,cmc
39 |
40 | # remove junk_index
41 | mask = np.in1d(index, junk_index, invert=True)
42 | index = index[mask]
43 |
44 | # find good_index index
45 | ngood = len(good_index)
46 | mask = np.in1d(index, good_index)
47 | rows_good = np.argwhere(mask==True)
48 | rows_good = rows_good.flatten()
49 |
50 | cmc[rows_good[0]:] = 1
51 | for i in range(ngood):
52 | d_recall = 1.0/ngood
53 | precision = (i+1)*1.0/(rows_good[i]+1)
54 | if rows_good[i]!=0:
55 | old_precision = i*1.0/rows_good[i]
56 | else:
57 | old_precision=1.0
58 | ap = ap + d_recall*(old_precision + precision)/2
59 |
60 | return ap, cmc
61 |
62 | ######################################################################
63 | result = scipy.io.loadmat('model/'+name+'/pytorch_result.mat')
64 | query_feature = result['query_f']
65 | query_cam = result['query_cam'][0]
66 | query_label = result['query_label'][0]
67 | gallery_feature = result['gallery_f']
68 | gallery_cam = result['gallery_cam'][0]
69 | gallery_label = result['gallery_label'][0]
70 |
71 |
72 | mat_path = 'model/'+name+'/all_scores.mat'
73 | all_scores = scipy.io.loadmat(mat_path) #important
74 | all_dist = all_scores['all_scores']
75 | print('all_dist shape:',all_dist.shape)
76 | print('query_cam shape:',query_cam.shape)
77 |
78 | CMC = torch.IntTensor(len(gallery_label)).zero_()
79 | ap = 0.0
80 | #re-ranking
81 | print('calculate initial distance')
82 | # q_g_dist = np.dot(query_feature, np.transpose(gallery_feature))
83 | # q_q_dist = np.dot(query_feature, np.transpose(query_feature))
84 | # g_g_dist = np.dot(gallery_feature, np.transpose(gallery_feature))
85 |
86 | since = time.time()
87 | re_rank = re_ranking(len(query_cam), all_dist)
88 | time_elapsed = time.time() - since
89 | print('Reranking complete in {:.0f}m {:.0f}s'.format(
90 | time_elapsed // 60, time_elapsed % 60))
91 | for i in range(len(query_label)):
92 | ap_tmp, CMC_tmp = evaluate(re_rank[i,:],query_label[i],query_cam[i],gallery_label,gallery_cam)
93 | if CMC_tmp[0]==-1:
94 | continue
95 | CMC = CMC + CMC_tmp
96 | ap += ap_tmp
97 | #print(i, CMC_tmp[0])
98 |
99 | CMC = CMC.float()
100 | CMC = CMC/len(query_label) #average CMC
101 | print('top1:%f top5:%f top10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label)))
102 |
--------------------------------------------------------------------------------
/evaluate_rerank_market.py:
--------------------------------------------------------------------------------
1 | import scipy.io
2 | import torch
3 | import numpy as np
4 | import time
5 | from re_ranking import re_ranking
6 | import argparse
7 | import os
8 | import math
9 |
10 | parser = argparse.ArgumentParser(description='evaluate')
11 | parser.add_argument('--name',default='ft_ResNet50_duke_pcb_r_c', type=str, help='0,1,2,3...or last')
12 | opt = parser.parse_args()
13 | name = opt.name
14 |
15 | #######################################################################
16 | # Evaluate
17 | def evaluate(score,ql,qc,gl,gc):
18 | index = np.argsort(score) #from small to large
19 | #index = index[::-1]
20 | # good index
21 | query_index = np.argwhere(gl==ql)
22 | camera_index = np.argwhere(gc==qc)
23 |
24 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
25 | junk_index1 = np.argwhere(gl==-1)
26 | junk_index2 = np.intersect1d(query_index, camera_index)
27 | junk_index = np.append(junk_index2, junk_index1) #.flatten())
28 |
29 | CMC_tmp = compute_mAP(index, good_index, junk_index)
30 | return CMC_tmp
31 |
32 |
33 | def compute_mAP(index, good_index, junk_index):
34 | ap = 0
35 | cmc = torch.IntTensor(len(index)).zero_()
36 | if good_index.size==0: # if empty
37 | cmc[0] = -1
38 | return ap,cmc
39 |
40 | # remove junk_index
41 | mask = np.in1d(index, junk_index, invert=True)
42 | index = index[mask]
43 |
44 | # find good_index index
45 | ngood = len(good_index)
46 | mask = np.in1d(index, good_index)
47 | rows_good = np.argwhere(mask==True)
48 | rows_good = rows_good.flatten()
49 |
50 | cmc[rows_good[0]:] = 1
51 | for i in range(ngood):
52 | d_recall = 1.0/ngood
53 | precision = (i+1)*1.0/(rows_good[i]+1)
54 | if rows_good[i]!=0:
55 | old_precision = i*1.0/rows_good[i]
56 | else:
57 | old_precision=1.0
58 | ap = ap + d_recall*(old_precision + precision)/2
59 |
60 | return ap, cmc
61 |
62 | ######################################################################
63 | result = scipy.io.loadmat('./model/'+name+'/pytorch_result.mat')
64 | query_feature = result['query_f']
65 | query_cam = result['query_cam'][0]
66 | query_label = result['query_label'][0]
67 | gallery_feature = result['gallery_f']
68 | gallery_cam = result['gallery_cam'][0]
69 | gallery_label = result['gallery_label'][0]
70 |
71 |
72 | mat_path = 'model/'+name+'/all_scores.mat'
73 | all_scores = scipy.io.loadmat(mat_path) #important
74 | all_dist = all_scores['all_scores']
75 | print('all_dist shape:',all_dist.shape)
76 | print('query_cam shape:',query_cam.shape)
77 |
78 | CMC = torch.IntTensor(len(gallery_label)).zero_()
79 | ap = 0.0
80 | #re-ranking
81 | print('calculate initial distance')
82 | # q_g_dist = np.dot(query_feature, np.transpose(gallery_feature))
83 | # q_q_dist = np.dot(query_feature, np.transpose(query_feature))
84 | # g_g_dist = np.dot(gallery_feature, np.transpose(gallery_feature))
85 |
86 | since = time.time()
87 | re_rank = re_ranking(len(query_cam), all_dist)
88 | time_elapsed = time.time() - since
89 | print('Reranking complete in {:.0f}m {:.0f}s'.format(
90 | time_elapsed // 60, time_elapsed % 60))
91 | for i in range(len(query_label)):
92 | ap_tmp, CMC_tmp = evaluate(re_rank[i,:],query_label[i],query_cam[i],gallery_label,gallery_cam)
93 | if CMC_tmp[0]==-1:
94 | continue
95 | CMC = CMC + CMC_tmp
96 | ap += ap_tmp
97 | #print(i, CMC_tmp[0])
98 |
99 | CMC = CMC.float()
100 | CMC = CMC/len(query_label) #average CMC
101 | print('top1:%f top5:%f top10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label)))
102 |
--------------------------------------------------------------------------------
/evaluate_st.py:
--------------------------------------------------------------------------------
1 | import scipy.io
2 | import torch
3 | import numpy as np
4 | import time
5 | import argparse
6 | import os
7 | import math
8 |
9 | parser = argparse.ArgumentParser(description='evaluate')
10 | parser.add_argument('--name',default='ft_ResNet50_market_pcb_r', type=str, help='0,1,2,3...or last')
11 | parser.add_argument('--alpha', default=5, type=float, help='alpha')
12 | parser.add_argument('--smooth', default=50, type=float, help='smooth')
13 | opt = parser.parse_args()
14 | name = opt.name
15 | alpha=opt.alpha
16 | smooth=opt.smooth
17 |
18 |
19 |
20 | #######################################################################
21 | # Evaluate
22 | def evaluate(qf,ql,qc,qfr,gf,gl,gc,gfr,distribution):
23 | query = qf
24 | score = np.dot(gf,query)
25 |
26 | # spatial temporal scores: qfr,gfr, qc, gc
27 | # TODO
28 | interval = 100
29 | score_st = np.zeros(len(gc))
30 | for i in range(len(gc)):
31 | if qfr>gfr[i]:
32 | diff = qfr-gfr[i]
33 | hist_ = int(diff/interval)
34 | pr = distribution[qc-1][gc[i]-1][hist_]
35 | else:
36 | diff = gfr[i]-qfr
37 | hist_ = int(diff/interval)
38 | pr = distribution[gc[i]-1][qc-1][hist_]
39 | score_st[i] = pr
40 |
41 | # ========================
42 | score = 1/(1+np.exp(-alpha*score))*1/(1+2*np.exp(-alpha*score_st))
43 | index = np.argsort(-score) #from large to small
44 | query_index = np.argwhere(gl==ql)
45 | camera_index = np.argwhere(gc==qc)
46 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
47 | junk_index1 = np.argwhere(gl==-1)
48 | junk_index2 = np.intersect1d(query_index, camera_index)
49 | junk_index = np.append(junk_index2, junk_index1) #.flatten())
50 | CMC_tmp = compute_mAP(index, good_index, junk_index)
51 | return CMC_tmp
52 |
53 |
54 | def compute_mAP(index, good_index, junk_index):
55 | ap = 0
56 | cmc = torch.IntTensor(len(index)).zero_()
57 | if good_index.size==0: # if empty
58 | cmc[0] = -1
59 | return ap,cmc
60 |
61 | # remove junk_index
62 | mask = np.in1d(index, junk_index, invert=True)
63 | index = index[mask]
64 |
65 | # find good_index index
66 | ngood = len(good_index)
67 | mask = np.in1d(index, good_index)
68 | rows_good = np.argwhere(mask==True)
69 | rows_good = rows_good.flatten()
70 |
71 | cmc[rows_good[0]:] = 1
72 | for i in range(ngood):
73 | d_recall = 1.0/ngood
74 | precision = (i+1)*1.0/(rows_good[i]+1)
75 | if rows_good[i]!=0:
76 | old_precision = i*1.0/rows_good[i]
77 | else:
78 | old_precision=1.0
79 | ap = ap + d_recall*(old_precision + precision)/2
80 |
81 | return ap, cmc
82 |
83 | def gaussian_func(x, u, o=50):
84 | if (o == 0):
85 | print("In gaussian, o shouldn't equel to zero")
86 | return 0
87 | temp1 = 1.0 / (o * math.sqrt(2 * math.pi))
88 | temp2 = -(math.pow(x - u, 2)) / (2 * math.pow(o, 2))
89 | return temp1 * math.exp(temp2)
90 |
91 | def gaussian_func2(x, u, o=50):
92 | temp1 = 1.0 / (o * math.sqrt(2 * math.pi))
93 | temp2 = -(np.power(x - u, 2)) / (2 * np.power(o, 2))
94 | return temp1 * np.exp(temp2)
95 |
96 |
97 | def gauss_smooth(arr):
98 | hist_num = len(arr)
99 | vect= np.zeros((hist_num,1))
100 | for i in range(hist_num):
101 | vect[i,0]=i
102 | # gaussian_vect= gaussian_func2(vect,0,1)
103 | gaussian_vect= gaussian_func2(vect,0,50)
104 | matrix = np.zeros((hist_num,hist_num))
105 | # matrix = np.eye(hist_num)
106 | for i in range(hist_num):
107 | for j in range(i,hist_num):
108 | matrix[i][j]=gaussian_vect[j-i]
109 | matrix = matrix+matrix.transpose()
110 | for i in range(hist_num):
111 | matrix[i][i]=matrix[i][i]/2
112 | # for i in range(hist_num):
113 | # for j in range(i):
114 | # matrix[i][j]=gaussian_vect[j]
115 | xxx = np.dot(matrix,arr)
116 | return xxx
117 |
118 | # faster gauss_smooth
119 | def gauss_smooth2(arr,o):
120 | hist_num = len(arr)
121 | vect= np.zeros((hist_num,1))
122 | for i in range(hist_num):
123 | vect[i,0]=i
124 | # gaussian_vect= gaussian_func2(vect,0,1)
125 | # o=50
126 | approximate_delta = 3*o # when x-u>approximate_delta, e.g., 6*o, the gaussian value is approximately equal to 0.
127 | gaussian_vect= gaussian_func2(vect,0,o)
128 | matrix = np.zeros((hist_num,hist_num))
129 | for i in range(hist_num):
130 | k=0
131 | for j in range(i,hist_num):
132 | if k>approximate_delta:
133 | continue
134 | matrix[i][j]=gaussian_vect[j-i]
135 | k=k+1
136 | matrix = matrix+matrix.transpose()
137 | for i in range(hist_num):
138 | matrix[i][i]=matrix[i][i]/2
139 | # for i in range(hist_num):
140 | # for j in range(i):
141 | # matrix[i][j]=gaussian_vect[j]
142 | xxx = np.dot(matrix,arr)
143 | return xxx
144 |
145 | ######################################################################
146 | # result = scipy.io.loadmat('pytorch_result.mat')
147 | result = scipy.io.loadmat('model/'+name+'/'+'pytorch_result.mat')
148 | query_feature = result['query_f']
149 | query_cam = result['query_cam'][0]
150 | query_label = result['query_label'][0]
151 | query_frames = result['query_frames'][0]
152 |
153 |
154 | gallery_feature = result['gallery_f']
155 | gallery_cam = result['gallery_cam'][0]
156 | gallery_label = result['gallery_label'][0]
157 | gallery_frames = result['gallery_frames'][0]
158 |
159 | query_feature=query_feature.transpose()/np.power(np.sum(np.power(query_feature,2),axis=1),0.5)
160 | query_feature=query_feature.transpose()
161 | print('query_feature:',query_feature.shape)
162 | gallery_feature=gallery_feature.transpose()/np.power(np.sum(np.power(gallery_feature,2),axis=1),0.5)
163 | gallery_feature=gallery_feature.transpose()
164 | print('gallery_feature:',gallery_feature.shape)
165 |
166 |
167 | #############################################################
168 |
169 | result2 = scipy.io.loadmat('model/'+name+'/'+'pytorch_result2.mat')
170 | distribution = result2['distribution']
171 |
172 | #############################################################
173 | for i in range(0,8):
174 | for j in range(0,8):
175 | print("gauss "+str(i)+"->"+str(j))
176 | # gauss_smooth(distribution[i][j])
177 | distribution[i][j][:]=gauss_smooth2(distribution[i][j][:],smooth)
178 |
179 |
180 | eps = 0.0000001
181 | sum_ = np.sum(distribution,axis=2)
182 | for i in range(8):
183 | for j in range(8):
184 | distribution[i][j][:]=distribution[i][j][:]/(sum_[i][j]+eps)
185 | #############################################################
186 |
187 | CMC = torch.IntTensor(len(gallery_label)).zero_()
188 | ap = 0.0
189 | #print(query_label)
190 | for i in range(len(query_label)):
191 | ap_tmp, CMC_tmp = evaluate(query_feature[i],query_label[i],query_cam[i],query_frames[i], gallery_feature,gallery_label,gallery_cam,gallery_frames,distribution)
192 | if CMC_tmp[0]==-1:
193 | continue
194 | CMC = CMC + CMC_tmp
195 | ap += ap_tmp
196 | print(i, CMC_tmp[0])
197 | # if i%10==0:
198 | # print('i:',i)
199 |
200 | CMC = CMC.float()
201 | CMC = CMC/len(query_label) #average CMC
202 | print('top1:%f top5:%f top10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label)))
203 | print('alpha,smooth:',alpha,smooth)
204 |
205 | result = {'CMC':CMC.numpy()}
206 |
207 | scipy.io.savemat('model/'+name+'/'+'CMC_duke_two_stream_add'+str(alpha)+'.mat',result)
208 |
209 |
210 |
--------------------------------------------------------------------------------
/gen_rerank_all_scores_mat.py:
--------------------------------------------------------------------------------
1 | import scipy.io
2 | import torch
3 | import numpy as np
4 | import time
5 | import argparse
6 | import os
7 | import math
8 |
9 | parser = argparse.ArgumentParser(description='evaluate')
10 | parser.add_argument('--name',default='ft_ResNet50_market_pcb_r', type=str, help='0,1,2,3...or last')
11 | parser.add_argument('--alpha', default=5, type=float, help='alpha')
12 | parser.add_argument('--smooth', default=50, type=float, help='smooth')
13 | opt = parser.parse_args()
14 | name = opt.name
15 | alpha=opt.alpha
16 | smooth=opt.smooth
17 |
18 |
19 |
20 | #######################################################################
21 | # Evaluate
22 | def evaluate(qf,ql,qc,qfr,gf,gl,gc,gfr,distribution):
23 | query = qf
24 | score = np.dot(gf,query)
25 |
26 | # spatial temporal scores: qfr,gfr, qc, gc
27 | # TODO
28 | interval = 100
29 | score_st = np.zeros(len(gc))
30 | for i in range(len(gc)):
31 | if qfr>gfr[i]:
32 | diff = qfr-gfr[i]
33 | hist_ = int(diff/interval)
34 | pr = distribution[qc-1][gc[i]-1][hist_]
35 | else:
36 | diff = gfr[i]-qfr
37 | hist_ = int(diff/interval)
38 | pr = distribution[gc[i]-1][qc-1][hist_]
39 | score_st[i] = pr
40 | # ========================
41 | score = 1/(1+np.exp(-alpha*score))*1/(1+2*np.exp(-alpha*score_st))
42 | ###############################################################################################
43 |
44 | index = np.argsort(-score) #from large to small
45 |
46 | query_index = np.argwhere(gl==ql)
47 | camera_index = np.argwhere(gc==qc)
48 |
49 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
50 | junk_index1 = np.argwhere(gl==-1)
51 | junk_index2 = np.intersect1d(query_index, camera_index)
52 | junk_index = np.append(junk_index2, junk_index1) #.flatten())
53 |
54 | CMC_tmp = compute_mAP(index, good_index, junk_index)
55 | return CMC_tmp
56 | def evaluate2(qf,ql,qc,qfr,gf,gl,gc,gfr,distribution):
57 | query = qf
58 | score = np.dot(gf,query)
59 |
60 | # spatial temporal scores: qfr,gfr, qc, gc
61 | # TODO
62 | interval = 100
63 | score_st = np.zeros(len(gc))
64 | for i in range(len(gc)):
65 | if qfr>gfr[i]:
66 | diff = qfr-gfr[i]
67 | hist_ = int(diff/interval)
68 | # print('debug:',qc-1,gc[i]-1,hist_)
69 | pr = distribution[qc-1][gc[i]-1][hist_]
70 | else:
71 | diff = gfr[i]-qfr
72 | hist_ = int(diff/interval)
73 | # print('debug:',qc-1,gc[i]-1,hist_)
74 | pr = distribution[gc[i]-1][qc-1][hist_]
75 | score_st[i] = pr
76 | # ========================
77 | score = 1/(1+np.exp(-alpha*score))*1/(1+2*np.exp(-alpha*score_st))
78 | return score
79 |
80 | def compute_mAP(index, good_index, junk_index):
81 | ap = 0
82 | cmc = torch.IntTensor(len(index)).zero_()
83 | if good_index.size==0: # if empty
84 | cmc[0] = -1
85 | return ap,cmc
86 |
87 | # remove junk_index
88 | mask = np.in1d(index, junk_index, invert=True)
89 | index = index[mask]
90 |
91 | # find good_index index
92 | ngood = len(good_index)
93 | mask = np.in1d(index, good_index)
94 | rows_good = np.argwhere(mask==True)
95 | rows_good = rows_good.flatten()
96 |
97 | cmc[rows_good[0]:] = 1
98 | for i in range(ngood):
99 | d_recall = 1.0/ngood
100 | precision = (i+1)*1.0/(rows_good[i]+1)
101 | if rows_good[i]!=0:
102 | old_precision = i*1.0/rows_good[i]
103 | else:
104 | old_precision=1.0
105 | ap = ap + d_recall*(old_precision + precision)/2
106 |
107 | return ap, cmc
108 |
109 | def gaussian_func(x, u, o=50):
110 | if (o == 0):
111 | print("In gaussian, o shouldn't equel to zero")
112 | return 0
113 | temp1 = 1.0 / (o * math.sqrt(2 * math.pi))
114 | temp2 = -(math.pow(x - u, 2)) / (2 * math.pow(o, 2))
115 | return temp1 * math.exp(temp2)
116 |
117 | def gaussian_func2(x, u, o=50):
118 | temp1 = 1.0 / (o * math.sqrt(2 * math.pi))
119 | temp2 = -(np.power(x - u, 2)) / (2 * np.power(o, 2))
120 | return temp1 * np.exp(temp2)
121 |
122 |
123 | def gauss_smooth(arr):
124 | hist_num = len(arr)
125 | vect= np.zeros((hist_num,1))
126 | for i in range(hist_num):
127 | vect[i,0]=i
128 | # gaussian_vect= gaussian_func2(vect,0,1)
129 | gaussian_vect= gaussian_func2(vect,0,50)
130 | matrix = np.zeros((hist_num,hist_num))
131 | # matrix = np.eye(hist_num)
132 | for i in range(hist_num):
133 | for j in range(i,hist_num):
134 | matrix[i][j]=gaussian_vect[j-i]
135 | matrix = matrix+matrix.transpose()
136 | for i in range(hist_num):
137 | matrix[i][i]=matrix[i][i]/2
138 | xxx = np.dot(matrix,arr)
139 | return xxx
140 |
141 | # faster gauss_smooth
142 | def gauss_smooth2(arr,o):
143 | hist_num = len(arr)
144 | vect= np.zeros((hist_num,1))
145 | for i in range(hist_num):
146 | vect[i,0]=i
147 | approximate_delta = 3*o # when x-u>approximate_delta, e.g., 6*o, the gaussian value is approximately equal to 0.
148 | gaussian_vect= gaussian_func2(vect,0,o)
149 | matrix = np.zeros((hist_num,hist_num))
150 | for i in range(hist_num):
151 | k=0
152 | for j in range(i,hist_num):
153 | if k>approximate_delta:
154 | continue
155 | matrix[i][j]=gaussian_vect[j-i]
156 | k=k+1
157 | matrix = matrix+matrix.transpose()
158 | for i in range(hist_num):
159 | matrix[i][i]=matrix[i][i]/2
160 | xxx = np.dot(matrix,arr)
161 | return xxx
162 |
163 | ######################################################################
164 | result = scipy.io.loadmat('model/'+name+'/'+'pytorch_result.mat')
165 | query_feature = result['query_f']
166 | query_cam = result['query_cam'][0]
167 | query_label = result['query_label'][0]
168 | query_frames = result['query_frames'][0]
169 |
170 |
171 | gallery_feature = result['gallery_f']
172 | gallery_cam = result['gallery_cam'][0]
173 | gallery_label = result['gallery_label'][0]
174 | gallery_frames = result['gallery_frames'][0]
175 |
176 | query_feature=query_feature.transpose()/np.power(np.sum(np.power(query_feature,2),axis=1),0.5)
177 | query_feature=query_feature.transpose()
178 | print('query_feature:',query_feature.shape)
179 | gallery_feature=gallery_feature.transpose()/np.power(np.sum(np.power(gallery_feature,2),axis=1),0.5)
180 | gallery_feature=gallery_feature.transpose()
181 | print('gallery_feature:',gallery_feature.shape)
182 |
183 |
184 | #############################################################
185 |
186 | result2 = scipy.io.loadmat('model/'+name+'/'+'pytorch_result2.mat')
187 | distribution = result2['distribution']
188 |
189 | #############################################################
190 | for i in range(0,8):
191 | for j in range(0,8):
192 | print("gauss "+str(i)+"->"+str(j))
193 | # gauss_smooth(distribution[i][j])
194 | distribution[i][j][:]=gauss_smooth2(distribution[i][j][:],smooth)
195 |
196 |
197 | eps = 0.0000001
198 | sum_ = np.sum(distribution,axis=2)
199 | for i in range(8):
200 | for j in range(8):
201 | distribution[i][j][:]=distribution[i][j][:]/(sum_[i][j]+eps)
202 | #############################################################
203 |
204 | all_features = np.concatenate([query_feature,gallery_feature],axis=0)
205 | all_labels = np.concatenate([query_label,gallery_label],axis=0)
206 | all_cams = np.concatenate([query_cam,gallery_cam],axis=0)
207 | all_frames = np.concatenate([query_frames,gallery_frames],axis=0)
208 |
209 | all_scores = np.zeros((len(all_labels),len(all_labels)))
210 |
211 | print('all_features shape:',all_features.shape)
212 | print('all_labels shape:',all_labels.shape)
213 | print('all_cams shape:',all_cams.shape)
214 | print('all_frames shape:',all_frames.shape)
215 | print('all_scores shape:',all_scores.shape)
216 |
217 |
218 | CMC = torch.IntTensor(len(all_labels)).zero_()
219 | ap = 0.0
220 | for i in range(len(all_labels)):
221 | scores_new = evaluate2(all_features[i],all_labels[i],all_cams[i],all_frames[i], all_features,all_labels,all_cams,all_frames,distribution)
222 | print('scores_new shape:',scores_new.shape)
223 | all_scores[i,:] = scores_new
224 | print(i)
225 |
226 | print('type(all_scores):',type(all_scores))
227 | all_scores = {'all_scores':all_scores}
228 | scipy.io.savemat('model/'+name+'/'+'all_scores'+'.mat',all_scores)
229 | ###############################################################################################
230 |
231 |
--------------------------------------------------------------------------------
/gen_st_model_duke.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function, division
4 |
5 | import argparse
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torch.optim import lr_scheduler
10 | from torch.autograd import Variable
11 | import numpy as np
12 | import torchvision
13 | from torchvision import datasets, models, transforms
14 | from torchvision import datasets
15 | import os
16 | import scipy.io
17 | ######################################################################
18 | # Options
19 | # --------
20 | parser = argparse.ArgumentParser(description='Training')
21 | parser.add_argument('--data_dir',default="/home/sdb1/huangpg/st-reid/st_baseline/Duke/pytorch/",type=str, help='./train_data')
22 | parser.add_argument('--name', default='ft_ResNet50_duke_pcb', type=str, help='save model path')
23 |
24 | opt = parser.parse_args()
25 | name = opt.name
26 | data_dir = opt.data_dir
27 |
28 |
29 | def get_id(img_path):
30 | camera_id = []
31 | labels = []
32 | frames = []
33 | for path, v in img_path:
34 | filename = path.split('/')[-1]
35 | label = filename[0:4]
36 | camera = filename.split('c')[1]
37 | frame = filename[9:16]
38 | if label[0:2]=='-1':
39 | labels.append(-1)
40 | else:
41 | labels.append(int(label))
42 | camera_id.append(int(camera[0]))
43 | frames.append(int(frame))
44 | return camera_id, labels, frames
45 |
46 | def spatial_temporal_distribution(camera_id, labels, frames):
47 | spatial_temporal_sum = np.zeros((702,8))
48 | spatial_temporal_count = np.zeros((702,8))
49 | eps = 0.0000001
50 | interval = 100.0
51 |
52 | for i in range(len(camera_id)):
53 | label_k = labels[i] #### not in order, done
54 | cam_k = camera_id[i]-1 ##### ### ### ### ### ### ### ### ### ### ### ### # from 1, not 0
55 | frame_k = frames[i]
56 | spatial_temporal_sum[label_k][cam_k]=spatial_temporal_sum[label_k][cam_k]+frame_k
57 | spatial_temporal_count[label_k][cam_k] = spatial_temporal_count[label_k][cam_k] + 1
58 | spatial_temporal_avg = spatial_temporal_sum/(spatial_temporal_count+eps) # spatial_temporal_avg: 702 ids, 8cameras, center point
59 |
60 | distribution = np.zeros((8,8,3000))
61 | for i in range(702):
62 | for j in range(8-1):
63 | for k in range(j+1,8):
64 | ###################################################### added
65 | if spatial_temporal_count[i][j]==0 or spatial_temporal_count[i][k]==0:
66 | continue
67 | st_ij = spatial_temporal_avg[i][j]
68 | st_ik = spatial_temporal_avg[i][k]
69 | if st_ij>st_ik:
70 | diff = st_ij-st_ik
71 | hist_ = int(diff/interval)
72 | distribution[j][k][hist_] = distribution[j][k][hist_]+1 # [big][small]
73 | else:
74 | diff = st_ik-st_ij
75 | hist_ = int(diff/interval)
76 | distribution[k][j][hist_] = distribution[k][j][hist_]+1
77 |
78 | sum_ = np.sum(distribution,axis=2)
79 | for i in range(8):
80 | for j in range(8):
81 | distribution[i][j][:]=distribution[i][j][:]/(sum_[i][j]+eps)
82 |
83 | return distribution # [to][from], to xxx camera, from xxx camera
84 |
85 | transform_train_list = [
86 | transforms.Resize(144, interpolation=3),
87 | transforms.RandomCrop((256,128)),
88 | transforms.RandomHorizontalFlip(),
89 | transforms.ToTensor(),
90 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
91 | ]
92 |
93 |
94 | image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,transform_train_list) for x in ['train_all']}
95 | train_path = image_datasets['train_all'].imgs
96 | train_cam, train_label, train_frames = get_id(train_path)
97 |
98 | train_label_order = []
99 | for i in range(len(train_path)):
100 | train_label_order.append(train_path[i][1])
101 |
102 |
103 | # distribution = spatial_temporal_distribution(train_cam, train_label, train_frames)
104 | distribution = spatial_temporal_distribution(train_cam, train_label_order, train_frames)
105 | result = {'distribution':distribution}
106 | scipy.io.savemat('model/'+name+'/'+'pytorch_result2.mat',result)
--------------------------------------------------------------------------------
/gen_st_model_market.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function, division
4 |
5 | import argparse
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torch.optim import lr_scheduler
10 | from torch.autograd import Variable
11 | import numpy as np
12 | import torchvision
13 | from torchvision import datasets, models, transforms
14 | from torchvision import datasets
15 | import os
16 | import scipy.io
17 | import math
18 | ######################################################################
19 | # Options
20 | # --------
21 | parser = argparse.ArgumentParser(description='Training')
22 | parser.add_argument('--data_dir',default="/home/huangpg/test_gc_code/Market/market_rename/",type=str, help='./train_data')
23 | parser.add_argument('--name', default='ft_ResNet50_market_rename_pcb', type=str, help='save model path')
24 |
25 | opt = parser.parse_args()
26 | name = opt.name
27 | data_dir = opt.data_dir
28 |
29 |
30 | def get_id(img_path):
31 | camera_id = []
32 | labels = []
33 | frames = []
34 | for path, v in img_path:
35 | filename = path.split('/')[-1]
36 | label = filename[0:4]
37 | camera = filename.split('c')[1]
38 | # frame = filename[9:16]
39 | frame = filename.split('_')[2][1:]
40 | if label[0:2]=='-1':
41 | labels.append(-1)
42 | else:
43 | labels.append(int(label))
44 | camera_id.append(int(camera[0]))
45 | frames.append(int(frame))
46 | return camera_id, labels, frames
47 |
48 | def spatial_temporal_distribution(camera_id, labels, frames):
49 | class_num=751
50 | max_hist = 5000
51 | spatial_temporal_sum = np.zeros((class_num,8))
52 | spatial_temporal_count = np.zeros((class_num,8))
53 | eps = 0.0000001
54 | interval = 100.0
55 |
56 | for i in range(len(camera_id)):
57 | label_k = labels[i] #### not in order, done
58 | cam_k = camera_id[i]-1 ##### ### ### ### ### ### ### ### ### ### ### ### # from 1, not 0
59 | frame_k = frames[i]
60 | spatial_temporal_sum[label_k][cam_k]=spatial_temporal_sum[label_k][cam_k]+frame_k
61 | spatial_temporal_count[label_k][cam_k] = spatial_temporal_count[label_k][cam_k] + 1
62 | spatial_temporal_avg = spatial_temporal_sum/(spatial_temporal_count+eps) # spatial_temporal_avg: 751 ids, 8cameras, center point
63 |
64 | distribution = np.zeros((8,8,max_hist))
65 | for i in range(class_num):
66 | for j in range(8-1):
67 | for k in range(j+1,8):
68 | if spatial_temporal_count[i][j]==0 or spatial_temporal_count[i][k]==0:
69 | continue
70 | st_ij = spatial_temporal_avg[i][j]
71 | st_ik = spatial_temporal_avg[i][k]
72 | if st_ij>st_ik:
73 | diff = st_ij-st_ik
74 | hist_ = int(diff/interval)
75 | distribution[j][k][hist_] = distribution[j][k][hist_]+1 # [big][small]
76 | else:
77 | diff = st_ik-st_ij
78 | hist_ = int(diff/interval)
79 | distribution[k][j][hist_] = distribution[k][j][hist_]+1
80 |
81 | sum_ = np.sum(distribution,axis=2)
82 | for i in range(8):
83 | for j in range(8):
84 | distribution[i][j][:]=distribution[i][j][:]/(sum_[i][j]+eps)
85 |
86 | return distribution # [to][from], to xxx camera, from xxx camera
87 |
88 | def gaussian_func(x, u, o=0.1):
89 | if (o == 0):
90 | print("In gaussian, o shouldn't equel to zero")
91 | return 0
92 | temp1 = 1.0 / (o * math.sqrt(2 * math.pi))
93 | temp2 = -(math.pow(x - u, 2)) / (2 * math.pow(o, 2))
94 | return temp1 * math.exp(temp2)
95 |
96 |
97 | def gauss_smooth(arr):
98 | # print(gaussian_func(0,0))
99 | for u, element in enumerate(arr):
100 | # print(u," ",element)
101 | if element != 0:
102 | for index in range(0, 3000):
103 | arr[index] = arr[index] + element * gaussian_func(index, u)
104 |
105 | sum = 0
106 | for v in arr:
107 | sum = sum + v
108 | if sum==0:
109 | return arr
110 | for i in range(0,3000):
111 | arr[i] = arr[i] / sum
112 | return arr
113 |
114 |
115 | transform_train_list = [
116 | transforms.Resize(144, interpolation=3),
117 | transforms.RandomCrop((256,128)),
118 | transforms.RandomHorizontalFlip(),
119 | transforms.ToTensor(),
120 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
121 | ]
122 |
123 |
124 | image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,transform_train_list) for x in ['train_all']}
125 | train_path = image_datasets['train_all'].imgs
126 | train_cam, train_label, train_frames = get_id(train_path)
127 |
128 | train_label_order = []
129 | for i in range(len(train_path)):
130 | train_label_order.append(train_path[i][1])
131 |
132 |
133 | # distribution = spatial_temporal_distribution(train_cam, train_label, train_frames)
134 | distribution = spatial_temporal_distribution(train_cam, train_label_order, train_frames)
135 |
136 | # for i in range(0,8):
137 | # for j in range(0,8):
138 | # print("gauss "+str(i)+"->"+str(j))
139 | # gauss_smooth(distribution[i][j])
140 |
141 | result = {'distribution':distribution}
142 | scipy.io.savemat('model/'+name+'/'+'pytorch_result2.mat',result)
143 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | from torchvision import models
5 | from torch.autograd import Variable
6 |
7 | ######################################################################
8 | def weights_init_kaiming(m):
9 | classname = m.__class__.__name__
10 | # print(classname)
11 | if classname.find('Conv') != -1:
12 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
13 | elif classname.find('Linear') != -1:
14 | init.kaiming_normal(m.weight.data, a=0, mode='fan_out')
15 | init.constant(m.bias.data, 0.0)
16 | elif classname.find('BatchNorm1d') != -1:
17 | init.normal(m.weight.data, 1.0, 0.02)
18 | init.constant(m.bias.data, 0.0)
19 |
20 | def weights_init_classifier(m):
21 | classname = m.__class__.__name__
22 | if classname.find('Linear') != -1:
23 | init.normal(m.weight.data, std=0.001)
24 | init.constant(m.bias.data, 0.0)
25 |
26 | # Defines the new fc layer and classification layer
27 | # |--Linear--|--bn--|--relu--|--Linear--|
28 | class ClassBlock(nn.Module):
29 | def __init__(self, input_dim, class_num, dropout=True, relu=True, num_bottleneck=512):
30 | super(ClassBlock, self).__init__()
31 | add_block = []
32 | add_block += [nn.Linear(input_dim, num_bottleneck)]
33 | add_block += [nn.BatchNorm1d(num_bottleneck)]
34 | if relu:
35 | add_block += [nn.LeakyReLU(0.1)]
36 | if dropout:
37 | add_block += [nn.Dropout(p=0.5)]
38 | add_block = nn.Sequential(*add_block)
39 | add_block.apply(weights_init_kaiming)
40 |
41 | classifier = []
42 | classifier += [nn.Linear(num_bottleneck, class_num)]
43 | classifier = nn.Sequential(*classifier)
44 | classifier.apply(weights_init_classifier)
45 |
46 | self.add_block = add_block
47 | self.classifier = classifier
48 | def forward(self, x):
49 | x = self.add_block(x)
50 | x = self.classifier(x)
51 | return x
52 |
53 | # Define the ResNet50-based Model
54 | class ft_net(nn.Module):
55 |
56 | def __init__(self, class_num ):
57 | super(ft_net, self).__init__()
58 | model_ft = models.resnet50(pretrained=True)
59 | # avg pooling to global pooling
60 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))
61 | self.model = model_ft
62 | self.classifier = ClassBlock(2048, class_num)
63 |
64 | def forward(self, x):
65 | x = self.model.conv1(x)
66 | x = self.model.bn1(x)
67 | x = self.model.relu(x)
68 | x = self.model.maxpool(x)
69 | x = self.model.layer1(x)
70 | x = self.model.layer2(x)
71 | x = self.model.layer3(x)
72 | x = self.model.layer4(x)
73 | x = self.model.avgpool(x)
74 | x = torch.squeeze(x)
75 | x = self.classifier(x)
76 | return x
77 |
78 | # Define the DenseNet121-based Model
79 | class ft_net_dense(nn.Module):
80 |
81 | def __init__(self, class_num ):
82 | super().__init__()
83 | model_ft = models.densenet121(pretrained=True)
84 | model_ft.features.avgpool = nn.AdaptiveAvgPool2d((1,1))
85 | model_ft.fc = nn.Sequential()
86 | self.model = model_ft
87 | # For DenseNet, the feature dim is 1024
88 | self.classifier = ClassBlock(1024, class_num)
89 |
90 | def forward(self, x):
91 | x = self.model.features(x)
92 | x = torch.squeeze(x)
93 | x = self.classifier(x)
94 | return x
95 |
96 | # Define the ResNet50-based Model (Middle-Concat)
97 | # In the spirit of "The Devil is in the Middle: Exploiting Mid-level Representations for Cross-Domain Instance Matching." Yu, Qian, et al. arXiv:1711.08106 (2017).
98 | class ft_net_middle(nn.Module):
99 |
100 | def __init__(self, class_num ):
101 | super(ft_net_middle, self).__init__()
102 | model_ft = models.resnet50(pretrained=True)
103 | # avg pooling to global pooling
104 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))
105 | self.model = model_ft
106 | self.classifier = ClassBlock(2048+1024, class_num)
107 |
108 | def forward(self, x):
109 | x = self.model.conv1(x)
110 | x = self.model.bn1(x)
111 | x = self.model.relu(x)
112 | x = self.model.maxpool(x)
113 | x = self.model.layer1(x)
114 | x = self.model.layer2(x)
115 | x = self.model.layer3(x)
116 | # x0 n*1024*1*1
117 | x0 = self.model.avgpool(x)
118 | x = self.model.layer4(x)
119 | # x1 n*2048*1*1
120 | x1 = self.model.avgpool(x)
121 | x = torch.cat((x0,x1),1)
122 | x = torch.squeeze(x)
123 | x = self.classifier(x)
124 | return x
125 |
126 | # Part Model proposed in Yifan Sun etal. (2018)
127 | class PCB(nn.Module):
128 | def __init__(self, class_num ):
129 | super(PCB, self).__init__()
130 |
131 | self.part = 6 # We cut the pool5 to 6 parts
132 | model_ft = models.resnet50(pretrained=True)
133 | self.model = model_ft
134 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1))
135 | self.dropout = nn.Dropout(p=0.5)
136 | # remove the final downsample
137 | self.model.layer4[0].downsample[0].stride = (1,1)
138 | self.model.layer4[0].conv2.stride = (1,1)
139 | # define 6 classifiers
140 | for i in range(self.part):
141 | name = 'classifier'+str(i)
142 | setattr(self, name, ClassBlock(2048, class_num, True, False, 256))
143 |
144 | def forward(self, x):
145 | x = self.model.conv1(x)
146 | x = self.model.bn1(x)
147 | x = self.model.relu(x)
148 | x = self.model.maxpool(x)
149 |
150 | x = self.model.layer1(x)
151 | x = self.model.layer2(x)
152 | x = self.model.layer3(x)
153 | x = self.model.layer4(x)
154 | x = self.avgpool(x)
155 | x = self.dropout(x)
156 | part = {}
157 | predict = {}
158 | # get six part feature batchsize*2048*6
159 | for i in range(self.part):
160 | part[i] = torch.squeeze(x[:,:,i])
161 | name = 'classifier'+str(i)
162 | c = getattr(self,name)
163 | predict[i] = c(part[i])
164 |
165 | # sum prediction
166 | #y = predict[0]
167 | #for i in range(self.part-1):
168 | # y += predict[i+1]
169 | y = []
170 | for i in range(self.part):
171 | y.append(predict[i])
172 | return y
173 |
174 | class PCB_test(nn.Module):
175 | def __init__(self,model):
176 | super(PCB_test,self).__init__()
177 | self.part = 6
178 | self.model = model.model
179 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1))
180 | # remove the final downsample
181 | self.model.layer4[0].downsample[0].stride = (1,1)
182 | self.model.layer4[0].conv2.stride = (1,1)
183 |
184 | def forward(self, x):
185 | x = self.model.conv1(x)
186 | x = self.model.bn1(x)
187 | x = self.model.relu(x)
188 | x = self.model.maxpool(x)
189 |
190 | x = self.model.layer1(x)
191 | x = self.model.layer2(x)
192 | x = self.model.layer3(x)
193 | x = self.model.layer4(x)
194 | x = self.avgpool(x)
195 | y = x.view(x.size(0),x.size(1),x.size(2))
196 | return y
197 |
198 | # debug model structure
199 | #net = ft_net(751)
200 | net = ft_net_dense(751)
201 | #print(net)
202 | input = Variable(torch.FloatTensor(8, 3, 224, 224))
203 | output = net(input)
204 | print('net output size:')
205 | print(output.shape)
206 |
--------------------------------------------------------------------------------
/plot_st_distribution.py:
--------------------------------------------------------------------------------
1 | import scipy.io
2 | import torch
3 | import numpy as np
4 | import time
5 | import os
6 | import math
7 |
8 | import matplotlib.pyplot as plt
9 |
10 | # import matplotlib
11 | # matplotlib.use('agg')
12 | # import matplotlib.pyplot as plt
13 |
14 | def gaussian_func(x, u, o=50):
15 | if (o == 0):
16 | print("In gaussian, o shouldn't equel to zero")
17 | return 0
18 | temp1 = 1.0 / (o * math.sqrt(2 * math.pi))
19 | temp2 = -(math.pow(x - u, 2)) / (2 * math.pow(o, 2))
20 | return temp1 * math.exp(temp2)
21 |
22 | def gaussian_func2(x, u, o=50):
23 | temp1 = 1.0 / (o * math.sqrt(2 * math.pi))
24 | temp2 = -(np.power(x - u, 2)) / (2 * np.power(o, 2))
25 | return temp1 * np.exp(temp2)
26 |
27 | def gauss_smooth(arr):
28 | hist_num = len(arr)
29 | vect= np.zeros((hist_num,1))
30 | for i in range(hist_num):
31 | vect[i,0]=i
32 | # gaussian_vect= gaussian_func2(vect,0,1)
33 | gaussian_vect= gaussian_func2(vect,0,50)
34 | matrix = np.zeros((hist_num,hist_num))
35 | # matrix = np.eye(hist_num)
36 | for i in range(hist_num):
37 | for j in range(i,hist_num):
38 | matrix[i][j]=gaussian_vect[j-i]
39 | matrix = matrix+matrix.transpose()
40 | for i in range(hist_num):
41 | matrix[i][i]=matrix[i][i]/2
42 | # for i in range(hist_num):
43 | # for j in range(i):
44 | # matrix[i][j]=gaussian_vect[j]
45 | xxx = np.dot(matrix,arr)
46 | return xxx
47 |
48 | # faster gauss_smooth
49 | def gauss_smooth2(arr):
50 | hist_num = len(arr)
51 | vect= np.zeros((hist_num,1))
52 | for i in range(hist_num):
53 | vect[i,0]=i
54 | # gaussian_vect= gaussian_func2(vect,0,1)
55 | o=5
56 | approximate_delta = 6*o # when x-u>approximate_delta, e.g., 6*o, the gaussian value is approximately equal to 0.
57 | gaussian_vect= gaussian_func2(vect,0,o)
58 | matrix = np.zeros((hist_num,hist_num))
59 | for i in range(hist_num):
60 | k=0
61 | for j in range(i,hist_num):
62 | if k>approximate_delta:
63 | continue
64 | matrix[i][j]=gaussian_vect[j-i]
65 | k=k+1
66 | matrix = matrix+matrix.transpose()
67 | for i in range(hist_num):
68 | matrix[i][i]=matrix[i][i]/2
69 | # for i in range(hist_num):
70 | # for j in range(i):
71 | # matrix[i][j]=gaussian_vect[j]
72 | xxx = np.dot(matrix,arr)
73 | return xxx
74 |
75 |
76 | result2 = scipy.io.loadmat('model/'+'ft_ResNet50_market_pcb'+'/'+'pytorch_result2.mat')
77 | distribution = result2['distribution']
78 |
79 | #############################################################
80 | for i in range(0,8):
81 | for j in range(0,8):
82 | print("gauss "+str(i)+"->"+str(j))
83 | # gauss_smooth(distribution[i][j])
84 | distribution[i][j][:]=gauss_smooth2(distribution[i][j][:])
85 |
86 |
87 | eps = 0.0000001
88 | sum_ = np.sum(distribution,axis=2)
89 | for i in range(1):
90 | for j in range(8):
91 | distribution[i][j][:]=distribution[i][j][:]/(sum_[i][j]+eps)
92 |
93 | # plot:
94 |
95 | for i in range(1):
96 | for j in range(8):
97 | one_distr=distribution[i][j][:]
98 | x=range(0,len(one_distr))
99 | plt.figure(1)
100 | plt.plot(x, one_distr)
101 | # plt.savefig(str(i)+'-'+str(j)+'.jpg')
102 | # plt.show()
103 | # plt.axis([0, 3000, 0, 0.05])
104 | plt.xlim((0,1000))
105 | plt.show()
106 |
--------------------------------------------------------------------------------
/prepare.py:
--------------------------------------------------------------------------------
1 | import os
2 | from shutil import copyfile
3 | import argparse
4 | import shutil
5 |
6 | download_path = "./raw-dataset/DukeMTMC-reID/"
7 |
8 | parser = argparse.ArgumentParser(description='prepare')
9 | parser.add_argument('--Market', action='store_true', help='prepare dataset market1501')
10 | parser.add_argument('--Duke', action='store_true', help='prepare dataset Duke-MTMC')
11 | opt = parser.parse_args()
12 |
13 | if not os.path.isdir(download_path):
14 | print('please change the download_path')
15 |
16 | if opt.Market:
17 | save_path = "./dataset/Market1501_prepare/"
18 | else:
19 | save_path = "./dataset/DukeMTMC_prepare/"
20 |
21 | if not os.path.exists(save_path):
22 | os.makedirs(save_path)
23 | # -----------------------------------------
24 | # query
25 | query_path = download_path + '/query'
26 | query_save_path = save_path + '/query'
27 | if not os.path.exists(query_save_path):
28 | os.makedirs(query_save_path)
29 |
30 | for root, dirs, files in os.walk(query_path, topdown=True):
31 | for name in files:
32 | if not name[-3:] == 'jpg':
33 | continue
34 | ID = name.split('_')
35 | src_path = query_path + '/' + name
36 | dst_path = query_save_path + '/' + ID[0]
37 | if not os.path.isdir(dst_path):
38 | os.mkdir(dst_path)
39 | copyfile(src_path, dst_path + '/' + name)
40 |
41 | # -----------------------------------------
42 | # gallery
43 | gallery_path = download_path + '/bounding_box_test'
44 | gallery_save_path = save_path + '/gallery'
45 | if not os.path.exists(gallery_save_path):
46 | os.makedirs(gallery_save_path)
47 |
48 | for root, dirs, files in os.walk(gallery_path, topdown=True):
49 | for name in files:
50 | if not name[-3:] == 'jpg':
51 | continue
52 | ID = name.split('_')
53 | src_path = gallery_path + '/' + name
54 | dst_path = gallery_save_path + '/' + ID[0]
55 | if not os.path.isdir(dst_path):
56 | os.mkdir(dst_path)
57 | copyfile(src_path, dst_path + '/' + name)
58 |
59 | # ---------------------------------------
60 | # train_all
61 | train_path = download_path + '/bounding_box_train'
62 | train_save_path = save_path + '/train_all'
63 | if not os.path.exists(train_save_path):
64 | os.makedirs(train_save_path)
65 |
66 | for root, dirs, files in os.walk(train_path, topdown=True):
67 | for name in files:
68 | if not name[-3:] == 'jpg':
69 | continue
70 | ID = name.split('_')
71 | src_path = train_path + '/' + name
72 | dst_path = train_save_path + '/' + ID[0]
73 | if not os.path.isdir(dst_path):
74 | os.mkdir(dst_path)
75 | copyfile(src_path, dst_path + '/' + name)
76 |
77 | # ---------------------------------------
78 | # train_val
79 | train_path = download_path + '/bounding_box_train'
80 | train_save_path = save_path + '/train'
81 | val_save_path = save_path + '/val'
82 | if not os.path.exists(train_save_path):
83 | os.makedirs(train_save_path)
84 | os.makedirs(val_save_path)
85 |
86 | for root, dirs, files in os.walk(train_path, topdown=True):
87 | for name in files:
88 | if not name[-3:] == 'jpg':
89 | continue
90 | ID = name.split('_')
91 | src_path = train_path + '/' + name
92 | dst_path = train_save_path + '/' + ID[0]
93 | if not os.path.isdir(dst_path):
94 | os.mkdir(dst_path)
95 | dst_path = val_save_path + '/' + ID[0] # first image is used as val image
96 | os.mkdir(dst_path)
97 | copyfile(src_path, dst_path + '/' + name)
98 |
99 |
100 | # ================================================================================================
101 | # market1501_rename
102 | # ================================================================================================
103 |
104 | def parse_frame(imgname, dict_cam_seq_max={}):
105 | dict_cam_seq_max = {
106 | 11: 72681, 12: 74546, 13: 74881, 14: 74661, 15: 74891, 16: 54346, 17: 0, 18: 0,
107 | 21: 163691, 22: 164677, 23: 98102, 24: 0, 25: 0, 26: 0, 27: 0, 28: 0,
108 | 31: 161708, 32: 161769, 33: 104469, 34: 0, 35: 0, 36: 0, 37: 0, 38: 0,
109 | 41: 72107, 42: 72373, 43: 74810, 44: 74541, 45: 74910, 46: 50616, 47: 0, 48: 0,
110 | 51: 161095, 52: 161724, 53: 103487, 54: 0, 55: 0, 56: 0, 57: 0, 58: 0,
111 | 61: 87551, 62: 131268, 63: 95817, 64: 30952, 65: 0, 66: 0, 67: 0, 68: 0}
112 | fid = imgname.strip().split("_")[0]
113 | cam = int(imgname.strip().split("_")[1][1])
114 | seq = int(imgname.strip().split("_")[1][3])
115 | frame = int(imgname.strip().split("_")[2])
116 | count = imgname.strip().split("_")[-1]
117 | # print(id)
118 | # print(cam) # 1
119 | # print(seq) # 2
120 | # print(frame)
121 | re = 0
122 | for i in range(1, seq):
123 | re = re + dict_cam_seq_max[int(str(cam) + str(i))]
124 | re = re + frame
125 | new_name = str(fid) + "_c" + str(cam) + "_f" + '{:0>7}'.format(str(re)) + "_" + count
126 | # print(new_name)
127 | return new_name
128 |
129 |
130 | def gen_train_all_rename():
131 | path = "./dataset/Market1501_prepare/train_all/"
132 | folderName = []
133 | for root, dirs, files in os.walk(path):
134 | folderName = dirs
135 | break
136 | # print(len(folderName))
137 |
138 | for fname in folderName:
139 | # print(fname)
140 |
141 | if not os.path.exists("./dataset/market_rename/train_all/" + fname):
142 | os.makedirs("./dataset/market_rename/train_all/" + fname)
143 |
144 | img_names = []
145 | for root, dirs, files in os.walk(path + fname):
146 | img_names = files
147 | break
148 | # print(img_names)
149 | # print(len(img_names))
150 | for imgname in img_names:
151 | newname = parse_frame(imgname)
152 | # print(newname)
153 | srcfile = path + fname + "/" + imgname
154 | dstfile = "./dataset/market_rename/train_all/" + fname + "/" + newname
155 | shutil.copyfile(srcfile, dstfile)
156 | # break # 测试一个id
157 |
158 |
159 | def gen_train_rename():
160 | path = "./dataset/Market1501_prepare/train/"
161 | folderName = []
162 | for root, dirs, files in os.walk(path):
163 | folderName = dirs
164 | break
165 | # print(len(folderName))
166 |
167 | for fname in folderName:
168 | # print(fname)
169 |
170 | if not os.path.exists("./dataset/market_rename/train/" + fname):
171 | os.makedirs("./dataset/market_rename/train/" + fname)
172 |
173 | img_names = []
174 | for root, dirs, files in os.walk(path + fname):
175 | img_names = files
176 | break
177 | # print(img_names)
178 | # print(len(img_names))
179 | for imgname in img_names:
180 | newname = parse_frame(imgname)
181 | # print(newname)
182 | srcfile = path + fname + "/" + imgname
183 | dstfile = "./dataset/market_rename/train/" + fname + "/" + newname
184 | shutil.copyfile(srcfile, dstfile)
185 | # break # 测试一个id
186 |
187 |
188 | def gen_val_rename():
189 | path = "./dataset/Market1501_prepare/val/"
190 | folderName = []
191 | for root, dirs, files in os.walk(path):
192 | folderName = dirs
193 | break
194 | # print(len(folderName))
195 |
196 | for fname in folderName:
197 | # print(fname)
198 |
199 | if not os.path.exists("./dataset/market_rename/val/" + fname):
200 | os.makedirs("./dataset/market_rename/val/" + fname)
201 |
202 | img_names = []
203 | for root, dirs, files in os.walk(path + fname):
204 | img_names = files
205 | break
206 | # print(img_names)
207 | # print(len(img_names))
208 | for imgname in img_names:
209 | newname = parse_frame(imgname)
210 | # print(newname)
211 | srcfile = path + fname + "/" + imgname
212 | dstfile = "./dataset/market_rename/val/" + fname + "/" + newname
213 | shutil.copyfile(srcfile, dstfile)
214 | # break # 测试一个id
215 |
216 |
217 | def gen_query_rename():
218 | path = "./dataset/Market1501_prepare/query/"
219 | folderName = []
220 | for root, dirs, files in os.walk(path):
221 | folderName = dirs
222 | break
223 | # print(len(folderName))
224 |
225 | for fname in folderName:
226 | # print(fname)
227 |
228 | if not os.path.exists("./dataset/market_rename/query/" + fname):
229 | os.makedirs("./dataset/market_rename/query/" + fname)
230 |
231 | img_names = []
232 | for root, dirs, files in os.walk(path + fname):
233 | img_names = files
234 | break
235 | # print(img_names)
236 | # print(len(img_names))
237 | for imgname in img_names:
238 | newname = parse_frame(imgname)
239 | # print(newname)
240 | srcfile = path + fname + "/" + imgname
241 | dstfile = "./dataset/market_rename/query/" + fname + "/" + newname
242 | shutil.copyfile(srcfile, dstfile)
243 | # break # 测试一个id
244 |
245 |
246 | def gen_gallery_rename():
247 | path = "./dataset/Market1501_prepare/gallery/"
248 | folderName = []
249 | for root, dirs, files in os.walk(path):
250 | folderName = dirs
251 | break
252 | # print(len(folderName))
253 |
254 | for fname in folderName:
255 | # print(fname)
256 |
257 | if not os.path.exists("./dataset/market_rename/gallery/" + fname):
258 | os.makedirs("./dataset/market_rename/gallery/" + fname)
259 |
260 | img_names = []
261 | for root, dirs, files in os.walk(path + fname):
262 | img_names = files
263 | break
264 | # print(img_names)
265 | # print(len(img_names))
266 | for imgname in img_names:
267 | newname = parse_frame(imgname)
268 | # print(newname)
269 | srcfile = path + fname + "/" + imgname
270 | dstfile = "./dataset/market_rename/gallery/" + fname + "/" + newname
271 | shutil.copyfile(srcfile, dstfile)
272 | # break # 测试一个id
273 |
274 |
275 | if opt.Market:
276 | gen_train_all_rename()
277 | gen_train_rename()
278 | gen_val_rename()
279 | gen_query_rename()
280 | gen_gallery_rename()
281 | shutil.rmtree("./dataset/Market1501_prepare/")
282 | print("Done!")
283 |
--------------------------------------------------------------------------------
/random_erasing.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from torchvision.transforms import *
4 |
5 | from PIL import Image
6 | import random
7 | import math
8 | import numpy as np
9 | import torch
10 |
11 | class RandomErasing(object):
12 | """ Randomly selects a rectangle region in an image and erases its pixels.
13 | 'Random Erasing Data Augmentation' by Zhong et al.
14 | See https://arxiv.org/pdf/1708.04896.pdf
15 | Args:
16 | probability: The probability that the Random Erasing operation will be performed.
17 | sl: Minimum proportion of erased area against input image.
18 | sh: Maximum proportion of erased area against input image.
19 | r1: Minimum aspect ratio of erased area.
20 | mean: Erasing value.
21 | """
22 |
23 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):
24 | self.probability = probability
25 | self.mean = mean
26 | self.sl = sl
27 | self.sh = sh
28 | self.r1 = r1
29 |
30 | def __call__(self, img):
31 |
32 | if random.uniform(0, 1) > self.probability:
33 | return img
34 |
35 | for attempt in range(100):
36 | area = img.size()[1] * img.size()[2]
37 |
38 | target_area = random.uniform(self.sl, self.sh) * area
39 | aspect_ratio = random.uniform(self.r1, 1/self.r1)
40 |
41 | h = int(round(math.sqrt(target_area * aspect_ratio)))
42 | w = int(round(math.sqrt(target_area / aspect_ratio)))
43 |
44 | if w < img.size()[2] and h < img.size()[1]:
45 | x1 = random.randint(0, img.size()[1] - h)
46 | y1 = random.randint(0, img.size()[2] - w)
47 | if img.size()[0] == 3:
48 | img[0, x1:x1+h, y1:y1+w] = self.mean[0]
49 | img[1, x1:x1+h, y1:y1+w] = self.mean[1]
50 | img[2, x1:x1+h, y1:y1+w] = self.mean[2]
51 | else:
52 | img[0, x1:x1+h, y1:y1+w] = self.mean[0]
53 | return img
54 |
55 | return img
56 |
--------------------------------------------------------------------------------
/re_ranking.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2/python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Mon Jun 26 14:46:56 2017
5 | @author: luohao
6 | Modified by Houjing Huang, 2017-12-22.
7 | - This version accepts distance matrix instead of raw features.
8 | - The difference of `/` division between python 2 and 3 is handled.
9 | - numpy.float16 is replaced by numpy.float32 for numerical precision.
10 |
11 | Modified by Zhedong Zheng, 2018-1-12.
12 | - replace sort with topK, which save about 30s.
13 | """
14 |
15 | """
16 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
17 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
18 | Matlab version: https://github.com/zhunzhong07/person-re-ranking
19 | """
20 |
21 | """
22 | API
23 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery]
24 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query]
25 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery]
26 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3)
27 | Returns:
28 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery]
29 | """
30 |
31 |
32 | import numpy as np
33 |
34 | def k_reciprocal_neigh( initial_rank, i, k1):
35 | forward_k_neigh_index = initial_rank[i,:k1+1]
36 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]
37 | fi = np.where(backward_k_neigh_index==i)[0]
38 | return forward_k_neigh_index[fi]
39 |
40 | def re_ranking(query_num, all_dist, k1=20, k2=6, lambda_value=0.3):
41 | # The following naming, e.g. gallery_num, is different from outer scope.
42 | # Don't care about it.
43 | original_dist = all_dist
44 |
45 | original_dist = 2. - 2 * original_dist #np.power(original_dist, 2).astype(np.float32)
46 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0))
47 | V = np.zeros_like(original_dist).astype(np.float32)
48 | #initial_rank = np.argsort(original_dist).astype(np.int32)
49 | # top K1+1
50 | initial_rank = np.argpartition( original_dist, range(1,k1+1) )
51 |
52 | print('all_dist shape:',all_dist.shape)
53 | print('initial_rank shape:',initial_rank.shape)
54 |
55 | # query_num = q_g_dist.shape[0]
56 | all_num = original_dist.shape[0]
57 |
58 | for i in range(all_num):
59 | # k-reciprocal neighbors
60 | k_reciprocal_index = k_reciprocal_neigh( initial_rank, i, k1)
61 | k_reciprocal_expansion_index = k_reciprocal_index
62 | for j in range(len(k_reciprocal_index)):
63 | candidate = k_reciprocal_index[j]
64 | candidate_k_reciprocal_index = k_reciprocal_neigh( initial_rank, candidate, int(np.around(k1/2)))
65 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index):
66 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index)
67 |
68 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
69 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index])
70 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight)
71 |
72 | original_dist = original_dist[:query_num,]
73 | if k2 != 1:
74 | V_qe = np.zeros_like(V,dtype=np.float32)
75 | for i in range(all_num):
76 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0)
77 | V = V_qe
78 | del V_qe
79 | del initial_rank
80 | invIndex = []
81 | for i in range(all_num):
82 | invIndex.append(np.where(V[:,i] != 0)[0])
83 |
84 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32)
85 |
86 | for i in range(query_num):
87 | temp_min = np.zeros(shape=[1,all_num],dtype=np.float32)
88 | indNonZero = np.where(V[i,:] != 0)[0]
89 | indImages = []
90 | indImages = [invIndex[ind] for ind in indNonZero]
91 | for j in range(len(indNonZero)):
92 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])
93 | jaccard_dist[i] = 1-temp_min/(2.-temp_min)
94 |
95 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value
96 | del original_dist
97 | del V
98 | del jaccard_dist
99 | final_dist = final_dist[:query_num,query_num:]
100 | return final_dist
101 |
--------------------------------------------------------------------------------
/readme.txt:
--------------------------------------------------------------------------------
1 | Market1501
2 |
3 | 1. data prepare
4 | python3 prepare --Market
5 | 2. train
6 | python3 train_market.py --PCB --gpu_ids 2 --name ft_ResNet50_pcb_market_e --erasing_p 0.5 --train_all --data_dir "/home/huangpg/st-reid/dataset/market_rename/"
7 | 3. test
8 | python3 test_st_market.py --PCB --gpu_ids 2 --name ft_ResNet50_pcb_market_e --test_dir "/home/huangpg/st-reid/dataset/market_rename/"
9 | 4. st model
10 | python3 gen_st_model.py --name ft_ResNet50_pcb_market_e --data_dir "/home/huangpg/st-reid/dataset/market_rename/"
11 | 5. evaluate
12 | python3 evaluate_st.py --name ft_ResNet50_pcb_market_e
13 | 6. re-rank
14 | python3 gen_rerank_all_scores_mat.py --name ft_ResNet50_pcb_market_e
15 | python3 evaluate_rerank_market.py --name ft_ResNet50_pcb_market_e
16 |
17 | Duke:
18 |
19 | 1. data prepare
20 | python3 prepare --Duke
21 | 2. train
22 | python3 train_duke.py --PCB --gpu_ids 2 --name ft_ResNet50_pcb_duke_e --erasing_p 0.5 --train_all --data_dir "/home/huangpg/st-reid/dataset/DukeMTMC_prepare/"
23 | 3. test
24 | python3 test_st_duke.py --PCB --gpu_ids 2 --name ft_ResNet50_pcb_duke_e --test_dir "/home/huangpg/st-reid/dataset/DukeMTMC_prepare/"
25 | 4. st model
26 | python3 gen_st_model_duke.py --name ft_ResNet50_pcb_duke_e --data_dir "/home/huangpg/st-reid/dataset/DukeMTMC_prepare/"
27 | 5. evaluate
28 | python3 evaluate_st.py --name ft_ResNet50_pcb_duke_e
29 | 6. re-rank
30 | python3 gen_rerank_all_scores_mat.py --name ft_ResNet50_pcb_duke_e
31 | python3 evaluate_rerank_duke.py --name ft_ResNet50_pcb_duke_e
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/test_st_duke.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function, division
4 |
5 | import argparse
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torch.optim import lr_scheduler
10 | from torch.autograd import Variable
11 | import numpy as np
12 | import torchvision
13 | from torchvision import datasets, models, transforms
14 | import time
15 | import os
16 | import scipy.io
17 | from model import ft_net, ft_net_dense, PCB, PCB_test
18 |
19 | ######################################################################
20 | # Options
21 | # --------
22 | parser = argparse.ArgumentParser(description='Training')
23 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2')
24 | parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last')
25 | parser.add_argument('--test_dir',default='/home/zzd/Market/pytorch',type=str, help='./test_data')
26 | parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path')
27 | parser.add_argument('--batchsize', default=32, type=int, help='batchsize')
28 | parser.add_argument('--use_dense', action='store_true', help='use densenet121' )
29 | parser.add_argument('--PCB', action='store_true', help='use PCB' )
30 |
31 | opt = parser.parse_args()
32 |
33 | str_ids = opt.gpu_ids.split(',')
34 | #which_epoch = opt.which_epoch
35 | name = opt.name
36 | test_dir = opt.test_dir
37 |
38 | gpu_ids = []
39 | for str_id in str_ids:
40 | id = int(str_id)
41 | if id >=0:
42 | gpu_ids.append(id)
43 |
44 | # set gpu ids
45 | if len(gpu_ids)>0:
46 | torch.cuda.set_device(gpu_ids[0])
47 |
48 | ######################################################################
49 | # Load Data
50 | # ---------
51 | #
52 | # We will use torchvision and torch.utils.data packages for loading the
53 | # data.
54 | #
55 | data_transforms = transforms.Compose([
56 | transforms.Resize((288,144), interpolation=3),
57 | transforms.ToTensor(),
58 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59 | ############### Ten Crop
60 | #transforms.TenCrop(224),
61 | #transforms.Lambda(lambda crops: torch.stack(
62 | # [transforms.ToTensor()(crop)
63 | # for crop in crops]
64 | # )),
65 | #transforms.Lambda(lambda crops: torch.stack(
66 | # [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop)
67 | # for crop in crops]
68 | # ))
69 | ])
70 |
71 | if opt.PCB:
72 | data_transforms = transforms.Compose([
73 | transforms.Resize((384,192), interpolation=3),
74 | transforms.ToTensor(),
75 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
76 | ])
77 |
78 |
79 | data_dir = test_dir
80 | image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']}
81 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
82 | shuffle=False, num_workers=16) for x in ['gallery','query']}
83 |
84 | class_names = image_datasets['query'].classes
85 | use_gpu = torch.cuda.is_available()
86 |
87 | ######################################################################
88 | # Load model
89 | #---------------------------
90 | def load_network(network):
91 | save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch)
92 | network.load_state_dict(torch.load(save_path))
93 | return network
94 |
95 |
96 | ######################################################################
97 | # Extract feature
98 | # ----------------------
99 | #
100 | # Extract feature from a trained model.
101 | #
102 | def fliplr(img):
103 | '''flip horizontal'''
104 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W
105 | img_flip = img.index_select(3,inv_idx)
106 | return img_flip
107 |
108 | def extract_feature(model,dataloaders):
109 | features = torch.FloatTensor()
110 | count = 0
111 | for data in dataloaders:
112 | img, label = data
113 | n, c, h, w = img.size()
114 | count += n
115 | print(count)
116 | if opt.use_dense:
117 | ff = torch.FloatTensor(n,1024).zero_()
118 | else:
119 | ff = torch.FloatTensor(n,2048).zero_()
120 | if opt.PCB:
121 | ff = torch.FloatTensor(n,2048,6).zero_() # we have four parts
122 | for i in range(2):
123 | if(i==1):
124 | img = fliplr(img)
125 | input_img = Variable(img.cuda())
126 | outputs = model(input_img)
127 | f = outputs.data.cpu()
128 | ff = ff+f
129 | # norm feature
130 | if opt.PCB:
131 | # feature size (n,2048,4)
132 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
133 | ff = ff.div(fnorm.expand_as(ff))
134 | ff = ff.view(ff.size(0), -1)
135 | else:
136 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
137 | ff = ff.div(fnorm.expand_as(ff))
138 |
139 | features = torch.cat((features,ff), 0)
140 | return features
141 |
142 | def get_id(img_path):
143 | camera_id = []
144 | labels = []
145 | frames = []
146 | for path, v in img_path:
147 | filename = path.split('/')[-1]
148 | label = filename[0:4]
149 | camera = filename.split('c')[1]
150 | frame = filename[9:16]
151 | if label[0:2]=='-1':
152 | labels.append(-1)
153 | else:
154 | labels.append(int(label))
155 | camera_id.append(int(camera[0]))
156 | frames.append(int(frame))
157 | return camera_id, labels, frames
158 |
159 | gallery_path = image_datasets['gallery'].imgs
160 | query_path = image_datasets['query'].imgs
161 |
162 | gallery_cam,gallery_label, gallery_frames = get_id(gallery_path)
163 | query_cam,query_label, query_frames = get_id(query_path)
164 |
165 | ######################################################################
166 | # Load Collected data Trained model
167 | class_num=702
168 | # class_num=751
169 | print('-------test-----------')
170 | if opt.use_dense:
171 | model_structure = ft_net_dense(class_num)
172 | else:
173 | model_structure = ft_net(class_num)
174 |
175 | if opt.PCB:
176 | model_structure = PCB(class_num)
177 |
178 | model = load_network(model_structure)
179 |
180 | # Remove the final fc layer and classifier layer
181 | if not opt.PCB:
182 | model.model.fc = nn.Sequential()
183 | model.classifier = nn.Sequential()
184 | else:
185 | model = PCB_test(model)
186 |
187 | # Change to test mode
188 | model = model.eval()
189 | if use_gpu:
190 | model = model.cuda()
191 |
192 | # Extract feature
193 | gallery_feature = extract_feature(model,dataloaders['gallery'])
194 | query_feature = extract_feature(model,dataloaders['query'])
195 |
196 | # Save to Matlab for check
197 | result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'gallery_frames':gallery_frames,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam,'query_frames':query_frames}
198 | scipy.io.savemat('model/'+name+'/'+'pytorch_result.mat',result)
199 |
--------------------------------------------------------------------------------
/test_st_market.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function, division
4 |
5 | import argparse
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torch.optim import lr_scheduler
10 | from torch.autograd import Variable
11 | import numpy as np
12 | import torchvision
13 | from torchvision import datasets, models, transforms
14 | import time
15 | import os
16 | import scipy.io
17 | from model import ft_net, ft_net_dense, PCB, PCB_test
18 |
19 | ######################################################################
20 | # Options
21 | # --------
22 | parser = argparse.ArgumentParser(description='Training')
23 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2')
24 | parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last')
25 | parser.add_argument('--test_dir',default='/home/zzd/Market/pytorch',type=str, help='./test_data')
26 | parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path')
27 | parser.add_argument('--batchsize', default=32, type=int, help='batchsize')
28 | parser.add_argument('--use_dense', action='store_true', help='use densenet121' )
29 | parser.add_argument('--PCB', action='store_true', help='use PCB' )
30 |
31 | opt = parser.parse_args()
32 |
33 | str_ids = opt.gpu_ids.split(',')
34 | #which_epoch = opt.which_epoch
35 | name = opt.name
36 | test_dir = opt.test_dir
37 |
38 | gpu_ids = []
39 | for str_id in str_ids:
40 | id = int(str_id)
41 | if id >=0:
42 | gpu_ids.append(id)
43 |
44 | # set gpu ids
45 | if len(gpu_ids)>0:
46 | torch.cuda.set_device(gpu_ids[0])
47 |
48 | ######################################################################
49 | # Load Data
50 | # ---------
51 | #
52 | # We will use torchvision and torch.utils.data packages for loading the
53 | # data.
54 | #
55 | data_transforms = transforms.Compose([
56 | transforms.Resize((288,144), interpolation=3),
57 | transforms.ToTensor(),
58 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59 | ############### Ten Crop
60 | #transforms.TenCrop(224),
61 | #transforms.Lambda(lambda crops: torch.stack(
62 | # [transforms.ToTensor()(crop)
63 | # for crop in crops]
64 | # )),
65 | #transforms.Lambda(lambda crops: torch.stack(
66 | # [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop)
67 | # for crop in crops]
68 | # ))
69 | ])
70 |
71 | if opt.PCB:
72 | data_transforms = transforms.Compose([
73 | transforms.Resize((384,192), interpolation=3),
74 | transforms.ToTensor(),
75 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
76 | ])
77 |
78 |
79 | data_dir = test_dir
80 | image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']}
81 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
82 | shuffle=False, num_workers=16) for x in ['gallery','query']}
83 |
84 | class_names = image_datasets['query'].classes
85 | use_gpu = torch.cuda.is_available()
86 |
87 | ######################################################################
88 | # Load model
89 | #---------------------------
90 | def load_network(network):
91 | save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch)
92 | network.load_state_dict(torch.load(save_path))
93 | return network
94 |
95 |
96 | ######################################################################
97 | # Extract feature
98 | # ----------------------
99 | #
100 | # Extract feature from a trained model.
101 | #
102 | def fliplr(img):
103 | '''flip horizontal'''
104 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W
105 | img_flip = img.index_select(3,inv_idx)
106 | return img_flip
107 |
108 | def extract_feature(model,dataloaders):
109 | features = torch.FloatTensor()
110 | count = 0
111 | for data in dataloaders:
112 | img, label = data
113 | n, c, h, w = img.size()
114 | count += n
115 | print(count)
116 | if opt.use_dense:
117 | ff = torch.FloatTensor(n,1024).zero_()
118 | else:
119 | ff = torch.FloatTensor(n,2048).zero_()
120 | if opt.PCB:
121 | ff = torch.FloatTensor(n,2048,6).zero_() # we have four parts
122 | for i in range(2):
123 | if(i==1):
124 | img = fliplr(img)
125 | input_img = Variable(img.cuda())
126 | outputs = model(input_img)
127 | f = outputs.data.cpu()
128 | ff = ff+f
129 | # norm feature
130 | if opt.PCB:
131 | # feature size (n,2048,4)
132 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
133 | ff = ff.div(fnorm.expand_as(ff))
134 | ff = ff.view(ff.size(0), -1)
135 | else:
136 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
137 | ff = ff.div(fnorm.expand_as(ff))
138 |
139 | features = torch.cat((features,ff), 0)
140 | return features
141 |
142 | def get_id(img_path):
143 | camera_id = []
144 | labels = []
145 | frames = []
146 | for path, v in img_path:
147 | filename = path.split('/')[-1]
148 | label = filename[0:4]
149 | camera = filename.split('c')[1]
150 | # frame = filename[9:16]
151 | frame = filename.split('_')[2][1:]
152 | if label[0:2]=='-1':
153 | labels.append(-1)
154 | else:
155 | labels.append(int(label))
156 | camera_id.append(int(camera[0]))
157 | frames.append(int(frame))
158 | return camera_id, labels, frames
159 |
160 | gallery_path = image_datasets['gallery'].imgs
161 | query_path = image_datasets['query'].imgs
162 |
163 | gallery_cam,gallery_label, gallery_frames = get_id(gallery_path)
164 | query_cam,query_label, query_frames = get_id(query_path)
165 |
166 | ######################################################################
167 | # Load Collected data Trained model
168 | class_num=751
169 | print('-------test-----------')
170 | if opt.use_dense:
171 | model_structure = ft_net_dense(class_num)
172 | else:
173 | model_structure = ft_net(class_num)
174 |
175 | if opt.PCB:
176 | model_structure = PCB(class_num)
177 |
178 | model = load_network(model_structure)
179 |
180 | # Remove the final fc layer and classifier layer
181 | if not opt.PCB:
182 | model.model.fc = nn.Sequential()
183 | model.classifier = nn.Sequential()
184 | else:
185 | model = PCB_test(model)
186 |
187 | # Change to test mode
188 | model = model.eval()
189 | if use_gpu:
190 | model = model.cuda()
191 |
192 | # Extract feature
193 | gallery_feature = extract_feature(model,dataloaders['gallery'])
194 | query_feature = extract_feature(model,dataloaders['query'])
195 |
196 | # Save to Matlab for check
197 | result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'gallery_frames':gallery_frames,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam,'query_frames':query_frames}
198 | scipy.io.savemat('model/'+name+'/'+'pytorch_result.mat',result)
199 |
--------------------------------------------------------------------------------
/train_duke.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function, division
4 |
5 | import argparse
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torch.optim import lr_scheduler
10 | from torch.autograd import Variable
11 | import numpy as np
12 | import torchvision
13 | from torchvision import datasets, models, transforms
14 | import matplotlib
15 | matplotlib.use('agg')
16 | import matplotlib.pyplot as plt
17 | from PIL import Image
18 | import time
19 | import os
20 | from model import ft_net, ft_net_dense, PCB
21 | from random_erasing import RandomErasing
22 | import json
23 |
24 | ######################################################################
25 | # Options
26 | # --------
27 | parser = argparse.ArgumentParser(description='Training')
28 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2')
29 | parser.add_argument('--name',default='ft_ResNet50', type=str, help='output model name')
30 | parser.add_argument('--data_dir',default='/home/zzd/Market/pytorch',type=str, help='training dir path')
31 | parser.add_argument('--train_all', action='store_true', help='use all training data' )
32 | parser.add_argument('--color_jitter', action='store_true', help='use color jitter in training' )
33 | parser.add_argument('--batchsize', default=32, type=int, help='batchsize')
34 | parser.add_argument('--erasing_p', default=0, type=float, help='Random Erasing probability, in [0,1]')
35 | parser.add_argument('--use_dense', action='store_true', help='use densenet121' )
36 | parser.add_argument('--PCB', action='store_true', help='use PCB+ResNet50' )
37 | opt = parser.parse_args()
38 |
39 | data_dir = opt.data_dir
40 | name = opt.name
41 | str_ids = opt.gpu_ids.split(',')
42 | gpu_ids = []
43 | for str_id in str_ids:
44 | gid = int(str_id)
45 | if gid >=0:
46 | gpu_ids.append(gid)
47 |
48 | # set gpu ids
49 | if len(gpu_ids)>0:
50 | torch.cuda.set_device(gpu_ids[0])
51 | #print(gpu_ids[0])
52 |
53 | if not os.path.exists("./model/"):
54 | os.makedirs("./model/")
55 |
56 | ######################################################################
57 | # Load Data
58 | # ---------
59 | #
60 |
61 | transform_train_list = [
62 | #transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
63 | transforms.Resize((288,144), interpolation=3),
64 | transforms.RandomCrop((256,128)),
65 | transforms.RandomHorizontalFlip(),
66 | transforms.ToTensor(),
67 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
68 | ]
69 |
70 | transform_val_list = [
71 | transforms.Resize(size=(256,128),interpolation=3), #Image.BICUBIC
72 | transforms.ToTensor(),
73 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
74 | ]
75 |
76 | if opt.PCB:
77 | transform_train_list = [
78 | transforms.Resize((384,192), interpolation=3),
79 | transforms.RandomHorizontalFlip(),
80 | transforms.ToTensor(),
81 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
82 | ]
83 | transform_val_list = [
84 | transforms.Resize(size=(384,192),interpolation=3), #Image.BICUBIC
85 | transforms.ToTensor(),
86 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
87 | ]
88 |
89 | if opt.erasing_p>0:
90 | transform_train_list = transform_train_list + [RandomErasing(probability = opt.erasing_p, mean=[0.0, 0.0, 0.0])]
91 |
92 | if opt.color_jitter:
93 | transform_train_list = [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)] + transform_train_list
94 |
95 | print(transform_train_list)
96 | data_transforms = {
97 | 'train': transforms.Compose( transform_train_list ),
98 | 'val': transforms.Compose(transform_val_list),
99 | }
100 |
101 |
102 | train_all = ''
103 | if opt.train_all:
104 | train_all = '_all'
105 |
106 | image_datasets = {}
107 | image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train' + train_all),
108 | data_transforms['train'])
109 | image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),
110 | data_transforms['val'])
111 |
112 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
113 | shuffle=True, num_workers=16)
114 | for x in ['train', 'val']}
115 | dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
116 | class_names = image_datasets['train'].classes
117 |
118 | use_gpu = torch.cuda.is_available()
119 |
120 | inputs, classes = next(iter(dataloaders['train']))
121 |
122 | ######################################################################
123 | # Training the model
124 | # ------------------
125 | #
126 | # Now, let's write a general function to train a model. Here, we will
127 | # illustrate:
128 | #
129 | # - Scheduling the learning rate
130 | # - Saving the best model
131 | #
132 | # In the following, parameter ``scheduler`` is an LR scheduler object from
133 | # ``torch.optim.lr_scheduler``.
134 |
135 | y_loss = {} # loss history
136 | y_loss['train'] = []
137 | y_loss['val'] = []
138 | y_err = {}
139 | y_err['train'] = []
140 | y_err['val'] = []
141 |
142 | def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
143 | since = time.time()
144 |
145 | best_model_wts = model.state_dict()
146 | best_acc = 0.0
147 |
148 | for epoch in range(num_epochs):
149 | print('Epoch {}/{}'.format(epoch, num_epochs - 1))
150 | print('-' * 10)
151 |
152 | # Each epoch has a training and validation phase
153 | for phase in ['train', 'val']:
154 | if phase == 'train':
155 | scheduler.step()
156 | model.train(True) # Set model to training mode
157 | else:
158 | model.train(False) # Set model to evaluate mode
159 |
160 | running_loss = 0.0
161 | running_corrects = 0
162 | # Iterate over data.
163 | for data in dataloaders[phase]:
164 | # get the inputs
165 | inputs, labels = data
166 | #print(inputs.shape)
167 | # wrap them in Variable
168 | if use_gpu:
169 | inputs = Variable(inputs.cuda())
170 | labels = Variable(labels.cuda())
171 | else:
172 | inputs, labels = Variable(inputs), Variable(labels)
173 |
174 | # zero the parameter gradients
175 | optimizer.zero_grad()
176 |
177 | # forward
178 | outputs = model(inputs)
179 | if not opt.PCB:
180 | _, preds = torch.max(outputs.data, 1)
181 | loss = criterion(outputs, labels)
182 | else:
183 | part = {}
184 | sm = nn.Softmax(dim=1)
185 | num_part = 6
186 | for i in range(num_part):
187 | part[i] = outputs[i]
188 |
189 | score = sm(part[0]) + sm(part[1]) +sm(part[2]) + sm(part[3]) +sm(part[4]) +sm(part[5])
190 | _, preds = torch.max(score.data, 1)
191 |
192 | loss = criterion(part[0], labels)
193 | for i in range(num_part-1):
194 | loss += criterion(part[i+1], labels)
195 |
196 | # backward + optimize only if in training phase
197 | if phase == 'train':
198 | loss.backward()
199 | optimizer.step()
200 |
201 | # statistics
202 | running_loss += loss.item()
203 | running_corrects += torch.sum(preds == labels.data)
204 | # print(running_corrects)
205 |
206 | epoch_loss = running_loss / dataset_sizes[phase]
207 | epoch_acc = (running_corrects.item()) / dataset_sizes[phase]
208 | # print(running_corrects.item())
209 | # print(dataset_sizes[phase])
210 | # print(epoch_acc)
211 |
212 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(
213 | phase, epoch_loss, epoch_acc))
214 |
215 | y_loss[phase].append(epoch_loss)
216 | y_err[phase].append(1.0-epoch_acc)
217 | # deep copy the model
218 | if phase == 'val':
219 | last_model_wts = model.state_dict()
220 | if epoch%10 == 9:
221 | save_network(model, epoch)
222 | draw_curve(epoch)
223 |
224 | print()
225 |
226 | time_elapsed = time.time() - since
227 | print('Training complete in {:.0f}m {:.0f}s'.format(
228 | time_elapsed // 60, time_elapsed % 60))
229 | #print('Best val Acc: {:4f}'.format(best_acc))
230 |
231 | # load best model weights
232 | model.load_state_dict(last_model_wts)
233 | save_network(model, 'last')
234 | return model
235 |
236 |
237 | ######################################################################
238 | # Draw Curve
239 | #---------------------------
240 | x_epoch = []
241 | fig = plt.figure()
242 | ax0 = fig.add_subplot(121, title="loss")
243 | ax1 = fig.add_subplot(122, title="top1err")
244 | def draw_curve(current_epoch):
245 | x_epoch.append(current_epoch)
246 | ax0.plot(x_epoch, y_loss['train'], 'bo-', label='train')
247 | ax0.plot(x_epoch, y_loss['val'], 'ro-', label='val')
248 | ax1.plot(x_epoch, y_err['train'], 'bo-', label='train')
249 | ax1.plot(x_epoch, y_err['val'], 'ro-', label='val')
250 | if current_epoch == 0:
251 | ax0.legend()
252 | ax1.legend()
253 | fig.savefig( os.path.join('./model',name,'train.jpg'))
254 |
255 | ######################################################################
256 | # Save model
257 | #---------------------------
258 | def save_network(network, epoch_label):
259 | save_filename = 'net_%s.pth'% epoch_label
260 | save_path = os.path.join('./model',name,save_filename)
261 | torch.save(network.cpu().state_dict(), save_path)
262 | if torch.cuda.is_available:
263 | network.cuda(gpu_ids[0])
264 |
265 |
266 | ######################################################################
267 | # Finetuning the convnet
268 | # ----------------------
269 | #
270 | # Load a pretrainied model and reset final fully connected layer.
271 | #
272 |
273 | if opt.use_dense:
274 | model = ft_net_dense(len(class_names))
275 | else:
276 | model = ft_net(len(class_names))
277 |
278 | if opt.PCB:
279 | model = PCB(len(class_names))
280 |
281 | print(model)
282 |
283 | if use_gpu:
284 | model = model.cuda()
285 |
286 | criterion = nn.CrossEntropyLoss()
287 |
288 | if not opt.PCB:
289 | ignored_params = list(map(id, model.model.fc.parameters() )) + list(map(id, model.classifier.parameters() ))
290 | base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
291 | optimizer_ft = optim.SGD([
292 | {'params': base_params, 'lr': 0.01},
293 | {'params': model.model.fc.parameters(), 'lr': 0.1},
294 | {'params': model.classifier.parameters(), 'lr': 0.1}
295 | ], weight_decay=5e-4, momentum=0.9, nesterov=True)
296 | else:
297 | ignored_params = list(map(id, model.model.fc.parameters() ))
298 | ignored_params += (list(map(id, model.classifier0.parameters() ))
299 | +list(map(id, model.classifier1.parameters() ))
300 | +list(map(id, model.classifier2.parameters() ))
301 | +list(map(id, model.classifier3.parameters() ))
302 | +list(map(id, model.classifier4.parameters() ))
303 | +list(map(id, model.classifier5.parameters() ))
304 | #+list(map(id, model.classifier6.parameters() ))
305 | #+list(map(id, model.classifier7.parameters() ))
306 | )
307 | base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
308 | optimizer_ft = optim.SGD([
309 | {'params': base_params, 'lr': 0.01},
310 | {'params': model.model.fc.parameters(), 'lr': 0.1},
311 | {'params': model.classifier0.parameters(), 'lr': 0.1},
312 | {'params': model.classifier1.parameters(), 'lr': 0.1},
313 | {'params': model.classifier2.parameters(), 'lr': 0.1},
314 | {'params': model.classifier3.parameters(), 'lr': 0.1},
315 | {'params': model.classifier4.parameters(), 'lr': 0.1},
316 | {'params': model.classifier5.parameters(), 'lr': 0.1},
317 | #{'params': model.classifier6.parameters(), 'lr': 0.01},
318 | #{'params': model.classifier7.parameters(), 'lr': 0.01}
319 | ], weight_decay=5e-4, momentum=0.9, nesterov=True)
320 |
321 | # Decay LR by a factor of 0.1 every 40 epochs
322 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=40, gamma=0.1)
323 |
324 | ######################################################################
325 | # Train and evaluate
326 | # ^^^^^^^^^^^^^^^^^^
327 | #
328 | # It should take around 1-2 hours on GPU.
329 | #
330 | dir_name = os.path.join('./model',name)
331 | if not os.path.isdir(dir_name):
332 | os.mkdir(dir_name)
333 |
334 | # save opts
335 | with open('%s/opts.json'%dir_name,'w') as fp:
336 | json.dump(vars(opt), fp, indent=1)
337 |
338 | model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,
339 | num_epochs=60)
340 |
341 |
--------------------------------------------------------------------------------
/train_market.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function, division
4 |
5 | import argparse
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torch.optim import lr_scheduler
10 | from torch.autograd import Variable
11 | import numpy as np
12 | import torchvision
13 | from torchvision import datasets, models, transforms
14 | import matplotlib
15 | matplotlib.use('agg')
16 | import matplotlib.pyplot as plt
17 | from PIL import Image
18 | import time
19 | import os
20 | from model import ft_net, ft_net_dense, PCB
21 | from random_erasing import RandomErasing
22 | import json
23 |
24 | ######################################################################
25 | # Options
26 | # --------
27 | parser = argparse.ArgumentParser(description='Training')
28 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2')
29 | parser.add_argument('--name',default='ft_ResNet50', type=str, help='output model name')
30 | parser.add_argument('--data_dir',default='/home/zzd/Market/pytorch',type=str, help='training dir path')
31 | parser.add_argument('--train_all', action='store_true', help='use all training data' )
32 | parser.add_argument('--color_jitter', action='store_true', help='use color jitter in training' )
33 | parser.add_argument('--batchsize', default=32, type=int, help='batchsize')
34 | parser.add_argument('--erasing_p', default=0, type=float, help='Random Erasing probability, in [0,1]')
35 | parser.add_argument('--use_dense', action='store_true', help='use densenet121' )
36 | parser.add_argument('--PCB', action='store_true', help='use PCB+ResNet50' )
37 | opt = parser.parse_args()
38 |
39 | data_dir = opt.data_dir
40 | name = opt.name
41 | str_ids = opt.gpu_ids.split(',')
42 | gpu_ids = []
43 | for str_id in str_ids:
44 | gid = int(str_id)
45 | if gid >=0:
46 | gpu_ids.append(gid)
47 |
48 | # set gpu ids
49 | if len(gpu_ids)>0:
50 | torch.cuda.set_device(gpu_ids[0])
51 | #print(gpu_ids[0])
52 |
53 | if not os.path.exists("./model/"):
54 | os.makedirs("./model/")
55 |
56 | ######################################################################
57 | # Load Data
58 | # ---------
59 | #
60 |
61 | transform_train_list = [
62 | #transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
63 | transforms.Resize((288,144), interpolation=3),
64 | transforms.RandomCrop((256,128)),
65 | transforms.RandomHorizontalFlip(),
66 | transforms.ToTensor(),
67 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
68 | ]
69 |
70 | transform_val_list = [
71 | transforms.Resize(size=(256,128),interpolation=3), #Image.BICUBIC
72 | transforms.ToTensor(),
73 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
74 | ]
75 |
76 | if opt.PCB:
77 | transform_train_list = [
78 | transforms.Resize((384,192), interpolation=3),
79 | transforms.RandomHorizontalFlip(),
80 | transforms.ToTensor(),
81 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
82 | ]
83 | transform_val_list = [
84 | transforms.Resize(size=(384,192),interpolation=3), #Image.BICUBIC
85 | transforms.ToTensor(),
86 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
87 | ]
88 |
89 | if opt.erasing_p>0:
90 | transform_train_list = transform_train_list + [RandomErasing(probability = opt.erasing_p, mean=[0.0, 0.0, 0.0])]
91 |
92 | if opt.color_jitter:
93 | transform_train_list = [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)] + transform_train_list
94 |
95 | print(transform_train_list)
96 | data_transforms = {
97 | 'train': transforms.Compose( transform_train_list ),
98 | 'val': transforms.Compose(transform_val_list),
99 | }
100 |
101 |
102 | train_all = ''
103 | if opt.train_all:
104 | train_all = '_all'
105 |
106 | image_datasets = {}
107 | image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train' + train_all),
108 | data_transforms['train'])
109 | image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),
110 | data_transforms['val'])
111 |
112 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
113 | shuffle=True, num_workers=16)
114 | for x in ['train', 'val']}
115 | dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
116 | class_names = image_datasets['train'].classes
117 |
118 | use_gpu = torch.cuda.is_available()
119 |
120 | inputs, classes = next(iter(dataloaders['train']))
121 |
122 | ######################################################################
123 | # Training the model
124 | # ------------------
125 | #
126 | # Now, let's write a general function to train a model. Here, we will
127 | # illustrate:
128 | #
129 | # - Scheduling the learning rate
130 | # - Saving the best model
131 | #
132 | # In the following, parameter ``scheduler`` is an LR scheduler object from
133 | # ``torch.optim.lr_scheduler``.
134 |
135 | y_loss = {} # loss history
136 | y_loss['train'] = []
137 | y_loss['val'] = []
138 | y_err = {}
139 | y_err['train'] = []
140 | y_err['val'] = []
141 |
142 | def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
143 | since = time.time()
144 |
145 | best_model_wts = model.state_dict()
146 | best_acc = 0.0
147 |
148 | for epoch in range(num_epochs):
149 | print('Epoch {}/{}'.format(epoch, num_epochs - 1))
150 | print('-' * 10)
151 |
152 | # Each epoch has a training and validation phase
153 | for phase in ['train', 'val']:
154 | if phase == 'train':
155 | scheduler.step()
156 | model.train(True) # Set model to training mode
157 | else:
158 | model.train(False) # Set model to evaluate mode
159 |
160 | running_loss = 0.0
161 | running_corrects = 0
162 | # Iterate over data.
163 | for data in dataloaders[phase]:
164 | # get the inputs
165 | inputs, labels = data
166 | #print(inputs.shape)
167 | # wrap them in Variable
168 | if use_gpu:
169 | inputs = Variable(inputs.cuda())
170 | labels = Variable(labels.cuda())
171 | else:
172 | inputs, labels = Variable(inputs), Variable(labels)
173 |
174 | # zero the parameter gradients
175 | optimizer.zero_grad()
176 |
177 | # forward
178 | outputs = model(inputs)
179 | if not opt.PCB:
180 | _, preds = torch.max(outputs.data, 1)
181 | loss = criterion(outputs, labels)
182 | else:
183 | part = {}
184 | sm = nn.Softmax(dim=1)
185 | num_part = 6
186 | for i in range(num_part):
187 | part[i] = outputs[i]
188 |
189 | score = sm(part[0]) + sm(part[1]) +sm(part[2]) + sm(part[3]) +sm(part[4]) +sm(part[5])
190 | _, preds = torch.max(score.data, 1)
191 |
192 | loss = criterion(part[0], labels)
193 | for i in range(num_part-1):
194 | loss += criterion(part[i+1], labels)
195 |
196 | # backward + optimize only if in training phase
197 | if phase == 'train':
198 | loss.backward()
199 | optimizer.step()
200 |
201 | # statistics
202 | running_loss += loss.item()
203 | running_corrects += torch.sum(preds == labels.data)
204 | # print(running_corrects)
205 |
206 | epoch_loss = running_loss / dataset_sizes[phase]
207 | epoch_acc = (running_corrects.item()) / dataset_sizes[phase]
208 | # print(running_corrects.item())
209 | # print(dataset_sizes[phase])
210 | # print(epoch_acc)
211 |
212 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(
213 | phase, epoch_loss, epoch_acc))
214 |
215 | y_loss[phase].append(epoch_loss)
216 | y_err[phase].append(1.0-epoch_acc)
217 | # deep copy the model
218 | if phase == 'val':
219 | last_model_wts = model.state_dict()
220 | if epoch%10 == 9:
221 | save_network(model, epoch)
222 | draw_curve(epoch)
223 |
224 | print()
225 |
226 | time_elapsed = time.time() - since
227 | print('Training complete in {:.0f}m {:.0f}s'.format(
228 | time_elapsed // 60, time_elapsed % 60))
229 | #print('Best val Acc: {:4f}'.format(best_acc))
230 |
231 | # load best model weights
232 | model.load_state_dict(last_model_wts)
233 | save_network(model, 'last')
234 | return model
235 |
236 |
237 | ######################################################################
238 | # Draw Curve
239 | #---------------------------
240 | x_epoch = []
241 | fig = plt.figure()
242 | ax0 = fig.add_subplot(121, title="loss")
243 | ax1 = fig.add_subplot(122, title="top1err")
244 | def draw_curve(current_epoch):
245 | x_epoch.append(current_epoch)
246 | ax0.plot(x_epoch, y_loss['train'], 'bo-', label='train')
247 | ax0.plot(x_epoch, y_loss['val'], 'ro-', label='val')
248 | ax1.plot(x_epoch, y_err['train'], 'bo-', label='train')
249 | ax1.plot(x_epoch, y_err['val'], 'ro-', label='val')
250 | if current_epoch == 0:
251 | ax0.legend()
252 | ax1.legend()
253 | fig.savefig( os.path.join('./model',name,'train.jpg'))
254 |
255 | ######################################################################
256 | # Save model
257 | #---------------------------
258 | def save_network(network, epoch_label):
259 | save_filename = 'net_%s.pth'% epoch_label
260 | save_path = os.path.join('./model',name,save_filename)
261 | torch.save(network.cpu().state_dict(), save_path)
262 | if torch.cuda.is_available:
263 | network.cuda(gpu_ids[0])
264 |
265 |
266 | ######################################################################
267 | # Finetuning the convnet
268 | # ----------------------
269 | #
270 | # Load a pretrainied model and reset final fully connected layer.
271 | #
272 |
273 | if opt.use_dense:
274 | model = ft_net_dense(len(class_names))
275 | else:
276 | model = ft_net(len(class_names))
277 |
278 | if opt.PCB:
279 | model = PCB(len(class_names))
280 |
281 | print(model)
282 |
283 | if use_gpu:
284 | model = model.cuda()
285 |
286 | criterion = nn.CrossEntropyLoss()
287 |
288 | if not opt.PCB:
289 | ignored_params = list(map(id, model.model.fc.parameters() )) + list(map(id, model.classifier.parameters() ))
290 | base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
291 | optimizer_ft = optim.SGD([
292 | {'params': base_params, 'lr': 0.01},
293 | {'params': model.model.fc.parameters(), 'lr': 0.1},
294 | {'params': model.classifier.parameters(), 'lr': 0.1}
295 | ], weight_decay=5e-4, momentum=0.9, nesterov=True)
296 | else:
297 | ignored_params = list(map(id, model.model.fc.parameters() ))
298 | ignored_params += (list(map(id, model.classifier0.parameters() ))
299 | +list(map(id, model.classifier1.parameters() ))
300 | +list(map(id, model.classifier2.parameters() ))
301 | +list(map(id, model.classifier3.parameters() ))
302 | +list(map(id, model.classifier4.parameters() ))
303 | +list(map(id, model.classifier5.parameters() ))
304 | #+list(map(id, model.classifier6.parameters() ))
305 | #+list(map(id, model.classifier7.parameters() ))
306 | )
307 | base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
308 | optimizer_ft = optim.SGD([
309 | {'params': base_params, 'lr': 0.01},
310 | {'params': model.model.fc.parameters(), 'lr': 0.1},
311 | {'params': model.classifier0.parameters(), 'lr': 0.1},
312 | {'params': model.classifier1.parameters(), 'lr': 0.1},
313 | {'params': model.classifier2.parameters(), 'lr': 0.1},
314 | {'params': model.classifier3.parameters(), 'lr': 0.1},
315 | {'params': model.classifier4.parameters(), 'lr': 0.1},
316 | {'params': model.classifier5.parameters(), 'lr': 0.1},
317 | #{'params': model.classifier6.parameters(), 'lr': 0.01},
318 | #{'params': model.classifier7.parameters(), 'lr': 0.01}
319 | ], weight_decay=5e-4, momentum=0.9, nesterov=True)
320 |
321 | # Decay LR by a factor of 0.1 every 40 epochs
322 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=40, gamma=0.1)
323 |
324 | ######################################################################
325 | # Train and evaluate
326 | # ^^^^^^^^^^^^^^^^^^
327 | #
328 | # It should take around 1-2 hours on GPU.
329 | #
330 | dir_name = os.path.join('./model',name)
331 | if not os.path.isdir(dir_name):
332 | os.mkdir(dir_name)
333 |
334 | # save opts
335 | with open('%s/opts.json'%dir_name,'w') as fp:
336 | json.dump(vars(opt), fp, indent=1)
337 |
338 | model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,
339 | num_epochs=60)
340 |
341 |
--------------------------------------------------------------------------------