├── README.md ├── __pycache__ ├── config.cpython-36.pyc ├── config.cpython-37.pyc ├── config.cpython-38.pyc ├── data.cpython-36.pyc ├── data.cpython-37.pyc ├── data.cpython-38.pyc ├── model.cpython-36.pyc ├── model.cpython-37.pyc ├── non_local_gaussian.cpython-36.pyc ├── non_local_gaussian.cpython-37.pyc ├── sampler.cpython-36.pyc ├── sampler.cpython-37.pyc ├── tripletloss.cpython-36.pyc ├── tripletloss.cpython-37.pyc ├── utils.cpython-36.pyc └── utils.cpython-37.pyc ├── config.py ├── data.py ├── model.py ├── sampler.py ├── test.py ├── train.py ├── tripletloss.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # CRfusionGait_pytorch 2 | This is the code for the "Robust Gait Recognition based on Deep CNNs with Camera and Radar Sensor Fusion". 3 | The details will be updated soon. 4 | The dataset can be found at https://github.com/LanDu-XD/CRfusionGait 5 | ## Citation 6 | If you use this dataset and code, please cite the following papers 7 | ``` 8 | @ARTICLE{10045763, 9 | author={Shi, Yu and Du, Lan and Chen, Xiaoyang and Liao, Xun and Yu, Zengyu and Li, Zenghui and Wang, Chunxin and Xue, Shikun}, 10 | journal={IEEE Internet of Things Journal}, 11 | title={Robust Gait Recognition Based on Deep CNNs With Camera and Radar Sensor Fusion}, 12 | year={2023}, 13 | volume={10}, 14 | number={12}, 15 | pages={10817-10832}, 16 | doi={10.1109/JIOT.2023.3242417}} 17 | 18 | ``` 19 | -------------------------------------------------------------------------------- /__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/data.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/data.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/non_local_gaussian.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/non_local_gaussian.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/non_local_gaussian.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/non_local_gaussian.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/sampler.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/tripletloss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/tripletloss.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/tripletloss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/tripletloss.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuShi1213/CRfusionGait_pytorch/d191af19ccb551ac01240cae71678dadeed359b1/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | datapath ="/media/ai/d899633d-2ef9-4e8d-ac69-eb22cae4d04f/gyc/shiyu/gaitcode-set/GEI_spe/data/geiori" 4 | #spe_path ="/media/ai/d899633d-2ef9-4e8d-ac69-eb22cae4d04f/gyc/shiyu/gaitcode-set/GEI_spe/data/spe" 5 | spe_path ="/media/ai/d899633d-2ef9-4e8d-ac69-eb22cae4d04f/gyc/shiyu/GEI_spe/datasets/spectrogram2.4s/" 6 | identity_list = sorted(os.listdir(datapath)) 7 | train_list = identity_list[0:74] 8 | test_list = identity_list[74:] 9 | resolution = 128 10 | cut_padding = int(float(resolution)/64*10) 11 | batch_size = (2,4) 12 | # sample_type ='random' 13 | frame_num=10 14 | train_start_iteration = 0 15 | hidden_dim = 256 16 | margin = 0.2 17 | lr = 0.0001 18 | total_iter = 400000 19 | model_name = '6-CNN-avgcut-shared-independent-fc' 20 | save_name = '6-CNN-avgcut-shared-independent-fc' 21 | test_probe_condition_list = [['nm-05', 'nm-06'],['bg-01', 'bg-02'],['ct-01', 'ct-02']] 22 | test_gallery_condition_list = [['nm-01', 'nm-02', 'nm-03', 'nm-04']] 23 | angle_list = ['000','030','045','060','090','300','315','330'] 24 | test_restore_iter = 330000 25 | 26 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from config import * 3 | import os 4 | import xarray 5 | import numpy as np 6 | import cv2 7 | from PIL import Image 8 | 9 | class load_data(Dataset): 10 | 11 | def __init__(self,flag): 12 | 13 | super(load_data, self).__init__() 14 | 15 | self.flag = flag 16 | if flag == 'train': 17 | flag_list = train_list 18 | else: 19 | flag_list = test_list 20 | 21 | self.identity_list = [] 22 | self.pic_list = [] 23 | self.spe_list = [] 24 | self.condition_list = [] 25 | self.angle_list = [] 26 | 27 | for identity in flag_list: 28 | identity_path = os.path.join(datapath,identity) 29 | for condition in sorted(os.listdir(identity_path)): 30 | condition_path = os.path.join(identity_path,condition) 31 | for angle in sorted(angle_list): 32 | radar_identity_path = os.path.join(spe_path,identity) 33 | radar_condition_path = os.path.join(radar_identity_path,condition) 34 | spe = radar_condition_path+'/'+identity+'-'+condition+'-'+angle+'.png' 35 | 36 | pic = condition_path+'/'+identity+'-'+condition+'-'+angle+'.png' 37 | if os.path.exists(spe) and os.path.exists(pic): 38 | self.pic_list.append([pic]) 39 | self.spe_list.append([spe]) 40 | self.identity_list.append(identity) 41 | self.condition_list.append(condition) 42 | self.angle_list.append(angle) 43 | 44 | self.data_size = len(self.identity_list) 45 | self.label_set = sorted(list(set(self.identity_list))) 46 | self.condition_set = sorted(list(set(self.condition_list))) 47 | self.angle_set = sorted(list(set(self.angle_list))) 48 | _ = np.zeros((len(self.label_set), 49 | len(self.condition_set), 50 | len(self.angle_set))).astype('int') 51 | 52 | self.index_dict = xarray.DataArray( 53 | _, 54 | coords = {'label':self.label_set, 55 | 'condition':self.condition_set, 56 | 'angle':self.angle_set}, 57 | dims=['label','condition','angle'] 58 | ) 59 | 60 | for i in range(self.data_size): 61 | label = self.identity_list[i] 62 | condition = self.condition_list[i] 63 | angle = self.angle_list[i] 64 | self.index_dict.loc[label,condition,angle] = i 65 | 66 | def __len__(self): 67 | 68 | return len(self.identity_list) 69 | 70 | def process_img(self,path): 71 | 72 | # imgs = sorted(os.listdir(path)) 73 | frame_list = [np.reshape( 74 | cv2.imread(path), 75 | [resolution,resolution,-1])[:,:,0]] 76 | num_list = list(range(len(frame_list))) 77 | data_dict = xarray.DataArray( 78 | frame_list, 79 | coords={'frame':num_list}, 80 | dims = ['frame','img_y','img_x'], 81 | ) 82 | cut_array = data_dict[:,:,cut_padding:-cut_padding].astype('float32') / 255.0 83 | 84 | return cut_array 85 | 86 | def process_spe(self,paths): 87 | 88 | path1 = paths[0] 89 | 90 | 91 | img1 = Image.open(path1) 92 | img1 = np.array(img1.resize((128, 88), Image.ANTIALIAS)).astype('float32')/255.0 93 | 94 | img = np.transpose(img1, (2, 0, 1)) 95 | 96 | return img 97 | 98 | def __getitem__(self, item): 99 | 100 | a = self.pic_list[item] 101 | 102 | data = [self.process_img(path) for path in self.pic_list[item]] 103 | # frame_set = [set(feature.coords['frame'].values.tolist()) for feature in data] 104 | # frame_set = list(set.intersection(*frame_set)) 105 | 106 | spe_data = [self.process_spe(self.spe_list[item])] 107 | 108 | return data,spe_data,self.identity_list[item],\ 109 | self.condition_list[item],self.angle_list[item] 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from non_local_gaussian import NONLocalBlock1D 6 | 7 | ####sigmoid_mlp#### 8 | def conv1d(in_planes,out_planes,kernal_size,has_bias=False,**kwargs): 9 | return nn.Conv1d(in_planes, out_planes,kernal_size,bias=has_bias,**kwargs) 10 | def mlp_sigmoid(in_planes,out_planes,kernel_size,**kwargs): 11 | return nn.Sequential(conv1d(in_planes,in_planes//16,kernel_size,**kwargs), 12 | nn.BatchNorm1d(in_planes//16), 13 | nn.LeakyReLU(inplace=True), 14 | conv1d(in_planes//16,out_planes,kernel_size,**kwargs), 15 | nn.Sigmoid()) 16 | class part_score(nn.Module): 17 | def __init__(self,channel,reduction=2): 18 | super(part_score,self).__init__() 19 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 20 | self.fc = nn.Sequential( 21 | nn.Conv1d(channel, channel//reduction, 1, bias=False), 22 | nn.ReLU(inplace=True), 23 | nn.Conv1d(channel//reduction, channel, 1, bias=False), 24 | nn.Sigmoid() 25 | ) 26 | def forward(self,x): 27 | b,c,p=x.size() 28 | x1 = x.permute(0,2,1) 29 | y = self.avg_pool(x1) 30 | y = self.fc(y).permute(0,2,1) 31 | return x+x*y.expand_as(x) 32 | class BasicConv2d(nn.Module): 33 | def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 34 | super(BasicConv2d, self).__init__() 35 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs) 36 | 37 | def forward(self, x): 38 | x = self.conv(x) 39 | return F.leaky_relu(x, inplace=True) 40 | 41 | 42 | class SetBlock(nn.Module): 43 | 44 | def __init__(self, forward_block, pooling=False): 45 | super(SetBlock, self).__init__() 46 | self.forward_block = forward_block 47 | self.pooling = pooling 48 | if pooling: 49 | self.pool2d = nn.MaxPool2d(2) 50 | 51 | def forward(self, x): 52 | n, s, c, h, w = x.size() 53 | x = self.forward_block(x.view(-1,c,h,w)) 54 | if self.pooling: 55 | x = self.pool2d(x) 56 | _, c, h, w = x.size() 57 | return x.view(n, s, c, h ,w) 58 | 59 | class Conv2d(nn.Module): 60 | def __init__(self, in_channels, out_channels, kernel_size, p_size, p_stride,padding,do_pool): 61 | super(Conv2d, self).__init__() 62 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,padding=padding) 63 | self.norm = nn.LocalResponseNorm(5) 64 | self.relu = nn.ReLU() 65 | self.do_pool = do_pool 66 | self.pool = nn.MaxPool2d(kernel_size=p_size, stride=p_stride) 67 | 68 | def forward(self, x): 69 | x = self.conv(x) 70 | x = self.norm(x) 71 | x = self.relu(x) 72 | if self.do_pool == True: 73 | x = self.pool(x) 74 | return x 75 | 76 | class SetNet(nn.Module): 77 | 78 | def __init__(self, hidden_dim): 79 | super(SetNet, self).__init__() 80 | self.hidden_dim = hidden_dim 81 | self.batch_frame = None 82 | 83 | _set_in_channels = 1 84 | _set_channels = [32, 64, 128] 85 | self.set_layer1 = SetBlock(BasicConv2d(_set_in_channels, _set_channels[0], 7, padding=3)) 86 | self.set_layer2 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[0], 7, padding=3), True) 87 | self.set_layer3 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[1], 5, padding=2)) 88 | self.set_layer4 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[1], 5, padding=2), True) 89 | self.set_layer5 = SetBlock(BasicConv2d(_set_channels[1], _set_channels[2], 3, padding=1)) 90 | self.set_layer6 = SetBlock(BasicConv2d(_set_channels[2], _set_channels[2], 3, padding=1)) 91 | ################################################################# radar spe ################################################### 92 | self.channels = [3,32,64,128] 93 | self.ks = [5,3] 94 | self.p_size = 2 95 | self.p_stride = 2 96 | self.conv1 = Conv2d(self.channels[0],self.channels[1],self.ks[0],self.p_size,self.p_stride,padding=2,do_pool=False) 97 | self.conv2 = Conv2d(self.channels[1],self.channels[1],self.ks[0],self.p_size,self.p_stride,padding=2,do_pool=True) 98 | self.conv3 = Conv2d(self.channels[1],self.channels[2],self.ks[1],self.p_size,self.p_stride,padding=1,do_pool=False) 99 | self.conv4 = Conv2d(self.channels[2],self.channels[2],self.ks[1],self.p_size,self.p_stride,padding=1,do_pool=True) 100 | self.conv5 = Conv2d(self.channels[2],self.channels[3],self.ks[1],self.p_size,self.p_stride,padding=1,do_pool=False) 101 | self.conv6 = Conv2d(self.channels[3],self.channels[3],self.ks[1],self.p_size,self.p_stride,padding=0,do_pool=False) 102 | self.relation = NONLocalBlock1D(128) 103 | self.score = mlp_sigmoid(128,128,1,groups=1) 104 | self.score1 = mlp_sigmoid(64,64,1,groups=1) 105 | self.score2 = mlp_sigmoid(32,32,1,groups=1) 106 | self.score_gei = mlp_sigmoid(128,1,1,groups=1) 107 | self.score_gei1 = mlp_sigmoid(64,1,1,groups=1) 108 | self.score_gei2 = mlp_sigmoid(32,1,1,groups=1) 109 | self.radar_bin_num = [15] 110 | self.radar_bin_num1 = [16] 111 | self.radar_fc_bin = nn.Parameter( 112 | nn.init.xavier_uniform_( torch.zeros(sum([16]), 128, hidden_dim)) 113 | ) 114 | self.radar_fc_bin1 = nn.Parameter( 115 | nn.init.xavier_uniform_( torch.zeros(sum([17]), 32, hidden_dim)) 116 | ) 117 | self.radar_fc_bin2 = nn.Parameter( 118 | nn.init.xavier_uniform_( torch.zeros(sum([17]), 64, hidden_dim)) 119 | ) 120 | ################################################################################################################################### 121 | self.bin_num = [16] 122 | self.bin_num1 = [16] 123 | 124 | self.fc_bin = nn.Parameter( 125 | nn.init.xavier_uniform_( torch.zeros(sum([16]), 128, hidden_dim)) 126 | ) 127 | self.fc_bin1 = nn.Parameter( 128 | nn.init.xavier_uniform_( torch.zeros(sum(self.bin_num1), 32, hidden_dim)) 129 | ) 130 | self.fc_bin2 = nn.Parameter( 131 | nn.init.xavier_uniform_( torch.zeros(sum(self.bin_num), 64, hidden_dim)) 132 | ) 133 | 134 | for m in self.modules(): 135 | if isinstance(m, (nn.Conv2d, nn.Conv1d)): 136 | nn.init.xavier_uniform_(m.weight.data) 137 | elif isinstance(m, nn.Linear): 138 | nn.init.xavier_uniform_(m.weight.data) 139 | nn.init.constant_(m.bias.data, 0.0) 140 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 141 | nn.init.normal(m.weight.data, 1.0, 0.02) 142 | nn.init.constant_(m.bias.data, 0.0) 143 | 144 | def frame_max(self, x): 145 | if self.batch_frame is None: 146 | return torch.max(x, 1) 147 | else: 148 | _tmp = [ 149 | torch.max(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1) 150 | for i in range(len(self.batch_frame) - 1) 151 | ] 152 | max_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0) 153 | arg_max_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0) 154 | return max_list, arg_max_list 155 | 156 | def frame_median(self, x): 157 | if self.batch_frame is None: 158 | return torch.median(x, 1) 159 | else: 160 | _tmp = [ 161 | torch.median(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1) 162 | for i in range(len(self.batch_frame) - 1) 163 | ] 164 | median_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0) 165 | arg_median_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0) 166 | return median_list, arg_median_list 167 | 168 | def forward(self, x,y): 169 | 170 | x = x.unsqueeze(2) 171 | 172 | x = self.set_layer1(x) 173 | x = self.set_layer2(x) 174 | 175 | mid1 = self.frame_max(x)[0] 176 | 177 | x = self.set_layer3(x) 178 | x = self.set_layer4(x) 179 | 180 | mid2 = self.frame_max(x)[0] 181 | 182 | x = self.set_layer5(x) 183 | x = self.set_layer6(x) 184 | x = self.frame_max(x)[0] 185 | gl = x 186 | 187 | final_feature = [] 188 | n,c,h,w = gl.size() 189 | feature = [] 190 | for num_bin in self.bin_num: 191 | z = x.view(n,c,num_bin,-1) 192 | z = z.mean(3) + z.max(3)[0] 193 | #z = self.score_gei(z) 194 | pred_score = self.score_gei(z) 195 | #score = pred_score 196 | long = z.mul(pred_score) 197 | #z = z+long 198 | feature.append(long) 199 | feature = torch.cat(feature,2).permute(2,0,1).contiguous() 200 | feature = feature.matmul(self.fc_bin) 201 | feature = feature.permute(1, 0, 2).contiguous() 202 | final_feature.append(feature) 203 | 204 | n,c,h,w = mid1.size() 205 | feature1 = [] 206 | for num_bin in self.bin_num1: 207 | z = mid1.view(n,c,16,-1) 208 | z = z.mean(3) + z.max(3)[0] 209 | pred_score = self.score_gei2(z) 210 | #score = pred_score.div(pred_score.sum(-1).unsqueeze(1)) 211 | long = z.mul(pred_score) 212 | #z = z+long 213 | feature1.append(long) 214 | feature1 = torch.cat(feature1,2).permute(2,0,1).contiguous() 215 | feature1 = feature1.matmul(self.fc_bin1) 216 | feature1 = feature1.permute(1, 0, 2).contiguous() 217 | final_feature.append(feature1) 218 | 219 | n,c,h,w = mid2.size() 220 | feature2 = [] 221 | for num_bin in self.bin_num: 222 | z = mid2.view(n,c,num_bin,-1) 223 | z = z.mean(3) + z.max(3)[0] 224 | pred_score = self.score_gei1(z) 225 | #score = pred_score.div(pred_score.sum(-1).unsqueeze(1)) 226 | long = z.mul(pred_score) 227 | #z = z+long 228 | feature2.append(long) 229 | feature2 = torch.cat(feature2,2).permute(2,0,1).contiguous() 230 | feature2 = feature2.matmul(self.fc_bin2) 231 | feature2 = feature2.permute(1, 0, 2).contiguous() 232 | final_feature.append(feature2) 233 | 234 | ########################################################################################################################### 235 | y = self.conv1(y) 236 | y = self.conv2(y) 237 | radar_multi1 = y 238 | y = self.conv3(y) 239 | y = self.conv4(y) 240 | radar_multi2 = y 241 | y = self.conv5(y) 242 | y = self.conv6(y) 243 | 244 | n,c,h,w = y.size() 245 | radar_feature = [] 246 | for num_bin in self.radar_bin_num: 247 | z = y.view(n,c,-1,num_bin) 248 | z = z.mean(2) + z.max(2)[0] 249 | pred_score = self.score(z) 250 | long = z.mul(pred_score).sum(-1).div(pred_score.sum(-1)) 251 | long = long.unsqueeze(2) 252 | z1 = torch.cat((z,long),2) 253 | #z = self.relation(z) 254 | radar_feature.append(z1) 255 | radar_feature = torch.cat(radar_feature,2).permute(2,0,1).contiguous() 256 | radar_feature = radar_feature.matmul(self.radar_fc_bin) 257 | radar_feature = radar_feature.permute(1, 0, 2).contiguous() 258 | final_feature.append(radar_feature) 259 | ##multilayer1 260 | n,c,h,w = radar_multi1.size() 261 | radar_feature1 = [] 262 | for num_bin in self.radar_bin_num: 263 | z = radar_multi1.view(n,c,-1,16) 264 | z = z.mean(2) + z.max(2)[0] 265 | pred_score = self.score2(z) 266 | long = z.mul(pred_score).sum(-1).div(pred_score.sum(-1)) 267 | long = long = long.unsqueeze(2) 268 | z = torch.cat((z,long),2) 269 | radar_feature1.append(z) 270 | radar_feature1 = torch.cat(radar_feature1,2).permute(2,0,1).contiguous() 271 | radar_feature1 = radar_feature1.matmul(self.radar_fc_bin1) 272 | radar_feature1 = radar_feature1.permute(1, 0, 2).contiguous() 273 | final_feature.append(radar_feature1) 274 | ##multilayer2 275 | n,c,h,w = radar_multi2.size() 276 | radar_feature2 = [] 277 | for num_bin in self.radar_bin_num: 278 | z = radar_multi2.view(n,c,-1,16) 279 | z = z.mean(2) + z.max(2)[0] 280 | pred_score = self.score1(z) 281 | long = z.mul(pred_score).sum(-1).div(pred_score.sum(-1)) 282 | long = long = long.unsqueeze(2) 283 | z = torch.cat((z,long),2) 284 | radar_feature2.append(z) 285 | radar_feature2 = torch.cat(radar_feature2,2).permute(2,0,1).contiguous() 286 | radar_feature2 = radar_feature2.matmul(self.radar_fc_bin2) 287 | radar_feature2 = radar_feature2.permute(1, 0, 2).contiguous() 288 | final_feature.append(radar_feature2) 289 | 290 | final_feature = torch.cat(final_feature,1) 291 | 292 | return final_feature 293 | -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | import random 3 | 4 | class TripletSampler(Sampler): 5 | 6 | def __init__(self, dataset, batch_size): 7 | self.dataset = dataset 8 | self.batch_size = batch_size 9 | 10 | def __iter__(self): 11 | 12 | while (True): 13 | sample_indices = list() 14 | pid_list = random.sample( 15 | list(self.dataset.label_set), 16 | self.batch_size[0]) 17 | for pid in pid_list: 18 | _index = self.dataset.index_dict.loc[pid, :, :].values 19 | _index = _index[_index > 0].flatten().tolist() 20 | _index = random.choices( 21 | _index, 22 | k=self.batch_size[1]) 23 | sample_indices += _index 24 | yield sample_indices 25 | 26 | def __len__(self): 27 | return self.dataset.data_size -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import numpy as np 3 | import argparse 4 | from config import * 5 | import random 6 | import torch 7 | import math 8 | from model import SetNet 9 | from data import load_data 10 | from utils import collate_fn 11 | import torch.autograd as autograd 12 | import torch.nn.functional as F 13 | 14 | test_batch_size = 8 15 | 16 | def cuda_dist(x, y): 17 | x = torch.from_numpy(x).cuda() 18 | y = torch.from_numpy(y).cuda() 19 | dist = torch.sum(x ** 2, 1).unsqueeze(1) + torch.sum(y ** 2, 1).unsqueeze( 20 | 1).transpose(0, 1) - 2 * torch.matmul(x, y.transpose(0, 1)) 21 | dist = torch.sqrt(F.relu(dist)) 22 | return dist 23 | 24 | def de_diag(acc, each_angle=False): 25 | result = np.sum(acc - np.diag(np.diag(acc)), 1) / 7.0 26 | if not each_angle: 27 | result = np.mean(result) 28 | return result 29 | 30 | def test(): 31 | 32 | model = SetNet(hidden_dim=hidden_dim).float() 33 | model.load_state_dict(torch.load(os.path.join( 34 | 'checkpoint', model_name, 35 | '{}-{:0>5}-encoder-gei.ptm'.format(save_name, test_restore_iter)))) 36 | model.cuda() 37 | model.eval() 38 | print('Complete loading model!') 39 | 40 | test_source = load_data('test') 41 | test_loader = torch.utils.data.DataLoader( 42 | dataset = test_source, 43 | batch_size=test_batch_size, 44 | sampler = torch.utils.data.sampler.SequentialSampler(test_source), 45 | collate_fn = collate_fn, 46 | num_workers = 8 47 | ) 48 | 49 | feature_list = [] 50 | label_list = [] 51 | condition_list = [] 52 | angle_list = [] 53 | for _, (seq,spe,identity,condition,angle) in enumerate(test_loader): 54 | for i in range(len(seq)): 55 | seq[i] = autograd.Variable(torch.from_numpy(seq[i])).cuda().float() 56 | for i in range(len(spe)): 57 | spe[i] = autograd.Variable(torch.from_numpy(spe[i])).cuda().float() 58 | feature = model(*seq,*spe) 59 | a,b,c = feature.size() 60 | feature_list.append(feature.view(a,-1).data.cpu().numpy()) 61 | label_list += identity 62 | condition_list += condition 63 | angle_list += angle 64 | 65 | feature_list = np.concatenate(feature_list,0) 66 | label_list = np.array(label_list) 67 | num_angle = len(set(angle_list)) 68 | angle_set_list = list(set(angle_list)) 69 | 70 | print('Finish Loading data!') 71 | acc_table = np.zeros([len(test_probe_condition_list),num_angle,num_angle]) 72 | for (con,probe_condition) in enumerate(test_probe_condition_list): 73 | # for gallery_condition in test_gallery_condition_list: 74 | for (a1, probe_angle) in enumerate(sorted(angle_set_list)): 75 | for (a2,gallery_angle) in enumerate(sorted(angle_set_list)): 76 | 77 | gallery_mask = np.isin(condition_list,test_gallery_condition_list) & np.isin(angle_list,[gallery_angle]) 78 | gallery_feature = feature_list[gallery_mask,:] 79 | gallery_label = label_list[gallery_mask] 80 | 81 | probe_mask = np.isin(condition_list,probe_condition) & \ 82 | np.isin(angle_list,[probe_angle]) 83 | probe_feature = feature_list[probe_mask,:] 84 | probe_label = label_list[probe_mask] 85 | 86 | dist = cuda_dist(probe_feature,gallery_feature) 87 | a = dist 88 | idx = dist.sort(1)[1].cpu().numpy() 89 | 90 | acc_table[con,a1,a2] = np.round( 91 | np.sum((probe_label == gallery_label[idx[:,0]])>0,\ 92 | 0)*100 / dist.shape[0],2) 93 | 94 | 95 | print('===Rank-%d (Include identical-view cases)===' % (1)) 96 | print('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % ( 97 | np.mean(acc_table[0, :, :]), 98 | np.mean(acc_table[1, :, :]), 99 | np.mean(acc_table[2, :, :]))) 100 | 101 | print('===Rank-%d (Exclude identical-view cases)===' % (1)) 102 | print('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % ( 103 | de_diag(acc_table[0, :, :]), 104 | de_diag(acc_table[1, :, :]), 105 | de_diag(acc_table[2, :, :]))) 106 | 107 | np.set_printoptions(precision=2, floatmode='fixed') 108 | print('===Rank-%d of each angle (Exclude identical-view cases)===' % (1)) 109 | print('NM:', de_diag(acc_table[0, :, :], True)) 110 | print('BG:', de_diag(acc_table[1, :, :], True)) 111 | print('CL:', de_diag(acc_table[2, :, :], True)) 112 | acc_table2 = np.zeros([len(test_probe_condition_list),num_angle]) 113 | for (con,probe_condition) in enumerate(test_probe_condition_list): 114 | for gallery_condition in test_gallery_condition_list: 115 | for (a1, probe_angle) in enumerate(sorted(angle_set_list)): 116 | gallery_mask = np.isin(condition_list,gallery_condition) 117 | gallery_feature = feature_list[gallery_mask,:] 118 | gallery_label = label_list[gallery_mask] # 119 | 120 | probe_mask = np.isin(condition_list,probe_condition) & \ 121 | np.isin(angle_list,[probe_angle]) 122 | probe_feature = feature_list[probe_mask,:] 123 | probe_label = label_list[probe_mask] 124 | dist = cuda_dist(probe_feature,gallery_feature) 125 | idx = dist.sort(1)[1].cpu().numpy() 126 | 127 | acc_table2[con,a1] = np.round( 128 | np.sum((probe_label == gallery_label[idx[:,0]])>0,\ 129 | 0)*100 / dist.shape[0],2) 130 | print('===Rank-%d (Multiple-view cases)===' % (1)) 131 | print('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % ( 132 | np.mean(acc_table2[0, :]), 133 | np.mean(acc_table2[1, :]), 134 | np.mean(acc_table2[2, :]))) 135 | np.set_printoptions(precision=2, floatmode='fixed') 136 | print('===Rank-%d of each angle (Multiple-view cases)===' % (1)) 137 | print('NM:', acc_table2[0, :]) 138 | print('BG:', acc_table2[1, :]) 139 | print('CL:', acc_table2[2, :]) 140 | 141 | if __name__ == '__main__': 142 | test() 143 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | from data import load_data 3 | from sampler import TripletSampler 4 | import torch 5 | import numpy as np 6 | import random 7 | import math 8 | import torch.autograd as autograd 9 | import torch.optim as optim 10 | from model import SetNet 11 | from tripletloss import TripletLoss 12 | from datetime import datetime 13 | import sys 14 | from utils import collate_fn 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 16 | import torch.nn.functional as F 17 | 18 | def cuda_dist(x, y): 19 | x = torch.from_numpy(x).cuda() 20 | y = torch.from_numpy(y).cuda() 21 | dist = torch.sum(x ** 2, 1).unsqueeze(1) + torch.sum(y ** 2, 1).unsqueeze( 22 | 1).transpose(0, 1) - 2 * torch.matmul(x, y.transpose(0, 1)) 23 | dist = torch.sqrt(F.relu(dist)) 24 | return dist 25 | 26 | def de_diag(acc, each_angle=False): 27 | result = np.sum(acc - np.diag(np.diag(acc)), 1) / 7.0 28 | if not each_angle: 29 | result = np.mean(result) 30 | return result 31 | def save(model,optimizer,iteration): 32 | 33 | os.makedirs(os.path.join('checkpoint', model_name), exist_ok=True) 34 | torch.save(model.state_dict(), 35 | os.path.join('checkpoint', model_name, 36 | '{}-{:0>5}-encoder-gei.ptm'.format( 37 | save_name, iteration))) 38 | torch.save(optimizer.state_dict(), 39 | os.path.join('checkpoint', model_name, 40 | '{}-{:0>5}-optimizer-gei.ptm'.format( 41 | save_name, iteration))) 42 | 43 | # restore_iter: iteration index of the checkpoint to load 44 | def load(iteration,model,optimizer): 45 | 46 | model.load_state_dict(torch.load(os.path.join( 47 | 'checkpoint', model_name, 48 | '{}-{:0>5}-encoder-gei.ptm'.format(save_name, iteration)))) 49 | optimizer.load_state_dict(torch.load(os.path.join( 50 | 'checkpoint', model_name, 51 | '{}-{:0>5}-optimizer-gei.ptm'.format(save_name, iteration)))) 52 | 53 | def train(): 54 | 55 | train_source = load_data(flag='train') 56 | triplet_sampler = TripletSampler(train_source,batch_size) 57 | train_loader = torch.utils.data.DataLoader( 58 | dataset = train_source, 59 | batch_sampler = triplet_sampler, 60 | collate_fn = collate_fn, 61 | num_workers = 8 62 | ) 63 | test_source = load_data('test') 64 | test_loader = torch.utils.data.DataLoader( 65 | dataset = test_source, 66 | batch_size=1, 67 | sampler = torch.utils.data.sampler.SequentialSampler(test_source), 68 | collate_fn = collate_fn, 69 | num_workers = 8 70 | ) 71 | 72 | 73 | model = SetNet(hidden_dim=hidden_dim).float() 74 | model.cuda() 75 | model.train() 76 | num_person,num_sample = batch_size 77 | Loss = TripletLoss(num_person*num_sample,margin).cuda() 78 | optimizer = optim.Adam([{'params':model.parameters()}],lr=lr) 79 | 80 | 81 | iteration = train_start_iteration 82 | log_path = './log.txt' 83 | if train_start_iteration != 0: 84 | load(iteration,model,optimizer) 85 | # else: 86 | # with open(log_path, 'w') as f: 87 | # f.write("start" + '\n') 88 | 89 | hard_loss_metric_list = [] 90 | full_loss_metric_list = [] 91 | full_loss_num_list = [] 92 | dist_list = [] 93 | # mean_dist = 0.01 94 | _time1 = datetime.now() 95 | for seq,spe,identity,condition,angle in train_loader: 96 | iteration += 1 97 | optimizer.zero_grad() 98 | for i in range(len(seq)): 99 | seq[i] = autograd.Variable(torch.from_numpy(seq[i])).cuda().float() 100 | for i in range(len(spe)): 101 | spe[i] = autograd.Variable(torch.from_numpy(spe[i])).cuda().float() 102 | label = [train_source.label_set.index(l) for l in identity] 103 | label = autograd.Variable(torch.from_numpy(np.array(label))).cuda().long() 104 | 105 | b = spe 106 | 107 | feature = model(*seq,*spe) 108 | 109 | triplet_feature = feature.permute(1, 0, 2).contiguous() 110 | triplet_label = label.unsqueeze(0).repeat(triplet_feature.size(0), 1) 111 | (full_loss_metric, hard_loss_metric, mean_dist, full_loss_num 112 | ) = Loss(triplet_feature, triplet_label) 113 | loss = full_loss_metric.mean() 114 | 115 | hard_loss_metric_list.append(hard_loss_metric.mean().data.cpu().numpy()) 116 | full_loss_metric_list.append(full_loss_metric.mean().data.cpu().numpy()) 117 | full_loss_num_list.append(full_loss_num.mean().data.cpu().numpy()) 118 | dist_list.append(mean_dist.mean().data.cpu().numpy()) 119 | 120 | if loss > 1e-9: 121 | loss.backward() 122 | optimizer.step() 123 | 124 | if iteration % 1000 == 0: 125 | print(datetime.now() - _time1) 126 | _time1 = datetime.now() 127 | 128 | if iteration % 100 == 0: 129 | print('iter {}:'.format(iteration), end='') 130 | print(', hard_loss_metric={0:.8f}'.format(np.mean(hard_loss_metric_list)), end='') 131 | print(', full_loss_metric={0:.8f}'.format(np.mean(full_loss_metric_list)), end='') 132 | print(', full_loss_num={0:.8f}'.format(np.mean(full_loss_num_list)), end='') 133 | with open(log_path, 'a') as f: 134 | f.write('iter {}:'.format(iteration)+', full_loss_metric={0:.8f}'.format(np.mean(full_loss_metric_list))+'\n') 135 | mean_dist = np.mean(dist_list) 136 | print(', mean_dist={0:.8f}'.format(mean_dist), end='') 137 | print(', lr=%f' % optimizer.param_groups[0]['lr'], end='\n') 138 | # print(', hard or full=%r' % hard_or_full_trip) 139 | sys.stdout.flush() 140 | hard_loss_metric_list = [] 141 | full_loss_metric_list = [] 142 | full_loss_num_list = [] 143 | dist_list = [] 144 | 145 | if iteration % 10000 == 0: 146 | save(model,optimizer,iteration) 147 | if iteration % 10000 == 0: 148 | # validate 149 | model.eval() 150 | with torch.no_grad(): 151 | feature_list = [] 152 | label_list = [] 153 | condition_list = [] 154 | angle_list = [] 155 | for _, (seq,spe,identity,condition,angle) in enumerate(test_loader): 156 | for i in range(len(seq)): 157 | seq[i] = autograd.Variable(torch.from_numpy(seq[i])).cuda().float() 158 | for i in range(len(spe)): 159 | spe[i] = autograd.Variable(torch.from_numpy(spe[i])).cuda().float() 160 | feature = model(*seq,*spe) 161 | a,b,c = feature.size() 162 | feature_list.append(feature.view(a,-1).data.cpu().numpy()) 163 | label_list += identity 164 | condition_list += condition 165 | angle_list += angle 166 | 167 | feature_list = np.concatenate(feature_list,0) 168 | label_list = np.array(label_list) 169 | num_angle = len(set(angle_list)) 170 | angle_set_list = list(set(angle_list)) 171 | 172 | print('Finish Loading data!') 173 | acc_table = np.zeros([len(test_probe_condition_list),num_angle,num_angle]) 174 | for (con,probe_condition) in enumerate(test_probe_condition_list): 175 | # for gallery_condition in test_gallery_condition_list: 176 | for (a1, probe_angle) in enumerate(sorted(angle_set_list)): 177 | for (a2,gallery_angle) in enumerate(sorted(angle_set_list)): 178 | 179 | gallery_mask = np.isin(condition_list,test_gallery_condition_list) & np.isin(angle_list,[gallery_angle]) 180 | gallery_feature = feature_list[gallery_mask,:] 181 | gallery_label = label_list[gallery_mask] 182 | 183 | probe_mask = np.isin(condition_list,probe_condition) & \ 184 | np.isin(angle_list,[probe_angle]) 185 | probe_feature = feature_list[probe_mask,:] 186 | probe_label = label_list[probe_mask] 187 | 188 | dist = cuda_dist(probe_feature,gallery_feature) 189 | a = dist 190 | idx = dist.sort(1)[1].cpu().numpy() 191 | 192 | acc_table[con,a1,a2] = np.round( 193 | np.sum((probe_label == gallery_label[idx[:,0]])>0,\ 194 | 0)*100 / dist.shape[0],2) 195 | 196 | 197 | print('===Rank-%d (Include identical-view cases)===' % (1)) 198 | print('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % ( 199 | np.mean(acc_table[0, :, :]), 200 | np.mean(acc_table[1, :, :]), 201 | np.mean(acc_table[2, :, :]))) 202 | 203 | print('===Rank-%d (Exclude identical-view cases)===' % (1)) 204 | print('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % ( 205 | de_diag(acc_table[0, :, :]), 206 | de_diag(acc_table[1, :, :]), 207 | de_diag(acc_table[2, :, :]))) 208 | 209 | np.set_printoptions(precision=2, floatmode='fixed') 210 | print('===Rank-%d of each angle (Exclude identical-view cases)===' % (1)) 211 | print('NM:', de_diag(acc_table[0, :, :], True)) 212 | print('BG:', de_diag(acc_table[1, :, :], True)) 213 | print('CL:', de_diag(acc_table[2, :, :], True)) 214 | 215 | if iteration == total_iter: 216 | break 217 | 218 | 219 | if __name__ == '__main__': 220 | train() 221 | -------------------------------------------------------------------------------- /tripletloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TripletLoss(nn.Module): 7 | def __init__(self, batch_size, margin): 8 | super(TripletLoss, self).__init__() 9 | self.batch_size = batch_size 10 | self.margin = margin 11 | 12 | def forward(self, feature, label): 13 | # feature: [n, m, d], label: [n, m] 14 | n, m, d = feature.size() 15 | hp_mask = (label.unsqueeze(1) == label.unsqueeze(2)).byte().view(-1).bool()#################################### 16 | hn_mask = (label.unsqueeze(1) != label.unsqueeze(2)).byte().view(-1).bool()#################################### 17 | 18 | dist = self.batch_dist(feature) 19 | mean_dist = dist.mean(1).mean(1) 20 | dist = dist.view(-1) 21 | # hard 22 | hard_hp_dist = torch.max(torch.masked_select(dist, hp_mask).view(n, m, -1), 2)[0] 23 | hard_hn_dist = torch.min(torch.masked_select(dist, hn_mask).view(n, m, -1), 2)[0] 24 | hard_loss_metric = F.relu(self.margin + hard_hp_dist - hard_hn_dist).view(n, -1) 25 | 26 | hard_loss_metric_mean = torch.mean(hard_loss_metric, 1) 27 | 28 | # non-zero full 29 | full_hp_dist = torch.masked_select(dist, hp_mask).view(n, m, -1, 1) 30 | full_hn_dist = torch.masked_select(dist, hn_mask).view(n, m, 1, -1) 31 | full_loss_metric = F.relu(self.margin + full_hp_dist - full_hn_dist).view(n, -1) 32 | 33 | full_loss_metric_sum = full_loss_metric.sum(1) 34 | full_loss_num = (full_loss_metric != 0).sum(1).float() 35 | 36 | full_loss_metric_mean = full_loss_metric_sum / full_loss_num 37 | full_loss_metric_mean[full_loss_num == 0] = 0 38 | 39 | return full_loss_metric_mean, hard_loss_metric_mean, mean_dist, full_loss_num 40 | 41 | def batch_dist(self, x): 42 | x2 = torch.sum(x ** 2, 2) 43 | a = x2.unsqueeze(2) 44 | b = x2.unsqueeze(2).transpose(1, 2) 45 | c = torch.matmul(x, x.transpose(1, 2)) 46 | dist = x2.unsqueeze(2) + x2.unsqueeze(2).transpose(1, 2) - 2 * torch.matmul(x, x.transpose(1, 2)) 47 | dist = torch.sqrt(F.relu(dist)) 48 | return dist -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | import random 3 | import math 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def collate_fn(batch): 9 | 10 | batch_size = len(batch) 11 | feature_num = len(batch[0][0]) 12 | seqs = [batch[i][0] for i in range(batch_size)] 13 | spes = [batch[i][1] for i in range(batch_size)] 14 | label = [batch[i][2] for i in range(batch_size)] 15 | seq_type = [batch[i][3] for i in range(batch_size)] 16 | view = [batch[i][4] for i in range(batch_size)] 17 | batch = [seqs, spes,label, seq_type, view] 18 | seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)] 19 | spes = [np.asarray([spes[i][j] for i in range(batch_size)]) for j in range(feature_num)] 20 | batch[0] = seqs 21 | batch[1] = spes 22 | return batch 23 | --------------------------------------------------------------------------------