├── .idea ├── CNN-RNN2016.iml ├── codeStyles │ ├── Project.xml │ └── codeStyleConfig.xml ├── dictionaries │ └── ying.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── dataprocess ├── __init__.py ├── sampler.py ├── seqtransform.py └── video_loader.py ├── dataset ├── __init__.py └── prid2011.py ├── eval ├── __init__.py ├── eva_functions.py └── evaluator.py ├── individualImage.png ├── models ├── __init__.py └── cnnrnn.py ├── prid_data.py ├── splits_prid2011.json ├── train.py └── utils.py /.idea/CNN-RNN2016.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/codeStyles/Project.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 10 | -------------------------------------------------------------------------------- /.idea/codeStyles/codeStyleConfig.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /.idea/dictionaries/ying.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | dirname 5 | dirnames 6 | identites 7 | imgs 8 | infostruct 9 | prid 10 | tracklet 11 | tracklets 12 | ying 13 | 14 | 15 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CNN-RNN2016 2 | Reimplementation the paper of Recurrent Convolutional Network for Video-based Person Re-Identification in Pytorch 3 | # Preparation 4 | Python 3.6 5 | Pytorch >= 0.4.0 6 | # Result on Prid2011 7 | 8 | |版本| map| rank1 | rank5 | rank10 | rank20 | 9 | | :---: | :---: |:-------:|:---: |:------:| -------:| 10 | | 复现 | 58.8%| 49.4% | 68.5% | 83.1% | 89.9% | 11 | | 原文 | --| 70% | 90% | 95% | 97% | 12 | 13 | 14 | # Problems 15 | P.1 use the official split to form dataset, the dataset is too small. 16 | 17 | train identites: 89, test identites: 89 18 | 19 | => PRID-2011 loaded 20 | 21 | |subset |# ids| # tracklets | 22 | | :---: | :---: |:-------:| 23 | | train | 89| 178 | 24 | | query | 89| 89 | 25 | | gallery | 89| 89 | 26 | | total | 178| 356 | 27 | 28 | P.2 Data Augmentation 29 | 30 | 1.mirror is not implemented 31 | 2.resize the image size from (128, 64) to (256, 128),not same as (64,48) in offical code 32 | 33 | P.3 Training 34 | 35 | 1.If batch-size set to 1, the net will not be convergent. 36 | 2.The dataset is too small, we can change the dataset generation way to extend 37 | the dataset. Maybe like the paper 'Video Person Re-identification with 38 | Competitive Snippet-similarity Aggregation and Co-attentive Snippet Embedding'. 39 | 40 | # Reference 41 | [Recurrent-Convolutional-Video-ReID](https://github.com/niallmcl/Recurrent-Convolutional-Video-ReID) 42 | 43 | [Spatial-Temporal-Pooling-Networks-ReID](https://github.com/YuanLeung/Spatial-Temporal-Pooling-Networks-ReID) 44 | -------------------------------------------------------------------------------- /dataprocess/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # @author = yuci 4 | # Date: 19-3-20 下午8:10 5 | 6 | -------------------------------------------------------------------------------- /dataprocess/sampler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # @author = yuci 4 | # Date: 19-3-25 上午9:42 5 | import numpy as np 6 | import torch 7 | 8 | from torch.utils.data.sampler import (Sampler, SequentialSampler) 9 | from collections import defaultdict 10 | 11 | 12 | # 构成正序列对时,从另一个摄像头中选择对应的行人数据 13 | def No_index(a, b): 14 | assert isinstance(a, list) 15 | return [i for i, j in enumerate(a) if j != b] 16 | 17 | 18 | class RandomPairSampler(Sampler): 19 | 20 | def __init__(self, data_source): 21 | self.data_source = data_source # data_source的结构是一个元组(img_path, pid, cam_id) 22 | self.index_pid = defaultdict(int) # 索引对应的pid :index ---pid 23 | self.pid_cam = defaultdict(list) # pid对应的cam,是个列表 24 | self.pid_index = defaultdict(list) # pid 对应的索引 25 | self.num_samples = len(self.data_source) # 数据集的长度即为采样的数目 178 26 | 27 | for index, (_, pid, cam) in enumerate(self.data_source): 28 | self.index_pid[index] = pid 29 | self.pid_cam[pid].append(cam) 30 | self.pid_index[pid].append(index) 31 | 32 | def __len__(self): 33 | return self.num_samples * 2 # 采样后的数据是原数据集长度的两倍 34 | 35 | def __iter__(self): # 返回正负序列对 36 | indices = torch.randperm(self.num_samples) 37 | ret = [] # 返回(seqA, seqB, target) 38 | for i in range(2*self.num_samples): 39 | 40 | if i % 2 == 0: # positive pair 41 | j = i // 2 42 | j = int(indices[j]) # 确定序列对的第一个序列j 43 | _, j_pid, j_cam = self.data_source[j] 44 | pid_j = self.index_pid[j] 45 | cams = self.pid_cam[pid_j] 46 | index = self.pid_index[pid_j] 47 | select_cams = No_index(cams, j_cam) # 从另一个cam中选择第二个序列 48 | try: 49 | select_camind = np.random.choice(select_cams) 50 | except ValueError: 51 | print(cams) 52 | print(pid_j) 53 | select_ind = index[select_camind] # 选择第二个序列 54 | target = [1, pid_j, pid_j] # 标签信息 55 | ret.append((j, select_ind, target)) 56 | else: # negative pair 57 | p_rand_id = torch.randperm(self.num_samples) 58 | a = int(p_rand_id[0]) # 随机选择 59 | pid_a = self.index_pid[a] 60 | 61 | b = int(p_rand_id[1]) 62 | _, b_pid, b_cam = self.data_source[b] 63 | pid_b = self.index_pid[b] 64 | cams = self.pid_cam[pid_b] 65 | index = self.pid_index[pid_b] 66 | select_cams = No_index(cams, b_cam) 67 | try: 68 | select_camind = np.random.choice(select_cams) 69 | except ValueError: 70 | print(cams) 71 | print(pid_b) 72 | select_ind = index[select_camind] 73 | 74 | target = [-1, pid_a, pid_b] 75 | ret.append((a, select_ind, target)) 76 | return iter(ret) 77 | -------------------------------------------------------------------------------- /dataprocess/seqtransform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # @author = yuci 4 | # Date: 19-3-26 上午9:53 5 | import torch 6 | import math 7 | import random 8 | from PIL import Image, ImageOps 9 | import numpy as np 10 | import cv2 as cv 11 | from utils import to_numpy 12 | 13 | 14 | class Compose(object): 15 | """Composes several transforms together. 16 | 17 | Args: 18 | transforms (List[Transform]): list of transforms to compose. 19 | 20 | Example: 21 | >>> transforms.Compose([ 22 | >>> transforms.CenterCrop(10), 23 | >>> transforms.ToTensor(), 24 | >>> ]) 25 | """ 26 | 27 | def __init__(self, transforms): 28 | self.transforms = transforms 29 | 30 | def __call__(self, seqs): 31 | for t in self.transforms: 32 | seqs = t(seqs) 33 | return seqs 34 | 35 | 36 | class RectScale(object): 37 | def __init__(self, height, width, interpolation=Image.BILINEAR): 38 | self.height = height 39 | self.width = width 40 | self.interpolation = interpolation 41 | 42 | def __call__(self, seqs): # seqs = list[[image0,...,image8]] 43 | modallen = len(seqs) # 1个list 44 | framelen = len(seqs[0]) # 8 45 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] # : [[[], [], [], [], [], [], [], []]] 46 | 47 | for modal_ind, modal in enumerate(seqs): # 遍历modal,总共有1个 48 | for frame_ind, frame in enumerate(modal): # 遍历每一帧图片 49 | w, h = frame.size # w:128,h:256 50 | if h == self.height and w == self.width: 51 | new_seqs[modal_ind][frame_ind] = frame 52 | else: 53 | new_seqs[modal_ind][frame_ind] = frame.resize((self.width, self.height), self.interpolation) 54 | 55 | return new_seqs 56 | 57 | 58 | class RandomSizedRectCrop(object): 59 | def __init__(self, height, width, interpolation=Image.BILINEAR): 60 | self.height = height 61 | self.width = width 62 | self.interpolation = interpolation 63 | 64 | def __call__(self, seqs): 65 | sample_img = seqs[0][0] 66 | for attempt in range(10): 67 | area = sample_img.size[0] * sample_img.size[1] 68 | target_area = random.uniform(0.64, 1.0) * area 69 | aspect_ratio = random.uniform(2, 3) 70 | 71 | h = int(round(math.sqrt(target_area * aspect_ratio))) 72 | w = int(round(math.sqrt(target_area / aspect_ratio))) 73 | 74 | if w <= sample_img.size[0] and h <= sample_img.size[1]: 75 | x1 = random.randint(0, sample_img.size[0] - w) 76 | y1 = random.randint(0, sample_img.size[1] - h) 77 | 78 | sample_img = sample_img.crop((x1, y1, x1 + w, y1 + h)) 79 | assert (sample_img.size == (w, h)) 80 | modallen = len(seqs) 81 | framelen = len(seqs[0]) 82 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 83 | 84 | for modal_ind, modal in enumerate(seqs): 85 | for frame_ind, frame in enumerate(modal): 86 | 87 | frame = frame.crop((x1, y1, x1 + w, y1 + h)) 88 | new_seqs[modal_ind][frame_ind] = frame.resize((self.width, self.height), self.interpolation) 89 | 90 | return new_seqs 91 | 92 | # Fallback 93 | scale = RectScale(self.height, self.width, 94 | interpolation=self.interpolation) 95 | return scale(seqs) 96 | 97 | 98 | class RandomSizedEarser(object): 99 | 100 | def __init__(self, sl=0.02, sh=0.2, asratio=0.3, p=0.5): 101 | self.sl = sl 102 | self.sh = sh 103 | self.asratio = asratio 104 | self.p = p 105 | 106 | def __call__(self, seqs): 107 | modallen = len(seqs) 108 | framelen = len(seqs[0]) 109 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 110 | for modal_ind, modal in enumerate(seqs): 111 | for frame_ind, frame in enumerate(modal): 112 | p1 = random.uniform(0.0, 1.0) 113 | W = frame.size[0] 114 | H = frame.size[1] 115 | area = H * W 116 | 117 | if p1 > self.p: 118 | new_seqs[modal_ind][frame_ind] = frame 119 | else: 120 | gen = True 121 | while gen: 122 | Se = random.uniform(self.sl, self.sh) * area 123 | re = random.uniform(self.asratio, 1 / self.asratio) 124 | He = np.sqrt(Se * re) 125 | We = np.sqrt(Se / re) 126 | xe = random.uniform(0, W - We) 127 | ye = random.uniform(0, H - He) 128 | if xe + We <= W and ye + He <= H and xe > 0 and ye > 0: 129 | x1 = int(np.ceil(xe)) 130 | y1 = int(np.ceil(ye)) 131 | x2 = int(np.floor(x1 + We)) 132 | y2 = int(np.floor(y1 + He)) 133 | part1 = frame.crop((x1, y1, x2, y2)) 134 | Rc = random.randint(0, 255) 135 | Gc = random.randint(0, 255) 136 | Bc = random.randint(0, 255) 137 | I = Image.new('RGB', part1.size, (Rc, Gc, Bc)) 138 | frame.paste(I, part1.size) 139 | break 140 | 141 | new_seqs[modal_ind][frame_ind] = frame 142 | 143 | return new_seqs 144 | 145 | 146 | class RandomHorizontalFlip(object): 147 | """Randomly horizontally flips the given PIL.Image Sequence with a probability of 0.5 148 | """ 149 | def __call__(self, seqs): 150 | if random.random() < 0.5: 151 | modallen = len(seqs) 152 | framelen = len(seqs[0]) 153 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 154 | for modal_ind, modal in enumerate(seqs): 155 | for frame_ind, frame in enumerate(modal): 156 | new_seqs[modal_ind][frame_ind] = frame.transpose(Image.FLIP_LEFT_RIGHT) 157 | return new_seqs 158 | return seqs 159 | 160 | 161 | class ToTensor(object): 162 | 163 | def __call__(self, seqs): 164 | modallen = len(seqs) 165 | framelen = len(seqs[0]) 166 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 167 | pic = seqs[0][0] 168 | 169 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 170 | if pic.mode == 'YCbCr': 171 | nchannel = 3 172 | elif pic.mode == 'I;16': 173 | nchannel = 1 174 | else: 175 | nchannel = len(pic.mode) 176 | 177 | if pic.mode == 'I': 178 | for modal_ind, modal in enumerate(seqs): 179 | for frame_ind, frame in enumerate(modal): 180 | img = torch.from_numpy(np.array(frame, np.int32, copy=False)) 181 | img = img.view(pic.size[1], pic.size[0], nchannel) 182 | new_seqs[modal_ind][frame_ind] = img.transpose(0, 1).transpose(0, 2).contiguous() 183 | 184 | elif pic.mode == 'I;16': 185 | for modal_ind, modal in enumerate(seqs): 186 | for frame_ind, frame in enumerate(modal): 187 | img = torch.from_numpy(np.array(frame, np.int16, copy=False)) 188 | img = img.view(pic.size[1], pic.size[0], nchannel) 189 | new_seqs[modal_ind][frame_ind] = img.transpose(0, 1).transpose(0, 2).contiguous() 190 | else: 191 | for modal_ind, modal in enumerate(seqs): 192 | for frame_ind, frame in enumerate(modal): 193 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(frame.tobytes())) 194 | img = img.view(pic.size[1], pic.size[0], nchannel) # torch.Size([128, 64, 3]) 195 | img = img.transpose(0, 1).transpose(0, 2).contiguous() # torch.Size([3, 128, 64]) 196 | new_seqs[modal_ind][frame_ind] = img.float().div(255) 197 | 198 | return new_seqs 199 | 200 | 201 | class Normalize(object): 202 | """Given mean: (R, G, B) and std: (R, G, B), 203 | will normalize each channel of the torch.*Tensor, i.e. 204 | channel = (channel - mean) / std 205 | """ 206 | def __init__(self, mean, std): 207 | self.mean = mean 208 | self.std = std 209 | 210 | def __call__(self, seqs): 211 | # TODO: make efficient 212 | modallen = len(seqs) 213 | framelen = len(seqs[0]) 214 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 215 | 216 | for modal_ind, modal in enumerate(seqs): 217 | for frame_ind, frame in enumerate(modal): 218 | for t, m, s in zip(frame, self.mean, self.std): 219 | t.sub_(m).div_(s) 220 | new_seqs[modal_ind][frame_ind] = frame 221 | 222 | return new_seqs 223 | 224 | 225 | class ToYUV(object): 226 | 227 | def __call__(self, seqs): 228 | modallen = len(seqs) # 1个list中有2个列表 229 | framelen = len(seqs[0]) 230 | imagePixelData = torch.zeros((framelen, 5, 256, 128)) # torch.Size([16, 5, 128, 64]) 231 | 232 | for i in range(framelen): 233 | fileRGB = seqs[0][i] 234 | fileOF = seqs[1][i] 235 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(fileRGB.tobytes())) # h,w,c 236 | img = img.view(fileRGB.size[1], fileRGB.size[0], 3) # torch.Size([128, 64, 3]) 237 | # img = img.transpose(0, 1).transpose(0, 2).contiguous() # torch.Size([3, 128, 64]) 238 | img = to_numpy(img.float()) # : (3, 256, 128) 239 | # img = cv.cvtColor(img, cv.COLOR_RGB2BGR) 240 | img = cv.cvtColor(img, cv.COLOR_BGR2YUV) 241 | 242 | imgof = torch.ByteTensor(torch.ByteStorage.from_buffer(fileOF.tobytes())) 243 | imgof = imgof.view(fileOF.size[1], fileOF.size[0], 3) # torch.Size([128, 64, 3]) 244 | # imgof = imgof.transpose(0, 1).transpose(0, 2).contiguous().astype(np.float32) # torch.Size([3, 128, 64]) 245 | imgof = to_numpy(imgof.float()) 246 | img_tensor = torch.from_numpy(img) # torch.Size([256, 128, 3]) 247 | imgof_tensor = torch.from_numpy(imgof) # torch.Size([256, 128, 3]) 248 | for c in range(3): 249 | v = torch.sqrt(torch.var(img_tensor[:, :, c])) 250 | m = torch.mean(img_tensor[:, :, c]) 251 | img_tensor[:, :, c] = img_tensor[:, :, c] - m 252 | img_tensor[:, :, c] = img_tensor[:, :, c] / torch.sqrt(v) 253 | imagePixelData[i, c] = img_tensor[:, :, c] 254 | 255 | for j in range(2): 256 | c = j + 1 257 | v = torch.sqrt(torch.var(imgof_tensor[:, :, c])) 258 | m = torch.mean(imgof_tensor[:, :, c]) 259 | imgof_tensor[:, :, c] = imgof_tensor[:, :, c] - m 260 | imgof_tensor[:, :, c] = imgof_tensor[:, :, c] / torch.sqrt(v) 261 | imagePixelData[i, j + 3] = imgof_tensor[:, :, c] 262 | 263 | # for c in range(2): 264 | # 265 | # v = torch.sqrt(torch.var(imgof_tensor[:, :, c])) 266 | # m = torch.mean(imgof_tensor[:, :, c]) 267 | # imgof_tensor[:, :, c] = imgof_tensor[:, :, c] - m 268 | # imgof_tensor[:, :, c] = imgof_tensor[:, :, c] / torch.sqrt(v) 269 | # imagePixelData[i, c + 3] = imgof_tensor[:, :, c] 270 | 271 | return imagePixelData 272 | -------------------------------------------------------------------------------- /dataprocess/video_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # @author = yuci 4 | # Date: 19-3-20 下午8:38 5 | """ 第四步""" 6 | from PIL import Image 7 | import numpy as np 8 | import random 9 | 10 | import torch 11 | from torch.utils.data import Dataset 12 | from utils import to_torch 13 | import os.path as osp 14 | import cv2 as cv 15 | 16 | 17 | class VideoDataset(Dataset): 18 | """形成一个批次的数据 (批次大小,序列长度,channel, height,width)""" 19 | sample_methods = ['random', 'dense'] 20 | 21 | def __init__(self, dataset, seq_len=16, sample='random', transform=None): 22 | self.dataset = dataset # 数据集 23 | self.seq_len = seq_len # 批次数据中的序列长度 24 | self.sample = sample # 采样方法,随机采样16个就结束,还是密集采样所有 25 | self.transform = transform # 对数据集中的图片进行数据增强 26 | self.useOpticalFlow = 1 27 | 28 | def __len__(self): # 获得数据集的长度 29 | return len(self.dataset) 30 | 31 | def getOfPath(self, rgb_path): # (一次1张图片)将图像转换为YUV,并且加入光流信息 32 | root_of = '/home/ying/Desktop/video_reid_mars/data/prid2011sequence/raw/prid2011flow/prid2011flow' 33 | fname_list = rgb_path.split('/') 34 | of_path = osp.join(root_of, fname_list[-3], fname_list[-2], fname_list[-1]) # 光流路径 35 | 36 | return of_path 37 | 38 | def __getitem__(self, item): # item = tuple(103, 102, [1, 51, 51]) 39 | if self.sample == 'random': 40 | """ 41 | Randomly sample seq_len consecutive frames from num frames, 42 | if num is smaller than seq_len, then replicate items. 43 | This sampling strategy is used in training phase. 44 | """ 45 | item0, item1, target = item 46 | img0_paths, pid0, camid0 = self.dataset[item0] 47 | 48 | img1_paths, pid1, camid1 = self.dataset[item1] # 从数据集中获得某个id相应的信息 49 | num0 = len(img0_paths) # 获得这个id下,有多少张图片 73 50 | num1 = len(img1_paths) # 119 51 | # seq0 52 | frame_indices0 = list(range(num0)) # 将这些图片建立索引 [0, 1,...,72] 53 | rand_end0 = max(0, len(frame_indices0) - self.seq_len - 1) # 随机选择的索引起点最大值,56 54 | begin_index0 = random.randint(0, rand_end0) # 序列起点索引 0 55 | end_index0 = min(begin_index0 + self.seq_len, len(frame_indices0)) # 序列终点索引 16 56 | indices0 = frame_indices0[begin_index0:end_index0] # 根据随机选择的索引,确定序列的索引 57 | 58 | for index0 in indices0: # 遍历序列索引, 59 | if len(indices0) >= self.seq_len: 60 | break 61 | indices0.append(index0) # 当序列长度小于采样的长度时,复制序列,直到等于seq_len 62 | 63 | indices0 = np.array(indices0) # 列表转换成数组 64 | imgseq0 = [] # 存放图像序列 65 | flowseq0 = [] 66 | 67 | for index0 in indices0: # [0, 1,...,15] 68 | index = int(index0) # 0 69 | img_paths0 = img0_paths[index] # 获得对应索引index下的图像绝对路径 70 | of_paths0 = self.getOfPath(img_paths0) # 获得对应的光流图片 71 | imgrgb0 = Image.open(img_paths0).convert('RGB') 72 | ofrgb0 = Image.open(of_paths0).convert('RGB') 73 | imgseq0.append(imgrgb0) 74 | flowseq0.append(ofrgb0) 75 | seq0 = [imgseq0, flowseq0] # [['.png','.png']] 76 | if self.transform is not None: # 以序列为单位进行数据增强,同时在transform中将图像转换为YUV格式 77 | seq0 = self.transform(seq0) # 进行数据增强,在transform中将光流数据加上 78 | img_tensor0 = seq0 # torch.Size([16, 5, 128, 64]) 79 | 80 | # seq1 todo: 参照上面的修改 81 | frame_indices1 = list(range(num1)) 82 | rand_end1 = max(0, len(frame_indices1) - self.seq_len - 1) 83 | begin_index1 = random.randint(0, rand_end1) 84 | end_index1 = min(begin_index1 + self.seq_len, len(frame_indices1)) 85 | indices1 = frame_indices1[begin_index1:end_index1] 86 | 87 | for index1 in indices1: 88 | if len(indices1) >= self.seq_len: 89 | break 90 | indices1.append(index1) 91 | 92 | indices1 = np.array(indices1) 93 | imgseq1 = [] 94 | flowseq1 = [] 95 | 96 | for index1 in indices1: 97 | index = int(index1) 98 | img_paths1 = img1_paths[index] 99 | of_paths1 = self.getOfPath(img_paths1) 100 | imgrgb1 = Image.open(img_paths1) 101 | ofrgb1 = Image.open(of_paths1) 102 | imgseq1.append(imgrgb1) 103 | flowseq1.append(ofrgb1) 104 | seq1 = [imgseq1, flowseq1] 105 | if self.transform is not None: 106 | seq1 = self.transform(seq1) 107 | img_tensor1 = seq1 108 | 109 | return img_tensor0, img_tensor1, target 110 | elif self.sample == 'dense': 111 | """ 112 | Sample all frames in a video into a list of clips, 113 | each clip contains seq_len frames, batch_size needs to be set to 1. 114 | This sampling strategy is used in test phase. 115 | """ 116 | img_paths, pid, camid = self.dataset[item] 117 | num = len(img_paths) # 27 118 | cur_index = 0 # 密集采样,起始索引为0 119 | frame_indices = list(range(num)) # 图像帧索引列表 120 | indices_list = [] 121 | 122 | while num - cur_index > self.seq_len: # 当序列总长度-当前索引 > 采样长度,则更新当前索引,一直遍历这个序列 123 | indices_list.append(frame_indices[cur_index:cur_index+self.seq_len]) 124 | cur_index += self.seq_len # 更新当前索引 125 | last_seq = frame_indices[cur_index:] # 最后一个索引不满足采样长度,补足最后一个 126 | for index in last_seq: 127 | if len(last_seq) > self.seq_len: 128 | break 129 | last_seq.append(index) 130 | indices_list.append(last_seq) # 加上最后一个采样长度 131 | 132 | imgs_list = [] 133 | for indices in indices_list: # 遍历每一个采样序列长度 134 | imgseq = [] # 用于存放一个采样序列长度的图片 135 | for index in indices: # 遍历每个采样序列中的每一张图片 136 | index = int(index) 137 | img_path = img_paths[index] 138 | img = Image.open(img_path).convert('RGB') 139 | if self.transform is not None: 140 | img = self.transform(img) 141 | img = img.unsqueeze(0) # [1, 3, 224, 112] 142 | imgseq.append(img) 143 | imgseq = to_torch(imgseq) 144 | imgseq = torch.cat(imgseq, dim=0) # [seq_len, 3, 224, 112]获得一个采样序列的图像 145 | imgs_list.append(imgseq) 146 | imgs_list = tuple(imgs_list) 147 | imgs_array = torch.stack(imgs_list) # 获得整个行人序列的密集采样序列 148 | 149 | return imgs_array, pid, camid 150 | else: 151 | raise KeyError("unknown sample method: {}.".format(self.sample)) 152 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # @author = yuci 4 | # Date: 19-3-18 上午11:01 5 | 6 | 7 | from .prid2011 import PRID 8 | 9 | 10 | def get_sequence(name, *args, **kwargs): 11 | __factory = { 12 | 'prid2011': PRID, 13 | } 14 | 15 | if name not in __factory: 16 | raise KeyError("Unknown dataset", name) 17 | return __factory[name](*args, **kwargs) 18 | -------------------------------------------------------------------------------- /dataset/prid2011.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # @author = yuci 4 | # Date: 19-3-18 下午9:28 5 | import glob 6 | import os.path as osp 7 | import numpy as np 8 | from utils import read_json 9 | 10 | """第三步 数据集形成""" 11 | 12 | 13 | # 用于构建数据集信息 14 | class infostruct(object): 15 | pass 16 | 17 | 18 | # 1.datasequence,数据集的总体情况 19 | class PRID(object): 20 | """ 21 | code mainly from https://github.com/KaiyangZhou/deep-person-reid 22 | """ 23 | root = '/home/ying/Desktop/video_reid_mars/data/prid2011sequence/raw/prid_2011' 24 | split_path = osp.join(root, 'splits_prid2011.json') 25 | cam_a_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_a') 26 | cam_b_path = osp.join(root, 'prid_2011', 'multi_shot', 'cam_b') 27 | 28 | def __init__(self, split_id=0, min_seq_len=0): 29 | self._check_before_run() # 一、检查数据集是否存在 30 | 31 | splits = read_json(self.split_path) # 二、读取split文件,包含数据集分割信息train/test 32 | if split_id >= len(splits): # split是个元组?查询的split_id 不能超过整个分割的长度 33 | raise ValueError("split id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(self.split_path))) 34 | split = splits[split_id] # 根据split_id从splits文件中选出对应的分割split 35 | train_split, test_split = split['train'], split['test'] 36 | print("# train identites: {}, # test identites: {}".format(len(train_split), len(test_split))) 37 | 38 | # 三、根据split信息,处理原始数据集。返回数据集,训练集中tracklet的数量,行人id的数量,每个tracklet中图片的数量 39 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 40 | self._process_data(train_split, cam1=True, cam2=True) 41 | query, num_query_tracklets, num_query_pids, num_imgs_query, query_pid, query_camid = \ 42 | self._process_data2(test_split, cam1=True, cam2=False) 43 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery, gallery_pid, gallery_camid = \ 44 | self._process_data2(test_split, cam1=False, cam2=True) 45 | 46 | # 四、统计下每个视频序列的长度信息 47 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery # 列表 48 | min_num = np.min(num_imgs_per_tracklet) 49 | max_num = np.max(num_imgs_per_tracklet) 50 | avg_num = np.mean(num_imgs_per_tracklet) 51 | 52 | # 五、统计行人id信息 53 | num_total_pids = num_train_pids + num_query_pids + num_gallery_pids # 一个数 54 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 55 | 56 | # 六、封装数据集? 57 | self.train = train 58 | self.query = query 59 | self.gallery = gallery 60 | self.num_train_pids = num_train_pids 61 | self.num_query_pids = num_query_pids 62 | self.num_gallery_pids = num_gallery_pids 63 | 64 | self.queryinfo = infostruct() 65 | self.queryinfo.pid = query_pid 66 | self.queryinfo.camid = query_camid 67 | self.queryinfo.tranum = num_imgs_query 68 | 69 | self.galleryinfo = infostruct() 70 | self.galleryinfo.pid = gallery_pid 71 | self.galleryinfo.camid = gallery_camid 72 | self.galleryinfo.tranum = num_imgs_gallery 73 | 74 | # 七、打印数据集的一些基本信息 75 | print("=> PRID-2011 loaded") 76 | print("Dataset statistics:") 77 | print(" ------------------------------") 78 | print(" subset | # ids | # tracklets") 79 | print(" ------------------------------") 80 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 81 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 82 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 83 | print(" ------------------------------") 84 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 85 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 86 | print(" ------------------------------") 87 | 88 | # check if all files are available 89 | def _check_before_run(self): 90 | if not osp.exists(self.root): 91 | raise RuntimeError("'{}' is not available".format(self.root)) 92 | 93 | # 根据split文件,处理原始数据集,形成对应的训练集,测试集 94 | def _process_data(self, dirnames, cam1=True, cam2=True): 95 | tracklets = [] # 列表,存放每个id的视频序列, 行人id,cam id 96 | num_imgs_per_tracklet = [] # 统计每个tracklet中的图片数目 97 | dirname2pid = {dirname: i for i, dirname in enumerate(dirnames)} # 根据数据集特点,将文件名字转换为连续的id 98 | 99 | for dirname in dirnames: 100 | if cam1: # cam_a 摄像头中数据 101 | person_dir = osp.join(self.cam_a_path, dirname) # 获得对应行人id的文件夹路径 102 | img_names = glob.glob(osp.join(person_dir, '*.png')) # 使用glob函数,获取文件夹中所有图片的路径 103 | assert len(img_names) > 0 104 | img_names = tuple(img_names) # 将所有图片路径存在一个元组中 105 | pid = dirname2pid[dirname] # 根据数据集特点,将文件夹名字转换为行人id 106 | tracklets.append((img_names, pid, 0)) 107 | num_imgs_per_tracklet.append(len(img_names)) 108 | if cam2: 109 | person_dir = osp.join(self.cam_b_path, dirname) # 获得对应行人id的文件夹路径 110 | img_names = glob.glob(osp.join(person_dir, '*.png')) # 使用glob函数,获取文件夹中所有图片的路径 111 | assert len(img_names) > 0 112 | img_names = tuple(img_names) # 将所有图片路径存在一个元组中 113 | pid = dirname2pid[dirname] # 根据数据集特点,将文件夹名字转换为行人id 114 | tracklets.append((img_names, pid, 1)) 115 | num_imgs_per_tracklet.append(len(img_names)) 116 | 117 | num_tracklets = len(tracklets) # 最后统计下视频序列的个数 118 | num_pid = len(dirnames) # 统计行人的id数 119 | 120 | return tracklets, num_tracklets, num_pid, num_imgs_per_tracklet 121 | 122 | def _process_data2(self, dirnames, cam1=True, cam2=True): 123 | tracklets = [] # 列表,存放每个id的视频序列, 行人id,cam id 124 | num_imgs_per_tracklet = [] # 统计每个tracklet中的图片数目 125 | dirname2pid = {dirname: i for i, dirname in enumerate(dirnames)} # 根据数据集特点,将文件名字转换为连续的id 126 | pid_list = [] 127 | camid_list = [] 128 | 129 | for dirname in dirnames: 130 | if cam1: # cam_a 摄像头中数据 131 | person_dir = osp.join(self.cam_a_path, dirname) # 获得对应行人id的文件夹路径 132 | img_names = glob.glob(osp.join(person_dir, '*.png')) # 使用glob函数,获取文件夹中所有图片的路径 133 | assert len(img_names) > 0 134 | img_names = tuple(img_names) # 将所有图片路径存在一个元组中 135 | pid = dirname2pid[dirname] # 根据数据集特点,将文件夹名字转换为行人id 136 | tracklets.append((img_names, pid, 0)) 137 | pid_list.append(pid) 138 | camid_list.append(0) 139 | num_imgs_per_tracklet.append(len(img_names)) 140 | if cam2: 141 | person_dir = osp.join(self.cam_b_path, dirname) # 获得对应行人id的文件夹路径 142 | img_names = glob.glob(osp.join(person_dir, '*.png')) # 使用glob函数,获取文件夹中所有图片的路径 143 | assert len(img_names) > 0 144 | img_names = tuple(img_names) # 将所有图片路径存在一个元组中 145 | pid = dirname2pid[dirname] # 根据数据集特点,将文件夹名字转换为行人id 146 | tracklets.append((img_names, pid, 1)) 147 | pid_list.append(pid) 148 | camid_list.append(1) 149 | num_imgs_per_tracklet.append(len(img_names)) 150 | 151 | num_tracklets = len(tracklets) # 最后统计下视频序列的个数 152 | num_pid = len(dirnames) # 统计行人的id数 153 | 154 | return tracklets, num_tracklets, num_pid, num_imgs_per_tracklet, pid_list, camid_list 155 | -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # @author = yuci 4 | # Date: 19-3-24 下午9:20 5 | from .eva_functions import evaluate 6 | from .evaluator import Evaluator 7 | -------------------------------------------------------------------------------- /eval/eva_functions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # @author = yuci 4 | # Date: 19-3-24 下午9:20 5 | import numpy as np 6 | from utils import to_torch 7 | 8 | 9 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 10 | num_q, num_g = distmat.shape 11 | if num_g < max_rank: 12 | max_rank = num_g 13 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 14 | indices = np.argsort(distmat, axis=1) 15 | 16 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 17 | 18 | # compute cmc curve for each query 19 | all_cmc = [] 20 | all_AP = [] 21 | num_valid_q = 0. 22 | for q_idx in range(num_q): 23 | # get query pid and camid 24 | q_pid = q_pids[q_idx] 25 | q_camid = q_camids[q_idx] 26 | # remove gallery samples that have the same pid and camid with query 27 | order = indices[q_idx] 28 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 29 | keep = np.invert(remove) 30 | 31 | # compute cmc curve 32 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 33 | 34 | if not np.any(orig_cmc): 35 | # this condition is true when query identity does not appear in gallery 36 | continue 37 | 38 | cmc = orig_cmc.cumsum() 39 | cmc[cmc > 1] = 1 40 | all_cmc.append(cmc[:max_rank]) 41 | num_valid_q += 1. 42 | 43 | # compute average precision 44 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 45 | num_rel = orig_cmc.sum() 46 | tmp_cmc = orig_cmc.cumsum() 47 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 48 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 49 | AP = tmp_cmc.sum() / num_rel 50 | 51 | all_AP.append(AP) 52 | 53 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 54 | 55 | all_cmc = np.asarray(all_cmc).astype(np.float32) 56 | all_cmc = all_cmc.sum(0) / num_valid_q 57 | mAP = np.mean(all_AP) 58 | 59 | return all_cmc, mAP 60 | 61 | 62 | def accuracy(output, target, topk=(1,)): 63 | output, target = to_torch(output), to_torch(target) 64 | maxk = max(topk) 65 | batch_size = target.size(0) 66 | 67 | _, pred = output.topk(maxk, 1, True, True) 68 | pred = pred.t() 69 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 70 | 71 | ret = [] 72 | for k in topk: 73 | correct_k = correct[:k].view(-1).float().sum(0) 74 | ret.append(correct_k.mul_(1. / batch_size)) 75 | return ret 76 | -------------------------------------------------------------------------------- /eval/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import to_torch 3 | from .eva_functions import evaluate 4 | import numpy as np 5 | from torch import nn 6 | 7 | 8 | def evaluate_seq(distmat, query_pids, query_camids, gallery_pids, gallery_camids, cmc_topk=[1, 5, 10, 20]): 9 | query_ids = np.array(query_pids) 10 | gallery_ids = np.array(gallery_pids) 11 | query_cams = np.array(query_camids) 12 | gallery_cams = np.array(gallery_camids) 13 | 14 | cmc_scores, mAP = evaluate(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 15 | print('Mean AP: {:4.1%}'.format(mAP)) 16 | 17 | for r in cmc_topk: 18 | print("Rank-{:<3}: {:.1%}".format(r, cmc_scores[r-1])) 19 | print("------------------") 20 | 21 | # Use the allshots cmc top-1 score for validation criterion 22 | return cmc_scores[0] 23 | 24 | 25 | def pairwise_distance_tensor(query_x, gallery_x): 26 | 27 | m, n = query_x.size(0), gallery_x.size(0) 28 | x = query_x.view(m, -1) 29 | y = gallery_x.view(n, -1) 30 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +\ 31 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 32 | dist.addmm_(1, -2, x, y.t()) 33 | 34 | return dist 35 | 36 | 37 | class Evaluator(object): 38 | 39 | def __init__(self, cnn_model): 40 | super(Evaluator, self).__init__() 41 | self.cnn_model = cnn_model 42 | self.softmax = nn.Softmax(dim=-1) 43 | 44 | def extract_feature(self, data_loader): # 2 45 | # print_freq = 50 46 | self.cnn_model.eval() 47 | 48 | qf = [] 49 | # qf_raw = [] 50 | 51 | for i, inputs in enumerate(data_loader): 52 | imgs, _, _ = inputs 53 | b, n, s, c, h, w = imgs.size() 54 | imgs = imgs.view(b*n, s, c, h, w) 55 | imgs = to_torch(imgs) # torch.Size([8, 8, 3, 256, 128]) 56 | # flows = to_torch(flows) # torch.Size([8, 8, 3, 256, 128]) 57 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 58 | imgs = imgs.to(device) 59 | # flows = flows.to(device) 60 | with torch.no_grad(): 61 | out_feat, out_raw = self.cnn_model(imgs) 62 | allfeatures = out_feat.view(n, -1) # torch.Size([8, 128]) 63 | # allfeatures_raw = out_raw.view(n, -1) # torch.Size([8, 128]) 64 | allfeatures = torch.mean(allfeatures, 0).data.cpu() # 汇总一个序列特征,取平均 65 | # allfeatures_raw = torch.mean(allfeatures_raw, 0).data 66 | qf.append(allfeatures) 67 | # qf_raw.append(allfeatures_raw) 68 | qf = torch.stack(qf) 69 | # qf_raw = torch.stack(allfeatures_raw) 70 | 71 | print("Extracted features for query/gallery set, obtained {}-by-{} matrix" 72 | .format(qf.size(0), qf.size(1))) 73 | return qf 74 | 75 | def evaluate(self, query_loader, gallery_loader, queryinfo, galleryinfo): 76 | # 1 77 | self.cnn_model.eval() 78 | 79 | querypid = queryinfo.pid # 100 ge id : [74, 20, 90, 151, 1, 69, 84, 149, 5, 111, -1, 154, ...] 80 | querycamid = queryinfo.camid # 00000000000 81 | 82 | gallerypid = galleryinfo.pid # : [74, 20, 90, 151, 1, 69, 84, 149, 5, 111, -1, 154, ...] 83 | gallerycamid = galleryinfo.camid # 1111111111 84 | 85 | pooled_probe = self.extract_feature(query_loader) # 1980 * 128 86 | pooled_gallery = self.extract_feature(gallery_loader) 87 | print("Computing distance matrix") 88 | distmat = pairwise_distance_tensor(pooled_probe, pooled_gallery) 89 | 90 | return evaluate_seq(distmat, querypid, querycamid, gallerypid, gallerycamid) 91 | -------------------------------------------------------------------------------- /individualImage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AsuradaYuci/CNN-RNN2016/875a54a7bfa91a05d737e8155258079d5b26a701/individualImage.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # @author = yuci 4 | # Date: 19-3-18 上午11:03 5 | from .cnnrnn import Net 6 | 7 | 8 | __factory = { 9 | 'cnn-rnn': Net, 10 | } 11 | 12 | 13 | def names(): 14 | return sorted(__factory.keys()) 15 | 16 | 17 | def creat(name, *args, **kwargs): 18 | if name not in __factory: 19 | raise KeyError("unknown model:", name) 20 | return __factory[name](*args, **kwargs) 21 | -------------------------------------------------------------------------------- /models/cnnrnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # @author = yuci 4 | # Date: 19-3-21 下午7:29 5 | """CNN-RNN 网络架构 三层CNN 6 | 1.第一层网络:16个卷积核,尺寸为5*5,步长为2;2*2最大池化;tanh激活函数 7 | 2.第二层网络:64个卷积核,尺寸为5*5,步长为2;2*2最大池化;tanh激活函数 8 | 3.第三层网络:64个卷积核,尺寸为5*5,步长为2;2*2最大池化;tanh激活函数 9 | 4.0.5的dropout 10 | 5.128个元素的FC全连接层 11 | 12 | """ 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from torch.nn import functional as F 17 | 18 | 19 | class Net(nn.Module): 20 | def __init__(self, nFilter1, nFilter2, nFilter3, num_person_train, dropout=0.5, num_features=0, seq_len=0, batch=0): 21 | super(Net, self).__init__() 22 | self.batch = batch 23 | self.seq_len = seq_len 24 | self.num_person_train = num_person_train 25 | self.dropout = dropout # 随机失活的概率,0-1 26 | self.num_features = num_features # 输出的特征维度 128 27 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 28 | 29 | self.nFilters = [nFilter1, nFilter2, nFilter3] # 初始化每一层的卷积核个数 30 | self.filter_size = [5, 5, 5] # 卷积核尺寸 31 | 32 | self.poolsize = [2, 2, 2] # 最大池化的尺寸 33 | self.stepsize = [2, 2, 2] # 池化步长 34 | self.padDim = 4 # 零填充 35 | self.input_channel = 5 # 3img + 2optical flow 36 | 37 | # 构建卷积层,nn。Conv2d(输入通道,卷积核的个数,卷积核尺寸,步长,零填充) 38 | self.conv1 = nn.Conv2d(self.input_channel, self.nFilters[0], self.filter_size[0], stride=1, padding=self.padDim) 39 | self.conv2 = nn.Conv2d(self.nFilters[0], self.nFilters[1], self.filter_size[1], stride=1, padding=self.padDim) 40 | self.conv3 = nn.Conv2d(self.nFilters[1], self.nFilters[2], self.filter_size[2], stride=1, padding=self.padDim) 41 | 42 | # 构建最大池化层 43 | self.pooling1 = nn.MaxPool2d(self.poolsize[0], self.stepsize[0]) 44 | self.pooling2 = nn.MaxPool2d(self.poolsize[1], self.stepsize[1]) 45 | self.pooling3 = nn.MaxPool2d(self.poolsize[2], self.stepsize[2]) 46 | 47 | # tanh激活函数 48 | self.tanh = nn.Tanh() 49 | 50 | # FC层 51 | n_fully_connected = 21280 # 根据图片尺寸修改 52 | 53 | self.seq2 = nn.Sequential( 54 | nn.Dropout(self.dropout), 55 | nn.Linear(n_fully_connected, self.num_features) 56 | ) 57 | 58 | # rnn层 59 | self.rnn = nn.RNN(self.num_features, self.num_features) 60 | self.hid_weight = nn.Parameter( 61 | nn.init.xavier_uniform_( 62 | torch.Tensor(1, self.seq_len * self.batch, self.num_features).to(self.device), gain=np.sqrt(2.0) 63 | ), requires_grad=True, 64 | ) 65 | 66 | # final full connectlayer 67 | self.final_FC = nn.Linear(self.num_features, self.num_person_train) 68 | 69 | def build_net(self, input1, input2): 70 | seq1 = nn.Sequential( 71 | self.conv1, self.tanh, self.pooling1, 72 | self.conv2, self.tanh, self.pooling2, 73 | self.conv3, self.tanh, self.pooling3, 74 | ) 75 | b = input1.size(0) # batch的大小 76 | n = input1.size(1) # 1个batch中图片的数目 77 | input1 = input1.view(b*n, input1.size(2), input1.size(3), input1.size(4)) # 78 | input2 = input2.view(b*n, input2.size(2), input2.size(3), input2.size(4)) 79 | inp1_seq1_out = seq1(input1).view(input1.size(0), -1) # torch.Size([16, 32, 35, 19]) 80 | inp2_seq1_out = seq1(input2).view(input2.size(0), -1) # 经过卷积层后的输出 torch.Size([16, 32, 35, 19]) 81 | inp1_seq2_out = self.seq2(inp1_seq1_out).unsqueeze_(0) 82 | inp2_seq2_out = self.seq2(inp2_seq1_out).unsqueeze_(0) # 经过fc层的输出 83 | 84 | inp1_rnn_out, hn1 = self.rnn(inp1_seq2_out, self.hid_weight) 85 | inp2_rnn_out, hn2 = self.rnn(inp2_seq2_out, self.hid_weight) # todo:should debug here 86 | inp1_rnn_out = inp1_rnn_out.view(b, n, -1) # torch.Size([8, 16, 128]) 87 | inp2_rnn_out = inp2_rnn_out.view(b, n, -1) 88 | inp1_rnn_out = inp1_rnn_out.permute(0, 2, 1) 89 | inp2_rnn_out = inp2_rnn_out.permute(0, 2, 1) # 8,128,16 90 | 91 | # 平均池化/最大池化 92 | feature_p = F.max_pool1d(inp1_rnn_out, inp1_rnn_out.size(2)) # 序列特征 8, 128, 1 93 | feature_g = F.max_pool1d(inp2_rnn_out, inp2_rnn_out.size(2)) 94 | 95 | feature_p = feature_p.view(b, self.num_features) # 8,128 96 | feature_g = feature_g.view(b, self.num_features) 97 | 98 | # 分类 99 | identity_p = self.final_FC(feature_p) # 身份特征 torch.Size([8, 89]) 100 | identity_g = self.final_FC(feature_g) 101 | return feature_p, feature_g, identity_p, identity_g 102 | 103 | def forward(self, input1, input2): 104 | feature_p, feature_g, identity_p, identity_g = self.build_net(input1, input2) 105 | return feature_p, feature_g, identity_p, identity_g 106 | 107 | 108 | class Criterion(nn.Module): 109 | def __init__(self, hinge_margin=2): 110 | super(Criterion, self).__init__() 111 | self.hinge_margin = hinge_margin 112 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 113 | 114 | def forward(self, feature_p, feature_g, identity_p, identity_g, target): 115 | dist = nn.PairwiseDistance(p=2) 116 | pair_dist = dist(feature_p, feature_g) # 欧几里得距离 117 | 118 | # 1.折页损失 119 | hing = nn.HingeEmbeddingLoss(margin=self.hinge_margin, reduce=False) 120 | label0 = target[0].to(self.device) 121 | hing_loss = hing(pair_dist, label0) 122 | 123 | # 2.交叉熵损失 124 | nll = nn.CrossEntropyLoss() 125 | label1 = target[1].to(self.device) 126 | label2 = target[2].to(self.device) 127 | loss_p = nll(identity_p, label1) 128 | loss_g = nll(identity_g, label2) 129 | 130 | # 3.损失求和 131 | total_loss = hing_loss + loss_p + loss_g 132 | mean_loss = torch.mean(total_loss) 133 | 134 | return mean_loss 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /prid_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # @author = yuci 4 | # Date: 19-3-18 上午11:04 5 | """第一步 main文件""" 6 | 7 | import torch 8 | import numpy as np 9 | import sys 10 | import argparse 11 | import os.path as osp 12 | import torch.backends.cudnn as cudnn 13 | # import torchvision.transforms as T 14 | from torch.utils.data import DataLoader 15 | import torch.optim as optim 16 | 17 | from utils import Logger, load_checkpoint, save_checkpoint 18 | from dataset import get_sequence 19 | from dataprocess.sampler import RandomPairSampler 20 | from dataprocess.video_loader import VideoDataset 21 | from dataprocess import seqtransform as T 22 | import models 23 | from models.cnnrnn import Criterion 24 | from eval import Evaluator 25 | from train import SEQTrainer 26 | 27 | 28 | # get dataset 数据集准备 29 | def getdata(dataset_name, split_id, batch_size, seq_len, seq_srd, workers): 30 | # todo:将光流信息加上,路径问题,数据增强 31 | # root_rgb = '/home/ying/Desktop/video_reid_mars/data/prid2011sequence/raw/prid_2011/prid_2011/multi_shot' 32 | # root_of = '/home/ying/Desktop/video_reid_mars/data/prid2011sequence/raw/prid2011flow' 33 | dataset = get_sequence(dataset_name, split_id) 34 | train_set = dataset.train 35 | num_classes = dataset.num_train_pids # 89 36 | # normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 37 | # 采用的数据增强:随机裁剪,水平翻转 38 | transform_train = T.Compose([ 39 | T.RectScale(256, 128), 40 | T.RandomSizedEarser(), 41 | T.RandomHorizontalFlip(), 42 | T.ToYUV(), 43 | ]) 44 | transform_query = T.Compose([ 45 | T.RectScale(256, 128), 46 | T.ToYUV(), 47 | ]) 48 | transform_gallery = T.Compose([ 49 | T.RectScale(256, 128), 50 | T.ToYUV(), 51 | ]) 52 | # 对数据集进行处理,封装进Dataloader ==> 类的初始化 53 | train_processor = VideoDataset(dataset.train, seq_len=seq_len, sample='random', transform=transform_train) 54 | query_processor = VideoDataset(dataset.query, seq_len=seq_len, sample='dense', transform=transform_query) 55 | gallery_processor = VideoDataset(dataset.gallery, seq_len=seq_len, sample='dense', transform=transform_gallery) 56 | 57 | train_loader = DataLoader(train_processor, batch_size=batch_size, num_workers=workers, sampler=RandomPairSampler(train_set), pin_memory=True, drop_last=True) 58 | query_loader = DataLoader(query_processor, batch_size=1, num_workers=workers, shuffle=False, pin_memory=True, drop_last=False) 59 | gallery_loader = DataLoader(gallery_processor, batch_size=1, num_workers=workers, shuffle=False, pin_memory=True, drop_last=False) 60 | 61 | return dataset, num_classes, train_loader, query_loader, gallery_loader 62 | 63 | 64 | def main(args): 65 | # 1.初始化设置 66 | np.random.seed(args.seed) 67 | torch.manual_seed(args.seed) 68 | torch.cuda.manual_seed_all(args.seed) # 为GPU设置随机数种子 69 | cudnn.benchmark = True # 在程序刚开始加这条语句可以提升一点训练速度,没什么额外开销 70 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 71 | 72 | # 2.日志文件 log 73 | if args.evaluate == 1: 74 | sys.stdout = Logger(osp.join(args.logs_dir, 'log_test.txt')) 75 | else: 76 | sys.stdout = Logger(osp.join(args.logs_dir, 'log_train.txt')) 77 | 78 | # 3.数据集 todo:重建这里的数据集 79 | dataset, numclasses, train_loader, query_loader, gallery_loader = getdata(args.dataset, args.split, args.batch_size, args.seq_len, args.seq_srd, args.workers) 80 | 81 | # 4.建立网络 82 | cnn_rnn_model = models.creat(args.a1, 16, 32, 32, numclasses, num_features=args.features, seq_len=args.seq_len, batch=args.batch_size).to(device) 83 | criterion = Criterion(args.hingeMargin).to(device) 84 | optimizer = optim.SGD(cnn_rnn_model.parameters(), lr=args.lr1, momentum=args.momentum, weight_decay=args.weight_decay) 85 | 86 | # 5.trainer实例化 87 | trainer = SEQTrainer(cnn_rnn_model, criterion) 88 | 89 | # 6.evaluate类的实例化 90 | evaluator = Evaluator(cnn_rnn_model) 91 | best_top1 = 0 92 | 93 | # 6.进入训练/测试模式 94 | if args.evaluate == 1: 95 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'cnn_rnn_best.pth.tar')) 96 | cnn_rnn_model.load_state_dict(checkpoint['state_dict']) 97 | rank1 = evaluator.evaluate(query_loader, gallery_loader, dataset.queryinfo, dataset.galleryinfo) 98 | else: 99 | cnn_rnn_model.train() 100 | for epoch in range(args.start_epoch, args.epochs): 101 | trainer.train(epoch, train_loader, optimizer) 102 | 103 | # 每隔几个epoch测试一次 104 | if (epoch + 1) % 400 == 0 or (epoch + 1) == args.epochs: 105 | 106 | top1 = evaluator.evaluate(query_loader, gallery_loader, dataset.queryinfo, dataset.galleryinfo) 107 | is_best = top1 > best_top1 108 | if is_best: 109 | best_top1 = top1 110 | 111 | save_checkpoint({ 112 | 'state_dict': cnn_rnn_model.state_dict(), 113 | 'epoch': epoch + 1, 114 | 'best_top1': best_top1, 115 | }, is_best, fpath=osp.join(args.logs_dir, 'cnn_checkpoint.pth.tar')) 116 | 117 | 118 | if __name__ == '__main__': 119 | parser = argparse.ArgumentParser(description='hh') 120 | parser.add_argument('--seed', type=int, default=1) 121 | # DATA 122 | parser.add_argument('--dataset', type=str, default='prid2011', choices=['ilds', 'prid2011', 'mars']) 123 | parser.add_argument('--batch-size', type=int, default=8, help='depend on your device') 124 | parser.add_argument('--workers', type=int, default=4) 125 | parser.add_argument('--seq_len', type=int, default=16, help='the number of images in a sequence') 126 | parser.add_argument('--seq_srd', type=int, default=16, help='采样间隔步长') 127 | parser.add_argument('--split', type=int, default=0, help='total 10') 128 | # Model 129 | parser.add_argument('--a1', type=str, default='cnn-rnn') 130 | parser.add_argument('--nConvFilters', type=int, default=32) 131 | parser.add_argument('--features', type=int, default=128, help='features dimension') 132 | parser.add_argument('--hingeMargin', type=int, default=2) 133 | # parser.add_argument('--dropout', type=float, default=0.0) 134 | # Optimizer 135 | parser.add_argument('--lr1', type=float, default=0.001) 136 | parser.add_argument('--lrstep', type=int, default=20) 137 | parser.add_argument('--momentum', type=float, default=0.9) 138 | parser.add_argument('--weight-decay', type=float, default=5e-4) 139 | # Train 140 | parser.add_argument('--start-epoch', type=int, default=0) 141 | parser.add_argument('--epochs', type=int, default=400) 142 | parser.add_argument('--evaluate', type=int, default=0, help='0 => train; 1 =>test') 143 | # Path 144 | working_dir = osp.dirname(osp.abspath(__file__)) 145 | parser.add_argument('--dataset-dir', type=str, metavar='PATH', default=osp.join(working_dir, '../video_reid_mars/data')) 146 | parser.add_argument('--logs-dir', type=str, metavar='PATH', default=osp.join(working_dir, 'log/yuci')) 147 | args = parser.parse_args() 148 | # main func 149 | main(args) 150 | -------------------------------------------------------------------------------- /splits_prid2011.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "test": [ 4 | "person_0001", 5 | "person_0002", 6 | "person_0003", 7 | "person_0007", 8 | "person_0012", 9 | "person_0013", 10 | "person_0014", 11 | "person_0016", 12 | "person_0017", 13 | "person_0023", 14 | "person_0027", 15 | "person_0030", 16 | "person_0032", 17 | "person_0033", 18 | "person_0035", 19 | "person_0036", 20 | "person_0037", 21 | "person_0038", 22 | "person_0040", 23 | "person_0043", 24 | "person_0044", 25 | "person_0045", 26 | "person_0046", 27 | "person_0047", 28 | "person_0048", 29 | "person_0049", 30 | "person_0054", 31 | "person_0058", 32 | "person_0060", 33 | "person_0061", 34 | "person_0063", 35 | "person_0067", 36 | "person_0069", 37 | "person_0070", 38 | "person_0073", 39 | "person_0074", 40 | "person_0075", 41 | "person_0078", 42 | "person_0081", 43 | "person_0082", 44 | "person_0083", 45 | "person_0085", 46 | "person_0087", 47 | "person_0090", 48 | "person_0096", 49 | "person_0097", 50 | "person_0099", 51 | "person_0101", 52 | "person_0102", 53 | "person_0107", 54 | "person_0108", 55 | "person_0109", 56 | "person_0111", 57 | "person_0116", 58 | "person_0117", 59 | "person_0122", 60 | "person_0124", 61 | "person_0126", 62 | "person_0127", 63 | "person_0130", 64 | "person_0131", 65 | "person_0132", 66 | "person_0141", 67 | "person_0148", 68 | "person_0149", 69 | "person_0151", 70 | "person_0152", 71 | "person_0153", 72 | "person_0156", 73 | "person_0169", 74 | "person_0170", 75 | "person_0171", 76 | "person_0172", 77 | "person_0173", 78 | "person_0174", 79 | "person_0175", 80 | "person_0177", 81 | "person_0178", 82 | "person_0181", 83 | "person_0182", 84 | "person_0184", 85 | "person_0186", 86 | "person_0188", 87 | "person_0189", 88 | "person_0191", 89 | "person_0192", 90 | "person_0193", 91 | "person_0195", 92 | "person_0197" 93 | ], 94 | "train": [ 95 | "person_0004", 96 | "person_0005", 97 | "person_0008", 98 | "person_0009", 99 | "person_0010", 100 | "person_0011", 101 | "person_0015", 102 | "person_0018", 103 | "person_0020", 104 | "person_0021", 105 | "person_0022", 106 | "person_0024", 107 | "person_0025", 108 | "person_0026", 109 | "person_0028", 110 | "person_0029", 111 | "person_0034", 112 | "person_0039", 113 | "person_0041", 114 | "person_0042", 115 | "person_0050", 116 | "person_0051", 117 | "person_0053", 118 | "person_0055", 119 | "person_0056", 120 | "person_0059", 121 | "person_0064", 122 | "person_0065", 123 | "person_0066", 124 | "person_0068", 125 | "person_0072", 126 | "person_0076", 127 | "person_0077", 128 | "person_0079", 129 | "person_0080", 130 | "person_0084", 131 | "person_0086", 132 | "person_0088", 133 | "person_0089", 134 | "person_0091", 135 | "person_0092", 136 | "person_0093", 137 | "person_0095", 138 | "person_0098", 139 | "person_0100", 140 | "person_0103", 141 | "person_0104", 142 | "person_0105", 143 | "person_0110", 144 | "person_0112", 145 | "person_0113", 146 | "person_0114", 147 | "person_0115", 148 | "person_0118", 149 | "person_0119", 150 | "person_0121", 151 | "person_0123", 152 | "person_0125", 153 | "person_0128", 154 | "person_0129", 155 | "person_0133", 156 | "person_0134", 157 | "person_0135", 158 | "person_0137", 159 | "person_0140", 160 | "person_0143", 161 | "person_0144", 162 | "person_0146", 163 | "person_0147", 164 | "person_0150", 165 | "person_0154", 166 | "person_0155", 167 | "person_0157", 168 | "person_0160", 169 | "person_0161", 170 | "person_0162", 171 | "person_0163", 172 | "person_0164", 173 | "person_0165", 174 | "person_0167", 175 | "person_0168", 176 | "person_0176", 177 | "person_0179", 178 | "person_0180", 179 | "person_0185", 180 | "person_0187", 181 | "person_0194", 182 | "person_0196", 183 | "person_0198" 184 | ] 185 | }, 186 | { 187 | "test": [ 188 | "person_0001", 189 | "person_0004", 190 | "person_0005", 191 | "person_0007", 192 | "person_0008", 193 | "person_0009", 194 | "person_0010", 195 | "person_0011", 196 | "person_0020", 197 | "person_0021", 198 | "person_0022", 199 | "person_0024", 200 | "person_0025", 201 | "person_0026", 202 | "person_0030", 203 | "person_0032", 204 | "person_0034", 205 | "person_0035", 206 | "person_0041", 207 | "person_0043", 208 | "person_0046", 209 | "person_0050", 210 | "person_0051", 211 | "person_0055", 212 | "person_0056", 213 | "person_0060", 214 | "person_0061", 215 | "person_0066", 216 | "person_0067", 217 | "person_0068", 218 | "person_0069", 219 | "person_0073", 220 | "person_0074", 221 | "person_0076", 222 | "person_0077", 223 | "person_0078", 224 | "person_0081", 225 | "person_0083", 226 | "person_0085", 227 | "person_0087", 228 | "person_0088", 229 | "person_0089", 230 | "person_0090", 231 | "person_0091", 232 | "person_0092", 233 | "person_0095", 234 | "person_0096", 235 | "person_0099", 236 | "person_0100", 237 | "person_0107", 238 | "person_0108", 239 | "person_0110", 240 | "person_0112", 241 | "person_0113", 242 | "person_0115", 243 | "person_0116", 244 | "person_0117", 245 | "person_0119", 246 | "person_0121", 247 | "person_0123", 248 | "person_0137", 249 | "person_0141", 250 | "person_0146", 251 | "person_0148", 252 | "person_0149", 253 | "person_0153", 254 | "person_0154", 255 | "person_0155", 256 | "person_0165", 257 | "person_0167", 258 | "person_0169", 259 | "person_0172", 260 | "person_0173", 261 | "person_0174", 262 | "person_0175", 263 | "person_0179", 264 | "person_0180", 265 | "person_0181", 266 | "person_0185", 267 | "person_0186", 268 | "person_0187", 269 | "person_0189", 270 | "person_0191", 271 | "person_0192", 272 | "person_0194", 273 | "person_0195", 274 | "person_0196", 275 | "person_0197", 276 | "person_0198" 277 | ], 278 | "train": [ 279 | "person_0002", 280 | "person_0003", 281 | "person_0012", 282 | "person_0013", 283 | "person_0014", 284 | "person_0015", 285 | "person_0016", 286 | "person_0017", 287 | "person_0018", 288 | "person_0023", 289 | "person_0027", 290 | "person_0028", 291 | "person_0029", 292 | "person_0033", 293 | "person_0036", 294 | "person_0037", 295 | "person_0038", 296 | "person_0039", 297 | "person_0040", 298 | "person_0042", 299 | "person_0044", 300 | "person_0045", 301 | "person_0047", 302 | "person_0048", 303 | "person_0049", 304 | "person_0053", 305 | "person_0054", 306 | "person_0058", 307 | "person_0059", 308 | "person_0063", 309 | "person_0064", 310 | "person_0065", 311 | "person_0070", 312 | "person_0072", 313 | "person_0075", 314 | "person_0079", 315 | "person_0080", 316 | "person_0082", 317 | "person_0084", 318 | "person_0086", 319 | "person_0093", 320 | "person_0097", 321 | "person_0098", 322 | "person_0101", 323 | "person_0102", 324 | "person_0103", 325 | "person_0104", 326 | "person_0105", 327 | "person_0109", 328 | "person_0111", 329 | "person_0114", 330 | "person_0118", 331 | "person_0122", 332 | "person_0124", 333 | "person_0125", 334 | "person_0126", 335 | "person_0127", 336 | "person_0128", 337 | "person_0129", 338 | "person_0130", 339 | "person_0131", 340 | "person_0132", 341 | "person_0133", 342 | "person_0134", 343 | "person_0135", 344 | "person_0140", 345 | "person_0143", 346 | "person_0144", 347 | "person_0147", 348 | "person_0150", 349 | "person_0151", 350 | "person_0152", 351 | "person_0156", 352 | "person_0157", 353 | "person_0160", 354 | "person_0161", 355 | "person_0162", 356 | "person_0163", 357 | "person_0164", 358 | "person_0168", 359 | "person_0170", 360 | "person_0171", 361 | "person_0176", 362 | "person_0177", 363 | "person_0178", 364 | "person_0182", 365 | "person_0184", 366 | "person_0188", 367 | "person_0193" 368 | ] 369 | }, 370 | { 371 | "test": [ 372 | "person_0003", 373 | "person_0007", 374 | "person_0010", 375 | "person_0012", 376 | "person_0015", 377 | "person_0016", 378 | "person_0017", 379 | "person_0018", 380 | "person_0023", 381 | "person_0024", 382 | "person_0025", 383 | "person_0027", 384 | "person_0028", 385 | "person_0029", 386 | "person_0033", 387 | "person_0036", 388 | "person_0037", 389 | "person_0039", 390 | "person_0040", 391 | "person_0041", 392 | "person_0044", 393 | "person_0045", 394 | "person_0047", 395 | "person_0048", 396 | "person_0049", 397 | "person_0051", 398 | "person_0053", 399 | "person_0055", 400 | "person_0056", 401 | "person_0058", 402 | "person_0059", 403 | "person_0060", 404 | "person_0061", 405 | "person_0063", 406 | "person_0068", 407 | "person_0074", 408 | "person_0078", 409 | "person_0079", 410 | "person_0080", 411 | "person_0083", 412 | "person_0087", 413 | "person_0088", 414 | "person_0095", 415 | "person_0096", 416 | "person_0100", 417 | "person_0102", 418 | "person_0104", 419 | "person_0107", 420 | "person_0109", 421 | "person_0110", 422 | "person_0112", 423 | "person_0113", 424 | "person_0114", 425 | "person_0121", 426 | "person_0122", 427 | "person_0123", 428 | "person_0129", 429 | "person_0133", 430 | "person_0134", 431 | "person_0137", 432 | "person_0140", 433 | "person_0144", 434 | "person_0146", 435 | "person_0147", 436 | "person_0148", 437 | "person_0149", 438 | "person_0150", 439 | "person_0151", 440 | "person_0152", 441 | "person_0153", 442 | "person_0155", 443 | "person_0156", 444 | "person_0165", 445 | "person_0169", 446 | "person_0171", 447 | "person_0174", 448 | "person_0176", 449 | "person_0177", 450 | "person_0179", 451 | "person_0180", 452 | "person_0185", 453 | "person_0187", 454 | "person_0189", 455 | "person_0193", 456 | "person_0194", 457 | "person_0195", 458 | "person_0196", 459 | "person_0197", 460 | "person_0198" 461 | ], 462 | "train": [ 463 | "person_0001", 464 | "person_0002", 465 | "person_0004", 466 | "person_0005", 467 | "person_0008", 468 | "person_0009", 469 | "person_0011", 470 | "person_0013", 471 | "person_0014", 472 | "person_0020", 473 | "person_0021", 474 | "person_0022", 475 | "person_0026", 476 | "person_0030", 477 | "person_0032", 478 | "person_0034", 479 | "person_0035", 480 | "person_0038", 481 | "person_0042", 482 | "person_0043", 483 | "person_0046", 484 | "person_0050", 485 | "person_0054", 486 | "person_0064", 487 | "person_0065", 488 | "person_0066", 489 | "person_0067", 490 | "person_0069", 491 | "person_0070", 492 | "person_0072", 493 | "person_0073", 494 | "person_0075", 495 | "person_0076", 496 | "person_0077", 497 | "person_0081", 498 | "person_0082", 499 | "person_0084", 500 | "person_0085", 501 | "person_0086", 502 | "person_0089", 503 | "person_0090", 504 | "person_0091", 505 | "person_0092", 506 | "person_0093", 507 | "person_0097", 508 | "person_0098", 509 | "person_0099", 510 | "person_0101", 511 | "person_0103", 512 | "person_0105", 513 | "person_0108", 514 | "person_0111", 515 | "person_0115", 516 | "person_0116", 517 | "person_0117", 518 | "person_0118", 519 | "person_0119", 520 | "person_0124", 521 | "person_0125", 522 | "person_0126", 523 | "person_0127", 524 | "person_0128", 525 | "person_0130", 526 | "person_0131", 527 | "person_0132", 528 | "person_0135", 529 | "person_0141", 530 | "person_0143", 531 | "person_0154", 532 | "person_0157", 533 | "person_0160", 534 | "person_0161", 535 | "person_0162", 536 | "person_0163", 537 | "person_0164", 538 | "person_0167", 539 | "person_0168", 540 | "person_0170", 541 | "person_0172", 542 | "person_0173", 543 | "person_0175", 544 | "person_0178", 545 | "person_0181", 546 | "person_0182", 547 | "person_0184", 548 | "person_0186", 549 | "person_0188", 550 | "person_0191", 551 | "person_0192" 552 | ] 553 | }, 554 | { 555 | "test": [ 556 | "person_0002", 557 | "person_0004", 558 | "person_0007", 559 | "person_0008", 560 | "person_0009", 561 | "person_0016", 562 | "person_0021", 563 | "person_0022", 564 | "person_0023", 565 | "person_0024", 566 | "person_0026", 567 | "person_0027", 568 | "person_0030", 569 | "person_0035", 570 | "person_0036", 571 | "person_0038", 572 | "person_0041", 573 | "person_0042", 574 | "person_0043", 575 | "person_0046", 576 | "person_0049", 577 | "person_0050", 578 | "person_0056", 579 | "person_0058", 580 | "person_0060", 581 | "person_0064", 582 | "person_0065", 583 | "person_0069", 584 | "person_0070", 585 | "person_0072", 586 | "person_0076", 587 | "person_0077", 588 | "person_0078", 589 | "person_0081", 590 | "person_0082", 591 | "person_0083", 592 | "person_0087", 593 | "person_0088", 594 | "person_0090", 595 | "person_0092", 596 | "person_0095", 597 | "person_0100", 598 | "person_0102", 599 | "person_0103", 600 | "person_0105", 601 | "person_0107", 602 | "person_0108", 603 | "person_0109", 604 | "person_0117", 605 | "person_0118", 606 | "person_0119", 607 | "person_0122", 608 | "person_0124", 609 | "person_0125", 610 | "person_0127", 611 | "person_0128", 612 | "person_0131", 613 | "person_0133", 614 | "person_0134", 615 | "person_0135", 616 | "person_0137", 617 | "person_0140", 618 | "person_0141", 619 | "person_0146", 620 | "person_0148", 621 | "person_0149", 622 | "person_0150", 623 | "person_0152", 624 | "person_0153", 625 | "person_0154", 626 | "person_0156", 627 | "person_0157", 628 | "person_0161", 629 | "person_0162", 630 | "person_0165", 631 | "person_0167", 632 | "person_0168", 633 | "person_0174", 634 | "person_0176", 635 | "person_0177", 636 | "person_0178", 637 | "person_0179", 638 | "person_0180", 639 | "person_0181", 640 | "person_0184", 641 | "person_0192", 642 | "person_0193", 643 | "person_0194", 644 | "person_0198" 645 | ], 646 | "train": [ 647 | "person_0001", 648 | "person_0003", 649 | "person_0005", 650 | "person_0010", 651 | "person_0011", 652 | "person_0012", 653 | "person_0013", 654 | "person_0014", 655 | "person_0015", 656 | "person_0017", 657 | "person_0018", 658 | "person_0020", 659 | "person_0025", 660 | "person_0028", 661 | "person_0029", 662 | "person_0032", 663 | "person_0033", 664 | "person_0034", 665 | "person_0037", 666 | "person_0039", 667 | "person_0040", 668 | "person_0044", 669 | "person_0045", 670 | "person_0047", 671 | "person_0048", 672 | "person_0051", 673 | "person_0053", 674 | "person_0054", 675 | "person_0055", 676 | "person_0059", 677 | "person_0061", 678 | "person_0063", 679 | "person_0066", 680 | "person_0067", 681 | "person_0068", 682 | "person_0073", 683 | "person_0074", 684 | "person_0075", 685 | "person_0079", 686 | "person_0080", 687 | "person_0084", 688 | "person_0085", 689 | "person_0086", 690 | "person_0089", 691 | "person_0091", 692 | "person_0093", 693 | "person_0096", 694 | "person_0097", 695 | "person_0098", 696 | "person_0099", 697 | "person_0101", 698 | "person_0104", 699 | "person_0110", 700 | "person_0111", 701 | "person_0112", 702 | "person_0113", 703 | "person_0114", 704 | "person_0115", 705 | "person_0116", 706 | "person_0121", 707 | "person_0123", 708 | "person_0126", 709 | "person_0129", 710 | "person_0130", 711 | "person_0132", 712 | "person_0143", 713 | "person_0144", 714 | "person_0147", 715 | "person_0151", 716 | "person_0155", 717 | "person_0160", 718 | "person_0163", 719 | "person_0164", 720 | "person_0169", 721 | "person_0170", 722 | "person_0171", 723 | "person_0172", 724 | "person_0173", 725 | "person_0175", 726 | "person_0182", 727 | "person_0185", 728 | "person_0186", 729 | "person_0187", 730 | "person_0188", 731 | "person_0189", 732 | "person_0191", 733 | "person_0195", 734 | "person_0196", 735 | "person_0197" 736 | ] 737 | }, 738 | { 739 | "test": [ 740 | "person_0002", 741 | "person_0003", 742 | "person_0005", 743 | "person_0007", 744 | "person_0009", 745 | "person_0021", 746 | "person_0023", 747 | "person_0024", 748 | "person_0025", 749 | "person_0027", 750 | "person_0029", 751 | "person_0032", 752 | "person_0036", 753 | "person_0037", 754 | "person_0038", 755 | "person_0039", 756 | "person_0044", 757 | "person_0045", 758 | "person_0047", 759 | "person_0049", 760 | "person_0050", 761 | "person_0051", 762 | "person_0053", 763 | "person_0056", 764 | "person_0060", 765 | "person_0063", 766 | "person_0067", 767 | "person_0072", 768 | "person_0075", 769 | "person_0076", 770 | "person_0078", 771 | "person_0079", 772 | "person_0081", 773 | "person_0083", 774 | "person_0086", 775 | "person_0087", 776 | "person_0088", 777 | "person_0090", 778 | "person_0095", 779 | "person_0096", 780 | "person_0097", 781 | "person_0099", 782 | "person_0100", 783 | "person_0101", 784 | "person_0103", 785 | "person_0104", 786 | "person_0105", 787 | "person_0110", 788 | "person_0111", 789 | "person_0115", 790 | "person_0118", 791 | "person_0119", 792 | "person_0122", 793 | "person_0123", 794 | "person_0126", 795 | "person_0127", 796 | "person_0129", 797 | "person_0131", 798 | "person_0135", 799 | "person_0140", 800 | "person_0141", 801 | "person_0152", 802 | "person_0153", 803 | "person_0154", 804 | "person_0155", 805 | "person_0156", 806 | "person_0157", 807 | "person_0161", 808 | "person_0164", 809 | "person_0169", 810 | "person_0170", 811 | "person_0171", 812 | "person_0172", 813 | "person_0173", 814 | "person_0174", 815 | "person_0175", 816 | "person_0178", 817 | "person_0179", 818 | "person_0180", 819 | "person_0181", 820 | "person_0185", 821 | "person_0186", 822 | "person_0187", 823 | "person_0188", 824 | "person_0189", 825 | "person_0191", 826 | "person_0194", 827 | "person_0195", 828 | "person_0197" 829 | ], 830 | "train": [ 831 | "person_0001", 832 | "person_0004", 833 | "person_0008", 834 | "person_0010", 835 | "person_0011", 836 | "person_0012", 837 | "person_0013", 838 | "person_0014", 839 | "person_0015", 840 | "person_0016", 841 | "person_0017", 842 | "person_0018", 843 | "person_0020", 844 | "person_0022", 845 | "person_0026", 846 | "person_0028", 847 | "person_0030", 848 | "person_0033", 849 | "person_0034", 850 | "person_0035", 851 | "person_0040", 852 | "person_0041", 853 | "person_0042", 854 | "person_0043", 855 | "person_0046", 856 | "person_0048", 857 | "person_0054", 858 | "person_0055", 859 | "person_0058", 860 | "person_0059", 861 | "person_0061", 862 | "person_0064", 863 | "person_0065", 864 | "person_0066", 865 | "person_0068", 866 | "person_0069", 867 | "person_0070", 868 | "person_0073", 869 | "person_0074", 870 | "person_0077", 871 | "person_0080", 872 | "person_0082", 873 | "person_0084", 874 | "person_0085", 875 | "person_0089", 876 | "person_0091", 877 | "person_0092", 878 | "person_0093", 879 | "person_0098", 880 | "person_0102", 881 | "person_0107", 882 | "person_0108", 883 | "person_0109", 884 | "person_0112", 885 | "person_0113", 886 | "person_0114", 887 | "person_0116", 888 | "person_0117", 889 | "person_0121", 890 | "person_0124", 891 | "person_0125", 892 | "person_0128", 893 | "person_0130", 894 | "person_0132", 895 | "person_0133", 896 | "person_0134", 897 | "person_0137", 898 | "person_0143", 899 | "person_0144", 900 | "person_0146", 901 | "person_0147", 902 | "person_0148", 903 | "person_0149", 904 | "person_0150", 905 | "person_0151", 906 | "person_0160", 907 | "person_0162", 908 | "person_0163", 909 | "person_0165", 910 | "person_0167", 911 | "person_0168", 912 | "person_0176", 913 | "person_0177", 914 | "person_0182", 915 | "person_0184", 916 | "person_0192", 917 | "person_0193", 918 | "person_0196", 919 | "person_0198" 920 | ] 921 | }, 922 | { 923 | "test": [ 924 | "person_0001", 925 | "person_0002", 926 | "person_0004", 927 | "person_0005", 928 | "person_0008", 929 | "person_0009", 930 | "person_0010", 931 | "person_0012", 932 | "person_0013", 933 | "person_0014", 934 | "person_0016", 935 | "person_0018", 936 | "person_0020", 937 | "person_0021", 938 | "person_0024", 939 | "person_0026", 940 | "person_0029", 941 | "person_0034", 942 | "person_0037", 943 | "person_0040", 944 | "person_0041", 945 | "person_0042", 946 | "person_0044", 947 | "person_0045", 948 | "person_0046", 949 | "person_0048", 950 | "person_0050", 951 | "person_0054", 952 | "person_0056", 953 | "person_0059", 954 | "person_0061", 955 | "person_0063", 956 | "person_0065", 957 | "person_0069", 958 | "person_0072", 959 | "person_0077", 960 | "person_0078", 961 | "person_0079", 962 | "person_0082", 963 | "person_0083", 964 | "person_0084", 965 | "person_0086", 966 | "person_0092", 967 | "person_0096", 968 | "person_0102", 969 | "person_0104", 970 | "person_0108", 971 | "person_0109", 972 | "person_0110", 973 | "person_0111", 974 | "person_0114", 975 | "person_0117", 976 | "person_0118", 977 | "person_0119", 978 | "person_0121", 979 | "person_0122", 980 | "person_0123", 981 | "person_0124", 982 | "person_0125", 983 | "person_0128", 984 | "person_0129", 985 | "person_0130", 986 | "person_0131", 987 | "person_0132", 988 | "person_0140", 989 | "person_0141", 990 | "person_0146", 991 | "person_0147", 992 | "person_0148", 993 | "person_0152", 994 | "person_0153", 995 | "person_0160", 996 | "person_0161", 997 | "person_0163", 998 | "person_0164", 999 | "person_0165", 1000 | "person_0169", 1001 | "person_0170", 1002 | "person_0178", 1003 | "person_0180", 1004 | "person_0184", 1005 | "person_0185", 1006 | "person_0187", 1007 | "person_0188", 1008 | "person_0189", 1009 | "person_0192", 1010 | "person_0194", 1011 | "person_0196", 1012 | "person_0198" 1013 | ], 1014 | "train": [ 1015 | "person_0003", 1016 | "person_0007", 1017 | "person_0011", 1018 | "person_0015", 1019 | "person_0017", 1020 | "person_0022", 1021 | "person_0023", 1022 | "person_0025", 1023 | "person_0027", 1024 | "person_0028", 1025 | "person_0030", 1026 | "person_0032", 1027 | "person_0033", 1028 | "person_0035", 1029 | "person_0036", 1030 | "person_0038", 1031 | "person_0039", 1032 | "person_0043", 1033 | "person_0047", 1034 | "person_0049", 1035 | "person_0051", 1036 | "person_0053", 1037 | "person_0055", 1038 | "person_0058", 1039 | "person_0060", 1040 | "person_0064", 1041 | "person_0066", 1042 | "person_0067", 1043 | "person_0068", 1044 | "person_0070", 1045 | "person_0073", 1046 | "person_0074", 1047 | "person_0075", 1048 | "person_0076", 1049 | "person_0080", 1050 | "person_0081", 1051 | "person_0085", 1052 | "person_0087", 1053 | "person_0088", 1054 | "person_0089", 1055 | "person_0090", 1056 | "person_0091", 1057 | "person_0093", 1058 | "person_0095", 1059 | "person_0097", 1060 | "person_0098", 1061 | "person_0099", 1062 | "person_0100", 1063 | "person_0101", 1064 | "person_0103", 1065 | "person_0105", 1066 | "person_0107", 1067 | "person_0112", 1068 | "person_0113", 1069 | "person_0115", 1070 | "person_0116", 1071 | "person_0126", 1072 | "person_0127", 1073 | "person_0133", 1074 | "person_0134", 1075 | "person_0135", 1076 | "person_0137", 1077 | "person_0143", 1078 | "person_0144", 1079 | "person_0149", 1080 | "person_0150", 1081 | "person_0151", 1082 | "person_0154", 1083 | "person_0155", 1084 | "person_0156", 1085 | "person_0157", 1086 | "person_0162", 1087 | "person_0167", 1088 | "person_0168", 1089 | "person_0171", 1090 | "person_0172", 1091 | "person_0173", 1092 | "person_0174", 1093 | "person_0175", 1094 | "person_0176", 1095 | "person_0177", 1096 | "person_0179", 1097 | "person_0181", 1098 | "person_0182", 1099 | "person_0186", 1100 | "person_0191", 1101 | "person_0193", 1102 | "person_0195", 1103 | "person_0197" 1104 | ] 1105 | }, 1106 | { 1107 | "test": [ 1108 | "person_0004", 1109 | "person_0005", 1110 | "person_0008", 1111 | "person_0009", 1112 | "person_0012", 1113 | "person_0015", 1114 | "person_0016", 1115 | "person_0021", 1116 | "person_0023", 1117 | "person_0025", 1118 | "person_0028", 1119 | "person_0029", 1120 | "person_0032", 1121 | "person_0033", 1122 | "person_0035", 1123 | "person_0036", 1124 | "person_0040", 1125 | "person_0041", 1126 | "person_0044", 1127 | "person_0045", 1128 | "person_0048", 1129 | "person_0050", 1130 | "person_0051", 1131 | "person_0053", 1132 | "person_0054", 1133 | "person_0056", 1134 | "person_0059", 1135 | "person_0065", 1136 | "person_0067", 1137 | "person_0068", 1138 | "person_0072", 1139 | "person_0078", 1140 | "person_0080", 1141 | "person_0081", 1142 | "person_0082", 1143 | "person_0083", 1144 | "person_0084", 1145 | "person_0085", 1146 | "person_0086", 1147 | "person_0088", 1148 | "person_0090", 1149 | "person_0091", 1150 | "person_0092", 1151 | "person_0095", 1152 | "person_0097", 1153 | "person_0099", 1154 | "person_0107", 1155 | "person_0108", 1156 | "person_0109", 1157 | "person_0111", 1158 | "person_0113", 1159 | "person_0116", 1160 | "person_0118", 1161 | "person_0119", 1162 | "person_0121", 1163 | "person_0122", 1164 | "person_0124", 1165 | "person_0126", 1166 | "person_0127", 1167 | "person_0129", 1168 | "person_0130", 1169 | "person_0134", 1170 | "person_0135", 1171 | "person_0137", 1172 | "person_0143", 1173 | "person_0146", 1174 | "person_0149", 1175 | "person_0150", 1176 | "person_0153", 1177 | "person_0154", 1178 | "person_0157", 1179 | "person_0165", 1180 | "person_0168", 1181 | "person_0170", 1182 | "person_0172", 1183 | "person_0173", 1184 | "person_0174", 1185 | "person_0175", 1186 | "person_0176", 1187 | "person_0177", 1188 | "person_0178", 1189 | "person_0179", 1190 | "person_0180", 1191 | "person_0182", 1192 | "person_0186", 1193 | "person_0189", 1194 | "person_0192", 1195 | "person_0197", 1196 | "person_0198" 1197 | ], 1198 | "train": [ 1199 | "person_0001", 1200 | "person_0002", 1201 | "person_0003", 1202 | "person_0007", 1203 | "person_0010", 1204 | "person_0011", 1205 | "person_0013", 1206 | "person_0014", 1207 | "person_0017", 1208 | "person_0018", 1209 | "person_0020", 1210 | "person_0022", 1211 | "person_0024", 1212 | "person_0026", 1213 | "person_0027", 1214 | "person_0030", 1215 | "person_0034", 1216 | "person_0037", 1217 | "person_0038", 1218 | "person_0039", 1219 | "person_0042", 1220 | "person_0043", 1221 | "person_0046", 1222 | "person_0047", 1223 | "person_0049", 1224 | "person_0055", 1225 | "person_0058", 1226 | "person_0060", 1227 | "person_0061", 1228 | "person_0063", 1229 | "person_0064", 1230 | "person_0066", 1231 | "person_0069", 1232 | "person_0070", 1233 | "person_0073", 1234 | "person_0074", 1235 | "person_0075", 1236 | "person_0076", 1237 | "person_0077", 1238 | "person_0079", 1239 | "person_0087", 1240 | "person_0089", 1241 | "person_0093", 1242 | "person_0096", 1243 | "person_0098", 1244 | "person_0100", 1245 | "person_0101", 1246 | "person_0102", 1247 | "person_0103", 1248 | "person_0104", 1249 | "person_0105", 1250 | "person_0110", 1251 | "person_0112", 1252 | "person_0114", 1253 | "person_0115", 1254 | "person_0117", 1255 | "person_0123", 1256 | "person_0125", 1257 | "person_0128", 1258 | "person_0131", 1259 | "person_0132", 1260 | "person_0133", 1261 | "person_0140", 1262 | "person_0141", 1263 | "person_0144", 1264 | "person_0147", 1265 | "person_0148", 1266 | "person_0151", 1267 | "person_0152", 1268 | "person_0155", 1269 | "person_0156", 1270 | "person_0160", 1271 | "person_0161", 1272 | "person_0162", 1273 | "person_0163", 1274 | "person_0164", 1275 | "person_0167", 1276 | "person_0169", 1277 | "person_0171", 1278 | "person_0181", 1279 | "person_0184", 1280 | "person_0185", 1281 | "person_0187", 1282 | "person_0188", 1283 | "person_0191", 1284 | "person_0193", 1285 | "person_0194", 1286 | "person_0195", 1287 | "person_0196" 1288 | ] 1289 | }, 1290 | { 1291 | "test": [ 1292 | "person_0002", 1293 | "person_0003", 1294 | "person_0009", 1295 | "person_0010", 1296 | "person_0012", 1297 | "person_0013", 1298 | "person_0015", 1299 | "person_0025", 1300 | "person_0028", 1301 | "person_0030", 1302 | "person_0033", 1303 | "person_0035", 1304 | "person_0036", 1305 | "person_0041", 1306 | "person_0047", 1307 | "person_0048", 1308 | "person_0049", 1309 | "person_0051", 1310 | "person_0053", 1311 | "person_0056", 1312 | "person_0058", 1313 | "person_0059", 1314 | "person_0066", 1315 | "person_0075", 1316 | "person_0078", 1317 | "person_0079", 1318 | "person_0080", 1319 | "person_0082", 1320 | "person_0086", 1321 | "person_0087", 1322 | "person_0091", 1323 | "person_0096", 1324 | "person_0097", 1325 | "person_0098", 1326 | "person_0099", 1327 | "person_0101", 1328 | "person_0102", 1329 | "person_0103", 1330 | "person_0104", 1331 | "person_0105", 1332 | "person_0107", 1333 | "person_0108", 1334 | "person_0109", 1335 | "person_0112", 1336 | "person_0115", 1337 | "person_0118", 1338 | "person_0121", 1339 | "person_0122", 1340 | "person_0126", 1341 | "person_0127", 1342 | "person_0128", 1343 | "person_0129", 1344 | "person_0130", 1345 | "person_0132", 1346 | "person_0134", 1347 | "person_0135", 1348 | "person_0137", 1349 | "person_0141", 1350 | "person_0143", 1351 | "person_0144", 1352 | "person_0146", 1353 | "person_0148", 1354 | "person_0149", 1355 | "person_0151", 1356 | "person_0152", 1357 | "person_0153", 1358 | "person_0154", 1359 | "person_0156", 1360 | "person_0157", 1361 | "person_0162", 1362 | "person_0164", 1363 | "person_0167", 1364 | "person_0168", 1365 | "person_0170", 1366 | "person_0171", 1367 | "person_0172", 1368 | "person_0173", 1369 | "person_0174", 1370 | "person_0175", 1371 | "person_0178", 1372 | "person_0179", 1373 | "person_0181", 1374 | "person_0184", 1375 | "person_0188", 1376 | "person_0191", 1377 | "person_0192", 1378 | "person_0193", 1379 | "person_0194", 1380 | "person_0198" 1381 | ], 1382 | "train": [ 1383 | "person_0001", 1384 | "person_0004", 1385 | "person_0005", 1386 | "person_0007", 1387 | "person_0008", 1388 | "person_0011", 1389 | "person_0014", 1390 | "person_0016", 1391 | "person_0017", 1392 | "person_0018", 1393 | "person_0020", 1394 | "person_0021", 1395 | "person_0022", 1396 | "person_0023", 1397 | "person_0024", 1398 | "person_0026", 1399 | "person_0027", 1400 | "person_0029", 1401 | "person_0032", 1402 | "person_0034", 1403 | "person_0037", 1404 | "person_0038", 1405 | "person_0039", 1406 | "person_0040", 1407 | "person_0042", 1408 | "person_0043", 1409 | "person_0044", 1410 | "person_0045", 1411 | "person_0046", 1412 | "person_0050", 1413 | "person_0054", 1414 | "person_0055", 1415 | "person_0060", 1416 | "person_0061", 1417 | "person_0063", 1418 | "person_0064", 1419 | "person_0065", 1420 | "person_0067", 1421 | "person_0068", 1422 | "person_0069", 1423 | "person_0070", 1424 | "person_0072", 1425 | "person_0073", 1426 | "person_0074", 1427 | "person_0076", 1428 | "person_0077", 1429 | "person_0081", 1430 | "person_0083", 1431 | "person_0084", 1432 | "person_0085", 1433 | "person_0088", 1434 | "person_0089", 1435 | "person_0090", 1436 | "person_0092", 1437 | "person_0093", 1438 | "person_0095", 1439 | "person_0100", 1440 | "person_0110", 1441 | "person_0111", 1442 | "person_0113", 1443 | "person_0114", 1444 | "person_0116", 1445 | "person_0117", 1446 | "person_0119", 1447 | "person_0123", 1448 | "person_0124", 1449 | "person_0125", 1450 | "person_0131", 1451 | "person_0133", 1452 | "person_0140", 1453 | "person_0147", 1454 | "person_0150", 1455 | "person_0155", 1456 | "person_0160", 1457 | "person_0161", 1458 | "person_0163", 1459 | "person_0165", 1460 | "person_0169", 1461 | "person_0176", 1462 | "person_0177", 1463 | "person_0180", 1464 | "person_0182", 1465 | "person_0185", 1466 | "person_0186", 1467 | "person_0187", 1468 | "person_0189", 1469 | "person_0195", 1470 | "person_0196", 1471 | "person_0197" 1472 | ] 1473 | }, 1474 | { 1475 | "test": [ 1476 | "person_0003", 1477 | "person_0009", 1478 | "person_0010", 1479 | "person_0011", 1480 | "person_0012", 1481 | "person_0015", 1482 | "person_0017", 1483 | "person_0022", 1484 | "person_0023", 1485 | "person_0026", 1486 | "person_0027", 1487 | "person_0034", 1488 | "person_0035", 1489 | "person_0036", 1490 | "person_0038", 1491 | "person_0041", 1492 | "person_0043", 1493 | "person_0044", 1494 | "person_0047", 1495 | "person_0048", 1496 | "person_0050", 1497 | "person_0051", 1498 | "person_0053", 1499 | "person_0055", 1500 | "person_0056", 1501 | "person_0064", 1502 | "person_0066", 1503 | "person_0069", 1504 | "person_0070", 1505 | "person_0072", 1506 | "person_0076", 1507 | "person_0078", 1508 | "person_0083", 1509 | "person_0084", 1510 | "person_0086", 1511 | "person_0088", 1512 | "person_0089", 1513 | "person_0092", 1514 | "person_0093", 1515 | "person_0095", 1516 | "person_0097", 1517 | "person_0098", 1518 | "person_0099", 1519 | "person_0101", 1520 | "person_0102", 1521 | "person_0103", 1522 | "person_0107", 1523 | "person_0108", 1524 | "person_0109", 1525 | "person_0115", 1526 | "person_0117", 1527 | "person_0119", 1528 | "person_0122", 1529 | "person_0123", 1530 | "person_0124", 1531 | "person_0126", 1532 | "person_0128", 1533 | "person_0129", 1534 | "person_0130", 1535 | "person_0131", 1536 | "person_0133", 1537 | "person_0141", 1538 | "person_0144", 1539 | "person_0146", 1540 | "person_0148", 1541 | "person_0151", 1542 | "person_0154", 1543 | "person_0155", 1544 | "person_0157", 1545 | "person_0160", 1546 | "person_0161", 1547 | "person_0162", 1548 | "person_0163", 1549 | "person_0168", 1550 | "person_0172", 1551 | "person_0174", 1552 | "person_0175", 1553 | "person_0179", 1554 | "person_0184", 1555 | "person_0185", 1556 | "person_0186", 1557 | "person_0187", 1558 | "person_0188", 1559 | "person_0189", 1560 | "person_0191", 1561 | "person_0193", 1562 | "person_0194", 1563 | "person_0195", 1564 | "person_0196" 1565 | ], 1566 | "train": [ 1567 | "person_0001", 1568 | "person_0002", 1569 | "person_0004", 1570 | "person_0005", 1571 | "person_0007", 1572 | "person_0008", 1573 | "person_0013", 1574 | "person_0014", 1575 | "person_0016", 1576 | "person_0018", 1577 | "person_0020", 1578 | "person_0021", 1579 | "person_0024", 1580 | "person_0025", 1581 | "person_0028", 1582 | "person_0029", 1583 | "person_0030", 1584 | "person_0032", 1585 | "person_0033", 1586 | "person_0037", 1587 | "person_0039", 1588 | "person_0040", 1589 | "person_0042", 1590 | "person_0045", 1591 | "person_0046", 1592 | "person_0049", 1593 | "person_0054", 1594 | "person_0058", 1595 | "person_0059", 1596 | "person_0060", 1597 | "person_0061", 1598 | "person_0063", 1599 | "person_0065", 1600 | "person_0067", 1601 | "person_0068", 1602 | "person_0073", 1603 | "person_0074", 1604 | "person_0075", 1605 | "person_0077", 1606 | "person_0079", 1607 | "person_0080", 1608 | "person_0081", 1609 | "person_0082", 1610 | "person_0085", 1611 | "person_0087", 1612 | "person_0090", 1613 | "person_0091", 1614 | "person_0096", 1615 | "person_0100", 1616 | "person_0104", 1617 | "person_0105", 1618 | "person_0110", 1619 | "person_0111", 1620 | "person_0112", 1621 | "person_0113", 1622 | "person_0114", 1623 | "person_0116", 1624 | "person_0118", 1625 | "person_0121", 1626 | "person_0125", 1627 | "person_0127", 1628 | "person_0132", 1629 | "person_0134", 1630 | "person_0135", 1631 | "person_0137", 1632 | "person_0140", 1633 | "person_0143", 1634 | "person_0147", 1635 | "person_0149", 1636 | "person_0150", 1637 | "person_0152", 1638 | "person_0153", 1639 | "person_0156", 1640 | "person_0164", 1641 | "person_0165", 1642 | "person_0167", 1643 | "person_0169", 1644 | "person_0170", 1645 | "person_0171", 1646 | "person_0173", 1647 | "person_0176", 1648 | "person_0177", 1649 | "person_0178", 1650 | "person_0180", 1651 | "person_0181", 1652 | "person_0182", 1653 | "person_0192", 1654 | "person_0197", 1655 | "person_0198" 1656 | ] 1657 | }, 1658 | { 1659 | "test": [ 1660 | "person_0001", 1661 | "person_0004", 1662 | "person_0009", 1663 | "person_0010", 1664 | "person_0011", 1665 | "person_0014", 1666 | "person_0015", 1667 | "person_0016", 1668 | "person_0021", 1669 | "person_0024", 1670 | "person_0025", 1671 | "person_0030", 1672 | "person_0032", 1673 | "person_0034", 1674 | "person_0035", 1675 | "person_0041", 1676 | "person_0042", 1677 | "person_0043", 1678 | "person_0044", 1679 | "person_0045", 1680 | "person_0047", 1681 | "person_0049", 1682 | "person_0053", 1683 | "person_0054", 1684 | "person_0056", 1685 | "person_0058", 1686 | "person_0059", 1687 | "person_0060", 1688 | "person_0064", 1689 | "person_0066", 1690 | "person_0067", 1691 | "person_0070", 1692 | "person_0073", 1693 | "person_0074", 1694 | "person_0075", 1695 | "person_0076", 1696 | "person_0077", 1697 | "person_0080", 1698 | "person_0081", 1699 | "person_0082", 1700 | "person_0083", 1701 | "person_0085", 1702 | "person_0086", 1703 | "person_0087", 1704 | "person_0088", 1705 | "person_0095", 1706 | "person_0096", 1707 | "person_0097", 1708 | "person_0100", 1709 | "person_0104", 1710 | "person_0108", 1711 | "person_0110", 1712 | "person_0111", 1713 | "person_0113", 1714 | "person_0114", 1715 | "person_0115", 1716 | "person_0117", 1717 | "person_0119", 1718 | "person_0122", 1719 | "person_0126", 1720 | "person_0127", 1721 | "person_0129", 1722 | "person_0134", 1723 | "person_0140", 1724 | "person_0143", 1725 | "person_0146", 1726 | "person_0147", 1727 | "person_0150", 1728 | "person_0151", 1729 | "person_0153", 1730 | "person_0156", 1731 | "person_0157", 1732 | "person_0160", 1733 | "person_0162", 1734 | "person_0168", 1735 | "person_0174", 1736 | "person_0176", 1737 | "person_0179", 1738 | "person_0180", 1739 | "person_0181", 1740 | "person_0182", 1741 | "person_0184", 1742 | "person_0185", 1743 | "person_0186", 1744 | "person_0189", 1745 | "person_0192", 1746 | "person_0193", 1747 | "person_0196", 1748 | "person_0198" 1749 | ], 1750 | "train": [ 1751 | "person_0002", 1752 | "person_0003", 1753 | "person_0005", 1754 | "person_0007", 1755 | "person_0008", 1756 | "person_0012", 1757 | "person_0013", 1758 | "person_0017", 1759 | "person_0018", 1760 | "person_0020", 1761 | "person_0022", 1762 | "person_0023", 1763 | "person_0026", 1764 | "person_0027", 1765 | "person_0028", 1766 | "person_0029", 1767 | "person_0033", 1768 | "person_0036", 1769 | "person_0037", 1770 | "person_0038", 1771 | "person_0039", 1772 | "person_0040", 1773 | "person_0046", 1774 | "person_0048", 1775 | "person_0050", 1776 | "person_0051", 1777 | "person_0055", 1778 | "person_0061", 1779 | "person_0063", 1780 | "person_0065", 1781 | "person_0068", 1782 | "person_0069", 1783 | "person_0072", 1784 | "person_0078", 1785 | "person_0079", 1786 | "person_0084", 1787 | "person_0089", 1788 | "person_0090", 1789 | "person_0091", 1790 | "person_0092", 1791 | "person_0093", 1792 | "person_0098", 1793 | "person_0099", 1794 | "person_0101", 1795 | "person_0102", 1796 | "person_0103", 1797 | "person_0105", 1798 | "person_0107", 1799 | "person_0109", 1800 | "person_0112", 1801 | "person_0116", 1802 | "person_0118", 1803 | "person_0121", 1804 | "person_0123", 1805 | "person_0124", 1806 | "person_0125", 1807 | "person_0128", 1808 | "person_0130", 1809 | "person_0131", 1810 | "person_0132", 1811 | "person_0133", 1812 | "person_0135", 1813 | "person_0137", 1814 | "person_0141", 1815 | "person_0144", 1816 | "person_0148", 1817 | "person_0149", 1818 | "person_0152", 1819 | "person_0154", 1820 | "person_0155", 1821 | "person_0161", 1822 | "person_0163", 1823 | "person_0164", 1824 | "person_0165", 1825 | "person_0167", 1826 | "person_0169", 1827 | "person_0170", 1828 | "person_0171", 1829 | "person_0172", 1830 | "person_0173", 1831 | "person_0175", 1832 | "person_0177", 1833 | "person_0178", 1834 | "person_0187", 1835 | "person_0188", 1836 | "person_0191", 1837 | "person_0194", 1838 | "person_0195", 1839 | "person_0197" 1840 | ] 1841 | } 1842 | ] -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # @author = yuci 4 | # Date: 19-3-25 下午8:13 5 | 6 | 7 | from __future__ import print_function, absolute_import 8 | import time 9 | import torch 10 | from torch import nn 11 | from eval.eva_functions import accuracy 12 | from utils import AverageMeter 13 | import torch.nn.functional as F 14 | from utils import to_numpy 15 | # from tensorboardX import SummaryWriter 16 | # writer = SummaryWriter('/media/ying/0BDD17830BDD1783/video_reid/logs/generate_q_by_lstm_100epoch') 17 | # mode decide how to train the model 18 | """ 分割损失回传 19 | optimizer1.zero_grad() 20 | loss1.backward(retain_graph=True) 21 | optimizer1.step() 22 | 23 | optimizer2.zero_grad() 24 | loss2.backward() 25 | optimizer2.step() 26 | """ 27 | """ 累积梯度回传 28 | # if (i + 1) % accumulation_steps == 0: 29 | # optimizer1.step() 30 | # optimizer2.step() 31 | # optimizer1.zero_grad() 32 | # optimizer2.zero_grad() 33 | # # loss.backward() 34 | """ 35 | 36 | 37 | class BaseTrainer(object): 38 | 39 | def __init__(self, model, criterion): 40 | super(BaseTrainer, self).__init__() 41 | self.model = model 42 | self.criterion = criterion 43 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 44 | 45 | def train(self, epoch, data_loader, optimizer): 46 | self.model.train() 47 | 48 | batch_time = AverageMeter() 49 | data_time = AverageMeter() 50 | losses = AverageMeter() 51 | precisions = AverageMeter() 52 | precisions1 = AverageMeter() 53 | accumulation_steps = 8 54 | # total_loss = 0 55 | 56 | end = time.time() 57 | for i, inputs in enumerate(data_loader): 58 | data_time.update(time.time() - end) 59 | 60 | netinputs0, netinputs1, targets = self._parse_data(inputs) 61 | 62 | loss = self._forward(netinputs0, netinputs1, targets) # 1.前向传播 63 | # loss = loss1 + loss2 64 | # total_loss += loss.item() 65 | losses.update(loss.item(), len(targets[0])) 66 | 67 | optimizer.zero_grad() 68 | loss.backward() 69 | optimizer.step() 70 | 71 | batch_time.update(time.time() - end) 72 | end = time.time() 73 | print_freq = 10 74 | if (i + 1) % print_freq == 0: 75 | print('Epoch: [{}][{}/{}]\t' 76 | 'Loss {:.3f} ({:.3f})\t' 77 | .format(epoch, i + 1, len(data_loader), 78 | losses.val, losses.avg)) 79 | 80 | def _parse_data(self, inputs): 81 | raise NotImplementedError 82 | 83 | def _forward(self, netinputs0, netinputs1, targets): 84 | raise NotImplementedError 85 | 86 | 87 | class SEQTrainer(BaseTrainer): 88 | 89 | def __init__(self, model, criterion): 90 | super(SEQTrainer, self).__init__(model, criterion) 91 | self.criterion = criterion 92 | # self.rate = rate 93 | # self.softmax = nn.Softmax() 94 | 95 | def _parse_data(self, inputs): 96 | seq0, seq1, targets = inputs 97 | seq0 = seq0.to(self.device) 98 | seq1 = seq1.to(self.device) 99 | 100 | return seq0, seq1, targets 101 | 102 | def _forward(self, netinputs0, netinputs1, targets): 103 | # todo: target有问题? 104 | feature_p, feature_g, identity_p, identity_g = self.model(netinputs0, netinputs1) 105 | loss = self.criterion(feature_p, feature_g, identity_p, identity_g, targets) 106 | 107 | return loss 108 | 109 | def train(self, epoch, data_loader, optimizer): 110 | # self.rate = rate 111 | super(SEQTrainer, self).train(epoch, data_loader, optimizer) 112 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # @author = yuci 4 | # Date: 19-3-18 上午11:11 5 | """第二步 工具集合 6 | define some utils in this file. 7 | Code imported from https://github.com/Cysu/open-reid/. 8 | """ 9 | import torch 10 | import os 11 | import os.path as osp 12 | import sys 13 | import errno 14 | import json 15 | import shutil 16 | 17 | 18 | # 0.to_numpy/to_torch 转换数据类型 19 | def to_numpy(tensor): 20 | if torch.is_tensor(tensor): 21 | return tensor.cpu().numpy() 22 | elif type(tensor).__module__ != 'numpy': 23 | raise ValueError("Cannot convert {} to numpy array".format(type(tensor))) 24 | return tensor 25 | 26 | 27 | def to_torch(ndarray): 28 | if type(ndarray).__module__ == 'numpy': 29 | return torch.from_numpy(ndarray) 30 | elif not torch.is_tensor(ndarray): 31 | raise ValueError("Cannot convert {} to torch tensor".format(type(ndarray))) 32 | return ndarray 33 | 34 | 35 | # 1.makefile/dir如果对应路径中的文件不存在,则新建这个文件 36 | def mkdir_if_missing(dir_path): 37 | try: 38 | os.makedirs(dir_path) 39 | except OSError as e: 40 | if e.errno != errno.EEXIST: 41 | raise 42 | 43 | 44 | # 2.Computes and stores the average and current value 计算和存储平均值、当前值,输出损失时用到 45 | class AverageMeter(object): 46 | def __init__(self): 47 | self.val = 0 48 | self.avg = 0 49 | self.sum = 0 50 | self.count = 0 51 | 52 | def reset(self): 53 | self.val = 0 54 | self.avg = 0 55 | self.sum = 0 56 | self.count = 0 57 | 58 | def update(self, val, n=1): 59 | self.val = val 60 | self.sum += val * n 61 | self.count += n 62 | self.avg = self.sum / self.count 63 | 64 | 65 | # 3. logger 日志管理 Write console output to external text file. 66 | class Logger(object): 67 | def __init__(self, fpath=None): 68 | self.console = sys.stdout 69 | self.file = None 70 | if fpath is not None: 71 | mkdir_if_missing(os.path.dirname(fpath)) 72 | self.file = open(fpath, 'w') 73 | 74 | def __del__(self): 75 | self.close() 76 | 77 | def __enter__(self): 78 | pass 79 | 80 | def __exit__(self): 81 | self.close() 82 | 83 | def write(self, msg): 84 | self.console.write(msg) 85 | if self.file is not None: 86 | self.file.write(msg) 87 | 88 | def flush(self): 89 | self.console.flush() 90 | if self.file is not None: 91 | self.file.flush() 92 | os.fsync(self.file.fileno()) 93 | 94 | def close(self): 95 | self.console.close() 96 | if self.file is not None: 97 | self.file.close() 98 | 99 | 100 | # 4.serialization 101 | def read_json(fpath): 102 | with open(fpath, 'r') as f: 103 | obj = json.load(f) # 解码:把Json格式字符串解码转换成Python对象 104 | return obj 105 | 106 | 107 | def write_json(obj, fpath): 108 | mkdir_if_missing(osp.dirname(fpath)) # 对应的文件如果不存在,则新建它 109 | with open(fpath, 'w') as f: 110 | json.dump(obj, f, indent=4, separators=(',', ':')) # 这表示dictionary内keys之间用“,”隔开,而KEY和value之间用“:”隔开 111 | # 使用json.dump()将数据obj写入文件f,会换行且按照indent的数量显示前面的空白, 112 | 113 | 114 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): # 保存checkpoint文件 115 | mkdir_if_missing(osp.dirname(fpath)) 116 | torch.save(state, fpath) 117 | if is_best: 118 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 119 | 120 | 121 | def load_checkpoint(fpath): # 加载checkpoint文件 122 | if osp.isfile(fpath): 123 | checkpoint = torch.load(fpath) 124 | print("=> Loaded checkpoint '{}'".format(fpath)) 125 | return checkpoint 126 | else: 127 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 128 | --------------------------------------------------------------------------------